Skip to main content

Model Prefetching

Model prefetching allows you to download and optionally shard HuggingFace models into your Flyte storage before they are needed by inference tasks. This significantly reduces cold-start times for LLM workloads and enables efficient multi-GPU inference by pre-calculating tensor-parallel shards.

In this tutorial, you will learn how to prefetch a model for standard use and how to shard a large model for high-performance inference.

Prerequisites

  • Flyte SDK: Ensure you have flyte-sdk installed.
  • HuggingFace Token: You must have a HuggingFace token with access to the models you wish to prefetch.
  • Flyte Secret: Store your HuggingFace token as a Flyte secret. By default, the prefetch utility looks for a secret named HF_TOKEN.

Step 1: Basic Model Prefetching

The simplest way to prefetch a model is to use the hf_model function. This will download the model files from HuggingFace and stream them directly to your Flyte remote storage.

import flyte
from flyte.prefetch import hf_model

# Initialize your Flyte connection
flyte.init(endpoint="your-flyte-endpoint")

# Prefetch a model to remote storage
run = hf_model(
repo="meta-llama/Llama-3.1-8B",
hf_token_key="HF_TOKEN"
)

# Wait for the prefetch task to complete
run.wait()

# Access the remote path of the prefetched model
model_dir = run.outputs()[0]
print(f"Model stored at: {model_dir.path}")

When you call hf_model, Flyte triggers a remote task that:

  1. Authenticates with HuggingFace using the secret provided in hf_token_key.
  2. Attempts to stream the model files directly to your remote object store (e.g., S3 or GCS).
  3. Falls back to a local download and upload if streaming is not supported by the storage backend.

Step 2: Sharding Large Models with vLLM

For large models (like Llama-3.1-70B), you often need to shard the weights across multiple GPUs using tensor parallelism. The hf_model utility can perform this sharding during the prefetch phase using the vllm engine.

To do this, provide a ShardConfig and ensure the task has sufficient GPU resources.

from flyte import Resources
from flyte.prefetch import hf_model, ShardConfig, VLLMShardArgs

# Configure sharding for 4 GPUs
shard_config = ShardConfig(
engine="vllm",
args=VLLMShardArgs(
tensor_parallel_size=4,
dtype="auto",
max_model_len=8192
)
)

# Prefetch and shard the model
run = hf_model(
repo="meta-llama/Llama-3.1-70B",
shard_config=shard_config,
resources=Resources(
cpu="32",
memory="256Gi",
gpu="A100:4",
disk="500Gi"
)
)

run.wait()

In this step:

  • VLLMShardArgs defines how the model should be sharded. tensor_parallel_size=4 tells vLLM to split the model for 4-way parallelism.
  • Resources must include the GPUs required for the sharding process. The hf_model function dynamically builds a container image with the necessary CUDA toolkit and vLLM dependencies to perform the sharding.
  • The resulting artifact in Flyte storage will contain the sharded .safetensors files, ready to be loaded by a vLLM inference server without further processing.

Step 3: Using the Prefetched Model in a Workflow

Once a model is prefetched, you can pass the resulting Dir (directory) to other tasks. The Run object returned by hf_model provides access to these outputs.

from flyte import task, workflow

@task
def generate_text(model_dir: flyte.io.Dir, prompt: str) -> str:
# Use the prefetched model path for inference
print(f"Loading model from {model_dir.path}")
# ... inference logic ...
return "Generated response"

@workflow
def inference_workflow(prompt: str):
# Trigger prefetch (or use cached result)
prefetch_run = hf_model(repo="meta-llama/Llama-3.1-8B")

# Pass the output Dir to the inference task
generate_text(model_dir=prefetch_run.outputs()[0], prompt=prompt)

Automation via CLI

You can also trigger model prefetching directly from the command line using the flyte CLI. This is useful for CI/CD pipelines or manual environment setup.

Basic prefetch:

flyte prefetch hf-model meta-llama/Llama-3.1-8B --wait

Prefetch with sharding: First, create a shard_config.yaml:

engine: vllm
args:
tensor_parallel_size: 2
dtype: float16

Then run the command:

flyte prefetch hf-model meta-llama/Llama-3.1-8B \
--shard-config shard_config.yaml \
--gpu L4:2 \
--wait

Important Considerations

  • Artifact Naming: By default, the artifact name is derived from the repo (e.g., meta-llama/Llama-3.1-8B becomes Llama-3-1-8B). You can override this using the artifact_name parameter.
  • Caching: Prefetch runs are cached by Flyte. If you call hf_model with the same parameters, it will return the existing Run immediately. To force a re-prefetch, use the force parameter (e.g., force=1).
  • Storage Path: You can specify a custom object store path using raw_data_path. If not provided, Flyte uses its default metadata storage.