Skip to main content

State Persistence with Checkpoints

The flyte.Checkpoint helper provides a standardized way to persist and restore task state across retries and failures. By saving intermediate progress to object storage, long-running tasks can resume from their last successful state rather than restarting from scratch.

Core Concepts

Checkpointing in this SDK is built around the BaseCheckpoint abstract class and its primary implementation, Checkpoint, located in src/flyte/_checkpoint.py.

The Checkpoint Helper

The Checkpoint class manages a local workspace and handles the I/O between that workspace and remote object storage. It is typically accessed via the task context:

import flyte

@flyte.task(retries=3)
def my_task():
checkpoint = flyte.ctx().checkpoint
# Use checkpoint to load or save state

Local Workspace

Every Checkpoint instance maintains a temporary local directory (accessible via the path property). This directory serves as the staging area for data being saved or the extraction target for data being restored.

Persistence Models

The Checkpoint helper adapts its behavior based on the type of data provided to the save or save_sync methods.

1. Raw Bytes

For simple state (like a loop index or a small JSON string), you can save and load raw bytes directly.

# Saving bytes
checkpoint.save_sync(b"iteration_10")

# Loading bytes
path = checkpoint.load_sync()
if path:
state = path.read_bytes()

2. Single Files

When saving a file path or a pathlib.Path that points to a file, the SDK uploads that file directly to the remote destination.

Restoration Detail: When a single-file checkpoint is restored via load() or load_sync(), it is placed in the local workspace and renamed to payload. The method returns the path to this payload file.

3. Directories (Tarballs)

If you pass a directory path to save(), the SDK automatically creates a gzip-compressed tarball (.tar.gz) of the directory's top-level entries before uploading.

Restoration Detail: When a tarball checkpoint is restored, it is automatically extracted into the Checkpoint.path directory. The load() method returns the root of this extracted tree.

Sync vs. Async Tasks

The SDK provides distinct methods for standard synchronous tasks and async tasks to ensure non-blocking I/O where appropriate.

Task TypeLoad MethodSave Method
Synchronous (def)load_sync()save_sync(data)
Asynchronous (async def)await load()await save(data)

Example: Sync Task with Bytes

In examples/checkpoint/generic_data_checkpoint.py, a task uses save_sync to persist a loop index:

@env.task(retries=3)
def use_checkpoint(n_iterations: int) -> int:
checkpoint = flyte.ctx().checkpoint
path = checkpoint.load_sync()

start = 0
if path:
start = int(path.read_bytes().decode())

for index in range(start, n_iterations):
# ... logic ...
checkpoint.save_sync(f"{index + 1}".encode())
return index

Example: Async Task with PyTorch

In examples/checkpoint/pytorch_task_checkpoint.py, an async task saves a model state dict:

@env.task(retries=3)
async def train_linear(epochs: int = 3) -> float:
cp = flyte.ctx().checkpoint
prev_cp_path = await cp.load()

if prev_cp_path:
blob = torch.load(prev_cp_path, map_location="cpu")
model.load_state_dict(blob["model"])
start = int(blob["epoch"]) + 1

# ... training loop ...
await cp.save(wpath) # wpath is a pathlib.Path

Framework Integration and Utilities

Finding the Latest Checkpoint

When working with frameworks like PyTorch Lightning that may save multiple checkpoint files in a directory, the latest_checkpoint utility (in src/flyte/_checkpoint.py) helps identify the most recent file based on modification time or a custom key.

from flyte import latest_checkpoint

path = checkpoint.load_sync()
if path:
# Find the newest .ckpt file in the restored directory
last_file = latest_checkpoint(path, glob_pattern="**/*.ckpt")

URI Repair for Retries

The Checkpoint implementation includes logic to handle edge cases in remote execution environments. Specifically, it uses repair_union_prev_checkpoint_uri to ensure that when a task is retried, the checkpoint_src correctly points to the output of the previous attempt, even if the runtime environment provides a URI pointing to the current attempt's prefix. This is managed internally during Checkpoint initialization.

Internal Mechanics: BaseCheckpoint

The BaseCheckpoint class in src/flyte/_checkpoint.py defines the contract that all checkpoint helpers must follow:

class BaseCheckpoint(ABC):
@property
@abstractmethod
def path(self) -> pathlib.Path:
"""Local directory for reading and writing checkpoint files."""

@abstractmethod
def prev_exists(self) -> bool:
"""Whether a previous-checkpoint exists (retry / resume)."""

@abstractmethod
async def load(self) -> Optional[pathlib.Path]: ...

@abstractmethod
async def save(self, data: pathlib.Path | str | bytes) -> None: ...

The standard Checkpoint implementation uses flyte.io.File for all blob I/O, ensuring compatibility with the various storage backends supported by the Flyte platform.