Model deployment#

One of the main goals of PyMC-Marketing is to facilitate the deployment of its models.

This is achieved by building our models on top of ModelBuilder, a brand new PyMC experimental feature that offers a scikit-learn-like API and makes PyMC models easy to deploy.

PyMC-marketing models inherit 2 easy-to-use methods: save and load that can be used after the model has been fitted. All models can be configured with two standard dictionaries: model_config and sampler_config that are serialized during save and persisted after load, allowing model reuse across workflows.

We will illustrate this functionality with the example model described in the MMM Example Notebook. For sake of generality, we ommit most technical details here.

import arviz as az
import numpy as np
import pandas as pd

from pymc_marketing.mmm import DelayedSaturatedMMM
seed = sum(map(ord, "mmm"))
rng = np.random.default_rng(seed=seed)

Let’s load the dataset:

url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/mmm_example.csv"
df = pd.read_csv(url)

columns_to_keep = [
    "date_week",
    "y",
    "x1",
    "x2",
    "event_1",
    "event_2",
    "dayofyear",
]

data = df[columns_to_keep].copy()
data["t"] = np.arange(df.shape[0])
data.head()
date_week y x1 x2 event_1 event_2 dayofyear t
0 2018-04-02 3984.662237 0.318580 0.0 0.0 0.0 92 0
1 2018-04-09 3762.871794 0.112388 0.0 0.0 0.0 99 1
2 2018-04-16 4466.967388 0.292400 0.0 0.0 0.0 106 2
3 2018-04-23 3864.219373 0.071399 0.0 0.0 0.0 113 3
4 2018-04-30 4441.625278 0.386745 0.0 0.0 0.0 120 4

But for our model we need much smaller dataset, many of the previous features were contributing to generation of others, now as our target variable is computed we can filter out not needed columns:

Model and sampling configuration#

Model configuration#

We first illustrate the use of model_config to define custom priors within the model.

Because there are potentially many variables that can be configured, each model provides a default_model_config attribute. This will allow you to see which settings are available by default and only define the ones you need to change.

We need to create a dummy model to be able to see the configuration dictionary.

dummy_model = DelayedSaturatedMMM(
    date_column="date_week",
    channel_columns=["x1", "x2"],
    control_columns=[
        "event_1",
        "event_2",
        "t",
    ],
    adstock_max_lag=8,
    yearly_seasonality=2,
)
dummy_model.default_model_config
{'intercept': {'mu': 0, 'sigma': 2},
 'beta_channel': {'sigma': 2, 'dims': ('channel',)},
 'alpha': {'alpha': 1, 'beta': 3, 'dims': ('channel',)},
 'lam': {'alpha': 3, 'beta': 1, 'dims': ('channel',)},
 'sigma': {'sigma': 2},
 'gamma_control': {'mu': 0, 'sigma': 2, 'dims': ('control',)},
 'mu': {'dims': ('date',)},
 'likelihood': {'dims': ('date',)},
 'gamma_fourier': {'mu': 0, 'b': 1, 'dims': 'fourier_mode'}}

We can change the parameters that go into the distribution of each term. In this case we’ll just simply replace the sigma for beta_channel with a custom one:

n_channels = 2

total_spend_per_channel = data[["x1", "x2"]].sum(axis=0)
spend_share = total_spend_per_channel / total_spend_per_channel.sum()
spend_share

# The scale necessary to make a HalfNormal distribution have unit variance
HALFNORMAL_SCALE = 1 / np.sqrt(1 - 2 / np.pi)
prior_sigma = HALFNORMAL_SCALE * n_channels * spend_share.to_numpy()
prior_sigma
array([2.1775326 , 1.14026088])
my_model_config = {'beta_channel': {'sigma': prior_sigma, 'dims': ('channel',)}}

