Machine Learning Operations
Machine Learning Operations (MLOps) in this SDK are centered around efficient resource allocation for training and robust state management through checkpointing. These features allow tasks to utilize specialized hardware like GPUs and TPUs while ensuring that long-running training jobs can resume from the last successful state after a failure or preemption.
Resource Management
The Resources class in src/flyte/_resources.py is the primary interface for specifying compute requirements. It supports standard CPU and memory requests as well as fine-grained accelerator configurations.
Accelerators and GPUs
You can request GPUs by passing a string (e.g., "T4:1") or by using the GPU helper function for advanced configurations like Multi-Instance GPU (MIG) partitioning.
from flyte import Resources, GPU
# Requesting a specific GPU type and quantity
res = Resources(gpu="A100:8")
# Advanced configuration with MIG partitioning
res_mig = Resources(gpu=GPU(device="A100", quantity=1, partition="1g.5gb"))
The SDK supports a wide range of accelerators defined in the Accelerators literal, including:
- NVIDIA GPUs: T4, L4, A10, A100 (40G/80G), H100, H200, V100.
- Google Cloud TPUs: V5P, V6E (via the
TPUhelper). - AWS Neuron: Inf1, Inf2, Trn1, Trn2 (via the
Neuronhelper). - AMD GPUs: MI100, MI210, MI300X, etc.
Shared Memory
For ML workloads that involve heavy data loading (e.g., using PyTorch DataLoader with multiple workers), you can configure shared memory (/dev/shm) using the shm parameter:
# Set shared memory to 16 GiB
res = Resources(cpu=8, memory="32Gi", shm="16Gi")
# Automatically set to the maximum available on the node
res_auto = Resources(shm="auto")
Task Checkpointing
Checkpointing allows a task to save its state to durable storage and recover it in a subsequent attempt. This is managed by the Checkpoint class, which is accessible via flyte.ctx().checkpoint.
The Checkpoint Lifecycle
A typical checkpointing flow involves loading the previous state at the start of the task and saving the new state periodically (e.g., after every epoch).
import flyte
import torch
@env.task(retries=3)
def train_task(epochs: int):
cp = flyte.ctx().checkpoint
# 1. Load previous state if it exists
prev_path = cp.load_sync()
if prev_path:
# If the checkpoint was a single file, it is restored as 'payload'
# If it was a directory, prev_path is the directory root
state = torch.load(prev_path)
model.load_state_dict(state["model"])
start_epoch = state["epoch"]
for epoch in range(start_epoch, epochs):
# ... training logic ...
# 2. Save current state
# Passing a directory tars it; passing bytes or a file saves it directly.
cp.save_sync({"model": model.state_dict(), "epoch": epoch + 1})
Sync vs. Async Operations
The Checkpoint class provides both synchronous and asynchronous methods to match the task type:
- Sync Tasks: Use
load_sync()andsave_sync(). - Async Tasks: Use
await load()andawait save().
State Recovery with latest_checkpoint
When Checkpoint.load() restores a directory tree (because a directory was saved), you often need to find the most recent checkpoint file within that tree. The latest_checkpoint utility in src/flyte/_checkpoint.py simplifies this by searching for files matching a glob pattern and sorting them by modification time.
from flyte import latest_checkpoint
# Find the newest .ckpt file in the restored directory
ckpt_path = latest_checkpoint(cp.path, glob_pattern="**/*.ckpt")
if ckpt_path:
model.load_from_checkpoint(ckpt_path)
Framework Integrations
The SDK's checkpointing system is designed to integrate with popular ML frameworks by using callbacks. This ensures that when the framework saves a checkpoint to local disk, it is automatically synchronized with Flyte's durable storage.
PyTorch Lightning Example
In examples/checkpoint/pytorch_lightning_checkpoint.py, a custom callback is used to trigger save_sync at the end of each epoch:
class FlyteLightningCheckpointCallback(ModelCheckpoint):
def __init__(self, flyte_checkpoint: flyte.Checkpoint, **kwargs) -> None:
super().__init__(**kwargs)
self._flyte_checkpoint = flyte_checkpoint
def on_train_epoch_end(self, trainer, pl_module):
super().on_train_epoch_end(trainer, pl_module)
if self.dirpath:
# Sync the entire checkpoint directory to Flyte
self._flyte_checkpoint.save_sync(pathlib.Path(self.dirpath))
Hugging Face Trainer Example
Similarly, for Hugging Face, a TrainerCallback can be used to sync the output directory:
class FlyteTrainerCheckpointCallback(TrainerCallback):
def __init__(self, flyte_checkpoint: flyte.Checkpoint):
self._flyte_checkpoint = flyte_checkpoint
def on_save(self, args, state, control, **kwargs):
# Sync the latest checkpoint folder to Flyte
self._flyte_checkpoint.save_sync(pathlib.Path(args.output_dir))
Implementation Details
URI Repair
The SDK includes a specialized repair_union_prev_checkpoint_uri function to handle a specific backend behavior where the "previous checkpoint" URI might incorrectly point to the current attempt's directory. If FLYTE_ATTEMPT_NUMBER is detected, the SDK automatically rewrites the URI to point to the correct previous attempt (n-1), ensuring that retries can actually find the state from the failed attempt.
Storage Optimization
To minimize task startup latency, flyte.io and flyte.storage are imported lazily within Checkpoint methods. This prevents heavy dependencies like pydantic or fsspec from being loaded for tasks that do not utilize checkpointing features.