Skip to main content

Building Custom Async Connectors

In this tutorial, you will build a custom asynchronous connector to integrate an external batch processing service with Flyte. This allows Flyte to trigger long-running jobs on external platforms, poll for their status, and retrieve results once they complete.

By the end of this guide, you will have implemented a BatchJobConnector that simulates an external service and a corresponding BatchJobTask that users can call from their Flyte workflows.

Prerequisites

To follow this tutorial, you need the flyte-sdk installed in your environment. You should also have a basic understanding of Flyte tasks and templates.

Step 1: Define Job Metadata

Every asynchronous job needs a way to track its state (like a job ID). You define this by subclassing ResourceMeta. This metadata is serialized and stored by Flyte to track the job across restarts.

Create a file named connector.py and add the following:

from dataclasses import dataclass
from flyte.connectors import ResourceMeta

@dataclass
class BatchJobMetadata(ResourceMeta):
job_id: str
created_at: float

The ResourceMeta base class provides encode and decode methods that use JSON serialization by default, ensuring your metadata can be safely stored in Flyte's database.

Step 2: Implement the Async Connector

The AsyncConnector is the core of the integration. You must implement three primary methods: create, get, and delete.

Add the connector implementation to connector.py:

import time
import uuid
from typing import Any, Dict, Optional
from flyteidl2.core.execution_pb2 import TaskExecution
from flyte.connectors import AsyncConnector, Resource

class BatchJobConnector(AsyncConnector):
name = "Batch Job Connector"
task_type_name = "batch_job"
metadata_type = BatchJobMetadata

async def create(
self,
task_template,
output_prefix: str,
inputs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> BatchJobMetadata:
# Simulate submitting a job to an external API
job_id = str(uuid.uuid4())[:8]
print(f"Submitted batch job {job_id}")
return BatchJobMetadata(job_id=job_id, created_at=time.time())

async def get(self, resource_meta: BatchJobMetadata, **kwargs) -> Resource:
# Simulate polling the external API
elapsed = time.time() - resource_meta.created_at

if elapsed < 10: # Simulate a 10-second job
return Resource(phase=TaskExecution.RUNNING, message="Job in progress")

# Return SUCCEEDED phase and the job outputs
return Resource(
phase=TaskExecution.SUCCEEDED,
message="Job completed successfully",
outputs={"result": f"output-from-{resource_meta.job_id}"},
)

async def delete(self, resource_meta: BatchJobMetadata, **kwargs):
# Clean up or cancel the job on the external service
print(f"Cancelled job {resource_meta.job_id}")

Key Components of the Connector

  • create: Invoked when the Flyte task starts. It should return a ResourceMeta object containing the external job's identifier.
  • get: Invoked periodically by Flyte to check the job status. It returns a Resource object containing the current phase (e.g., RUNNING, SUCCEEDED, FAILED).
  • delete: Invoked if the Flyte task is aborted. This should be idempotent.

Step 3: Register the Connector

For Flyte to discover your connector, you must register it with the ConnectorRegistry.

Add this to the bottom of connector.py:

from flyte.connectors import ConnectorRegistry

ConnectorRegistry.register(BatchJobConnector())

In a production environment, you can also use the flyte.connectors entry point in your setup.py or pyproject.toml to automatically register connectors when your package is installed.

Step 4: Create the User-Facing Task

To allow users to use this connector in Python, you create a task class that inherits from TaskTemplate and AsyncConnectorExecutorMixin. The mixin enables local execution by delegating calls to your registered connector.

Create a file named task.py:

from typing import Any, Dict, Optional, Type
from flyte.connectors import AsyncConnectorExecutorMixin
from flyte.extend import TaskTemplate
from flyte.models import NativeInterface, SerializationContext

class BatchJobTask(AsyncConnectorExecutorMixin, TaskTemplate):
_TASK_TYPE = "batch_job"

def __init__(
self,
name: str,
inputs: Optional[Dict[str, Type]] = None,
outputs: Optional[Dict[str, Type]] = None,
**kwargs,
):
super().__init__(
name=name,
interface=NativeInterface(
{k: (v, None) for k, v in inputs.items()} if inputs else {},
outputs or {},
),
task_type=self._TASK_TYPE,
image=None,
**kwargs,
)

def custom_config(self, sctx: SerializationContext) -> Optional[Dict[str, Any]]:
# You can pass configuration to the connector here
return {"timeout": 300}

Step 5: Local Execution and Testing

Because you used AsyncConnectorExecutorMixin, you can run this task locally. The mixin will look up the BatchJobConnector in the registry and execute the create/get loop.

import asyncio
from task import BatchJobTask
import connector # Ensure the connector is registered

async def run_locally():
task = BatchJobTask(
name="my_batch_job",
inputs={"val": int},
outputs={"result": str}
)

# This will trigger the connector's create and poll get until success
result = await task.execute(val=42)
print(f"Task result: {result}")

if __name__ == "__main__":
asyncio.run(run_locally())

Advanced: Handling Secrets

If your connector needs credentials (like an API key), use the ConnectorSecretsMixin. This ensures secrets are passed from the Flyte environment to your connector methods.

from flyte.connectors import ConnectorSecretsMixin

class SecureBatchJobConnector(ConnectorSecretsMixin, AsyncConnector):
name = "Secure Batch Job"
task_type_name = "secure_batch"

def __init__(self, secrets: Dict[str, str]):
super().__init__(secrets=secrets)

async def create(self, task_template, output_prefix, **kwargs) -> ResourceMeta:
# Secrets are passed in **kwargs based on the keys defined in self._secrets
api_key = kwargs.get("my_api_key")
# Use api_key to authenticate with external service
...

When initializing the connector, you map the internal secret ID to the environment variable name:

connector = SecureBatchJobConnector(secrets={"my_api_key": "EXTERNAL_SERVICE_API_KEY"})
ConnectorRegistry.register(connector)