MMM Example Notebook#
In this notebook we work out a simulated example to showcase the media mix Model (MMM) API from pymc-marketing
. This package provides a pymc
implementation of the MMM presented in the paper Jin, Yuxue, et al. “Bayesian methods for media mix modeling with carryover and shape effects.” (2017). We work with synthetic data as we want to do parameter recovery to better understand the model assumptions. That is, we explicitly set values for our adstock and saturation parameters (see model specification below) and recover them back from the model. The data generation process is as an adaptation of the blog post “Media Effect Estimation with PyMC: Adstock, Saturation & Diminishing Returns” by Juan Orduz.
Business Problem#
Before jumping into the data, let’s first define the business problem we are trying to solve. We are a marketing agency and we want to optimize the marketing budget of a client. We have access to the following data:
Sales data: weekly sales of the client.
Media spend data: weekly spend on different media channels (e.g. TV, radio, online, etc.). In this example we consider 2 media channels: \(x_{1}\) and \(x_{2}\).
Domain knowledge:
We know that there has a been an positive sales trend which we believe comes from a strong economic growth.
We also know that there is a yearly seasonality effect.
In addition, we were informed about two outliers in the data during the weeks
2019-05-13
and2021-09-14
.
Given this information we can draw a Directed Acyclic Graph (DAG) or graphical model of how we believe our variables are related. In other words, represent how we believe our system is causally related.
Show code cell source
import graphviz as gr
g = gr.Digraph()
g.node(name="Sales", label="Sales", color="deepskyblue", style="filled")
g.node(name="Marketing", label="Marketing", color="deeppink", style="filled")
g.edge(tail_name="Special Events", head_name="Sales")
g.edge(tail_name="Marketing", head_name="Sales")
g.edge(tail_name="Exogenous Variables", head_name="Sales")
g
In this example, we will consider a simple system where:
Marketing: It represents the actions generated by \(x_{1}\) and \(x_{2}\).
Special Events: Outliers on specific days, which are possibly given by special dates.
Exogenous Variables: We will consider variables that are determined by external factors, not determined in the model (E.g: Country economic growth or weather conditions that determine seasonal behavior).
Understanding this ecosystem is essential to create a model that reveals the true causal signals and allows us to optimize our advertising budget. But, What do we mean by optimize the marketing budget? We want to find the optimal media mix that maximizes sales.
Now, given the DAG outlined above, we understand that there is a causal relationship between marketing and sales, but what is the nature of that relationship? In this case, we will assume that this relationship is not linear, for example, a \(10\%\) increase in channel \(x_{1}\) spend does not necessarily translate into a \(10\%\) increase in sales. This can be explained by two phenomena:
On the one hand side, there is a carry-over effect. Meaning, the effect of spend on sales is not instantaneous but accumulates over time.
In addition, there is a saturation effect. Meaning, the effect of spend on sales is not linear but saturates at some point.
The equation implemented to describe the DAG presented above will be the one expressed in Jin, Yuxue, et al. “Bayesian methods for media mix modeling with carryover and shape effects.” (2017), adding a causal assumption around the media effects and their exclusively positive impact. Concretely, given a time series target variable \(y_{t}\) (e.g. sales or conversions), media variables \(x_{m, t}\) (e.g. impressions, clicks or costs) and a set of control covariates \(z_{c, t}\) (e.g. holidays, special events) we consider a linear model of the form
where \(\alpha\) is the intercept, \(f\) is a media transformation function and \(\varepsilon_{t}\) is the error term which we assume is normally distributed. The function \(f\) encodes the positive media contribution on the target variable. Typically we consider two types of transformation: adstock (carry-over) and saturation effects.
In PyMC-Marketing, we offer an API for a Bayesian Media Mix Model (MMM) with various specifications. In the example, we’ll implement Geometric Adstock
and Logistic Saturation
as the chosen transformations for our previously discussed Structural Causal Equation.
Tip
The MMM model in pymc-marketing
provides additional features on top of this base model:
Experiment Calibration: We have the option to add empirical experiments (lift tests) to calibrate the model using custom likelihood functions. See Lift Test Calibration.
Time-varying Intercept: Capture time-varying baseline contributions in your model (using modern and efficient Gaussian processes approximation methods). That is, we allow the intercept term \(\alpha = \alpha(t)\) to vary over time. See mmm_tvp_example.
Budget Optimization: Allocate your marketing budget based on the parameters recover by the model, finding the spend distribution to maximizes the amount of contribution given a limited budget. See Budget Allocation with PyMC-Marketing.
References:#
Part I: Data Generation Process#
In Part I of this notebook we focus on the data generating process. That is, we want to construct the target variable \(y_{t}\) (sales) by adding each of the components described in the Business Problem section.
Prepare Notebook#
import warnings
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from pymc_marketing.mmm import MMM, GeometricAdstock, LogisticSaturation
from pymc_marketing.mmm.transformers import geometric_adstock, logistic_saturation
from pymc_marketing.prior import Prior
warnings.filterwarnings("ignore", category=FutureWarning)
az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
Generate Data#
1. Date Range#
First we set a time range for our data. We consider a bit more than 2 years of data at weekly granularity.
seed: int = sum(map(ord, "mmm"))
rng: np.random.Generator = np.random.default_rng(seed=seed)
# date range
min_date = pd.to_datetime("2018-04-01")
max_date = pd.to_datetime("2021-09-01")
df = pd.DataFrame(
data={"date_week": pd.date_range(start=min_date, end=max_date, freq="W-MON")}
).assign(
year=lambda x: x["date_week"].dt.year,
month=lambda x: x["date_week"].dt.month,
dayofyear=lambda x: x["date_week"].dt.dayofyear,
)
n = df.shape[0]
print(f"Number of observations: {n}")
Number of observations: 179
2. Media Costs Data#
Now we generate synthetic data from two channels \(x_1\) and \(x_2\). We refer to it as the raw signal as it is going to be the input at the modeling phase. We expect the contribution of each channel to be different, based on the carryover and saturation parameters.
Raw Signal
# media data
x1 = rng.uniform(low=0.0, high=1.0, size=n)
df["x1"] = np.where(x1 > 0.9, x1, x1 / 2)
x2 = rng.uniform(low=0.0, high=1.0, size=n)
df["x2"] = np.where(x2 > 0.8, x2, 0)
fig, ax = plt.subplots(
nrows=2, ncols=1, figsize=(10, 7), sharex=True, sharey=True, layout="constrained"
)
sns.lineplot(x="date_week", y="x1", data=df, color="C0", ax=ax[0])
sns.lineplot(x="date_week", y="x2", data=df, color="C1", ax=ax[1])
ax[1].set(xlabel="date")
fig.suptitle("Media Costs Data", fontsize=16);

