Other NUTS Samplers#

In this notebook we show how to fit a CLV model with other NUTS samplers. These alternative samplers can be significantly faster and also sample on the GPU.

Note

You need to install these packages in your Python environment.

Tip

You can pass the exact same nuts_sampler argument to the MMM models.

For the purpose of illustration, we will use the same data and model as in the other CLV notebooks.

import arviz as az
import matplotlib.pyplot as plt
from lifetimes.datasets import load_cdnow_summary

from pymc_marketing import clv

az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
df = (
    load_cdnow_summary(index_col=[0])
    .reset_index()
    .rename(columns={"ID": "customer_id"})
)

We can pass the keyword argument nuts_sampler to the fit method of the CLV model to specify the NUTS sampler to use. In addition, we can pass additional keyword arguments which will be passed to the pymc.sample method via the model builder layer. For example, we can use the numpyro sampler as:

sampler_kwargs = {
    "draws": 2_000,
    "target_accept": 0.9,
    "chains": 5,
    "random_seed": 42,
}

model = clv.BetaGeoModel(data=df)
idata_numpyro = model.fit(nuts_sampler="numpyro", **sampler_kwargs)

Similarly, we can use the blackjax sampler as:

idata_blackjax = model.fit(nuts_sampler="blackjax", **sampler_kwargs)

Finally, we can use the nutpie which is a Rust implementation of NUTS.

idata_nutpie = model.fit(nuts_sampler="nutpie", **sampler_kwargs)

The results from the samplers are almost identical:

Hide code cell source
fig, axes = plt.subplots(
    nrows=2, ncols=2, figsize=(12, 8), sharex=False, sharey=False, layout="constrained"
)

axes = axes.ravel()

for i, var_name in enumerate(["a", "b", "alpha", "r"]):
    for j, (idata, label) in enumerate(
        zip(
            [idata_blackjax, idata_nutpie, idata_numpyro],
            ["blackjax", "nutpie", "numpyro"], strict=False,
        )
    ):
        az.plot_posterior(
            data=idata,
            var_names=[var_name],
            color=f"C{j}",
            point_estimate=None,
            hdi_prob="hide",
            label=label,
            ax=axes[i],
        )

fig.suptitle(
    "Posterior istributions of model parameters",
    fontsize=18,
    fontweight="bold",
    y=1.05,
);
../../_images/7af5a613b6f6ca1b2167b16dd6b14f2784d8baa78c9ac89f47916d4e2b30e4d1.png
%load_ext watermark
%watermark -n -u -v -iv -w -p blackjax,numpyro,nutpie,pymc
Last updated: Sat Mar 09 2024

Python implementation: CPython
Python version       : 3.11.3
IPython version      : 8.20.0

blackjax: 0.0.0
numpyro : 0.14.0
nutpie  : 0.9.2
pymc    : 5.10.4

arviz         : 0.15.1
matplotlib    : 3.7.1
pymc_marketing: 0.4.0

Watermark: 2.4.3