import arviz as az 
import matplotlib.pyplot as plt
import pandas as pd
import pymc as pm
import pymc_extras as pmx
import seaborn as sns

from pymc_marketing import clv

Cargar datos#

url_cdnow = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/cdnow_transactions.csv"

raw_data = pd.read_csv(url_cdnow)

Agregaciones RFM#

rfm_data = clv.rfm_summary(
    raw_data,
    customer_id_col="id",
    datetime_col="date",
    datetime_format="%Y%m%d",
    time_unit="W",
)

rfm_data.info()
rfm_data.head()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2357 entries, 0 to 2356
Data columns (total 4 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   customer_id  2357 non-null   int64  
 1   frequency    2357 non-null   float64
 2   recency      2357 non-null   float64
 3   T            2357 non-null   float64
dtypes: float64(3), int64(1)
memory usage: 73.8 KB
customer_id frequency recency T
0 1 3.0 49.0 78.0
1 2 1.0 2.0 78.0
2 3 0.0 0.0 78.0
3 4 0.0 0.0 78.0
4 5 0.0 0.0 78.0

ModeloParetoNBD#

ajuste del MAP#

pnbd_map = clv.ParetoNBDModel(data=rfm_data)
pnbd_map.build_model()  # required for prior predictive checks
pnbd_map
Pareto/NBD
            alpha ~ Weibull(2, 10)
             beta ~ Weibull(2, 10)
                r ~ Weibull(2, 1)
                s ~ Weibull(2, 1)
recency_frequency ~ ParetoNBD(r, alpha, s, beta, <constant>)
with pnbd_map.model:
    prior_idata = pm.sample_prior_predictive(random_seed=45, draws=1)

obs_freq = prior_idata.observed_data["recency_frequency"].sel(obs_var="frequency")
ppc_freq = prior_idata.prior_predictive["recency_frequency"].sel(obs_var="frequency")[
    0
][0]
Sampling: [alpha, beta, r, recency_frequency, s]
pnbd_map.fit()
map_fit = pnbd_map.fit_summary()  # save for plotting later

obs_freq = pnbd_map.idata.observed_data["recency_frequency"].sel(obs_var="frequency")
ppc_freq = pnbd_map.distribution_new_customer_recency_frequency(
    rfm_data,
    random_seed=42,
).sel(chain=0, draw=0, obs_var="frequency")


Sampling: [recency_frequency]


DEMZ se ajusta#

pnbd_full = clv.ParetoNBDModel(data=rfm_data)
pnbd_full.fit(
    fit_method="demz", draws=3000, tune=2500, idata_kwargs={"log_likelihood": True}
)
Multiprocess sampling (4 chains in 4 jobs)
DEMetropolisZ: [alpha, beta, r, s]


Sampling 4 chains for 2_500 tune and 3_000 draw iterations (10_000 + 12_000 draws total) took 6 seconds.
arviz.InferenceData
    • <xarray.Dataset> Size: 408kB
      Dimensions:  (chain: 4, draw: 3000)
      Coordinates:
        * chain    (chain) int64 32B 0 1 2 3
        * draw     (draw) int64 24kB 0 1 2 3 4 5 6 ... 2994 2995 2996 2997 2998 2999
      Data variables:
          alpha    (chain, draw) float64 96kB 14.23 14.23 14.23 ... 16.13 16.13 16.6
          beta     (chain, draw) float64 96kB 13.44 13.44 13.44 ... 9.076 9.076 11.54
          r        (chain, draw) float64 96kB 0.5573 0.5573 0.5573 ... 0.6239 0.6777
          s        (chain, draw) float64 96kB 0.4499 0.4499 0.4499 ... 0.3464 0.4332
      Attributes:
          created_at:                 2025-03-02T15:41:34.816242+00:00
          arviz_version:              0.18.0
          inference_library:          pymc
          inference_library_version:  5.20.0
          sampling_time:              6.329453945159912
          tuning_steps:               2500

    • <xarray.Dataset> Size: 226MB
      Dimensions:            (chain: 4, draw: 3000, customer_id: 2357)
      Coordinates:
        * chain              (chain) int64 32B 0 1 2 3
        * draw               (draw) int64 24kB 0 1 2 3 4 ... 2995 2996 2997 2998 2999
        * customer_id        (customer_id) int64 19kB 1 2 3 4 ... 2354 2355 2356 2357
      Data variables:
          recency_frequency  (chain, draw, customer_id) float64 226MB -14.34 ... -0...
      Attributes:
          created_at:                 2025-03-02T15:41:40.869292+00:00
          arviz_version:              0.18.0
          inference_library:          pymc
          inference_library_version:  5.20.0

    • <xarray.Dataset> Size: 324kB
      Dimensions:   (chain: 4, draw: 3000)
      Coordinates:
        * chain     (chain) int64 32B 0 1 2 3
        * draw      (draw) int64 24kB 0 1 2 3 4 5 6 ... 2994 2995 2996 2997 2998 2999
      Data variables:
          accept    (chain, draw) float64 96kB 1.131 0.0002704 ... 0.006623 2.132
          accepted  (chain, draw) bool 12kB True False False ... False False True
          lambda    (chain, draw) float64 96kB 0.8415 0.8415 0.8415 ... 0.8415 0.8415
          scaling   (chain, draw) float64 96kB 0.0002542 0.0002542 ... 0.0003835
      Attributes:
          created_at:                 2025-03-02T15:41:34.820112+00:00
          arviz_version:              0.18.0
          inference_library:          pymc
          inference_library_version:  5.20.0
          sampling_time:              6.329453945159912
          tuning_steps:               2500

    • <xarray.Dataset> Size: 57kB
      Dimensions:            (customer_id: 2357, obs_var: 2)
      Coordinates:
        * customer_id        (customer_id) int64 19kB 1 2 3 4 ... 2354 2355 2356 2357
        * obs_var            (obs_var) <U9 72B 'recency' 'frequency'
      Data variables:
          recency_frequency  (customer_id, obs_var) float64 38kB 49.0 3.0 ... 0.0 0.0
      Attributes:
          created_at:                 2025-03-02T15:41:34.821948+00:00
          arviz_version:              0.18.0
          inference_library:          pymc
          inference_library_version:  5.20.0

    • <xarray.Dataset> Size: 94kB
      Dimensions:      (index: 2357)
      Coordinates:
        * index        (index) int64 19kB 0 1 2 3 4 5 ... 2352 2353 2354 2355 2356
      Data variables:
          customer_id  (index) int64 19kB 1 2 3 4 5 6 ... 2353 2354 2355 2356 2357
          frequency    (index) float64 19kB 3.0 1.0 0.0 0.0 0.0 ... 5.0 1.0 6.0 0.0
          recency      (index) float64 19kB 49.0 2.0 0.0 0.0 ... 24.0 44.0 62.0 0.0
          T            (index) float64 19kB 78.0 78.0 78.0 78.0 ... 66.0 66.0 66.0

pnbd_full.fit_summary()
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
alpha 15.614 1.050 13.668 17.537 0.036 0.026 862.0 996.0 1.0
beta 12.750 3.691 6.362 19.946 0.135 0.096 739.0 1323.0 1.0
r 0.619 0.046 0.542 0.712 0.002 0.001 855.0 898.0 1.0
s 0.431 0.060 0.320 0.544 0.002 0.001 824.0 1251.0 1.0

Nutpie Fit#

pnbd_nutpie = clv.ParetoNBDModel(data=rfm_data)
bgnbd_nutpie = clv.BetaGeoModel(data=rfm_data)

idata_nutpie = pnbd_nutpie.fit(fit_method = 'mcmc', nuts_sampler="nutpie")
/Users/coltallen/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/pymc/pytensorf.py:1066: FutureWarning: compile_pymc was renamed to compile. Old name will be removed in a future release of PyMC
  warnings.warn(
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[17], line 4
      1 pnbd_nutpie = clv.ParetoNBDModel(data=rfm_data)
      2 bgnbd_nutpie = clv.BetaGeoModel(data=rfm_data)
----> 4 idata_nutpie = pnbd_nutpie.fit(fit_method = 'mcmc', nuts_sampler="nutpie")

File ~/Projects/pymc-marketing/pymc_marketing/clv/models/pareto_nbd.py:350, in ParetoNBDModel.fit(self, fit_method, **kwargs)
    345 with warnings.catch_warnings():
    346     warnings.simplefilter(
    347         action="ignore",
    348         category=UserWarning,
    349     )
--> 350     return super().fit(fit_method, **kwargs)

File ~/Projects/pymc-marketing/pymc_marketing/clv/models/basic.py:135, in CLVModel.fit(self, fit_method, **kwargs)
    133 match fit_method:
    134     case "mcmc":
--> 135         idata = self._fit_mcmc(**kwargs)
    136     case "map":
    137         idata = self._fit_MAP(**kwargs)

File ~/Projects/pymc-marketing/pymc_marketing/clv/models/basic.py:164, in CLVModel._fit_mcmc(self, **kwargs)
    162     sampler_config = self.sampler_config.copy()
    163 sampler_config.update(**kwargs)
--> 164 return pm.sample(**sampler_config, model=self.model)

File ~/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/pymc/sampling/mcmc.py:781, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
    776         raise ValueError(
    777             "Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability."
    778         )
    780     with joined_blas_limiter():
--> 781         return _sample_external_nuts(
    782             sampler=nuts_sampler,
    783             draws=draws,
    784             tune=tune,
    785             chains=chains,
    786             target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    787             random_seed=random_seed,
    788             initvals=initvals,
    789             model=model,
    790             var_names=var_names,
    791             progressbar=progressbar,
    792             idata_kwargs=idata_kwargs,
    793             compute_convergence_checks=compute_convergence_checks,
    794             nuts_sampler_kwargs=nuts_sampler_kwargs,
    795             **kwargs,
    796         )
    798 if exclusive_nuts and not provided_steps:
    799     # Special path for NUTS initialization
    800     if "nuts" in kwargs:

File ~/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/pymc/sampling/mcmc.py:337, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, idata_kwargs, compute_convergence_checks, nuts_sampler_kwargs, **kwargs)
    335     if kwarg in nuts_sampler_kwargs:
    336         compile_kwargs[kwarg] = nuts_sampler_kwargs.pop(kwarg)
--> 337 compiled_model = nutpie.compile_pymc_model(
    338     model,
    339     **compile_kwargs,
    340 )
    341 t_start = time.time()
    342 idata = nutpie.sample(
    343     compiled_model,
    344     draws=draws,
   (...)
    350     **nuts_sampler_kwargs,
    351 )

File ~/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/nutpie/compile_pymc.py:391, in compile_pymc_model(model, backend, gradient_backend, **kwargs)
    388     backend = "numba"
    390 if backend.lower() == "numba":
--> 391     return _compile_pymc_model_numba(model, **kwargs)
    392 elif backend.lower() == "jax":
    393     return _compile_pymc_model_jax(
    394         model, gradient_backend=gradient_backend, **kwargs
    395     )

File ~/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/nutpie/compile_pymc.py:176, in _compile_pymc_model_numba(model, **kwargs)
    162     raise ImportError(
    163         "Numba is not installed in the current environment. "
    164         "Please install it with something like "
    165         "'mamba install -c conda-forge numba' "
    166         "and restart your kernel in case you are in an interactive session."
    167     )
    168 import numba
    170 (
    171     n_dim,
    172     n_expanded,
    173     logp_fn_pt,
    174     expand_fn_pt,
    175     shape_info,
--> 176 ) = _make_functions(model, mode="NUMBA", compute_grad=True, join_expanded=True)
    178 expand_fn = expand_fn_pt.vm.jit_fn
    179 logp_fn = logp_fn_pt.vm.jit_fn

File ~/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/nutpie/compile_pymc.py:497, in _make_functions(model, mode, compute_grad, join_expanded)
    495     (logp, grad) = pytensor.clone_replace([logp, grad], replacements)
    496     with model:
--> 497         logp_fn_pt = compile_pymc((joined,), (logp, grad), mode=mode)
    498 else:
    499     (logp,) = pytensor.clone_replace([logp], replacements)

File ~/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/pymc/pytensorf.py:1070, in compile_pymc(*args, **kwargs)
   1065 def compile_pymc(*args, **kwargs):
   1066     warnings.warn(
   1067         "compile_pymc was renamed to compile. Old name will be removed in a future release of PyMC",
   1068         FutureWarning,
   1069     )
-> 1070     return compile(*args, **kwargs)

File ~/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/pymc/pytensorf.py:1055, in compile(inputs, outputs, random_seed, mode, **kwargs)
   1053 opt_qry = mode.provided_optimizer.including("random_make_inplace", check_parameter_opt)
   1054 mode = Mode(linker=mode.linker, optimizer=opt_qry)
-> 1055 pytensor_function = pytensor.function(
   1056     inputs,
   1057     outputs,
   1058     updates={**rng_updates, **kwargs.pop("updates", {})},
   1059     mode=mode,
   1060     **kwargs,
   1061 )
   1062 return pytensor_function

File ~/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/pytensor/compile/function/__init__.py:318, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input)
    312     fn = orig_function(
    313         inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name
    314     )
    315 else:
    316     # note: pfunc will also call orig_function -- orig_function is
    317     #      a choke point that all compilation must pass through
