Dask Cluster Architecture
The Dask integration in this SDK allows Flyte tasks to provision and interact with a transient Dask cluster. This architecture consists of a centralized scheduler and one or more worker groups, all managed through the DaskTask plugin.
Defining the Cluster Topology
The cluster configuration is defined using three primary classes in flyteplugins.dask.task: Dask, Scheduler, and WorkerGroup. These classes allow you to specify the resources and images for different components of the Dask cluster.
Scheduler Configuration
The Scheduler class configures the Dask scheduler pod. It allows for custom images and resource requirements:
class Scheduler:
image: Optional[str] = None
resources: Optional[Resources] = None
Worker Groups
The WorkerGroup class defines the pool of workers that will execute the distributed computations. You can specify the number of workers, their hardware resources, and a specific container image:
class WorkerGroup:
number_of_workers: Optional[int] = 1
image: Optional[str] = None
resources: Optional[Resources] = None
The Dask Configuration Object
The Dask class aggregates these configurations into a single object that is passed to the Flyte task environment:
class Dask:
scheduler: Scheduler = field(default_factory=Scheduler)
workers: WorkerGroup = field(default_factory=WorkerGroup)
The DaskTask Lifecycle
The DaskTask class (found in plugins/dask/src/flyteplugins/dask/task.py) manages the transition from a standard Flyte task to a distributed Dask job. It handles two critical phases: serialization and initialization.
Serialization and Custom Config
When a task is registered, DaskTask.custom_config transforms the Python configuration into a DaskJob protobuf message. This message tells the Flyte backend how to provision the Kubernetes pods for the scheduler and workers.
def custom_config(self, sctx: SerializationContext) -> Dict[str, Any]:
scheduler = self.plugin_config.scheduler
wg = self.plugin_config.workers
job = DaskJob(
scheduler=DaskScheduler(image=scheduler.image, resources=get_proto_resources(scheduler.resources)),
workers=DaskWorkerGroup(
number_of_workers=wg.number_of_workers, image=wg.image, resources=get_proto_resources(wg.resources)
),
)
return MessageToDict(job)
The Pre-Execution Hook
Before the user's task code runs, the pre method executes within the cluster. This method is responsible for connecting to the Dask cluster and ensuring that the environment is synchronized across all nodes.
async def pre(self, *args, **kwargs) -> Dict[str, Any]:
ctx = flyte.ctx()
code_bundle = ctx.code_bundle
if ctx.is_in_cluster() and code_bundle:
client = Client()
client.register_plugin(DownloadCodeBundleWorkerPlugin(code_bundle))
client.register_plugin(DownloadCodeBundleSchedulerPlugin(code_bundle))
return {}
Automated Code Synchronization
A common challenge in distributed computing is ensuring that user-defined functions (UDFs) and dependencies are available on all workers. This SDK solves this using "Code Bundle" plugins.
DownloadCodeBundleWorkerPlugin
This plugin runs on every Dask worker as it initializes. It downloads the Flyte code bundle and injects the current directory into the Python path:
class DownloadCodeBundleWorkerPlugin(WorkerPlugin):
async def setup(self, worker):
sys.path.insert(0, ".")
await download_code_bundle(self.code_bundle)
DownloadCodeBundleSchedulerPlugin
Similarly, the scheduler must also have access to the code bundle to coordinate tasks effectively. The DownloadCodeBundleSchedulerPlugin performs the same setup on the scheduler node:
class DownloadCodeBundleSchedulerPlugin(SchedulerPlugin):
async def start(self, scheduler):
sys.path.insert(0, ".")
await download_code_bundle(self.code_bundle)
Implementation Example
To use this architecture, you define a Dask configuration and associate it with a TaskEnvironment. The following example from examples/plugins/dask_example.py demonstrates how to set up a cluster with 4 workers:
from flyteplugins.dask import Dask, Scheduler, WorkerGroup
from flyte import Resources, TaskEnvironment
# Define the cluster topology
dask_config = Dask(
scheduler=Scheduler(),
workers=WorkerGroup(number_of_workers=4),
)
# Create the environment
dask_env = TaskEnvironment(
name="dask_env",
plugin_config=dask_config,
resources=Resources(cpu="1", memory="1Gi"),
)
@dask_env.task
async def hello_dask_nested(n: int = 3) -> typing.List[int]:
from distributed import Client
# The client automatically connects to the scheduler provisioned by Flyte
client = Client()
futures = client.map(lambda x: x + 1, range(n))
res = client.gather(futures)
return res
Critical Requirements
- Dask Installation: Any custom image provided to
SchedulerorWorkerGroupmust havedask[distributed]installed. - Environment Consistency: It is highly recommended that the scheduler, workers, and the task runner use the same image to prevent serialization errors (pickling issues) between different Python versions or library versions.
- Pathing: The code bundle plugins insert
.intosys.path. Ensure your task's working directory is writable.