Skip to main content

Dask Cluster Architecture

The Dask integration in this SDK allows Flyte tasks to provision and interact with a transient Dask cluster. This architecture consists of a centralized scheduler and one or more worker groups, all managed through the DaskTask plugin.

Defining the Cluster Topology

The cluster configuration is defined using three primary classes in flyteplugins.dask.task: Dask, Scheduler, and WorkerGroup. These classes allow you to specify the resources and images for different components of the Dask cluster.

Scheduler Configuration

The Scheduler class configures the Dask scheduler pod. It allows for custom images and resource requirements:

class Scheduler:
image: Optional[str] = None
resources: Optional[Resources] = None

Worker Groups

The WorkerGroup class defines the pool of workers that will execute the distributed computations. You can specify the number of workers, their hardware resources, and a specific container image:

class WorkerGroup:
number_of_workers: Optional[int] = 1
image: Optional[str] = None
resources: Optional[Resources] = None

The Dask Configuration Object

The Dask class aggregates these configurations into a single object that is passed to the Flyte task environment:

class Dask:
scheduler: Scheduler = field(default_factory=Scheduler)
workers: WorkerGroup = field(default_factory=WorkerGroup)

The DaskTask Lifecycle

The DaskTask class (found in plugins/dask/src/flyteplugins/dask/task.py) manages the transition from a standard Flyte task to a distributed Dask job. It handles two critical phases: serialization and initialization.

Serialization and Custom Config

When a task is registered, DaskTask.custom_config transforms the Python configuration into a DaskJob protobuf message. This message tells the Flyte backend how to provision the Kubernetes pods for the scheduler and workers.

def custom_config(self, sctx: SerializationContext) -> Dict[str, Any]:
scheduler = self.plugin_config.scheduler
wg = self.plugin_config.workers

job = DaskJob(
scheduler=DaskScheduler(image=scheduler.image, resources=get_proto_resources(scheduler.resources)),
workers=DaskWorkerGroup(
number_of_workers=wg.number_of_workers, image=wg.image, resources=get_proto_resources(wg.resources)
),
)

return MessageToDict(job)

The Pre-Execution Hook

Before the user's task code runs, the pre method executes within the cluster. This method is responsible for connecting to the Dask cluster and ensuring that the environment is synchronized across all nodes.

async def pre(self, *args, **kwargs) -> Dict[str, Any]:
ctx = flyte.ctx()
code_bundle = ctx.code_bundle
if ctx.is_in_cluster() and code_bundle:
client = Client()
client.register_plugin(DownloadCodeBundleWorkerPlugin(code_bundle))
client.register_plugin(DownloadCodeBundleSchedulerPlugin(code_bundle))

return {}

Automated Code Synchronization

A common challenge in distributed computing is ensuring that user-defined functions (UDFs) and dependencies are available on all workers. This SDK solves this using "Code Bundle" plugins.

DownloadCodeBundleWorkerPlugin

This plugin runs on every Dask worker as it initializes. It downloads the Flyte code bundle and injects the current directory into the Python path:

class DownloadCodeBundleWorkerPlugin(WorkerPlugin):
async def setup(self, worker):
sys.path.insert(0, ".")
await download_code_bundle(self.code_bundle)

DownloadCodeBundleSchedulerPlugin

Similarly, the scheduler must also have access to the code bundle to coordinate tasks effectively. The DownloadCodeBundleSchedulerPlugin performs the same setup on the scheduler node:

class DownloadCodeBundleSchedulerPlugin(SchedulerPlugin):
async def start(self, scheduler):
sys.path.insert(0, ".")
await download_code_bundle(self.code_bundle)

Implementation Example

To use this architecture, you define a Dask configuration and associate it with a TaskEnvironment. The following example from examples/plugins/dask_example.py demonstrates how to set up a cluster with 4 workers:

from flyteplugins.dask import Dask, Scheduler, WorkerGroup
from flyte import Resources, TaskEnvironment

# Define the cluster topology
dask_config = Dask(
scheduler=Scheduler(),
workers=WorkerGroup(number_of_workers=4),
)

# Create the environment
dask_env = TaskEnvironment(
name="dask_env",
plugin_config=dask_config,
resources=Resources(cpu="1", memory="1Gi"),
)

@dask_env.task
async def hello_dask_nested(n: int = 3) -> typing.List[int]:
from distributed import Client
# The client automatically connects to the scheduler provisioned by Flyte
client = Client()
futures = client.map(lambda x: x + 1, range(n))
res = client.gather(futures)
return res

Critical Requirements

  • Dask Installation: Any custom image provided to Scheduler or WorkerGroup must have dask[distributed] installed.
  • Environment Consistency: It is highly recommended that the scheduler, workers, and the task runner use the same image to prevent serialization errors (pickling issues) between different Python versions or library versions.
  • Pathing: The code bundle plugins insert . into sys.path. Ensure your task's working directory is writable.