--> 318     fn = pfunc(
    319         params=inputs,
    320         outputs=outputs,
    321         mode=mode,
    322         updates=updates,
    323         givens=givens,
    324         no_default_updates=no_default_updates,
    325         accept_inplace=accept_inplace,
    326         name=name,
    327         rebuild_strict=rebuild_strict,
    328         allow_input_downcast=allow_input_downcast,
    329         on_unused_input=on_unused_input,
    330         profile=profile,
    331         output_keys=output_keys,
    332     )
    333 return fn

File ~/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/pytensor/compile/function/pfunc.py:465, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph)
    451     profile = ProfileStats(message=profile)
    453 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
    454     params,
    455     outputs,
   (...)
    462     fgraph=fgraph,
    463 )
--> 465 return orig_function(
    466     inputs,
    467     cloned_outputs,
    468     mode,
    469     accept_inplace=accept_inplace,
    470     name=name,
    471     profile=profile,
    472     on_unused_input=on_unused_input,
    473     output_keys=output_keys,
    474     fgraph=fgraph,
    475 )

File ~/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/pytensor/compile/function/types.py:1769, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph)
   1757     m = Maker(
   1758         inputs,
   1759         outputs,
   (...)
   1766         fgraph=fgraph,
   1767     )
   1768     with config.change_flags(compute_test_value="off"):
