Time-Slice-Cross-Validation and Parameter Stability#

In this notebook we will illustrate how to perform time-slice cross validation for a media mix model. This is an important step to evaluate the stability and quality of the model. We not only look into out of sample predictions but also the stability of the model parameters.

Prepare Notebook#

import warnings
from dataclasses import dataclass

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

from pymc_marketing.metrics import crps
from pymc_marketing.mmm import (
    MMM,
    GeometricAdstock,
    LogisticSaturation,
)
from pymc_marketing.mmm.utils import apply_sklearn_transformer_across_dim
from pymc_marketing.paths import data_dir
from pymc_marketing.prior import Prior

warnings.simplefilter(action="ignore", category=FutureWarning)

az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"


%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
seed: int = sum(map(ord, "mmm"))
rng: np.random.Generator = np.random.default_rng(seed=seed)

Read Data#

We use the same data as in the example notebook MMM Example Notebook.

data_path = data_dir / "mmm_example.csv"

data_df = pd.read_csv(data_path, parse_dates=["date_week"])

data_df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 179 entries, 0 to 178
Data columns (total 8 columns):
 #   Column     Non-Null Count  Dtype         
---  ------     --------------  -----         
 0   date_week  179 non-null    datetime64[ns]
 1   y          179 non-null    float64       
 2   x1         179 non-null    float64       
 3   x2         179 non-null    float64       
 4   event_1    179 non-null    float64       
 5   event_2    179 non-null    float64       
 6   dayofyear  179 non-null    int64         
 7   t          179 non-null    int64         
dtypes: datetime64[ns](1), float64(5), int64(2)
memory usage: 11.3 KB

Specify Time-Slice-Cross-Validation Strategy#

The main idea of the time-slice cross validation process is to fit the model on a time slice of the data and then evaluate it on the next time slice. We repeat this process for each time slice of the data. As we want to simulate a production-like environment where we enlarge our training data over time, we make the time-slice size grow over time.

Following the strategy of the example notebook, we use the costs share of each channel to set the prior standard deviation of the beta parameters. We need to compute this share for each training time slice independently.

Data Leakage

It is very important to avoid data leakage when performing time-slice cross validation. This means that the model should not see any training data from the future. This also includes any data pre-processing steps!

For example, as mentioned above, we need to compute the costs share for each training time slice independently if we want to avoid data leakage. Other sources of data leakage include using a global feature for thr trend component. In our case, we simply use an increasing variable t so we are safe as we just increase it by one for each time slice.

We wrap the main steps of the training procedure in a set of functions.

def compute_sigma_from_costs(
    X: pd.DataFrame, channel_columns: list[str]
) -> list[float]:
    """Compute the prior standard deviation of the beta parameters from the costs share of each channel."""
    n_channels = len(channel_columns)
    total_spend_per_channel = X[channel_columns].sum(axis=0)
    spend_share = total_spend_per_channel / total_spend_per_channel.sum()
    prior_sigma = n_channels * spend_share.to_numpy()
    return prior_sigma.tolist()


def get_mmm(X: pd.DataFrame, channel_columns: list[str]) -> MMM:
    """Specify the model."""
    prior_sigma = compute_sigma_from_costs(X, channel_columns)

    model_config = {
        "intercept": Prior("Normal", mu=0, sigma=0.5),
        "saturation_beta": Prior("HalfNormal", sigma=prior_sigma, dims="channel"),
        "gamma_control": Prior("Normal", mu=0, sigma=0.05, dims="control"),
        "likelihood": Prior("Normal", sigma=Prior("Exponential", lam=1 / 10)),
    }

    return MMM(
        adstock=GeometricAdstock(l_max=8),
        saturation=LogisticSaturation(),
        date_column="date_week",
        channel_columns=channel_columns,
        control_columns=[
            "event_1",
            "event_2",
            "t",
        ],
        yearly_seasonality=2,
        model_config=model_config,
    )


def fit_mmm(
    mmm: MMM, X: pd.DataFrame, y: pd.Series, random_seed: np.random.Generator
) -> MMM:
    """Fit the model."""
    fit_kwargs = {
        "tune": 1_000,
        "chains": 4,
        "draws": 1_000,
        "nuts_sampler": "numpyro",
        "random_seed": random_seed,
    }
    _ = mmm.fit(X, y, progressbar=False, **fit_kwargs)
    _ = mmm.sample_posterior_predictive(
        X, extend_idata=True, combined=True, progressbar=False, random_seed=random_seed
    )
    return mmm

For the sake of convenience, we define a data container to store the results of the time-slice cross validation step.

