MMM with time-varying media baseline#

Introduction#

In the domain of Marketing Mix Modeling (MMM), understanding the impact of various marketing activities on a target variable and other key performance indicators is crucial. Traditional regression models often neglect the temporal dynamics of marketing activities, potentially leading to biased or incomplete insights. This notebook aims to showcase the difference between a conventional regression model that does not account for time variation and a more sophisticated model that incorporates time as a key component through a Gaussian process.

The objective is to determine the contribution of each marketing activity to the overall target variable or desired outcome. This process typically involves two critical transformations:

  1. Saturation Function: This function models the diminishing returns of marketing inputs. As more resources are allocated to a specific channel, the incremental benefit tends to decrease.

  2. Adstock Function: This function captures the carryover effect of marketing activities over time, recognizing that the impact of a marketing effort extends beyond the immediate period in which it occurs.

The standard approach in MMM applies these transformations to the marketing inputs, resulting in a contribution to the outcome.

Time-Dependent MMM Model#

In real-world scenarios, the effectiveness of marketing activities is not static but varies over time due to factors like competitive actions, and market dynamics. To account for this, we introduce a time-dependent component into the MMM framework using a Gaussian Process, specifically a Hilbert Space GP. This allows us to capture the hidden latent temporal variation of the marketing contributions.

Model Specification#

In pymc-marketing we provide an API for a Bayesian media mix model (MMM) specification following Jin, Yuxue, et al. “Bayesian methods for media mix modeling with carryover and shape effects.” (2017) as a base model. Concretely, given a time series target variable \(y_{t}\) (e.g. sales or conversions), media variables \(x_{m, t}\) (e.g. impressions, clicks or costs) and a set of control covariates \(z_{c, t}\) (e.g. holidays, special events) we consider a linear model of the form

\[ y_{t} = \alpha + \sum_{m=1}^{M}\beta_{m}f(x_{m, t}) + \sum_{c=1}^{C}\gamma_{c}z_{c, t} + \varepsilon_{t}, \]

where \(\alpha\) is the intercept, \(f\) is a media transformation function and \(\varepsilon_{t}\) is the error term which we assume is normally distributed. The function \(f\) encodes the contribution of media on the target variable. Typically we consider two types of transformation: adstock (carry-over) and saturation effects.

When time_media_varying is set to True, we capture a single latent process that multiplies all channels. We assume all channels share the same time-dependent fluctuations, contrasting with implementations where each channel has an independent latent process. The modified model can be represented as:

\[ y_{t} = \alpha + \lambda_{t} \cdot \sum_{m=1}^{M}\beta_{m}f(x_{m, t}) \ + \sum_{c=1}^{C}\gamma_{c}z_{c, t} + \varepsilon_{t}, \]

where \(\lambda_{t}\) is the time-varying component modeled as a latent process. This shared time-dependent variation \(\lambda_{t}\) allows us to capture the overall temporal effects that influence all media channels simultaneously.

Objective#

This notebook will:

  1. Illustrate the formulation of a standard MMM model without time variation.

  2. Extend the model to include a time component using HSGP.

  3. Compare the results and insights derived from both models, highlighting the importance of incorporating time variation in capturing the true impact of marketing activities.

By the end of this notebook, you will have a comprehensive understanding of the advantages of using time-dependent MMM models in capturing the dynamic nature of marketing effectiveness, leading to more accurate and actionable insights.

Prerequisite Knowledge#

The notebook assumes the reader has knowledge of the essential functionalities of PyMC-Marketing. If one is unfamiliar, the “MMM Example Notebook” serves as an excellent starting point, offering a comprehensive introduction to media mix models in this context.


Part I: Data Generation Process#

In Part I of this notebook we focus on the data generating process. That is, we want to construct the target variable \(y_{t}\) (sales) by adding each of the components described in the Business Problem section.

Prepare Notebook#

import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import seaborn as sns

from pymc_marketing.mmm import MMM, GeometricAdstock, MichaelisMentenSaturation
from pymc_marketing.prior import Prior

warnings.filterwarnings("ignore")

az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"

1. Date Range#

First we set a time range for our data. We consider a bit more than 2 years of data at weekly granularity.

# Creating variables

seed: int = sum(map(ord, "Time Media Contributions are amazing"))
rng: np.random.Generator = np.random.default_rng(seed=seed)

# date range
min_date = pd.to_datetime("2018-04-01")
max_date = pd.to_datetime("2021-09-01")

df = pd.DataFrame(
    data={"date_week": pd.date_range(start=min_date, end=max_date, freq="W-MON")}
).assign(
    year=lambda x: x["date_week"].dt.year,
    month=lambda x: x["date_week"].dt.month,
    dayofyear=lambda x: x["date_week"].dt.dayofyear,
)

n = df.shape[0]
print(f"Number of observations: {n}")
Number of observations: 179

2. Media Costs Data#

Now we generate synthetic data from two channels \(x_1\) and \(x_2\). We refer to it as the raw signal as it is going to be the input at the modeling phase. We expect the contribution of each channel to be different, based on the carryover and saturation parameters.

x1 = rng.uniform(low=0.0, high=1.0, size=n)
df["x1"] = np.where(x1 > 0.9, x1, x1 / 2)

x2 = rng.uniform(low=0.0, high=1.0, size=n)
df["x2"] = np.where(x2 > 0.8, x2, 0)
fig, ax = plt.subplots(
    nrows=2, ncols=1, figsize=(10, 7), sharex=True, sharey=True, layout="constrained"
)
sns.lineplot(x="date_week", y="x1", data=df, color="C0", ax=ax[0])
sns.lineplot(x="date_week", y="x2", data=df, color="C1", ax=ax[1])
ax[1].set(xlabel="date_week")
fig.suptitle("Media Costs Data", fontsize=16);
../../_images/bd523dfd8f2392ba765b2a39ed0130a79728120a27b16557f3c4ca510c48b52f.png

