ModelSamplerEstimator.estimate_num_steps_sampling#

ModelSamplerEstimator.estimate_num_steps_sampling(model, *, tune=None, draws=None, seed=None, nuts_kwargs=None, mcmc_kwargs=None)[source]#

Estimate total number of NUTS steps during warmup + sampling using NumPyro.

Parameters:
modelModel

PyMC model to estimate steps for using a JAX/NumPyro NUTS kernel.

tuneint | None, optional

Warmup iterations. Defaults to the estimator setting if None.

drawsint | None, optional

Sampling iterations. Defaults to the estimator setting if None.

seedint | None, optional

Random seed for the JAX run. Defaults to the estimator setting if None.

nuts_kwargsdict | None, optional

Additional keyword arguments passed to numpyro.infer.NUTS. If not provided, the estimator’s default_nuts_kwargs are used. Provided values override the defaults.

mcmc_kwargsdict | None, optional

Additional keyword arguments passed to numpyro.infer.MCMC (excluding num_warmup and num_samples, which are set by tune/draws). If not provided, the estimator’s default_mcmc_kwargs are used. Provided values override the defaults.

Returns:
int

Total number of leapfrog steps across warmup + sampling.