MMM con una línea base de medios variable en el tiempo#

Introducción#

En el ámbito del Modelado de Mezcla de Marketing (MMM), comprender el impacto de diversas actividades de marketing en una variable objetivo y otros indicadores clave de rendimiento es crucial. Los modelos de regresión tradicionales a menudo pasan por alto las dinámicas temporales de las actividades de marketing, lo que puede llevar a percepciones sesgadas o incompletas. Este cuaderno tiene como objetivo mostrar la diferencia entre un modelo de regresión convencional que no tiene en cuenta la variación temporal y un modelo más sofisticado que incorpora el tiempo como un componente clave a través de un proceso gaussiano.

El objetivo es determinar la contribución de cada actividad de marketing a la variable objetivo general o resultado deseado. Este proceso generalmente implica dos transformaciones críticas:

  1. Función de Saturación: Esta función modela los rendimientos decrecientes de las entradas de marketing. A medida que se asignan más recursos a un canal específico, el beneficio incremental tiende a disminuir.

  2. Función Adstock: Esta función captura el efecto de arrastre de las actividades de marketing a lo largo del tiempo, reconociendo que el impacto de un esfuerzo de marketing se extiende más allá del período inmediato en el que ocurre.

El enfoque estándar en MMM aplica estas transformaciones a las entradas de marketing, resultando en una contribución al resultado.

Modelo MMM Dependiente del Tiempo#

En escenarios del mundo real, la efectividad de las actividades de marketing no es estática, sino que varía con el tiempo debido a factores como las acciones competitivas y la dinámica del mercado. Para tener en cuenta esto, introducimos un componente dependiente del tiempo en el marco de MMM utilizando un Proceso Gaussiano, específicamente un GP de Espacio de Hilbert. Esto nos permite capturar la variación temporal latente oculta de las contribuciones de marketing.

Especificación del modelo#

En pymc-marketing proporcionamos una API para la especificación de un modelo de mezcla de medios bayesiano (MMM) siguiendo Jin, Yuxue, et al. “Métodos bayesianos para la modelización de mezcla de medios con efectos de arrastre y forma.” (2017) como modelo base. Concretamente, dado una variable objetivo de serie temporal \(y_{t}\) (por ejemplo, ventas o conversiones), variables de medios \(x_{m, t}\) (por ejemplo, impresiones, clics o costos) y un conjunto de covariables de control \(z_{c, t}\) (por ejemplo, días festivos, eventos especiales), consideramos un modelo lineal de la forma

\[::\]

donde \(\alpha\) es la intersección, \(f\) es una función de transformación de medios y \(\varepsilon_{t}\) es el término de error que asumimos que está distribuido normalmente. La función \(f\) codifica la contribución de los medios en la variable objetivo. Típicamente consideramos dos tipos de transformación: adstock (carry-over) y efectos de saturación.

Cuando time_media_varying se establece en True, capturamos un único proceso latente que multiplica todos los canales. Suponemos que todos los canales comparten las mismas fluctuaciones dependientes del tiempo, en contraste con las implementaciones donde cada canal tiene un proceso latente independiente. El modelo modificado se puede representar como:

\[::\]

donde \(\lambda_{t}\) es el componente variable en el tiempo modelado como un proceso latente. Esta variación compartida dependiente del tiempo \(\lambda_{t}\) nos permite capturar los efectos temporales generales que influyen en todos los canales de medios simultáneamente.

Objetivo#

Este cuaderno hará:

  1. Ilustre la formulación de un modelo MMM estándar sin variación temporal.

  2. Extienda el modelo para incluir un componente temporal utilizando HSGP.

  3. Compare los resultados y las percepciones derivadas de ambos modelos, destacando la importancia de incorporar la variación temporal para capturar el verdadero impacto de las actividades de marketing.

Al final de este cuaderno, tendrá una comprensión completa de las ventajas de utilizar modelos MMM dependientes del tiempo para capturar la naturaleza dinámica de la efectividad del marketing, lo que conducirá a conocimientos más precisos y accionables.

Conocimientos Previos#

El cuaderno asume que el lector tiene conocimiento de las funcionalidades esenciales de PyMC-Marketing. Si no está familiarizado, el «Cuaderno de Ejemplo de MMM» sirve como un excelente punto de partida, ofreciendo una introducción completa a los modelos de mezcla de medios en este contexto.


Parte I: Proceso de Generación de Datos#

In Part I of this notebook we focus on the data generating process. We want to construct the target variable \(y_{t}\) (sales) by adding each of the components described in the Business Problem section.

Note: Model components are built out of this notebook, we’ll show how they combine together to generate the target.

Preparar el cuaderno#

import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from pymc_extras.prior import Prior

from pymc_marketing.mmm import GeometricAdstock, MichaelisMentenSaturation
from pymc_marketing.mmm.multidimensional import MMM
from pymc_marketing.paths import data_dir

warnings.filterwarnings("ignore")

az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"

1. Date Range#

Primero establecemos un rango de tiempo para nuestros datos. Consideramos un poco más de 2 años de datos con granularidad semanal.

# Creating variables

seed: int = sum(map(ord, "Time Media Contributions are amazing"))
rng: np.random.Generator = np.random.default_rng(seed=seed)

# date range
min_date = pd.to_datetime("2018-04-01")
max_date = pd.to_datetime("2021-09-01")

df = pd.read_csv(data_dir / "media_tvp_data.csv", index_col=0)
df["date_week"] = pd.to_datetime(df["date_week"])

n = df.shape[0]
print(f"Number of observations: {n}")
Number of observations: 179

2. Media Costs Data#

Ahora generamos datos sintéticos a partir de dos canales \(x_1\) y \(x_2\). Nos referimos a ello como la señal en bruto, ya que será la entrada en la fase de modelado. Esperamos que la contribución de cada canal sea diferente, basada en los parámetros de arrastre y saturación.

fig, ax = plt.subplots(
    nrows=2, ncols=1, figsize=(10, 7), sharex=True, sharey=True, layout="constrained"
)
sns.lineplot(x="date_week", y="x1", data=df, color="C0", ax=ax[0])
sns.lineplot(x="date_week", y="x2", data=df, color="C1", ax=ax[1])
ax[1].set(xlabel="date_week")
fig.suptitle("Media Costs Data", fontsize=16);