As mentioned in the original notebook: “For the prior specification there is no right or wrong answer. It all depends on the data, the context and the assumptions you are willing to make. It is always recommended to do some prior predictive sampling and sensitivity analysis to check the impact of the priors on the posterior. We skip this here for the sake of simplicity. If you are not sure about specific priors, the DelayedSaturatedMMM class has some default priors that you can use as a starting point.”

Sampling configuration#

The second feature we can customize is sampler_config. Similar to model_config, it’s a dictionary that gets saved and contains things you would usually pass to the fit() kwargs. It’s not mandatory to create your own sampler_config. The default DelayedSaturatedMMM.sampler_config is empty because the default sampling parameters usually prove sufficient for a start.

dummy_model.default_sampler_config
{}
my_sampler_config = {
    'tune':1000,
    'draws':1000,
    'chains':4,
    'target_accept':0.95,
}

Let’s finally assemble our model!

mmm = DelayedSaturatedMMM(
    model_config = my_model_config,
    sampler_config = my_sampler_config,
    date_column="date_week",
    channel_columns=["x1", "x2"],
    control_columns=[
        "event_1",
        "event_2",
        "t",
    ],
    adstock_max_lag=8,
    yearly_seasonality=2,
)

We can confirm our settings are being used

mmm.model_config["beta_channel"]
{'sigma': array([2.1775326 , 1.14026088]), 'dims': ('channel',)}
mmm.sampler_config
{'tune': 1000, 'draws': 1000, 'chains': 4, 'target_accept': 0.95}

Model Fitting#

Note that we didn’t pass the dataset to the class constructor itself. This is done to mimick the scikit-learn API, and make it easier to get started on PyMC-Marketing models.

# Split X, and y
X = data.drop('y',axis=1)
y = data['y']

All that’s left now is to finally fit the model:

As you can see below, you can still pass the sampler kwargs directly to fit() method. However, only those kwargs passed using sampler_config will be saved and reused after loading the model.

