Hyperparameter Sweeps with W&B
Hyperparameter sweeps allow you to automate the search for the best model parameters by running multiple trials with different configurations. By integrating Weights & Biases (W&B) with Flyte, you can distribute these trials across a cluster, manage them through the W&B dashboard, and automatically track results in the Flyte UI.
In this tutorial, you will build a parallel hyperparameter sweep that launches multiple agents across different Flyte tasks to optimize a simulated training objective.
Prerequisites
Before starting, ensure you have the following:
- W&B API Key: You must have a W&B account and an API key.
- Flyte Secret: Configure a secret named
wandb_api_keyin your Flyte environment to be exposed as theWANDB_API_KEYenvironment variable. - Plugin Installed: The
flyteplugins-wandbpackage must be installed in your task environment.
Step 1: Define the Objective Function
The objective function is the code that W&B agents will execute for each trial. It receives hyperparameters from W&B and logs metrics back to the platform. Use the @wandb_init decorator to automatically handle W&B run initialization.
import wandb
import time
from flyteplugins.wandb import wandb_init
@wandb_init
def objective():
"""Objective function for W&B sweep - trains a model with hyperparameters."""
run = wandb.run
config = run.config
print(f"Training with lr={config.learning_rate}, batch_size={config.batch_size}")
# Simulate training loop
for epoch in range(config.epochs):
# Simulate training metrics
loss = 1.0 / (config.learning_rate * config.batch_size) + epoch * 0.1
accuracy = min(0.95, config.learning_rate * config.batch_size * (epoch + 1) * 0.01)
run.log({
"epoch": epoch,
"loss": loss,
"accuracy": accuracy,
})
time.sleep(0.5)
The @wandb_init decorator ensures that wandb.init() is called correctly within the Flyte task, using the configuration provided by the sweep controller.
Step 2: Create the Sweep Agent Task
A sweep agent is a Flyte task that pulls trials from the W&B cloud controller and executes the objective function. By using the @wandb_sweep decorator, Flyte automatically adds a "Weights & Biases Sweep" link to the task in the UI.
import flyte
from flyteplugins.wandb import wandb_sweep, get_wandb_context
# Define the task environment with the necessary secret
env = flyte.TaskEnvironment(
name="wandb-sweep",
image=flyte.Image.from_debian_base(name="wandb-sweep").with_pip_packages("flyteplugins-wandb"),
secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)
@wandb_sweep
@env.task
async def sweep_agent(agent_id: int, sweep_id: str, count: int = 5) -> int:
"""
Single sweep agent that pulls trials from W&B cloud controller.
"""
print(f"[Agent {agent_id}] Starting agent for sweep {sweep_id}")
# Run the agent
wandb.agent(
sweep_id,
function=objective,
count=count,
project=get_wandb_context().project
)
return agent_id
Step 3: Orchestrate Parallel Agents
To run trials in parallel, create an orchestrator task that generates the sweep and launches multiple agents. Use get_wandb_sweep_id() to retrieve the ID of the sweep created by the decorator.
import asyncio
from datetime import timedelta
from flyteplugins.wandb import wandb_sweep, get_wandb_sweep_id
@wandb_sweep
@env.task
async def run_parallel_sweep(total_trials: int = 15, trials_per_agent: int = 5) -> str:
# Retrieve the sweep ID created by the @wandb_sweep decorator
sweep_id = get_wandb_sweep_id()
# Calculate how many agents to launch
num_agents = (total_trials + trials_per_agent - 1) // trials_per_agent
# Launch multiple agents in parallel as separate Flyte tasks
agent_tasks = [
sweep_agent.override(
resources=flyte.Resources(cpu="1", memory="2Gi"),
retries=2,
timeout=timedelta(minutes=30),
)(agent_id=i + 1, sweep_id=sweep_id, count=trials_per_agent)
for i in range(num_agents)
]
# Wait for all agents to complete
await asyncio.gather(*agent_tasks)
return sweep_id
Step 4: Configure and Execute the Sweep
Finally, define the sweep configuration (search method, metrics, and parameter ranges) and execute the orchestrator. Use wandb_sweep_config and flyte.with_runcontext to pass this configuration into the Flyte execution.
from flyteplugins.wandb import wandb_config, wandb_sweep_config
if __name__ == "__main__":
flyte.init_from_config()
# Define the sweep and project configuration
run = flyte.with_runcontext(
custom_context={
**wandb_config(project="my-flyte-project", entity="my-team"),
**wandb_sweep_config(
method="random",
metric={"name": "loss", "goal": "minimize"},
parameters={
"learning_rate": {"min": 0.0001, "max": 0.1},
"batch_size": {"values": [16, 32, 64]},
"epochs": {"value": 5},
},
),
},
).run(run_parallel_sweep, total_trials=10, trials_per_agent=2)
print(f"Sweep Execution URL: {run.url}")
Step 5: Automatic Log Downloading
If you want to archive the logs for all runs in a sweep back to Flyte's metadata store, set download_logs=True in the @wandb_sweep decorator or the wandb_sweep_config.
@wandb_sweep(download_logs=True)
@env.task
async def run_sweep_with_logs() -> str:
sweep_id = get_wandb_sweep_id()
# ... launch agents ...
return sweep_id
When download_logs is enabled, the plugin will automatically call download_wandb_sweep_logs(sweep_id) after the task completes. This downloads all run files from the sweep and attaches them as a directory output, which can be viewed in the Flyte UI's "Trace" or "Outputs" section.
Summary
By following this tutorial, you have:
- Defined a training objective using
@wandb_init. - Created a distributed agent task using
@wandb_sweep. - Orchestrated parallel execution of agents using
asyncioandget_wandb_sweep_id. - Configured hyperparameter ranges using
wandb_sweep_config.
The WandbSweep link will now appear in your Flyte UI for every task involved in the sweep, providing a direct path to the W&B dashboard for real-time monitoring.