Experiment Tracking with MLflow
Integrate MLflow into your Flyte tasks to track parameters, metrics, and artifacts. This plugin manages the MLflow run lifecycle, supports autologging, and automatically generates UI links in the Flyte console.
Basic Task Tracking
To track a task, use the @mlflow_run decorator. It must be the outermost decorator, placed above @env.task.
from flyteplugins.mlflow import mlflow_run, get_mlflow_run
import flyte
@mlflow_run(
tracking_uri="http://mlflow.example.com",
experiment_name="my-experiment",
tags={"team": "ml-ops"}
)
@env.task
async def train_task(learning_rate: float):
import mlflow
# Log manually using standard MLflow API
mlflow.log_param("lr", learning_rate)
mlflow.log_metric("accuracy", 0.95)
# Access the active run object if needed
run = get_mlflow_run()
print(f"Active Run ID: {run.info.run_id}")
Automatic Logging (Autologging)
Enable autologging for specific frameworks (e.g., sklearn, pytorch, tensorflow) by setting autolog=True.
from flyteplugins.mlflow import mlflow_run
@mlflow_run(
autolog=True,
framework="sklearn",
log_models=True,
log_datasets=False
)
@env.task
async def train_sklearn_model():
from sklearn.linear_model import LogisticRegression
import numpy as np
X = np.random.randn(100, 4)
y = (X[:, 0] > 0).astype(int)
# Autolog captures parameters, metrics, and the model automatically
model = LogisticRegression()
model.fit(X, y)
Configuring Global Tracking
Use mlflow_config() within flyte.with_runcontext to set tracking configurations for an entire workflow or execution. This avoids hardcoding URIs in every task.
from flyteplugins.mlflow import mlflow_config
import flyte
# Define global configuration
ml_ctx = mlflow_config(
tracking_uri="http://mlflow.example.com",
experiment_name="/shared/experiments/project-a",
link_host="http://mlflow.example.com",
link_template="{host}/#/experiments/{experiment_id}/runs/{run_id}"
)
# Run workflow with the context
run = flyte.with_runcontext(custom_context=ml_ctx).run(my_workflow)
Generating MLflow UI Links
The Mlflow link class allows you to see a direct link to the MLflow UI from the Flyte console. It resolves the URL using the link_host and link_template provided in the configuration.
from flyteplugins.mlflow import Mlflow, mlflow_run
@mlflow_run
@env.task(links=[Mlflow()])
async def tracked_task_with_link():
import mlflow
mlflow.log_metric("score", 1.0)
Managing Run Hierarchy
The run_mode parameter in mlflow_run or mlflow_config controls how tasks interact with existing runs:
auto(default): Reuses the parent's MLflow run if one exists; otherwise, creates a new one.new: Always creates a fresh, independent MLflow run.nested: Creates a child run under the current parent run.
from flyteplugins.mlflow import mlflow_run, mlflow_config
@mlflow_run(run_mode="new")
@env.task
async def independent_subtask():
# This task will always have its own unique MLflow run
pass
@mlflow_run
@env.task
async def parent_task():
# Create a nested run for a specific block of code
with mlflow_config(run_mode="nested", tags={"type": "hpo-iteration"}):
await independent_subtask()
Distributed Training
The plugin automatically handles distributed training environments. It detects the process rank using the RANK environment variable and ensures that only rank 0 logs to the MLflow tracking server to prevent race conditions and duplicate data.
Troubleshooting
Decorator Order
The @mlflow_run decorator must be placed above @env.task. If placed below, the MLflow context will not be correctly initialized when the task execution starts.
Correct:
@mlflow_run(...)
@env.task
async def my_task(): ...
Incorrect:
@env.task
@mlflow_run(...)
async def my_task(): ...
Missing UI Links
If the MLflow link does not appear in the Flyte console:
- Ensure
Mlflow()is included in thelinkslist of the@env.taskdecorator. - Verify that
link_hostis set in yourmlflow_config. - If using a custom UI (like Databricks), ensure the
link_templatematches your platform's URL structure (e.g.,{host}/ml/experiments/{experiment_id}/runs/{run_id}).