{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Model deployment" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One of the main goals of PyMC-Marketing is to facilitate the deployment of its models. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is achieved by building our models on top of [ModelBuilder](https://www.pymc.io/projects/examples/en/latest/howto/model_builder.html), a brand new [PyMC experimental](https://www.pymc.io/projects/experimental/en/latest/) feature that offers a scikit-learn-like API and makes PyMC models easy to deploy.\n", "\n", "PyMC-marketing models inherit 2 easy-to-use methods: `save` and `load` that can be used after the model has been fitted. All models can be configured with two standard dictionaries: `model_config` and `sampler_config` that are serialized during `save` and persisted after `load`, allowing model reuse across workflows." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will illustrate this functionality with the example model described in the [MMM Example Notebook](https://www.pymc-marketing.io/en/stable/notebooks/mmm/mmm_example.html). For sake of generality, we ommit most technical details here." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import arviz as az\n", "import numpy as np\n", "import pandas as pd\n", "\n", "from pymc_marketing.mmm import DelayedSaturatedMMM" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "seed = sum(map(ord, \"mmm\"))\n", "rng = np.random.default_rng(seed=seed)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's load the dataset:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
date_weekyx1x2event_1event_2dayofyeart
02018-04-023984.6622370.3185800.00.00.0920
12018-04-093762.8717940.1123880.00.00.0991
22018-04-164466.9673880.2924000.00.00.01062
32018-04-233864.2193730.0713990.00.00.01133
42018-04-304441.6252780.3867450.00.00.01204
\n", "
" ], "text/plain": [ " date_week y x1 x2 event_1 event_2 dayofyear t\n", "0 2018-04-02 3984.662237 0.318580 0.0 0.0 0.0 92 0\n", "1 2018-04-09 3762.871794 0.112388 0.0 0.0 0.0 99 1\n", "2 2018-04-16 4466.967388 0.292400 0.0 0.0 0.0 106 2\n", "3 2018-04-23 3864.219373 0.071399 0.0 0.0 0.0 113 3\n", "4 2018-04-30 4441.625278 0.386745 0.0 0.0 0.0 120 4" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "url = \"https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/mmm_example.csv\"\n", "df = pd.read_csv(url)\n", "\n", "columns_to_keep = [\n", " \"date_week\",\n", " \"y\",\n", " \"x1\",\n", " \"x2\",\n", " \"event_1\",\n", " \"event_2\",\n", " \"dayofyear\",\n", "]\n", "\n", "data = df[columns_to_keep].copy()\n", "data[\"t\"] = np.arange(df.shape[0])\n", "data.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But for our model we need much smaller dataset, many of the previous features were contributing to generation of others, now as our target variable is computed we can filter out not needed columns:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model and sampling configuration" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model configuration" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We first illustrate the use of `model_config` to define custom priors within the model.\n", "\n", "Because there are potentially many variables that can be configured, each model provides a `default_model_config` attribute. This will allow you to see which settings are available by default and only define the ones you need to change.\n", "\n", "We need to create a dummy model to be able to see the configuration dictionary." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'intercept': {'mu': 0, 'sigma': 2},\n", " 'beta_channel': {'sigma': 2, 'dims': ('channel',)},\n", " 'alpha': {'alpha': 1, 'beta': 3, 'dims': ('channel',)},\n", " 'lam': {'alpha': 3, 'beta': 1, 'dims': ('channel',)},\n", " 'sigma': {'sigma': 2},\n", " 'gamma_control': {'mu': 0, 'sigma': 2, 'dims': ('control',)},\n", " 'mu': {'dims': ('date',)},\n", " 'likelihood': {'dims': ('date',)},\n", " 'gamma_fourier': {'mu': 0, 'b': 1, 'dims': 'fourier_mode'}}" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dummy_model = DelayedSaturatedMMM(\n", " date_column=\"date_week\",\n", " channel_columns=[\"x1\", \"x2\"],\n", " control_columns=[\n", " \"event_1\",\n", " \"event_2\",\n", " \"t\",\n", " ],\n", " adstock_max_lag=8,\n", " yearly_seasonality=2,\n", ")\n", "dummy_model.default_model_config" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can change the parameters that go into the distribution of each term.\n", "In this case we'll just simply replace the `sigma` for `beta_channel` with a custom one:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([2.1775326 , 1.14026088])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "n_channels = 2\n", "\n", "total_spend_per_channel = data[[\"x1\", \"x2\"]].sum(axis=0)\n", "spend_share = total_spend_per_channel / total_spend_per_channel.sum()\n", "spend_share\n", "\n", "# The scale necessary to make a HalfNormal distribution have unit variance\n", "HALFNORMAL_SCALE = 1 / np.sqrt(1 - 2 / np.pi)\n", "prior_sigma = HALFNORMAL_SCALE * n_channels * spend_share.to_numpy()\n", "prior_sigma" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "my_model_config = {'beta_channel': {'sigma': prior_sigma, 'dims': ('channel',)}}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As mentioned in the original notebook: _\"For the prior specification there is no right or wrong answer. It all depends on the data, the context and the assumptions you are willing to make. It is always recommended to do some prior predictive sampling and sensitivity analysis to check the impact of the priors on the posterior. We skip this here for the sake of simplicity. If you are not sure about specific priors, the `DelayedSaturatedMMM` class has some default priors that you can use as a starting point.\"_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sampling configuration" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The second feature we can customize is `sampler_config`. Similar to `model_config`, it's a dictionary that gets saved and contains things you would usually pass to the `fit()` kwargs. It's not mandatory to create your own `sampler_config`. The default `DelayedSaturatedMMM.sampler_config` is empty because the default sampling parameters usually prove sufficient for a start." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dummy_model.default_sampler_config" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "my_sampler_config = {\n", " 'tune':1000,\n", " 'draws':1000,\n", " 'chains':4,\n", " 'target_accept':0.95,\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's finally assemble our model!" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "mmm = DelayedSaturatedMMM(\n", " model_config = my_model_config,\n", " sampler_config = my_sampler_config,\n", " date_column=\"date_week\",\n", " channel_columns=[\"x1\", \"x2\"],\n", " control_columns=[\n", " \"event_1\",\n", " \"event_2\",\n", " \"t\",\n", " ],\n", " adstock_max_lag=8,\n", " yearly_seasonality=2,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can confirm our settings are being used" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'sigma': array([2.1775326 , 1.14026088]), 'dims': ('channel',)}" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mmm.model_config[\"beta_channel\"]" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'tune': 1000, 'draws': 1000, 'chains': 4, 'target_accept': 0.95}" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mmm.sampler_config" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Fitting" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that we didn't pass the dataset to the class constructor itself. This is done to mimick the `scikit-learn` API, and make it easier to get started on PyMC-Marketing models." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# Split X, and y\n", "X = data.drop('y',axis=1)\n", "y = data['y']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "All that's left now is to finally fit the model:\n", "\n", "As you can see below, you can still pass the sampler kwargs directly to `fit()` method. However, only those kwargs passed using `sampler_config` will be saved and reused after loading the model." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Auto-assigning NUTS sampler...\n", "Initializing NUTS using jitter+adapt_diag...\n", "Multiprocess sampling (4 chains in 4 jobs)\n", "NUTS: [intercept, beta_channel, alpha, lam, sigma, gamma_control, gamma_fourier]\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " 100.00% [8000/8000 01:11<00:00 Sampling 4 chains, 0 divergences]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 71 seconds.\n", "Sampling: [alpha, beta_channel, gamma_control, gamma_fourier, intercept, lam, likelihood, sigma]\n", "Sampling: [likelihood]\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " 100.00% [4000/4000 00:00<00:00]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", "
\n", "
arviz.InferenceData
\n", "
\n", "
    \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:                    (chain: 4, draw: 1000, control: 3,\n",
             "                                fourier_mode: 4, channel: 2, date: 179)\n",
             "Coordinates:\n",
             "  * chain                      (chain) int64 0 1 2 3\n",
             "  * draw                       (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n",
             "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
             "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
             "  * channel                    (channel) <U2 'x1' 'x2'\n",
             "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
             "Data variables: (12/13)\n",
             "    intercept                  (chain, draw) float64 0.3852 0.3278 ... 0.311\n",
             "    gamma_control              (chain, draw, control) float64 0.2353 ... 0.00...\n",
             "    gamma_fourier              (chain, draw, fourier_mode) float64 0.004958 ....\n",
             "    beta_channel               (chain, draw, channel) float64 0.3975 ... 0.3032\n",
             "    alpha                      (chain, draw, channel) float64 0.4327 ... 0.2364\n",
             "    lam                        (chain, draw, channel) float64 2.864 ... 2.357\n",
             "    ...                         ...\n",
             "    channel_adstock            (chain, draw, date, channel) float64 0.1816 .....\n",
             "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.2543 .....\n",
             "    channel_contributions      (chain, draw, date, channel) float64 0.1011 .....\n",
             "    control_contributions      (chain, draw, date, control) float64 0.0 ... 0...\n",
             "    fourier_contributions      (chain, draw, date, fourier_mode) float64 0.00...\n",
             "    mu                         (chain, draw, date) float64 0.4922 ... 0.6067\n",
             "Attributes:\n",
             "    created_at:                 2023-08-03T12:44:52.592010\n",
             "    arviz_version:              0.15.1\n",
             "    inference_library:          pymc\n",
             "    inference_library_version:  5.7.0\n",
             "    sampling_time:              71.22048568725586\n",
             "    tuning_steps:               1000

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:     (chain: 4, draw: 1000, date: 179)\n",
             "Coordinates:\n",
             "  * chain       (chain) int64 0 1 2 3\n",
             "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
             "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
             "Data variables:\n",
             "    likelihood  (chain, draw, date) float64 0.4729 0.4781 ... 0.5132 0.5917\n",
             "Attributes:\n",
             "    created_at:                 2023-08-03T12:44:55.266426\n",
             "    arviz_version:              0.15.1\n",
             "    inference_library:          pymc\n",
             "    inference_library_version:  5.7.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:                (chain: 4, draw: 1000)\n",
             "Coordinates:\n",
             "  * chain                  (chain) int64 0 1 2 3\n",
             "  * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999\n",
             "Data variables: (12/17)\n",
             "    process_time_diff      (chain, draw) float64 0.02484 0.03027 ... 0.04276\n",
             "    lp                     (chain, draw) float64 351.2 354.5 ... 353.0 353.4\n",
             "    max_energy_error       (chain, draw) float64 -0.03674 0.03495 ... 0.1954\n",
             "    step_size              (chain, draw) float64 0.07388 0.07388 ... 0.06563\n",
             "    reached_max_treedepth  (chain, draw) bool False False False ... False False\n",
             "    perf_counter_diff      (chain, draw) float64 0.02507 0.03056 ... 0.04421\n",
             "    ...                     ...\n",
             "    index_in_trajectory    (chain, draw) int64 -31 33 32 -28 43 ... 44 -9 38 29\n",
             "    energy                 (chain, draw) float64 -344.9 -345.2 ... -345.0 -346.9\n",
             "    largest_eigval         (chain, draw) float64 nan nan nan nan ... nan nan nan\n",
             "    n_steps                (chain, draw) float64 63.0 63.0 63.0 ... 63.0 127.0\n",
             "    tree_depth             (chain, draw) int64 6 6 6 6 6 6 6 6 ... 6 6 6 6 6 6 7\n",
             "    diverging              (chain, draw) bool False False False ... False False\n",
             "Attributes:\n",
             "    created_at:                 2023-08-03T12:44:52.608151\n",
             "    arviz_version:              0.15.1\n",
             "    inference_library:          pymc\n",
             "    inference_library_version:  5.7.0\n",
             "    sampling_time:              71.22048568725586\n",
             "    tuning_steps:               1000

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:                    (chain: 1, draw: 500, channel: 2, control: 3,\n",
             "                                fourier_mode: 4, date: 179)\n",
             "Coordinates:\n",
             "  * chain                      (chain) int64 0\n",
             "  * draw                       (draw) int64 0 1 2 3 4 5 ... 495 496 497 498 499\n",
             "  * channel                    (channel) <U2 'x1' 'x2'\n",
             "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
             "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
             "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
             "Data variables: (12/13)\n",
             "    beta_channel               (chain, draw, channel) float64 1.418 ... 1.259\n",
             "    gamma_control              (chain, draw, control) float64 3.665 ... -1.201\n",
             "    gamma_fourier              (chain, draw, fourier_mode) float64 0.08728 .....\n",
             "    channel_contributions      (chain, draw, date, channel) float64 0.2905 .....\n",
             "    intercept                  (chain, draw) float64 -3.807 -1.606 ... -0.598\n",
             "    lam                        (chain, draw, channel) float64 1.372 ... 1.678\n",
             "    ...                         ...\n",
             "    mu                         (chain, draw, date) float64 -3.19 ... -212.0\n",
             "    alpha                      (chain, draw, channel) float64 0.0521 ... 0.7527\n",
             "    channel_adstock            (chain, draw, date, channel) float64 0.303 ......\n",
             "    control_contributions      (chain, draw, date, control) float64 0.0 ... -...\n",
             "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.2049 .....\n",
             "    sigma                      (chain, draw) float64 2.361 0.3904 ... 2.132\n",
             "Attributes:\n",
             "    created_at:                 2023-08-03T12:44:54.644498\n",
             "    arviz_version:              0.15.1\n",
             "    inference_library:          pymc\n",
             "    inference_library_version:  5.7.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:     (chain: 1, draw: 500, date: 179)\n",
             "Coordinates:\n",
             "  * chain       (chain) int64 0\n",
             "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
             "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
             "Data variables:\n",
             "    likelihood  (chain, draw, date) float64 -7.646 -1.949 ... -210.5 -215.1\n",
             "Attributes:\n",
             "    created_at:                 2023-08-03T12:44:54.650068\n",
             "    arviz_version:              0.15.1\n",
             "    inference_library:          pymc\n",
             "    inference_library_version:  5.7.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:     (date: 179)\n",
             "Coordinates:\n",
             "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
             "Data variables:\n",
             "    likelihood  (date) float64 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625\n",
             "Attributes:\n",
             "    created_at:                 2023-08-03T12:44:52.612711\n",
             "    arviz_version:              0.15.1\n",
             "    inference_library:          pymc\n",
             "    inference_library_version:  5.7.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:       (date: 179, channel: 2, control: 3, fourier_mode: 4)\n",
             "Coordinates:\n",
             "  * date          (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
             "  * channel       (channel) <U2 'x1' 'x2'\n",
             "  * control       (control) <U7 'event_1' 'event_2' 't'\n",
             "  * fourier_mode  (fourier_mode) <U11 'sin_order_1' ... 'cos_order_2'\n",
             "Data variables:\n",
             "    channel_data  (date, channel) float64 0.3196 0.0 0.1128 ... 0.0 0.4403 0.0\n",
             "    target        (date) float64 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625\n",
             "    control_data  (date, control) float64 0.0 0.0 0.0 0.0 ... 0.0 0.0 178.0\n",
             "    fourier_data  (date, fourier_mode) float64 0.9999 -0.01183 ... -0.4547\n",
             "Attributes:\n",
             "    created_at:                 2023-08-03T12:44:52.614490\n",
             "    arviz_version:              0.15.1\n",
             "    inference_library:          pymc\n",
             "    inference_library_version:  5.7.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:    (index: 179)\n",
             "Coordinates:\n",
             "  * index      (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
             "Data variables:\n",
             "    date_week  (index) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
             "    x1         (index) float64 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389\n",
             "    x2         (index) float64 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.8633 0.0 0.0 0.0\n",
             "    event_1    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
             "    event_2    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
             "    dayofyear  (index) int64 92 99 106 113 120 127 ... 207 214 221 228 235 242\n",
             "    t          (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
             "    y          (index) float64 3.985e+03 3.763e+03 ... 4.479e+03 4.676e+03

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
\n", "
\n", " " ], "text/plain": [ "Inference data with groups:\n", "\t> posterior\n", "\t> posterior_predictive\n", "\t> sample_stats\n", "\t> prior\n", "\t> prior_predictive\n", "\t> observed_data\n", "\t> constant_data\n", "\t> fit_data" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mmm.fit(X=X, y=y, random_seed=rng)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `fit()` method automatically builds the model using the priors from `model_config`, and assigns the created model to our instance. You can access it as a normal attribute." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "pymc.model.Model" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(mmm.model)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "clusterdate (179) x channel (2)\n", "\n", "date (179) x channel (2)\n", "\n", "\n", "clusterdate (179)\n", "\n", "date (179)\n", "\n", "\n", "clusterchannel (2)\n", "\n", "channel (2)\n", "\n", "\n", "clusterdate (179) x control (3)\n", "\n", "date (179) x control (3)\n", "\n", "\n", "clustercontrol (3)\n", "\n", "control (3)\n", "\n", "\n", "clusterdate (179) x fourier_mode (4)\n", "\n", "date (179) x fourier_mode (4)\n", "\n", "\n", "clusterfourier_mode (4)\n", "\n", "fourier_mode (4)\n", "\n", "\n", "\n", "channel_adstock_saturated\n", "\n", "channel_adstock_saturated\n", "~\n", "Deterministic\n", "\n", "\n", "\n", "channel_contributions\n", "\n", "channel_contributions\n", "~\n", "Deterministic\n", "\n", "\n", "\n", "channel_adstock_saturated->channel_contributions\n", "\n", "\n", "\n", "\n", "\n", "channel_adstock\n", "\n", "channel_adstock\n", "~\n", "Deterministic\n", "\n", "\n", "\n", "channel_adstock->channel_adstock_saturated\n", "\n", "\n", "\n", "\n", "\n", "mu\n", "\n", "mu\n", "~\n", "Deterministic\n", "\n", "\n", "\n", "channel_contributions->mu\n", "\n", "\n", "\n", "\n", "\n", "channel_data\n", "\n", "channel_data\n", "~\n", "MutableData\n", "\n", "\n", "\n", "channel_data->channel_adstock\n", "\n", "\n", "\n", "\n", "\n", "likelihood\n", "\n", "likelihood\n", "~\n", "Normal\n", "\n", "\n", "\n", "mu->likelihood\n", "\n", "\n", "\n", "\n", "\n", "target\n", "\n", "target\n", "~\n", "MutableData\n", "\n", "\n", "\n", "likelihood->target\n", "\n", "\n", "\n", "\n", "\n", "sigma\n", "\n", "sigma\n", "~\n", "HalfNormal\n", "\n", "\n", "\n", "sigma->likelihood\n", "\n", "\n", "\n", "\n", "\n", "intercept\n", "\n", "intercept\n", "~\n", "Normal\n", "\n", "\n", "\n", "intercept->mu\n", "\n", "\n", "\n", "\n", "\n", "lam\n", "\n", "lam\n", "~\n", "Gamma\n", "\n", "\n", "\n", "lam->channel_adstock_saturated\n", "\n", "\n", "\n", "\n", "\n", "beta_channel\n", "\n", "beta_channel\n", "~\n", "HalfNormal\n", "\n", "\n", "\n", "beta_channel->channel_contributions\n", "\n", "\n", "\n", "\n", "\n", "alpha\n", "\n", "alpha\n", "~\n", "Beta\n", "\n", "\n", "\n", "alpha->channel_adstock\n", "\n", "\n", "\n", "\n", "\n", "control_contributions\n", "\n", "control_contributions\n", "~\n", "Deterministic\n", "\n", "\n", "\n", "control_contributions->mu\n", "\n", "\n", "\n", "\n", "\n", "control_data\n", "\n", "control_data\n", "~\n", "MutableData\n", "\n", "\n", "\n", "control_data->control_contributions\n", "\n", "\n", "\n", "\n", "\n", "gamma_control\n", "\n", "gamma_control\n", "~\n", "Normal\n", "\n", "\n", "\n", "gamma_control->control_contributions\n", "\n", "\n", "\n", "\n", "\n", "fourier_contributions\n", "\n", "fourier_contributions\n", "~\n", "Deterministic\n", "\n", "\n", "\n", "fourier_contributions->mu\n", "\n", "\n", "\n", "\n", "\n", "fourier_data\n", "\n", "fourier_data\n", "~\n", "MutableData\n", "\n", "\n", "\n", "fourier_data->fourier_contributions\n", "\n", "\n", "\n", "\n", "\n", "gamma_fourier\n", "\n", "gamma_fourier\n", "~\n", "Laplace\n", "\n", "\n", "\n", "gamma_fourier->fourier_contributions\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mmm.graphviz()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "posterior trace can be accessed by `fit_result` attribute" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.Dataset>\n",
       "Dimensions:                    (chain: 4, draw: 1000, control: 3,\n",
       "                                fourier_mode: 4, channel: 2, date: 179)\n",
       "Coordinates:\n",
       "  * chain                      (chain) int64 0 1 2 3\n",
       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n",
       "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
       "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
       "  * channel                    (channel) <U2 'x1' 'x2'\n",
       "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
       "Data variables: (12/13)\n",
       "    intercept                  (chain, draw) float64 0.3852 0.3278 ... 0.311\n",
       "    gamma_control              (chain, draw, control) float64 0.2353 ... 0.00...\n",
       "    gamma_fourier              (chain, draw, fourier_mode) float64 0.004958 ....\n",
       "    beta_channel               (chain, draw, channel) float64 0.3975 ... 0.3032\n",
       "    alpha                      (chain, draw, channel) float64 0.4327 ... 0.2364\n",
       "    lam                        (chain, draw, channel) float64 2.864 ... 2.357\n",
       "    ...                         ...\n",
       "    channel_adstock            (chain, draw, date, channel) float64 0.1816 .....\n",
       "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.2543 .....\n",
       "    channel_contributions      (chain, draw, date, channel) float64 0.1011 .....\n",
       "    control_contributions      (chain, draw, date, control) float64 0.0 ... 0...\n",
       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 0.00...\n",
       "    mu                         (chain, draw, date) float64 0.4922 ... 0.6067\n",
       "Attributes:\n",
       "    created_at:                 2023-08-03T12:44:52.592010\n",
       "    arviz_version:              0.15.1\n",
       "    inference_library:          pymc\n",
       "    inference_library_version:  5.7.0\n",
       "    sampling_time:              71.22048568725586\n",
       "    tuning_steps:               1000
" ], "text/plain": [ "\n", "Dimensions: (chain: 4, draw: 1000, control: 3,\n", " fourier_mode: 4, channel: 2, date: 179)\n", "Coordinates:\n", " * chain (chain) int64 0 1 2 3\n", " * draw (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n", " * control (control) \n", "
\n", "
arviz.InferenceData
\n", "
\n", "
    \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:                    (chain: 4, draw: 1000, control: 3,\n",
             "                                fourier_mode: 4, channel: 2, date: 179)\n",
             "Coordinates:\n",
             "  * chain                      (chain) int64 0 1 2 3\n",
             "  * draw                       (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n",
             "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
             "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
             "  * channel                    (channel) <U2 'x1' 'x2'\n",
             "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
             "Data variables: (12/13)\n",
             "    intercept                  (chain, draw) float64 0.3852 0.3278 ... 0.311\n",
             "    gamma_control              (chain, draw, control) float64 0.2353 ... 0.00...\n",
             "    gamma_fourier              (chain, draw, fourier_mode) float64 0.004958 ....\n",
             "    beta_channel               (chain, draw, channel) float64 0.3975 ... 0.3032\n",
             "    alpha                      (chain, draw, channel) float64 0.4327 ... 0.2364\n",
             "    lam                        (chain, draw, channel) float64 2.864 ... 2.357\n",
             "    ...                         ...\n",
             "    channel_adstock            (chain, draw, date, channel) float64 0.1816 .....\n",
             "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.2543 .....\n",
             "    channel_contributions      (chain, draw, date, channel) float64 0.1011 .....\n",
             "    control_contributions      (chain, draw, date, control) float64 0.0 ... 0...\n",
             "    fourier_contributions      (chain, draw, date, fourier_mode) float64 0.00...\n",
             "    mu                         (chain, draw, date) float64 0.4922 ... 0.6067\n",
             "Attributes:\n",
             "    created_at:                 2023-08-03T12:44:52.592010\n",
             "    arviz_version:              0.15.1\n",
             "    inference_library:          pymc\n",
             "    inference_library_version:  5.7.0\n",
             "    sampling_time:              71.22048568725586\n",
             "    tuning_steps:               1000

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:     (chain: 4, draw: 1000, date: 179)\n",
             "Coordinates:\n",
             "  * chain       (chain) int64 0 1 2 3\n",
             "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
             "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
             "Data variables:\n",
             "    likelihood  (chain, draw, date) float64 0.4729 0.4781 ... 0.5132 0.5917\n",
             "Attributes:\n",
             "    created_at:                 2023-08-03T12:44:55.266426\n",
             "    arviz_version:              0.15.1\n",
             "    inference_library:          pymc\n",
             "    inference_library_version:  5.7.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:                (chain: 4, draw: 1000)\n",
             "Coordinates:\n",
             "  * chain                  (chain) int64 0 1 2 3\n",
             "  * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999\n",
             "Data variables: (12/17)\n",
             "    process_time_diff      (chain, draw) float64 0.02484 0.03027 ... 0.04276\n",
             "    lp                     (chain, draw) float64 351.2 354.5 ... 353.0 353.4\n",
             "    max_energy_error       (chain, draw) float64 -0.03674 0.03495 ... 0.1954\n",
             "    step_size              (chain, draw) float64 0.07388 0.07388 ... 0.06563\n",
             "    reached_max_treedepth  (chain, draw) bool False False False ... False False\n",
             "    perf_counter_diff      (chain, draw) float64 0.02507 0.03056 ... 0.04421\n",
             "    ...                     ...\n",
             "    index_in_trajectory    (chain, draw) int64 -31 33 32 -28 43 ... 44 -9 38 29\n",
             "    energy                 (chain, draw) float64 -344.9 -345.2 ... -345.0 -346.9\n",
             "    largest_eigval         (chain, draw) float64 nan nan nan nan ... nan nan nan\n",
             "    n_steps                (chain, draw) float64 63.0 63.0 63.0 ... 63.0 127.0\n",
             "    tree_depth             (chain, draw) int64 6 6 6 6 6 6 6 6 ... 6 6 6 6 6 6 7\n",
             "    diverging              (chain, draw) bool False False False ... False False\n",
             "Attributes:\n",
             "    created_at:                 2023-08-03T12:44:52.608151\n",
             "    arviz_version:              0.15.1\n",
             "    inference_library:          pymc\n",
             "    inference_library_version:  5.7.0\n",
             "    sampling_time:              71.22048568725586\n",
             "    tuning_steps:               1000

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:                    (chain: 1, draw: 500, channel: 2, control: 3,\n",
             "                                fourier_mode: 4, date: 179)\n",
             "Coordinates:\n",
             "  * chain                      (chain) int64 0\n",
             "  * draw                       (draw) int64 0 1 2 3 4 5 ... 495 496 497 498 499\n",
             "  * channel                    (channel) <U2 'x1' 'x2'\n",
             "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
             "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
             "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
             "Data variables: (12/13)\n",
             "    beta_channel               (chain, draw, channel) float64 1.418 ... 1.259\n",
             "    gamma_control              (chain, draw, control) float64 3.665 ... -1.201\n",
             "    gamma_fourier              (chain, draw, fourier_mode) float64 0.08728 .....\n",
             "    channel_contributions      (chain, draw, date, channel) float64 0.2905 .....\n",
             "    intercept                  (chain, draw) float64 -3.807 -1.606 ... -0.598\n",
             "    lam                        (chain, draw, channel) float64 1.372 ... 1.678\n",
             "    ...                         ...\n",
             "    mu                         (chain, draw, date) float64 -3.19 ... -212.0\n",
             "    alpha                      (chain, draw, channel) float64 0.0521 ... 0.7527\n",
             "    channel_adstock            (chain, draw, date, channel) float64 0.303 ......\n",
             "    control_contributions      (chain, draw, date, control) float64 0.0 ... -...\n",
             "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.2049 .....\n",
             "    sigma                      (chain, draw) float64 2.361 0.3904 ... 2.132\n",
             "Attributes:\n",
             "    created_at:                 2023-08-03T12:44:54.644498\n",
             "    arviz_version:              0.15.1\n",
             "    inference_library:          pymc\n",
             "    inference_library_version:  5.7.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:     (chain: 1, draw: 500, date: 179)\n",
             "Coordinates:\n",
             "  * chain       (chain) int64 0\n",
             "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
             "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
             "Data variables:\n",
             "    likelihood  (chain, draw, date) float64 -7.646 -1.949 ... -210.5 -215.1\n",
             "Attributes:\n",
             "    created_at:                 2023-08-03T12:44:54.650068\n",
             "    arviz_version:              0.15.1\n",
             "    inference_library:          pymc\n",
             "    inference_library_version:  5.7.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:     (date: 179)\n",
             "Coordinates:\n",
             "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
             "Data variables:\n",
             "    likelihood  (date) float64 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625\n",
             "Attributes:\n",
             "    created_at:                 2023-08-03T12:44:52.612711\n",
             "    arviz_version:              0.15.1\n",
             "    inference_library:          pymc\n",
             "    inference_library_version:  5.7.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:       (date: 179, channel: 2, control: 3, fourier_mode: 4)\n",
             "Coordinates:\n",
             "  * date          (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
             "  * channel       (channel) <U2 'x1' 'x2'\n",
             "  * control       (control) <U7 'event_1' 'event_2' 't'\n",
             "  * fourier_mode  (fourier_mode) <U11 'sin_order_1' ... 'cos_order_2'\n",
             "Data variables:\n",
             "    channel_data  (date, channel) float64 0.3196 0.0 0.1128 ... 0.0 0.4403 0.0\n",
             "    target        (date) float64 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625\n",
             "    control_data  (date, control) float64 0.0 0.0 0.0 0.0 ... 0.0 0.0 178.0\n",
             "    fourier_data  (date, fourier_mode) float64 0.9999 -0.01183 ... -0.4547\n",
             "Attributes:\n",
             "    created_at:                 2023-08-03T12:44:52.614490\n",
             "    arviz_version:              0.15.1\n",
             "    inference_library:          pymc\n",
             "    inference_library_version:  5.7.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:    (index: 179)\n",
             "Coordinates:\n",
             "  * index      (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
             "Data variables:\n",
             "    date_week  (index) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
             "    x1         (index) float64 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389\n",
             "    x2         (index) float64 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.8633 0.0 0.0 0.0\n",
             "    event_1    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
             "    event_2    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
             "    dayofyear  (index) int64 92 99 106 113 120 127 ... 207 214 221 228 235 242\n",
             "    t          (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
             "    y          (index) float64 3.985e+03 3.763e+03 ... 4.479e+03 4.676e+03

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
\n", " \n", " " ], "text/plain": [ "Inference data with groups:\n", "\t> posterior\n", "\t> posterior_predictive\n", "\t> sample_stats\n", "\t> prior\n", "\t> prior_predictive\n", "\t> observed_data\n", "\t> constant_data\n", "\t> fit_data" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mmm.idata" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Saving and loading a fitted model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "All the data passed to the model on initialization is stored in `idata.attrs`. This will be used later in the `save()` method to convert both this data and all the fit data into the netCDF format. You can read more about this format [here](https://python.arviz.org/en/stable/getting_started/XarrayforArviZ.html)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `save` and `load` method only require a path to inform where the model should be saved and loaded from." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "mmm.save('my_saved_model.nc')" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ricardo/miniconda3/envs/pymc-marketing/lib/python3.11/site-packages/arviz/data/inference_data.py:152: UserWarning: fit_data group is not defined in the InferenceData scheme\n", " warnings.warn(\n" ] } ], "source": [ "loaded_model = DelayedSaturatedMMM.load('my_saved_model.nc')" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'sigma': array([2.1775326 , 1.14026088]), 'dims': ('channel',)}" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loaded_model.model_config[\"beta_channel\"]" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "clusterdate (179) x channel (2)\n", "\n", "date (179) x channel (2)\n", "\n", "\n", "clusterdate (179)\n", "\n", "date (179)\n", "\n", "\n", "clusterchannel (2)\n", "\n", "channel (2)\n", "\n", "\n", "clusterdate (179) x control (3)\n", "\n", "date (179) x control (3)\n", "\n", "\n", "clustercontrol (3)\n", "\n", "control (3)\n", "\n", "\n", "clusterdate (179) x fourier_mode (4)\n", "\n", "date (179) x fourier_mode (4)\n", "\n", "\n", "clusterfourier_mode (4)\n", "\n", "fourier_mode (4)\n", "\n", "\n", "\n", "channel_adstock_saturated\n", "\n", "channel_adstock_saturated\n", "~\n", "Deterministic\n", "\n", "\n", "\n", "channel_contributions\n", "\n", "channel_contributions\n", "~\n", "Deterministic\n", "\n", "\n", "\n", "channel_adstock_saturated->channel_contributions\n", "\n", "\n", "\n", "\n", "\n", "channel_adstock\n", "\n", "channel_adstock\n", "~\n", "Deterministic\n", "\n", "\n", "\n", "channel_adstock->channel_adstock_saturated\n", "\n", "\n", "\n", "\n", "\n", "mu\n", "\n", "mu\n", "~\n", "Deterministic\n", "\n", "\n", "\n", "channel_contributions->mu\n", "\n", "\n", "\n", "\n", "\n", "channel_data\n", "\n", "channel_data\n", "~\n", "MutableData\n", "\n", "\n", "\n", "channel_data->channel_adstock\n", "\n", "\n", "\n", "\n", "\n", "likelihood\n", "\n", "likelihood\n", "~\n", "Normal\n", "\n", "\n", "\n", "mu->likelihood\n", "\n", "\n", "\n", "\n", "\n", "target\n", "\n", "target\n", "~\n", "MutableData\n", "\n", "\n", "\n", "likelihood->target\n", "\n", "\n", "\n", "\n", "\n", "sigma\n", "\n", "sigma\n", "~\n", "HalfNormal\n", "\n", "\n", "\n", "sigma->likelihood\n", "\n", "\n", "\n", "\n", "\n", "intercept\n", "\n", "intercept\n", "~\n", "Normal\n", "\n", "\n", "\n", "intercept->mu\n", "\n", "\n", "\n", "\n", "\n", "lam\n", "\n", "lam\n", "~\n", "Gamma\n", "\n", "\n", "\n", "lam->channel_adstock_saturated\n", "\n", "\n", "\n", "\n", "\n", "beta_channel\n", "\n", "beta_channel\n", "~\n", "HalfNormal\n", "\n", "\n", "\n", "beta_channel->channel_contributions\n", "\n", "\n", "\n", "\n", "\n", "alpha\n", "\n", "alpha\n", "~\n", "Beta\n", "\n", "\n", "\n", "alpha->channel_adstock\n", "\n", "\n", "\n", "\n", "\n", "control_contributions\n", "\n", "control_contributions\n", "~\n", "Deterministic\n", "\n", "\n", "\n", "control_contributions->mu\n", "\n", "\n", "\n", "\n", "\n", "control_data\n", "\n", "control_data\n", "~\n", "MutableData\n", "\n", "\n", "\n", "control_data->control_contributions\n", "\n", "\n", "\n", "\n", "\n", "gamma_control\n", "\n", "gamma_control\n", "~\n", "Normal\n", "\n", "\n", "\n", "gamma_control->control_contributions\n", "\n", "\n", "\n", "\n", "\n", "fourier_contributions\n", "\n", "fourier_contributions\n", "~\n", "Deterministic\n", "\n", "\n", "\n", "fourier_contributions->mu\n", "\n", "\n", "\n", "\n", "\n", "fourier_data\n", "\n", "fourier_data\n", "~\n", "MutableData\n", "\n", "\n", "\n", "fourier_data->fourier_contributions\n", "\n", "\n", "\n", "\n", "\n", "gamma_fourier\n", "\n", "gamma_fourier\n", "~\n", "Laplace\n", "\n", "\n", "\n", "gamma_fourier->fourier_contributions\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loaded_model.graphviz()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
arviz.InferenceData
\n", "
\n", "
    \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:                    (chain: 4, draw: 1000, control: 3,\n",
             "                                fourier_mode: 4, channel: 2, date: 179)\n",
             "Coordinates:\n",
             "  * chain                      (chain) int64 0 1 2 3\n",
             "  * draw                       (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n",
             "  * control                    (control) object 'event_1' 'event_2' 't'\n",
             "  * fourier_mode               (fourier_mode) object 'sin_order_1' ... 'cos_o...\n",
             "  * channel                    (channel) object 'x1' 'x2'\n",
             "  * date                       (date) object '2018-04-02' ... '2021-08-30'\n",
             "Data variables: (12/13)\n",
             "    intercept                  (chain, draw) float64 ...\n",
             "    gamma_control              (chain, draw, control) float64 ...\n",
             "    gamma_fourier              (chain, draw, fourier_mode) float64 ...\n",
             "    beta_channel               (chain, draw, channel) float64 ...\n",
             "    alpha                      (chain, draw, channel) float64 ...\n",
             "    lam                        (chain, draw, channel) float64 ...\n",
             "    ...                         ...\n",
             "    channel_adstock            (chain, draw, date, channel) float64 ...\n",
             "    channel_adstock_saturated  (chain, draw, date, channel) float64 ...\n",
             "    channel_contributions      (chain, draw, date, channel) float64 ...\n",
             "    control_contributions      (chain, draw, date, control) float64 ...\n",
             "    fourier_contributions      (chain, draw, date, fourier_mode) float64 ...\n",
             "    mu                         (chain, draw, date) float64 ...\n",
             "Attributes:\n",
             "    created_at:                 2023-08-03T12:44:52.592010\n",
             "    arviz_version:              0.15.1\n",
             "    inference_library:          pymc\n",
             "    inference_library_version:  5.7.0\n",
             "    sampling_time:              71.22048568725586\n",
             "    tuning_steps:               1000

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:     (chain: 4, draw: 1000, date: 179)\n",
             "Coordinates:\n",
             "  * chain       (chain) int64 0 1 2 3\n",
             "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
             "  * date        (date) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
             "Data variables:\n",
             "    likelihood  (chain, draw, date) float64 ...\n",
             "Attributes:\n",
             "    created_at:                 2023-08-03T12:44:55.266426\n",
             "    arviz_version:              0.15.1\n",
             "    inference_library:          pymc\n",
             "    inference_library_version:  5.7.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:                (chain: 4, draw: 1000)\n",
             "Coordinates:\n",
             "  * chain                  (chain) int64 0 1 2 3\n",
             "  * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999\n",
             "Data variables: (12/17)\n",
             "    process_time_diff      (chain, draw) float64 ...\n",
             "    lp                     (chain, draw) float64 ...\n",
             "    max_energy_error       (chain, draw) float64 ...\n",
             "    step_size              (chain, draw) float64 ...\n",
             "    reached_max_treedepth  (chain, draw) bool ...\n",
             "    perf_counter_diff      (chain, draw) float64 ...\n",
             "    ...                     ...\n",
             "    index_in_trajectory    (chain, draw) int64 ...\n",
             "    energy                 (chain, draw) float64 ...\n",
             "    largest_eigval         (chain, draw) float64 ...\n",
             "    n_steps                (chain, draw) float64 ...\n",
             "    tree_depth             (chain, draw) int64 ...\n",
             "    diverging              (chain, draw) bool ...\n",
             "Attributes:\n",
             "    created_at:                 2023-08-03T12:44:52.608151\n",
             "    arviz_version:              0.15.1\n",
             "    inference_library:          pymc\n",
             "    inference_library_version:  5.7.0\n",
             "    sampling_time:              71.22048568725586\n",
             "    tuning_steps:               1000

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:                    (chain: 1, draw: 500, channel: 2, control: 3,\n",
             "                                fourier_mode: 4, date: 179)\n",
             "Coordinates:\n",
             "  * chain                      (chain) int64 0\n",
             "  * draw                       (draw) int64 0 1 2 3 4 5 ... 495 496 497 498 499\n",
             "  * channel                    (channel) object 'x1' 'x2'\n",
             "  * control                    (control) object 'event_1' 'event_2' 't'\n",
             "  * fourier_mode               (fourier_mode) object 'sin_order_1' ... 'cos_o...\n",
             "  * date                       (date) object '2018-04-02' ... '2021-08-30'\n",
             "Data variables: (12/13)\n",
             "    beta_channel               (chain, draw, channel) float64 ...\n",
             "    gamma_control              (chain, draw, control) float64 ...\n",
             "    gamma_fourier              (chain, draw, fourier_mode) float64 ...\n",
             "    channel_contributions      (chain, draw, date, channel) float64 ...\n",
             "    intercept                  (chain, draw) float64 ...\n",
             "    lam                        (chain, draw, channel) float64 ...\n",
             "    ...                         ...\n",
             "    mu                         (chain, draw, date) float64 ...\n",
             "    alpha                      (chain, draw, channel) float64 ...\n",
             "    channel_adstock            (chain, draw, date, channel) float64 ...\n",
             "    control_contributions      (chain, draw, date, control) float64 ...\n",
             "    channel_adstock_saturated  (chain, draw, date, channel) float64 ...\n",
             "    sigma                      (chain, draw) float64 ...\n",
             "Attributes:\n",
             "    created_at:                 2023-08-03T12:44:54.644498\n",
             "    arviz_version:              0.15.1\n",
             "    inference_library:          pymc\n",
             "    inference_library_version:  5.7.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:     (chain: 1, draw: 500, date: 179)\n",
             "Coordinates:\n",
             "  * chain       (chain) int64 0\n",
             "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
             "  * date        (date) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
             "Data variables:\n",
             "    likelihood  (chain, draw, date) float64 ...\n",
             "Attributes:\n",
             "    created_at:                 2023-08-03T12:44:54.650068\n",
             "    arviz_version:              0.15.1\n",
             "    inference_library:          pymc\n",
             "    inference_library_version:  5.7.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:     (date: 179)\n",
             "Coordinates:\n",
             "  * date        (date) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
             "Data variables:\n",
             "    likelihood  (date) float64 ...\n",
             "Attributes:\n",
             "    created_at:                 2023-08-03T12:44:52.612711\n",
             "    arviz_version:              0.15.1\n",
             "    inference_library:          pymc\n",
             "    inference_library_version:  5.7.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:       (date: 179, channel: 2, control: 3, fourier_mode: 4)\n",
             "Coordinates:\n",
             "  * date          (date) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
             "  * channel       (channel) object 'x1' 'x2'\n",
             "  * control       (control) object 'event_1' 'event_2' 't'\n",
             "  * fourier_mode  (fourier_mode) object 'sin_order_1' ... 'cos_order_2'\n",
             "Data variables:\n",
             "    channel_data  (date, channel) float64 ...\n",
             "    target        (date) float64 ...\n",
             "    control_data  (date, control) float64 ...\n",
             "    fourier_data  (date, fourier_mode) float64 ...\n",
             "Attributes:\n",
             "    created_at:                 2023-08-03T12:44:52.614490\n",
             "    arviz_version:              0.15.1\n",
             "    inference_library:          pymc\n",
             "    inference_library_version:  5.7.0

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:    (index: 179)\n",
             "Coordinates:\n",
             "  * index      (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
             "Data variables:\n",
             "    date_week  (index) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
             "    x1         (index) float64 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389\n",
             "    x2         (index) float64 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.8633 0.0 0.0 0.0\n",
             "    event_1    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
             "    event_2    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
             "    dayofyear  (index) int64 92 99 106 113 120 127 ... 207 214 221 228 235 242\n",
             "    t          (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
             "    y          (index) float64 3.985e+03 3.763e+03 ... 4.479e+03 4.676e+03

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
\n", "
\n", " " ], "text/plain": [ "Inference data with groups:\n", "\t> posterior\n", "\t> posterior_predictive\n", "\t> sample_stats\n", "\t> prior\n", "\t> prior_predictive\n", "\t> observed_data\n", "\t> constant_data\n", "\t> fit_data" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loaded_model.idata" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A loaded model is ready to be used for sampling and prediction, making use of the previous fitting results and data if needed." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Sampling: [likelihood]\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n", " \n", " 100.00% [4000/4000 00:01<00:00]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.Dataset>\n",
       "Dimensions:     (chain: 4, draw: 1000, date: 179)\n",
       "Coordinates:\n",
       "  * chain       (chain) int64 0 1 2 3\n",
       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
       "Data variables:\n",
       "    likelihood  (chain, draw, date) float64 0.5057 0.4536 ... 0.5621 0.5581\n",
       "Attributes:\n",
       "    created_at:                 2023-08-03T12:45:03.555645\n",
       "    arviz_version:              0.15.1\n",
       "    inference_library:          pymc\n",
       "    inference_library_version:  5.7.0
" ], "text/plain": [ "\n", "Dimensions: (chain: 4, draw: 1000, date: 179)\n", "Coordinates:\n", " * chain (chain) int64 0 1 2 3\n", " * draw (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n", " * date (date) " ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "az.plot_ppc(loaded_model.idata);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Other models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Even though this introduction is using `DelayedSaturatedMMM`, all other PyMC-Marketing models (MMM and CLV) provide these functionalities as well." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The PyMC-Marketing functionalities described here are intended to facilitate model sharing among data science teams without demanding extensive modelling technical knowledge for everyone involved. We are still iterating on our API and would love to hear more feedback from our users!" ] } ], "metadata": { "hide_input": false, "kernelspec": { "display_name": "pymc-marketing", "language": "python", "name": "pymc-marketing" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 5 }