import numpy as np
import pandas as pd
import pymc as pm

from pymc_marketing.clv.distributions import continuous_contractual, continuous_non_contractual

from scipy.special import expit, hyp2f1

import xarray as xr

import matplotlib.pyplot as plt
a = 0.8
b = 2.5
alpha = 3
r = 4

rng = np.random.default_rng(seed=34)

df = pd.read_csv("data/clv_quickstart.csv")
T = df["T"].values
def rng_fn(rng, a, b, r, alpha, T, T0, size):
    p = rng.beta(a, b, size=size)
    lam = rng.gamma(r, 1 / alpha, size=size)

    return continuous_contractual.rng_fn(rng, lam, p, T, T0, size=size)
data = rng_fn(rng, a, b, r, alpha, T, 0, size=len(T))
recency = data[..., 0]
frequency = data[..., 1]
alive = 1 - data[..., 2]
def conditional_probability_alive_reference(
    frequency, 
    recency, 
    T
):
    log_div = (r + frequency) * np.log((alpha + T) / (alpha + recency)) + np.log(
        a / (b + np.maximum(frequency, 1) - 1)
    )

    return np.atleast_1d(np.where(frequency == 0, 1.0, expit(-log_div)))
plt.hist(conditional_probability_alive_reference(frequency, recency, T), bins=40);
../../../_images/4647b7752b50ca457542ec590df4ef04ec75f07399ba164d76f9aa05f46b258e.png
with pm.Model() as model:
    pm.Normal("a", a, sigma=0.1)
    pm.Normal("b", b, sigma=0.1)
    pm.Normal("alpha", alpha, sigma=0.1)
    pm.Normal("r", r, sigma=0.1)
    
    trace = pm.sample(chains=2)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 4 jobs)
NUTS: [a, b, alpha, r]
100.00% [4000/4000 00:02<00:00 Sampling 2 chains, 0 divergences]
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 2 seconds.
trace
arviz.InferenceData
    • <xarray.Dataset>
      Dimensions:  (chain: 2, draw: 1000)
      Coordinates:
        * chain    (chain) int64 0 1
        * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
      Data variables:
          a        (chain, draw) float64 0.7667 0.7566 0.8314 ... 0.8305 0.8132 0.878
          b        (chain, draw) float64 2.511 2.467 2.436 2.358 ... 2.574 2.474 2.428
          alpha    (chain, draw) float64 2.999 2.943 3.153 3.078 ... 2.939 3.032 3.082
          r        (chain, draw) float64 4.187 4.003 3.988 3.943 ... 4.045 4.019 4.012
      Attributes:
          created_at:                 2022-12-15T08:47:51.405208
          arviz_version:              0.14.0
          inference_library:          pymc
          inference_library_version:  5.0.0
          sampling_time:              2.083648920059204
          tuning_steps:               1000

    • <xarray.Dataset>
      Dimensions:                (chain: 2, draw: 1000)
      Coordinates:
        * chain                  (chain) int64 0 1
        * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999
      Data variables: (12/17)
          lp                     (chain, draw) float64 3.73 5.223 4.106 ... 5.424 4.63
          largest_eigval         (chain, draw) float64 nan nan nan nan ... nan nan nan
          process_time_diff      (chain, draw) float64 0.0005532 ... 0.0006748
          tree_depth             (chain, draw) int64 2 2 2 2 2 2 2 2 ... 2 2 2 2 2 2 2
          diverging              (chain, draw) bool False False False ... False False
          max_energy_error       (chain, draw) float64 0.4905 -0.3013 ... 0.2168
          ...                     ...
          smallest_eigval        (chain, draw) float64 nan nan nan nan ... nan nan nan
          energy                 (chain, draw) float64 -1.44 -3.08 ... -4.797 -4.601
          energy_error           (chain, draw) float64 -0.02631 -0.3013 ... 0.2168
          step_size              (chain, draw) float64 1.229 1.229 ... 0.7527 0.7527
          perf_counter_start     (chain, draw) float64 2.217e+03 ... 2.217e+03
          step_size_bar          (chain, draw) float64 1.03 1.03 ... 0.9716 0.9716
      Attributes:
          created_at:                 2022-12-15T08:47:51.418327
          arviz_version:              0.14.0
          inference_library:          pymc
          inference_library_version:  5.0.0
          sampling_time:              2.083648920059204
          tuning_steps:               1000

dims = ("customer_id",)
coords = {"customer_id": range(len(frequency))}

frequency = xr.DataArray(
    frequency, 
    dims=dims,
    coords=coords,
)