-> 1769         fn = m.create(defaults)
   1770 finally:
   1771     if profile and fn:

File ~/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/pytensor/compile/function/types.py:1661, in FunctionMaker.create(self, input_storage, storage_map)
   1658 start_import_time = pytensor.link.c.cmodule.import_time
   1660 with config.change_flags(traceback__limit=config.traceback__compile_limit):
-> 1661     _fn, _i, _o = self.linker.make_thunk(
   1662         input_storage=input_storage_lists, storage_map=storage_map
   1663     )
   1665 end_linker = time.perf_counter()
   1667 linker_time = end_linker - start_linker

File ~/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/pytensor/link/basic.py:245, in LocalLinker.make_thunk(self, input_storage, output_storage, storage_map, **kwargs)
    238 def make_thunk(
    239     self,
    240     input_storage: Optional["InputStorageType"] = None,
   (...)
    243     **kwargs,
    244 ) -> tuple["BasicThunkType", "InputStorageType", "OutputStorageType"]:
--> 245     return self.make_all(
    246         input_storage=input_storage,
    247         output_storage=output_storage,
    248         storage_map=storage_map,
    249     )[:3]

File ~/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/pytensor/link/basic.py:695, in JITLinker.make_all(self, input_storage, output_storage, storage_map)
    692 for k in storage_map:
    693     compute_map[k] = [k.owner is None]
