LLM Inference with Token-Aware Batching
To optimize Large Language Model (LLM) inference in high-throughput environments, you can use the TokenBatcher to aggregate concurrent requests into batches based on token counts rather than just record counts.
Basic Token-Aware Batching
The TokenBatcher allows you to define a target token budget per batch, ensuring you maximize GPU memory utilization without exceeding its limits.
import asyncio
from flyte.extras import TokenBatcher, Prompt
# 1. Define your inference logic
async def vllm_inference(batch: list[Prompt]) -> list[str]:
# Simulate LLM generation
# Note: Results must be returned in the same order as the input batch
return [f"Response to: {p.text}" for p in batch]
async def main():
# 2. Initialize the batcher with a token target (e.g., 32k tokens)
async with TokenBatcher(
inference_fn=vllm_inference,
target_batch_tokens=32_000,
batch_timeout_s=0.05,
) as batcher:
# 3. Submit prompts
# Prompt class automatically estimates tokens (~4 chars/token)
future1 = await batcher.submit(Prompt(text="What is Flyte?"))
future2 = await batcher.submit(Prompt(text="Explain dynamic batching."))
# 4. Wait for individual results
results = await asyncio.gather(future1, future2)
for res in results:
print(res)
if __name__ == "__main__":
asyncio.run(main())
Implementing Custom Token Estimators
If the default Prompt class's estimation (length // 4) is insufficient, you can implement the TokenEstimator protocol on your own dataclasses. The TokenBatcher will automatically call estimate_tokens() to calculate the batch "cost."
from dataclasses import dataclass
from flyte.extras import TokenBatcher
@dataclass
class CustomPrompt:
system_prompt: str
user_query: str
def estimate_tokens(self) -> int:
# Use a more accurate estimation or a real tokenizer
return (len(self.system_prompt) + len(self.user_query)) // 3
async def inference_fn(batch: list[CustomPrompt]) -> list[str]:
# Process the custom prompt objects
return ["result" for _ in batch]
# Usage
batcher = TokenBatcher(
inference_fn=inference_fn,
target_batch_tokens=4096
)
High-Throughput Inference in Flyte Tasks
In production Flyte workflows, you often want to share a single TokenBatcher (and its underlying model) across multiple concurrent task invocations on the same worker. This is achieved using flyte.ReusePolicy and alru_cache.
import asyncio
from async_lru import alru_cache
import flyte
from flyte.extras import TokenBatcher, Prompt
# Define a reusable environment with concurrency
gpu_env = flyte.TaskEnvironment(
name="gpu_worker",
resources=flyte.Resources(gpu="1"),
reusable=flyte.ReusePolicy(
replicas=2,
concurrency=10, # 10 concurrent tasks per replica
),
)
@alru_cache(maxsize=1)
async def get_shared_batcher():
"""Load model and start batcher once per process."""
async def inference(batch: list[Prompt]) -> list[str]:
# Real model inference here (e.g., vLLM or PyTorch)
return [f"Generated text for {len(batch)} items" for _ in batch]
batcher = TokenBatcher(
inference_fn=inference,
target_batch_tokens=32_000,
batch_timeout_s=0.05,
)
await batcher.start() # Manual start required when not using 'async with'
return batcher
@gpu_env.task
async def infer_task(texts: list[str]) -> list[str]:
batcher = await get_shared_batcher()
# Submit all texts to the shared batcher
futures = [await batcher.submit(Prompt(text=t)) for t in texts]
# All concurrent infer_task calls will have their prompts
# aggregated into the same batches by the singleton batcher.
return list(await asyncio.gather(*futures))
Configuration Parameters
The TokenBatcher (defined in src/flyte/extras/_dynamic_batcher.py) supports several parameters to tune performance:
| Parameter | Default | Description |
|---|---|---|
target_batch_tokens | 32,000 | The token budget the batcher tries to fill before processing. |
max_batch_size | 256 | Maximum number of records in a single batch, regardless of tokens. |
batch_timeout_s | 0.05 | Max time (seconds) to wait for a batch to fill before processing it anyway. |
max_queue_size | 5,000 | Size of the internal queue. submit() will await if this is full, providing backpressure. |
prefetch_batches | 2 | Number of batches to prepare ahead of the processing loop. |
Troubleshooting
Result Ordering
The inference_fn must return a list of results that exactly matches the order and length of the input batch. If the lengths do not match, the TokenBatcher will raise a ValueError and fail all requests in that batch.
Starting the Batcher
If you are not using the async with TokenBatcher(...) context manager, you must call await batcher.start() before calling submit(). Failure to do so will result in a RuntimeError.
Monitoring Performance
You can inspect the efficiency of your batching using the .stats property, which returns a BatchStats object:
stats = batcher.stats
print(f"Utilization: {stats.utilization:.2%}")
print(f"Total Batches: {stats.total_batches}")
print(f"Average Batch Size: {stats.avg_batch_size}")