BG/NBD Model#
In this notebook we show how to fit a BG/NBD model in PyMC-Marketing. We compare the results with the lifetimes
package (no longer maintained). The model is presented in the paper: Fader, P. S., Hardie, B. G., & Lee, K. L. (2005). “Counting your customers” the easy way: An alternative to the Pareto/NBD model. Marketing science, 24(2), 275-284.
Prepare Notebook#
import arviz as az
import matplotlib.pyplot as plt
import pandas as pd
import xarray as xr
from fastprogress.fastprogress import progress_bar
from lifetimes import BetaGeoFitter
from pymc_marketing import clv
# Plotting configuration
az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
Read Data#
We use the CDNOW
dataset (see lifetimes quick-start).
data_path = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/clv_quickstart.csv"
df = pd.read_csv(data_path)
df.head()
frequency | recency | T | monetary_value | |
---|---|---|---|---|
0 | 2 | 30.43 | 38.86 | 22.35 |
1 | 1 | 1.71 | 38.86 | 11.77 |
2 | 0 | 0.00 | 38.86 | 0.00 |
3 | 0 | 0.00 | 38.86 | 0.00 |
4 | 0 | 0.00 | 38.86 | 0.00 |
Recall from the lifetimes
documentation the following definitions:
frequency
represents the number of repeat purchases the customer has made. This means that it’s one less than the total number of purchases. This is actually slightly wrong. It’s the count of time periods the customer had a purchase in. So if using days as units, then it’s the count of days the customer had a purchase on.
T
represents the age of the customer in whatever time units chosen (weekly, in the above dataset). This is equal to the duration between a customer’s first purchase and the end of the period under study.
recency
represents the age of the customer when they made their most recent purchases. This is equal to the duration between a customer’s first purchase and their latest purchase. (Thus if they have made only 1 purchase, the recency is 0.)
Tip
We rename the index column to customer_id
as this is required by the model
data = (
df.reset_index()
.rename(columns={"index": "customer_id"})
.drop(columns="monetary_value")
)
Model Specification#
The BG/NBD model is a probabilistic model that describes the buying behavior of a customer in the non-contractual setting. It is based on the following assumptions for each customer:
Frequency Process#
While active, the time between transactions is distributed exponential with transaction rate, i.e.,
\[f(t_{j}|t_{j-1}; \lambda) = \lambda \exp(-\lambda (t_{j} - t_{j - 1})), \quad t_{j} \geq t_{j - 1} \geq 0\]Heterogeneity in \(\lambda\) follows a gamma distribution with pdf
\[f(\lambda|r, \alpha) = \frac{\alpha^{r}\lambda^{r - 1}\exp(-\lambda \alpha)}{\Gamma(r)}, \quad \lambda > 0\]
Dropout Process#
After any transaction, a customer becomes inactive with probability \(p\).
Heterogeneity in \(p\) follows a beta distribution with pdf
\[f(p|a, b) = \frac{\Gamma(a + b)}{\Gamma(a) \Gamma(b)} p^{a - 1}(1 - p)^{b - 1}, \quad 0 \leq p \leq 1\]The transaction rate \(\lambda\) and the dropout probability \(p\) vary independently across customers.
Instead of estimating \(\lambda\) and \(p\) for each specific customer, we do it for a randomly chosen customer, i.e. we work with the expected values of the parameters. Hence, we are interesting in finding the posterior distribution of the parameters \(r\), \(\alpha\), \(a\), and \(b\).
Model Fitting#
Estimating such parameters is very easy in PyMC-Marketing. We instantiate the model in a similar way:
model_mcmc = clv.BetaGeoModel(data=data)
We can build the model so that we can see the model specification:
model_mcmc.build_model()
print(model_mcmc)
BG/NBD
a ~ HalfFlat()
b ~ HalfFlat()
alpha ~ HalfFlat()
r ~ HalfFlat()
likelihood ~ Potential(f(r, alpha, b, a))
We can now fit the model. The default sampler in PyMC-Marketing is the No-U-Turn Sampler (NUTS). We use the default \(4\) chains and \(1000\) draws per chain.
Note
It is not necessary to build the model before fitting it. We can fit the model directly.
sample_kwargs = {
"draws": 2_000,
"chains": 4,
"target_accept": 0.9,
"random_seed": 42,
}
idata_mcmc = model_mcmc.fit()
idata_mcmc
-
<xarray.Dataset> Size: 136kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 Data variables: a (chain, draw) float64 32kB 0.8328 0.9293 0.9555 ... 1.042 1.112 b (chain, draw) float64 32kB 2.824 2.756 2.953 ... 3.807 4.156 3.535 alpha (chain, draw) float64 32kB 4.397 4.364 4.305 ... 4.039 4.173 4.47 r (chain, draw) float64 32kB 0.2511 0.2528 0.2424 ... 0.2196 0.2261 Attributes: created_at: 2024-04-05T07:20:26.323594 arviz_version: 0.17.1 inference_library: pymc inference_library_version: 5.11.0 sampling_time: 30.46035885810852 tuning_steps: 1000
-
<xarray.Dataset> Size: 496kB Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999 Data variables: (12/17) tree_depth (chain, draw) int64 32kB 4 2 4 4 2 4 ... 4 1 4 4 3 3 perf_counter_start (chain, draw) float64 32kB 1.8e+04 ... 1.801e+04 energy (chain, draw) float64 32kB 9.584e+03 ... 9.589e+03 index_in_trajectory (chain, draw) int64 32kB -2 -2 3 -3 -2 ... 0 6 -7 1 3 perf_counter_diff (chain, draw) float64 32kB 0.009088 ... 0.004634 step_size (chain, draw) float64 32kB 0.3874 0.3874 ... 0.4099 ... ... smallest_eigval (chain, draw) float64 32kB nan nan nan ... nan nan diverging (chain, draw) bool 4kB False False ... False False max_energy_error (chain, draw) float64 32kB 0.8383 0.3148 ... -1.02 lp (chain, draw) float64 32kB -9.583e+03 ... -9.585e+03 reached_max_treedepth (chain, draw) bool 4kB False False ... False False acceptance_rate (chain, draw) float64 32kB 0.7259 0.9069 ... 0.9577 Attributes: created_at: 2024-04-05T07:20:26.344549 arviz_version: 0.17.1 inference_library: pymc inference_library_version: 5.11.0 sampling_time: 30.46035885810852 tuning_steps: 1000
-
<xarray.Dataset> Size: 94kB Dimensions: (index: 2357) Coordinates: * index (index) int64 19kB 0 1 2 3 4 5 ... 2352 2353 2354 2355 2356 Data variables: customer_id (index) int64 19kB 0 1 2 3 4 5 ... 2352 2353 2354 2355 2356 frequency (index) int64 19kB 2 1 0 0 0 7 1 0 2 0 ... 7 1 2 0 0 0 5 0 4 0 recency (index) float64 19kB 30.43 1.71 0.0 0.0 ... 24.29 0.0 26.57 0.0 T (index) float64 19kB 38.86 38.86 38.86 38.86 ... 27.0 27.0 27.0
We can look into the summary table:
model_mcmc.fit_summary()
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
a | 0.975 | 0.282 | 0.528 | 1.463 | 0.007 | 0.005 | 1747.0 | 1655.0 | 1.0 |
b | 3.192 | 1.184 | 1.556 | 5.320 | 0.030 | 0.021 | 1768.0 | 1716.0 | 1.0 |
alpha | 4.479 | 0.382 | 3.761 | 5.213 | 0.009 | 0.006 | 1795.0 | 1809.0 | 1.0 |
r | 0.244 | 0.013 | 0.220 | 0.267 | 0.000 | 0.000 | 1742.0 | 1876.0 | 1.0 |
We see that the r_hat
values are close to \(1\), which indicates convergence.
We can also plot posterior distributions of the parameters and the rank plots:
axes = az.plot_trace(
data=model_mcmc.idata,
compact=True,
kind="rank_bars",
backend_kwargs={"figsize": (12, 9), "layout": "constrained"},
)
plt.gcf().suptitle("BG/NBD Model Trace", fontsize=18, fontweight="bold");
Using MAP fit#
CLV models such as BetaGeoModel
, can provide the maximum a posteriori estimates using a numerical optimizer (L-BFGS-B
) from scipy.optimize
under the hood.
model_map = clv.BetaGeoModel(data=data)
idata_map = model_map.fit(fit_method="map")
idata_map
-
<xarray.Dataset> Size: 48B Dimensions: (chain: 1, draw: 1) Coordinates: * chain (chain) int64 8B 0 * draw (draw) int64 8B 0 Data variables: a (chain, draw) float64 8B 0.793 b (chain, draw) float64 8B 2.426 alpha (chain, draw) float64 8B 4.414 r (chain, draw) float64 8B 0.2426 Attributes: created_at: 2024-04-05T07:20:31.004335 arviz_version: 0.17.1 inference_library: pymc inference_library_version: 5.11.0
-
<xarray.Dataset> Size: 94kB Dimensions: (index: 2357) Coordinates: * index (index) int64 19kB 0 1 2 3 4 5 ... 2352 2353 2354 2355 2356 Data variables: customer_id (index) int64 19kB 0 1 2 3 4 5 ... 2352 2353 2354 2355 2356 frequency (index) int64 19kB 2 1 0 0 0 7 1 0 2 0 ... 7 1 2 0 0 0 5 0 4 0 recency (index) float64 19kB 30.43 1.71 0.0 0.0 ... 24.29 0.0 26.57 0.0 T (index) float64 19kB 38.86 38.86 38.86 38.86 ... 27.0 27.0 27.0
This time we get point estimates for the parameters.
map_summary = model_map.fit_summary()
map_summary
a 0.793
b 2.426
alpha 4.414
r 0.243
Name: value, dtype: float64
Comparing with the lifetimes
package#
For the sake of comparison, we also fit the model using the lifetimes
package.
bgf = BetaGeoFitter()
bgf.fit(
frequency=data["frequency"].values,
recency=data["recency"].values,
T=data["T"].values,
)
bgf.summary
coef | se(coef) | lower 95% bound | upper 95% bound | |
---|---|---|---|---|
r | 0.242593 | 0.012557 | 0.217981 | 0.267205 |
alpha | 4.413532 | 0.378221 | 3.672218 | 5.154846 |
a | 0.792886 | 0.185719 | 0.428877 | 1.156895 |
b | 2.425752 | 0.705345 | 1.043276 | 3.808229 |
Show code cell source
fig, axes = plt.subplots(
nrows=2, ncols=2, figsize=(12, 9), sharex=False, sharey=False, layout="constrained"
)
axes = axes.flatten()
for i, var_name in enumerate(["r", "alpha", "a", "b"]):
ax = axes[i]
az.plot_posterior(
model_mcmc.idata.posterior[var_name].values.flatten(),
color="C0",
point_estimate="mean",
ax=ax,
label="MCMC",
)
ax.axvline(x=map_summary[var_name], color="C1", linestyle="--", label="MAP")
ax.axvline(
x=bgf.summary["coef"][var_name], color="C2", linestyle="--", label="lifetimes"
)
ax.legend(loc="upper right")
ax.set_title(var_name)
plt.gcf().suptitle("BG/NBD Model Parameters", fontsize=18, fontweight="bold");
Some Applications#
Now that you have fitted the model, we can use it to make predictions. For example, we can predict the expected probability of a customer being alive as a function of time (steps). Here is a snippet of code to do that:
Expected Number of Purchases#
Let us take a sample of users:
example_customer_ids = [1, 6, 10, 18, 45, 1412]
data_small = data.query("customer_id.isin(@example_customer_ids)")
data_small.head(6)
customer_id | frequency | recency | T | |
---|---|---|---|---|
1 | 1 | 1 | 1.71 | 38.86 |
6 | 6 | 1 | 5.00 | 38.86 |
10 | 10 | 5 | 24.43 | 38.86 |
18 | 18 | 3 | 28.29 | 38.71 |
45 | 45 | 12 | 34.43 | 38.57 |
1412 | 1412 | 14 | 30.29 | 31.57 |
Observe that the last two customers are frequent buyers as compared to the others.
steps = 90
expected_num_purchases_steps = xr.concat(
objs=[
model_mcmc.expected_num_purchases(
customer_id=data_small["customer_id"],
frequency=data_small["frequency"],
recency=data_small["recency"],
T=data_small["T"],
t=t,
)
for t in progress_bar(range(steps))
],
dim="t",
).transpose(..., "t")
We can plot the expected number of purchases for the next \(90\) periods:
Show code cell source
fig, axes = plt.subplots(
nrows=len(example_customer_ids),
ncols=1,
figsize=(12, 15),
sharex=True,
sharey=True,
layout="constrained",
)
axes = axes.flatten()
for i, customer_id in enumerate(example_customer_ids):
ax = axes[i]
customer_expected_num_purchases_steps = expected_num_purchases_steps.sel(
customer_id=customer_id
)
az.plot_hdi(
range(steps),
customer_expected_num_purchases_steps,
hdi_prob=0.94,
color="C0",
fill_kwargs={"alpha": 0.3, "label": "$94 \\%$ HDI"},
ax=ax,
)
az.plot_hdi(
range(steps),
customer_expected_num_purchases_steps,
hdi_prob=0.5,
color="C0",
fill_kwargs={"alpha": 0.5, "label": "$50 \\%$ HDI"},
ax=ax,
)
ax.plot(
range(steps),
customer_expected_num_purchases_steps.mean(dim=("chain", "draw")),
color="C0",
label="posterior mean",
)
ax.legend(loc="upper left")
ax.set(title=f"Customer {customer_id}", xlabel="t", ylabel="Probability Alive")
axes[-1].set(xlabel="steps")
plt.gcf().suptitle("Expected Number of Purchases", fontsize=18, fontweight="bold");
Note that the frequent buyers are expected to make more purchases in the future.
Probability of a Customer Being Alive#
We now look into the probability of a customer being alive for the next \(90\) periods:
steps = 90
expected_probability_alive_steps = xr.concat(
objs=[
model_mcmc.expected_probability_alive(
customer_id=data_small["customer_id"],
frequency=data_small["frequency"],
recency=data_small["recency"],
T=data_small["T"] + t, # add t days
)
for t in progress_bar(range(steps))
],
dim="t",
).transpose(..., "t")
Show code cell source
fig, axes = plt.subplots(
nrows=len(example_customer_ids),
ncols=1,
figsize=(12, 15),
sharex=True,
sharey=True,
layout="constrained",
)
axes = axes.flatten()
for i, customer_id in enumerate(example_customer_ids):
ax = axes[i]
customer_expected_probability_alive_steps = expected_probability_alive_steps.sel(
customer_id=customer_id
)
az.plot_hdi(
range(steps),
customer_expected_probability_alive_steps,
hdi_prob=0.94,
color="C1",
fill_kwargs={"alpha": 0.3, "label": "$94 \\%$ HDI"},
ax=ax,
)
az.plot_hdi(
range(steps),
customer_expected_probability_alive_steps,
hdi_prob=0.5,
color="C1",
fill_kwargs={"alpha": 0.5, "label": "$50 \\%$ HDI"},
ax=ax,
)
ax.plot(
range(steps),
customer_expected_probability_alive_steps.mean(dim=("chain", "draw")),
color="C1",
label="posterior mean",
)
ax.legend(loc="upper right")
ax.set(title=f"Customer {customer_id}", ylabel="Probability Alive", ylim=(0, 1))
axes[-1].set(xlabel="steps")
plt.gcf().suptitle(
"Expected Probability Alive over Time", fontsize=18, fontweight="bold"
);
Tip
Here are some general remarks:
It’s important to remark these plots assume no future purchases.
The decay probability is not the same as it depends in the purchase history of the customer.
The probability of being alive is always decreasing as we are assuming there is no change in the other parameters.
These probabilities are always non-negative, as expected.
Warning
For the frequent buyers, the probability of being alive drops very fast as we are assuming no future purchases. It is very important ot keep this in mind when interpreting the results.
%load_ext watermark
%watermark -n -u -v -iv -w -p pymc,pytensor
Last updated: Fri Apr 05 2024
Python implementation: CPython
Python version : 3.11.8
IPython version : 8.22.2
pymc : 5.11.0
pytensor: 2.18.6
pandas : 2.2.1
arviz : 0.17.1
xarray : 2024.2.0
pymc_marketing: 0.4.2
matplotlib : 3.8.3
Watermark: 2.4.3