import numpy as np
import nutpie
import pandas as pd
from lifetimes import BetaGeoBetaBinomFitter
import pymc as pm
from pymc_marketing.clv import BetaGeoBetaBinomModel
from pymc_extras.prior import Prior
data = pd.read_csv("https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/bgbb_donations.csv")
data.head()
model_config = {
"alpha": Prior("HalfFlat"),
"beta": Prior("HalfFlat"),
"gamma": Prior("HalfFlat"),
"delta": Prior("HalfFlat"),
}
model = BetaGeoBetaBinomModel(data=data,model_config=model_config)
model.build_model()
model
BG/BB
alpha ~ HalfFlat()
beta ~ HalfFlat()
gamma ~ HalfFlat()
delta ~ HalfFlat()
recency_frequency ~ BetaGeoBetaBinom(alpha, beta, gamma, delta, <constant>)
pm.model_to_graphviz(model.model)
model.fit(fit_method='map')
arviz.InferenceData
-
<xarray.Dataset> Size: 48B Dimensions: (chain: 1, draw: 1) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 8B 0 Data variables: alpha (chain, draw) float64 8B 1.204 beta (chain, draw) float64 8B 0.7497 delta (chain, draw) float64 8B 2.784 gamma (chain, draw) float64 8B 0.6568 Attributes: created_at: 2024-09-12T21:32:31.746938+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.13.0 -
<xarray.Dataset> Size: 267kB Dimensions: (customer_id: 11104, obs_var: 2) Coordinates: * customer_id (customer_id) int64 89kB 0 1 2 3 ... 11101 11102 11103 * obs_var (obs_var) <U9 72B 'recency' 'frequency' Data variables: recency_frequency (customer_id, obs_var) float64 178kB 0.0 0.0 ... 6.0 6.0 Attributes: created_at: 2024-09-12T21:32:31.749978+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.13.0 -
<xarray.Dataset> Size: 444kB Dimensions: (index: 11104) Coordinates: * index (index) int64 89kB 0 1 2 3 4 ... 11099 11100 11101 11102 11103 Data variables: customer_id (index) int64 89kB 0 1 2 3 4 ... 11099 11100 11101 11102 11103 frequency (index) int64 89kB 0 0 0 0 0 0 0 0 0 0 ... 6 6 6 6 6 6 6 6 6 6 recency (index) int64 89kB 0 0 0 0 0 0 0 0 0 0 ... 6 6 6 6 6 6 6 6 6 6 T (index) int64 89kB 6 6 6 6 6 6 6 6 6 6 ... 6 6 6 6 6 6 6 6 6 6
model.fit_summary()
alpha 1.204
beta 0.750
delta 2.784
gamma 0.657
Name: value, dtype: float64
Compare to lifetimes
bgbb = BetaGeoBetaBinomFitter().fit(data['frequency'].values,
data['recency'].values,
data['T'].values
)
bgbb
<lifetimes.BetaGeoBetaBinomFitter: fitted with 22 subjects, alpha: 1.20, beta: 0.75, delta: 2.78, gamma: 0.66>
mcmc_model = BetaGeoBetaBinomModel(data=data)
mcmc_model.build_model()
pm.model_to_graphviz(mcmc_model.model)
with mcmc_model.model:
prior_idata = pm.sample_prior_predictive()
Sampling: [kappa_dropout, kappa_purchase, phi_dropout, phi_purchase, recency_frequency]
prior_idata
arviz.InferenceData
-
<xarray.Dataset> Size: 36kB Dimensions: (chain: 1, draw: 500) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 Data variables: alpha (chain, draw) float64 4kB 4.591 0.6012 ... 0.4558 1.681 beta (chain, draw) float64 4kB 0.2142 4.733 ... 6.893 0.1661 delta (chain, draw) float64 4kB 1.696 1.264 229.4 ... 0.1884 1.227 gamma (chain, draw) float64 4kB 0.8744 0.3572 ... 0.9208 0.9806 kappa_dropout (chain, draw) float64 4kB 2.571 1.621 263.1 ... 1.109 2.207 kappa_purchase (chain, draw) float64 4kB 4.805 5.335 3.053 ... 7.348 1.847 phi_dropout (chain, draw) float64 4kB 0.3401 0.2203 ... 0.8302 0.4443 phi_purchase (chain, draw) float64 4kB 0.9554 0.1127 ... 0.06203 0.9101 Attributes: created_at: 2024-09-12T21:43:04.658373+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.13.0 -
<xarray.Dataset> Size: 89MB Dimensions: (chain: 1, draw: 500, customer_id: 11104, obs_var: 2) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 4kB 0 1 2 3 4 5 ... 494 495 496 497 498 499 * customer_id (customer_id) int64 89kB 0 1 2 3 ... 11101 11102 11103 * obs_var (obs_var) <U9 72B 'recency' 'frequency' Data variables: recency_frequency (chain, draw, customer_id, obs_var) float64 89MB 1.0 .... Attributes: created_at: 2024-09-12T21:43:04.660348+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.13.0 -
<xarray.Dataset> Size: 267kB Dimensions: (customer_id: 11104, obs_var: 2) Coordinates: * customer_id (customer_id) int64 89kB 0 1 2 3 ... 11101 11102 11103 * obs_var (obs_var) <U9 72B 'recency' 'frequency' Data variables: recency_frequency (customer_id, obs_var) float64 178kB 0.0 0.0 ... 6.0 6.0 Attributes: created_at: 2024-09-12T21:43:04.661115+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.13.0
nutpie es aproximadamente 3 veces más rápido que el muestreador NUTS predeterminado en pymc.
# add warning supress here
mcmc_model.fit(nuts_sampler="nutpie")
/Users/coltallen/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/nutpie/compile_pymc.py:554: NumbaWarning: Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals (such as ctypes pointers and large global arrays)
return inner(x)
/Users/coltallen/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/nutpie/compile_pymc.py:554: NumbaWarning: Cannot cache compiled function "scan" as it uses dynamic globals (such as ctypes pointers and large global arrays)
return inner(x)
/Users/coltallen/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/nutpie/compile_pymc.py:554: NumbaWarning: Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals (such as ctypes pointers and large global arrays)
return inner(x)
Sampler Progress
Total Chains: 4
Active Chains: 0
Finished Chains: 4
Sampling for 6 minutes
Estimated Time to Completion: now
| Progress | Draws | Divergences | Step Size | Gradients/Draw |
|---|---|---|---|---|
| 2000 | 0 | 0.68 | 7 | |
| 2000 | 0 | 0.69 | 3 | |
| 2000 | 0 | 0.66 | 3 | |
| 2000 | 0 | 0.64 | 7 |
arviz.InferenceData
-
<xarray.Dataset> Size: 392kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 ... 996 997 998 999 Data variables: alpha (chain, draw) float64 32kB 1.32 1.196 ... 1.15 beta (chain, draw) float64 32kB 0.7781 ... 0.7317 delta (chain, draw) float64 32kB 2.402 2.853 ... 2.966 gamma (chain, draw) float64 32kB 0.6269 ... 0.6927 kappa_dropout (chain, draw) float64 32kB 3.029 3.539 ... 3.658 kappa_dropout_interval__ (chain, draw) float64 32kB 0.7075 ... 0.9777 kappa_purchase (chain, draw) float64 32kB 2.098 1.946 ... 1.882 kappa_purchase_interval__ (chain, draw) float64 32kB 0.09346 ... -0.1257 phi_dropout (chain, draw) float64 32kB 0.207 0.194 ... 0.1893 phi_dropout_interval__ (chain, draw) float64 32kB -1.343 ... -1.454 phi_purchase (chain, draw) float64 32kB 0.6291 ... 0.6112 phi_purchase_interval__ (chain, draw) float64 32kB 0.5284 ... 0.4523 Attributes: created_at: 2024-09-12T21:27:00.667319+00:00 arviz_version: 0.18.0 inference_library: nutpie inference_library_version: 0.13.2 sampling_time: 371.8816521167755 -
<xarray.Dataset> Size: 336kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999 Data variables: depth (chain, draw) uint64 32kB 3 3 3 3 3 3 ... 2 3 3 2 2 3 diverging (chain, draw) bool 4kB False False ... False False energy (chain, draw) float64 32kB 3.324e+04 ... 3.323e+04 energy_error (chain, draw) float64 32kB 0.637 -0.1158 ... 0.0765 index_in_trajectory (chain, draw) int64 32kB -5 4 3 4 -5 ... 4 -3 -3 -1 5 logp (chain, draw) float64 32kB -3.324e+04 ... -3.323e+04 maxdepth_reached (chain, draw) bool 4kB False False ... False False mean_tree_accept (chain, draw) float64 32kB 0.9861 0.724 ... 1.0 0.9373 mean_tree_accept_sym (chain, draw) float64 32kB 0.9656 0.8221 ... 0.9368 n_steps (chain, draw) uint64 32kB 7 11 15 11 7 ... 3 7 11 3 7 step_size (chain, draw) float64 32kB 0.6836 0.6836 ... 0.6418 step_size_bar (chain, draw) float64 32kB 0.6836 0.6836 ... 0.6418 Attributes: created_at: 2024-09-12T21:27:00.659509+00:00 arviz_version: 0.18.0 -
<xarray.Dataset> Size: 267kB Dimensions: (customer_id: 11104, obs_var: 2) Coordinates: * customer_id (customer_id) int64 89kB 0 1 2 3 ... 11101 11102 11103 * obs_var (obs_var) <U9 72B 'recency' 'frequency' Data variables: recency_frequency (customer_id, obs_var) float64 178kB 0.0 0.0 ... 6.0 6.0 Attributes: created_at: 2024-09-12T21:27:00.666769+00:00 arviz_version: 0.18.0 inference_library: pymc inference_library_version: 5.13.0 -
<xarray.Dataset> Size: 444kB Dimensions: (index: 11104) Coordinates: * index (index) int64 89kB 0 1 2 3 4 ... 11099 11100 11101 11102 11103 Data variables: customer_id (index) int64 89kB 0 1 2 3 4 ... 11099 11100 11101 11102 11103 frequency (index) int64 89kB 0 0 0 0 0 0 0 0 0 0 ... 6 6 6 6 6 6 6 6 6 6 recency (index) int64 89kB 0 0 0 0 0 0 0 0 0 0 ... 6 6 6 6 6 6 6 6 6 6 T (index) int64 89kB 6 6 6 6 6 6 6 6 6 6 ... 6 6 6 6 6 6 6 6 6 6 -
<xarray.Dataset> Size: 392kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 ... 996 997 998 999 Data variables: alpha (chain, draw) float64 32kB 1.217 1.217 ... 1.171 beta (chain, draw) float64 32kB 0.2809 ... 0.7511 delta (chain, draw) float64 32kB 1.244 1.244 ... 3.435 gamma (chain, draw) float64 32kB 1.931 1.931 ... 0.7753 kappa_dropout (chain, draw) float64 32kB 3.175 3.175 ... 4.211 kappa_dropout_interval__ (chain, draw) float64 32kB 0.7772 ... 1.166 kappa_purchase (chain, draw) float64 32kB 1.498 1.498 ... 1.922 kappa_purchase_interval__ (chain, draw) float64 32kB -0.6972 ... -0.08079 phi_dropout (chain, draw) float64 32kB 0.6081 ... 0.1841 phi_dropout_interval__ (chain, draw) float64 32kB 0.4393 ... -1.489 phi_purchase (chain, draw) float64 32kB 0.8125 ... 0.6093 phi_purchase_interval__ (chain, draw) float64 32kB 1.466 1.466 ... 0.4444 Attributes: created_at: 2024-09-12T21:27:00.655861+00:00 arviz_version: 0.18.0 -
<xarray.Dataset> Size: 336kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999 Data variables: depth (chain, draw) uint64 32kB 0 0 2 1 0 0 ... 3 2 3 2 3 3 diverging (chain, draw) bool 4kB True True False ... False False energy (chain, draw) float64 32kB 4.085e+04 ... 3.324e+04 energy_error (chain, draw) float64 32kB 0.0 0.0 ... -0.0498 0.6378 index_in_trajectory (chain, draw) int64 32kB 0 0 2 -1 0 ... -3 -4 -3 4 -5 logp (chain, draw) float64 32kB -4.084e+04 ... -3.323e+04 maxdepth_reached (chain, draw) bool 4kB False False ... False False mean_tree_accept (chain, draw) float64 32kB 0.0 0.0 ... 0.9126 0.867 mean_tree_accept_sym (chain, draw) float64 32kB 0.0 0.0 ... 0.9539 0.918 n_steps (chain, draw) uint64 32kB 0 1 1 3 1 1 ... 15 7 11 3 15 step_size (chain, draw) float64 32kB 3.2 7.472 ... 0.6007 0.6479 step_size_bar (chain, draw) float64 32kB 3.2 7.472 ... 0.6418 0.6418 Attributes: created_at: 2024-09-12T21:27:00.663215+00:00 arviz_version: 0.18.0