Otros muestreadores NUTS#

En este cuaderno mostramos cómo ajustar un modelo de CLV con otros muestreadores NUTS. Estos muestreadores alternativos pueden ser significativamente más rápidos y también muestrear en la GPU.

Nota

Necesita instalar estos paquetes en su entorno de Python.

Truco

Puede pasar el mismo argumento nuts_sampler exactamente a los modelos MMM.

Truco

El soporte de GPU solo funciona con selectos muestreadores en PyMC que utilizan el backend de JAX. Estos muestreadores incluyen numpyro, blackjax y nutpie.

Asegúrese de que la GPU esté registrada, siga las instrucciones aquí.

A efectos de ilustración, utilizaremos los mismos datos y modelo que en los otros cuadernos de CLV.

import arviz as az
import matplotlib.pyplot as plt
import pandas as pd

from pymc_marketing import clv

az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
df = pd.read_csv(
    "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/clv_quickstart.csv"
)
df["customer_id"] = range(len(df))

Podemos pasar el argumento de palabra clave nuts_sampler al método fit del modelo CLV para especificar el muestreador NUTS a utilizar. Además, podemos pasar argumentos de palabra clave adicionales que se transmitirán al método pymc.sample a través de la capa del constructor del modelo. Por ejemplo, podemos utilizar el muestreador numpyro de la siguiente manera:

sampler_kwargs = {
    "draws": 2_000,
    "target_accept": 0.9,
    "chains": 5,
    "random_seed": 42,
}

model = clv.BetaGeoModel(data=df)
idata_numpyro = model.fit(nuts_sampler="numpyro", **sampler_kwargs)

De manera similar, podemos utilizar el muestreador blackjax de la siguiente manera:

idata_blackjax = model.fit(nuts_sampler="blackjax", **sampler_kwargs)

Finalmente, podemos utilizar nutpie, que es una implementación en Rust de NUTS.

idata_nutpie = model.fit(nuts_sampler="nutpie", **sampler_kwargs)

Los resultados de los muestreadores son casi idénticos:

Hide code cell source

fig, axes = plt.subplots(
    nrows=2, ncols=2, figsize=(12, 8), sharex=False, sharey=False, layout="constrained"
)

axes = axes.ravel()

for i, var_name in enumerate(["a", "b", "alpha", "r"]):
    for j, (idata, label) in enumerate(
        zip(
            [idata_blackjax, idata_nutpie, idata_numpyro],
            ["blackjax", "nutpie", "numpyro"],
            strict=False,
        )
    ):
        az.plot_posterior(
            data=idata,
            var_names=[var_name],
            color=f"C{j}",
            point_estimate=None,
            hdi_prob="hide",
            label=label,
            ax=axes[i],
        )

fig.suptitle(
    "Posterior istributions of model parameters",
    fontsize=18,
    fontweight="bold",
    y=1.05,
);
../../_images/7af5a613b6f6ca1b2167b16dd6b14f2784d8baa78c9ac89f47916d4e2b30e4d1.png
%load_ext watermark
%watermark -n -u -v -iv -w -p blackjax,numpyro,nutpie,pymc
Last updated: Sat Mar 09 2024

Python implementation: CPython
Python version       : 3.11.3
IPython version      : 8.20.0

blackjax: 0.0.0
numpyro : 0.14.0
nutpie  : 0.9.2
pymc    : 5.10.4

arviz         : 0.15.1
matplotlib    : 3.7.1
pymc_marketing: 0.4.0

Watermark: 2.4.3