--> 695 thunks, nodes, jit_fn = self.create_jitable_thunk(
    696     compute_map, nodes, input_storage, output_storage, storage_map
    697 )
    699 [fn] = thunks
    700 fn.jit_fn = jit_fn

File ~/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/pytensor/link/basic.py:647, in JITLinker.create_jitable_thunk(self, compute_map, order, input_storage, output_storage, storage_map)
    644 # This is a bit hackish, but we only return one of the output nodes
    645 output_nodes = [o.owner for o in self.fgraph.outputs if o.owner is not None][:1]
--> 647 converted_fgraph = self.fgraph_convert(
    648     self.fgraph,
    649     order=order,
    650     input_storage=input_storage,
    651     output_storage=output_storage,
    652     storage_map=storage_map,
    653 )
    655 thunk_inputs = self.create_thunk_inputs(storage_map)
    656 thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]

File ~/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/pytensor/link/numba/linker.py:10, in NumbaLinker.fgraph_convert(self, fgraph, **kwargs)
      7 def fgraph_convert(self, fgraph, **kwargs):
      8     from pytensor.link.numba.dispatch import numba_funcify
---> 10     return numba_funcify(fgraph, **kwargs)

File ~/miniconda3/envs/pymc-marketing-dev/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/pytensor/link/numba/dispatch/basic.py:463, in numba_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
    456 @numba_funcify.register(FunctionGraph)
    457 def numba_funcify_FunctionGraph(
    458     fgraph,
   (...)
    461     **kwargs,
    462 ):
