ModelSamplerEstimator.estimate_model_eval_time#

ModelSamplerEstimator.estimate_model_eval_time(model, num_evaluations=None)[source]#

Estimate average evaluation time (seconds) of logp+dlogp using JAX.

Parameters:
modelModel

PyMC model whose logp and gradients are jitted and evaluated.

num_evaluationsint | None, optional

Number of repeated evaluations to average over. If None, a value is chosen to take roughly 5 seconds in total for a stable estimate.

Returns:
float

Average evaluation time in seconds.