Skip to main content

LLM Inference with Token-Aware Batching

To optimize Large Language Model (LLM) inference in high-throughput environments, you can use the TokenBatcher to aggregate concurrent requests into batches based on token counts rather than just record counts.

Basic Token-Aware Batching

The TokenBatcher allows you to define a target token budget per batch, ensuring you maximize GPU memory utilization without exceeding its limits.

import asyncio
from flyte.extras import TokenBatcher, Prompt

# 1. Define your inference logic
async def vllm_inference(batch: list[Prompt]) -> list[str]:
# Simulate LLM generation
# Note: Results must be returned in the same order as the input batch
return [f"Response to: {p.text}" for p in batch]

async def main():
# 2. Initialize the batcher with a token target (e.g., 32k tokens)
async with TokenBatcher(
inference_fn=vllm_inference,
target_batch_tokens=32_000,
batch_timeout_s=0.05,
) as batcher:
# 3. Submit prompts
# Prompt class automatically estimates tokens (~4 chars/token)
future1 = await batcher.submit(Prompt(text="What is Flyte?"))
future2 = await batcher.submit(Prompt(text="Explain dynamic batching."))

# 4. Wait for individual results
results = await asyncio.gather(future1, future2)
for res in results:
print(res)

if __name__ == "__main__":
asyncio.run(main())

Implementing Custom Token Estimators

If the default Prompt class's estimation (length // 4) is insufficient, you can implement the TokenEstimator protocol on your own dataclasses. The TokenBatcher will automatically call estimate_tokens() to calculate the batch "cost."

from dataclasses import dataclass
from flyte.extras import TokenBatcher

@dataclass
class CustomPrompt:
system_prompt: str
user_query: str

def estimate_tokens(self) -> int:
# Use a more accurate estimation or a real tokenizer
return (len(self.system_prompt) + len(self.user_query)) // 3

async def inference_fn(batch: list[CustomPrompt]) -> list[str]:
# Process the custom prompt objects
return ["result" for _ in batch]

# Usage
batcher = TokenBatcher(
inference_fn=inference_fn,
target_batch_tokens=4096
)

High-Throughput Inference in Flyte Tasks

In production Flyte workflows, you often want to share a single TokenBatcher (and its underlying model) across multiple concurrent task invocations on the same worker. This is achieved using flyte.ReusePolicy and alru_cache.

import asyncio
from async_lru import alru_cache
import flyte
from flyte.extras import TokenBatcher, Prompt

# Define a reusable environment with concurrency
gpu_env = flyte.TaskEnvironment(
name="gpu_worker",
resources=flyte.Resources(gpu="1"),
reusable=flyte.ReusePolicy(
replicas=2,
concurrency=10, # 10 concurrent tasks per replica
),
)

@alru_cache(maxsize=1)
async def get_shared_batcher():
"""Load model and start batcher once per process."""
async def inference(batch: list[Prompt]) -> list[str]:
# Real model inference here (e.g., vLLM or PyTorch)
return [f"Generated text for {len(batch)} items" for _ in batch]

batcher = TokenBatcher(
inference_fn=inference,
target_batch_tokens=32_000,
batch_timeout_s=0.05,
)
await batcher.start() # Manual start required when not using 'async with'
return batcher

@gpu_env.task
async def infer_task(texts: list[str]) -> list[str]:
batcher = await get_shared_batcher()

# Submit all texts to the shared batcher
futures = [await batcher.submit(Prompt(text=t)) for t in texts]

# All concurrent infer_task calls will have their prompts
# aggregated into the same batches by the singleton batcher.
return list(await asyncio.gather(*futures))

Configuration Parameters

The TokenBatcher (defined in src/flyte/extras/_dynamic_batcher.py) supports several parameters to tune performance:

ParameterDefaultDescription
target_batch_tokens32,000The token budget the batcher tries to fill before processing.
max_batch_size256Maximum number of records in a single batch, regardless of tokens.
batch_timeout_s0.05Max time (seconds) to wait for a batch to fill before processing it anyway.
max_queue_size5,000Size of the internal queue. submit() will await if this is full, providing backpressure.
prefetch_batches2Number of batches to prepare ahead of the processing loop.

Troubleshooting

Result Ordering

The inference_fn must return a list of results that exactly matches the order and length of the input batch. If the lengths do not match, the TokenBatcher will raise a ValueError and fail all requests in that batch.

Starting the Batcher

If you are not using the async with TokenBatcher(...) context manager, you must call await batcher.start() before calling submit(). Failure to do so will result in a RuntimeError.

Monitoring Performance

You can inspect the efficiency of your batching using the .stats property, which returns a BatchStats object:

stats = batcher.stats
print(f"Utilization: {stats.utilization:.2%}")
print(f"Total Batches: {stats.total_batches}")
print(f"Average Batch Size: {stats.avg_batch_size}")