log_mmm#
- pymc_marketing.mlflow.log_mmm(mmm, artifact_path='model', registered_model_name=None, extend_idata=False, combined=True, include_last_observations=False, original_scale=True)[source]#
Log a PyMC-Marketing MMM as a native MLflow model for the current run.
- Parameters:
- mmm
MMM
The MMM to be logged.
- artifact_path
str
, optional The path to the artifact to be logged. Defaults to “mmm_model”.
- conda_env
dict
, optional A dictionary representation of a Conda environment. Defaults to the default conda environment.
- registered_model_name
str
, optional The name of the registered model to be logged. Defaults to None. If specified, the model will be registered under this name, otherwise it will not be registered.
- extend_idatabool, optional
Whether to extend the inference data with predictions. Used for all prediction methods. Defaults to False.
- combinedbool, optional
Whether to combine chain and draw dims into sample. Won’t work if a dim named sample already exists. Used for posterior/prior predictive sampling. Defaults to True.
- include_last_observationsbool, optional
Whether to include the last observations of training data for adstock transformation. Assumes X are next predictions following training data. Used for all prediction methods. Defaults to False.
- original_scalebool, optional
Whether to return predictions in original scale of target variable. Used for all prediction methods. Defaults to True.
- mmm
Notes
This function logs the model as a native MLflow model, this is different to the full model object, which includes the InferenceData. Doing this allows for the model to be stored in the MLFlow registry, helping with model versioning and deployment.
Examples
MLFlow Registering for a PyMC-Marketing MMM:
import pandas as pd import mlflow from pymc_marketing.mmm import ( GeometricAdstock, LogisticSaturation, MMM, ) import pymc_marketing.mlflow from pymc_marketing.mlflow import log_mmm pymc_marketing.mlflow.autolog(log_mmm=True) # Usual PyMC-Marketing model code data_url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/mmm_example.csv" data = pd.read_csv(data_url, parse_dates=["date_week"]) X = data.drop("y",axis=1) y = data["y"] mmm = MMM( adstock=GeometricAdstock(l_max=8), saturation=LogisticSaturation(), date_column="date_week", channel_columns=["x1", "x2"], control_columns=[ "event_1", "event_2", "t", ], yearly_seasonality=2, ) mlflow.set_experiment("MMM Experiment") with mlflow.start_run(): idata = mmm.fit(X, y) # Additional specific logging fig = mmm.plot_components_contributions() mlflow.log_figure(fig, "components.png") model_info = log_mmm( mmm=mmm, registered_model_name="my_amazing_mmm", include_last_observations=True, original_scale=False, )