Remark: By design, \(x_{1}\) should resemble a typical paid social channel and \(x_{2}\) a offline (e.g. TV) spend time series.
Effect Signal
Next, we pass the raw signal through the two transformations: first the geometric adstock (carryover effect) and then the logistic saturation. Note that we set the parameters ourselves, but we will recover them back from the model.
Let’s start with the adstock transformation. We set the adstock parameter \(0 < \alpha < 1\) to be \(0.4\) and \(0.2\) for \(x_1\) and \(x_2\) respectively. We set a maximum lag effect of \(8\) weeks.
# apply geometric adstock transformation
alpha1: float = 0.4
alpha2: float = 0.2
df["x1_adstock"] = (
geometric_adstock(x=df["x1"].to_numpy(), alpha=alpha1, l_max=8, normalize=True)
.eval()
.flatten()
)
df["x2_adstock"] = (
geometric_adstock(x=df["x2"].to_numpy(), alpha=alpha2, l_max=8, normalize=True)
.eval()
.flatten()
)
Next, we compose the resulting adstock signals with the logistic saturation function. We set the parameter \(\lambda > 0\) to be \(4\) and \(3\) for \(z_1\) and \(z_2\) respectively.
# apply saturation transformation
lam1: float = 4.0
lam2: float = 3.0
df["x1_adstock_saturated"] = logistic_saturation(
x=df["x1_adstock"].to_numpy(), lam=lam1
).eval()
df["x2_adstock_saturated"] = logistic_saturation(
x=df["x2_adstock"].to_numpy(), lam=lam2
).eval()
We can now visualize the effect signal for each channel after each transformation:
fig, ax = plt.subplots(
nrows=3, ncols=2, figsize=(16, 9), sharex=True, sharey=False, layout="constrained"
)
sns.lineplot(x="date_week", y="x1", data=df, color="C0", ax=ax[0, 0])
sns.lineplot(x="date_week", y="x2", data=df, color="C1", ax=ax[0, 1])
sns.lineplot(x="date_week", y="x1_adstock", data=df, color="C0", ax=ax[1, 0])
sns.lineplot(x="date_week", y="x2_adstock", data=df, color="C1", ax=ax[1, 1])
sns.lineplot(x="date_week", y="x1_adstock_saturated", data=df, color="C0", ax=ax[2, 0])
sns.lineplot(x="date_week", y="x2_adstock_saturated", data=df, color="C1", ax=ax[2, 1])
fig.suptitle("Media Costs Data - Transformed", fontsize=16);

3. Trend & Seasonal Components#
Now we add synthetic trend and seasonal components to the effect signal.
df["trend"] = (np.linspace(start=0.0, stop=50, num=n) + 10) ** (1 / 4) - 1
df["cs"] = -np.sin(2 * 2 * np.pi * df["dayofyear"] / 365.5)
df["cc"] = np.cos(1 * 2 * np.pi * df["dayofyear"] / 365.5)
df["seasonality"] = 0.5 * (df["cs"] + df["cc"])
fig, ax = plt.subplots()
sns.lineplot(x="date_week", y="trend", color="C2", label="trend", data=df, ax=ax)
sns.lineplot(
x="date_week", y="seasonality", color="C3", label="seasonality", data=df, ax=ax
)
ax.legend(loc="upper left")
ax.set(title="Trend & Seasonality Components", xlabel="date", ylabel=None);

4. Control Variables#
We add two events where there was a remarkable peak in our target variable. We assume they are independent and not seasonal (e.g. launch of a particular product).
df["event_1"] = (df["date_week"] == "2019-05-13").astype(float)
df["event_2"] = (df["date_week"] == "2020-09-14").astype(float)
5. Target Variable#
Finally, we define the target variable (sales) \(y\). We assume it is a linear combination of the effect signal, the trend and the seasonal components, plus the two events and an intercept. We also add some Gaussian noise.
df["intercept"] = 2.0
df["epsilon"] = rng.normal(loc=0.0, scale=0.25, size=n)
amplitude = 1
beta_1 = 3.0
beta_2 = 2.0
betas = [beta_1, beta_2]
df["y"] = amplitude * (
df["intercept"]
+ df["trend"]
+ df["seasonality"]
+ 1.5 * df["event_1"]
+ 2.5 * df["event_2"]
+ beta_1 * df["x1_adstock_saturated"]
+ beta_2 * df["x2_adstock_saturated"]
+ df["epsilon"]
)
fig, ax = plt.subplots()
sns.lineplot(x="date_week", y="y", color="black", data=df, ax=ax)
ax.set(title="Sales (Target Variable)", xlabel="date", ylabel="y (thousands)");

We can visualize the true component contributions over the historical period:
fig, ax = plt.subplots()
contributions = [
df["intercept"].sum(),
(beta_1 * df["x1_adstock_saturated"]).sum(),
(beta_2 * df["x2_adstock_saturated"]).sum(),
1.5 * df["event_1"].sum(),
2.5 * df["event_2"].sum(),
df["trend"].sum(),
df["seasonality"].sum(),
]
ax.bar(
["intercept", "x1", "x2", "event_1", "event_2", "trend", "seasonality"],
contributions,
color=["C0" if x >= 0 else "C3" for x in contributions],
alpha=0.8,
)
ax.bar_label(
ax.containers[0],
fmt="{:,.2f}",
label_type="edge",
padding=2,
fontsize=15,
fontweight="bold",
)
ax.set(title="Sales Attribution", ylabel="Sales (thousands)");

We would like to recover these values from the model.
6. Media Contribution Interpretation#
From the data generating process we can compute the relative contribution of each channel to the target variable. We will recover these values back from the model.
contribution_share_x1: float = (beta_1 * df["x1_adstock_saturated"]).sum() / (
beta_1 * df["x1_adstock_saturated"] + beta_2 * df["x2_adstock_saturated"]
).sum()
contribution_share_x2: float = (beta_2 * df["x2_adstock_saturated"]).sum() / (
beta_1 * df["x1_adstock_saturated"] + beta_2 * df["x2_adstock_saturated"]
).sum()
print(f"Contribution Share of x1: {contribution_share_x1:.2f}")
print(f"Contribution Share of x2: {contribution_share_x2:.2f}")
Contribution Share of x1: 0.81
Contribution Share of x2: 0.19
We can obtain the contribution plots for each channel where we clearly see the effect of the adstock and saturation transformations.
fig, ax = plt.subplots(
nrows=2, ncols=1, figsize=(12, 8), sharex=True, sharey=False, layout="constrained"
)
for i, x in enumerate(["x1", "x2"]):
sns.scatterplot(
x=df[x],
y=amplitude * betas[i] * df[f"{x}_adstock_saturated"],
color=f"C{i}",
ax=ax[i],
)
ax[i].set(
title=f"$x_{i + 1}$ contribution",
ylabel=f"$\\beta_{i + 1} \\cdot x_{i + 1}$ adstocked & saturated",
xlabel="x",
)

