Skip to main content

Hyperparameter Sweeps with W&B

Hyperparameter sweeps allow you to automate the search for the best model parameters by running multiple trials with different configurations. By integrating Weights & Biases (W&B) with Flyte, you can distribute these trials across a cluster, manage them through the W&B dashboard, and automatically track results in the Flyte UI.

In this tutorial, you will build a parallel hyperparameter sweep that launches multiple agents across different Flyte tasks to optimize a simulated training objective.

Prerequisites

Before starting, ensure you have the following:

  1. W&B API Key: You must have a W&B account and an API key.
  2. Flyte Secret: Configure a secret named wandb_api_key in your Flyte environment to be exposed as the WANDB_API_KEY environment variable.
  3. Plugin Installed: The flyteplugins-wandb package must be installed in your task environment.

Step 1: Define the Objective Function

The objective function is the code that W&B agents will execute for each trial. It receives hyperparameters from W&B and logs metrics back to the platform. Use the @wandb_init decorator to automatically handle W&B run initialization.

import wandb
import time
from flyteplugins.wandb import wandb_init

@wandb_init
def objective():
"""Objective function for W&B sweep - trains a model with hyperparameters."""
run = wandb.run
config = run.config

print(f"Training with lr={config.learning_rate}, batch_size={config.batch_size}")

# Simulate training loop
for epoch in range(config.epochs):
# Simulate training metrics
loss = 1.0 / (config.learning_rate * config.batch_size) + epoch * 0.1
accuracy = min(0.95, config.learning_rate * config.batch_size * (epoch + 1) * 0.01)

run.log({
"epoch": epoch,
"loss": loss,
"accuracy": accuracy,
})
time.sleep(0.5)

The @wandb_init decorator ensures that wandb.init() is called correctly within the Flyte task, using the configuration provided by the sweep controller.

Step 2: Create the Sweep Agent Task

A sweep agent is a Flyte task that pulls trials from the W&B cloud controller and executes the objective function. By using the @wandb_sweep decorator, Flyte automatically adds a "Weights & Biases Sweep" link to the task in the UI.

import flyte
from flyteplugins.wandb import wandb_sweep, get_wandb_context

# Define the task environment with the necessary secret
env = flyte.TaskEnvironment(
name="wandb-sweep",
image=flyte.Image.from_debian_base(name="wandb-sweep").with_pip_packages("flyteplugins-wandb"),
secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)

@wandb_sweep
@env.task
async def sweep_agent(agent_id: int, sweep_id: str, count: int = 5) -> int:
"""
Single sweep agent that pulls trials from W&B cloud controller.
"""
print(f"[Agent {agent_id}] Starting agent for sweep {sweep_id}")

# Run the agent
wandb.agent(
sweep_id,
function=objective,
count=count,
project=get_wandb_context().project
)

return agent_id

Step 3: Orchestrate Parallel Agents

To run trials in parallel, create an orchestrator task that generates the sweep and launches multiple agents. Use get_wandb_sweep_id() to retrieve the ID of the sweep created by the decorator.

import asyncio
from datetime import timedelta
from flyteplugins.wandb import wandb_sweep, get_wandb_sweep_id

@wandb_sweep
@env.task
async def run_parallel_sweep(total_trials: int = 15, trials_per_agent: int = 5) -> str:
# Retrieve the sweep ID created by the @wandb_sweep decorator
sweep_id = get_wandb_sweep_id()

# Calculate how many agents to launch
num_agents = (total_trials + trials_per_agent - 1) // trials_per_agent

# Launch multiple agents in parallel as separate Flyte tasks
agent_tasks = [
sweep_agent.override(
resources=flyte.Resources(cpu="1", memory="2Gi"),
retries=2,
timeout=timedelta(minutes=30),
)(agent_id=i + 1, sweep_id=sweep_id, count=trials_per_agent)
for i in range(num_agents)
]

# Wait for all agents to complete
await asyncio.gather(*agent_tasks)

return sweep_id

Step 4: Configure and Execute the Sweep

Finally, define the sweep configuration (search method, metrics, and parameter ranges) and execute the orchestrator. Use wandb_sweep_config and flyte.with_runcontext to pass this configuration into the Flyte execution.

from flyteplugins.wandb import wandb_config, wandb_sweep_config

if __name__ == "__main__":
flyte.init_from_config()

# Define the sweep and project configuration
run = flyte.with_runcontext(
custom_context={
**wandb_config(project="my-flyte-project", entity="my-team"),
**wandb_sweep_config(
method="random",
metric={"name": "loss", "goal": "minimize"},
parameters={
"learning_rate": {"min": 0.0001, "max": 0.1},
"batch_size": {"values": [16, 32, 64]},
"epochs": {"value": 5},
},
),
},
).run(run_parallel_sweep, total_trials=10, trials_per_agent=2)

print(f"Sweep Execution URL: {run.url}")

Step 5: Automatic Log Downloading

If you want to archive the logs for all runs in a sweep back to Flyte's metadata store, set download_logs=True in the @wandb_sweep decorator or the wandb_sweep_config.

@wandb_sweep(download_logs=True)
@env.task
async def run_sweep_with_logs() -> str:
sweep_id = get_wandb_sweep_id()
# ... launch agents ...
return sweep_id

When download_logs is enabled, the plugin will automatically call download_wandb_sweep_logs(sweep_id) after the task completes. This downloads all run files from the sweep and attaches them as a directory output, which can be viewed in the Flyte UI's "Trace" or "Outputs" section.

Summary

By following this tutorial, you have:

  1. Defined a training objective using @wandb_init.
  2. Created a distributed agent task using @wandb_sweep.
  3. Orchestrated parallel execution of agents using asyncio and get_wandb_sweep_id.
  4. Configured hyperparameter ranges using wandb_sweep_config.

The WandbSweep link will now appear in your Flyte UI for every task involved in the sweep, providing a direct path to the W&B dashboard for real-time monitoring.