Getting Started with PySpark Tasks
Learn how to define and execute Spark tasks natively on Kubernetes using the Flyte Spark plugin. By the end of this guide, you will have a working PySpark task that processes data using a managed SparkSession.
Prerequisites
To follow this tutorial, ensure you have the following:
- The Flyte SDK and Spark plugin installed.
- A base image with Spark installed (e.g.,
apache/spark-py:v3.4.0). - Access to a Flyte cluster with the Spark operator enabled.
Step 1: Define Spark Configuration
The first step is to define the infrastructure requirements for your Spark job. You use the Spark class from flyteplugins.spark.task to specify memory, CPU, and instance counts for both the driver and executors.
from flyteplugins.spark.task import Spark
spark_conf = Spark(
spark_conf={
"spark.driver.memory": "1000M",
"spark.executor.memory": "1000M",
"spark.executor.cores": "1",
"spark.executor.instances": "2",
"spark.driver.cores": "1",
},
)
The Spark class acts as a configuration container. When the task is serialized, these settings are converted into a SparkJob specification for the Kubernetes Spark operator.
Step 2: Create a Task Environment
Next, you must associate this configuration with a TaskEnvironment. This environment acts as a template for your tasks, ensuring they receive the correct resources and plugin settings.
import flyte
spark_env = flyte.TaskEnvironment(
name="spark_env",
resources=flyte.Resources(cpu=(1, 2), memory=("2000Mi", "3000Mi")),
plugin_config=spark_conf,
)
By passing spark_conf to plugin_config, you tell Flyte that any task decorated with this environment should be treated as a Spark task.
Step 3: Define the Spark Task
Now you can define your processing logic. Use the @spark_env.task decorator. Inside the task, you do not create your own SparkSession; instead, you retrieve the one automatically managed by Flyte from the context.
import flyte
@spark_env.task
async def calculate_pi(partitions: int = 10) -> float:
# Retrieve the managed SparkSession
spark = flyte.ctx().data["spark_session"]
import random
def f(_):
x = random.random() * 2 - 1
y = random.random() * 2 - 1
return 1 if x ** 2 + y ** 2 <= 1 else 0
count = spark.sparkContext.parallelize(range(1, partitions + 1), partitions).map(f).reduce(lambda a, b: a + b)
return 4.0 * count / partitions
The PysparkFunctionTask plugin handles the lifecycle of the SparkSession. When running in a cluster, it also automatically zips your code bundle and adds it to the SparkContext via addPyFile(), ensuring your dependencies are available on all executors.
Step 4: Working with DataFrames
Flyte natively supports pyspark.sql.DataFrame as input and output types. You can use type annotations to pass data between Spark tasks and other Flyte tasks.
import pyspark.sql
from typing import Annotated
# Define column metadata if needed
columns = [("name", str), ("age", int)]
@spark_env.task
async def sum_of_all_ages(sd: Annotated[pyspark.sql.DataFrame, columns]) -> int:
# The DataFrame 'sd' is automatically loaded into the SparkSession
total_age = sd.groupBy().sum("age").collect()[0][0]
return total_age
Step 5: Advanced Configuration and Overrides
If you need to change the Spark configuration for a specific execution without modifying the environment, you can use the .override() method.
# Override the number of executor instances for a specific call
custom_spark_conf = Spark(spark_conf={"spark.executor.instances": "10"})
# Execute with the new configuration
result = await calculate_pi.override(plugin_config=custom_spark_conf)(partitions=100)
Additionally, you can use PodTemplate within the Spark configuration to customize the Kubernetes pods used for the driver and executors, such as adding custom labels or environment variables.
Complete Result
Your final script should look like this:
import flyte
from flyteplugins.spark.task import Spark
import pyspark.sql
# 1. Configuration
spark_conf = Spark(
spark_conf={
"spark.driver.memory": "1000M",
"spark.executor.memory": "1000M",
"spark.executor.cores": "1",
"spark.executor.instances": "2",
},
)
# 2. Environment
spark_env = flyte.TaskEnvironment(
name="spark_env",
plugin_config=spark_conf,
)
# 3. Task
@spark_env.task
async def process_spark_data(n: int) -> int:
spark = flyte.ctx().data["spark_session"]
df = spark.createDataFrame([("Alice", 34), ("Bob", 45)], ["name", "age"])
return df.count()
# 4. Execution
if __name__ == "__main__":
import asyncio
print(asyncio.run(process_spark_data(n=10)))
Next Steps
- Explore using
hadoop_confin theSparkclass for custom S3/HDFS settings. - Learn how to use
PodTemplateto attach volumes to your Spark executors. - Check the
examples/plugins/pandera/pyspark_sql_schema.pyfile for advanced schema validation with Spark.