Remark: By design, \(x_{1}\) should resemble a typical paid social channel and \(x_{2}\) a offline (e.g. TV) spend time series.

4. Control Variables#

We add two events where there was a remarkable peak in our target variable. We assume they are independent an not seasonal (e.g. launch of a particular product).

df["event_1"] = (df["date_week"] == "2019-05-13").astype(float)
df["event_2"] = (df["date_week"] == "2020-09-14").astype(float)

5. Temporal Hidden Latent Process#

To illustrate the impact of time-varying media performance in our model, we generate a synthetic signal that modifies the base contribution. This signal, hidden_latent_media_fluctuation, is designed to simulate the natural fluctuations in media performance over time.

df["hidden_latent_media_fluctuation"] = (
    np.cos(0.5 * np.pi / 60 * np.arange(n)) / 2 + 1
) * 1

fig, ax = plt.subplots(
    nrows=1, ncols=1, figsize=(10, 7), sharex=True, sharey=True, layout="constrained"
)
sns.lineplot(x="date_week", y="hidden_latent_media_fluctuation", data=df, color="C0")
ax.set(xlabel="date_week")
fig.suptitle("Media performance change", fontsize=16);
../../_images/7fe8916420bc2de7be12e7f59d699a492344e8d27e5066062d096e1b6c2a0eb4.png

By centering the signal around 1, we’ll maintain the base contribution as the average effect while allowing for periodic increases and decreases. This approach mirrors real-world scenarios where marketing effectiveness can vary, but the overall trend remains consistent.

This synthetic signal is essential for demonstrating the efficacy of our time-dependent MMM model, which should recover this signal as much as possible.

6. Target Variable#

Finally, we need to create our target variable. In order to do create, we’ll use the PyMC do operator to specify some true parameter values that govern the causal relationships in the model.

Doing this, we’ll draw a simulated target variable (sales) \(y\), which assume it is a linear combination of all components in the model. We also add some Gaussian noise.

adstock_max_lag = 8
yearly_seasonality = 2

dummy_mmm = MMM(
    date_column="date_week",
    channel_columns=["x1", "x2"],
    control_columns=["event_1", "event_2"],
    yearly_seasonality=yearly_seasonality,
    adstock=GeometricAdstock(l_max=adstock_max_lag),
    saturation=MichaelisMentenSaturation(),
    time_varying_media=True,
)
df["init_target"] = 0
dummy_mmm.build_model(df.drop(columns=["init_target"]), df["init_target"])

Tip

After build your dummy model you can name all variables and parameters in it, using named_vars Alternatively you can make a plot of the model graph.

# Model to graphiz
pm.model_to_graphviz(dummy_mmm.model)
../../_images/3d2384378b0e19d3eb79443e0c87917bdb4ef191f65ca84a5dfde1f1e64b8f07.svg

Here are the true parameter values used in our model:

  • Intercept: 6.0

  • Adstock Alpha: [0.5, 0.4] (for two different media channels)

  • Saturation Alpha: [3, 5] (for two different media channels)

  • Saturation Lambda: [0.3, 0.5] (for two different media channels)

  • Media Temporal Latent Multiplier: The time-varying signal hidden_latent_media_fluctuation from our dataset

  • Gamma Fourier: [2.5, -0.5, 1.5, 2.5] (coefficients for Fourier terms)

  • Y Sigma: 0.25 (Noise)

  • Gamma Control: [-3.5, 6.25] (coefficients for control events, e.g., event1 and event2)

By specifying these true parameter values, we create a realistic simulated target variable that encapsulates the complexity of our media mix model. This approach allows us to effectively test and validate the performance of our time-dependent MMM model.

# Real values
real_alpha = [3, 5]
real_lam = [0.3, 0.5]

true_params = {
    "intercept": 6.0,
    "adstock_alpha": np.array([0.5, 0.4]),
    "saturation_alpha": np.array(real_alpha),
    "saturation_lam": np.array(real_lam),
    "media_temporal_latent_multiplier": df["hidden_latent_media_fluctuation"],
    "gamma_fourier": np.array(
        [
            2.5,
            -0.5,
            1.5,
            2.5,
        ]
    ),
    "y_sigma": 0.25,
    "gamma_control": np.array([-3.5, 6.25]),
}
true_model = pm.do(
    dummy_mmm.model,
    true_params,
)

Let’s unpack this a little bit. The do-function takes a pymc.Model object and a dict of parameter values. It then returns a new model where the original random variables (RVs) have been converted to constant nodes taking on the specified values.

Let’s start by drawing our intercept using our “draw” function from PyMC.

df["intercept"] = pm.draw(true_model.intercept, random_seed=rng)
plt.plot(df["intercept"])
plt.title("Intercept Over Time")
plt.xlabel("date_week")
plt.ylabel("Sales (thousands)");
../../_images/6b5d5b5bddb72b6a7a05df9ddcb1073a73a26ebbe81d55dee90fadf7fbdb8112.png

As you can see, the intercept is aligned with the previously added data, having a constant value of 6. But how looks our total contribution after transformed?

df["baseline_channel_contributions"] = pm.draw(
    true_model.baseline_channel_contributions.sum(axis=-1), random_seed=rng
)
df["channel_contributions"] = pm.draw(
    true_model.channel_contributions.sum(axis=-1), random_seed=rng
)

fig, ax = plt.subplots(2, 1, figsize=(12, 6), sharex=True)

ax[0].plot(df["baseline_channel_contributions"], color="purple", linestyle="--")
ax[0].set_title("Baseline Channel Contributions")
ax[0].set_xlabel("date_week")
ax[0].set_ylabel("Sales (thousands)")

