MMMWrapper#

class pymc_marketing.mlflow.MMMWrapper(model, predict_method='predict', extend_idata=False, combined=True, include_last_observations=False, original_scale=True, var_names=None, **sample_kwargs)[source]#

A class to prepare a PyMC Marketing Mix Model (MMM) for logging and registering in MLflow.

This class extends MLflow’s PythonModel to handle prediction tasks using a PyMC-based MMM. It supports several prediction methods, including point-prediction, posterior and prior predictive sampling.

Parameters:
modelpymc_marketing.mmm.MMM

The marketing mix model to be registered and used for predictions.

predict_methodstr, optional, default=”predict”

The default prediction method to use, such as “predict”, “sample_posterior_predictive”, or “sample_prior_predictive”.

extend_idatabool, default=False

Boolean determining whether the predictions should be added to inference data object. Defaults to False.

combinedbool, default=True

Combine chain and draw dims into sample. Won’t work if a dim named sample already exists. Defaults to True.

include_last_observationsbool, default=False

Boolean determining whether to include the last observations of the training data in order to carry over costs with the adstock transformation. Assumes that X are the next predictions following the training data. Defaults to False.

original_scalebool, default=True

Boolean determining whether to return the predictions in the original scale of the target variable.

var_nameslist of str, optional, default=None

The variable names to include in the predictions.

sample_kwargsdict, optional

Additional keyword arguments to pass to the selected sampling methods.

Methods

MMMWrapper.__init__(model[, predict_method, ...])

MMMWrapper.load_context(context)

Loads artifacts from the specified PythonModelContext that can be used by predict() when evaluating inputs.

MMMWrapper.predict(context, model_input[, ...])

Perform predictions or sampling using the specified prediction method.

MMMWrapper.predict_stream(context, model_input)

Evaluates a pyfunc-compatible input and produces an iterator of output.

Attributes

predict_type_hints

Internal method to get type hints from the predict function signature.