Skip to main content

Experiment Tracking with MLflow

Integrate MLflow into your Flyte tasks to track parameters, metrics, and artifacts. This plugin manages the MLflow run lifecycle, supports autologging, and automatically generates UI links in the Flyte console.

Basic Task Tracking

To track a task, use the @mlflow_run decorator. It must be the outermost decorator, placed above @env.task.

from flyteplugins.mlflow import mlflow_run, get_mlflow_run
import flyte

@mlflow_run(
tracking_uri="http://mlflow.example.com",
experiment_name="my-experiment",
tags={"team": "ml-ops"}
)
@env.task
async def train_task(learning_rate: float):
import mlflow

# Log manually using standard MLflow API
mlflow.log_param("lr", learning_rate)
mlflow.log_metric("accuracy", 0.95)

# Access the active run object if needed
run = get_mlflow_run()
print(f"Active Run ID: {run.info.run_id}")

Automatic Logging (Autologging)

Enable autologging for specific frameworks (e.g., sklearn, pytorch, tensorflow) by setting autolog=True.

from flyteplugins.mlflow import mlflow_run

@mlflow_run(
autolog=True,
framework="sklearn",
log_models=True,
log_datasets=False
)
@env.task
async def train_sklearn_model():
from sklearn.linear_model import LogisticRegression
import numpy as np

X = np.random.randn(100, 4)
y = (X[:, 0] > 0).astype(int)

# Autolog captures parameters, metrics, and the model automatically
model = LogisticRegression()
model.fit(X, y)

Configuring Global Tracking

Use mlflow_config() within flyte.with_runcontext to set tracking configurations for an entire workflow or execution. This avoids hardcoding URIs in every task.

from flyteplugins.mlflow import mlflow_config
import flyte

# Define global configuration
ml_ctx = mlflow_config(
tracking_uri="http://mlflow.example.com",
experiment_name="/shared/experiments/project-a",
link_host="http://mlflow.example.com",
link_template="{host}/#/experiments/{experiment_id}/runs/{run_id}"
)

# Run workflow with the context
run = flyte.with_runcontext(custom_context=ml_ctx).run(my_workflow)

The Mlflow link class allows you to see a direct link to the MLflow UI from the Flyte console. It resolves the URL using the link_host and link_template provided in the configuration.

from flyteplugins.mlflow import Mlflow, mlflow_run

@mlflow_run
@env.task(links=[Mlflow()])
async def tracked_task_with_link():
import mlflow
mlflow.log_metric("score", 1.0)

Managing Run Hierarchy

The run_mode parameter in mlflow_run or mlflow_config controls how tasks interact with existing runs:

  • auto (default): Reuses the parent's MLflow run if one exists; otherwise, creates a new one.
  • new: Always creates a fresh, independent MLflow run.
  • nested: Creates a child run under the current parent run.
from flyteplugins.mlflow import mlflow_run, mlflow_config

@mlflow_run(run_mode="new")
@env.task
async def independent_subtask():
# This task will always have its own unique MLflow run
pass

@mlflow_run
@env.task
async def parent_task():
# Create a nested run for a specific block of code
with mlflow_config(run_mode="nested", tags={"type": "hpo-iteration"}):
await independent_subtask()

Distributed Training

The plugin automatically handles distributed training environments. It detects the process rank using the RANK environment variable and ensures that only rank 0 logs to the MLflow tracking server to prevent race conditions and duplicate data.

Troubleshooting

Decorator Order

The @mlflow_run decorator must be placed above @env.task. If placed below, the MLflow context will not be correctly initialized when the task execution starts.

Correct:

@mlflow_run(...)
@env.task
async def my_task(): ...

Incorrect:

@env.task
@mlflow_run(...)
async def my_task(): ...

If the MLflow link does not appear in the Flyte console:

  1. Ensure Mlflow() is included in the links list of the @env.task decorator.
  2. Verify that link_host is set in your mlflow_config.
  3. If using a custom UI (like Databricks), ensure the link_template matches your platform's URL structure (e.g., {host}/ml/experiments/{experiment_id}/runs/{run_id}).