Distributed PyTorch Training
Distributed PyTorch training in this codebase is implemented via the TorchFunctionTask plugin, which integrates Flyte tasks with the TorchElastic (now part of torch.distributed) framework. This allows developers to run multi-node, multi-process training jobs with built-in fault tolerance and environment management.
The TorchFunctionTask Plugin
The TorchFunctionTask (defined in plugins/pytorch/src/flyteplugins/pytorch/task.py) is a specialized task template that transforms a standard Python function into a distributed PyTorch job. It manages the lifecycle of the training process, including:
- Environment Setup: Configuring NCCL timeouts, OpenMP threads, and unbuffered logging.
- Process Launching: Using
torch.distributed.launcher.api.elastic_launchto start worker processes. - Fault Detection: Monitoring worker health and handling failures like CUDA Out-of-Memory (OOM).
- Deadlock Prevention: Running a "zombie watchdog" to detect and recover from elastic agent hangs.
Configuring Distributed Execution
Distributed parameters are defined using the Elastic configuration class, which is passed to a TaskEnvironment.
from flyteplugins.pytorch.task import Elastic
import flyte
# Define the environment with distributed configuration
torch_env = flyte.TaskEnvironment(
name="pytorch-dist-env",
plugin_config=Elastic(
nnodes=2, # Number of nodes (can be a range like "2:4")
nproc_per_node=1, # Processes per node (usually 1 per GPU)
max_restarts=3, # Retries for transient failures
nccl_async_error_handling=True,
),
)
@torch_env.task
def train_model(epochs: int):
import torch.distributed as dist
# Initialize the process group within the task
dist.init_process_group(backend="gloo")
# ... training logic ...
When nnodes is set to 1, TorchFunctionTask automatically downgrades the task_type to python-task for standard local execution, as seen in its __post_init__ method:
# plugins/pytorch/src/flyteplugins/pytorch/task.py
def __post_init__(self):
super().__post_init__()
self.task_type = "python-task" if self.plugin_config.nnodes == 1 else "pytorch"
Environment Initialization
Before the training function executes, the pre() method in TorchFunctionTask configures several critical environment variables to ensure stable distributed execution:
PYTHONUNBUFFERED: Set to"1"to ensure logs are visible even if a process crashes.OMP_NUM_THREADS: Ifnproc_per_node > 1and not otherwise set, this defaults to1to prevent system overloading from nested parallelism.- NCCL Timeouts: The plugin propagates
nccl_heartbeat_timeout_secandnccl_collective_timeout_secto the worker processes.
The launcher_entrypoint function specifically patches torch.distributed to respect these timeouts before the user's code calls init_process_group():
# plugins/pytorch/src/flyteplugins/pytorch/task.py
def launcher_entrypoint(tctx: TaskContext, fn: bytes, kwargs: dict):
# ...
nccl_timeout = os.environ.get("FLYTE_NCCL_COLLECTIVE_TIMEOUT_SEC")
if nccl_timeout is not None:
from datetime import timedelta
import torch.distributed.constants
import torch.distributed.distributed_c10d
td = timedelta(seconds=int(nccl_timeout))
torch.distributed.constants.default_pg_nccl_timeout = td
torch.distributed.distributed_c10d.default_pg_nccl_timeout = td
# ...
Fault Tolerance and the Zombie Watchdog
Distributed training is prone to specific failure modes, such as one rank hitting a CUDA OOM while others wait indefinitely on a collective operation (e.g., all_reduce).
NCCL Failure Detection
The Elastic config allows fine-tuning failure detection:
nccl_heartbeat_timeout_sec: Defaults to 300s (5 minutes) in this plugin, significantly lower than the PyTorch default of 30 minutes. This ensures faster recovery when a worker becomes unresponsive.nccl_async_error_handling: When enabled, it allows NCCL to abort stuck collectives asynchronously, causing the worker to crash-exit so the elastic agent can detect it immediately.
The Zombie Watchdog
A known issue in PyTorch's elastic agent is a deadlock that occurs when all workers die from SIGABRT (common during NCCL timeouts). The agent can hang while trying to acquire a shared semaphore.
To mitigate this, TorchFunctionTask.execute() starts a _start_zombie_watchdog thread. This watchdog monitors /proc to count zombie child processes. If the number of zombies matches nproc_per_node, it assumes the agent is deadlocked and force-exits the process:
# plugins/pytorch/src/flyteplugins/pytorch/task.py
if len(zombie_pids) >= nproc:
logger.error("Zombie watchdog: %d worker processes are zombies... Force-exiting.", len(zombie_pids))
os._exit(1)
Kubernetes Run Policies
For execution on Kubernetes, the RunPolicy class allows you to define how the underlying PyTorchJob pods are managed. This is converted into a DistributedPyTorchTrainingTask during serialization.
| Parameter | Description |
|---|---|
clean_pod_policy | Controls pod cleanup after completion ("None", "all", or "Running"). |
ttl_seconds_after_finished | How long to keep the job record after it finishes. |
active_deadline_seconds | Maximum duration the job can remain active. |
backoff_limit | Number of retries before marking the job as failed. |
Example configuration:
from flyteplugins.pytorch.task import RunPolicy, Elastic
config = Elastic(
nnodes=2,
nproc_per_node=1,
run_policy=RunPolicy(
clean_pod_policy="all",
ttl_seconds_after_finished=3600,
)
)
This configuration ensures that once the training is complete, the Kubernetes pods are cleaned up, preventing resource leakage in the cluster.