High-Throughput Inference with Dynamic Batching
High-throughput inference, particularly for Large Language Models (LLMs), often requires aggregating many small, concurrent requests into larger batches to saturate GPU compute. The flyte.extras._dynamic_batcher module provides a robust mechanism for this via the DynamicBatcher and its LLM-specialized subclass, TokenBatcher.
These tools allow you to decouple request submission from batch execution, automatically assembling batches based on cost budgets (like token counts) and time windows.
Core Architecture
The DynamicBatcher operates using two internal asynchronous loops that coordinate via asyncio.Queue:
- Aggregation Loop: Drains the submission queue and assembles batches. It respects
target_batch_cost,max_batch_size, andbatch_timeout_s. If a batch doesn't reach the target cost within the timeout, it is dispatched anyway to maintain latency. - Processing Loop: Pulls assembled batches and executes the user-provided
process_fn. It resolves theasyncio.Futureassociated with each individual record in the batch.
The Submission Lifecycle
When you call submit(record), the batcher wraps your record in an internal _Envelope, which pairs the data with an asyncio.Future.
# From src/flyte/extras/_dynamic_batcher.py
class _Envelope(Generic[RecordT, ResultT]):
"""Internal wrapper pairing a record with its result future."""
record: RecordT
estimated_cost: int
future: asyncio.Future[ResultT] = field(default_factory=_make_future)
The submit method is a coroutine that provides natural backpressure; if the internal queue is full (defined by max_queue_size), it will await until space is available.
Cost and Token Estimation
To maximize utilization without exceeding hardware limits (like VRAM), the batcher needs to know the "cost" of each record. This is handled via the CostEstimator protocol.
# From src/flyte/extras/_dynamic_batcher.py
class CostEstimator(Protocol):
def estimate_cost(self) -> int: ...
You can implement this protocol directly on your data classes. For LLM workloads, TokenBatcher similarly looks for an estimate_tokens() method.
Example: Custom Record with Estimation
In examples/ml/batch_inference_saturate.py, a Prompt class is defined to estimate tokens based on character length:
@dataclass
class Prompt:
task_id: str
index: int
text: str
def estimate_tokens(self) -> int:
"""Rough token estimate (~4 chars per token)."""
return len(self.text) // 4 + 1
Implementing Persistent Workers in Flyte
The most powerful way to use DynamicBatcher in Flyte is the Persistent Worker pattern. By combining flyte.ReusePolicy with async_lru.alru_cache, you can maintain a single batcher instance across multiple concurrent task executions on the same container.
1. Define the Shared Batcher
Use alru_cache to ensure the model and batcher are initialized only once per process.
from async_lru import alru_cache
from flyte.extras import TokenBatcher
@alru_cache(maxsize=1)
async def get_batcher() -> TokenBatcher[Prompt, str]:
# inference_fn must accept list[RecordT] and return list[ResultT]
inference_fn = await get_model_inference_fn()
batcher = TokenBatcher[Prompt, str](
inference_fn=inference_fn,
target_batch_tokens=32_000,
batch_timeout_s=0.05,
)
await batcher.start()
return batcher
2. Configure the Task Environment
Set a ReusePolicy with concurrency > 1. This allows multiple Flyte task attempts to run concurrently in the same container, all feeding into the same batcher.
gpu_env = flyte.TaskEnvironment(
name="gpu_worker",
resources=flyte.Resources(gpu="L4:1"),
reusable=flyte.ReusePolicy(
replicas=2,
concurrency=10, # 10 concurrent tasks per replica
),
)
@gpu_env.task
async def infer_batch(prompts: list[str]) -> list[str]:
batcher = await get_batcher()
futures = [await batcher.submit(Prompt(text=p)) for p in prompts]
return await asyncio.gather(*futures)
Monitoring and Statistics
The BatchStats class provides real-time visibility into the batcher's performance. You can access these metrics via the batcher.stats property.
| Attribute | Description |
|---|---|
total_submitted | Total records received via submit. |
total_completed | Total records successfully processed. |
total_batches | Number of batches dispatched to process_fn. |
utilization | Fraction of time (0.0-1.0) spent processing vs. idle. |
Example of logging utilization from a task:
logger.info(
"Batcher utilization: %.1f%% | Avg Batch Size: %.1f",
batcher.stats.utilization * 100,
batcher.stats.avg_batch_size,
)
Important Constraints
- Order Preservation: The
process_fn(orinference_fn) must return a list of results in the exact same order and length as the input batch. If the lengths mismatch,DynamicBatcherraises aValueErrorand fails all futures in that batch. - Error Handling: If the
process_fnraises an exception, that exception is propagated to everyasyncio.Futurein the current batch. - Lifecycle: You must call
await batcher.start()before submitting records. Using the batcher as an asynchronous context manager (async with DynamicBatcher(...) as b:) handles start and stop automatically.