Validación cruzada por segmentos de tiempo y estabilidad de parámetros#

En este cuaderno, ilustramos cómo realizar la validación cruzada por segmentos de tiempo para un modelo de mezcla de medios. Este es un paso importante para evaluar la estabilidad y calidad del modelo. No solo analizamos las predicciones fuera de muestra, sino también la estabilidad de los parámetros del modelo.

These imports and configurations form the fundamental setup necessary for the entire span of this notebook.

The expectation is that a model has already been trained using the functionalities provided in prior versions of the PyMC-Marketing library. Thus, the data generation and training processes will be replicated in a different notebook. Those unfamiliar with these procedures are advised to refer to the «MMM Example Notebook.»

Preparar el cuaderno#

import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from pymc_marketing.mmm.time_slice_cross_validation import TimeSliceCrossValidator
from pymc_marketing.paths import data_dir

warnings.simplefilter(action="ignore", category=FutureWarning)

az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"


%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
/Users/juanitorduz/Documents/pymc-marketing/pymc_marketing/pytensor_utils.py:34: FutureWarning: `pytensor.graph.basic.ancestors` was moved to `pytensor.graph.traversal.ancestors`. Calling it from the old location will fail in a future release.
  from pytensor.graph.basic import ancestors
/Users/juanitorduz/Documents/pymc-marketing/pymc_marketing/mmm/multidimensional.py:216: FutureWarning: This functionality is experimental and subject to change. If you encounter any issues or have suggestions, please raise them at: https://github.com/pymc-labs/pymc-marketing/issues/new
  warnings.warn(warning_msg, FutureWarning, stacklevel=1)
/Users/juanitorduz/Documents/pymc-marketing/pymc_marketing/mmm/time_slice_cross_validation.py:32: UserWarning: The pymc_marketing.mmm.builders module is experimental and its API may change without warning.
  from pymc_marketing.mmm.builders.yaml import build_mmm_from_yaml
seed: int = sum(map(ord, "mmm"))
rng: np.random.Generator = np.random.default_rng(seed=seed)

Loading Data#

Here we will load our geo level dataset. This will then be used within our Time-Slice CV steps.

data_path = data_dir / "multidimensional_mock_data.csv"
data_df = pd.read_csv(data_path, parse_dates=["date"], index_col=0)
data_df.head()
date y x1 x2 event_1 event_2 dayofyear t geo
0 2018-04-02 3984.662237 159.290009 0.0 0.0 0.0 92 0 geo_a
1 2018-04-09 3762.871794 56.194238 0.0 0.0 0.0 99 1 geo_a
2 2018-04-16 4466.967388 146.200133 0.0 0.0 0.0 106 2 geo_a
3 2018-04-23 3864.219373 35.699276 0.0 0.0 0.0 113 3 geo_a
4 2018-04-30 4441.625278 193.372577 0.0 0.0 0.0 120 4 geo_a
X = data_df.drop(columns=["y"])
y = data_df["y"]

Especificar la estrategia de validación cruzada por segmentos de tiempo#

La idea principal del proceso de validación cruzada por segmentos de tiempo es ajustar el modelo en un segmento de tiempo de los datos y luego evaluarlo en el siguiente segmento de tiempo. Repetimos este proceso para cada segmento de tiempo de los datos. Como queremos simular un entorno similar al de producción donde ampliamos nuestros datos de entrenamiento con el tiempo, hacemos que el tamaño del segmento de tiempo crezca con el tiempo.

Filtración de datos

Es muy importante evitar la filtración de datos al realizar la validación cruzada por segmentos de tiempo. Esto significa que el modelo no debe ver ningún dato de entrenamiento del futuro. ¡Esto también incluye cualquier paso de preprocesamiento de datos!

Por ejemplo, como se mencionó anteriormente, necesitamos calcular la participación de costos para cada intervalo de tiempo de entrenamiento de manera independiente si queremos evitar la filtración de datos. Otras fuentes de filtración de datos incluyen el uso de una característica global para el componente de tendencia. En nuestro caso, simplemente utilizamos una variable creciente t, por lo que estamos a salvo, ya que solo la incrementamos en uno para cada intervalo de tiempo.

Ejecutar el bucle de validación cruzada de tiempo de segmento#

Depending on the business requirements, we need to decide the initial number of observations to use for fitting the model (n_init) and the forecast horizon (forecast_horizon). For this example, we use the first 342 observations to fit the model and then predict the next 12 observations (3 months).