Observación: Por diseño, \(x_{1}\) debería parecerse a un canal social pagado típico y \(x_{2}\) a una serie temporal de gastos offline (por ejemplo, televisión).

4. Control Variables#

Agregamos dos eventos donde hubo un pico notable en nuestra variable objetivo. Asumimos que son independientes y no estacionales (por ejemplo, el lanzamiento de un producto en particular).

fig, ax = plt.subplots(
    nrows=2, ncols=1, figsize=(10, 7), sharex=True, sharey=True, layout="constrained"
)
sns.lineplot(x="date_week", y="event_1", data=df, color="C0", ax=ax[0])
ax[0].set_title("Event 1")
sns.lineplot(x="date_week", y="event_2", data=df, color="C1", ax=ax[1])
ax[1].set_title("Event 2")
ax[1].set(xlabel="date_week")
fig.suptitle("Control Events", fontsize=16);

5. Temporal Hidden Latent Process#

Para ilustrar el impacto del rendimiento variable de los medios a lo largo del tiempo en nuestro modelo, generamos una señal sintética que modifica la contribución base. Esta señal, hidden_latent_media_fluctuation, está diseñada para simular las fluctuaciones naturales en el rendimiento de los medios a lo largo del tiempo.

fig, ax = plt.subplots(
    nrows=1, ncols=1, figsize=(10, 7), sharex=True, sharey=True, layout="constrained"
)
sns.lineplot(x="date_week", y="hidden_latent_media_fluctuation", data=df, color="C0")
ax.set(xlabel="date_week")
fig.suptitle("Media performance change", fontsize=16);

Al centrar la señal alrededor de 1, mantendremos la contribución base como el efecto promedio mientras permitimos aumentos y disminuciones periódicas. Este enfoque refleja escenarios del mundo real donde la efectividad del marketing puede variar, pero la tendencia general se mantiene constante.

Esta señal sintética es esencial para demostrar la eficacia de nuestro modelo MMM dependiente del tiempo, que debería recuperar esta señal tanto como sea posible.

6. Target Variable#

Finalmente, necesitamos crear nuestra variable objetivo. Para hacerlo, utilizaremos el operador do de PyMC operator para especificar algunos valores de parámetros verdaderos que rigen las relaciones causales en el modelo.

Haciendo esto, generaremos una variable objetivo simulada (ventas) \(y\), que asumimos es una combinación lineal de todos los componentes en el modelo. También añadimos algo de ruido gaussiano.

# Real values
real_alpha = [3, 5]
real_lam = [0.3, 0.5]

adstock_max_lag = 8
yearly_seasonality = 2

true_params = {
    "intercept_contribution": 6.0,
    "adstock_alpha": np.array([0.5, 0.4]),
    "saturation_alpha": np.array(real_alpha),
    "saturation_lam": np.array(real_lam),
    "media_temporal_latent_multiplier": df["hidden_latent_media_fluctuation"],
    "gamma_fourier": np.array([2.5, -0.5, 1.5, 2.5]),
    "y_sigma": 0.25,
    "gamma_control": np.array([-3.5, 6.25]),
}

Using the grid, the do operator, with a dummy model from the MMM class, we can build the true model. We’ll no go into those details, but we can unpack this a little bit.

The do-function takes a pymc.Model object and a dict of parameter values. It then returns a new model where the original random variables (RVs) have been converted to constant nodes taking on the specified values.

Meaning if we pick the node for the intercept, the value should match our grid.

plt.plot(df["intercept"])
plt.title("Intercept Over Time")
plt.xlabel("date_week")
plt.ylabel("Sales (thousands)");

Como puede ver, la intersección está alineada con los datos añadidos previamente, teniendo un valor constante de 6. Pero, ¿cómo se ve nuestra contribución total después de la transformación?

fig, ax = plt.subplots(2, 1, figsize=(12, 6), sharex=True)

ax[0].plot(df["baseline_channel_contribution"], color="purple", linestyle="--")
ax[0].set_title("Baseline Channel Contribution")
ax[0].set_xlabel("date_week")
ax[0].set_ylabel("Sales (thousands)")

ax[1].plot(df["channel_contribution"], color="purple")
ax[1].set_title("Channel Contribution")
ax[1].set_xlabel("date_week")
ax[1].set_ylabel("Sales (thousands)");

Contribuciones del Canal Base

El gráfico de la izquierda, titulado «Contribuciones de Canales Base», muestra las contribuciones de los canales de medios antes de considerar los efectos variables en el tiempo. Los valores se generan al sumar las contribuciones de canales base extraídas del modelo verdadero.

Contribuciones del Canal con Variación Temporal

El gráfico de la derecha, titulado «Contribuciones del Canal», muestra las contribuciones de los canales de medios después de incorporar la señal de rendimiento de medios variable en el tiempo. Estas contribuciones reflejan el impacto del proceso temporal latente, representado por hidden_latent_media_fluctuation, que modifica las contribuciones base. Esta modificación captura las fluctuaciones naturales en el rendimiento de los medios a lo largo del tiempo, influenciadas por diversas dinámicas de marketing.

fig, ax = plt.subplots(
    nrows=2, ncols=1, figsize=(10, 7), sharex=True, sharey=True, layout="constrained"
)
sns.lineplot(x="date_week", y="x1_contribution", data=df, color="C0", ax=ax[0])
sns.lineplot(x="date_week", y="x2_contribution", data=df, color="C1", ax=ax[1])
ax[1].set(xlabel="date_week")
fig.suptitle("Media Contribution per Channel", fontsize=16);

7. Trend & Seasonal Components#

También podemos observar la contribución de nuestros eventos de control, así como la estacionalidad añadida al crear el modelo verdadero.

fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(12, 8), sharex=True)

ax1.plot(df["yearly_seasonality_contribution"])
ax1.set_title("Yearly Seasonality Contribution")
ax1.set_xlabel("date_week")
ax1.set_ylabel("Sales (thousands)")

ax2.plot(df["control_contribution"])
ax2.set_title("Control Contribution")
ax2.set_xlabel("date_week")
ax2.set_ylabel("Sales (thousands)");

Finalmente, ¡podemos visualizar el verdadero objetivo dado todos los componentes anteriores!

plt.plot(df["y"], color="black")
plt.title("Target Variable (Sales)")
plt.xlabel("date_week")
plt.ylabel("Sales (thousands)");

