MMMWrapper#
- class pymc_marketing.mlflow.MMMWrapper(model, predict_method='predict', extend_idata=False, combined=True, include_last_observations=False, original_scale=True, var_names=None, **sample_kwargs)[source]#
A class to prepare a PyMC Marketing Mix Model (MMM) for logging and registering in MLflow.
This class extends MLflow’s PythonModel to handle prediction tasks using a PyMC-based MMM. It supports several prediction methods, including point-prediction, posterior and prior predictive sampling.
- Parameters:
- model
pymc_marketing.mmm.MMM
The marketing mix model to be registered and used for predictions.
- predict_method
str
, optional, default=”predict” The default prediction method to use, such as “predict”, “sample_posterior_predictive”, or “sample_prior_predictive”.
- extend_idatabool, default=False
Boolean determining whether the predictions should be added to inference data object. Defaults to False.
- combinedbool, default=True
Combine chain and draw dims into sample. Won’t work if a dim named sample already exists. Defaults to True.
- include_last_observationsbool, default=False
Boolean determining whether to include the last observations of the training data in order to carry over costs with the adstock transformation. Assumes that X are the next predictions following the training data. Defaults to False.
- original_scalebool, default=True
Boolean determining whether to return the predictions in the original scale of the target variable.
- var_names
list
ofstr
, optional, default=None The variable names to include in the predictions.
- sample_kwargs
dict
, optional Additional keyword arguments to pass to the selected sampling methods.
- model
Methods
MMMWrapper.__init__
(model[, predict_method, ...])MMMWrapper.load_context
(context)Loads artifacts from the specified
PythonModelContext
that can be used bypredict()
when evaluating inputs.MMMWrapper.predict
(context, model_input[, ...])Perform predictions or sampling using the specified prediction method.
MMMWrapper.predict_stream
(context, model_input)Evaluates a pyfunc-compatible input and produces an iterator of output.
Attributes
predict_type_hints
Internal method to get type hints from the predict function signature.