ModelSamplerEstimator.estimate_num_steps_sampling#
- ModelSamplerEstimator.estimate_num_steps_sampling(model, *, tune=None, draws=None, seed=None, nuts_kwargs=None, mcmc_kwargs=None)[source]#
Estimate total number of NUTS steps during warmup + sampling using NumPyro.
- Parameters:
- model
Model PyMC model to estimate steps for using a JAX/NumPyro NUTS kernel.
- tune
int|None, optional Warmup iterations. Defaults to the estimator setting if
None.- draws
int|None, optional Sampling iterations. Defaults to the estimator setting if
None.- seed
int|None, optional Random seed for the JAX run. Defaults to the estimator setting if
None.- nuts_kwargs
dict|None, optional Additional keyword arguments passed to
numpyro.infer.NUTS. If not provided, the estimator’sdefault_nuts_kwargsare used. Provided values override the defaults.- mcmc_kwargs
dict|None, optional Additional keyword arguments passed to
numpyro.infer.MCMC(excludingnum_warmupandnum_samples, which are set bytune/draws). If not provided, the estimator’sdefault_mcmc_kwargsare used. Provided values override the defaults.
- model
- Returns:
intTotal number of leapfrog steps across warmup + sampling.