linear_trend#

Linear trend using change points.

Examples#

Define a linear trend with 8 changepoints:

from pymc_marketing.mmm import LinearTrend

trend = LinearTrend(n_changepoints=8)

Sample the prior for the trend parameters and curve:

import numpy as np

seed = sum(map(ord, "Linear Trend"))
rng = np.random.default_rng(seed)

prior = trend.sample_prior(random_seed=rng)
curve = trend.sample_curve(prior)

Plot the curve samples:

_, axes = trend.plot_curve(curve, random_seed=rng)
ax = axes[0]
ax.set(
    xlabel="Time",
    ylabel="Trend",
    title=f"Linear Trend with {trend.n_changepoints} Change Points",
)
LinearTrend prior

Classes

LinearTrend(**data)

LinearTrend class.