Skip to main content

High-Throughput Inference with Dynamic Batching

High-throughput inference, particularly for Large Language Models (LLMs), often requires aggregating many small, concurrent requests into larger batches to saturate GPU compute. The flyte.extras._dynamic_batcher module provides a robust mechanism for this via the DynamicBatcher and its LLM-specialized subclass, TokenBatcher.

These tools allow you to decouple request submission from batch execution, automatically assembling batches based on cost budgets (like token counts) and time windows.

Core Architecture

The DynamicBatcher operates using two internal asynchronous loops that coordinate via asyncio.Queue:

  1. Aggregation Loop: Drains the submission queue and assembles batches. It respects target_batch_cost, max_batch_size, and batch_timeout_s. If a batch doesn't reach the target cost within the timeout, it is dispatched anyway to maintain latency.
  2. Processing Loop: Pulls assembled batches and executes the user-provided process_fn. It resolves the asyncio.Future associated with each individual record in the batch.

The Submission Lifecycle

When you call submit(record), the batcher wraps your record in an internal _Envelope, which pairs the data with an asyncio.Future.

# From src/flyte/extras/_dynamic_batcher.py
class _Envelope(Generic[RecordT, ResultT]):
"""Internal wrapper pairing a record with its result future."""
record: RecordT
estimated_cost: int
future: asyncio.Future[ResultT] = field(default_factory=_make_future)

The submit method is a coroutine that provides natural backpressure; if the internal queue is full (defined by max_queue_size), it will await until space is available.

Cost and Token Estimation

To maximize utilization without exceeding hardware limits (like VRAM), the batcher needs to know the "cost" of each record. This is handled via the CostEstimator protocol.

# From src/flyte/extras/_dynamic_batcher.py
class CostEstimator(Protocol):
def estimate_cost(self) -> int: ...

You can implement this protocol directly on your data classes. For LLM workloads, TokenBatcher similarly looks for an estimate_tokens() method.

Example: Custom Record with Estimation

In examples/ml/batch_inference_saturate.py, a Prompt class is defined to estimate tokens based on character length:

@dataclass
class Prompt:
task_id: str
index: int
text: str

def estimate_tokens(self) -> int:
"""Rough token estimate (~4 chars per token)."""
return len(self.text) // 4 + 1

Implementing Persistent Workers in Flyte

The most powerful way to use DynamicBatcher in Flyte is the Persistent Worker pattern. By combining flyte.ReusePolicy with async_lru.alru_cache, you can maintain a single batcher instance across multiple concurrent task executions on the same container.

1. Define the Shared Batcher

Use alru_cache to ensure the model and batcher are initialized only once per process.

from async_lru import alru_cache
from flyte.extras import TokenBatcher

@alru_cache(maxsize=1)
async def get_batcher() -> TokenBatcher[Prompt, str]:
# inference_fn must accept list[RecordT] and return list[ResultT]
inference_fn = await get_model_inference_fn()

batcher = TokenBatcher[Prompt, str](
inference_fn=inference_fn,
target_batch_tokens=32_000,
batch_timeout_s=0.05,
)
await batcher.start()
return batcher

2. Configure the Task Environment

Set a ReusePolicy with concurrency > 1. This allows multiple Flyte task attempts to run concurrently in the same container, all feeding into the same batcher.

gpu_env = flyte.TaskEnvironment(
name="gpu_worker",
resources=flyte.Resources(gpu="L4:1"),
reusable=flyte.ReusePolicy(
replicas=2,
concurrency=10, # 10 concurrent tasks per replica
),
)

@gpu_env.task
async def infer_batch(prompts: list[str]) -> list[str]:
batcher = await get_batcher()
futures = [await batcher.submit(Prompt(text=p)) for p in prompts]
return await asyncio.gather(*futures)

Monitoring and Statistics

The BatchStats class provides real-time visibility into the batcher's performance. You can access these metrics via the batcher.stats property.

AttributeDescription
total_submittedTotal records received via submit.
total_completedTotal records successfully processed.
total_batchesNumber of batches dispatched to process_fn.
utilizationFraction of time (0.0-1.0) spent processing vs. idle.

Example of logging utilization from a task:

logger.info(
"Batcher utilization: %.1f%% | Avg Batch Size: %.1f",
batcher.stats.utilization * 100,
batcher.stats.avg_batch_size,
)

Important Constraints

  • Order Preservation: The process_fn (or inference_fn) must return a list of results in the exact same order and length as the input batch. If the lengths mismatch, DynamicBatcher raises a ValueError and fails all futures in that batch.
  • Error Handling: If the process_fn raises an exception, that exception is propagated to every asyncio.Future in the current batch.
  • Lifecycle: You must call await batcher.start() before submitting records. Using the batcher as an asynchronous context manager (async with DynamicBatcher(...) as b:) handles start and stop automatically.