Otros muestreadores NUTS#
En este cuaderno mostramos cómo ajustar un modelo de CLV con otros muestreadores NUTS. Estos muestreadores alternativos pueden ser significativamente más rápidos y también muestrear en la GPU.
Nota
Necesita instalar estos paquetes en su entorno de Python.
Truco
Puede pasar el mismo argumento nuts_sampler exactamente a los modelos MMM.
Truco
El soporte de GPU solo funciona con selectos muestreadores en PyMC que utilizan el backend de JAX. Estos muestreadores incluyen numpyro, blackjax y nutpie.
Asegúrese de que la GPU esté registrada, siga las instrucciones aquí.
A efectos de ilustración, utilizaremos los mismos datos y modelo que en los otros cuadernos de CLV.
import arviz as az
import matplotlib.pyplot as plt
import pandas as pd
from pymc_marketing import clv
az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
df = pd.read_csv(
"https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/clv_quickstart.csv"
)
df["customer_id"] = range(len(df))
Podemos pasar el argumento de palabra clave nuts_sampler al método fit del modelo CLV para especificar el muestreador NUTS a utilizar. Además, podemos pasar argumentos de palabra clave adicionales que se transmitirán al método pymc.sample a través de la capa del constructor del modelo. Por ejemplo, podemos utilizar el muestreador numpyro de la siguiente manera:
sampler_kwargs = {
"draws": 2_000,
"target_accept": 0.9,
"chains": 5,
"random_seed": 42,
}
model = clv.BetaGeoModel(data=df)
idata_numpyro = model.fit(nuts_sampler="numpyro", **sampler_kwargs)
De manera similar, podemos utilizar el muestreador blackjax de la siguiente manera:
idata_blackjax = model.fit(nuts_sampler="blackjax", **sampler_kwargs)
Finalmente, podemos utilizar nutpie, que es una implementación en Rust de NUTS.
idata_nutpie = model.fit(nuts_sampler="nutpie", **sampler_kwargs)
Los resultados de los muestreadores son casi idénticos:
%load_ext watermark
%watermark -n -u -v -iv -w -p blackjax,numpyro,nutpie,pymc
Last updated: Sat Mar 09 2024
Python implementation: CPython
Python version : 3.11.3
IPython version : 8.20.0
blackjax: 0.0.0
numpyro : 0.14.0
nutpie : 0.9.2
pymc : 5.10.4
arviz : 0.15.1
matplotlib : 3.7.1
pymc_marketing: 0.4.0
Watermark: 2.4.3