--> 463     return fgraph_to_python(
    464         fgraph,
    465         numba_funcify,
    466         type_conversion_fn=numba_typify,
    467         fgraph_name=fgraph_name,
    468         **kwargs,
    469     )

File ~/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/pytensor/link/utils.py:736, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, unique_name, **kwargs)
    734 body_assigns = []
    735 for node in order:
--> 736     compiled_func = op_conversion_fn(
    737         node.op, node=node, storage_map=storage_map, **kwargs
    738     )
    740     # Create a local alias with a unique name
    741     local_compiled_func_name = unique_name(compiled_func)

File ~/miniconda3/envs/pymc-marketing-dev/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/pytensor/link/numba/dispatch/elemwise.py:357, in numba_funcify_Elemwise(op, node, **kwargs)
    355 if not isinstance(op.scalar_op, Composite):
    356     scalar_inputs = [scalar(dtype=input.dtype) for input in node.inputs]
--> 357     scalar_node = op.scalar_op.make_node(*scalar_inputs)
    359 scalar_op_fn = numba_funcify(
    360     op.scalar_op,
    361     node=scalar_node,
   (...)
    364     **kwargs,
    365 )
    367 nin = len(node.inputs)

File ~/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/pytensor/scalar/loop.py:180, in ScalarLoop.make_node(self, n_steps, *inputs)
    178 cloned_constant = cloned_inputs[len(cloned_update) :]
    179 # This will fail if the cloned init have a different dtype than the cloned_update
