LinearTrend#
- class pymc_marketing.mmm.linear_trend.LinearTrend(**data)[source]#
LinearTrend class.
Linear trend component using change points. The trend is defined as:
\[f(t) = k + \sum_{m=1}^{M} \delta_m I(t > s_m)\]where:
\(k\) is the base intercept,
\(\delta_m\) is the change in the trend at change point \(m\),
\(I\) is the indicator function,
\(s_m\) is the change point.
The change points are defined as:
\[s_m = \frac{m}{M+1} \max(t)\]where \(M\) is the number of change points.
The priors for the trend parameters are:
\(k \sim \text{Normal}(0, 0.05)\)
\(\delta_m \sim \text{Laplace}(0, 0.25)\)
- Parameters:
- priors
dict
[str
,Prior
], optional Dictionary with the priors for the trend parameters. The dictionary must have ‘delta’ key. If
include_intercept
is True, the ‘k’ key is also required. By default None, or the default priors.- dims
Dims
, optional Dimensions of the parameters, by default None or empty.
- n_changepoints
int
, optional Number of changepoints, by default 10.
- include_interceptbool, optional
Include an intercept in the trend, by default False
- priors
References
- Adapted from MBrouns/timeseers package:
Examples
Linear trend with 10 changepoints:
from pymc_marketing.mmm import LinearTrend trend = LinearTrend(n_changepoints=10)
Use the trend in a model:
import pymc as pm import numpy as np import pandas as pd n_years = 3 n_dates = 52 * n_years first_date = "2020-01-01" dates = pd.date_range(first_date, periods=n_dates, freq="W-MON") dayofyear = dates.dayofyear.to_numpy() t = (dates - dates[0]).days.to_numpy() t = t / 365.25 coords = {"date": dates} with pm.Model(coords=coords) as model: intercept = pm.Normal("intercept", mu=0, sigma=1) mu = intercept + trend.apply(t) sigma = pm.Gamma("sigma", mu=0.1, sigma=0.025) pm.Normal("obs", mu=mu, sigma=sigma, dims="date")
Hierarchical LinearTrend via hierarchical prior:
from pymc_marketing.prior import Prior hierarchical_delta = Prior( "Laplace", mu=Prior("Normal", dims="changepoint"), b=Prior("HalfNormal", dims="changepoint"), dims=("changepoint", "geo"), ) priors = dict(delta=hierarchical_delta) hierarchical_trend = LinearTrend( priors=priors, n_changepoints=10, dims="geo", )
Sample the hierarchical trend:
seed = sum(map(ord, "Hierarchical LinearTrend")) rng = np.random.default_rng(seed) coords = {"geo": ["A", "B"]} prior = hierarchical_trend.sample_prior( coords=coords, random_seed=rng, ) curve = hierarchical_trend.sample_curve(prior)
Plot the curve HDI and samples:
sample_kwargs = {"n": 3, "rng": rng} fig, axes = hierarchical_trend.plot_curve( curve, sample_kwargs=sample_kwargs, ) fig.suptitle("Hierarchical Linear Trend") axes[0].set(ylabel="Trend", xlabel="Time") axes[1].set(xlabel="Time")
Methods
LinearTrend.__init__
(**data)Create a new model by parsing and validating input data from keyword arguments.
Create the linear trend for the given x values.
LinearTrend.construct
([_fields_set])LinearTrend.copy
(*[, include, exclude, ...])Returns a copy of the model.
LinearTrend.dict
(*[, include, exclude, ...])LinearTrend.from_orm
(obj)LinearTrend.json
(*[, include, exclude, ...])LinearTrend.model_construct
([_fields_set])Creates a new instance of the
Model
class with validated data.LinearTrend.model_copy
(*[, update, deep])!!! abstract "Usage Documentation"
LinearTrend.model_dump
(*[, mode, include, ...])!!! abstract "Usage Documentation"
LinearTrend.model_dump_json
(*[, indent, ...])!!! abstract "Usage Documentation"
LinearTrend.model_json_schema
([by_alias, ...])Generates a JSON schema for a model class.
Compute the class name for parametrizations of generic classes.
LinearTrend.model_post_init
(context, /)Override this method to perform additional initialization after
__init__
andmodel_construct
.LinearTrend.model_rebuild
(*[, force, ...])Try to rebuild the pydantic-core schema for the model.
LinearTrend.model_validate
(obj, *[, strict, ...])Validate a pydantic model instance.
LinearTrend.model_validate_json
(json_data, *)!!! abstract "Usage Documentation"
LinearTrend.model_validate_strings
(obj, *[, ...])Validate the given object with string data against the Pydantic model.
LinearTrend.parse_file
(path, *[, ...])LinearTrend.parse_raw
(b, *[, content_type, ...])LinearTrend.plot_curve
(curve[, ...])Plot the curve samples from the trend.
LinearTrend.sample_curve
(parameters[, max_value])Sample the curve given parameters.
LinearTrend.sample_prior
([coords])Sample the prior for the parameters used in the trend.
LinearTrend.schema
([by_alias, ref_template])LinearTrend.schema_json
(*[, by_alias, ...])LinearTrend.update_forward_refs
(**localns)LinearTrend.validate
(value)Attributes
default_priors
Default priors for the trend parameters.
model_computed_fields
model_config
Configuration for the model, should be a dictionary conforming to [
ConfigDict
][pydantic.config.ConfigDict].model_extra
Get extra fields set during validation.
model_fields
model_fields_set
Returns the set of fields that have been explicitly set on this model instance.
priors
dims
n_changepoints
include_intercept