Other NUTS Samplers#
In this notebook we show how to fit a CLV model with other NUTS samplers. These alternative samplers can be significantly faster and also sample on the GPU.
Note
You need to install these packages in your Python environment.
Tip
You can pass the exact same nuts_sampler
argument to the MMM models.
For the purpose of illustration, we will use the same data and model as in the other CLV notebooks.
import arviz as az
import matplotlib.pyplot as plt
from lifetimes.datasets import load_cdnow_summary
from pymc_marketing import clv
az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
df = (
load_cdnow_summary(index_col=[0])
.reset_index()
.rename(columns={"ID": "customer_id"})
)
We can pass the keyword argument nuts_sampler
to the fit
method of the CLV
model to specify the NUTS sampler to use. In addition, we can pass additional keyword arguments which will be passed to the pymc.sample
method via the model builder layer. For example, we can use the numpyro
sampler as:
sampler_kwargs = {
"draws": 2_000,
"target_accept": 0.9,
"chains": 5,
"random_seed": 42,
}
model = clv.BetaGeoModel(data=df)
idata_numpyro = model.fit(nuts_sampler="numpyro", **sampler_kwargs)
Similarly, we can use the blackjax
sampler as:
idata_blackjax = model.fit(nuts_sampler="blackjax", **sampler_kwargs)
Finally, we can use the nutpie
which is a Rust implementation of NUTS.
idata_nutpie = model.fit(nuts_sampler="nutpie", **sampler_kwargs)
The results from the samplers are almost identical:
Show code cell source
fig, axes = plt.subplots(
nrows=2, ncols=2, figsize=(12, 8), sharex=False, sharey=False, layout="constrained"
)
axes = axes.ravel()
for i, var_name in enumerate(["a", "b", "alpha", "r"]):
for j, (idata, label) in enumerate(
zip(
[idata_blackjax, idata_nutpie, idata_numpyro],
["blackjax", "nutpie", "numpyro"],
strict=False,
)
):
az.plot_posterior(
data=idata,
var_names=[var_name],
color=f"C{j}",
point_estimate=None,
hdi_prob="hide",
label=label,
ax=axes[i],
)
fig.suptitle(
"Posterior istributions of model parameters",
fontsize=18,
fontweight="bold",
y=1.05,
);
%load_ext watermark
%watermark -n -u -v -iv -w -p blackjax,numpyro,nutpie,pymc
Last updated: Sat Mar 09 2024
Python implementation: CPython
Python version : 3.11.3
IPython version : 8.20.0
blackjax: 0.0.0
numpyro : 0.14.0
nutpie : 0.9.2
pymc : 5.10.4
arviz : 0.15.1
matplotlib : 3.7.1
pymc_marketing: 0.4.0
Watermark: 2.4.3