@dataclass
class TimeSliceCrossValidationResult:
    """Container for the results of the time slice cross validation."""

    X_train: pd.DataFrame
    y_train: pd.Series
    X_test: pd.DataFrame
    y_test: pd.Series
    mmm: MMM
    y_pred_test: pd.Series

Finally, we define the main function that performs the time-slice cross validation step by calling the functions defined above.

def time_slice_cross_validation_step(
    X_train: pd.DataFrame,
    y_train: pd.Series,
    X_test: pd.DataFrame,
    y_test: pd.Series,
    random_seed: np.random.Generator,
) -> TimeSliceCrossValidationResult:
    """Time-slice cross validation step.

    We fit the model on the training data and generate predictions for the test data.

    Parameters
    ----------
    X_train: pd.DataFrame
        Training data.
    y_train: pd.Series
        Training target.
    X_test: pd.DataFrame
        Test data.
    y_test: pd.Series
        Test target.

    Returns
    -------
    TimeSliceCrossValidationResult
        Results of the time slice cross validation step.
    """
    mmm = get_mmm(X_train, channel_columns=["x1", "x2"])
    mmm = fit_mmm(mmm, X_train, y_train, random_seed)

    y_pred_test = mmm.sample_posterior_predictive(
        X_test,
        include_last_observations=True,
        original_scale=True,
        extend_idata=False,
        progressbar=False,
        random_seed=random_seed,
    )

    return TimeSliceCrossValidationResult(
        X_train=X_train,
        y_train=y_train,
        X_test=X_test,
        y_test=y_test,
        mmm=mmm,
        y_pred_test=y_pred_test,
    )

We are now ready to run the time-slice cross validation loop 💪!

Run Time-Slice-Cross-Validation Loop#

Depending on the business requirements, we need to decide the initial number of observations to use for fitting the model (n_init) and the forecast horizon (forecast_horizon). For this example, we use the first 158 observations to fit the model and then predict the next 12 observations (3 months).

X = data_df[["date_week", "x1", "x2", "event_1", "event_2", "t"]]
y = data_df["y"]
n_init = 158
forecast_horizon = 12
n_iterations = y.size - n_init - forecast_horizon + 1

Let’s run it!

results: list[TimeSliceCrossValidationResult] = []


for iteration in tqdm(range(n_iterations)):
    # Split data into train and test
    train_test_split = iteration + n_init
    X_train = X.iloc[:train_test_split]
    y_train = y.iloc[:train_test_split]
    X_test = X.iloc[train_test_split : train_test_split + forecast_horizon]
    y_test = y.iloc[train_test_split : train_test_split + forecast_horizon]

    step = time_slice_cross_validation_step(
        X_train=X_train,
        y_train=y_train,
        X_test=X_test,
        y_test=y_test,
        random_seed=rng,
    )
    results.append(step)
Sampling: [y]
Sampling: [y]
Sampling: [y]
Sampling: [y]
Sampling: [y]
Sampling: [y]
Sampling: [y]
Sampling: [y]
Sampling: [y]
Sampling: [y]
Sampling: [y]
Sampling: [y]
Sampling: [y]
Sampling: [y]
Sampling: [y]
Sampling: [y]
Sampling: [y]
Sampling: [y]
Sampling: [y]
Sampling: [y]

Model Diagnostics#

First, we evaluate whether we have any divergences in the model (we can extend the analysis more more model diagnostics).

sum([result.mmm.idata["sample_stats"]["diverging"].sum().item() for result in results])
0

We have no divergences in the model 😃!

Evaluate Parameter Stability#

Next, we look at the stability of the model parameters. For a good model, these should not change abruptly over time.

  • Adstock Alpha

fig, ax = plt.subplots(figsize=(9, 6))

az.plot_forest(
    data=[result.mmm.idata["posterior"] for result in results],
    model_names=[f"Iteration {i}" for i in range(n_iterations)],
    var_names=["adstock_alpha"],
    combined=True,
    ax=ax,
)
fig.suptitle("Adstock Alpha", fontsize=18, fontweight="bold", y=1.06);
../../_images/fc1f0f55868d3242fa295bb8d58a9814f57dd2c4be07b7e4bca8dbeb7e3cc67d.png
  • Saturation Beta

fig, ax = plt.subplots(figsize=(9, 6))

az.plot_forest(
    data=[result.mmm.idata["posterior"] for result in results],
    model_names=[f"Iteration {i}" for i in range(n_iterations)],
    var_names=["saturation_beta"],
    combined=True,
    ax=ax,
)
fig.suptitle("Saturation Beta", fontsize=18, fontweight="bold", y=1.06);
../../_images/585d81f2884f3d1dd245a8e86aa17225d690d9ecb0b967b4b8b0455ceeaea26b.png
  • Saturation Lambda

fig, ax = plt.subplots(figsize=(9, 6))

