Cuaderno de ejemplo para MMM#
In this notebook we work out a simulated example to showcase the Media Mix Model (MMM) API from pymc-marketing. This package provides a pymc implementation of the MMM presented in the paper Jin, Yuxue, et al. “Bayesian methods for media mix modeling with carryover and shape effects.” (2017). We work with synthetic data as we want to do parameter recovery to better understand the model assumptions. That is, we explicitly set values for our adstock and saturation parameters (see model specification below) and recover them back from the model. The data generation process is an adaptation of the blog post «Media Effect Estimation with PyMC: Adstock, Saturation & Diminishing Returns» by Juan Orduz.
Problema de negocio#
Antes de adentrarnos en los datos, definamos primero el problema de negocio que estamos intentando resolver. Somos una agencia de marketing y queremos optimizar el presupuesto de marketing de un cliente. Tenemos acceso a los siguientes datos:
Datos de ventas: ventas semanales del cliente.
Datos de inversión en medios: gasto semanal en diferentes canales de medios (p. ej., TV, radio, online, etc.). En este ejemplo consideramos 2 canales de medios: \(x_{1}\) y \(x_{2}\).
Conocimiento del dominio:
We know that there has been a positive sales trend which we believe comes from a strong economic growth.
También sabemos que existe un efecto de estacionalidad anual.
Además, se nos informó de dos valores atípicos en los datos durante las semanas
2019-05-13y2021-09-14.
Con esta información podemos dibujar un Gráfico Acíclico Dirigido (DAG) o modelo gráfico de cómo creemos que nuestras variables están relacionadas. En otras palabras, representar cómo creemos que nuestro sistema está causalmente relacionado.
En este ejemplo, consideraremos un sistema sencillo donde:
Marketing: Representa las acciones generadas por \(x_{1}\) y \(x_{2}\).
Eventos especiales: Valores atípicos en días específicos, posiblemente debidos a fechas especiales.
Variables exógenas: Consideraremos variables determinadas por factores externos, no determinadas en el modelo (p. ej.: crecimiento económico del país o condiciones meteorológicas que determinan un comportamiento estacional).
Understanding this ecosystem is essential to create a model that reveals the true causal signals and allows us to optimize our advertising budget. But, what do we mean by optimize the marketing budget? We want to find the optimal media mix that maximizes sales.
Ahora, dado el DAG descrito arriba, entendemos que existe una relación causal entre marketing y ventas, pero ¿cuál es la naturaleza de esa relación? En este caso, asumiremos que esta relación no es lineal; por ejemplo, un aumento del \(10\%\) en el gasto del canal \(x_{1}\) no se traduce necesariamente en un aumento del \(10\%\) en las ventas. Esto puede explicarse por dos fenómenos:
On one hand, there is a carry-over effect. Meaning, the effect of spend on sales is not instantaneous but accumulates over time.
Además, existe un efecto de saturación. Es decir, el efecto del gasto sobre las ventas no es lineal, sino que se satura en algún punto.
La ecuación implementada para describir el DAG presentado arriba será la expresada en Jin, Yuxue, et al. “Bayesian methods for media mix modeling with carryover and shape effects.” (2017), añadiendo un supuesto causal sobre los efectos de los medios y su impacto exclusivamente positivo. Concretamente, dada una variable objetivo de serie temporal \(y_{t}\) (p. ej., ventas o conversiones), variables de medios \(x_{m, t}\) (p. ej., impresiones, clics o costes) y un conjunto de covariables de control \(z_{c, t}\) (p. ej., festivos, eventos especiales), consideramos un modelo lineal de la forma
donde \(\alpha\) es el intercepto, \(f\) es una función de transformación de medios y \(\varepsilon_{t}\) es el término de error que asumimos normalmente distribuido. La función \(f\) codifica la contribución positiva de los medios sobre la variable objetivo. Normalmente consideramos dos tipos de transformaciones: adstock (arrastre) y efectos de saturación.
En PyMC-Marketing, ofrecemos una API para un Modelo Bayesiano de Mezcla de Medios (MMM) con varias especificaciones. En el ejemplo, implementaremos Adstock geométrico y Saturación logística como las transformaciones elegidas para nuestra Ecuación Causal Estructural previamente discutida.
Truco
El modelo MMM en pymc-marketing proporciona funciones adicionales sobre este modelo base:
Calibración con experimentos: Tenemos la opción de añadir experimentos empíricos (pruebas de lift) para calibrar el modelo usando funciones de verosimilitud personalizadas. Consulta Calibración de Prueba de Elevación.
Intercepto variable en el tiempo: Captura contribuciones base que varían en el tiempo en tu modelo (usando métodos modernos y eficientes de aproximación con procesos gaussianos). Es decir, permitimos que el término de intercepto \(\alpha = \alpha(t)\) varíe con el tiempo. Consulta mmm_tvp_example.
Budget Optimization: Allocate your marketing budget based on the parameters recovered by the model, finding the spend distribution to maximize the amount of contribution given a limited budget. See Asignación de Presupuesto con PyMC-Marketing.
Referencias:#
Blog de PyMC Labs:
Johns, Michael and Wang, Zhenyu. «A Bayesian Approach to Media Mix Modeling»
Parte I: Proceso de generación de datos#
En la Parte I de este cuaderno nos centramos en el proceso de generación de datos. Es decir, queremos construir la variable objetivo \(y_{t}\) (ventas) añadiendo cada uno de los componentes descritos en la sección Problema de negocio.
Preparar el cuaderno#
import warnings
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import seaborn as sns
from pymc_extras.prior import Prior
from pymc_marketing.mmm import GeometricAdstock, LogisticSaturation
from pymc_marketing.mmm.multidimensional import MMM
from pymc_marketing.mmm.transformers import geometric_adstock, logistic_saturation
warnings.filterwarnings("ignore", category=FutureWarning)
az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
Generar datos#
1. Date Range#
Primero definimos un rango temporal para nuestros datos. Consideramos un poco más de 2 años de datos con granularidad semanal.
seed: int = sum(map(ord, "mmm"))
rng: np.random.Generator = np.random.default_rng(seed=seed)
# date range
min_date = pd.to_datetime("2018-04-01")
max_date = pd.to_datetime("2021-09-01")
df = pd.DataFrame(
data={"date_week": pd.date_range(start=min_date, end=max_date, freq="W-MON")}
).assign(
year=lambda x: x["date_week"].dt.year,
month=lambda x: x["date_week"].dt.month,
dayofyear=lambda x: x["date_week"].dt.dayofyear,
)
n = df.shape[0]
print(f"Number of observations: {n}")
Number of observations: 179
2. Media Costs Data#
Ahora generamos datos sintéticos de dos canales \(x_1\) y \(x_2\). Nos referimos a ello como la señal bruta ya que será la entrada en la fase de modelado. Esperamos que la contribución de cada canal sea diferente, en función de los parámetros de arrastre y saturación.
Señal inicial
# media data
x1 = rng.uniform(low=0.0, high=1.0, size=n)
df["x1"] = np.where(x1 > 0.9, x1, x1 / 2)
x2 = rng.uniform(low=0.0, high=1.0, size=n)
df["x2"] = np.where(x2 > 0.8, x2, 0)
fig, ax = plt.subplots(
nrows=2, ncols=1, figsize=(10, 7), sharex=True, sharey=True, layout="constrained"
)
sns.lineplot(x="date_week", y="x1", data=df, color="C0", ax=ax[0])
sns.lineplot(x="date_week", y="x2", data=df, color="C1", ax=ax[1])
ax[1].set(xlabel="date")
fig.suptitle("Media Costs Data", fontsize=18, fontweight="bold");
Remark: By design, \(x_{1}\) should resemble a typical paid social channel and \(x_{2}\) an offline (e.g. TV) spend time series.
Señal de efecto
A continuación, pasamos la señal bruta por las dos transformaciones: primero el adstock geométrico (efecto de arrastre) y luego la saturación logística. Ten en cuenta que fijamos nosotros mismos los parámetros, pero los recuperaremos desde el modelo.
Comencemos con la transformación de adstock. Fijamos el parámetro de adstock \(0 < \alpha < 1\) en \(0.4\) y \(0.2\) para \(x_1\) y \(x_2\) respectivamente. Establecemos un efecto de rezago máximo de \(8\) semanas.
# apply geometric adstock transformation
alpha1: float = 0.4
alpha2: float = 0.2
df["x1_adstock"] = geometric_adstock(
x=df["x1"].to_xarray(), alpha=alpha1, l_max=8, normalize=True, dim="index"
).eval()
df["x2_adstock"] = geometric_adstock(
x=df["x2"].to_xarray(), alpha=alpha2, l_max=8, normalize=True, dim="index"
).eval()
Next, we compose the resulting adstock signals with the logistic saturation function. We set the parameter \(\lambda > 0\) to be \(4\) and \(3\) for \(x_1\) and \(x_2\) respectively.
# apply saturation transformation
lam1: float = 4.0
lam2: float = 3.0
df["x1_adstock_saturated"] = logistic_saturation(
x=df["x1_adstock"].to_xarray(), lam=lam1
).eval()
df["x2_adstock_saturated"] = logistic_saturation(
x=df["x2_adstock"].to_xarray(), lam=lam2
).eval()
Ahora podemos visualizar la señal de efecto para cada canal después de cada transformación:
fig, ax = plt.subplots(
nrows=3, ncols=2, figsize=(16, 9), sharex=True, sharey=False, layout="constrained"
)
sns.lineplot(x="date_week", y="x1", data=df, color="C0", ax=ax[0, 0])
sns.lineplot(x="date_week", y="x2", data=df, color="C1", ax=ax[0, 1])
sns.lineplot(x="date_week", y="x1_adstock", data=df, color="C0", ax=ax[1, 0])
sns.lineplot(x="date_week", y="x2_adstock", data=df, color="C1", ax=ax[1, 1])
sns.lineplot(x="date_week", y="x1_adstock_saturated", data=df, color="C0", ax=ax[2, 0])
sns.lineplot(x="date_week", y="x2_adstock_saturated", data=df, color="C1", ax=ax[2, 1])
fig.suptitle("Media Costs Data - Transformed", fontsize=18, fontweight="bold");
3. Trend & Seasonal Components#
Ahora añadimos componentes sintéticos de tendencia y estacionalidad a la señal de efecto.
df["trend"] = (np.linspace(start=0.0, stop=50, num=n) + 10) ** (1 / 4) - 1
df["cs"] = -np.sin(2 * 2 * np.pi * df["dayofyear"] / 365.5)
df["cc"] = np.cos(1 * 2 * np.pi * df["dayofyear"] / 365.5)
df["seasonality"] = 0.5 * (df["cs"] + df["cc"])
fig, ax = plt.subplots()
sns.lineplot(x="date_week", y="trend", color="C2", label="trend", data=df, ax=ax)
sns.lineplot(
x="date_week", y="seasonality", color="C3", label="seasonality", data=df, ax=ax
)
ax.legend(loc="upper left")
ax.set(xlabel="date", ylabel=None)
ax.set_title("Trend & Seasonality Components", fontsize=18, fontweight="bold");
4. Control Variables#
Añadimos dos eventos en los que hubo un pico notable en nuestra variable objetivo. Suponemos que son independientes y no estacionales (p. ej., lanzamiento de un producto concreto).
df["event_1"] = (df["date_week"] == "2019-05-13").astype(float)
df["event_2"] = (df["date_week"] == "2020-09-14").astype(float)
5. Target Variable#
Por último, definimos la variable objetivo (ventas) \(y\). Suponemos que es una combinación lineal de la señal de efecto, la tendencia y los componentes estacionales, más los dos eventos y un intercepto. También añadimos algo de ruido gaussiano.
df["intercept"] = 2.0
df["epsilon"] = rng.normal(loc=0.0, scale=0.25, size=n)
amplitude = 1
beta_1 = 3.0
beta_2 = 2.0
betas = [beta_1, beta_2]
df["y"] = amplitude * (
df["intercept"]
+ df["trend"]
+ df["seasonality"]
+ 1.5 * df["event_1"]
+ 2.5 * df["event_2"]
+ beta_1 * df["x1_adstock_saturated"]
+ beta_2 * df["x2_adstock_saturated"]
+ df["epsilon"]
)
fig, ax = plt.subplots()
sns.lineplot(x="date_week", y="y", color="black", data=df, ax=ax)
ax.set(xlabel="date", ylabel="y (thousands)")
ax.set_title("Sales (Target Variable)", fontsize=18, fontweight="bold");
Podemos visualizar las contribuciones reales de los componentes durante el período histórico:
fig, ax = plt.subplots()
contributions = [
df["intercept"].sum(),
(beta_1 * df["x1_adstock_saturated"]).sum(),
(beta_2 * df["x2_adstock_saturated"]).sum(),
1.5 * df["event_1"].sum(),
2.5 * df["event_2"].sum(),
df["trend"].sum(),
df["seasonality"].sum(),
]
ax.bar(
["intercept", "x1", "x2", "event_1", "event_2", "trend", "seasonality"],
contributions,
color=["C0" if x >= 0 else "C3" for x in contributions],
alpha=0.8,
)
ax.bar_label(
ax.containers[0],
fmt="{:,.2f}",
label_type="edge",
padding=2,
fontsize=15,
fontweight="bold",
)
ax.set(ylabel="Sales (thousands)")
ax.set_title("Sales Attribution", fontsize=18, fontweight="bold");
Nos gustaría recuperar estos valores a partir del modelo.
6. Media Contribution Interpretation#
A partir del proceso de generación de datos podemos calcular la contribución relativa de cada canal a la variable objetivo. Recuperaremos estos valores desde el modelo.
contribution_share_x1: float = (beta_1 * df["x1_adstock_saturated"]).sum() / (
beta_1 * df["x1_adstock_saturated"] + beta_2 * df["x2_adstock_saturated"]
).sum()
contribution_share_x2: float = (beta_2 * df["x2_adstock_saturated"]).sum() / (
beta_1 * df["x1_adstock_saturated"] + beta_2 * df["x2_adstock_saturated"]
).sum()
print(f"Contribution Share of x1: {contribution_share_x1:.2f}")
print(f"Contribution Share of x2: {contribution_share_x2:.2f}")
Contribution Share of x1: 0.81
Contribution Share of x2: 0.19
Podemos obtener las gráficas de contribución para cada canal donde vemos claramente el efecto de las transformaciones de adstock y saturación.
fig, ax = plt.subplots(
nrows=2, ncols=1, figsize=(12, 8), sharex=True, sharey=False, layout="constrained"
)
for i, x in enumerate(["x1", "x2"]):
sns.scatterplot(
x=df[x],
y=amplitude * betas[i] * df[f"{x}_adstock_saturated"],
color=f"C{i}",
ax=ax[i],
)
ax[i].set(
title=f"$x_{i + 1}$ contribution",
ylabel=f"$\\beta_{i + 1} \\cdot x_{i + 1}$ adstocked & saturated",
xlabel="x",
)
Esta gráfica muestra algunos aspectos interesantes de la contribución de los medios:
El efecto de adstock se refleja en la contribución distinta de cero del canal incluso cuando el gasto es cero.
One can clearly see the saturation effect as the contribution growth (slope) decreases as the spend increases.
Como veremos en la Parte II de este cuaderno, ¡recuperaremos estas gráficas desde el modelo!
Vemos que el canal \(x_{1}\) tiene una contribución mayor que \(x_{2}\). Esto podría explicarse por el hecho de que hubo más gasto en el canal \(x_{1}\) que en el canal \(x_{2}\):
fig, ax = plt.subplots(figsize=(7, 5))
df[["x1", "x2"]].sum().plot(kind="bar", color=["C0", "C1"], ax=ax)
ax.set(title="Total Media Spend", xlabel="Media Channel", ylabel="Costs (thousands)");
Sin embargo, normalmente no solo interesa la contribución en sí, sino el Retorno sobre la Inversión Publicitaria (ROAS). Es decir, la contribución dividida por el coste. Podemos calcular el ROAS para cada canal de la siguiente manera:
roas_1 = (amplitude * beta_1 * df["x1_adstock_saturated"]).sum() / df["x1"].sum()
roas_2 = (amplitude * beta_2 * df["x2_adstock_saturated"]).sum() / df["x2"].sum()
fig, ax = plt.subplots(figsize=(7, 5))
(
pd.Series(data=[roas_1, roas_2], index=["x1", "x2"]).plot(
kind="bar", color=["C0", "C1"]
)
)
ax.set(title="ROAS (Approximation)", xlabel="Media Channel", ylabel="ROAS");
Es decir, el canal \(x_{1}\) parece ser más eficiente que el canal \(x_{2}\).
Nota
We recommend reading Section 4.1 in Jin, Yuxue, et al. “Bayesian methods for media mix modeling with carryover and shape effects.” (2017) for a detailed explanation of the ROAS (and mROAS). In particular:
Si transformamos nuestra variable objetivo \(y\) (p. ej., con una transformación logarítmica), hay que tener cuidado con el cálculo del ROAS, ya que fijar el gasto en cero no conmuta con la transformación.
Hay que tener cuidado con el efecto de adstock para incluir un período de arrastre que contabilice completamente el efecto del gasto. La estimación de ROAS anterior es una aproximación.
7. Data Output#
Por supuesto, no tendremos todas estas características en nuestros datos reales. Filtraremos las características que utilizaremos para el modelado:
columns_to_keep = [
"date_week",
"y",
"x1",
"x2",
"event_1",
"event_2",
"dayofyear",
]
data = df[columns_to_keep].copy()
data.head()
| date_week | y | x1 | x2 | event_1 | event_2 | dayofyear | |
|---|---|---|---|---|---|---|---|
| 0 | 2018-04-02 | 3.984662 | 0.318580 | 0.0 | 0.0 | 0.0 | 92 |
| 1 | 2018-04-09 | 3.762872 | 0.112388 | 0.0 | 0.0 | 0.0 | 99 |
| 2 | 2018-04-16 | 4.466967 | 0.292400 | 0.0 | 0.0 | 0.0 | 106 |
| 3 | 2018-04-23 | 3.864219 | 0.071399 | 0.0 | 0.0 | 0.0 | 113 |
| 4 | 2018-04-30 | 4.441625 | 0.386745 | 0.0 | 0.0 | 0.0 | 120 |
Parte II: Modelado#
En esta segunda parte, nos centramos en el proceso de modelado. Usaremos los datos generados en la Parte I.
1. Feature Engineering#
Suponiendo que hicimos un EDA y comprendemos bien los datos (aquí no lo hicimos porque generamos los datos nosotros mismos, ¡pero por favor nunca te saltes el EDA!), podemos empezar a construir nuestro modelo. Una cosa que vemos inmediatamente es la estacionalidad y el componente de tendencia. Podemos generar características nosotros mismos como variables de control, por ejemplo usando una línea recta de incremento uniforme para modelar el componente de tendencia. Además, incluimos variables ficticias para codificar las contribuciones de event_1 y event_2.
For the seasonality component we use Fourier modes (similar as in Prophet). We do not need to add the Fourier modes by hand as they are handled by the model API through the yearly_seasonality argument (see below). We use 2 modes for the seasonality component.
# trend feature
data["t"] = range(n)
data.head()
| date_week | y | x1 | x2 | event_1 | event_2 | dayofyear | t | |
|---|---|---|---|---|---|---|---|---|
| 0 | 2018-04-02 | 3.984662 | 0.318580 | 0.0 | 0.0 | 0.0 | 92 | 0 |
| 1 | 2018-04-09 | 3.762872 | 0.112388 | 0.0 | 0.0 | 0.0 | 99 | 1 |
| 2 | 2018-04-16 | 4.466967 | 0.292400 | 0.0 | 0.0 | 0.0 | 106 | 2 |
| 3 | 2018-04-23 | 3.864219 | 0.071399 | 0.0 | 0.0 | 0.0 | 113 | 3 |
| 4 | 2018-04-30 | 4.441625 | 0.386745 | 0.0 | 0.0 | 0.0 | 120 | 4 |
2. Model Specification#
We can specify the model structure using the MMM class. This class handles a lot of internal boilerplate code for us such as scaling the data (see details below) and handy diagnostics and reporting plots. One great feature is that we can specify the channel priors distributions ourselves, which is a fundamental component of the bayesian workflow as we can incorporate our prior knowledge into the model. This is one of the most important advantages of using a bayesian approach. Let’s see how we can do it.
Como no sabemos mucho más sobre los canales, empezamos con una heurística sencilla:
Las contribuciones de los canales deberían ser positivas, por lo que, por ejemplo, podemos usar una distribución
HalfNormalcomo prior. Necesitamos fijar el parámetrosigmapor canal. Cuanto mayor seasigma, más «libertad» tendrá para ajustar los datos. Para especificarsigmapodemos usar el siguiente punto.We expect channels where we spend the most to have more attributed sales, before seeing the data. This is a very reasonable assumption (note that we are not imposing anything at the level of efficiency!).
¿Cómo incorporar esta heurística en el modelo? Para empezar, es importante notar que la clase MMM escala las variables objetivo y de entrada mediante un transformador MaxAbsScaler de scikit-learn; es importante especificar los priors en el espacio escalado (es decir, entre 0 y 1). Una forma de hacerlo es usar la cuota de gasto como parámetro sigma para la distribución HalfNormal. De hecho, podemos añadir un factor de escalado para tener en cuenta el soporte de la distribución.
Primero, calculemos la cuota de gasto por canal:
total_spend_per_channel = data[["x1", "x2"]].sum(axis=0)
spend_share = total_spend_per_channel / total_spend_per_channel.sum()
spend_share
x1 0.65632
x2 0.34368
dtype: float64
A continuación, especificamos el parámetro sigma por canal:
n_channels = 2
prior_sigma = n_channels * spend_share.to_numpy()
prior_sigma.tolist()
Delayed Saturated MMM follows sklearn convention, so we need to split our data into X (predictors) and y (target value)
X = data.drop("y", axis=1)
y = data["y"]
Puedes usar el parámetro opcional “model_config” para aplicar tus propios priors al modelo. Cada entrada en “model_config” contiene una clave que corresponde a un nombre de distribución registrado en nuestro modelo. El valor de la clave es un diccionario que describe los parámetros de entrada de esa distribución específica.
Si no estás seguro de cómo definir tus propios priors, puedes usar la propiedad “default_model_config” de MMM para ver la estructura requerida.
dummy_model = MMM(
date_column="",
channel_columns=[""],
adstock=GeometricAdstock(l_max=4),
saturation=LogisticSaturation(),
)
dummy_model.default_model_config
{'intercept': Prior("Normal", mu=0, sigma=2, dims=()),
'likelihood': Prior("Normal", sigma=Prior("HalfNormal", sigma=2, dims=()), dims="date"),
'gamma_control': Prior("Normal", mu=0, sigma=2, dims="control"),
'gamma_fourier': Prior("Laplace", mu=0, b=1, dims="fourier_mode"),
'adstock_alpha': Prior("Beta", alpha=1, beta=3, dims="channel"),
'saturation_lam': Prior("Gamma", alpha=3, beta=1, dims="channel"),
'saturation_beta': Prior("HalfNormal", sigma=2, dims="channel")}
Puedes cambiar únicamente los parámetros a priori que desees; no es necesario modificar todos, ¡a menos que quieras!
my_model_config = {
"intercept": Prior("Normal", mu=0.5, sigma=0.2),
"saturation_beta": Prior("HalfNormal", sigma=prior_sigma, dims="channel"),
"gamma_control": Prior("Normal", mu=0, sigma=0.05, dims="control"),
"gamma_fourier": Prior("Laplace", mu=0, b=0.2, dims="fourier_mode"),
"likelihood": Prior("Normal", sigma=Prior("HalfNormal", sigma=6)),
}
Nota: Para la especificación de priors no hay una respuesta correcta o incorrecta. Todo depende de los datos, el contexto y los supuestos que estés dispuesto a hacer. Siempre se recomienda realizar muestreo predictivo previo y análisis de sensibilidad para comprobar el impacto de los priors en el posterior. Omitimos esto aquí por simplicidad. Si no estás seguro sobre priors específicos, la clase MMM tiene algunos priors predeterminados que puedes usar como punto de partida.
Model sampler allows specifying set of parameters that will be passed to fit the same way as the kwargs are getting passed so far. It doesn’t disable the fit kwargs, but rather extend them, to enable customizable and preservable configuration. By default the sampler_config for MMM is empty. But if you’d like to use it, you can define it like shown below:
my_sampler_config = {"progressbar": True}
Ahora estamos listos para usar la clase MMM para definir el modelo.
mmm = MMM(
model_config=my_model_config,
sampler_config=my_sampler_config,
date_column="date_week",
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
channel_columns=["x1", "x2"],
control_columns=["event_1", "event_2", "t"],
yearly_seasonality=2,
)
# Build the model and add contribution variables in original scale
mmm.build_model(X, y)
mmm.add_original_scale_contribution_variable(
var=[
"channel_contribution",
"control_contribution",
"intercept_contribution",
"yearly_seasonality_contribution",
"y",
]
)
pm.model_to_graphviz(mmm.model)
/Users/juanitorduz/micromamba/envs/pymc-marketing-dev/lib/python3.13/site-packages/pymc_extras/prior.py:822: UserWarning: Implicit conversion of array-like parameter sigma to DataArray with dims ('channel',). Use DataArray with explicit dims to avoid this warning
return _param_value_with_dims(param, value, dims=self.dims)
Observa cómo la clase MMM gestionó las transformaciones de medios.
Para evaluar los parámetros a priori del modelo podemos consultar la gráfica predictiva previa:
# Generate prior predictive samples
mmm.sample_prior_predictive(X, y, samples=2_000)
fig, axes = mmm.plot.prior_predictive()
La gráfica predictiva previa muestra que los priors no son demasiado informativos.
Note that the prior predictive plot is not in the original scale. The reason is that we handle scaling of the media variables and the target variable in the model class. Scaling is important for the model to sample efficiently. We will go deeper into this topic later. For now, we can show how to reproduce the plot in the original scale:
# Custom plot for prior predictive checks
fig, ax = plt.subplots()
for i, hdi_prob in enumerate([0.94, 0.5]):
az.plot_hdi(
x=mmm.model.coords["date"],
y=mmm.idata["prior"]["y_original_scale"].unstack().transpose(..., "date"),
smooth=False,
color="C0",
hdi_prob=hdi_prob,
fill_kwargs={"alpha": 0.3 + i * 0.1, "label": f"{hdi_prob:.0%} HDI"},
ax=ax,
)
sns.lineplot(data=df, x="date_week", y="y", color="black", label="Observed", ax=ax)
ax.legend(loc="upper left")
ax.set(xlabel="date", ylabel="y")
ax.set_title("Prior Predictive Checks", fontsize=18, fontweight="bold");
3. Model Fitting#
Ahora podemos ajustar el modelo:
Truco
Puedes usar otros muestreadores NUTS para ajustar el modelo como se puede hacer con modelos de PyMC. Solo necesitas asegurarte de tener los paquetes instalados en tu entorno local. Consulta Otros muestreadores NUTS.
%%time
mmm.fit(
X=X,
y=y,
chains=4,
tune=1_500,
draws=1_000,
target_accept=0.9,
random_seed=rng,
)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept_contribution, adstock_alpha, saturation_lam, saturation_beta, gamma_control, gamma_fourier, y_sigma]
Sampling 4 chains for 1_500 tune and 1_000 draw iterations (6_000 + 4_000 draws total) took 14 seconds.
CPU times: user 3.32 s, sys: 425 ms, total: 3.74 s
Wall time: 17.2 s
-
<xarray.Dataset> Size: 98MB Dimensions: (chain: 4, draw: 1000, channel: 2, control: 3, fourier_mode: 4, date: 179) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 ... 999 * channel (channel) <U2 16B 'x1' 'x2' * control (control) <U7 84B 'event_... * fourier_mode (fourier_mode) <U5 80B 's... * date (date) datetime64[ns] 1kB ... Data variables: (12/17) intercept_contribution (chain, draw) float64 32kB ... adstock_alpha (chain, draw, channel) float64 64kB ... saturation_lam (chain, draw, channel) float64 64kB ... saturation_beta (chain, draw, channel) float64 64kB ... gamma_control (chain, draw, control) float64 96kB ... gamma_fourier (chain, draw, fourier_mode) float64 128kB ... ... ... yearly_seasonality_contribution (chain, draw, date) float64 6MB ... channel_contribution_original_scale (chain, draw, date, channel) float64 11MB ... control_contribution_original_scale (chain, draw, date, control) float64 17MB ... intercept_contribution_original_scale (chain, draw) float64 32kB ... yearly_seasonality_contribution_original_scale (chain, draw, date) float64 6MB ... y_original_scale (chain, draw, date) float64 6MB ... Attributes: created_at: 2026-03-18T18:55:07.498726+00:00 arviz_version: 0.23.4 inference_library: pymc inference_library_version: 5.28.1 sampling_time: 13.935986280441284 tuning_steps: 1500 pymc_marketing_version: 0.18.2 -
<xarray.Dataset> Size: 528kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999 Data variables: (12/18) step_size (chain, draw) float64 32kB 0.1123 0.1123 ... 0.1045 lp (chain, draw) float64 32kB 347.3 343.4 ... 351.3 reached_max_treedepth (chain, draw) bool 4kB False False ... False False energy_error (chain, draw) float64 32kB -0.5743 ... -0.08631 perf_counter_start (chain, draw) float64 32kB 2.878e+05 ... 2.878e+05 energy (chain, draw) float64 32kB -339.1 -339.9 ... -339.5 ... ... step_size_bar (chain, draw) float64 32kB 0.09791 ... 0.09783 acceptance_rate (chain, draw) float64 32kB 0.8495 0.9967 ... 0.9785 divergences (chain, draw) int64 32kB 0 0 0 0 0 0 ... 0 0 0 0 0 0 max_energy_error (chain, draw) float64 32kB 0.8257 -0.2005 ... -0.1518 diverging (chain, draw) bool 4kB False False ... False False index_in_trajectory (chain, draw) int64 32kB -53 16 -17 -22 ... 18 12 13 Attributes: created_at: 2026-03-18T18:55:07.507249+00:00 arviz_version: 0.23.4 inference_library: pymc inference_library_version: 5.28.1 sampling_time: 13.935986280441284 tuning_steps: 1500 -
<xarray.Dataset> Size: 49MB Dimensions: (chain: 1, draw: 2000, date: 179, channel: 2, fourier_mode: 4, control: 3) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 16kB 0 ... 1999 * date (date) datetime64[ns] 1kB ... * channel (channel) <U2 16B 'x1' 'x2' * fourier_mode (fourier_mode) <U5 80B 's... * control (control) <U7 84B 'event_... Data variables: (12/17) intercept_contribution_original_scale (chain, draw) float64 16kB ... intercept_contribution (chain, draw) float64 16kB ... y_original_scale (chain, draw, date) float64 3MB ... adstock_alpha (chain, draw, channel) float64 32kB ... fourier_contribution (chain, draw, date, fourier_mode) float64 11MB ... y_sigma (chain, draw) float64 16kB ... ... ... channel_contribution_original_scale (chain, draw, date, channel) float64 6MB ... total_media_contribution_original_scale (chain, draw) float64 16kB ... saturation_beta (chain, draw, channel) float64 32kB ... gamma_fourier (chain, draw, fourier_mode) float64 64kB ... control_contribution (chain, draw, date, control) float64 9MB ... control_contribution_original_scale (chain, draw, date, control) float64 9MB ... Attributes: created_at: 2026-03-18T18:54:50.654895+00:00 arviz_version: 0.23.4 inference_library: pymc inference_library_version: 5.28.1 pymc_marketing_version: 0.18.2 -
<xarray.Dataset> Size: 3MB Dimensions: (chain: 1, draw: 2000, date: 179) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 16kB 0 1 2 3 4 5 6 ... 1994 1995 1996 1997 1998 1999 * date (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30 Data variables: y (chain, draw, date) float64 3MB 1.18 1.12 1.325 ... -1.336 7.49 Attributes: created_at: 2026-03-18T18:54:50.659439+00:00 arviz_version: 0.23.4 inference_library: pymc inference_library_version: 5.28.1 pymc_marketing_version: 0.18.2 -
<xarray.Dataset> Size: 3kB Dimensions: (date: 179) Coordinates: * date (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30 Data variables: y (date) float64 1kB 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625 Attributes: created_at: 2026-03-18T18:55:07.509684+00:00 arviz_version: 0.23.4 inference_library: pymc inference_library_version: 5.28.1 -
<xarray.Dataset> Size: 11kB Dimensions: (channel: 2, date: 179, control: 3) Coordinates: * channel (channel) <U2 16B 'x1' 'x2' * date (date) datetime64[ns] 1kB 2018-04-02 ... 2021-08-30 * control (control) <U7 84B 'event_1' 'event_2' 't' Data variables: channel_scale (channel) float64 16B 0.9967 0.9944 target_scale float64 8B 8.312 channel_data (date, channel) float64 3kB 0.3186 0.0 0.1124 ... 0.4389 0.0 target_data (date) float64 1kB 3.985 3.763 4.467 ... 4.138 4.479 4.676 control_data (date, control) float64 4kB 0.0 0.0 0.0 0.0 ... 0.0 0.0 178.0 dayofyear (date) int32 716B 92 99 106 113 120 ... 214 221 228 235 242 Attributes: created_at: 2026-03-18T18:55:07.511316+00:00 arviz_version: 0.23.4 inference_library: pymc inference_library_version: 5.28.1 -
<xarray.Dataset> Size: 11kB Dimensions: (date_week: 179) Coordinates: * date_week (date_week) datetime64[ns] 1kB 2018-04-02 ... 2021-08-30 Data variables: x1 (date_week) float64 1kB 0.3186 0.1124 0.2924 ... 0.2803 0.4389 x2 (date_week) float64 1kB 0.0 0.0 0.0 0.0 ... 0.8633 0.0 0.0 0.0 event_1 (date_week) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 event_2 (date_week) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 dayofyear (date_week) int32 716B 92 99 106 113 120 ... 214 221 228 235 242 t (date_week) int64 1kB 0 1 2 3 4 5 6 ... 173 174 175 176 177 178 y (date_week) float64 1kB 3.985 3.763 4.467 ... 4.138 4.479 4.676
Puedes acceder al modelo pymc como mmm.model.
type(mmm.model)
pymc.model.core.Model
print(f"Model was trained using the {mmm.saturation.__class__.__name__} function")
print(f"and the {mmm.adstock.__class__.__name__} function")
Model was trained using the LogisticSaturation function
and the GeometricAdstock function
Podemos ver fácilmente la estructura explícita del modelo:
mmm.graphviz()
Nota: Puedes notar que el gráfico aquí es una versión explícita de nuestro dibujo inicial (DAG), donde ahora podemos ver explícitamente todos los componentes diferentes que se incluyeron en cada nodo, incluyendo su dimensionalidad. Este gráfico es otra forma de ver las mismas suposiciones causales, realizadas durante la construcción del modelo generativo bayesiano.
Truco
There is another handy method to get a more detailed summary of the model structure:
mmm.table()
Variable Expression Dimensions ────────────────────────────────────────────────────────────────────────────────────────────────────────── channel_scale = Data channel[2] target_scale = Data channel_data = Data date[179] × channel[2] target_data = Data date[179] control_data = Data date[179] × control[3] dayofyear = Data date[179] intercept_contribution ~ Normal(0.5, 0.2) adstock_alpha ~ Beta(2, 1, 3) channel[2] saturation_lam ~ Gamma(2, 3, f()) channel[2] saturation_beta ~ HalfNormal(0, <constant>) channel[2] gamma_control ~ Normal(3, 0, 0.05) control[3] gamma_fourier ~ Laplace(4, 0, 0.2) fourier_mode[4] y_sigma ~ HalfNormal(0, 6) Parameter count = 15 channel_contribution = f() date[179] × channel[2] total_media_contribution_original_scale = f() control_contribution = f() date[179] × control[3] fourier_contribution = f() date[179] × fourier_mode[4] yearly_seasonality_contribution = f() date[179] channel_contribution_original_scale = f() date[179] × channel[2] control_contribution_original_scale = f() date[179] × control[3] intercept_contribution_original_scale = f() yearly_seasonality_contribution_original_scale = f() date[179] y_original_scale = f() date[179] y ~ Normal(f(), f()) date[179]
4. Model Diagnostics#
A good place to start assessing the model quality is by looking if the model had any divergences:
# Number of diverging samples
mmm.idata["sample_stats"]["diverging"].sum().item()
¡No obtuvimos ninguna! 🙌
El atributo fit_result contiene el objeto de traza pymc.
mmm.fit_result
<xarray.Dataset> Size: 98MB
Dimensions: (chain: 4, draw: 1000,
channel: 2, control: 3,
fourier_mode: 4, date: 179)
Coordinates:
* chain (chain) int64 32B 0 1 2 3
* draw (draw) int64 8kB 0 1 ... 999
* channel (channel) <U2 16B 'x1' 'x2'
* control (control) <U7 84B 'event_...
* fourier_mode (fourier_mode) <U5 80B 's...
* date (date) datetime64[ns] 1kB ...
Data variables: (12/17)
intercept_contribution (chain, draw) float64 32kB ...
adstock_alpha (chain, draw, channel) float64 64kB ...
saturation_lam (chain, draw, channel) float64 64kB ...
saturation_beta (chain, draw, channel) float64 64kB ...
gamma_control (chain, draw, control) float64 96kB ...
gamma_fourier (chain, draw, fourier_mode) float64 128kB ...
... ...
yearly_seasonality_contribution (chain, draw, date) float64 6MB ...
channel_contribution_original_scale (chain, draw, date, channel) float64 11MB ...
control_contribution_original_scale (chain, draw, date, control) float64 17MB ...
intercept_contribution_original_scale (chain, draw) float64 32kB ...
yearly_seasonality_contribution_original_scale (chain, draw, date) float64 6MB ...
y_original_scale (chain, draw, date) float64 6MB ...
Attributes:
created_at: 2026-03-18T18:55:07.498726+00:00
arviz_version: 0.23.4
inference_library: pymc
inference_library_version: 5.28.1
sampling_time: 13.935986280441284
tuning_steps: 1500
pymc_marketing_version: 0.18.2Por lo tanto, podemos usar toda la maquinaria pymc para ejecutar el diagnóstico del modelo. Primero, veamos el resumen de la traza:
az.summary(
data=mmm.fit_result,
var_names=[
"adstock_alpha",
"gamma_control",
"gamma_fourier",
"intercept_contribution",
"saturation_beta",
"saturation_lam",
"y_sigma",
],
)
| mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
|---|---|---|---|---|---|---|---|---|---|
| adstock_alpha[x1] | 0.402 | 0.032 | 0.345 | 0.463 | 0.001 | 0.001 | 2489.0 | 2528.0 | 1.0 |
| adstock_alpha[x2] | 0.187 | 0.040 | 0.113 | 0.266 | 0.001 | 0.001 | 2466.0 | 2742.0 | 1.0 |
| gamma_control[event_1] | 0.176 | 0.027 | 0.126 | 0.230 | 0.000 | 0.000 | 4024.0 | 3039.0 | 1.0 |
| gamma_control[event_2] | 0.231 | 0.028 | 0.175 | 0.281 | 0.000 | 0.000 | 3807.0 | 3169.0 | 1.0 |
| gamma_control[t] | 0.001 | 0.000 | 0.001 | 0.001 | 0.000 | 0.000 | 2916.0 | 2655.0 | 1.0 |
| gamma_fourier[sin_1] | 0.003 | 0.003 | -0.004 | 0.009 | 0.000 | 0.000 | 3893.0 | 2753.0 | 1.0 |
| gamma_fourier[sin_2] | -0.058 | 0.004 | -0.064 | -0.051 | 0.000 | 0.000 | 4266.0 | 3090.0 | 1.0 |
| gamma_fourier[cos_1] | 0.062 | 0.003 | 0.056 | 0.069 | 0.000 | 0.000 | 5107.0 | 2645.0 | 1.0 |
| gamma_fourier[cos_2] | 0.001 | 0.004 | -0.006 | 0.008 | 0.000 | 0.000 | 3875.0 | 2875.0 | 1.0 |
| intercept_contribution | 0.355 | 0.013 | 0.331 | 0.381 | 0.000 | 0.000 | 2177.0 | 2481.0 | 1.0 |
| saturation_beta[x1] | 0.362 | 0.020 | 0.325 | 0.401 | 0.000 | 0.000 | 1994.0 | 2179.0 | 1.0 |
| saturation_beta[x2] | 0.265 | 0.073 | 0.192 | 0.368 | 0.002 | 0.005 | 1986.0 | 1564.0 | 1.0 |
| saturation_lam[x1] | 3.945 | 0.384 | 3.226 | 4.684 | 0.008 | 0.006 | 2539.0 | 2107.0 | 1.0 |
| saturation_lam[x2] | 3.175 | 1.188 | 1.210 | 5.405 | 0.026 | 0.029 | 1902.0 | 1675.0 | 1.0 |
| y_sigma | 0.031 | 0.002 | 0.028 | 0.035 | 0.000 | 0.000 | 3524.0 | 2912.0 | 1.0 |
Observa que los parámetros estimados para \(\alpha\) y \(\lambda\) son muy cercanos a los que establecimos en el proceso de generación de datos! Vamos a trazar la traza para los parámetros:
_ = az.plot_trace(
data=mmm.fit_result,
var_names=[
"adstock_alpha",
"gamma_control",
"gamma_fourier",
"intercept_contribution",
"saturation_beta",
"saturation_lam",
"y_sigma",
],
compact=True,
backend_kwargs={"figsize": (12, 10), "layout": "constrained"},
)
plt.gcf().suptitle("Model Trace", fontsize=18, fontweight="bold");
Overall we see a good chain mixing.
Now we sample from the posterior predictive distribution. That is, we sample from the posterior distribution to get predictions for the target variable.
mmm.sample_posterior_predictive(X=X, random_seed=rng)
Sampling: [y]
<xarray.Dataset> Size: 12MB
Dimensions: (date: 179, sample: 4000)
Coordinates:
* date (date) datetime64[ns] 1kB 2018-04-02 ... 2021-08-30
* sample (sample) object 32kB MultiIndex
* chain (sample) int64 32kB 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3
* draw (sample) int64 32kB 0 1 2 3 4 5 ... 995 996 997 998 999
Data variables:
y (date, sample) float64 6MB 0.4584 0.4841 ... 0.5806 0.5661
y_original_scale (date, sample) float64 6MB 3.81 4.024 4.64 ... 4.826 4.706
Attributes:
created_at: 2026-03-18T18:55:13.766748+00:00
arviz_version: 0.23.4
inference_library: pymc
inference_library_version: 5.28.1We can now plot the posterior predictive distribution for the target variable. By default, the plot_posterior_predictive method will plot the mean prediction along with a \(94\%\) HDI.
fig, axes = mmm.plot.posterior_predictive(var=["y_original_scale"], hdi_prob=0.94)
sns.lineplot(
data=df, x="date_week", y="y", color="black", label="Observed", ax=axes[0][0]
);
El ajuste parece muy bueno (como se esperaba)!
Podemos inspeccionar los errores del modelo:
We do not see any pattern in the errors, which is a good sign.
Next, we can decompose the posterior predictive distribution into the different components. We start by looking at the channel contributions:
# Component contributions (scaled space)
fig, axes = mmm.plot.contributions_over_time(
var=["channel_contribution"], hdi_prob=0.94
)
We can plot the contributions in the original scale:
# Component contributions (original scale)
mmm.plot.contributions_over_time(
var=["channel_contribution_original_scale"],
hdi_prob=0.94,
);
Nota
The scalers attribute contains the scaling information for the target variable and the media variables.
There are simple numbers (stored in a xarray.Dataset) that we can use to scale the variables back to the original scale.
mmm.scalers
<xarray.Dataset> Size: 40B
Dimensions: (channel: 2)
Coordinates:
* channel (channel) object 16B 'x1' 'x2'
Data variables:
_channel (channel) float64 16B 0.9967 0.9944
_target float64 8B 8.312Let’s check that the scaling is correct:
# Channel contributions (x1)
np.testing.assert_allclose(
mmm.idata["posterior"]["channel_contribution"].sel(channel="x1")
* mmm.scalers["_target"],
mmm.idata["posterior"]["channel_contribution_original_scale"].sel(channel="x1"),
)
# Channel contributions (x2)
np.testing.assert_allclose(
mmm.idata["posterior"]["channel_contribution"].sel(channel="x2")
* mmm.scalers["_target"],
mmm.idata["posterior"]["channel_contribution_original_scale"].sel(channel="x2"),
)
# Intercept contribution
np.testing.assert_allclose(
mmm.idata["posterior"]["intercept_contribution"] * mmm.scalers["_target"],
mmm.idata["posterior"]["intercept_contribution_original_scale"],
)
We can now plot all the contributions in the original scale:
# Component contributions (original scale)
fig, axes = mmm.plot.contributions_over_time(
var=[
"channel_contribution_original_scale",
"control_contribution_original_scale",
"intercept_contribution_original_scale",
"yearly_seasonality_contribution_original_scale",
],
dims={"channel": ["x1", "x2"]},
hdi_prob=0.94,
)
axes = axes.flatten()
for ax in axes:
legend = ax.get_legend()
legend.set_bbox_to_anchor((0.5, -0.1))
We can combine these plots as:
# Component contributions (original scale)
fig, ax = mmm.plot.contributions_over_time(
var=[
"channel_contribution_original_scale",
"control_contribution_original_scale",
"intercept_contribution_original_scale",
"yearly_seasonality_contribution_original_scale",
],
dims={"channel": ["x1", "x2"]},
combine_dims=True,
hdi_prob=0.94,
figsize=(12, 7),
)
legend = ax[0, 0].get_legend()
legend.set_bbox_to_anchor((0.8, -0.12))
The following code shows how to manually generate the aggregated channel contribution against the other components:
Se puede lograr una descomposición similar usando un gráfico de área:
Here the base means the sum of the intercept, control and seasonal components. Note that this only works if the contributions of the channel or control variable are strictly positive.
Next, we look into the absolute historical contributions of each component as a waterfall plot. This type of visualization is very useful to present to a non-technical audience and decision makers.
Note that we have recovered the true values for all the parameters! Well, in fact the contributions of the intercept and t are not exactly the same as in the data generating process, but the aggregate does match the true values of intercept + trend. The reason is that the true latent trend is not completely linear. One could use the time-varying intercept feature to capture this effect.
We can extract the mean contributions over time directly from the model:
| date | x1 | x2 | event_1 | event_2 | t | yearly_seasonality | intercept | |
|---|---|---|---|---|---|---|---|---|
| 0 | 2018-04-02 | 1.079970 | 0.000000 | 0.0 | 0.0 | 0.000000 | 0.021160 | 2.950151 |
| 1 | 2018-04-09 | 0.830757 | 0.000000 | 0.0 | 0.0 | 0.005126 | 0.073151 | 2.950151 |
| 2 | 2018-04-16 | 1.290704 | 0.000000 | 0.0 | 0.0 | 0.010251 | 0.118963 | 2.950151 |
| 3 | 2018-04-23 | 0.790082 | 0.000000 | 0.0 | 0.0 | 0.015377 | 0.153282 | 2.950151 |
| 4 | 2018-04-30 | 1.536806 | 0.000000 | 0.0 | 0.0 | 0.020502 | 0.171528 | 2.950151 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 174 | 2021-08-02 | 0.335762 | 0.003322 | 0.0 | 0.0 | 0.891853 | -0.875931 | 2.950151 |
| 175 | 2021-08-09 | 0.710576 | 1.603175 | 0.0 | 0.0 | 0.896979 | -0.886478 | 2.950151 |
| 176 | 2021-08-16 | 0.875334 | 0.407119 | 0.0 | 0.0 | 0.902105 | -0.864161 | 2.950151 |
| 177 | 2021-08-23 | 1.270923 | 0.077905 | 0.0 | 0.0 | 0.907230 | -0.808582 | 2.950151 |
| 178 | 2021-08-30 | 1.812030 | 0.015387 | 0.0 | 0.0 | 0.912356 | -0.721022 | 2.950151 |
179 rows × 8 columns
5. Media Parameters#
Podemos profundizar en los parámetros de transformación de medios. Queremos comparar las distribuciones posteriores con los valores verdaderos.
fig, ax = plt.subplots(
nrows=2,
ncols=1,
sharex=True,
sharey=True,
figsize=(12, 7),
layout="constrained",
)
az.plot_posterior(
mmm.idata["posterior"],
var_names=["adstock_alpha"],
ref_val={
"adstock_alpha": [
{"channel": "x1", "ref_val": alpha1},
{"channel": "x2", "ref_val": alpha2},
],
},
ax=ax,
)
fig.suptitle("Adstock Alpha Posterior", fontsize=18, fontweight="bold");
fig, ax = plt.subplots(
nrows=2,
ncols=1,
sharex=True,
sharey=True,
figsize=(12, 7),
layout="constrained",
)
az.plot_posterior(
mmm.idata["posterior"],
var_names=["saturation_lam"],
ref_val={
"saturation_lam": [
{"channel": "x1", "ref_val": lam1},
{"channel": "x2", "ref_val": lam2},
],
},
ax=ax,
)
fig.suptitle("Saturation Lambda Posterior", fontsize=18, fontweight="bold");
We indeed see that our media parameters were successfully recovered!
6. Media Deep-Dive#
Primero podemos calcular la contribución relativa de cada canal a la variable objetivo. Ten en cuenta que recuperamos los valores verdaderos!
fig, ax = mmm.plot.channel_contribution_share_hdi(figsize=(10, 6))
ax.axvline(
x=contribution_share_x1,
color="C1",
linestyle="--",
label="true contribution share ($x_1$)",
)
ax.axvline(
x=contribution_share_x2,
color="C2",
linestyle="--",
label="true contribution share ($x_2$)",
)
ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.05), ncol=1);
A continuación, podemos trazar la contribución relativa de cada canal a la variable objetivo.
Primero trazamos la contribución directa por canal. De nuevo, obtenemos valores muy cercanos a los obtenidos en la Parte I.
fig, axes = mmm.plot.saturation_scatterplot(original_scale=True)
[ax.set(xlabel="x") for ax in axes.flatten()];
Note that trying to get the delayed cumulative contribution is not that easy as contributions from the past leak into the future. Specifically, note that we apply the saturation function to the aggregation. As the saturation function is non-linear. This is not the same as taking the sum of the saturation contributions Hence, it is very hard to reverse engineer the contribution after carryover and saturation composition this way.
A more transparent alternative is to evaluate the channel contribution at different share spend levels for the complete training period. Concretely, if we denote by \(\delta\) (we call it sweep factor) the input channel data percentage level, so that for \(\delta = 1\) we have the model input spend data and for \(\delta = 1.5\) we have a \(50\%\) increase in the spend, then we can compute the channel contribution at a grid of \(\delta\)-values and plot the results:
# Run sensitivity analysis sweep
sweeps = np.linspace(0, 1.5, 12)
mmm.sensitivity.run_sweep(
sweep_values=sweeps,
var_input="channel_data",
var_names="channel_contribution_original_scale",
extend_idata=True,
)
# Plot sensitivity analysis
ax = mmm.plot.sensitivity_analysis(
xlabel="Sweep multiplicative",
ylabel="Total contribution over training period",
hue_dim="channel",
x_sweep_axis="relative",
)
ax.axvline(1.0, color="black", linestyle="--", linewidth=1);
Here the black dashed line represents the case where the spend is at the historical level.
Este gráfico tiene en cuenta el arrastre (adstock) y el efecto de saturación.
We see that when we have no spend, the contribution is zero (assuming there was no spend in the past, otherwise the carryover effect would be non-zero).
Observa que estos valores de cuadrícula sirven como entradas para un paso de optimización.
También podemos trazar la misma contribución usando el eje x como la entrada total del canal (p. ej., el gasto total en EUR).
# Plot sensitivity analysis with absolute x-axis
ax = mmm.plot.sensitivity_analysis(
xlabel="Sweep absolute spend",
ylabel="Total contribution over training period",
hue_dim="channel",
x_sweep_axis="absolute",
)
for i, channel in enumerate(["x1", "x2"]):
ax.axvline(
X[channel].sum(),
color=f"C{i}",
linestyle="--",
label=f"historical total spend ({channel})",
)
ax.legend(loc="upper left");
All of these visualizations are very useful to understand the contribution of each channel to the target variable and the effect of the saturation and adstock effects. For more details on how to interpret these plots, please refer to tutorial Understanding Media Saturation in Marketing Mix Models.
7. Contribution Recovery#
A continuación, podemos trazar la contribución directa de cada canal a la variable objetivo a lo largo del tiempo.
# Component contributions (original scale)
fig, axes = mmm.plot.contributions_over_time(
var=["channel_contribution_original_scale"],
hdi_prob=0.94,
)
axes = axes.flatten()
for i, x in enumerate(["x1", "x2"]):
# Estimate true contribution in the original scale from the data generating process
sns.lineplot(
x=df["date_week"],
y=amplitude * betas[i] * df[f"{x}_adstock_saturated"],
color="black",
label=f"{x} true contribution",
linestyle="--",
alpha=0.5,
ax=axes[i],
)
[ax.legend(loc="upper left") for ax in axes]
fig.suptitle("Contribution Recovery", fontsize=18, fontweight="bold");
The results look great! We therefore successfully recovered the true values from the data generation process. We have also seen how easy it is to use the MMM class to fit media mix models! It takes over the model specification and the media transformations, while having all the flexibility of pymc!
8. ROAS#
Finalmente, podemos calcular la distribución posterior (aproximada) de ROAS para cada canal.
roas = mmm.incrementality.contribution_over_spend(frequency="all_time").rename("roas")
fig, axes = plt.subplots(
nrows=2, ncols=1, figsize=(12, 7), sharex=True, sharey=False, layout="constrained"
)
az.plot_posterior(roas, ref_val=[roas_1, roas_2], ax=axes)
axes[0].set(title="Channel $x_{1}$")
axes[1].set(title="Channel $x_{2}$", xlabel="ROAS")
fig.suptitle("ROAS Posterior Distributions", fontsize=18, fontweight="bold", y=1.06);
Vemos que las distribuciones posteriores de ROAS están centradas en los valores verdaderos. También vemos que, incluso considerando la incertidumbre, el canal \(x_{1}\) es más eficiente que el canal \(x_{2}\).
It is also useful to compare the ROAS and the contribution share. In the next plot we plot these two inferred estimates per channel.
This plot is very effective at summarizing channel efficiency. In this example, it turns out that the most efficient channel \(x_1\) has a higher contribution share than the less efficient channel \(x_2\).
9. Out of Sample Predictions#
Las predicciones fuera de muestra se realizan con los métodos predict y posterior_predictive. Estos incluyen
sample_posterior_predictive: Obtener la distribución predictiva posterior completapredict: Obtener la media de la distribución predictiva posterior
Estos métodos toman nuevos datos, X, y algunos kwargs adicionales para nuevas predicciones. Específicamente,
include_last_observations: booleano para llevar efectos de adstock de las últimas observaciones en el conjunto de datos de entrenamiento
Los nuevos datos necesitan tener todas las características que se especifican en el modelo. No hay que preocuparse por:
escalado de gastos de canales de entrada
creación de transformaciones de Fourier en la columna
date_columnescalado inverso al dominio de destino
¡Esto se hará automáticamente! Sin embargo, por favor ten en cuenta que las variables de control NO se escalan automáticamente - si es necesario, debes escalarlas antes de pasar los datos al modelo.
last_date = X["date_week"].max()
# New dates starting from last in dataset
n_new = 5
new_dates = pd.date_range(start=last_date, periods=1 + n_new, freq="W-MON")[1:]
X_out_of_sample = pd.DataFrame(
{
"date_week": new_dates,
}
)
# Same channel spends as last day
X_out_of_sample["x1"] = X["x1"].iloc[-1]
X_out_of_sample["x2"] = X["x2"].iloc[-1]
# Other features
X_out_of_sample["event_1"] = 0
X_out_of_sample["event_2"] = 0
X_out_of_sample["t"] = range(len(X), len(X) + n_new)
X_out_of_sample
| date_week | x1 | x2 | event_1 | event_2 | t | |
|---|---|---|---|---|---|---|
| 0 | 2021-09-06 | 0.438857 | 0.0 | 0 | 0 | 179 |
| 1 | 2021-09-13 | 0.438857 | 0.0 | 0 | 0 | 180 |
| 2 | 2021-09-20 | 0.438857 | 0.0 | 0 | 0 | 181 |
| 3 | 2021-09-27 | 0.438857 | 0.0 | 0 | 0 | 182 |
| 4 | 2021-10-04 | 0.438857 | 0.0 | 0 | 0 | 183 |
Llama al método deseado para obtener las nuevas muestras! Las nuevas coordenadas serán de las nuevas fechas
X_out_of_sample.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5 entries, 0 to 4
Data columns (total 6 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 date_week 5 non-null datetime64[ns]
1 x1 5 non-null float64
2 x2 5 non-null float64
3 event_1 5 non-null int64
4 event_2 5 non-null int64
5 t 5 non-null int64
dtypes: datetime64[ns](1), float64(2), int64(3)
memory usage: 372.0 bytes
y_out_of_sample = mmm.sample_posterior_predictive(X_out_of_sample, extend_idata=False)
y_out_of_sample
Sampling: [y]
<xarray.Dataset> Size: 416kB
Dimensions: (date: 5, sample: 4000)
Coordinates:
* date (date) datetime64[ns] 40B 2021-09-06 ... 2021-10-04
* sample (sample) object 32kB MultiIndex
* chain (sample) int64 32kB 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3
* draw (sample) int64 32kB 0 1 2 3 4 5 ... 995 996 997 998 999
Data variables:
y (date, sample) float64 160kB 0.5375 0.5461 ... 0.7077
y_original_scale (date, sample) float64 160kB 4.468 4.54 ... 6.267 5.883
Attributes:
created_at: 2026-03-18T18:55:23.418168+00:00
arviz_version: 0.23.4
inference_library: pymc
inference_library_version: 5.28.1Nota
If the method is being called multiple times, set the extend_idata argument to False in order to not overwrite the observed_data in the InferenceData
Las nuevas predicciones se transforman de nuevo a la escala original de la variable objetivo por defecto. Esto se puede ver a continuación:
def plot_in_sample(X, y, ax, n_points: int = 15):
sns.lineplot(
x=X["date_week"][-n_points:],
y=y[-n_points:],
marker="o",
markersize=7,
color="black",
label="actuals",
ax=ax,
)
return ax
def plot_out_of_sample(X_out_of_sample, y_out_of_sample, ax, color, label):
y_out_original_scale = (
y_out_of_sample["y_original_scale"].unstack().transpose(..., "date")
)
az.plot_hdi(
X_out_of_sample["date_week"].dt.to_pydatetime(),
y_out_original_scale,
smooth=False,
fill_kwargs={"alpha": 0.25, "color": color},
ax=ax,
)
mean = y_out_original_scale.mean(dim=("chain", "draw"))
mean.plot(ax=ax, marker="o", markersize=7, label=label, color=color, linestyle="--")
ax.set(ylabel="Original Target Scale")
ax.set_title("Out of sample predictions for MMM", fontsize=18, fontweight="bold")
return ax
_, ax = plt.subplots()
plot_in_sample(X, y, ax=ax)
plot_out_of_sample(
X_out_of_sample, y_out_of_sample, ax=ax, label="out of sample", color="C0"
)
ax.legend(loc="upper left");
Si los datos fuera de muestra se extienden desde las predicciones originales, considera establecer el include_last_observations a True para llevar los efectos de los gastos del último canal en el conjunto de entrenamiento.
The predictions are higher since the channel contributions from the final spends still have an impact that eventually subside.
y_out_of_sample_with_adstock = mmm.sample_posterior_predictive(
X_out_of_sample, extend_idata=False, include_last_observations=True
)
Sampling: [y]
10. Save Model#
After your model is trained, you can quickly save it using the save method. For more information about model deployment see Despliegue del modelo.
mmm.save("model.nc", engine="h5netcdf")
%load_ext watermark
%watermark -n -u -v -iv -w -p pymc_marketing,pytensor
Last updated: Wed, 18 Mar 2026
Python implementation: CPython
Python version : 3.13.12
IPython version : 9.11.0
pymc_marketing: 0.18.2
pytensor : 2.38.2
arviz : 0.23.4
graphviz : 0.21
matplotlib : 3.10.8
numpy : 2.4.2
pandas : 2.3.3
pymc : 5.28.1
pymc_extras : 0.9.3
pymc_marketing: 0.18.2
seaborn : 0.13.2
Watermark: 2.6.0