ax[1].plot(df["channel_contributions"], color="purple")
ax[1].set_title("Channel Contributions")
ax[1].set_xlabel("date_week")
ax[1].set_ylabel("Sales (thousands)");
../../_images/038eaa289c8c9cef642dc9db6b34c5fef55ec0237abe262b1a133007a8fb508b.png

Baseline Channel Contributions

The left plot, titled “Baseline Channel Contributions,” shows the contributions of the media channels before considering the time-varying effects. The values are generated by summing the baseline channel contributions drawn from the true model.

Channel Contributions with Time Variation

The right plot, titled “Channel Contributions,” displays the media channel contributions after incorporating the time-varying media performance signal. These contributions reflect the impact of the latent temporal process, represented by hidden_latent_media_fluctuation, which modifies the baseline contributions. This modification captures the natural fluctuations in media performance over time, as influenced by various marketing dynamics.

df["x1_contribution"] = pm.draw(true_model.channel_contributions, random_seed=rng)[:, 0]
df["x2_contribution"] = pm.draw(true_model.channel_contributions, random_seed=rng)[:, 1]

fig, ax = plt.subplots(
    nrows=2, ncols=1, figsize=(10, 7), sharex=True, sharey=True, layout="constrained"
)
sns.lineplot(x="date_week", y="x1_contribution", data=df, color="C0", ax=ax[0])
sns.lineplot(x="date_week", y="x2_contribution", data=df, color="C1", ax=ax[1])
ax[1].set(xlabel="date_week")
fig.suptitle("Media Contribution per Channel", fontsize=16);
../../_images/9bfabd05eec3905f8ba0c877b0e25398f2a89ca086a0d2770125966f0ec14032.png

7. Trend & Seasonal Components#

We can also observe the contribution of our control events, as well as the seasonality added when making the true model.

df["yearly_seasonality_contribution"] = pm.draw(
    true_model.yearly_seasonality_contribution, random_seed=rng
)

fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(12, 8), sharex=True)

ax1.plot(df["yearly_seasonality_contribution"])
ax1.set_title("Yearly Seasonality Contribution")
ax1.set_xlabel("date_week")
ax1.set_ylabel("Sales (thousands)")


df["control_contributions"] = pm.draw(
    true_model.control_contributions, random_seed=rng
).sum(axis=-1)

ax2.plot(df["control_contributions"])
ax2.set_title("Control Contributions")
ax2.set_xlabel("date_week")
ax2.set_ylabel("Sales (thousands)");
../../_images/e91a49066854319f4ee15e95247f82bac57f9f29d38da53893b86c57faa8731c.png

Finally, we can visualize the true target given all the previous componets!

df["y"] = pm.draw(true_model.y, random_seed=rng)

plt.plot(df["y"], color="black")
plt.title("Target Variable (Sales)")
plt.xlabel("date_week")
plt.ylabel("Sales (thousands)");
../../_images/761b794e52ec0b0ee7cfc2c6397468025bd124c1057a05fe55de94eb046c851b.png

Now with everything in place, we are going to separate our dataset to leave the real data estimated by the true model inside df and we will create a new dataset called data which will have all the necessary columns but will not have any information about true relationships. Similar to how it would happen in real life.

data = df[["date_week", "x1", "x2", "event_1", "event_2", "y"]].copy()

X = data.drop("y", axis=1)
y = data["y"]

As we discussed previously, we want to compare a model without variant coefficients to see how much it deviates from reality. For this we will create the MMM object that will receive all the necessary parameters to build our model, which should estimate the relationships of the true model.

basic_mmm = MMM(
    date_column="date_week",
    channel_columns=["x1", "x2"],
    control_columns=["event_1", "event_2"],
    yearly_seasonality=yearly_seasonality,
    adstock=GeometricAdstock(l_max=adstock_max_lag),
    saturation=MichaelisMentenSaturation(),
)