--> 180 op = ScalarLoop(
    181     init=cloned_init,
    182     update=cloned_update,
    183     constant=cloned_constant,
    184     until=cloned_until,
    185     name=self.name,
    186 )
    187 node = op.make_node(n_steps, *inputs)
    188 return node

File ~/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/pytensor/scalar/loop.py:69, in ScalarLoop.__init__(self, init, update, constant, until, name)
     66     inputs, outputs = clone([*init, *constant], update)
     68 self.is_while = bool(until)
---> 69 self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)
     70 self._validate_updates(self.inputs, self.outputs)
     72 self.inputs_type = tuple(input.type for input in self.inputs)

File ~/miniconda3/envs/pymc-marketing-dev/lib/python3.10/site-packages/pytensor/scalar/basic.py:3992, in ScalarInnerGraphOp._cleanup_graph(self, inputs, outputs)
   3990 for node in fgraph.apply_nodes:
   3991     if not isinstance(node.op, ScalarOp):
-> 3992         raise TypeError(
   3993             f"The fgraph of {self.__class__.__name__} must be exclusively "
   3994             "composed of scalar operations."
   3995         )
   3997 # Run MergeOptimization to avoid duplicated nodes
   3998 MergeOptimizer().rewrite(fgraph)

TypeError: The fgraph of ScalarLoop must be exclusively composed of scalar operations.

ADVI fit#

pnbd_advi = clv.ParetoNBDModel(data=rfm_data)
pnbd_advi.fit(
    n=12500,
    fit_method="advi", 
    obj_n_mc=15,
    # obj_optimizer=pm.adagrad(learning_rate=100.),
    idata_kwargs={"log_likelihood": True}
)

Finished [100%]: Average Loss = 16,486
arviz.InferenceData
    • <xarray.Dataset> Size: 20kB
      Dimensions:  (chain: 1, draw: 500)
      Coordinates:
        * chain    (chain) int64 8B 0
        * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
      Data variables:
          alpha    (chain, draw) float64 4kB 15.74 15.65 14.86 ... 16.01 15.87 15.23
          beta     (chain, draw) float64 4kB 12.93 12.05 12.89 ... 14.93 11.61 11.7
          r        (chain, draw) float64 4kB 0.6118 0.5966 0.5927 ... 0.6028 0.6444
          s        (chain, draw) float64 4kB 0.4536 0.4336 0.4428 ... 0.4636 0.4284
      Attributes:
          created_at:                 2025-01-15T18:22:23.502516+00:00
          arviz_version:              0.20.0
          inference_library:          pymc
          inference_library_version:  5.20.0

    • <xarray.Dataset> Size: 57kB
      Dimensions:            (customer_id: 2357, obs_var: 2)
      Coordinates:
        * customer_id        (customer_id) int64 19kB 1 2 3 4 ... 2354 2355 2356 2357
        * obs_var            (obs_var) <U9 72B 'recency' 'frequency'
      Data variables:
          recency_frequency  (customer_id, obs_var) float64 38kB 49.0 3.0 ... 0.0 0.0
      Attributes:
          created_at:                 2025-01-15T18:22:23.509273+00:00
          arviz_version:              0.20.0
          inference_library:          pymc
          inference_library_version:  5.20.0

    • <xarray.Dataset> Size: 94kB
      Dimensions:      (index: 2357)
      Coordinates:
        * index        (index) int64 19kB 0 1 2 3 4 5 ... 2352 2353 2354 2355 2356
      Data variables:
          customer_id  (index) int64 19kB 1 2 3 4 5 6 ... 2353 2354 2355 2356 2357
          frequency    (index) float64 19kB 3.0 1.0 0.0 0.0 0.0 ... 5.0 1.0 6.0 0.0
          recency      (index) float64 19kB 49.0 2.0 0.0 0.0 ... 24.0 44.0 62.0 0.0
          T            (index) float64 19kB 78.0 78.0 78.0 78.0 ... 66.0 66.0 66.0