mmm.fit(X=X, y=y, random_seed=rng)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, beta_channel, alpha, lam, sigma, gamma_control, gamma_fourier]
100.00% [8000/8000 01:11<00:00 Sampling 4 chains, 0 divergences]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 71 seconds.
Sampling: [alpha, beta_channel, gamma_control, gamma_fourier, intercept, lam, likelihood, sigma]
Sampling: [likelihood]
100.00% [4000/4000 00:00<00:00]
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:                    (chain: 4, draw: 1000, control: 3,
                                      fourier_mode: 4, channel: 2, date: 179)
      Coordinates:
        * chain                      (chain) int64 0 1 2 3
        * draw                       (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999
        * control                    (control) <U7 'event_1' 'event_2' 't'
        * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...
        * channel                    (channel) <U2 'x1' 'x2'
        * date                       (date) <U10 '2018-04-02' ... '2021-08-30'
      Data variables: (12/13)
          intercept                  (chain, draw) float64 0.3852 0.3278 ... 0.311
          gamma_control              (chain, draw, control) float64 0.2353 ... 0.00...
          gamma_fourier              (chain, draw, fourier_mode) float64 0.004958 ....
          beta_channel               (chain, draw, channel) float64 0.3975 ... 0.3032
          alpha                      (chain, draw, channel) float64 0.4327 ... 0.2364
          lam                        (chain, draw, channel) float64 2.864 ... 2.357
          ...                         ...
          channel_adstock            (chain, draw, date, channel) float64 0.1816 .....
          channel_adstock_saturated  (chain, draw, date, channel) float64 0.2543 .....
          channel_contributions      (chain, draw, date, channel) float64 0.1011 .....
          control_contributions      (chain, draw, date, control) float64 0.0 ... 0...
          fourier_contributions      (chain, draw, date, fourier_mode) float64 0.00...
          mu                         (chain, draw, date) float64 0.4922 ... 0.6067
      Attributes:
          created_at:                 2023-08-03T12:44:52.592010
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.7.0
          sampling_time:              71.22048568725586
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:     (chain: 4, draw: 1000, date: 179)
      Coordinates:
        * chain       (chain) int64 0 1 2 3
        * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999
        * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'
      Data variables:
          likelihood  (chain, draw, date) float64 0.4729 0.4781 ... 0.5132 0.5917
      Attributes:
          created_at:                 2023-08-03T12:44:55.266426
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.7.0

    • <xarray.Dataset>
      Dimensions:                (chain: 4, draw: 1000)
      Coordinates:
        * chain                  (chain) int64 0 1 2 3
        * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999
      Data variables: (12/17)
          process_time_diff      (chain, draw) float64 0.02484 0.03027 ... 0.04276
          lp                     (chain, draw) float64 351.2 354.5 ... 353.0 353.4
          max_energy_error       (chain, draw) float64 -0.03674 0.03495 ... 0.1954
          step_size              (chain, draw) float64 0.07388 0.07388 ... 0.06563
          reached_max_treedepth  (chain, draw) bool False False False ... False False
          perf_counter_diff      (chain, draw) float64 0.02507 0.03056 ... 0.04421
          ...                     ...
          index_in_trajectory    (chain, draw) int64 -31 33 32 -28 43 ... 44 -9 38 29
          energy                 (chain, draw) float64 -344.9 -345.2 ... -345.0 -346.9
          largest_eigval         (chain, draw) float64 nan nan nan nan ... nan nan nan
          n_steps                (chain, draw) float64 63.0 63.0 63.0 ... 63.0 127.0
          tree_depth             (chain, draw) int64 6 6 6 6 6 6 6 6 ... 6 6 6 6 6 6 7
          diverging              (chain, draw) bool False False False ... False False
      Attributes:
          created_at:                 2023-08-03T12:44:52.608151
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.7.0
          sampling_time:              71.22048568725586
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:                    (chain: 1, draw: 500, channel: 2, control: 3,
                                      fourier_mode: 4, date: 179)
      Coordinates:
        * chain                      (chain) int64 0
        * draw                       (draw) int64 0 1 2 3 4 5 ... 495 496 497 498 499
        * channel                    (channel) <U2 'x1' 'x2'
        * control                    (control) <U7 'event_1' 'event_2' 't'
        * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...
        * date                       (date) <U10 '2018-04-02' ... '2021-08-30'
      Data variables: (12/13)
          beta_channel               (chain, draw, channel) float64 1.418 ... 1.259
          gamma_control              (chain, draw, control) float64 3.665 ... -1.201
          gamma_fourier              (chain, draw, fourier_mode) float64 0.08728 .....
          channel_contributions      (chain, draw, date, channel) float64 0.2905 .....
          intercept                  (chain, draw) float64 -3.807 -1.606 ... -0.598
          lam                        (chain, draw, channel) float64 1.372 ... 1.678
          ...                         ...
          mu                         (chain, draw, date) float64 -3.19 ... -212.0
          alpha                      (chain, draw, channel) float64 0.0521 ... 0.7527
          channel_adstock            (chain, draw, date, channel) float64 0.303 ......
          control_contributions      (chain, draw, date, control) float64 0.0 ... -...
          channel_adstock_saturated  (chain, draw, date, channel) float64 0.2049 .....
          sigma                      (chain, draw) float64 2.361 0.3904 ... 2.132
      Attributes:
          created_at:                 2023-08-03T12:44:54.644498
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.7.0

    • <xarray.Dataset>
      Dimensions:     (chain: 1, draw: 500, date: 179)
      Coordinates:
        * chain       (chain) int64 0
        * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
        * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'
      Data variables:
          likelihood  (chain, draw, date) float64 -7.646 -1.949 ... -210.5 -215.1
      Attributes:
          created_at:                 2023-08-03T12:44:54.650068
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.7.0

    • <xarray.Dataset>
      Dimensions:     (date: 179)
      Coordinates:
        * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'
      Data variables:
          likelihood  (date) float64 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625
      Attributes:
          created_at:                 2023-08-03T12:44:52.612711
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.7.0

    • <xarray.Dataset>
      Dimensions:       (date: 179, channel: 2, control: 3, fourier_mode: 4)
      Coordinates:
        * date          (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'
        * channel       (channel) <U2 'x1' 'x2'
        * control       (control) <U7 'event_1' 'event_2' 't'
        * fourier_mode  (fourier_mode) <U11 'sin_order_1' ... 'cos_order_2'
      Data variables:
          channel_data  (date, channel) float64 0.3196 0.0 0.1128 ... 0.0 0.4403 0.0
          target        (date) float64 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625
          control_data  (date, control) float64 0.0 0.0 0.0 0.0 ... 0.0 0.0 178.0
          fourier_data  (date, fourier_mode) float64 0.9999 -0.01183 ... -0.4547
      Attributes:
          created_at:                 2023-08-03T12:44:52.614490
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.7.0

    • <xarray.Dataset>
      Dimensions:    (index: 179)
      Coordinates:
        * index      (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178
      Data variables:
          date_week  (index) object '2018-04-02' '2018-04-09' ... '2021-08-30'
          x1         (index) float64 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389
          x2         (index) float64 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.8633 0.0 0.0 0.0
          event_1    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0
          event_2    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0
          dayofyear  (index) int64 92 99 106 113 120 127 ... 207 214 221 228 235 242
          t          (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178
          y          (index) float64 3.985e+03 3.763e+03 ... 4.479e+03 4.676e+03

The fit() method automatically builds the model using the priors from model_config, and assigns the created model to our instance. You can access it as a normal attribute.

type(mmm.model)
pymc.model.Model
mmm.graphviz()
../../_images/4fe6690a8b798459e217c5c37db02f1eee3677c380ba6c93f5b60c1673235ee7.svg

posterior trace can be accessed by fit_result attribute

mmm.fit_result
<xarray.Dataset>
Dimensions:                    (chain: 4, draw: 1000, control: 3,
                                fourier_mode: 4, channel: 2, date: 179)
Coordinates:
  * chain                      (chain) int64 0 1 2 3
  * draw                       (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999
  * control                    (control) <U7 'event_1' 'event_2' 't'
  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...
  * channel                    (channel) <U2 'x1' 'x2'
  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'
Data variables: (12/13)
    intercept                  (chain, draw) float64 0.3852 0.3278 ... 0.311
    gamma_control              (chain, draw, control) float64 0.2353 ... 0.00...
    gamma_fourier              (chain, draw, fourier_mode) float64 0.004958 ....
    beta_channel               (chain, draw, channel) float64 0.3975 ... 0.3032
    alpha                      (chain, draw, channel) float64 0.4327 ... 0.2364
    lam                        (chain, draw, channel) float64 2.864 ... 2.357
    ...                         ...
    channel_adstock            (chain, draw, date, channel) float64 0.1816 .....
    channel_adstock_saturated  (chain, draw, date, channel) float64 0.2543 .....
    channel_contributions      (chain, draw, date, channel) float64 0.1011 .....
    control_contributions      (chain, draw, date, control) float64 0.0 ... 0...
    fourier_contributions      (chain, draw, date, fourier_mode) float64 0.00...
    mu                         (chain, draw, date) float64 0.4922 ... 0.6067
Attributes:
    created_at:                 2023-08-03T12:44:52.592010
    arviz_version:              0.15.1
    inference_library:          pymc
    inference_library_version:  5.7.0
    sampling_time:              71.22048568725586
    tuning_steps:               1000

If you wish to inspect the entire inference data, use the idata attribute. Within idata, you can find the entire dataset passed to the model under fit_data.

mmm.idata
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:                    (chain: 4, draw: 1000, control: 3,
                                      fourier_mode: 4, channel: 2, date: 179)
      Coordinates:
        * chain                      (chain) int64 0 1 2 3
        * draw                       (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999
        * control                    (control) <U7 'event_1' 'event_2' 't'
        * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...
        * channel                    (channel) <U2 'x1' 'x2'
        * date                       (date) <U10 '2018-04-02' ... '2021-08-30'
      Data variables: (12/13)
          intercept                  (chain, draw) float64 0.3852 0.3278 ... 0.311
          gamma_control              (chain, draw, control) float64 0.2353 ... 0.00...
          gamma_fourier              (chain, draw, fourier_mode) float64 0.004958 ....
          beta_channel               (chain, draw, channel) float64 0.3975 ... 0.3032
          alpha                      (chain, draw, channel) float64 0.4327 ... 0.2364
          lam                        (chain, draw, channel) float64 2.864 ... 2.357
          ...                         ...
          channel_adstock            (chain, draw, date, channel) float64 0.1816 .....
          channel_adstock_saturated  (chain, draw, date, channel) float64 0.2543 .....
          channel_contributions      (chain, draw, date, channel) float64 0.1011 .....
          control_contributions      (chain, draw, date, control) float64 0.0 ... 0...
          fourier_contributions      (chain, draw, date, fourier_mode) float64 0.00...
          mu                         (chain, draw, date) float64 0.4922 ... 0.6067
      Attributes:
          created_at:                 2023-08-03T12:44:52.592010
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.7.0
          sampling_time:              71.22048568725586
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:     (chain: 4, draw: 1000, date: 179)
      Coordinates:
        * chain       (chain) int64 0 1 2 3
        * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999
        * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'
      Data variables:
          likelihood  (chain, draw, date) float64 0.4729 0.4781 ... 0.5132 0.5917
      Attributes:
          created_at:                 2023-08-03T12:44:55.266426
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.7.0

    • <xarray.Dataset>
      Dimensions:                (chain: 4, draw: 1000)
      Coordinates:
        * chain                  (chain) int64 0 1 2 3
        * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999
      Data variables: (12/17)
          process_time_diff      (chain, draw) float64 0.02484 0.03027 ... 0.04276
          lp                     (chain, draw) float64 351.2 354.5 ... 353.0 353.4
          max_energy_error       (chain, draw) float64 -0.03674 0.03495 ... 0.1954
          step_size              (chain, draw) float64 0.07388 0.07388 ... 0.06563
          reached_max_treedepth  (chain, draw) bool False False False ... False False
          perf_counter_diff      (chain, draw) float64 0.02507 0.03056 ... 0.04421
          ...                     ...
          index_in_trajectory    (chain, draw) int64 -31 33 32 -28 43 ... 44 -9 38 29
          energy                 (chain, draw) float64 -344.9 -345.2 ... -345.0 -346.9
          largest_eigval         (chain, draw) float64 nan nan nan nan ... nan nan nan
          n_steps                (chain, draw) float64 63.0 63.0 63.0 ... 63.0 127.0
          tree_depth             (chain, draw) int64 6 6 6 6 6 6 6 6 ... 6 6 6 6 6 6 7
          diverging              (chain, draw) bool False False False ... False False
      Attributes:
          created_at:                 2023-08-03T12:44:52.608151
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.7.0
          sampling_time:              71.22048568725586
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:                    (chain: 1, draw: 500, channel: 2, control: 3,
                                      fourier_mode: 4, date: 179)
      Coordinates:
        * chain                      (chain) int64 0
        * draw                       (draw) int64 0 1 2 3 4 5 ... 495 496 497 498 499
        * channel                    (channel) <U2 'x1' 'x2'
        * control                    (control) <U7 'event_1' 'event_2' 't'
        * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...
        * date                       (date) <U10 '2018-04-02' ... '2021-08-30'
      Data variables: (12/13)
          beta_channel               (chain, draw, channel) float64 1.418 ... 1.259
          gamma_control              (chain, draw, control) float64 3.665 ... -1.201
          gamma_fourier              (chain, draw, fourier_mode) float64 0.08728 .....
          channel_contributions      (chain, draw, date, channel) float64 0.2905 .....
          intercept                  (chain, draw) float64 -3.807 -1.606 ... -0.598
          lam                        (chain, draw, channel) float64 1.372 ... 1.678
          ...                         ...
          mu                         (chain, draw, date) float64 -3.19 ... -212.0
          alpha                      (chain, draw, channel) float64 0.0521 ... 0.7527
          channel_adstock            (chain, draw, date, channel) float64 0.303 ......
          control_contributions      (chain, draw, date, control) float64 0.0 ... -...
          channel_adstock_saturated  (chain, draw, date, channel) float64 0.2049 .....
          sigma                      (chain, draw) float64 2.361 0.3904 ... 2.132
      Attributes:
          created_at:                 2023-08-03T12:44:54.644498
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.7.0

    • <xarray.Dataset>
      Dimensions:     (chain: 1, draw: 500, date: 179)
      Coordinates:
        * chain       (chain) int64 0
        * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
        * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'
      Data variables:
          likelihood  (chain, draw, date) float64 -7.646 -1.949 ... -210.5 -215.1
      Attributes:
          created_at:                 2023-08-03T12:44:54.650068
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.7.0

    • <xarray.Dataset>
      Dimensions:     (date: 179)
      Coordinates:
        * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'
      Data variables:
          likelihood  (date) float64 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625
      Attributes:
          created_at:                 2023-08-03T12:44:52.612711
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.7.0

    • <xarray.Dataset>
      Dimensions:       (date: 179, channel: 2, control: 3, fourier_mode: 4)
      Coordinates:
        * date          (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'
        * channel       (channel) <U2 'x1' 'x2'
        * control       (control) <U7 'event_1' 'event_2' 't'
        * fourier_mode  (fourier_mode) <U11 'sin_order_1' ... 'cos_order_2'
      Data variables:
          channel_data  (date, channel) float64 0.3196 0.0 0.1128 ... 0.0 0.4403 0.0
          target        (date) float64 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625
          control_data  (date, control) float64 0.0 0.0 0.0 0.0 ... 0.0 0.0 178.0
          fourier_data  (date, fourier_mode) float64 0.9999 -0.01183 ... -0.4547
      Attributes:
          created_at:                 2023-08-03T12:44:52.614490
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.7.0

    • <xarray.Dataset>
      Dimensions:    (index: 179)
      Coordinates:
        * index      (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178
      Data variables:
          date_week  (index) object '2018-04-02' '2018-04-09' ... '2021-08-30'
          x1         (index) float64 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389
          x2         (index) float64 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.8633 0.0 0.0 0.0
          event_1    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0
          event_2    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0
          dayofyear  (index) int64 92 99 106 113 120 127 ... 207 214 221 228 235 242
          t          (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178
          y          (index) float64 3.985e+03 3.763e+03 ... 4.479e+03 4.676e+03

Saving and loading a fitted model#

All the data passed to the model on initialization is stored in idata.attrs. This will be used later in the save() method to convert both this data and all the fit data into the netCDF format. You can read more about this format here.

The save and load method only require a path to inform where the model should be saved and loaded from.

mmm.save('my_saved_model.nc')
loaded_model = DelayedSaturatedMMM.load('my_saved_model.nc')
/home/ricardo/miniconda3/envs/pymc-marketing/lib/python3.11/site-packages/arviz/data/inference_data.py:152: UserWarning: fit_data group is not defined in the InferenceData scheme
  warnings.warn(
loaded_model.model_config["beta_channel"]
{'sigma': array([2.1775326 , 1.14026088]), 'dims': ('channel',)}
loaded_model.graphviz()
../../_images/4fe6690a8b798459e217c5c37db02f1eee3677c380ba6c93f5b60c1673235ee7.svg
loaded_model.idata
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:                    (chain: 4, draw: 1000, control: 3,
                                      fourier_mode: 4, channel: 2, date: 179)
      Coordinates:
        * chain                      (chain) int64 0 1 2 3
        * draw                       (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999
        * control                    (control) object 'event_1' 'event_2' 't'
        * fourier_mode               (fourier_mode) object 'sin_order_1' ... 'cos_o...
        * channel                    (channel) object 'x1' 'x2'
        * date                       (date) object '2018-04-02' ... '2021-08-30'
      Data variables: (12/13)
          intercept                  (chain, draw) float64 ...
          gamma_control              (chain, draw, control) float64 ...
          gamma_fourier              (chain, draw, fourier_mode) float64 ...
          beta_channel               (chain, draw, channel) float64 ...
          alpha                      (chain, draw, channel) float64 ...
          lam                        (chain, draw, channel) float64 ...
          ...                         ...
          channel_adstock            (chain, draw, date, channel) float64 ...
          channel_adstock_saturated  (chain, draw, date, channel) float64 ...
          channel_contributions      (chain, draw, date, channel) float64 ...
          control_contributions      (chain, draw, date, control) float64 ...
          fourier_contributions      (chain, draw, date, fourier_mode) float64 ...
          mu                         (chain, draw, date) float64 ...
      Attributes:
          created_at:                 2023-08-03T12:44:52.592010
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.7.0
          sampling_time:              71.22048568725586
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:     (chain: 4, draw: 1000, date: 179)
      Coordinates:
        * chain       (chain) int64 0 1 2 3
        * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999
        * date        (date) object '2018-04-02' '2018-04-09' ... '2021-08-30'
      Data variables:
          likelihood  (chain, draw, date) float64 ...
      Attributes:
          created_at:                 2023-08-03T12:44:55.266426
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.7.0

    • <xarray.Dataset>
      Dimensions:                (chain: 4, draw: 1000)
      Coordinates:
        * chain                  (chain) int64 0 1 2 3
        * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999
      Data variables: (12/17)
          process_time_diff      (chain, draw) float64 ...
          lp                     (chain, draw) float64 ...
          max_energy_error       (chain, draw) float64 ...
          step_size              (chain, draw) float64 ...
          reached_max_treedepth  (chain, draw) bool ...
          perf_counter_diff      (chain, draw) float64 ...
          ...                     ...
          index_in_trajectory    (chain, draw) int64 ...
          energy                 (chain, draw) float64 ...
          largest_eigval         (chain, draw) float64 ...
          n_steps                (chain, draw) float64 ...
          tree_depth             (chain, draw) int64 ...
          diverging              (chain, draw) bool ...
      Attributes:
          created_at:                 2023-08-03T12:44:52.608151
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.7.0
          sampling_time:              71.22048568725586
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:                    (chain: 1, draw: 500, channel: 2, control: 3,
                                      fourier_mode: 4, date: 179)
      Coordinates:
        * chain                      (chain) int64 0
        * draw                       (draw) int64 0 1 2 3 4 5 ... 495 496 497 498 499
        * channel                    (channel) object 'x1' 'x2'
        * control                    (control) object 'event_1' 'event_2' 't'
        * fourier_mode               (fourier_mode) object 'sin_order_1' ... 'cos_o...
        * date                       (date) object '2018-04-02' ... '2021-08-30'
      Data variables: (12/13)
          beta_channel               (chain, draw, channel) float64 ...
          gamma_control              (chain, draw, control) float64 ...
          gamma_fourier              (chain, draw, fourier_mode) float64 ...
          channel_contributions      (chain, draw, date, channel) float64 ...
          intercept                  (chain, draw) float64 ...
          lam                        (chain, draw, channel) float64 ...
          ...                         ...
          mu                         (chain, draw, date) float64 ...
          alpha                      (chain, draw, channel) float64 ...
          channel_adstock            (chain, draw, date, channel) float64 ...
          control_contributions      (chain, draw, date, control) float64 ...
          channel_adstock_saturated  (chain, draw, date, channel) float64 ...
          sigma                      (chain, draw) float64 ...
      Attributes:
          created_at:                 2023-08-03T12:44:54.644498
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.7.0

    • <xarray.Dataset>
      Dimensions:     (chain: 1, draw: 500, date: 179)
      Coordinates:
        * chain       (chain) int64 0
        * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
        * date        (date) object '2018-04-02' '2018-04-09' ... '2021-08-30'
      Data variables:
          likelihood  (chain, draw, date) float64 ...
      Attributes:
          created_at:                 2023-08-03T12:44:54.650068
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.7.0

    • <xarray.Dataset>
      Dimensions:     (date: 179)
      Coordinates:
        * date        (date) object '2018-04-02' '2018-04-09' ... '2021-08-30'
      Data variables:
          likelihood  (date) float64 ...
      Attributes:
          created_at:                 2023-08-03T12:44:52.612711
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.7.0

    • <xarray.Dataset>
      Dimensions:       (date: 179, channel: 2, control: 3, fourier_mode: 4)
      Coordinates:
        * date          (date) object '2018-04-02' '2018-04-09' ... '2021-08-30'
        * channel       (channel) object 'x1' 'x2'
        * control       (control) object 'event_1' 'event_2' 't'
        * fourier_mode  (fourier_mode) object 'sin_order_1' ... 'cos_order_2'
      Data variables:
          channel_data  (date, channel) float64 ...
          target        (date) float64 ...
          control_data  (date, control) float64 ...
          fourier_data  (date, fourier_mode) float64 ...
      Attributes:
          created_at:                 2023-08-03T12:44:52.614490
          arviz_version:              0.15.1
          inference_library:          pymc
          inference_library_version:  5.7.0

    • <xarray.Dataset>
      Dimensions:    (index: 179)
      Coordinates:
        * index      (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178
      Data variables:
          date_week  (index) object '2018-04-02' '2018-04-09' ... '2021-08-30'
          x1         (index) float64 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389
          x2         (index) float64 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.8633 0.0 0.0 0.0
          event_1    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0
          event_2    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0
          dayofyear  (index) int64 92 99 106 113 120 127 ... 207 214 221 228 235 242
          t          (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178
          y          (index) float64 3.985e+03 3.763e+03 ... 4.479e+03 4.676e+03

A loaded model is ready to be used for sampling and prediction, making use of the previous fitting results and data if needed.

loaded_model.sample_posterior_predictive(X, extend_idata=True, combined=False, random_seed=rng)
Sampling: [likelihood]
100.00% [4000/4000 00:01<00:00]
<xarray.Dataset>
Dimensions:     (chain: 4, draw: 1000, date: 179)
Coordinates:
  * chain       (chain) int64 0 1 2 3
  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999
  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'
Data variables:
    likelihood  (chain, draw, date) float64 0.5057 0.4536 ... 0.5621 0.5581
Attributes:
    created_at:                 2023-08-03T12:45:03.555645
    arviz_version:              0.15.1
    inference_library:          pymc
    inference_library_version:  5.7.0
az.plot_ppc(loaded_model.idata);
../../_images/5a4764013baa45143c87d9758826991a1a6fce6f690af0b07ac3b8b177ce78ae.png

Other models#

Even though this introduction is using DelayedSaturatedMMM, all other PyMC-Marketing models (MMM and CLV) provide these functionalities as well.

Summary#

The PyMC-Marketing functionalities described here are intended to facilitate model sharing among data science teams without demanding extensive modelling technical knowledge for everyone involved. We are still iterating on our API and would love to hear more feedback from our users!