Model Prefetching
Model prefetching allows you to download and optionally shard HuggingFace models into your Flyte storage before they are needed by inference tasks. This significantly reduces cold-start times for LLM workloads and enables efficient multi-GPU inference by pre-calculating tensor-parallel shards.
In this tutorial, you will learn how to prefetch a model for standard use and how to shard a large model for high-performance inference.
Prerequisites
- Flyte SDK: Ensure you have
flyte-sdkinstalled. - HuggingFace Token: You must have a HuggingFace token with access to the models you wish to prefetch.
- Flyte Secret: Store your HuggingFace token as a Flyte secret. By default, the prefetch utility looks for a secret named
HF_TOKEN.
Step 1: Basic Model Prefetching
The simplest way to prefetch a model is to use the hf_model function. This will download the model files from HuggingFace and stream them directly to your Flyte remote storage.
import flyte
from flyte.prefetch import hf_model
# Initialize your Flyte connection
flyte.init(endpoint="your-flyte-endpoint")
# Prefetch a model to remote storage
run = hf_model(
repo="meta-llama/Llama-3.1-8B",
hf_token_key="HF_TOKEN"
)
# Wait for the prefetch task to complete
run.wait()
# Access the remote path of the prefetched model
model_dir = run.outputs()[0]
print(f"Model stored at: {model_dir.path}")
When you call hf_model, Flyte triggers a remote task that:
- Authenticates with HuggingFace using the secret provided in
hf_token_key. - Attempts to stream the model files directly to your remote object store (e.g., S3 or GCS).
- Falls back to a local download and upload if streaming is not supported by the storage backend.
Step 2: Sharding Large Models with vLLM
For large models (like Llama-3.1-70B), you often need to shard the weights across multiple GPUs using tensor parallelism. The hf_model utility can perform this sharding during the prefetch phase using the vllm engine.
To do this, provide a ShardConfig and ensure the task has sufficient GPU resources.
from flyte import Resources
from flyte.prefetch import hf_model, ShardConfig, VLLMShardArgs
# Configure sharding for 4 GPUs
shard_config = ShardConfig(
engine="vllm",
args=VLLMShardArgs(
tensor_parallel_size=4,
dtype="auto",
max_model_len=8192
)
)
# Prefetch and shard the model
run = hf_model(
repo="meta-llama/Llama-3.1-70B",
shard_config=shard_config,
resources=Resources(
cpu="32",
memory="256Gi",
gpu="A100:4",
disk="500Gi"
)
)
run.wait()
In this step:
VLLMShardArgsdefines how the model should be sharded.tensor_parallel_size=4tells vLLM to split the model for 4-way parallelism.Resourcesmust include the GPUs required for the sharding process. Thehf_modelfunction dynamically builds a container image with the necessary CUDA toolkit and vLLM dependencies to perform the sharding.- The resulting artifact in Flyte storage will contain the sharded
.safetensorsfiles, ready to be loaded by a vLLM inference server without further processing.
Step 3: Using the Prefetched Model in a Workflow
Once a model is prefetched, you can pass the resulting Dir (directory) to other tasks. The Run object returned by hf_model provides access to these outputs.
from flyte import task, workflow
@task
def generate_text(model_dir: flyte.io.Dir, prompt: str) -> str:
# Use the prefetched model path for inference
print(f"Loading model from {model_dir.path}")
# ... inference logic ...
return "Generated response"
@workflow
def inference_workflow(prompt: str):
# Trigger prefetch (or use cached result)
prefetch_run = hf_model(repo="meta-llama/Llama-3.1-8B")
# Pass the output Dir to the inference task
generate_text(model_dir=prefetch_run.outputs()[0], prompt=prompt)
Automation via CLI
You can also trigger model prefetching directly from the command line using the flyte CLI. This is useful for CI/CD pipelines or manual environment setup.
Basic prefetch:
flyte prefetch hf-model meta-llama/Llama-3.1-8B --wait
Prefetch with sharding:
First, create a shard_config.yaml:
engine: vllm
args:
tensor_parallel_size: 2
dtype: float16
Then run the command:
flyte prefetch hf-model meta-llama/Llama-3.1-8B \
--shard-config shard_config.yaml \
--gpu L4:2 \
--wait
Important Considerations
- Artifact Naming: By default, the artifact name is derived from the repo (e.g.,
meta-llama/Llama-3.1-8BbecomesLlama-3-1-8B). You can override this using theartifact_nameparameter. - Caching: Prefetch runs are cached by Flyte. If you call
hf_modelwith the same parameters, it will return the existingRunimmediately. To force a re-prefetch, use theforceparameter (e.g.,force=1). - Storage Path: You can specify a custom object store path using
raw_data_path. If not provided, Flyte uses its default metadata storage.