register_tensor_transform#

pymc_marketing.prior.register_tensor_transform(name, transform)[source]#

Register a tensor transform function to be used in the Prior class.

Parameters:
namestr

The name of the transform.

funcCallable[[pt.TensorLike], pt.TensorLike]

The function to apply to the tensor.

Examples

Register a custom transform function.

from pymc_marketing.prior import (
    Prior,
    register_tensor_transform,
)

def custom_transform(x):
    return x ** 2

register_tensor_transform("square", custom_transform)

custom_distribution = Prior("Normal", transform="square")