log_mmm#

pymc_marketing.mlflow.log_mmm(mmm, artifact_path='model', registered_model_name=None, extend_idata=False, combined=True, include_last_observations=False, original_scale=True)[source]#

Log a PyMC-Marketing MMM as a native MLflow model for the current run.

Parameters:
mmmMMM

The MMM to be logged.

artifact_pathstr, optional

The path to the artifact to be logged. Defaults to “mmm_model”.

conda_envdict, optional

A dictionary representation of a Conda environment. Defaults to the default conda environment.

registered_model_namestr, optional

The name of the registered model to be logged. Defaults to None. If specified, the model will be registered under this name, otherwise it will not be registered.

extend_idatabool, optional

Whether to extend the inference data with predictions. Used for all prediction methods. Defaults to False.

combinedbool, optional

Whether to combine chain and draw dims into sample. Won’t work if a dim named sample already exists. Used for posterior/prior predictive sampling. Defaults to True.

include_last_observationsbool, optional

Whether to include the last observations of training data for adstock transformation. Assumes X are next predictions following training data. Used for all prediction methods. Defaults to False.

original_scalebool, optional

Whether to return predictions in original scale of target variable. Used for all prediction methods. Defaults to True.

Notes

This function logs the model as a native MLflow model, this is different to the full model object, which includes the InferenceData. Doing this allows for the model to be stored in the MLFlow registry, helping with model versioning and deployment.

Examples

MLFlow Registering for a PyMC-Marketing MMM:

import pandas as pd

import mlflow

from pymc_marketing.mmm import (
    GeometricAdstock,
    LogisticSaturation,
    MMM,
)
import pymc_marketing.mlflow
from pymc_marketing.mlflow import log_mmm

pymc_marketing.mlflow.autolog(log_mmm=True)

# Usual PyMC-Marketing model code

data_url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/mmm_example.csv"
data = pd.read_csv(data_url, parse_dates=["date_week"])

X = data.drop("y",axis=1)
y = data["y"]

mmm = MMM(
    adstock=GeometricAdstock(l_max=8),
    saturation=LogisticSaturation(),
    date_column="date_week",
    channel_columns=["x1", "x2"],
    control_columns=[
        "event_1",
        "event_2",
        "t",
    ],
    yearly_seasonality=2,
)

mlflow.set_experiment("MMM Experiment")

with mlflow.start_run():
    idata = mmm.fit(X, y)

    # Additional specific logging
    fig = mmm.plot_components_contributions()
    mlflow.log_figure(fig, "components.png")

    model_info = log_mmm(
        mmm=mmm,
        registered_model_name="my_amazing_mmm",
        include_last_observations=True,
        original_scale=False,
    )