Distributed Computing
Flyte provides native integrations for distributed computing frameworks, allowing you to scale workloads across Spark, Ray, and Dask clusters. These integrations are implemented through a plugin system where specific configuration classes are passed to a TaskEnvironment.
The Plugin Architecture
Distributed computing in this SDK is built on the AsyncFunctionTaskTemplate. When you define a TaskEnvironment with a plugin_config, Flyte uses the TaskPluginRegistry to map that configuration to a specific task implementation.
For example, when a Spark configuration is provided, the environment automatically uses the PysparkFunctionTask plugin to handle the setup and execution of the Spark job.
Apache Spark
Spark integration allows you to run PySpark code natively on Kubernetes. The SDK handles the creation of the SparkSession and ensures your code is distributed to all executors.
Configuration
The Spark class (defined in plugins/spark/src/flyteplugins/spark/task.py) is used to configure the Spark cluster:
from flyteplugins.spark.task import Spark
spark_conf = Spark(
spark_conf={
"spark.driver.memory": "3000M",
"spark.executor.instances": "2",
},
hadoop_conf={
"fs.s3a.endpoint": "s3.amazonaws.com",
}
)
How it Works
The PysparkFunctionTask plugin performs several automated steps:
- Session Injection: In the
premethod, it builds aSparkSessionand makes it available via the Flyte context. - Code Distribution: If running in a cluster, it zips the local code bundle and uses
sess.sparkContext.addPyFile()to distribute it to all Spark workers. - Lifecycle Management: It automatically stops the
SparkSessionafter the task completes if running in debug mode.
Usage
You access the session within your task using flyte.ctx():
@spark_env.task
async def my_spark_task(partitions: int) -> float:
# The session is automatically created by the plugin
spark = flyte.ctx().data["spark_session"]
df = spark.createDataFrame([(1, "a"), (2, "b")], ["id", "val"])
return df.count()
Ray
Ray integration enables distributed Python applications to run on a Ray cluster managed by KubeRay.
Configuration
Ray tasks are configured using RayJobConfig, which defines the head and worker node specifications:
from flyteplugins.ray.task import RayJobConfig, HeadNodeConfig, WorkerNodeConfig
ray_config = RayJobConfig(
head_node_config=HeadNodeConfig(ray_start_params={"log-color": "True"}),
worker_node_config=[WorkerNodeConfig(group_name="ray-group", replicas=2)],
runtime_env={"pip": ["numpy", "pandas"]},
)
How it Works
The RayFunctionTask (in plugins/ray/src/flyteplugins/ray/task.py) manages the Ray lifecycle:
- Initialization: The
premethod callsray.init(). - Environment Sync: It automatically sets the
working_dirin the Rayruntime_envto the current working directory, ensuring your code is available on all nodes. - KubeRay Integration: The
custom_configmethod transforms theRayJobConfiginto aRayJobprotobuf message, which Flyte uses to provision the cluster.
Usage
Once configured, you can use standard Ray primitives like .remote() and ray.get():
import ray
@ray_env.task
async def hello_ray(n: int) -> list[int]:
@ray.remote
def f(x):
return x * x
futures = [f.remote(i) for i in range(n)]
return ray.get(futures)
Dask
Dask integration provides an ephemeral Dask cluster for parallelizing Python code using Dask DataFrames or Bags.
Configuration
The Dask configuration class (in plugins/dask/src/flyteplugins/dask/task.py) defines the scheduler and worker groups:
from flyteplugins.dask.task import Dask, Scheduler, WorkerGroup
dask_config = Dask(
scheduler=Scheduler(resources=flyte.Resources(cpu="1", memory="2Gi")),
workers=WorkerGroup(number_of_workers=4, resources=flyte.Resources(cpu="2", memory="4Gi")),
)
How it Works
Dask tasks use custom plugins to synchronize code across the cluster:
- Code Sync: The
DaskTaskplugin registersDownloadCodeBundleWorkerPluginandDownloadCodeBundleSchedulerPlugin. - Dynamic Setup: These plugins run on every worker and the scheduler as they start, downloading the Flyte code bundle and adding it to the Python path.
- Client Connectivity: Inside the task, calling
Client()without arguments automatically connects to the scheduler managed by Flyte.
Usage
from distributed import Client
@dask_env.task
async def my_dask_task(n: int) -> list[int]:
client = Client() # Connects to the ephemeral cluster
futures = client.map(lambda x: x + 1, range(n))
return client.gather(futures)
Common Patterns and Requirements
Image Requirements
For all distributed frameworks, the Docker image used in the TaskEnvironment must have the corresponding framework installed:
- Spark: Requires
pyspark. - Ray: Requires
ray. - Dask: Requires
dask[distributed].
Resource Overrides
While defaults are set in the TaskEnvironment, you can override specific resource requirements at call-time using the .override() method, provided the environment is not marked as reusable.
# Override the number of Spark executors for a specific call
my_spark_task.override(
plugin_config=replace(spark_conf, spark_conf={"spark.executor.instances": "10"})
)(partitions=100)