az.plot_forest(
    data=[result.mmm.idata["posterior"] for result in results],
    model_names=[f"Iteration {i}" for i in range(n_iterations)],
    var_names=["saturation_lam"],
    combined=True,
    ax=ax,
)
fig.suptitle("Saturation Lambda", fontsize=18, fontweight="bold", y=1.06);
../../_images/919171ab9db9579c269f7c324d36ae195fb1ec3cdc3301ea0a5ec801e414c5c2.png

The parameters seem to be stable over time. This implies that the estimates ROAS will not change abruptly over time.

Evaluate Out of Sample Predictions#

Finally, we evaluate the out of sample predictions. To begin with, we can simply plot the posterior predictive distributions for each iteration for both the training and test data.

Hide code cell source
fig, axes = plt.subplots(
    nrows=n_iterations,
    ncols=1,
    figsize=(12, 25),
    sharex=True,
    sharey=True,
    layout="constrained",
)

axes = axes.ravel()

for i, result in enumerate(results):
    ax = axes[i]
    result.mmm.plot_posterior_predictive(original_scale=True, ax=ax)

    hdi_prob = 0.94
    test_hdi = az.hdi(result.y_pred_test["y"].to_numpy().T, hdi_prob=hdi_prob)

    ax.fill_between(
        result.X_test["date_week"],
        test_hdi[:, 0],
        test_hdi[:, 1],
        color="C1",
        label=f"{hdi_prob:.0%} HDI (test)",
        alpha=0.5,
    )
    ax.plot(X["date_week"], y, marker="o", color="black")
    ax.axvline(result.X_test["date_week"].iloc[0], color="C2", linestyle="--")
    ax.legend(loc="upper right")

axes[-1].set(xlim=(X["date_week"].iloc[n_init - 9], None))
fig.suptitle("Posterior Predictive Check", fontsize=18, fontweight="bold", y=1.02);
../../_images/2ee71bebeda0e77b88a4c72af255b498bfa994098f92230c14d3fca841b8aae9.png

Overall, the out of sample predictions look very good 🚀!

We can quantify the model performance using the Continuous Ranked Probability Score (CRPS).

“The CRPS — Continuous Ranked Probability Score — is a score function that compares a single ground truth value to a Cumulative Distribution Function. It can be used as a metric to evaluate a model’s performance when the target variable is continuous and the model predicts the target’s distribution; Examples include Bayesian Regression or Bayesian Time Series models.”

For a nice explanation of the CRPS, check out this blog post.

In PyMC Marketing, we provide the function crps to compute this metric. We can use it to compute the CRPS score for each iteration.

crps_results_train: list[float] = [
    crps(
        y_true=result.y_train.to_numpy(),
        y_pred=az.extract(
            # Scale the predictions back to the original scale
            apply_sklearn_transformer_across_dim(
                data=result.mmm.idata.posterior_predictive["y"],
                func=result.mmm.get_target_transformer().inverse_transform,
                dim_name="date",
            )
        )["y"]
        .to_numpy()
        .T,
    )
    for result in results
]


crps_results_test: list[float] = [
    crps(
        y_true=result.y_test.to_numpy(),
        y_pred=result.y_pred_test["y"].to_numpy().T,
    )
    for result in results
]

fig, ax = plt.subplots(
    nrows=2, ncols=1, figsize=(12, 7), sharex=True, sharey=False, layout="constrained"
)

ax[0].plot(crps_results_train, marker="o", color="C0", label="train")
ax[0].set(ylabel="CRPS", title="Train CRPS")
ax[1].plot(crps_results_test, marker="o", color="C1", label="test")
ax[1].set(xlabel="Iteration", ylabel="CRPS", title="Test CRPS")
fig.suptitle("CRPS for each iteration", fontsize=18, fontweight="bold", y=1.05);
../../_images/0866dbae19bdf2f1c1ea7adb9f5d12a1d3644bcaf62bb511a9335344a91c4a52.png

Event though the visual results look great, we see that the CRPS mildly decreases for the training data while it increases for the test data as we increase the size of the training data. This is a sign that we are overfitting the model to the training data. Some strategies to overcome this issue include using regularization techniques and re-evaluate the model specification. This should be an iterative process.

%load_ext watermark
%watermark -n -u -v -iv -w -p pymc_marketing,pytensor,numpyro
Last updated: Sun Aug 25 2024

Python implementation: CPython
Python version       : 3.12.4
IPython version      : 8.26.0

pymc_marketing: 0.8.0
pytensor      : 2.22.1
numpyro       : 0.15.2

matplotlib: 3.9.0
pandas    : 2.2.2
arviz     : 0.19.0
numpy     : 1.26.4

Watermark: 2.4.3