This plot shows some interesting aspects of the media contribution:
The adstock effect is reflected in the non-zero contribution of the channel even when the spend is zero.
One clearly see the saturation effect as the contribution growth (slope) decreases as the spend increases.
As we will see in Part II of this notebook, we will recover these plots from the model!
We see that channel \(x_{1}\) has a higher contribution than \(x_{2}\). This could be explained by the fact that there was more spend in channel \(x_{1}\) than in channel \(x_{2}\):
fig, ax = plt.subplots(figsize=(7, 5))
df[["x1", "x2"]].sum().plot(kind="bar", color=["C0", "C1"], ax=ax)
ax.set(title="Total Media Spend", xlabel="Media Channel", ylabel="Costs (thousands)");

However, one usually is not only interested in the contribution itself but rather the Return on Ad Spend (ROAS). That is, the contribution divided by the cost. We can compute the ROAS for each channel as follows:
roas_1 = (amplitude * beta_1 * df["x1_adstock_saturated"]).sum() / df["x1"].sum()
roas_2 = (amplitude * beta_2 * df["x2_adstock_saturated"]).sum() / df["x2"].sum()
fig, ax = plt.subplots(figsize=(7, 5))
(
pd.Series(data=[roas_1, roas_2], index=["x1", "x2"]).plot(
kind="bar", color=["C0", "C1"]
)
)
ax.set(title="ROAS (Approximation)", xlabel="Media Channel", ylabel="ROAS");

That is, channel \(x_{1}\) seems to be more efficient than channel \(x_{2}\).
Note
We recommended reading Section 4.1 in Jin, Yuxue, et al. “Bayesian methods for media mix modeling with carryover and shape effects.” (2017) for a detailed explanation of the ROAS (and mROAS). In particular:
If we transform our target variable \(y\) (e.g. with a log transformation), one needs to be careful with the ROAS computation as setting the spend to zero does not commute with the transformation.
One has to be careful with the adstock effect so that we include a carryover period to fully account for the effect of the spend. The ROAS estimation above is an approximation.
7. Data Output#
We of course will not have all of these features in our real data. Let’s filter out the features we will use for modeling:
columns_to_keep = [
"date_week",
"y",
"x1",
"x2",
"event_1",
"event_2",
"dayofyear",
]
data = df[columns_to_keep].copy()
data.head()
date_week | y | x1 | x2 | event_1 | event_2 | dayofyear | |
---|---|---|---|---|---|---|---|
0 | 2018-04-02 | 3.984662 | 0.318580 | 0.0 | 0.0 | 0.0 | 92 |
1 | 2018-04-09 | 3.762872 | 0.112388 | 0.0 | 0.0 | 0.0 | 99 |
2 | 2018-04-16 | 4.466967 | 0.292400 | 0.0 | 0.0 | 0.0 | 106 |
3 | 2018-04-23 | 3.864219 | 0.071399 | 0.0 | 0.0 | 0.0 | 113 |
4 | 2018-04-30 | 4.441625 | 0.386745 | 0.0 | 0.0 | 0.0 | 120 |
Part II: Modeling#
On this second part, we focus on the modeling process. We will use the data generated in Part I.
1. Feature Engineering#
Assuming we did an EDA and we have a good understanding of the data (we did not do it here as we generated the data ourselves, but please never skip the EDA!), we can start building our model. One thing we immediately see is the seasonality and the trend component. We can generate features ourselves as control variables, for example using a uniformly increasing straight line to model the trend component. In addition, we include dummy variables to encode the event_1
and event_2
contributions.
For the seasonality component we use Fourier modes (similar as in Prophet). We do not need to add the Fourier modes by hand as they are handled by the model API through the yearly_seasonality
argument (see below). We use 4 modes for the seasonality component.
# trend feature
data["t"] = range(n)
data.head()
date_week | y | x1 | x2 | event_1 | event_2 | dayofyear | t | |
---|---|---|---|---|---|---|---|---|
0 | 2018-04-02 | 3.984662 | 0.318580 | 0.0 | 0.0 | 0.0 | 92 | 0 |
1 | 2018-04-09 | 3.762872 | 0.112388 | 0.0 | 0.0 | 0.0 | 99 | 1 |
2 | 2018-04-16 | 4.466967 | 0.292400 | 0.0 | 0.0 | 0.0 | 106 | 2 |
3 | 2018-04-23 | 3.864219 | 0.071399 | 0.0 | 0.0 | 0.0 | 113 | 3 |
4 | 2018-04-30 | 4.441625 | 0.386745 | 0.0 | 0.0 | 0.0 | 120 | 4 |
2. Model Specification#
We can specify the model structure using the MMM
class. This class, handles a lot of internal boilerplate code for us such us scaling the data (see details below) and handy diagnostics and reporting plots. One great feature is that we can specify the channel priors distributions ourselves, which fundamental component of the bayesian workflow as we can incorporate our prior knowledge into the model. This is one of the most important advantages of using a bayesian approach. Let’s see how we can do it.
As we do not know much more about the channels, we start with a simple heuristic:
The channel contributions should be positive, so we can for example use a
HalfNormal
distribution as prior. We need to set thesigma
parameter per channel. The higher thesigma
, the more “freedom” it has to fit the data. To specifysigma
we can use the following point.We expect channels where we spend the most to have more attributed sales , before seeing the data. This is a very reasonable assumption (note that we are not imposing anything at the level of efficiency!).
How to incorporate this heuristic into the model? To begin with, it is important to note that the MMM
class scales the target and input variables through an MaxAbsScaler
transformer from scikit-learn
, its important to specify the priors in the scaled space (i.e. between 0 and 1). One way to do it is to use the spend share as the sigma
parameter for the HalfNormal
distribution. We can actually add a scaling factor to take into account the support of the distribution.
First, let’s compute the share of spend per channel:
total_spend_per_channel = data[["x1", "x2"]].sum(axis=0)
spend_share = total_spend_per_channel / total_spend_per_channel.sum()
spend_share
x1 0.65632
x2 0.34368
dtype: float64
Next, we specify the sigma
parameter per channel:
n_channels = 2
prior_sigma = n_channels * spend_share.to_numpy()
prior_sigma.tolist()
[1.3126390269400678, 0.687360973059932]
Delayed Saturated MMM follows sklearn convention, so we need to split our data into X (predictors) and y(target value)
X = data.drop("y", axis=1)
y = data["y"]
You can use the optional parameter ‘model_config’ to apply your own priors to the model. Each entry in the ‘model_config’ contains a key that corresponds to a registered distribution name in our model. The value of the key is a dictionary that describes the input parameters of that specific distribution.
If you’re unsure how to define your own priors, you can use the ‘default_model_config’ property of MMM
to see the required structure.
dummy_model = MMM(
date_column="",
channel_columns=[""],
adstock=GeometricAdstock(l_max=4),
saturation=LogisticSaturation(),
)
dummy_model.default_model_config
{'intercept': Prior("Normal", mu=0, sigma=2),
'likelihood': Prior("Normal", sigma=Prior("HalfNormal", sigma=2)),
'gamma_control': Prior("Normal", mu=0, sigma=2, dims="control"),
'gamma_fourier': Prior("Laplace", mu=0, b=1, dims="fourier_mode"),
'adstock_alpha': Prior("Beta", alpha=1, beta=3, dims="channel"),
'saturation_lam': Prior("Gamma", alpha=3, beta=1, dims="channel"),
'saturation_beta': Prior("HalfNormal", sigma=2, dims="channel")}
You can change only the prior parameters that you wish, no need to alter all of them, unless you’d like to!
my_model_config = {
"intercept": Prior("Normal", mu=0.5, sigma=0.2),
"saturation_beta": Prior("HalfNormal", sigma=prior_sigma),
"gamma_control": Prior("Normal", mu=0, sigma=0.05),
"gamma_fourier": Prior("Laplace", mu=0, b=0.2),
"likelihood": Prior("Normal", sigma=Prior("HalfNormal", sigma=6)),
}
Remark: For the prior specification there is no right or wrong answer. It all depends on the data, the context and the assumptions you are willing to make. It is always recommended to do some prior predictive sampling and sensitivity analysis to check the impact of the priors on the posterior. We skip this here for the sake of simplicity. If you are not sure about specific priors, the MMM
class has some default priors that you can use as a starting point.
Model sampler allows specifying set of parameters that will be passed to fit the same way as the kwargs
are getting passed so far. It doesn’t disable the fit kwargs, but rather extend them, to enable customizable and preservable configuration. By default the sampler_config for MMM
is empty. But if you’d like to use it, you can define it like showed below:
my_sampler_config = {"progressbar": True}
Now we are ready to use the MMM
class to define the model.
mmm = MMM(
model_config=my_model_config,
sampler_config=my_sampler_config,
date_column="date_week",
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
channel_columns=["x1", "x2"],
control_columns=["event_1", "event_2", "t"],
yearly_seasonality=2,
)
Observe how the media transformations were handled by the class MMM
.
To assess the model prior parameters we can look into the prior predictive plot:
# Generate prior predictive samples
mmm.sample_prior_predictive(X, y, samples=2_000)
fig, ax = plt.subplots()
mmm.plot_prior_predictive(ax=ax, original_scale=True)
ax.legend(loc="lower center", bbox_to_anchor=(0.5, -0.2), ncol=4);
Sampling: [adstock_alpha, gamma_control, gamma_fourier, intercept, saturation_beta, saturation_lam, y, y_sigma]