Ahora que todo está en su lugar, vamos a separar nuestro conjunto de datos para dejar los datos reales estimados por el modelo verdadero dentro de df y crearemos un nuevo conjunto de datos llamado data que tendrá todas las columnas necesarias pero no tendrá ninguna información sobre las relaciones verdaderas. Similar a como sucedería en la vida real.

data = df[["date_week", "x1", "x2", "event_1", "event_2", "y"]].copy()

X = data.drop("y", axis=1)
y = data["y"]

Como discutimos anteriormente, queremos comparar un modelo sin coeficientes de variante para ver cuánto se desvía de la realidad. Para esto, crearemos el objeto MMM que recibirá todos los parámetros necesarios para construir nuestro modelo, el cual debería estimar las relaciones del modelo verdadero.

basic_mmm = MMM(
    date_column="date_week",
    target_column="y",
    channel_columns=["x1", "x2"],
    control_columns=["event_1", "event_2"],
    yearly_seasonality=yearly_seasonality,
    adstock=GeometricAdstock(l_max=adstock_max_lag).set_dims_for_all_priors("channel"),
    saturation=MichaelisMentenSaturation().set_dims_for_all_priors("channel"),
)

basic_mmm.fit(
    X=X,
    y=y,
    target_accept=0.92,
    draws=500,
    random_seed=rng,
)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept_contribution, adstock_alpha, saturation_alpha, saturation_lam, gamma_control, gamma_fourier, y_sigma]

Sampling 4 chains for 1_000 tune and 500 draw iterations (4_000 + 2_000 draws total) took 8 seconds.
There were 5 divergences after tuning. Increase `target_accept` or reparameterize.