basic_mmm.fit(
    X=X,
    y=y,
    target_accept=0.90,
    draws=4000,
    tune=2000,
    chains=4,
    nuts_sampler="numpyro",
    random_seed=rng,
)
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
There were 35 divergences after tuning. Increase `target_accept` or reparameterize.
arviz.InferenceData
    • <xarray.Dataset> Size: 231MB
      Dimensions:                          (chain: 4, draw: 4000, control: 2,
                                            fourier_mode: 4, channel: 2, date: 179)
      Coordinates:
        * chain                            (chain) int64 32B 0 1 2 3
        * draw                             (draw) int64 32kB 0 1 2 ... 3997 3998 3999
        * control                          (control) <U7 56B 'event_1' 'event_2'
        * fourier_mode                     (fourier_mode) <U11 176B 'sin_order_1' ....
        * channel                          (channel) <U2 16B 'x1' 'x2'
        * date                             (date) datetime64[ns] 1kB 2018-04-02 ......
      Data variables:
          intercept                        (chain, draw) float64 128kB 0.4198 ... 0...
          gamma_control                    (chain, draw, control) float64 256kB -0....
          gamma_fourier                    (chain, draw, fourier_mode) float64 512kB ...
          adstock_alpha                    (chain, draw, channel) float64 256kB 0.6...
          saturation_alpha                 (chain, draw, channel) float64 256kB 0.5...
          saturation_lam                   (chain, draw, channel) float64 256kB 2.4...
          y_sigma                          (chain, draw) float64 128kB 0.06081 ... ...
          channel_contributions            (chain, draw, date, channel) float64 46MB ...
          control_contributions            (chain, draw, date, control) float64 46MB ...
          fourier_contributions            (chain, draw, date, fourier_mode) float64 92MB ...
          yearly_seasonality_contribution  (chain, draw, date) float64 23MB -0.0069...
          mu                               (chain, draw, date) float64 23MB 0.4359 ...
      Attributes:
          created_at:     2024-07-01T15:28:35.626884
          arviz_version:  0.17.1

    • <xarray.Dataset> Size: 816kB
      Dimensions:          (chain: 4, draw: 4000)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 32kB 0 1 2 3 4 5 ... 3995 3996 3997 3998 3999
      Data variables:
          acceptance_rate  (chain, draw) float64 128kB 0.9893 0.9954 ... 0.9757 0.9665
          step_size        (chain, draw) float64 128kB 0.04815 0.04815 ... 0.06408
          diverging        (chain, draw) bool 16kB False False False ... False False
          energy           (chain, draw) float64 128kB -219.6 -216.3 ... -226.0 -221.7
          n_steps          (chain, draw) int64 128kB 63 63 63 63 127 ... 63 63 63 63
          tree_depth       (chain, draw) int64 128kB 6 6 6 6 7 6 6 6 ... 6 6 6 6 6 6 6
          lp               (chain, draw) float64 128kB -225.0 -226.8 ... -230.5 -227.7
      Attributes:
          created_at:     2024-07-01T15:28:35.649848
          arviz_version:  0.17.1

    • <xarray.Dataset> Size: 3kB
      Dimensions:  (date: 179)
      Coordinates:
        * date     (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30
      Data variables:
          y        (date) float64 1kB 0.522 0.5735 0.5612 ... 0.5669 0.5082 0.4131
      Attributes:
          created_at:                 2024-07-01T15:28:35.655105
          arviz_version:              0.17.1
          inference_library:          numpyro
          inference_library_version:  0.14.0
          sampling_time:              42.398164

    • <xarray.Dataset> Size: 13kB
      Dimensions:       (date: 179, channel: 2, control: 2, fourier_mode: 4)
      Coordinates:
        * date          (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30
        * channel       (channel) <U2 16B 'x1' 'x2'
        * control       (control) <U7 56B 'event_1' 'event_2'
        * fourier_mode  (fourier_mode) <U11 176B 'sin_order_1' ... 'cos_order_2'
      Data variables:
          channel_data  (date, channel) float64 3kB 0.2957 0.0 0.9413 ... 0.1273 0.0
          control_data  (date, control) float64 3kB 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0
          fourier_data  (date, fourier_mode) float64 6kB 0.9999 -0.01183 ... -0.4547
      Attributes:
          created_at:                 2024-07-01T15:28:35.663986
          arviz_version:              0.17.1
          inference_library:          numpyro
          inference_library_version:  0.14.0
          sampling_time:              42.398164

    • <xarray.Dataset> Size: 10kB
      Dimensions:    (index: 179)
      Coordinates:
        * index      (index) int64 1kB 0 1 2 3 4 5 6 7 ... 172 173 174 175 176 177 178
      Data variables:
          date_week  (index) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30
          x1         (index) float64 1kB 0.2948 0.9383 0.1397 ... 0.9225 0.9364 0.1269
          x2         (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
          event_1    (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0
          event_2    (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0
          y          (index) float64 1kB 7.764 8.529 8.347 7.955 ... 8.431 7.558 6.143

As we can see the model found divergencies!🤯

The occurrence of divergences in our Bayesian MMM highlights the strengths and robustness of the Bayesian framework in hypothesis testing and model validation. Bayesian models are structural and adhere to certain assumptions about the data-generating process. When these assumptions are violated or the model structure does not fit the data well, divergences and sampling problems can arise.

This characteristic makes the Bayesian approach a powerful tool for:

  • Hypothesis Testing: By defining clear structural relationships and assumptions, Bayesian models can help test and validate hypotheses about the underlying processes in the data.

  • Model Validation: Divergences and sampling issues serve as indicators that the model may not be correctly specified, prompting further investigation and refinement.

  • Understanding Complex Systems: Bayesian methods allow for the incorporation of prior knowledge and the testing of various structural assumptions, making them well-suited for understanding complex, real-world systems.

In this particular case, we can perfectly suspect why the model had divergences. The internal structure from our world model (MMM) is neglecting time when this one is an important factor (we know this because we have carried out the due process of data generation).


Despite that, let’s take a look at the data that we were able to recover through this basic model.

If we decompose the posterior predictive distribution into the different components, everything becomes clear:

basic_mmm.plot_components_contributions();
../../_images/539e5f443a237c659d4449d76f5b3e15e937674ecf54a8d820d77863d79a0cfb.png

Some contributions end up having more units than the target value, forcing the model to compensate. Resulting in an incorrect decomposition of our marketing activities.

For example, our time series ends up with a long tail of probable values ​​for marketing contributions, this tail being up to 3X greater than the maximum value of our target.

def plot_posterior(
    posterior, figsize=(15, 8), path_color="blue", hist_color="blue", **kwargs
):
    """Plot the posterior distribution of a stochastic process.

    Parameters
    ----------
    posterior : xarray.DataArray
        The posterior distribution with shape (draw, chain, date).
    figsize : tuple
        Size of the figure.
    path_color : str
        Color of the paths in the time series plot.
    hist_color : str
        Color of the histogram.
    **kwargs
        Additional keyword arguments to pass to the plotting functions.

    """
    # Calculate the expected value (mean) across all draws and chains for each date
    expected_value = posterior.mean(dim=("draw", "chain"))

    # Create a figure and a grid of subplots
    fig = plt.figure(figsize=figsize)
    gs = fig.add_gridspec(1, 2, width_ratios=[3, 1])

    # Time series plot
    ax1 = fig.add_subplot(gs[0])
    for chain in range(posterior.shape[1]):
        for draw in range(
            0, posterior.shape[0], 10
        ):  # Plot every 10th draw for performance
            ax1.plot(
                posterior.date,
                posterior[draw, chain],
                color=path_color,
                alpha=0.05,
                linewidth=0.4,
            )

    # Plot expected value with a distinct color
    ax1.plot(
        posterior.date,
        expected_value,
        color="black",
        linestyle="--",
        linewidth=2,
        label="Expected Value",
    )
    ax1.set_title("Posterior Predictive")
    ax1.set_xlabel("Date")
    ax1.set_ylabel("Value")
    ax1.grid(True)
    ax1.legend()

    # KDE plot instead of histogram
    ax2 = fig.add_subplot(gs[1])
    final_values = posterior[:, :, -1].values.flatten()
    sns.kdeplot(
        y=final_values, ax=ax2, color=hist_color, fill=True, alpha=0.4, **kwargs
    )

    # Plot expected value line in KDE plot
    ax2.axhline(
        y=expected_value[-1].values.mean(), color="black", linestyle="--", linewidth=2
    )
    ax2.set_title("Distribution at T")
    ax2.set_xlabel("Density")
    ax2.set_yticklabels([])  # Hide y tick labels to avoid duplication
    ax2.grid(True)

    plt.tight_layout()
    return fig


plot_posterior(
    posterior=basic_mmm.fit_result["channel_contributions"].sum(dim="channel")
);
../../_images/2909f2fa0b3529658002dcaf9321000fe38c28951d0f15468de2104987b43cbd.png

But why the contributions are over estimated? The contributions are poorly estimated because the parameters of our transformations are also poorly estimated. For example, the parameters that control the maximum effectiveness (At the saturation function) of each channel are much higher than the real ones for both channels.

fig = basic_mmm.plot_channel_parameter(param_name="saturation_alpha", figsize=(9, 5))
ax = fig.axes[0]
ax.axvline(
    x=(real_alpha[0] / df.y.max()), color="C0", linestyle="--", label=r"$\alpha_1$"
)
ax.axvline(
    x=(real_alpha[1] / df.y.max()), color="C1", linestyle="--", label=r"$\alpha_2$"
)
ax.legend(loc="upper right");
../../_images/327a90784f8d7d71ed9c55ca1ff7fea8d431fc206ad9d02af247601c3e6841bb.png

What would change if we now consider time as a factor in our model?

Now we can do this by adding the following parameter to the initialization of our time_varying_media model and changing it to True.

mmm = MMM(
    date_column="date_week",
    channel_columns=["x1", "x2"],
    control_columns=["event_1", "event_2"],
    yearly_seasonality=yearly_seasonality,
    adstock=GeometricAdstock(l_max=adstock_max_lag),
    saturation=MichaelisMentenSaturation(),
    time_varying_media=True,
)

Note: By doing this, now our model config will have a new key media_tvp_config with the parameters that control the priors of our HSGP.

mmm.model_config["media_tvp_config"]
{'m': 200,
 'L': None,
 'eta_lam': 1,
 'ls_mu': None,
 'ls_sigma': 10,
 'cov_func': None}
mmm.fit(
    X=X,
    y=y,
    target_accept=0.95,
    draws=4000,
    tune=2000,
    chains=4,
    nuts_sampler="numpyro",
    random_seed=rng,
)
arviz.InferenceData
    • <xarray.Dataset> Size: 326MB
      Dimensions:                                      (chain: 4, draw: 4000, m: 200,
                                                        control: 2, fourier_mode: 4,
                                                        channel: 2, date: 179)
      Coordinates:
        * chain                                        (chain) int64 32B 0 1 2 3
        * draw                                         (draw) int64 32kB 0 1 ... 3999
        * m                                            (m) int64 2kB 0 1 2 ... 198 199
        * control                                      (control) <U7 56B 'event_1' ...
        * fourier_mode                                 (fourier_mode) <U11 176B 'si...
        * channel                                      (channel) <U2 16B 'x1' 'x2'
        * date                                         (date) datetime64[ns] 1kB 20...
      Data variables: (12/17)
          intercept                                    (chain, draw) float64 128kB ...
          media_temporal_latent_multiplier_hsgp_coefs  (chain, draw, m) float64 26MB ...
          gamma_control                                (chain, draw, control) float64 256kB ...
          gamma_fourier                                (chain, draw, fourier_mode) float64 512kB ...
          adstock_alpha                                (chain, draw, channel) float64 256kB ...
          saturation_alpha                             (chain, draw, channel) float64 256kB ...
          ...                                           ...
          media_temporal_latent_multiplier             (chain, draw, date) float64 23MB ...
          channel_contributions                        (chain, draw, date, channel) float64 46MB ...
          control_contributions                        (chain, draw, date, control) float64 46MB ...
          fourier_contributions                        (chain, draw, date, fourier_mode) float64 92MB ...
          yearly_seasonality_contribution              (chain, draw, date) float64 23MB ...
          mu                                           (chain, draw, date) float64 23MB ...
      Attributes:
          created_at:     2024-07-01T15:32:08.125144
          arviz_version:  0.17.1

    • <xarray.Dataset> Size: 816kB
      Dimensions:          (chain: 4, draw: 4000)
      Coordinates:
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 32kB 0 1 2 3 4 5 ... 3995 3996 3997 3998 3999
      Data variables:
          acceptance_rate  (chain, draw) float64 128kB 1.0 0.991 ... 0.9909 0.9319
          step_size        (chain, draw) float64 128kB 0.03353 0.03353 ... 0.03417
          diverging        (chain, draw) bool 16kB False False False ... False False
          energy           (chain, draw) float64 128kB -37.37 -59.82 ... -69.04 -81.03
          n_steps          (chain, draw) int64 128kB 127 127 127 127 ... 127 127 127
          tree_depth       (chain, draw) int64 128kB 7 7 7 7 7 7 7 7 ... 7 7 7 7 7 7 7
          lp               (chain, draw) float64 128kB -163.9 -162.1 ... -179.3 -183.4
      Attributes:
          created_at:     2024-07-01T15:32:08.150129
          arviz_version:  0.17.1

    • <xarray.Dataset> Size: 3kB
      Dimensions:  (date: 179)
      Coordinates:
        * date     (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30
      Data variables:
          y        (date) float64 1kB 0.522 0.5735 0.5612 ... 0.5669 0.5082 0.4131
      Attributes:
          created_at:                 2024-07-01T15:32:08.155469
          arviz_version:              0.17.1
          inference_library:          numpyro
          inference_library_version:  0.14.0
          sampling_time:              91.780659

    • <xarray.Dataset> Size: 14kB
      Dimensions:       (date: 179, channel: 2, control: 2, fourier_mode: 4)
      Coordinates:
        * date          (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30
        * channel       (channel) <U2 16B 'x1' 'x2'
        * control       (control) <U7 56B 'event_1' 'event_2'
        * fourier_mode  (fourier_mode) <U11 176B 'sin_order_1' ... 'cos_order_2'
      Data variables:
          channel_data  (date, channel) float64 3kB 0.2957 0.0 0.9413 ... 0.1273 0.0
          time_index    (date) int32 716B 0 1 2 3 4 5 6 ... 173 174 175 176 177 178
          control_data  (date, control) float64 3kB 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0
          fourier_data  (date, fourier_mode) float64 6kB 0.9999 -0.01183 ... -0.4547
      Attributes:
          created_at:                 2024-07-01T15:32:08.166136
          arviz_version:              0.17.1
          inference_library:          numpyro
          inference_library_version:  0.14.0
          sampling_time:              91.780659

    • <xarray.Dataset> Size: 10kB
      Dimensions:    (index: 179)
      Coordinates:
        * index      (index) int64 1kB 0 1 2 3 4 5 6 7 ... 172 173 174 175 176 177 178
      Data variables:
          date_week  (index) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30
          x1         (index) float64 1kB 0.2948 0.9383 0.1397 ... 0.9225 0.9364 0.1269
          x2         (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
          event_1    (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0
          event_2    (index) float64 1kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0
          y          (index) float64 1kB 7.764 8.529 8.347 7.955 ... 8.431 7.558 6.143

All divergences disappeared, this is a good sign! 🚀

Let’s check our samples!

az.summary(
    data=mmm.fit_result,
    var_names=[
        "intercept",
        "y_sigma",
        "gamma_control",
        "gamma_fourier",
    ],
)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
intercept 0.406 0.008 0.391 0.419 0.0 0.0 15031.0 11911.0 1.0
y_sigma 0.016 0.001 0.015 0.018 0.0 0.0 21062.0 11466.0 1.0
gamma_control[event_1] -0.247 0.017 -0.279 -0.216 0.0 0.0 27809.0 11307.0 1.0
gamma_control[event_2] 0.425 0.017 0.395 0.458 0.0 0.0 24666.0 11448.0 1.0
gamma_fourier[sin_order_1] 0.168 0.002 0.163 0.172 0.0 0.0 19360.0 13393.0 1.0
gamma_fourier[cos_order_1] -0.035 0.002 -0.040 -0.031 0.0 0.0 24944.0 13189.0 1.0
gamma_fourier[sin_order_2] 0.103 0.002 0.099 0.106 0.0 0.0 29689.0 12210.0 1.0
gamma_fourier[cos_order_2] 0.169 0.002 0.166 0.173 0.0 0.0 27382.0 12700.0 1.0
_ = az.plot_trace(
    data=mmm.fit_result,
    var_names=[
        "intercept",
        "y_sigma",
        "gamma_control",
        "gamma_fourier",
    ],
    compact=True,
    backend_kwargs={"figsize": (12, 10), "layout": "constrained"},
)
plt.gcf().suptitle("Model Trace", fontsize=16);
../../_images/f979614fe35cf2d3d933b4d41d82fff494194aac8450f0ee1e17328629ba1575.png
az.summary(
    data=mmm.fit_result,
    var_names=["adstock_alpha", "saturation_lam", "saturation_alpha"],
)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
adstock_alpha[x1] 0.502 0.037 0.434 0.571 0.000 0.000 15550.0 11970.0 1.0
adstock_alpha[x2] 0.342 0.029 0.286 0.396 0.000 0.000 8569.0 10571.0 1.0
saturation_lam[x1] 0.323 0.092 0.176 0.492 0.001 0.001 10213.0 8291.0 1.0
saturation_lam[x2] 0.359 0.083 0.214 0.515 0.001 0.001 8211.0 10240.0 1.0
saturation_alpha[x1] 0.185 0.018 0.153 0.217 0.000 0.000 10595.0 8103.0 1.0
saturation_alpha[x2] 0.251 0.026 0.207 0.300 0.000 0.000 8030.0 9912.0 1.0
_ = az.plot_trace(
    data=mmm.fit_result,
    var_names=["adstock_alpha", "saturation_lam", "saturation_alpha"],
    compact=True,
    backend_kwargs={"figsize": (12, 10), "layout": "constrained"},
)
plt.gcf().suptitle("Model Trace", fontsize=16);
../../_images/b38af9f3a479e0c6beb6f421cfc4ed0486f6784854f6e25b3314df724fd6a6fe.png
az.summary(
    data=mmm.fit_result,
    var_names=[
        "media_temporal_latent_multiplier_eta",
        "media_temporal_latent_multiplier_ls",
        "media_temporal_latent_multiplier_hsgp_coefs",
    ],
)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
media_temporal_latent_multiplier_eta 1.780 0.573 0.884 2.850 0.007 0.005 6167.0 9314.0 1.0
media_temporal_latent_multiplier_ls 102.411 9.329 84.448 119.532 0.065 0.047 21144.0 12357.0 1.0
media_temporal_latent_multiplier_hsgp_coefs[0] 0.267 0.328 -0.279 0.906 0.003 0.002 11228.0 10144.0 1.0
media_temporal_latent_multiplier_hsgp_coefs[1] 0.882 0.235 0.458 1.314 0.003 0.002 5679.0 8818.0 1.0
media_temporal_latent_multiplier_hsgp_coefs[2] 1.815 0.498 0.911 2.754 0.006 0.005 6212.0 8490.0 1.0
... ... ... ... ... ... ... ... ... ...
media_temporal_latent_multiplier_hsgp_coefs[195] -0.007 0.995 -1.889 1.816 0.006 0.009 28075.0 11523.0 1.0
media_temporal_latent_multiplier_hsgp_coefs[196] 0.015 1.015 -1.938 1.864 0.007 0.009 23927.0 11588.0 1.0
media_temporal_latent_multiplier_hsgp_coefs[197] -0.006 1.002 -1.800 1.908 0.006 0.009 27312.0 12634.0 1.0
media_temporal_latent_multiplier_hsgp_coefs[198] 0.002 0.994 -1.821 1.912 0.006 0.009 28122.0 12062.0 1.0
media_temporal_latent_multiplier_hsgp_coefs[199] -0.007 0.999 -1.858 1.906 0.006 0.009 30812.0 11862.0 1.0

202 rows × 9 columns

_ = az.plot_trace(
    data=mmm.fit_result,
    var_names=[
        "media_temporal_latent_multiplier_eta",
        "media_temporal_latent_multiplier_ls",
        "media_temporal_latent_multiplier_hsgp_coefs",
    ],
    compact=True,
    backend_kwargs={"figsize": (12, 10), "layout": "constrained"},
)
plt.gcf().suptitle("Model Trace", fontsize=16);
../../_images/5e82dddbebc64eec997072899a6f1979a12a2a3c5dc26d91bc501e5d4c2ac70a.png

Everything seems fine for now, there is nothing that raises red flags when analyzing our trace. But what about the decomposition?

mmm.plot_components_contributions();
../../_images/6da9318f12e20de94c1f31873009d0521996df141ce3a8724bbfe0f1dc07de47.png

The decomposition looks much better now 🔥 It seems that we are estimating each parameter better, and there are no obvious trade-offs between components!

Let’s see how well the original parameters have managed to be recovered?

fig = mmm.plot_channel_parameter(param_name="saturation_alpha", figsize=(9, 5))
ax = fig.axes[0]
ax.axvline(
    x=(real_alpha[0] / df.y.max()), color="C0", linestyle="--", label=r"$\alpha_1$"
)
ax.axvline(
    x=(real_alpha[1] / df.y.max()), color="C1", linestyle="--", label=r"$\alpha_2$"
)
ax.legend(loc="upper right");
../../_images/40e173e809c4e056e9edaa90062995e1cfe6fae6bf028882e207fd8b3b078376.png
fig = mmm.plot_channel_parameter(param_name="saturation_lam", figsize=(9, 5))
ax = fig.axes[0]
ax.axvline(
    x=(real_lam[0] / df.x1.max()), color="C0", linestyle="--", label=r"$\lambda_1$"
)
ax.axvline(
    x=(real_lam[1] / df.x2.max()), color="C1", linestyle="--", label=r"$\lambda_2$"
)
ax.legend(loc="upper right");
../../_images/1d78e973b19a92a82c16e5bbeb9eaab72aefa9aeb9e8b391f30ae908614460ca.png

The parameters of the saturation function seem to be recovered practically in their entirety for both channels! This is great 🎉

Let’s see how much we manage to recover from the true variation. We can analyze the variable media_temporal_latent_multiplier and compare it against the original variable used in the original process.

media_latent_factor = mmm.fit_result["media_temporal_latent_multiplier"].quantile(
    [0.025, 0.50, 0.975], dim=["chain", "draw"]
)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(15, 10))
sns.lineplot(
    x=mmm.fit_result.coords["date"],
    y=media_latent_factor.sel(quantile=0.5),
    label="Predicted",
    color="blue",
)

sns.lineplot(
    x=mmm.fit_result.coords["date"],
    y=df["hidden_latent_media_fluctuation"],
    label="Real",
    color="Black",
    linestyle="--",
)


ax.fill_between(
    mmm.fit_result.coords["date"],
    media_latent_factor.sel(quantile=0.025),
    media_latent_factor.sel(quantile=0.975),
    alpha=0.3,
)
ax.set_title("HSGP")
ax.set_xlabel("Date")
ax.set_ylabel("Latent Factor")
ax.tick_params(axis="x", rotation=45)
ax.legend()
plt.show()
../../_images/e770a112bb80cf54d969255d8c06a375f31fc030c6027a846a8826c0ba585038.png

Incredible 🚀 we recovered the latent process almost perfectly. Although it seems a little overestimated, it is quite close to the real thing!

recover_channel_contributions = mmm.fit_result["channel_contributions"].quantile(
    [0.025, 0.50, 0.975], dim=["chain", "draw"]
)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(15, 10))
sns.lineplot(
    x=mmm.fit_result.coords["date"],
    y=recover_channel_contributions.sel(quantile=0.5).sum(axis=-1),
    label="Posterior Predictive Contribution",
    color="purple",
)

sns.lineplot(
    x=mmm.fit_result.coords["date"],
    y=df["channel_contributions"] / df["y"].max(),
    label="Real",
    color="purple",
    linestyle="--",
)


ax.fill_between(
    mmm.fit_result.coords["date"],
    recover_channel_contributions.sel(quantile=0.025).sum(axis=-1),
    recover_channel_contributions.sel(quantile=0.975).sum(axis=-1),
    alpha=0.3,
)
ax.set_title("Recover contribution")
ax.set_xlabel("Date")
ax.set_ylabel("Sales")
ax.tick_params(axis="x", rotation=45)
ax.legend()
plt.show()
../../_images/7803fbe8c75108acafc8e1423b22c781968433ae11c1e7a6502d5e65703b24f4.png

This is reflected when comparing the recovered contribution against the original. We can see that they are exactly the same!

fig, ax = plt.subplots(
    nrows=2, ncols=1, figsize=(10, 7), sharex=True, sharey=True, layout="constrained"
)

sns.lineplot(
    x="date_week",
    y="x1_contribution",
    data=df,
    color="C0",
    ax=ax[0],
    label="Real Contribution x1",
)
ax[0].fill_between(
    mmm.fit_result.coords["date"],
    recover_channel_contributions.sel(quantile=0.025).sel(channel="x1") * df.y.max(),
    recover_channel_contributions.sel(quantile=0.975).sel(channel="x1") * df.y.max(),
    alpha=0.3,
    color="C0",
    label="Posterior Contribution x1",
)
ax[0].legend()

sns.lineplot(
    x="date_week",
    y="x2_contribution",
    data=df,
    color="C1",
    ax=ax[1],
    label="Real Contribution x2",
)
ax[1].fill_between(
    mmm.fit_result.coords["date"],
    recover_channel_contributions.sel(quantile=0.025).sel(channel="x2") * df.y.max(),
    recover_channel_contributions.sel(quantile=0.975).sel(channel="x2") * df.y.max(),
    alpha=0.3,
    color="C1",
    label="Posterior Contribution x2",
)

ax[1].set(xlabel="weeks")
fig.suptitle("Media Contribution per Channel", fontsize=16)
ax[1].legend();
../../_images/9d606a076c5568d6ea0c8dc5d95410f263f97ff3882cdc90e563b1bc46744dc1.png

Contributions per channel were also recovered correctly, unlike our first model!

Insights#

The Bayesian approach not only facilitates hypothesis testing and model validation but also provides a structured way to incorporate prior knowledge and test various assumptions about the data-generating process. The occurrence of divergences, as observed in our initial model fitting, underscores the importance of model specification and alignment with the underlying data structure. These divergences serve as a diagnostic tool, guiding further refinement and improvement of the model.

In summary, using PyMC-Marketing to build time-aware MMM models enables marketers to gain deeper insights and achieve a more accurate understanding of the impact of their efforts. This methodology enhances the ability to make data-driven decisions, optimize marketing strategies, and ultimately drive better business outcomes.

Conclusion#

Throughout this notebook, we have explored the implementation of a Bayesian Marketing Mix Model (MMM) using PyMC, comparing the performance and insights gained from models with and without a time component. The key takeaway from our analysis is the significant advantage of incorporating time-varying factors into MMM.

Uncovering Real Causal Relationships#

By integrating a time component, we can uncover the true causal relationships between our target variable (such as sales) and our marketing efforts. The traditional approach, which neglects temporal dynamics, often fails to capture the complex and fluctuating nature of real-world marketing performance. In contrast, the time-dependent model provides a more accurate and nuanced understanding of how marketing activities influence outcomes over time.

Advantages of PyMC-Marketing#

PyMC-Marketing, offers powerful tools to implement these advanced methodologies. The new features and functionalities, including the handling differents adstock effects, saturation effects, and Hilbert Space Gaussian processes (HSGP) for modeling time-varying components, allow for more precise and reliable modeling of marketing data.

We encourage practitioners to leverage these advanced techniques and the capabilities of PyMC-Marketing to improve their marketing analytics and gain a competitive edge in their strategic planning.

Bonus#

This notebook simulated a very simple variation, it is possible that the true time-dependent latent processes hidden in your data are more complex therefore, you will need to use priors to guide your model to find the real data.

One way to achieve this is through modifying the model configuration.

custom_config = {
    "intercept": Prior("HalfNormal", sigma=0.5),
    "saturation_alpha": Prior(
        "Gamma", mu=np.array([0.3, 0.4]), sigma=np.array([0.2, 0.2]), dims="channel"
    ),
    "saturation_lam": Prior("Beta", alpha=4, beta=4, dims="channel"),
}

media_tvp_config = {
    "media_tvp_config": {
        "m": 50,
        "L": 30,
        "eta_lam": 3,
        "ls_mu": 5,
        "ls_sigma": 5,
        "cov_func": None,
    }
}

custom_config = {**mmm.model_config, **custom_config, **media_tvp_config}
custom_config
{'intercept': Prior("HalfNormal", sigma=0.5),
 'likelihood': Prior("Normal", sigma=Prior("HalfNormal", sigma=2), dims="date"),
 'gamma_control': Prior("Normal", mu=0, sigma=2, dims="control"),
 'gamma_fourier': Prior("Laplace", mu=0, b=1, dims="fourier_mode"),
 'media_tvp_config': {'m': 50,
  'L': 30,
  'eta_lam': 3,
  'ls_mu': 5,
  'ls_sigma': 5,
  'cov_func': None},
 'adstock_alpha': Prior("Beta", alpha=1, beta=3, dims="channel"),
 'saturation_alpha': Prior("Gamma", mu=[0.3 0.4], sigma=[0.2 0.2], dims="channel"),
 'saturation_lam': Prior("Beta", alpha=4, beta=4, dims="channel")}
mmm_calibrated = MMM(
    date_column="date_week",
    channel_columns=["x1", "x2"],
    control_columns=["event_1", "event_2"],
    yearly_seasonality=yearly_seasonality,
    adstock=GeometricAdstock(l_max=adstock_max_lag),
    saturation=MichaelisMentenSaturation(),
    time_varying_media=True,
)
%load_ext watermark
%watermark -n -u -v -iv -w -p pytensor
Last updated: Mon Jul 01 2024

Python implementation: CPython
Python version       : 3.10.13
IPython version      : 8.22.2

pytensor: 2.20.0

numpy     : 1.26.4
pymc      : 5.14.0
seaborn   : 0.13.2
arviz     : 0.17.1
matplotlib: 3.8.3
pandas    : 2.2.1

Watermark: 2.4.3