The prior predictive plot shows that the priors are not too informative.
3. Model Fitting#
We can now fit the model:
Tip
You can use other NUTS samplers to fit the model as one can do with PyMC models. You just need to make sure to have the packages installed in your local environment. See Other NUTS Samplers.
mmm.fit(X=X, y=y, chains=4, target_accept=0.85, nuts_sampler="numpyro", random_seed=rng)
-
<xarray.Dataset> Size: 64MB Dimensions: (chain: 4, draw: 1000, channel: 2, date: 179, control: 3, fourier_mode: 4) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 ... 997 998 999 * channel (channel) <U2 16B 'x1' 'x2' * date (date) datetime64[ns] 1kB 2018-04-02 ...... * control (control) <U7 84B 'event_1' 'event_2' 't' * fourier_mode (fourier_mode) <U5 80B 'sin_1' ... 'cos_2' Data variables: (12/13) adstock_alpha (chain, draw, channel) float64 64kB 0.37... channel_contributions (chain, draw, date, channel) float64 11MB ... control_contributions (chain, draw, date, control) float64 17MB ... fourier_contributions (chain, draw, date, fourier_mode) float64 23MB ... gamma_control (chain, draw, control) float64 96kB 0.17... gamma_fourier (chain, draw, fourier_mode) float64 128kB ... ... ... mu (chain, draw, date) float64 6MB 0.5024 .... saturation_beta (chain, draw, channel) float64 64kB 0.33... saturation_lam (chain, draw, channel) float64 64kB 3.86... total_contributions (chain, draw) float64 32kB 38.0 ... 43.7 y_sigma (chain, draw) float64 32kB 0.03083 ... 0... yearly_seasonality_contribution (chain, draw, date) float64 6MB -0.00084... Attributes: created_at: 2025-01-25T21:40:54.250230+00:00 arviz_version: 0.20.0 inference_library: numpyro inference_library_version: 0.16.1 sampling_time: 13.58799 tuning_steps: 1000
-
<xarray.Dataset> Size: 204kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: acceptance_rate (chain, draw) float64 32kB 0.9917 0.8058 ... 0.9826 0.993 diverging (chain, draw) bool 4kB False False False ... False False energy (chain, draw) float64 32kB -339.1 -339.2 ... -339.1 -335.4 lp (chain, draw) float64 32kB -345.7 -346.1 ... -346.9 -346.1 n_steps (chain, draw) int64 32kB 511 511 511 511 ... 511 511 511 step_size (chain, draw) float64 32kB 0.006979 0.006979 ... 0.006447 tree_depth (chain, draw) int64 32kB 9 9 9 9 9 8 9 10 ... 9 9 9 9 9 9 9 Attributes: created_at: 2025-01-25T21:40:54.264810+00:00 arviz_version: 0.20.0
-
<xarray.Dataset> Size: 32MB Dimensions: (chain: 1, draw: 2000, channel: 2, date: 179, control: 3, fourier_mode: 4) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 16kB 0 1 2 ... 1997 1998 1999 * channel (channel) <U2 16B 'x1' 'x2' * date (date) datetime64[ns] 1kB 2018-04-02 ...... * control (control) <U7 84B 'event_1' 'event_2' 't' * fourier_mode (fourier_mode) <U5 80B 'sin_1' ... 'cos_2' Data variables: (12/13) adstock_alpha (chain, draw, channel) float64 32kB 0.55... channel_contributions (chain, draw, date, channel) float64 6MB ... control_contributions (chain, draw, date, control) float64 9MB ... fourier_contributions (chain, draw, date, fourier_mode) float64 11MB ... gamma_control (chain, draw, control) float64 48kB -0.0... gamma_fourier (chain, draw, fourier_mode) float64 64kB ... ... ... mu (chain, draw, date) float64 3MB 1.042 ..... saturation_beta (chain, draw, channel) float64 32kB 0.12... saturation_lam (chain, draw, channel) float64 32kB 3.18... total_contributions (chain, draw) float64 16kB 24.58 ... 9.465 y_sigma (chain, draw) float64 16kB 4.913 ... 4.627 yearly_seasonality_contribution (chain, draw, date) float64 3MB 0.4942 .... Attributes: created_at: 2025-01-25T21:40:38.508011+00:00 arviz_version: 0.20.0 inference_library: pymc inference_library_version: 5.20.0
-
<xarray.Dataset> Size: 3MB Dimensions: (chain: 1, draw: 2000, date: 179) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 16kB 0 1 2 3 4 5 6 ... 1994 1995 1996 1997 1998 1999 * date (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30 Data variables: y (chain, draw, date) float64 3MB 10.64 1.141 ... 0.5361 -2.597 Attributes: created_at: 2025-01-25T21:40:38.511725+00:00 arviz_version: 0.20.0 inference_library: pymc inference_library_version: 5.20.0
-
<xarray.Dataset> Size: 3kB Dimensions: (date: 179) Coordinates: * date (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30 Data variables: y (date) float64 1kB 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625 Attributes: created_at: 2025-01-25T21:40:54.265945+00:00 arviz_version: 0.20.0 inference_library: numpyro inference_library_version: 0.16.1 sampling_time: 13.58799 tuning_steps: 1000
-
<xarray.Dataset> Size: 9kB Dimensions: (date: 179, channel: 2, control: 3) Coordinates: * date (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30 * channel (channel) <U2 16B 'x1' 'x2' * control (control) <U7 84B 'event_1' 'event_2' 't' Data variables: channel_data (date, channel) float64 3kB 0.3196 0.0 0.1128 ... 0.4403 0.0 control_data (date, control) float64 4kB 0.0 0.0 0.0 0.0 ... 0.0 0.0 178.0 dayofyear (date) int32 716B 92 99 106 113 120 ... 214 221 228 235 242 Attributes: created_at: 2025-01-25T21:40:54.267940+00:00 arviz_version: 0.20.0 inference_library: numpyro inference_library_version: 0.16.1 sampling_time: 13.58799 tuning_steps: 1000
-
<xarray.Dataset> Size: 12kB Dimensions: (index: 179) Coordinates: * index (index) int64 1kB 0 1 2 3 4 5 6 7 ... 172 173 174 175 176 177 178 Data variables: date_week (index) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30 x1 (index) float64 1kB 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389 x2 (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.8633 0.0 0.0 0.0 event_1 (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 event_2 (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 dayofyear (index) int32 716B 92 99 106 113 120 127 ... 214 221 228 235 242 t (index) int64 1kB 0 1 2 3 4 5 6 7 ... 172 173 174 175 176 177 178 y (index) float64 1kB 3.985 3.763 4.467 3.864 ... 4.138 4.479 4.676
You can access pymc
model as mmm.model
.
type(mmm.model)
pymc.model.core.Model
print(f"Model was train using the {mmm.saturation.__class__.__name__} function")
print(f"and the {mmm.adstock.__class__.__name__} function")
Model was train using the LogisticSaturation function
and the GeometricAdstock function
We can easily see the explicit model structure:
mmm.graphviz()
Note: You may notice that the graph here is an explicit version of our initial drawing (DAG), where we can now explicitly see all the different components that were included in each node, including their dimensionality. This graph is another way of looking at the same causal assumptions, made during the construction of the bayesian generative model.
4. Model Diagnostics#
A good place to start is by looking if the model had any divergences:
# Number of diverging samples
mmm.idata["sample_stats"]["diverging"].sum().item()
0
We got none! 🙌
The fit_result
attribute contains the pymc
trace object.
mmm.fit_result
<xarray.Dataset> Size: 64MB Dimensions: (chain: 4, draw: 1000, channel: 2, date: 179, control: 3, fourier_mode: 4) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 ... 997 998 999 * channel (channel) <U2 16B 'x1' 'x2' * date (date) datetime64[ns] 1kB 2018-04-02 ...... * control (control) <U7 84B 'event_1' 'event_2' 't' * fourier_mode (fourier_mode) <U5 80B 'sin_1' ... 'cos_2' Data variables: (12/13) adstock_alpha (chain, draw, channel) float64 64kB 0.37... channel_contributions (chain, draw, date, channel) float64 11MB ... control_contributions (chain, draw, date, control) float64 17MB ... fourier_contributions (chain, draw, date, fourier_mode) float64 23MB ... gamma_control (chain, draw, control) float64 96kB 0.17... gamma_fourier (chain, draw, fourier_mode) float64 128kB ... ... ... mu (chain, draw, date) float64 6MB 0.5024 .... saturation_beta (chain, draw, channel) float64 64kB 0.33... saturation_lam (chain, draw, channel) float64 64kB 3.86... total_contributions (chain, draw) float64 32kB 38.0 ... 43.7 y_sigma (chain, draw) float64 32kB 0.03083 ... 0... yearly_seasonality_contribution (chain, draw, date) float64 6MB -0.00084... Attributes: created_at: 2025-01-25T21:40:54.250230+00:00 arviz_version: 0.20.0 inference_library: numpyro inference_library_version: 0.16.1 sampling_time: 13.58799 tuning_steps: 1000
We can therefore use all the pymc
machinery to run model diagnostics. First, let’s see the summary of the trace:
az.summary(
data=mmm.fit_result,
var_names=[
"intercept",
"y_sigma",
"saturation_beta",
"saturation_lam",
"adstock_alpha",
"gamma_control",
"gamma_fourier",
],
)
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
intercept | 0.355 | 0.013 | 0.330 | 0.380 | 0.000 | 0.000 | 2585.0 | 2711.0 | 1.0 |
y_sigma | 0.031 | 0.002 | 0.028 | 0.035 | 0.000 | 0.000 | 3182.0 | 2862.0 | 1.0 |
saturation_beta[x1] | 0.362 | 0.020 | 0.326 | 0.402 | 0.000 | 0.000 | 2218.0 | 2375.0 | 1.0 |
saturation_beta[x2] | 0.270 | 0.083 | 0.193 | 0.396 | 0.003 | 0.002 | 1404.0 | 1101.0 | 1.0 |
saturation_lam[x1] | 3.952 | 0.379 | 3.232 | 4.637 | 0.007 | 0.005 | 2566.0 | 2270.0 | 1.0 |
saturation_lam[x2] | 3.140 | 1.188 | 1.074 | 5.356 | 0.031 | 0.022 | 1378.0 | 1134.0 | 1.0 |
adstock_alpha[x1] | 0.402 | 0.031 | 0.341 | 0.458 | 0.001 | 0.000 | 2582.0 | 2532.0 | 1.0 |
adstock_alpha[x2] | 0.188 | 0.041 | 0.117 | 0.271 | 0.001 | 0.001 | 1820.0 | 1833.0 | 1.0 |
gamma_control[event_1] | 0.176 | 0.028 | 0.123 | 0.226 | 0.000 | 0.000 | 3408.0 | 2689.0 | 1.0 |
gamma_control[event_2] | 0.231 | 0.028 | 0.178 | 0.282 | 0.000 | 0.000 | 3310.0 | 2774.0 | 1.0 |
gamma_control[t] | 0.001 | 0.000 | 0.001 | 0.001 | 0.000 | 0.000 | 3057.0 | 3034.0 | 1.0 |
gamma_fourier[sin_1] | 0.003 | 0.003 | -0.004 | 0.010 | 0.000 | 0.000 | 5758.0 | 2452.0 | 1.0 |
gamma_fourier[sin_2] | -0.058 | 0.004 | -0.064 | -0.051 | 0.000 | 0.000 | 5624.0 | 2990.0 | 1.0 |
gamma_fourier[cos_1] | 0.062 | 0.003 | 0.057 | 0.069 | 0.000 | 0.000 | 5881.0 | 2848.0 | 1.0 |
gamma_fourier[cos_2] | 0.001 | 0.003 | -0.005 | 0.008 | 0.000 | 0.000 | 5009.0 | 3071.0 | 1.0 |
Observe that the estimated parameters for \(\alpha\) and \(\lambda\) are very close to the ones we set in the data generation process! Let’s plot the trace for the parameters:
_ = az.plot_trace(
data=mmm.fit_result,
var_names=[
"intercept",
"y_sigma",
"saturation_beta",
"saturation_lam",
"adstock_alpha",
"gamma_control",
"gamma_fourier",
],
compact=True,
backend_kwargs={"figsize": (12, 10), "layout": "constrained"},
)
plt.gcf().suptitle("Model Trace", fontsize=16);

Now we sample from the posterior predictive distribution:
mmm.sample_posterior_predictive(X, extend_idata=True, combined=True)
Sampling: [y]
<xarray.Dataset> Size: 6MB Dimensions: (sample: 4000, date: 179) Coordinates: * date (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30 * sample (sample) object 32kB MultiIndex * chain (sample) int64 32kB 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3 * draw (sample) int64 32kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 Data variables: y (date, sample) float64 6MB 3.879 4.181 3.995 ... 4.974 5.116 4.946 Attributes: created_at: 2025-01-25T21:40:56.469125+00:00 arviz_version: 0.20.0 inference_library: pymc inference_library_version: 5.20.0
We can now plot the posterior predictive distribution for the target variable. By default, the plot_posterior_predictive
method will plot the mean prediction along with a 94% and 50% HDI.
mmm.plot_posterior_predictive(original_scale=True);

But you can also remove the mean and HDI and add a gradient, which shows the range of the posterior predictive distribution.
mmm.plot_posterior_predictive(add_mean=False, add_gradient=True);

The fit looks very good (as expected)!
We can inspect the model errors:
mmm.plot_errors(original_scale=True);

We can actually extract the whole error posterior distribution for custom error analyzes:
errors = mmm.get_errors(original_scale=True)
fig, ax = plt.subplots(figsize=(8, 6))
az.plot_dist(
errors, quantiles=[0.25, 0.5, 0.75], color="C3", fill_kwargs={"alpha": 0.7}, ax=ax
)
ax.axvline(x=0, color="black", linestyle="--", linewidth=1, label="zero")
ax.legend()
ax.set(title="Errors Posterior Distribution");

Next, we can decompose the posterior predictive distribution into the different components:
mmm.plot_components_contributions();

Remark: This plot shows the decomposition of the normalized target variable when by dividing by its maximum value. Do not forget that internally we are scaling the variables to make the model sample more efficiently. You can recover the transformations from the API methods, e.g.
mmm.get_target_transformer()
Pipeline(steps=[('scaler', MaxAbsScaler())])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Pipeline(steps=[('scaler', MaxAbsScaler())])
MaxAbsScaler()
We plot in the original scale by simply passing the original_scale=True
argument:
mmm.plot_components_contributions(original_scale=True);

A similar decomposition can be achieved using an area plot:
groups = {
"Base": [
"intercept",
"event_1",
"event_2",
"t",
"yearly_seasonality",
],
"Channel 1": ["x1"],
"Channel 2": ["x2"],
}
fig = mmm.plot_grouped_contribution_breakdown_over_time(
stack_groups=groups,
original_scale=True,
area_kwargs={
"color": {
"Channel 1": "C0",
"Channel 2": "C1",
"Base": "gray",
"Seasonality": "black",
},
"alpha": 0.7,
},
)
fig.suptitle("Contribution Breakdown over Time", fontsize=16);

Note that this only works if the contributions of the channel or control variable are strictly positive.
Next, we look into the absolute historical contributions of each component:
mmm.plot_waterfall_components_decomposition();

Note that we have recovered the true values for all the parameters! Well, in fact the contributions of the intercept
and t
are not exactly the same as int the data generating process, but the aggregate does match the true values of intercept
+ trend
. The reason is that the true latent trend is not completely linear. One could use the time-varying intercept feature to capture this effect.
We can extract the data itself of all the input variables contributions over time, i.e. the regression coefficients times the feature values, as follows:
get_mean_contributions_over_time_df = mmm.compute_mean_contributions_over_time(
original_scale=True
)
get_mean_contributions_over_time_df.head()
x1 | x2 | event_1 | event_2 | t | yearly_seasonality | intercept | |
---|---|---|---|---|---|---|---|
date | |||||||
2018-04-02 | 1.081818 | 0.0 | 0.0 | 0.0 | 0.000000 | 0.021384 | 2.948147 |
2018-04-09 | 0.831961 | 0.0 | 0.0 | 0.0 | 0.005136 | 0.073417 | 2.948147 |
2018-04-16 | 1.292488 | 0.0 | 0.0 | 0.0 | 0.010273 | 0.119255 | 2.948147 |
2018-04-23 | 0.790950 | 0.0 | 0.0 | 0.0 | 0.015409 | 0.153584 | 2.948147 |
2018-04-30 | 1.538755 | 0.0 | 0.0 | 0.0 | 0.020545 | 0.171823 | 2.948147 |
5. Media Parameters#
We can deep-dive into the media transformation parameters. We want to compare the posterior distributions against the true values.
fig = mmm.plot_channel_parameter(param_name="adstock_alpha", figsize=(9, 5))
ax = fig.axes[0]
ax.axvline(x=alpha1, color="C0", linestyle="--", label=r"$\alpha_1$")
ax.axvline(x=alpha2, color="C1", linestyle="--", label=r"$\alpha_2$")
ax.legend(loc="upper right");

fig = mmm.plot_channel_parameter(param_name="saturation_lam", figsize=(9, 5))
ax = fig.axes[0]
ax.axvline(x=lam1, color="C0", linestyle="--", label=r"$\lambda_1$")
ax.axvline(x=lam2, color="C1", linestyle="--", label=r"$\lambda_2$")
ax.legend(loc="upper right");

We indeed see that our media parameter were successfully recovered!
6. Media Deep-Dive#
First we can compute the relative contribution of each channel to the target variable. Note that we recover the true values!
fig = mmm.plot_channel_contribution_share_hdi(figsize=(7, 5))
ax = fig.axes[0]
ax.axvline(
x=contribution_share_x1,
color="C1",
linestyle="--",
label="true contribution share ($x_1$)",
)
ax.axvline(
x=contribution_share_x2,
color="C2",
linestyle="--",
label="true contribution share ($x_2$)",
)
ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.05), ncol=1);

Next, we can plot the relative contribution of each channel to the target variable.
First we plot the direct contribution per channel. Again, we get very close values as the ones obtained in Part I.
fig = mmm.plot_direct_contribution_curves()
[ax.set(xlabel="x") for ax in fig.axes];

Note that trying to get the delayed cumulative contribution is not that easy as contributions from the past leak into the future. Specifically, note that we apply the saturation function to the aggregation. As the saturation function is non-linear. This is not the same as taking the sum of the saturation contributions Hence, is very hard to reverse engineer the contribution after carryover and saturation composition this way.
A more transparent alternative is to evaluate the channel contribution at different share spend levels for the complete training period. Concretely, if the denote by \(\delta\) the input channel data percentage level, so that for \(\delta = 1\) we have the model input spend data and for \(\delta = 1.5\) we have a \(50\%\) increase in the spend, then we can compute the channel contribution at a grid of \(\delta\)-values and plot the results:
mmm.plot_channel_contributions_grid(start=0, stop=1.5, num=12);

This plot does account for carryover (adstock) and saturation effects.
We see that when we have no spend, the contribution is zero (assuming there twas no spend in the past, otherwise the carryover effect would be non-zero).
Observe that these grid values serve as inputs for an optimization step.
We can also plot the same contribution using the x-axis as the total channel input (e.g. total spend in EUR).
mmm.plot_channel_contributions_grid(start=0, stop=1.5, num=12, absolute_xrange=True);

7. Contribution Recovery#
Next, we can plot the direct contribution of each channel to the target variable over time.
channels_contribution_original_scale = mmm.compute_channel_contribution_original_scale()
channels_contribution_original_scale_hdi = az.hdi(
ary=channels_contribution_original_scale
)
fig, ax = plt.subplots(
nrows=2, figsize=(15, 8), ncols=1, sharex=True, sharey=False, layout="constrained"
)
for i, x in enumerate(["x1", "x2"]):
# Estimate true contribution in the original scale from the data generating process
sns.lineplot(
x=df["date_week"],
y=amplitude * betas[i] * df[f"{x}_adstock_saturated"],
color="black",
label=f"{x} true contribution",
ax=ax[i],
)
# HDI estimated contribution in the original scale
ax[i].fill_between(
x=df["date_week"],
y1=channels_contribution_original_scale_hdi.sel(channel=x)["x"][:, 0],
y2=channels_contribution_original_scale_hdi.sel(channel=x)["x"][:, 1],
color=f"C{i}",
label=rf"{x} $94\%$ HDI contribution",
alpha=0.4,
)
# Mean estimated contribution in the original scale
sns.lineplot(
x=df["date_week"],
y=get_mean_contributions_over_time_df[x].to_numpy(),
color=f"C{i}",
label=f"{x} posterior mean contribution",
alpha=0.8,
ax=ax[i],
)
ax[i].legend(loc="center left", bbox_to_anchor=(1, 0.5))
ax[i].set(title=f"Channel {x}")

The results look great! We therefore successfully recovered the true values from the data generation process. We have also seen how easy is to use the MMM
class to fit media mix models! It takes over the model specification and the media transformations, while having all the flexibility of pymc
!
8. ROAS#
Finally, we can compute the (approximate) ROAS posterior distribution for each channel.
channel_contribution_original_scale = mmm.compute_channel_contribution_original_scale()
spend_sum = X[["x1", "x2"]].sum().to_numpy()
roas_samples = (
channel_contribution_original_scale.sum(dim="date")
/ spend_sum[np.newaxis, np.newaxis, :]
)
fig, axes = plt.subplots(
nrows=2, ncols=1, figsize=(12, 7), sharex=True, sharey=False, layout="constrained"
)
az.plot_posterior(roas_samples, ref_val=[roas_1, roas_2], ax=axes)
axes[0].set(title="Channel $x_{1}$")
axes[1].set(title="Channel $x_{2}$", xlabel="ROAS")
fig.suptitle("ROAS Posterior Distributions", fontsize=18, fontweight="bold", y=1.06);

We see that the ROAS posterior distributions are centered around the true values! We also see that, even considering the uncertainty, channel \(x_{1}\) is more efficient than channel \(x_{2}\).
It is also useful to compare the ROAS and the contribution share. In the next plot we plot these two these two inferred estimates per channel.
# Get the contribution share samples
share_samples = mmm.get_channel_contributions_share_samples()
fig, ax = plt.subplots(figsize=(9, 7))
for i, channel in enumerate(["x1", "x2"]):
# Contribution share mean and hdi
share_mean = share_samples.sel(channel=channel).mean().to_numpy()
share_hdi = az.hdi(share_samples.sel(channel=channel))["x"].to_numpy()
# ROAS mean and hdi
roas_mean = roas_samples.sel(channel=channel).mean().to_numpy()
roas_hdi = az.hdi(roas_samples.sel(channel=channel))["x"].to_numpy()
# Plot the contribution share hdi
ax.vlines(share_mean, roas_hdi[0], roas_hdi[1], color=f"C{i}", alpha=0.8)
# Plot the ROAS hdi
ax.hlines(roas_mean, share_hdi[0], share_hdi[1], color=f"C{i}", alpha=0.8)
# Plot the means
ax.scatter(
share_mean,
roas_mean,
# Size of the scatter points is proportional to the spend share
s=spend_share[channel] * 100,
color=f"C{i}",
label=channel,
)
ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{x:.0%}"))
ax.legend(loc="upper left", title="Channel", title_fontsize=12)
ax.set(
title="Channel Contribution Share vs ROAS",
xlabel="Contribution Share",
ylabel="ROAS",
);

This plot is very effective summarizing channel efficiency. In this example, it turns out that the most efficient channel \(x_1\) has a higher contribution share than the less efficient channel \(x_2\).
9. Out of Sample Predictions#
Out of sample predictions are done with the predict
and posterior_predictive
methods. These include
sample_posterior_predictive
: Get the full posterior predictive distributionpredict
: Get the mean of the posterior predictive distribution
These methods take new data, X
, and some additional kwargs
for new predictions. Namely,
include_last_observations
: boolean flag in order to carry adstock effects from last observations in the training dataset
The new data needs to have all the features that are specified in the model. There is no need to worry about:
input scaling of channel spends or control features
creating fourier transformations on the
date_column
inverse scaling back to target domain
That will be done automatically!
last_date = X["date_week"].max()
# New dates starting from last in dataset
n_new = 5
new_dates = pd.date_range(start=last_date, periods=1 + n_new, freq="W-MON")[1:]
X_out_of_sample = pd.DataFrame(
{
"date_week": new_dates,
}
)
# Same channel spends as last day
X_out_of_sample["x1"] = X["x1"].iloc[-1]
X_out_of_sample["x2"] = X["x2"].iloc[-1]
# Other features
X_out_of_sample["event_1"] = 0
X_out_of_sample["event_2"] = 0
X_out_of_sample["t"] = range(len(X), len(X) + n_new)
X_out_of_sample
date_week | x1 | x2 | event_1 | event_2 | t | |
---|---|---|---|---|---|---|
0 | 2021-09-06 | 0.438857 | 0.0 | 0 | 0 | 179 |
1 | 2021-09-13 | 0.438857 | 0.0 | 0 | 0 | 180 |
2 | 2021-09-20 | 0.438857 | 0.0 | 0 | 0 | 181 |
3 | 2021-09-27 | 0.438857 | 0.0 | 0 | 0 | 182 |
4 | 2021-10-04 | 0.438857 | 0.0 | 0 | 0 | 183 |
Call the desired method to get the new samples! The new coordinates will be from the new dates
X_out_of_sample.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5 entries, 0 to 4
Data columns (total 6 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 date_week 5 non-null datetime64[ns]
1 x1 5 non-null float64
2 x2 5 non-null float64
3 event_1 5 non-null int64
4 event_2 5 non-null int64
5 t 5 non-null int64
dtypes: datetime64[ns](1), float64(2), int64(3)
memory usage: 372.0 bytes
y_out_of_sample = mmm.sample_posterior_predictive(X_out_of_sample, extend_idata=False)
y_out_of_sample
<xarray.Dataset> Size: 256kB Dimensions: (sample: 4000, date: 5) Coordinates: * date (date) datetime64[ns] 40B 2021-09-06 2021-09-13 ... 2021-10-04 * sample (sample) object 32kB MultiIndex * chain (sample) int64 32kB 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3 * draw (sample) int64 32kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 Data variables: y (date, sample) float64 160kB 5.145 5.333 4.512 ... 5.915 6.603 Attributes: created_at: 2025-01-25T21:41:16.704856+00:00 arviz_version: 0.20.0 inference_library: pymc inference_library_version: 5.20.0
NOTE: If the method is being called multiple times, set the extend_idata
argument to False in order to not overwrite the observed_data
in the InferenceData
The new predictions are transformed back to the original scale of the target by default. That can be seen below:
def plot_in_sample(X, y, ax, n_points: int = 15):
(
y.to_frame()
.set_index(X["date_week"])
.iloc[-n_points:]
.plot(ax=ax, marker="o", color="black", label="actuals")
)
return ax
def plot_out_of_sample(X_out_of_sample, y_out_of_sample, ax, color, label):
y_out_of_sample_groupby = y_out_of_sample["y"].to_series().groupby("date")
lower, upper = quantiles = [0.025, 0.975]
conf = y_out_of_sample_groupby.quantile(quantiles).unstack()
ax.fill_between(
X_out_of_sample["date_week"].dt.to_pydatetime(),
conf[lower],
conf[upper],
alpha=0.25,
color=color,
label=f"{label} interval",
)
mean = y_out_of_sample_groupby.mean()
mean.plot(ax=ax, marker="o", label=label, color=color, linestyle="--")
ax.set(ylabel="Original Target Scale", title="Out of sample predictions for MMM")
return ax
_, ax = plt.subplots()
plot_in_sample(X, y, ax=ax)
plot_out_of_sample(
X_out_of_sample, y_out_of_sample, ax=ax, label="out of sample", color="C0"
)
ax.legend(loc="upper left");

If the out of sample data is being extended from the original predictions, consider setting the include_last_observations
to True in order to carry over the effects from the last channel spends in the training set.
The predictions are higher since the channel contributions the final spends still have an impact that eventually subside.
y_out_of_sample_with_adstock = mmm.sample_posterior_predictive(
X_out_of_sample, extend_idata=False, include_last_observations=True
)
_, ax = plt.subplots()
plot_in_sample(X, y, ax=ax)
plot_out_of_sample(
X_out_of_sample, y_out_of_sample, ax=ax, label="out of sample", color="C0"
)
plot_out_of_sample(
X_out_of_sample,
y_out_of_sample_with_adstock,
ax=ax,
label="adstock out of sample",
color="C1",
)
ax.legend();

Finally we can use the model to understand the expected sales for different media spend scenarios considering the adstock and saturation effects learned from the data.
spends = [0.3, 0.5, 1, 2]
fig, axes = plt.subplots(
nrows=len(spends),
ncols=1,
figsize=(11, 9),
sharex=True,
sharey=True,
layout="constrained",
)
axes = axes.flatten()
for ax, spend in zip(axes, spends, strict=True):
mmm.plot_new_spend_contributions(spend_amount=spend, progressbar=False, ax=ax)
fig.suptitle("New Spend Contribution Simulations", fontsize=18, fontweight="bold");

We clearly see that since \(x_1\) has a higher adstock parameter \(\alpha\) than \(x_2\), then for new spend on a single date (i.e. one_time
)
\(x_1\) has larger delayed contributions than \(x_2\).
10. Save Model#
After your model is train, you can quickly save it using the save
method. For more information about model deployment see Model deployment.
# mmm.save("model.nc")
%load_ext watermark
%watermark -n -u -v -iv -w -p pymc_marketing,pytensor
Last updated: Sat Jan 25 2025
Python implementation: CPython
Python version : 3.12.8
IPython version : 8.31.0
pymc_marketing: 0.10.0
pytensor : 2.26.4
matplotlib : 3.10.0
graphviz : 0.20.3
pymc : 5.20.0
pymc_marketing: 0.10.0
pandas : 2.2.3
numpy : 1.26.4
seaborn : 0.13.2
arviz : 0.20.0
Watermark: 2.5.0