SaturationTransformation#

class pymc_marketing.mmm.components.saturation.SaturationTransformation(priors=None, prefix=None)[source]#

Subclass for all saturation transformations.

In order to use a custom saturation transformation, subclass and define:

  • function: function to take x to contributions

  • default_priors: default distributions for each parameter in function

By subclassing from this method, lift test integration will come for free!

Examples

Make a non-saturating saturation transformation

from pymc_marketing.mmm import SaturationTransformation
from pymc_marketing.prior import Prior

def infinite_returns(x, b):
    return b * x

class InfiniteReturns(SaturationTransformation):
    lookup_name = "infinite_returns"
    function = infinite_returns
    default_priors = {"b": Prior("HalfNormal")}

Make use of plotting capabilities to understand the transformation and its priors

import matplotlib.pyplot as plt
import numpy as np

saturation = InfiniteReturns()

rng = np.random.default_rng(0)

prior = saturation.sample_prior(random_seed=rng)
curve = saturation.sample_curve(prior)
saturation.plot_curve(curve, sample_kwargs={"rng": rng})
plt.show()

Methods

SaturationTransformation.__init__([priors, ...])

SaturationTransformation.apply(x[, dims])

Call within a model context.

SaturationTransformation.plot_curve(curve[, ...])

Plot curve HDI and samples.

SaturationTransformation.plot_curve_hdi(curve)

Plot the HDI of the curve.

SaturationTransformation.plot_curve_samples(curve)

Plot samples from the curve.

SaturationTransformation.sample_curve([...])

Sample the curve of the saturation transformation given parameters.

SaturationTransformation.sample_prior([coords])

Sample the priors for the transformation.

SaturationTransformation.set_dims_for_all_priors(dims)

Set the dims for all priors.

SaturationTransformation.to_dict()

Convert the transformation to a dictionary.

SaturationTransformation.update_priors(priors)

Update the priors for a function after initialization.

Attributes

combined_dims

Get the combined dims for all the parameters.

function_priors

Get the priors for the function.

model_config

Mapping from variable name to prior for the model.

prefix

variable_mapping

Mapping from parameter name to variable name in the model.

default_priors

function

lookup_name