arviz.InferenceData
    • <xarray.Dataset> Size: 26MB
      Dimensions:                                  (chain: 4, draw: 500, channel: 2,
                                                    control: 2, fourier_mode: 4,
                                                    date: 179)
      Coordinates:
        * chain                                    (chain) int64 32B 0 1 2 3
        * draw                                     (draw) int64 4kB 0 1 2 ... 498 499
        * channel                                  (channel) <U2 16B 'x1' 'x2'
        * control                                  (control) <U7 56B 'event_1' 'eve...
        * fourier_mode                             (fourier_mode) <U5 80B 'sin_1' ....
        * date                                     (date) datetime64[ns] 1kB 2018-0...
      Data variables:
          adstock_alpha                            (chain, draw, channel) float64 32kB ...
          gamma_control                            (chain, draw, control) float64 32kB ...
          gamma_fourier                            (chain, draw, fourier_mode) float64 64kB ...
          intercept_contribution                   (chain, draw) float64 16kB 0.358...
          saturation_alpha                         (chain, draw, channel) float64 32kB ...
          saturation_lam                           (chain, draw, channel) float64 32kB ...
          y_sigma                                  (chain, draw) float64 16kB 0.058...
          channel_contribution                     (chain, draw, date, channel) float64 6MB ...
          control_contribution                     (chain, draw, date, control) float64 6MB ...
          fourier_contribution                     (chain, draw, date, fourier_mode) float64 11MB ...
          total_media_contribution_original_scale  (chain, draw) float64 16kB 489.3...
          yearly_seasonality_contribution          (chain, draw, date) float64 3MB ...
      Attributes:
          created_at:                 2026-02-03T20:55:06.894079+00:00
          arviz_version:              0.23.0
          inference_library:          pymc
          inference_library_version:  5.27.0
          sampling_time:              8.357609987258911
          tuning_steps:               1000
          pymc_marketing_version:     0.17.1

    • <xarray.Dataset> Size: 264kB
      Dimensions:                (chain: 4, draw: 500)
      Coordinates:
        * chain                  (chain) int64 32B 0 1 2 3
        * draw                   (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499
      Data variables: (12/18)
          acceptance_rate        (chain, draw) float64 16kB 0.9822 0.9998 ... 0.9899
          divergences            (chain, draw) int64 16kB 0 0 0 0 0 0 ... 0 0 0 0 0 0
          diverging              (chain, draw) bool 2kB False False ... False False
          energy                 (chain, draw) float64 16kB -216.4 -220.5 ... -218.7
          energy_error           (chain, draw) float64 16kB -0.3042 ... 0.04869
          index_in_trajectory    (chain, draw) int64 16kB 46 -20 -50 -9 ... -20 16 -32
          ...                     ...
          process_time_diff      (chain, draw) float64 16kB 0.004583 ... 0.004648
          reached_max_treedepth  (chain, draw) bool 2kB False False ... False False
          smallest_eigval        (chain, draw) float64 16kB nan nan nan ... nan nan
          step_size              (chain, draw) float64 16kB 0.09049 ... 0.07296
          step_size_bar          (chain, draw) float64 16kB 0.07765 ... 0.09081
          tree_depth             (chain, draw) int64 16kB 6 5 6 5 5 6 ... 6 5 6 5 6 6
      Attributes:
          created_at:                 2026-02-03T20:55:06.902207+00:00
          arviz_version:              0.23.0
          inference_library:          pymc
          inference_library_version:  5.27.0
          sampling_time:              8.357609987258911
          tuning_steps:               1000

    • <xarray.Dataset> Size: 3kB
      Dimensions:  (date: 179)
      Coordinates:
        * date     (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30
      Data variables:
          y        (date) float64 1kB 0.5317 0.6001 0.6019 ... 0.3418 0.2982 0.2264
      Attributes:
          created_at:                 2026-02-03T20:55:06.904767+00:00
          arviz_version:              0.23.0
          inference_library:          pymc
          inference_library_version:  5.27.0

    • <xarray.Dataset> Size: 9kB
      Dimensions:        (date: 179, channel: 2, control: 2)
      Coordinates:
        * date           (date) datetime64[ns] 1kB 2018-04-02 ... 2021-08-30
        * channel        (channel) <U2 16B 'x1' 'x2'
        * control        (control) <U7 56B 'event_1' 'event_2'
      Data variables:
          channel_data   (date, channel) float64 3kB 0.2948 0.0 0.9383 ... 0.1269 0.0
          channel_scale  (channel) float64 16B 0.9968 0.9927
          control_data   (date, control) float64 3kB 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0
          dayofyear      (date) int32 716B 92 99 106 113 120 ... 214 221 228 235 242
          target_data    (date) float64 1kB 7.784 8.785 8.812 ... 5.004 4.366 3.315
          target_scale   float64 8B 14.64
      Attributes:
          created_at:                 2026-02-03T20:55:06.906425+00:00
          arviz_version:              0.23.0
          inference_library:          pymc
          inference_library_version:  5.27.0

    • <xarray.Dataset> Size: 9kB
      Dimensions:    (date_week: 179)
      Coordinates:
        * date_week  (date_week) datetime64[ns] 1kB 2018-04-02 ... 2021-08-30
      Data variables:
          x1         (date_week) float64 1kB 0.2948 0.9383 0.1397 ... 0.9364 0.1269
          x2         (date_week) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
          event_1    (date_week) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
          event_2    (date_week) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
          y          (date_week) float64 1kB 7.784 8.785 8.812 ... 5.004 4.366 3.315

Como podemos ver, ¡el modelo encontró divergencias!🤯

La ocurrencia de divergencias en nuestro MMM bayesiano destaca las fortalezas y la robustez del marco bayesiano en la prueba de hipótesis y la validación de modelos. Los modelos bayesianos son estructurales y se adhieren a ciertas suposiciones sobre el proceso generador de datos. Cuando estas suposiciones se violan o la estructura del modelo no se ajusta bien a los datos, pueden surgir divergencias y problemas de muestreo.

Esta característica convierte el enfoque bayesiano en una herramienta poderosa para:

  • Pruebas de Hipótesis: Al definir relaciones estructurales y supuestos claros, los modelos bayesianos pueden ayudar a probar y validar hipótesis sobre los procesos subyacentes en los datos.

  • Validación del Modelo: Las divergencias y los problemas de muestreo sirven como indicadores de que el modelo puede no estar correctamente especificado, lo que provoca una investigación y refinamiento adicionales.

  • Comprensión de Sistemas Complejos: Los métodos bayesianos permiten la incorporación de conocimientos previos y la prueba de diversas suposiciones estructurales, lo que los hace especialmente adecuados para comprender sistemas complejos del mundo real.

En este caso particular, podemos sospechar perfectamente por qué el modelo tuvo divergencias. La estructura interna de nuestro modelo del mundo (MMM) está desestimando el tiempo cuando este es un factor importante (sabemos esto porque hemos llevado a cabo el debido proceso de generación de datos).


A pesar de eso, echemos un vistazo a los datos que pudimos recuperar a través de este modelo básico.

Si descomponemos la distribución predictiva posterior en los diferentes componentes, todo queda claro:

basic_mmm.plot.contributions_over_time(
    var=[
        "control_contribution",
        "channel_contribution",
        "yearly_seasonality_contribution",
    ],
    combine_dims=True,
    figsize=(16, 8),
)
(<Figure size 1600x800 with 1 Axes>,
 array([[<Axes: title={'center': 'Time Series Contributions'}, xlabel='Date', ylabel='Posterior Value'>]],
       dtype=object))
../../_images/8a1073da66b80e577925cc19b36a1b15b4aa663992a9579a72c0bde8ab8b1903.png

Algunas contribuciones terminan teniendo más unidades que el valor objetivo, lo que obliga al modelo a compensar. Esto resulta en una descomposición incorrecta de nuestras actividades de marketing.

Por ejemplo, nuestra serie temporal termina con una larga cola de valores probables para las contribuciones de marketing, siendo esta cola hasta 3X mayor que el valor máximo de nuestro objetivo.

def plot_posterior(
    posterior, figsize=(15, 8), path_color="blue", hist_color="blue", **kwargs
):
    """Plot the posterior distribution of a stochastic process.

    Parameters
    ----------
    posterior : xarray.DataArray
        The posterior distribution with shape (draw, chain, date).
    figsize : tuple
        Size of the figure.
    path_color : str
        Color of the paths in the time series plot.
    hist_color : str
        Color of the histogram.
    **kwargs
        Additional keyword arguments to pass to the plotting functions.

    """
    # Calculate the expected value (mean) across all draws and chains for each date
    expected_value = posterior.mean(dim=("draw", "chain"))

    # Create a figure and a grid of subplots
    fig = plt.figure(figsize=figsize)
    gs = fig.add_gridspec(1, 2, width_ratios=[3, 1])

    # Time series plot
    ax1 = fig.add_subplot(gs[0])
    for chain in range(posterior.shape[1]):
        for draw in range(
            0, posterior.shape[0], 10
        ):  # Plot every 10th draw for performance
            ax1.plot(
                posterior.date,
                posterior[draw, chain],
                color=path_color,
                alpha=0.05,
                linewidth=0.4,
            )

    # Plot expected value with a distinct color
    ax1.plot(
        posterior.date,
        expected_value,
        color="black",
        linestyle="--",
        linewidth=2,
        label="Expected Value",
    )
    ax1.set_title("Posterior Predictive")
    ax1.set_xlabel("Date")
    ax1.set_ylabel("Value")
    ax1.grid(True)
    ax1.legend()

    # KDE plot instead of histogram
    ax2 = fig.add_subplot(gs[1])
    final_values = posterior[:, :, -1].values.flatten()
    sns.kdeplot(
        y=final_values, ax=ax2, color=hist_color, fill=True, alpha=0.4, **kwargs
    )

    # Plot expected value line in KDE plot
    ax2.axhline(
        y=expected_value[-1].values.mean(), color="black", linestyle="--", linewidth=2
    )
    ax2.set_title("Distribution at T")
    ax2.set_xlabel("Density")
    ax2.set_yticklabels([])  # Hide y tick labels to avoid duplication
    ax2.grid(True)

    plt.tight_layout()
    return fig


plot_posterior(
    posterior=basic_mmm.fit_result["channel_contribution"].sum(dim="channel")
);

Pero, ¿por qué se sobreestiman las contribuciones? Las contribuciones están mal estimadas porque los parámetros de nuestras transformaciones también están mal estimados. Por ejemplo, los parámetros que controlan la efectividad máxima (en la función de saturación) de cada canal son mucho más altos que los reales para ambos canales.

fig = basic_mmm.plot.channel_parameter(param_name="saturation_alpha", figsize=(9, 5))
ax = fig.axes[0]
ax.axvline(
    x=(real_alpha[0] / df.y.max()), color="C0", linestyle="--", label=r"$\alpha_1$"
)
ax.axvline(
    x=(real_alpha[1] / df.y.max()), color="C1", linestyle="--", label=r"$\alpha_2$"
)
ax.legend(loc="upper right");

¿Qué cambiaría si ahora consideramos el tiempo como un factor en nuestro modelo?

Ahora podemos hacer esto añadiendo el siguiente parámetro a la inicialización de nuestro modelo time_varying_media y cambiándolo a True.

from pymc_marketing.hsgp_kwargs import HSGPKwargs

hsgp_kwargs = HSGPKwargs(
    ls_mu=11.0,  # InverseGamma lengthscale prior mean
    ls_sigma=5.0,  # InverseGamma lengthscale prior sigma
)
mmm = MMM(
    date_column="date_week",
    target_column="y",
    channel_columns=["x1", "x2"],
    control_columns=["event_1", "event_2"],
    yearly_seasonality=yearly_seasonality,
    adstock=GeometricAdstock(l_max=adstock_max_lag).set_dims_for_all_priors("channel"),
    saturation=MichaelisMentenSaturation().set_dims_for_all_priors("channel"),
    time_varying_media=True,  # 1. Enable the feature
    model_config={"media_tvp_config": hsgp_kwargs},
)

Nota

Al hacer esto, ahora nuestra configuración del modelo tendrá una nueva clave media_tvp_config con los parámetros que controlan los priors de nuestro HSGP.

mmm.model_config["media_tvp_config"]
HSGPKwargs(m=200, L=None, eta_lam=1, ls_mu=11.0, ls_sigma=5.0, cov_func=None)
mmm.fit(
    X=X,
    y=y,
    target_accept=0.92,
    draws=500,
    random_seed=rng,
)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept_contribution, adstock_alpha, saturation_alpha, saturation_lam, media_temporal_latent_multiplier_raw_eta, media_temporal_latent_multiplier_raw_ls, media_temporal_latent_multiplier_raw_hsgp_coefs_offset, gamma_control, gamma_fourier, y_sigma]

Sampling 4 chains for 1_000 tune and 500 draw iterations (4_000 + 2_000 draws total) took 54 seconds.
There were 8 divergences after tuning. Increase `target_accept` or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details

arviz.InferenceData
    • <xarray.Dataset> Size: 44MB
      Dimensions:                                                 (chain: 4,
                                                                   draw: 500,
                                                                   channel: 2,
                                                                   control: 2,
                                                                   fourier_mode: 4,
                                                                   media_temporal_latent_multiplier_raw_m: 200,
                                                                   date: 179)
      Coordinates:
        * chain                                                   (chain) int64 32B ...
        * draw                                                    (draw) int64 4kB ...
        * channel                                                 (channel) <U2 16B ...
        * control                                                 (control) <U7 56B ...
        * fourier_mode                                            (fourier_mode) <U5 80B ...
        * media_temporal_latent_multiplier_raw_m                  (media_temporal_latent_multiplier_raw_m) int64 2kB ...
        * date                                                    (date) datetime64[ns] 1kB ...
      Data variables: (12/20)
          adstock_alpha                                           (chain, draw, channel) float64 32kB ...
          gamma_control                                           (chain, draw, control) float64 32kB ...
          gamma_fourier                                           (chain, draw, fourier_mode) float64 64kB ...
          intercept_contribution                                  (chain, draw) float64 16kB ...
          media_temporal_latent_multiplier_raw_eta                (chain, draw) float64 16kB ...
          media_temporal_latent_multiplier_raw_hsgp_coefs_offset  (chain, draw, media_temporal_latent_multiplier_raw_m) float64 3MB ...
          ...                                                      ...
          media_temporal_latent_multiplier                        (chain, draw, date) float64 3MB ...
          media_temporal_latent_multiplier_f_mean                 (chain, draw) float64 16kB ...
          media_temporal_latent_multiplier_raw                    (chain, draw, date) float64 3MB ...
          media_temporal_latent_multiplier_raw_hsgp_coefs         (chain, draw, media_temporal_latent_multiplier_raw_m) float64 3MB ...
          total_media_contribution_original_scale                 (chain, draw) float64 16kB ...
          yearly_seasonality_contribution                         (chain, draw, date) float64 3MB ...
      Attributes:
          created_at:                 2026-02-03T20:56:17.140137+00:00
          arviz_version:              0.23.0
          inference_library:          pymc
          inference_library_version:  5.27.0
          sampling_time:              54.1484808921814
          tuning_steps:               1000
          pymc_marketing_version:     0.17.1

    • <xarray.Dataset> Size: 264kB
      Dimensions:                (chain: 4, draw: 500)
      Coordinates:
        * chain                  (chain) int64 32B 0 1 2 3
        * draw                   (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499
      Data variables: (12/18)
          acceptance_rate        (chain, draw) float64 16kB 0.9151 0.9345 ... 0.9902
          divergences            (chain, draw) int64 16kB 0 0 0 0 0 0 ... 2 2 2 2 2 2
          diverging              (chain, draw) bool 2kB False False ... False False
          energy                 (chain, draw) float64 16kB -41.24 -41.99 ... -58.52
          energy_error           (chain, draw) float64 16kB 0.08577 ... 0.002545
          index_in_trajectory    (chain, draw) int64 16kB 84 -20 47 ... 47 -91 -102
          ...                     ...
          process_time_diff      (chain, draw) float64 16kB 0.02821 ... 0.02846
          reached_max_treedepth  (chain, draw) bool 2kB False False ... False False
          smallest_eigval        (chain, draw) float64 16kB nan nan nan ... nan nan
          step_size              (chain, draw) float64 16kB 0.02866 ... 0.03422
          step_size_bar          (chain, draw) float64 16kB 0.03508 ... 0.03673
          tree_depth             (chain, draw) int64 16kB 7 7 7 7 7 7 ... 7 7 7 7 7 7
      Attributes:
          created_at:                 2026-02-03T20:56:17.151157+00:00
          arviz_version:              0.23.0
          inference_library:          pymc
          inference_library_version:  5.27.0
          sampling_time:              54.1484808921814
          tuning_steps:               1000

    • <xarray.Dataset> Size: 3kB
      Dimensions:  (date: 179)
      Coordinates:
        * date     (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30
      Data variables:
          y        (date) float64 1kB 0.5317 0.6001 0.6019 ... 0.3418 0.2982 0.2264
      Attributes:
          created_at:                 2026-02-03T20:56:17.154293+00:00
          arviz_version:              0.23.0
          inference_library:          pymc
          inference_library_version:  5.27.0

    • <xarray.Dataset> Size: 10kB
      Dimensions:        (date: 179, channel: 2, control: 2)
      Coordinates:
        * date           (date) datetime64[ns] 1kB 2018-04-02 ... 2021-08-30
        * channel        (channel) <U2 16B 'x1' 'x2'
        * control        (control) <U7 56B 'event_1' 'event_2'
      Data variables:
          channel_data   (date, channel) float64 3kB 0.2948 0.0 0.9383 ... 0.1269 0.0
          channel_scale  (channel) float64 16B 0.9968 0.9927
          control_data   (date, control) float64 3kB 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0
          dayofyear      (date) int32 716B 92 99 106 113 120 ... 214 221 228 235 242
          target_data    (date) float64 1kB 7.784 8.785 8.812 ... 5.004 4.366 3.315
          target_scale   float64 8B 14.64
          time_index     (date) int32 716B 0 1 2 3 4 5 6 ... 173 174 175 176 177 178
      Attributes:
          created_at:                 2026-02-03T20:56:17.156463+00:00
          arviz_version:              0.23.0
          inference_library:          pymc
          inference_library_version:  5.27.0

    • <xarray.Dataset> Size: 9kB
      Dimensions:    (date_week: 179)
      Coordinates:
        * date_week  (date_week) datetime64[ns] 1kB 2018-04-02 ... 2021-08-30
      Data variables:
          x1         (date_week) float64 1kB 0.2948 0.9383 0.1397 ... 0.9364 0.1269
          x2         (date_week) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
          event_1    (date_week) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
          event_2    (date_week) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
          y          (date_week) float64 1kB 7.784 8.785 8.812 ... 5.004 4.366 3.315

We got less divergences, this is a good sign! 🚀

¡Revisemos nuestras muestras!

az.summary(
    data=mmm.fit_result,
    var_names=[
        "intercept_contribution",
        "y_sigma",
        "gamma_control",
        "gamma_fourier",
    ],
)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
intercept_contribution 0.413 0.007 0.399 0.427 0.0 0.0 1988.0 1471.0 1.0
y_sigma 0.017 0.001 0.015 0.018 0.0 0.0 2915.0 1413.0 1.0
gamma_control[event_1] -0.251 0.018 -0.283 -0.216 0.0 0.0 2957.0 1418.0 1.0
gamma_control[event_2] 0.431 0.018 0.396 0.461 0.0 0.0 3626.0 1549.0 1.0
gamma_fourier[sin_1] 0.169 0.003 0.163 0.175 0.0 0.0 1692.0 1071.0 1.0
gamma_fourier[sin_2] -0.032 0.002 -0.036 -0.028 0.0 0.0 3111.0 1390.0 1.0
gamma_fourier[cos_1] 0.100 0.003 0.095 0.106 0.0 0.0 2112.0 1164.0 1.0
gamma_fourier[cos_2] 0.172 0.002 0.168 0.176 0.0 0.0 2836.0 1752.0 1.0
_ = az.plot_trace(
    data=mmm.fit_result,
    var_names=[
        "intercept_contribution",
        "y_sigma",
        "gamma_control",
        "gamma_fourier",
    ],
    compact=True,
    backend_kwargs={"figsize": (12, 10), "layout": "constrained"},
)
plt.gcf().suptitle("Model Trace", fontsize=16);
az.summary(
    data=mmm.fit_result,
    var_names=["adstock_alpha", "saturation_lam", "saturation_alpha"],
)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
adstock_alpha[x1] 0.500 0.038 0.423 0.569 0.001 0.001 1976.0 1653.0 1.0
adstock_alpha[x2] 0.345 0.029 0.287 0.395 0.001 0.001 1426.0 1444.0 1.0
saturation_lam[x1] 0.333 0.091 0.182 0.502 0.002 0.003 1582.0 1193.0 1.0
saturation_lam[x2] 0.366 0.085 0.202 0.519 0.002 0.002 1357.0 1282.0 1.0
saturation_alpha[x1] 0.188 0.018 0.158 0.222 0.000 0.001 1539.0 1158.0 1.0
saturation_alpha[x2] 0.256 0.026 0.211 0.307 0.001 0.001 1316.0 1180.0 1.0
_ = az.plot_trace(
    data=mmm.fit_result,
    var_names=["adstock_alpha", "saturation_lam", "saturation_alpha"],
    compact=True,
    backend_kwargs={"figsize": (12, 10), "layout": "constrained"},
)
plt.gcf().suptitle("Model Trace", fontsize=16);
az.summary(
    data=mmm.fit_result,
    var_names=[
        "media_temporal_latent_multiplier_raw_eta",
        "media_temporal_latent_multiplier_raw_ls",
        "media_temporal_latent_multiplier_raw_hsgp_coefs",
    ],
)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
media_temporal_latent_multiplier_raw_eta 0.643 0.364 0.242 1.201 0.018 0.041 944.0 542.0 1.01
media_temporal_latent_multiplier_raw_ls 52.133 17.904 25.169 86.925 0.666 0.668 792.0 482.0 1.01
media_temporal_latent_multiplier_raw_hsgp_coefs[0] 0.140 8.410 -14.409 12.980 0.210 0.545 1996.0 1173.0 1.00
media_temporal_latent_multiplier_raw_hsgp_coefs[1] 6.197 1.938 2.813 9.697 0.048 0.069 1749.0 1378.0 1.00
media_temporal_latent_multiplier_raw_hsgp_coefs[2] 5.361 2.909 0.956 10.262 0.070 0.160 1836.0 1317.0 1.00
... ... ... ... ... ... ... ... ... ...
media_temporal_latent_multiplier_raw_hsgp_coefs[195] -0.000 0.000 -0.000 0.000 0.000 0.000 2914.0 1565.0 1.00
media_temporal_latent_multiplier_raw_hsgp_coefs[196] 0.000 0.000 -0.000 0.000 0.000 0.000 3106.0 1518.0 1.01
media_temporal_latent_multiplier_raw_hsgp_coefs[197] -0.000 0.000 -0.000 0.000 0.000 0.000 4220.0 1427.0 1.00
media_temporal_latent_multiplier_raw_hsgp_coefs[198] -0.000 0.000 -0.000 0.000 0.000 0.000 3287.0 1454.0 1.00
media_temporal_latent_multiplier_raw_hsgp_coefs[199] 0.000 0.000 -0.000 0.000 0.000 0.000 3258.0 1591.0 1.01

202 rows × 9 columns

_ = az.plot_trace(
    data=mmm.fit_result,
    var_names=[
        "media_temporal_latent_multiplier_raw_eta",
        "media_temporal_latent_multiplier_raw_ls",
        "media_temporal_latent_multiplier_raw_hsgp_coefs",
    ],
    compact=True,
    backend_kwargs={"figsize": (12, 10), "layout": "constrained"},
)
plt.gcf().suptitle("Model Trace", fontsize=16);

Todo parece estar bien por ahora, no hay nada que genere alertas al analizar nuestro rastro. Pero, ¿qué pasa con la descomposición?

mmm.plot.contributions_over_time(
    var=[
        "control_contribution",
        "channel_contribution",
        "yearly_seasonality_contribution",
    ],
    combine_dims=True,
    figsize=(16, 8),
);

La descomposición se ve mucho mejor ahora 🔥 Parece que estamos estimando cada parámetro de manera más precisa, ¡y no hay compromisos obvios entre los componentes!

Veamos qué tan bien se han logrado recuperar los parámetros originales.

fig = mmm.plot.channel_parameter(param_name="saturation_alpha", figsize=(9, 5))
ax = fig.axes[0]
ax.axvline(
    x=(real_alpha[0] / df.y.max()), color="C0", linestyle="--", label=r"$\alpha_1$"
)
ax.axvline(
    x=(real_alpha[1] / df.y.max()), color="C1", linestyle="--", label=r"$\alpha_2$"
)
ax.legend(loc="upper right");
fig = mmm.plot.channel_parameter(param_name="saturation_lam", figsize=(9, 5))
ax = fig.axes[0]
ax.axvline(
    x=(real_lam[0] / df.x1.max()), color="C0", linestyle="--", label=r"$\lambda_1$"
)
ax.axvline(
    x=(real_lam[1] / df.x2.max()), color="C1", linestyle="--", label=r"$\lambda_2$"
)
ax.legend(loc="upper right");

¡Los parámetros de la función de saturación parecen haberse recuperado prácticamente en su totalidad para ambos canales! ¡Esto es genial 🎉!

Veamos cuánto logramos recuperar de la verdadera variación. Podemos analizar la variable media_temporal_latent_multiplier y compararla con la variable original utilizada en el proceso original.

media_latent_factor = mmm.fit_result["media_temporal_latent_multiplier"].quantile(
    [0.025, 0.50, 0.975], dim=["chain", "draw"]
)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(15, 10))
sns.lineplot(
    x=mmm.fit_result.coords["date"],
    y=media_latent_factor.sel(quantile=0.5),
    label="Predicted",
    color="blue",
)

sns.lineplot(
    x=mmm.fit_result.coords["date"],
    y=df["hidden_latent_media_fluctuation"],
    label="Real",
    color="Black",
    linestyle="--",
)


ax.fill_between(
    mmm.fit_result.coords["date"],
    media_latent_factor.sel(quantile=0.025),
    media_latent_factor.sel(quantile=0.975),
    alpha=0.3,
)
ax.set_title("HSGP")
ax.set_xlabel("Date")
ax.set_ylabel("Latent Factor")
ax.tick_params(axis="x", rotation=45)
ax.legend()
plt.show()

Increíble 🚀 hemos recuperado el proceso latente casi a la perfección. Aunque parece un poco sobreestimado, ¡está bastante cerca de la realidad!

recover_channel_contribution = mmm.fit_result["channel_contribution"].quantile(
    [0.025, 0.50, 0.975], dim=["chain", "draw"]
)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(15, 10))
sns.lineplot(
    x=mmm.fit_result.coords["date"],
    y=recover_channel_contribution.sel(quantile=0.5).sum(axis=-1),
    label="Posterior Predictive Contribution",
    color="purple",
)

sns.lineplot(
    x=mmm.fit_result.coords["date"],
    y=df["channel_contribution"] / df["y"].max(),
    label="Real",
    color="purple",
    linestyle="--",
)


ax.fill_between(
    mmm.fit_result.coords["date"],
    recover_channel_contribution.sel(quantile=0.025).sum(axis=-1),
    recover_channel_contribution.sel(quantile=0.975).sum(axis=-1),
    alpha=0.3,
)
ax.set_title("Recover contribution")
ax.set_xlabel("Date")
ax.set_ylabel("Sales")
ax.tick_params(axis="x", rotation=45)
ax.legend()
plt.show()

Esto se refleja al comparar la contribución recuperada con la original. ¡Podemos ver que son exactamente iguales!

Podemos comparar ahora la distribución de las contribuciones por canal entre nuestros dos modelos.

basic_recover_channel_contribution = basic_mmm.fit_result[
    "channel_contribution"
].quantile([0.025, 0.50, 0.975], dim=["chain", "draw"])


fig, ax = plt.subplots(
    nrows=2, ncols=1, figsize=(15, 9), sharex=True, sharey=True, layout="constrained"
)

sns.lineplot(
    x="date_week",
    y="x1_contribution",
    data=df,
    color="C0",
    ax=ax[0],
    label="Real Contribution x1",
)
ax[0].fill_between(
    basic_mmm.fit_result.coords["date"],
    basic_recover_channel_contribution.sel(quantile=0.025).sel(channel="x1")
    * df.y.max(),
    basic_recover_channel_contribution.sel(quantile=0.975).sel(channel="x1")
    * df.y.max(),
    alpha=0.4,
    color="C5",
    label="Posterior Contribution x1 (Basic)",
)
ax[0].fill_between(
    mmm.fit_result.coords["date"],
    recover_channel_contribution.sel(quantile=0.025).sel(channel="x1") * df.y.max(),
    recover_channel_contribution.sel(quantile=0.975).sel(channel="x1") * df.y.max(),
    alpha=0.4,
    color="C0",
    label="Posterior Contribution x1 (Time-varying)",
)
ax[0].legend(bbox_to_anchor=(0.5, -0.18), loc="upper center", ncols=3)

sns.lineplot(
    x="date_week",
    y="x2_contribution",
    data=df,
    color="C1",
    ax=ax[1],
    label="Real Contribution x2",
)
ax[1].fill_between(
    basic_mmm.fit_result.coords["date"],
    basic_recover_channel_contribution.sel(quantile=0.025).sel(channel="x2")
    * df.y.max(),
    basic_recover_channel_contribution.sel(quantile=0.975).sel(channel="x2")
    * df.y.max(),
    alpha=0.4,
    color="C3",
    label="Posterior Contribution x2 (Basic)",
)
ax[1].fill_between(
    mmm.fit_result.coords["date"],
    recover_channel_contribution.sel(quantile=0.025).sel(channel="x2") * df.y.max(),
    recover_channel_contribution.sel(quantile=0.975).sel(channel="x2") * df.y.max(),
    alpha=0.4,
    color="C1",
    label="Posterior Contribution x2 (Time-varying)",
)

ax[1].set(xlabel="weeks")
fig.suptitle("Media Contribution per Channel", fontsize=16)
ax[1].legend(bbox_to_anchor=(0.5, -0.18), loc="upper center", ncols=3);

Las contribuciones por canal también se recuperaron correctamente, ¡a diferencia de nuestro primer modelo! De hecho, vemos cómo el modelo base intenta (¡y falla!) capturar la verdadera variación. La razón es clara: falta de flexibilidad.

Perspectivas#

El enfoque bayesiano no solo facilita la prueba de hipótesis y la validación de modelos, sino que también proporciona una manera estructurada de incorporar conocimientos previos y probar diversas suposiciones sobre el proceso de generación de datos. La ocurrencia de divergencias, como se observó en nuestro ajuste inicial del modelo, subraya la importancia de la especificación del modelo y su alineación con la estructura subyacente de los datos. Estas divergencias sirven como una herramienta de diagnóstico, guiando el posterior refinamiento y mejora del modelo.

En resumen, utilizar PyMC-Marketing para construir modelos MMM conscientes del tiempo permite a los mercadólogos obtener una comprensión más profunda y precisa del impacto de sus esfuerzos. Esta metodología mejora la capacidad de tomar decisiones basadas en datos, optimizar estrategias de marketing y, en última instancia, impulsar mejores resultados comerciales.

Conclusión#

A lo largo de este cuaderno, hemos explorado la implementación de un Modelo de Mezcla de Marketing Bayesiano (MMM) utilizando PyMC, comparando el rendimiento y los conocimientos obtenidos de modelos con y sin un componente temporal. La conclusión clave de nuestro análisis es la ventaja significativa de incorporar factores que varían en el tiempo en el MMM.

Descubriendo Relaciones Causales Reales#

Al integrar un componente temporal, podemos descubrir las verdaderas relaciones causales entre nuestra variable objetivo (como las ventas) y nuestros esfuerzos de marketing. El enfoque tradicional, que descuida las dinámicas temporales, a menudo no logra capturar la naturaleza compleja y fluctuante del rendimiento del marketing en el mundo real. En contraste, el modelo dependiente del tiempo proporciona una comprensión más precisa y matizada de cómo las actividades de marketing influyen en los resultados a lo largo del tiempo.

Ventajas de PyMC-Marketing#

PyMC-Marketing ofrece herramientas poderosas para implementar estas metodologías avanzadas. Las nuevas características y funcionalidades, que incluyen el manejo de diferentes efectos de adstock, efectos de saturación y procesos gaussianos en el espacio de Hilbert (HSGP) para modelar componentes que varían en el tiempo, permiten un modelado más preciso y confiable de los datos de marketing.

Animamos a los profesionales a aprovechar estas técnicas avanzadas y las capacidades de PyMC-Marketing para mejorar su análisis de marketing y obtener una ventaja competitiva en su planificación estratégica.

Bono#

Este cuaderno simuló una variación muy simple, es posible que los verdaderos procesos latentes dependientes del tiempo ocultos en sus datos sean más complejos, por lo tanto, necesitará utilizar priors para guiar su modelo a encontrar los datos reales.

Una forma de lograr esto es mediante la modificación de la configuración del modelo.

custom_config = {
    "intercept": Prior("HalfNormal", sigma=0.5),
    "saturation_alpha": Prior(
        "Gamma", mu=np.array([0.3, 0.4]), sigma=np.array([0.2, 0.2]), dims="channel"
    ),
    "saturation_lam": Prior("Beta", alpha=4, beta=4, dims="channel"),
}

media_tvp_config = {
    "media_tvp_config": {
        "m": 50,
        "L": 30,
        "eta_lam": 3,
        "ls_mu": 5,
        "ls_sigma": 5,
        "cov_func": None,
    }
}

custom_config = {**mmm.model_config, **custom_config, **media_tvp_config}
custom_config
{'intercept': Prior("HalfNormal", sigma=0.5),
 'likelihood': Prior("Normal", sigma=Prior("HalfNormal", sigma=2), dims="date"),
 'gamma_control': Prior("Normal", mu=0, sigma=2, dims="control"),
 'gamma_fourier': Prior("Laplace", mu=0, b=1, dims="fourier_mode"),
 'media_tvp_config': {'m': 50,
  'L': 30,
  'eta_lam': 3,
  'ls_mu': 5,
  'ls_sigma': 5,
  'cov_func': None},
 'adstock_alpha': Prior("Beta", alpha=1, beta=3, dims="channel"),
 'saturation_alpha': Prior("Gamma", mu=[0.3 0.4], sigma=[0.2 0.2], dims="channel"),
 'saturation_lam': Prior("Beta", alpha=4, beta=4, dims="channel")}
mmm_calibrated = MMM(
    date_column="date_week",
    target_column="y",
    channel_columns=["x1", "x2"],
    control_columns=["event_1", "event_2"],
    yearly_seasonality=yearly_seasonality,
    adstock=GeometricAdstock(l_max=adstock_max_lag),
    saturation=MichaelisMentenSaturation(),
    time_varying_media=True,
    model_config=custom_config,
)
%load_ext watermark
%watermark -n -u -v -iv -w -p pymc,pymc_marketing,pytensor,numpyro
Last updated: Tue, 03 Feb 2026

Python implementation: CPython
Python version       : 3.13.11
IPython version      : 9.9.0

pymc          : 5.27.0
pymc_marketing: 0.17.1
pytensor      : 2.36.3
numpyro       : 0.19.0

arviz         : 0.23.0
matplotlib    : 3.10.8
numpy         : 2.3.5
pandas        : 2.3.3
pymc_extras   : 0.7.0
pymc_marketing: 0.17.1
seaborn       : 0.13.2

Watermark: 2.6.0