Data Persistence and Checkpointing
Checkpointing allows you to build fault-tolerant tasks that can resume from a saved state after a failure or retry. In this tutorial, you will implement a task that persists its progress to an object store and restores it upon retry.
What You Will Build
You will create a task that performs a multi-step calculation. If the task fails midway, it will use the flyte.Checkpoint helper to:
- Check for a previously saved state.
- Resume execution from the last successful iteration.
- Save progress periodically to ensure minimal work is lost on the next retry.
Prerequisites
To use checkpointing, your task must be configured with retries greater than 0.
flyteinstalled in your environment.- A task decorated with
@env.task(retries=N).
Step 1: Accessing the Checkpoint Helper
The checkpoint helper is available through the task context. You access it using flyte.ctx().checkpoint.
import flyte
from flyte import env
@env.task(retries=3)
def resilient_task(n_iterations: int) -> int:
# Access the checkpoint helper
checkpoint = flyte.ctx().checkpoint
if checkpoint is None:
# Checkpointing is not configured for this execution
return 0
# ... logic follows
The flyte.ctx().checkpoint property returns an instance of flyte._checkpoint.Checkpoint. If the runtime has not provided a checkpoint path, this will be None.
Step 2: Loading Previous State
Before starting your main logic, check if a previous attempt saved any data. Use load_sync() for standard tasks or await load() for async tasks.
# Load the path to the previous checkpoint data
prev_path = checkpoint.load_sync()
start_iteration = 0
if prev_path:
# If prev_path is not None, it points to a local file or directory
# containing the data from the last successful save_sync()
state_bytes = prev_path.read_bytes()
start_iteration = int(state_bytes.decode())
print(f"Resuming from iteration {start_iteration}")
When load_sync() is called:
- If it's the first attempt, it returns
None. - If a previous attempt saved data, it downloads that data to a local temporary workspace and returns the
pathlib.Pathto it. - If the remote data was a single file or raw bytes, the returned path points to a file named
payloadinside the checkpoint directory.
Step 3: Saving Progress
As your task progresses, save the state so that subsequent retries can pick up where you left off.
for i in range(start_iteration, n_iterations):
# Perform work...
current_result = i + 1
# Save state after each successful iteration
checkpoint.save_sync(f"{current_result}".encode())
return n_iterations
The save_sync() method accepts bytes, a str path, or a pathlib.Path.
- Bytes: Uploaded directly as a single object.
- File Path: The file is uploaded directly.
- Directory Path: The directory is automatically compressed into a
.tar.gzarchive before upload.
Step 4: Handling Complex State (Directories)
For machine learning or complex data, you often need to save entire directories. When you load a checkpoint that was saved as a directory, load_sync() returns the path to the root of the extracted directory.
You can use flyte.latest_checkpoint to find specific files within that restored tree, which is particularly useful for framework-specific files like PyTorch Lightning's last.ckpt.
import pathlib
import torch
@env.task(retries=2)
async def train_model(epochs: int):
cp = flyte.ctx().checkpoint
# Restore directory
prev_cp_path = await cp.load()
if prev_cp_path:
# Find the newest checkpoint file matching a pattern
last_ckpt = flyte.latest_checkpoint(prev_cp_path, "*.pt")
if last_ckpt:
model_state = torch.load(last_ckpt)
# Load state into model...
# Training loop...
for epoch in range(epochs):
# Save a directory of artifacts
checkpoint_dir = pathlib.Path("checkpoints")
checkpoint_dir.mkdir(exist_ok=True)
torch.save(model.state_dict(), checkpoint_dir / f"epoch_{epoch}.pt")
# This uploads the entire 'checkpoints' directory as a tarball
await cp.save(checkpoint_dir)
Complete Example
Here is the complete resilient task combining these steps:
import flyte
from flyte import env
import pathlib
@env.task(retries=5)
def persistent_counter(target: int) -> int:
checkpoint = flyte.ctx().checkpoint
if not checkpoint:
return 0
# 1. Restore
prev_path = checkpoint.load_sync()
count = 0
if prev_path:
count = int(prev_path.read_text())
# 2. Execute and Persist
while count < target:
count += 1
# Simulate work or potential failure
if count == 5:
raise RuntimeWarning("Simulated failure at step 5")
# Save current progress
checkpoint.save_sync(str(count).encode())
return count
Next Steps
- Async Tasks: Use
await checkpoint.load()andawait checkpoint.save()for non-blocking I/O inasynctasks. - ML Frameworks: Integrate with PyTorch Lightning or HuggingFace by passing
flyte.ctx().checkpointinto custom callbacks that triggersave_syncat the end of every epoch.