Skip to main content

Configuring Elastic Training and Fault Tolerance

Distributed training jobs often suffer from transient failures such as CUDA Out-of-Memory (OOM) errors, spot instance preemptions, or network desynchronization. The Elastic configuration in this SDK allows you to manage these failures by configuring PyTorch's elastic agent and NCCL timeout parameters.

Configuring Elastic Training

To enable elastic training, provide an Elastic object to the plugin_config of your TaskEnvironment. This configuration manages how many nodes and processes are launched, and how the system should respond when a worker fails.

import flyte
from flyteplugins.pytorch.task import Elastic

# Define a task environment with elastic training configuration
torch_env = flyte.TaskEnvironment(
name="pytorch-elastic-env",
resources=flyte.Resources(cpu=(2, 4), memory=("2Gi", "4Gi"), gpu=2),
plugin_config=Elastic(
nnodes=1, # Number of nodes (can be a range like "1:2")
nproc_per_node=2, # Number of GPUs/processes per node
max_restarts=3, # Retry the whole group up to 3 times on failure
monitor_interval=3, # Poll worker health every 3 seconds
),
)

@torch_env.task
def train_task():
import torch.distributed as dist
dist.init_process_group("nccl")
# ... training logic ...

Handling CUDA OOM and Hangs

A common failure mode in PyTorch DDP is when one worker hits an OOM and skips a collective operation (like loss.backward()), causing all other workers to hang indefinitely. You can configure aggressive timeouts to detect and fail these jobs quickly.

The total worst-case time before a job fails is calculated as: (max_restarts + 1) * (nccl_collective_timeout_sec + nccl_heartbeat_timeout_sec)

from flyteplugins.pytorch.task import Elastic

# Aggressive configuration for fast failure detection
elastic_config = Elastic(
nproc_per_node=2,
nnodes=1,
max_restarts=0, # Fail immediately on first error
nccl_collective_timeout_sec=60, # Collective ops time out after 60s
nccl_heartbeat_timeout_sec=60, # Heartbeat watchdog kills process 60s later
nccl_async_error_handling=True, # Abort stuck collectives asynchronously
)

Key Parameters for Fault Tolerance

ParameterDescription
max_restartsMaximum number of worker group restarts. Set to 0 for deterministic failures (e.g., model too large for GPU) to avoid useless retry cycles.
nccl_collective_timeout_secTimeout for individual NCCL operations (e.g., all-reduce). When a worker desyncs, others block for this long. Default is 600s.
nccl_heartbeat_timeout_secThe second phase of failure detection. After a collective timeout, the heartbeat monitor waits this long before sending SIGABRT to kill the worker. Default is 300s.
nccl_async_error_handlingWhen True, sets TORCH_NCCL_ASYNC_ERROR_HANDLING=1. This causes the worker to crash-exit on a stuck collective, which the agent detects within monitor_interval seconds.
nccl_enable_monitoringActivates NCCL's built-in monitoring thread (required for heartbeat timeouts). Defaults to True.

Managing Kubernetes Job Policies

You can further control the lifecycle of the underlying Kubernetes PyTorchJob using the RunPolicy class. This is passed via the run_policy argument of the Elastic config.

from flyteplugins.pytorch.task import Elastic, RunPolicy

elastic_with_policy = Elastic(
nnodes=2,
nproc_per_node=2,
run_policy=RunPolicy(
clean_pod_policy="all", # Clean up all pods after completion
backoff_limit=4, # K8s-level retries
ttl_seconds_after_finished=100, # Delete job 100s after finishing
active_deadline_seconds=3600, # Hard limit on job duration (1 hour)
),
)

Troubleshooting and Internal Mechanics

The Zombie Watchdog

PyTorch's elastic agent has a known deadlock where it can hang if all workers die simultaneously (e.g., from a SIGABRT triggered by NCCL). This plugin implements an internal _start_zombie_watchdog (found in plugins/pytorch/src/flyteplugins/pytorch/task.py) that monitors the /proc filesystem. If it detects that all worker processes have become zombies, it force-exits the agent to ensure the Flyte task actually fails instead of hanging forever.

Environment Variable Precedence

The plugin propagates configuration to workers via environment variables. However, it will not override variables that are already set in your environment. For example:

  • nccl_heartbeat_timeout_sec sets TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC.
  • nccl_async_error_handling sets TORCH_NCCL_ASYNC_ERROR_HANDLING.
  • nccl_collective_timeout_sec sets FLYTE_NCCL_COLLECTIVE_TIMEOUT_SEC.

CPU Threading Defaults

If nproc_per_node > 1 and OMP_NUM_THREADS is not explicitly set, the plugin defaults it to 1. This prevents multiple workers on the same node from over-subscribing CPU resources, which can lead to severe performance degradation.