# Initialize cross-validator
cv = TimeSliceCrossValidator(
    n_init=163,
    forecast_horizon=12,
    date_column="date",
    step_size=1,
)
# We can check how many splits we will have
# As a reference, the number of splits is computed as:
# n_iterations = y.size - n_init - forecast_horizon + 1
n_splits = cv.get_n_splits(X, y)
print(f"Number of splits: {n_splits}")
Number of splits: 5

¡Ejecutémoslo!

For more details on the build_mmm_from_yaml, consult the pymc-marketing documentation on Model Deployment.

Alternatively, load a model that has been saved to MLflow via pymc_marketing.mlflow.log_inference_data or has been autologged to MLflow via pymc_marketing.mlflow.autolog(log_mmm=True), from the PyMC-Marketing MLflow module.

results = cv.run(
    X,
    y,
    # You can also pass sampler_config here to speed things up
    sampler_config={
        "tune": 1_000,
        "draws": 1_000,
        "chains": 4,
        "random_seed": seed,
        "target_accept": 0.90,
        "nuts_sampler": "numpyro",
    },
    yaml_path=data_dir / "config_files" / "multi_dimensional_example_model.yml",
)

Sampling: [y]

Sampling: [y]

Sampling: [y]

Sampling: [y]

Sampling: [y]
# We can view the cross-validation results!
# The CV object is an instance of ArviZ InferenceData
results
arviz.InferenceData
    • <xarray.Dataset> Size: 700MB
      Dimensions:                                  (cv: 5, chain: 4, draw: 1000,
                                                    channel: 2, changepoint: 5,
                                                    geo: 2, control: 2,
                                                    fourier_mode: 4, date: 167)
      Coordinates:
        * cv                                       (cv) object 40B 'Iteration 0' .....
        * chain                                    (chain) int64 32B 0 1 2 3
        * draw                                     (draw) int64 8kB 0 1 2 ... 998 999
        * channel                                  (channel) <U2 16B 'x1' 'x2'
        * changepoint                              (changepoint) int64 40B 0 1 2 3 4
        * geo                                      (geo) <U5 40B 'geo_a' 'geo_b'
        * control                                  (control) <U7 56B 'event_1' 'eve...
        * fourier_mode                             (fourier_mode) <U5 80B 'sin_1' ....
        * date                                     (date) datetime64[ns] 1kB 2018-0...
      Data variables: (12/20)
          adstock_alpha                            (cv, chain, draw, channel) float64 320kB ...
          delta                                    (cv, chain, draw, changepoint, geo) float64 2MB ...
          delta_b                                  (cv, chain, draw) float64 160kB ...
          gamma_control                            (cv, chain, draw, control) float64 320kB ...
          gamma_fourier                            (cv, chain, draw, geo, fourier_mode) float64 1MB ...
          gamma_fourier_b                          (cv, chain, draw) float64 160kB ...
          ...                                       ...
          fourier_contribution                     (cv, chain, draw, date, geo, fourier_mode) float64 214MB ...
          intercept_contribution_original_scale    (cv, chain, draw, geo) float64 320kB ...
          total_media_contribution_original_scale  (cv, chain, draw) float64 160kB ...
          trend_effect_contribution                (cv, chain, draw, date, geo) float64 53MB ...
          y_original_scale                         (cv, chain, draw, date, geo) float64 53MB ...
          yearly_seasonality_contribution          (cv, chain, draw, date, geo) float64 53MB ...
      Attributes:
          created_at:                 2025-12-20T13:44:59.248841+00:00
          arviz_version:              0.22.0
          inference_library:          numpyro
          inference_library_version:  0.19.0
          sampling_time:              7.597176
          tuning_steps:               1000
          pymc_marketing_version:     0.17.1

    • <xarray.Dataset> Size: 115MB
      Dimensions:           (cv: 5, chain: 4, draw: 1000, date: 179, geo: 2)
      Coordinates:
        * cv                (cv) object 40B 'Iteration 0' ... 'Iteration 4'
        * chain             (chain) int64 32B 0 1 2 3
        * draw              (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
        * date              (date) datetime64[ns] 1kB 2018-04-02 ... 2021-08-30
        * geo               (geo) <U5 40B 'geo_a' 'geo_b'
      Data variables:
          y                 (cv, chain, draw, date, geo) float64 57MB 0.4308 ... 0....
          y_original_scale  (cv, chain, draw, date, geo) float64 57MB 3.581e+03 ......
      Attributes:
          created_at:                 2025-12-20T13:45:01.510622+00:00
          arviz_version:              0.22.0
          inference_library:          pymc
          inference_library_version:  5.26.1

    • <xarray.Dataset> Size: 988kB
      Dimensions:          (cv: 5, chain: 4, draw: 1000)
      Coordinates:
        * cv               (cv) object 40B 'Iteration 0' ... 'Iteration 4'
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables:
          acceptance_rate  (cv, chain, draw) float64 160kB 0.9825 0.989 ... 0.9515 1.0
          diverging        (cv, chain, draw) bool 20kB False False ... False False
          energy           (cv, chain, draw) float64 160kB -506.9 -507.3 ... -528.7
          lp               (cv, chain, draw) float64 160kB -527.1 -523.7 ... -544.4
          n_steps          (cv, chain, draw) int64 160kB 63 63 63 63 ... 63 63 63 63
          step_size        (cv, chain, draw) float64 160kB 0.04739 0.04739 ... 0.05671
          tree_depth       (cv, chain, draw) int64 160kB 6 6 6 6 6 6 6 ... 7 6 6 6 6 6
      Attributes:
          created_at:     2025-12-20T13:44:59.257358+00:00
          arviz_version:  0.22.0

    • <xarray.Dataset> Size: 231MB
      Dimensions:                                         (cv: 5, chain: 1,
                                                           draw: 1000, date: 179,
                                                           geo: 2, control: 2,
                                                           fourier_mode: 4,
                                                           channel: 2, changepoint: 5)
      Coordinates:
        * cv                                              (cv) object 40B 'Iteratio...
        * chain                                           (chain) int64 8B 0
        * draw                                            (draw) int64 8kB 0 1 ... 999
        * date                                            (date) datetime64[ns] 1kB ...
        * geo                                             (geo) <U5 40B 'geo_a' 'ge...
        * control                                         (control) <U7 56B 'event_...
        * fourier_mode                                    (fourier_mode) <U5 80B 's...
        * channel                                         (channel) <U2 16B 'x1' 'x2'
        * changepoint                                     (changepoint) int64 40B 0...
      Data variables: (12/22)
          y_original_scale                                (cv, chain, draw, date, geo) float64 14MB ...
          intercept_contribution                          (cv, chain, draw, geo) float64 80kB ...
          y_sigma                                         (cv, chain, draw) float64 40kB ...
          control_contribution                            (cv, chain, draw, date, geo, control) float64 29MB ...
          yearly_seasonality_contribution_original_scale  (cv, chain, draw, date, geo) float64 14MB ...
          total_media_contribution_original_scale         (cv, chain, draw) float64 40kB ...
          ...                                              ...
          control_contribution_original_scale             (cv, chain, draw, date, geo, control) float64 29MB ...
          channel_contribution                            (cv, chain, draw, date, geo, channel) float64 29MB ...
          yearly_seasonality_contribution                 (cv, chain, draw, date, geo) float64 14MB ...
          intercept_contribution_original_scale           (cv, chain, draw, geo) float64 80kB ...
          delta                                           (cv, chain, draw, changepoint, geo) float64 400kB ...
          delta_b                                         (cv, chain, draw) float64 40kB ...
      Attributes:
          created_at:                 2025-07-26T08:20:31.433730+00:00
          arviz_version:              0.21.0
          inference_library:          pymc
          inference_library_version:  5.25.1
          pymc_marketing_version:     0.15.1

    • <xarray.Dataset> Size: 14MB
      Dimensions:  (cv: 5, chain: 1, draw: 1000, date: 179, geo: 2)
      Coordinates:
        * cv       (cv) object 40B 'Iteration 0' 'Iteration 1' ... 'Iteration 4'
        * chain    (chain) int64 8B 0
        * draw     (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
        * date     (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30
        * geo      (geo) <U5 40B 'geo_a' 'geo_b'
      Data variables:
          y        (cv, chain, draw, date, geo) float64 14MB 2.658 2.098 ... 2.466
      Attributes:
          created_at:                 2025-07-26T08:20:31.438500+00:00
          arviz_version:              0.21.0
          inference_library:          pymc
          inference_library_version:  5.25.1
          pymc_marketing_version:     0.15.1

    • <xarray.Dataset> Size: 15kB
      Dimensions:  (cv: 5, date: 167, geo: 2)
      Coordinates:
        * cv       (cv) object 40B 'Iteration 0' 'Iteration 1' ... 'Iteration 4'
        * date     (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-06-07
        * geo      (geo) <U5 40B 'geo_a' 'geo_b'
      Data variables:
          y        (cv, date, geo) float64 13kB 0.4794 0.5206 0.4527 ... 0.6063 0.5798
      Attributes:
          created_at:                 2025-12-20T13:44:59.258392+00:00
          arviz_version:              0.22.0
          inference_library:          numpyro
          inference_library_version:  0.19.0
          sampling_time:              7.597176
          tuning_steps:               1000

    • <xarray.Dataset> Size: 82kB
      Dimensions:        (cv: 5, date: 167, geo: 2, channel: 2, control: 2)
      Coordinates:
        * cv             (cv) object 40B 'Iteration 0' 'Iteration 1' ... 'Iteration 4'
        * date           (date) datetime64[ns] 1kB 2018-04-02 ... 2021-06-07
        * geo            (geo) <U5 40B 'geo_a' 'geo_b'
        * channel        (channel) <U2 16B 'x1' 'x2'
        * control        (control) <U7 56B 'event_1' 'event_2'
      Data variables:
          channel_data   (cv, date, geo, channel) float64 27kB 159.3 0.0 ... 72.29 0.0
          channel_scale  (cv, geo, channel) float64 160B 498.3 497.2 ... 498.3 497.2
          control_data   (cv, date, geo, control) float64 27kB 0.0 0.0 0.0 ... 0.0 0.0
          dayofyear      (cv, date) float64 7kB 92.0 99.0 106.0 ... 144.0 151.0 158.0
          target_data    (cv, date, geo) float64 13kB 3.985e+03 ... 4.894e+03
          target_scale   (cv, geo) float64 80B 8.312e+03 8.441e+03 ... 8.441e+03
          trend_t        (cv, date) float64 7kB 0.0 7.0 14.0 ... 1.155e+03 1.162e+03
      Attributes:
          created_at:                 2025-12-20T13:44:59.260742+00:00
          arviz_version:              0.22.0
          inference_library:          numpyro
          inference_library_version:  0.19.0
          sampling_time:              7.597176
          tuning_steps:               1000

    • <xarray.Dataset> Size: 95kB
      Dimensions:    (cv: 5, date: 167, geo: 2)
      Coordinates:
        * cv         (cv) object 40B 'Iteration 0' 'Iteration 1' ... 'Iteration 4'
        * date       (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-06-07
        * geo        (geo) object 16B 'geo_a' 'geo_b'
      Data variables:
          x1         (cv, date, geo) float64 13kB 159.3 159.3 56.19 ... 72.29 72.29
          x2         (cv, date, geo) float64 13kB 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
          event_1    (cv, date, geo) float64 13kB 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
          event_2    (cv, date, geo) float64 13kB 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
          dayofyear  (cv, date, geo) float64 13kB 92.0 92.0 99.0 ... 151.0 158.0 158.0
          t          (cv, date, geo) float64 13kB 0.0 0.0 1.0 ... 165.0 166.0 166.0
          y          (cv, date, geo) float64 13kB 3.985e+03 4.395e+03 ... 4.894e+03

    • <xarray.Dataset> Size: 88kB
      Dimensions:        (cv: 5, date: 179, geo: 2, channel: 2, control: 2)
      Coordinates:
        * cv             (cv) object 40B 'Iteration 0' 'Iteration 1' ... 'Iteration 4'
        * date           (date) datetime64[ns] 1kB 2018-04-02 ... 2021-08-30
        * geo            (geo) <U5 40B 'geo_a' 'geo_b'
        * channel        (channel) <U2 16B 'x1' 'x2'
        * control        (control) <U7 56B 'event_1' 'event_2'
      Data variables:
          channel_data   (cv, date, geo, channel) float64 29kB 159.3 0.0 ... 219.4 0.0
          channel_scale  (cv, geo, channel) float64 160B 498.3 497.2 ... 498.3 497.2
          control_data   (cv, date, geo, control) float64 29kB 0.0 0.0 0.0 ... 0.0 0.0
          dayofyear      (cv, date) float64 7kB 92.0 99.0 106.0 ... 228.0 235.0 242.0
          target_data    (cv, date, geo) float64 14kB 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0
          target_scale   (cv, geo) float64 80B 8.312e+03 8.441e+03 ... 8.441e+03
          trend_t        (cv, date) float64 7kB 0.0 7.0 14.0 ... 1.239e+03 1.246e+03
      Attributes:
          created_at:                 2025-12-20T13:45:01.515695+00:00
          arviz_version:              0.22.0
          inference_library:          pymc
          inference_library_version:  5.26.1

    • <xarray.Dataset> Size: 80B
      Dimensions:   (cv: 5)
      Coordinates:
        * cv        (cv) object 40B 'Iteration 0' 'Iteration 1' ... 'Iteration 4'
      Data variables:
          metadata  (cv) object 40B {'X_train':           date          x1         ...

Diagnósticos del Modelo#

Primero, evaluamos si tenemos alguna divergencia en el modelo (podemos extender el análisis con más diagnósticos del modelo).

# Let's check if there are any divergences
diverging_count = int(results.sample_stats["diverging"].values.sum())
print("Diverging transitions:", diverging_count)
Diverging transitions: 0

No tenemos divergencias en el modelo 😃!

Evaluar la Estabilidad del Parámetro#

A continuación, examinamos la estabilidad de los parámetros del modelo. Para un buen modelo, estos no deberían cambiar de manera abrupta a lo largo del tiempo.

  • Adstock Alpha

cv.plot.param_stability(
    results=results,
    parameter=["adstock_alpha"],
    dims={"geo": ["geo_a"]},
);
  • Saturación Beta

cv.plot.param_stability(
    results,
    parameter=["saturation_beta"],
    dims={"geo": ["geo_a", "geo_b"]},
);
  • Saturación Lambda

cv.plot.param_stability(
    results,
    parameter=["saturation_lam"],
    # dims={"geo": ["geo_a", "geo_b"]}
);

Los parámetros parecen ser estables a lo largo del tiempo. Esto implica que las estimaciones de ROAS no cambiarán abruptamente con el tiempo.

Evaluar Predicciones Fuera de Muestra#

Finalmente, evaluamos las predicciones fuera de muestra. Para comenzar, podemos simplemente trazar las distribuciones predictivas posteriores para cada iteración tanto para los datos de entrenamiento como para los datos de prueba.

# Plot model predictions across time slices
cv.plot.cv_predictions(
    results,
    # dims={"geo": ["geo_b"]} # to plot specific dimensions only
);
/Users/juanitorduz/Documents/pymc-marketing/pymc_marketing/mmm/plot.py:3259: UserWarning: The figure layout has changed to tight
  plt.tight_layout(rect=[0, 0.07, 1, 1])
../../_images/b4f2289c40830567b358077cabb7f85f2b246f57b758d4904e9504ea7cb7584c.png

En general, las predicciones fuera de muestra se ven muy bien 🚀!

Podemos cuantificar el rendimiento del modelo utilizando la Puntuación de Probabilidad Continua Clasificada (CRPS).

El CRPS — Puntaje de Probabilidad Clasificada Continua — es una función de puntuación que compara un único valor de verdad fundamental con una Función de Distribución Acumulativa. Puede utilizarse como una métrica para evaluar el rendimiento de un modelo cuando la variable objetivo es continua y el modelo predice la distribución del objetivo; ejemplos incluyen la Regresión Bayesiana o modelos de Series Temporales Bayesianos.

Para una buena explicación del CRPS, consulte esta entrada de blog.

En PyMC-Marketing, proporcionamos la función {func}crps <pymc_marketing.metrics.crps>` para calcular esta métrica. Podemos utilizarla para calcular la puntuación CRPS para cada iteración.

# Compute the CRPS score for each iteration and plot!
cv.plot.cv_crps(
    results,
    # dims={"geo": ["geo_b"]} # to plot specific dimensions only
);
/Users/juanitorduz/Documents/pymc-marketing/pymc_marketing/mmm/plot.py:3780: UserWarning: The figure layout has changed to tight
  plt.tight_layout()
../../_images/828546b94a677d35803f1dbe7b85050f9f017c941cde8ea07f9b655a7dce6bbf.png

A pesar de que los resultados visuales son excelentes, observamos que el CRPS disminuye levemente para los datos de entrenamiento mientras que aumenta para los datos de prueba a medida que incrementamos el tamaño de los datos de entrenamiento. Esto es un indicio de que estamos sobreajustando el modelo a los datos de entrenamiento. Algunas estrategias para superar este problema incluyen el uso de técnicas de regularización y la reevaluación de la especificación del modelo. Este debería ser un proceso iterativo.

%load_ext watermark
%watermark -n -u -v -iv -w -p pymc_marketing,pytensor,numpyro
Last updated: Sat Dec 20 2025

Python implementation: CPython
Python version       : 3.13.11
IPython version      : 9.8.0

pymc_marketing: 0.17.1
pytensor      : 2.35.1
numpyro       : 0.19.0

arviz         : 0.22.0
matplotlib    : 3.10.8
numpy         : 2.3.5
pandas        : 2.3.3
pymc_marketing: 0.17.1

Watermark: 2.5.1