recency = xr.DataArray(
    recency, 
    dims=dims,
    coords=coords,
)

T = xr.DataArray(
    T, 
    dims=dims,
    coords=coords,
)
def conditional_probability_alive(frequency, recency, T, trace):

    dims = ("customer_id",)
    coords = {"customer_id": range(len(frequency))}
    
    to_xarray = lambda array: xr.DataArray(data=array, coords=coords, dims=dims)

    frequency = to_xarray(frequency)
    recency = to_xarray(recency)
    T = to_xarray(T)
    
    a = trace.posterior["a"]
    b = trace.posterior["b"]
    alpha = trace.posterior["alpha"]
    r = trace.posterior["r"]
    
    log_div = (r + frequency) * np.log((alpha + T) / (alpha + recency)) + np.log(
        a / (b + np.maximum(frequency, 1) - 1)
    )

    return xr.where(frequency == 0, 1.0, expit(-log_div))
plt.hist(conditional_probability_alive(frequency, recency, T, trace).mean(("draw", "chain")), bins=40);
../../../_images/1900ae23f3d07812130dd40a9951a36aa20fdfd6d36c0e15d5b50921cdd7922f.png
trace.posterior["a"].mean()
<xarray.DataArray 'a' ()>
array(0.80128725)
def expected_number_of_purchases(t, frequency, recency, T):
    numerator = 1 - ((alpha + T) / (alpha + T + t)) ** (r + frequency) * hyp2f1(
        r + frequency,
        b + frequency,
        a + b + frequency - 1,
        t / (alpha + T + t),
    )
    numerator *= (a + b + frequency - 1) / (a - 1)
    denominator = 1 + (frequency > 0) * (a / (b + frequency - 1)) * (
        (alpha + T) / (alpha + recency)
    ) ** (r + frequency)
    
    return numerator/denominator
def to_xarray(*arrays):
    num_customers = len(arrays[0])
    dims = ("customer_id",)
    coords = {"customer_id": range(num_customers)}

    if len(arrays) == 1:
        return xr.DataArray(data=arrays[0], coords=coords, dims=dims)

    if any(len(array) != num_customers for array in arrays):
        raise ValueError("The size of input arrays must be the same.")

    return (xr.DataArray(data=array, coords=coords, dims=dims) for array in arrays)

frequency, recency, T = to_xarray(frequency, recency, T)

t = xr.DataArray(
    range(20, 40, 2),
    coords={"times": range(10)},
    dims=("times",),
)
expected_number_of_purchases(t, frequency, recency, T)
<xarray.DataArray (customer_id: 2357, times: 10)>
array([[1.24383802e+01, 1.34879227e+01, 1.45105754e+01, ...,
        1.92688909e+01, 2.01577586e+01, 2.10280203e+01],
       [1.40509154e-07, 1.52093799e-07, 1.63348148e-07, ...,
        2.15314049e-07, 2.24954275e-07, 2.34373811e-07],
       [8.33117474e-05, 9.01011987e-05, 9.66872364e-05, ...,
        1.26980810e-04, 1.32580942e-04, 1.38047343e-04],
       ...,
       [1.81728791e-05, 1.96313117e-05, 2.10438465e-05, ...,
        2.75171216e-05, 2.87100928e-05, 2.98736204e-05],
       [3.65527293e-04, 3.93375244e-04, 4.20185039e-04, ...,
        5.41215146e-04, 5.63223769e-04, 5.84608832e-04],
       [2.65302833e-05, 2.86050981e-05, 3.06086185e-05, ...,
        3.97219388e-05, 4.13903545e-05, 4.30145687e-05]])
Coordinates:
  * customer_id  (customer_id) int64 0 1 2 3 4 5 ... 2352 2353 2354 2355 2356
  * times        (times) int64 0 1 2 3 4 5 6 7 8 9
to_xarray = lambda array: xr.DataArray(array,
    coords={"customer_id": range(len(array))},
    dims=("customer_id",),
)
t = to_xarray(range(20, 40, 2))
frequency = to_xarray([1, 3, 5, 7, 9]*2)
recency = to_xarray([20, 30]*5)
T = to_xarray([25, 35]*5)
expected_number_of_purchases(t, frequency, recency, T)
<xarray.DataArray (customer_id: 10)>
array([1.42554638, 2.17059763, 3.29959668, 4.06284244, 4.72380438,
       1.73223875, 3.1645582 , 4.08454744, 5.20603817, 6.34220761])
Coordinates:
  * customer_id  (customer_id) int64 0 1 2 3 4 5 6 7 8 9