pnbd_advi.fit_summary()
arviz - WARNING - Shape validation failed: input_shape: (1, 500), minimum_shape: (chains=2, draws=4)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
alpha 15.547 0.567 14.514 16.599 0.030 0.021 366.0 453.0 NaN
beta 12.531 1.160 10.681 14.933 0.054 0.038 458.0 372.0 NaN
r 0.612 0.019 0.575 0.646 0.001 0.001 448.0 353.0 NaN
s 0.429 0.021 0.396 0.472 0.001 0.001 502.0 446.0 NaN
sns.lineplot(pnbd_advi.approx.hist, )
plt.yscale('log')
../../../_images/91c7b4cf415cae8c376314382f24087f0dfb480b5c9e4a2235c8cdc631e2d54e.png

Ajuste de rango completo#

pnbd_fullrank = clv.ParetoNBDModel(data=rfm_data)
pnbd_fullrank.fit(
    fit_method="fullrank_advi", 
    obj_n_mc=5,
    idata_kwargs={"log_likelihood": True}
)

Finished [100%]: Average Loss = 16,530
arviz.InferenceData
    • <xarray.Dataset> Size: 20kB
      Dimensions:  (chain: 1, draw: 500)
      Coordinates:
        * chain    (chain) int64 8B 0
        * draw     (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
      Data variables:
          alpha    (chain, draw) float64 4kB 17.41 12.11 15.68 ... 13.78 13.21 25.11
          beta     (chain, draw) float64 4kB 27.5 19.56 20.12 ... 4.865 3.775 9.99
          r        (chain, draw) float64 4kB 0.5248 0.3552 0.6851 ... 0.7622 0.7358
          s        (chain, draw) float64 4kB 0.5267 0.3196 0.5446 ... 0.2275 0.3081
      Attributes:
          created_at:                 2025-01-15T19:11:38.027406+00:00
          arviz_version:              0.20.0
          inference_library:          pymc
          inference_library_version:  5.20.0

    • <xarray.Dataset> Size: 57kB
      Dimensions:            (customer_id: 2357, obs_var: 2)
      Coordinates:
        * customer_id        (customer_id) int64 19kB 1 2 3 4 ... 2354 2355 2356 2357
        * obs_var            (obs_var) <U9 72B 'recency' 'frequency'
      Data variables:
          recency_frequency  (customer_id, obs_var) float64 38kB 49.0 3.0 ... 0.0 0.0
      Attributes:
          created_at:                 2025-01-15T19:11:38.033900+00:00
          arviz_version:              0.20.0
          inference_library:          pymc
          inference_library_version:  5.20.0

    • <xarray.Dataset> Size: 94kB
      Dimensions:      (index: 2357)
      Coordinates:
        * index        (index) int64 19kB 0 1 2 3 4 5 ... 2352 2353 2354 2355 2356
      Data variables:
          customer_id  (index) int64 19kB 1 2 3 4 5 6 ... 2353 2354 2355 2356 2357
          frequency    (index) float64 19kB 3.0 1.0 0.0 0.0 0.0 ... 5.0 1.0 6.0 0.0
          recency      (index) float64 19kB 49.0 2.0 0.0 0.0 ... 24.0 44.0 62.0 0.0
          T            (index) float64 19kB 78.0 78.0 78.0 78.0 ... 66.0 66.0 66.0

pnbd_fullrank.fit_summary()
arviz - WARNING - Shape validation failed: input_shape: (1, 500), minimum_shape: (chains=2, draws=4)
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
alpha 15.207 5.127 6.969 26.097 0.220 0.159 552.0 428.0 NaN
beta 14.941 10.196 2.099 33.381 0.503 0.394 467.0 363.0 NaN
r 0.592 0.209 0.253 0.973 0.010 0.007 391.0 409.0 NaN
s 0.417 0.192 0.143 0.786 0.009 0.007 458.0 462.0 NaN
sns.lineplot(pnbd_advi.approx.hist)
<Axes: >
../../../_images/c7e634de704f6c9fe294ce0a9f95fa8a01b0d7067c07f876e85207bcb3b9a910.png

Comparación visual#

_, axes = plt.subplots(
    nrows=2, ncols=2, figsize=(12, 7), sharex=False, sharey=False, layout="constrained"
)

axes = axes.flatten()

for i, var_name in enumerate(["r", "alpha", "s", "beta"]):
    ax = axes[i]
    az.plot_posterior(
        pnbd_full.idata.posterior[var_name].values.flatten(),
        color="C0",
        point_estimate="mean",
        ax=ax,
        label="DEMZ",
    )
    az.plot_posterior(
        pnbd_advi.idata.posterior[var_name].values.flatten(),
        color="C1",
        point_estimate="mean",
        ax=ax,
        label="ADVI",
    )
    # az.plot_posterior(
    #     pnbd_fullrank.idata.posterior[var_name].values.flatten(),
    #     color="C2",
    #     point_estimate="mean",
    #     ax=ax,
    #     label="FULLRANK_ADVI",
    # )
    ax.axvline(x=map_fit[var_name], color="C3", linestyle="--", label="MAP")
    ax.legend(loc="upper right")
    ax.set_title(var_name)

plt.gcf().suptitle("Pareto/NBD Model Parameters - DEMZ vs ADVI fits", fontsize=18, fontweight="bold");
../../../_images/eab1b0dfb159a79a01ec987bcfd926ccc70a9a257b4a9e44e6b7ba183949f475.png
_, axes = plt.subplots(
    nrows=2, ncols=2, figsize=(12, 7), sharex=False, sharey=False, layout="constrained"
)

axes = axes.flatten()

for i, var_name in enumerate(["r", "alpha", "s", "beta"]):
    ax = axes[i]
    az.plot_posterior(
        pnbd_full.idata.posterior[var_name].values.flatten(),
        color="C0",
        point_estimate="mean",
        ax=ax,
        label="DEMZ",
    )
    # az.plot_posterior(
    #     pnbd_advi.idata.posterior[var_name].values.flatten(),
    #     color="C1",
    #     point_estimate="mean",
    #     ax=ax,
    #     label="ADVI",
    # )
    az.plot_posterior(
        pnbd_fullrank.idata.posterior[var_name].values.flatten(),
        color="C2",
        point_estimate="mean",
        ax=ax,
        label="FULLRANK_ADVI",
    )
    ax.axvline(x=map_fit[var_name], color="C3", linestyle="--", label="MAP")
    ax.legend(loc="upper right")
    ax.set_title(var_name)

plt.gcf().suptitle("Pareto/NBD Model Parameters - DEMZ vs FULLRANK fits", fontsize=18, fontweight="bold");
../../../_images/eeb21844e208ebf25fa50ad1b73941e49d7db0508e4335a4623f5e4e825405e8.png

Observaciones:

  • Fullrank ofrece un ajuste bastante deficiente.

  • el ajuste de advi coincide con mcmc, aunque proporciona estimaciones más estrechas

Diferencias relativas en las estimaciones de parámetros#

(
    100*(pnbd_full.fit_summary()[['mean', 'sd']] - pnbd_advi.fit_summary()[['mean', 'sd']]) / 
    pnbd_full.fit_summary()[['mean', 'sd']]
    ).rename(
    columns={
        "mean": "pcnt_relative_diff_param_mean",
        "sd": "pcnt_relative_diff_param_sd",
    }
)
arviz - WARNING - Shape validation failed: input_shape: (1, 500), minimum_shape: (chains=2, draws=4)
pcnt_relative_diff_param_mean pcnt_relative_diff_param_sd
alpha 0.429102 46.000000
beta 1.717647 68.572203
r 1.130856 58.695652
s 0.464037 65.000000