From 0e890b0aafa9930cde57a6d0174c1f4ca473b50b Mon Sep 17 00:00:00 2001 From: eclipse1605 Date: Sat, 20 Dec 2025 16:49:57 +0530 Subject: [PATCH] attempt to fix warmup bookkeeping --- pymc/backends/mcbackend.py | 3 --- pymc/sampling/mcmc.py | 16 ++++------------ pymc/step_methods/compound.py | 10 ++++------ pymc/step_methods/hmc/base_hmc.py | 1 - pymc/step_methods/hmc/hmc.py | 1 - pymc/step_methods/hmc/nuts.py | 1 - pymc/step_methods/metropolis.py | 31 ++++++------------------------- pymc/step_methods/slicer.py | 7 ++----- tests/backends/test_mcbackend.py | 8 +++----- tests/sampling/test_mcmc.py | 27 +++++++++++++++++++++++++++ 10 files changed, 46 insertions(+), 59 deletions(-) diff --git a/pymc/backends/mcbackend.py b/pymc/backends/mcbackend.py index d02a6dbebb..e89ac19cf2 100644 --- a/pymc/backends/mcbackend.py +++ b/pymc/backends/mcbackend.py @@ -34,7 +34,6 @@ BlockedStep, CompoundStep, StatsBijection, - check_step_emits_tune, flat_statname, flatten_steps, ) @@ -210,8 +209,6 @@ def make_runmeta_and_point_fn( ) -> tuple[mcb.RunMeta, PointFunc]: variables, point_fn = get_variables_and_point_fn(model, initial_point) - check_step_emits_tune(step) - # In PyMC the sampler stats are grouped by the sampler. sample_stats = [] steps = flatten_steps(step) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index de341c68cd..949235bc76 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1043,18 +1043,10 @@ def _sample_return( else: traces, length = _choose_chains(traces, 0) mtrace = MultiTrace(traces)[:length] - # count the number of tune/draw iterations that happened - # ideally via the "tune" statistic, but not all samplers record it! - if "tune" in mtrace.stat_names: - # Get the tune stat directly from chain 0, sampler 0 - stat = mtrace._straces[0].get_sampler_stats("tune", sampler_idx=0) - stat = tuple(stat) - n_tune = stat.count(True) - n_draws = stat.count(False) - else: - # these may be wrong when KeyboardInterrupt happened, but they're better than nothing - n_tune = min(tune, len(mtrace)) - n_draws = max(0, len(mtrace) - n_tune) + # Count the number of tune/draw iterations that happened. + # The warmup/draw boundary is owned by the sampling driver. + n_tune = min(tune, len(mtrace)) + n_draws = max(0, len(mtrace) - n_tune) if discard_tuned_samples: mtrace = mtrace[n_tune:] diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index a9cae903f0..5984446c94 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -92,6 +92,10 @@ def infer_warn_stats_info( sds[sname] = (dtype, None) elif sds: stats_dtypes.append({sname: dtype for sname, (dtype, _) in sds.items()}) + + # Even when a step method does not emit any stats, downstream components still assume one stats "slot" per step method. represent that with a single empty dict. + if not stats_dtypes: + stats_dtypes.append({}) return stats_dtypes, sds @@ -352,12 +356,6 @@ def flatten_steps(step: BlockedStep | CompoundStep) -> list[BlockedStep]: def check_step_emits_tune(step: CompoundStep | BlockedStep): - if isinstance(step, BlockedStep) and "tune" not in step.stats_dtypes_shapes: - raise TypeError(f"{type(step)} does not emit the required 'tune' stat.") - elif isinstance(step, CompoundStep): - for sstep in step.methods: - if "tune" not in sstep.stats_dtypes_shapes: - raise TypeError(f"{type(sstep)} does not emit the required 'tune' stat.") return diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index 297b095e23..c3e6d75e5c 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -273,7 +273,6 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: self.iter_count += 1 stats: dict[str, Any] = { - "tune": self.tune, "diverging": diverging, "divergences": self.divergences, "perf_counter_diff": perf_end - perf_start, diff --git a/pymc/step_methods/hmc/hmc.py b/pymc/step_methods/hmc/hmc.py index 1697341bc8..57fd5219b1 100644 --- a/pymc/step_methods/hmc/hmc.py +++ b/pymc/step_methods/hmc/hmc.py @@ -53,7 +53,6 @@ class HamiltonianMC(BaseHMC): stats_dtypes_shapes = { "step_size": (np.float64, []), "n_steps": (np.int64, []), - "tune": (bool, []), "step_size_bar": (np.float64, []), "accept": (np.float64, []), "diverging": (bool, []), diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index c927d57e31..f674e852ee 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -110,7 +110,6 @@ class NUTS(BaseHMC): stats_dtypes_shapes = { "depth": (np.int64, []), "step_size": (np.float64, []), - "tune": (bool, []), "mean_tree_accept": (np.float64, []), "step_size_bar": (np.float64, []), "tree_size": (np.float64, []), diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index c042bc1f3d..b371f6dd48 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -146,7 +146,6 @@ class Metropolis(ArrayStepShared): stats_dtypes_shapes = { "accept": (np.float64, []), "accepted": (np.float64, []), - "tune": (bool, []), "scaling": (np.float64, []), } @@ -316,7 +315,6 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: self.steps_until_tune -= 1 stats = { - "tune": self.tune, "scaling": np.mean(self.scaling), "accept": np.mean(np.exp(self.accept_rate_iter)), "accepted": np.mean(self.accepted_iter), @@ -331,7 +329,6 @@ def competence(var, has_grad): @staticmethod def _progressbar_config(n_chains=1): columns = [ - TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)), TextColumn("{task.fields[scaling]:0.2f}", table_column=Column("Scaling", ratio=1)), TextColumn( "{task.fields[accept_rate]:0.2f}", table_column=Column("Accept Rate", ratio=1) @@ -339,7 +336,6 @@ def _progressbar_config(n_chains=1): ] stats = { - "tune": [True] * n_chains, "scaling": [0] * n_chains, "accept_rate": [0.0] * n_chains, } @@ -351,7 +347,7 @@ def _make_progressbar_update_functions(): def update_stats(step_stats): return { "accept_rate" if key == "accept" else key: step_stats[key] - for key in ("tune", "accept", "scaling") + for key in ("accept", "scaling") } return (update_stats,) @@ -448,7 +444,6 @@ class BinaryMetropolis(ArrayStep): stats_dtypes_shapes = { "accept": (np.float64, []), - "tune": (bool, []), "p_jump": (np.float64, []), } @@ -505,7 +500,6 @@ def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: self.accepted += accepted stats = { - "tune": self.tune, "accept": np.exp(accept), "p_jump": p_jump, } @@ -574,9 +568,7 @@ class BinaryGibbsMetropolis(ArrayStep): name = "binary_gibbs_metropolis" - stats_dtypes_shapes = { - "tune": (bool, []), - } + stats_dtypes_shapes = {} _state_class = BinaryGibbsMetropolisState @@ -594,8 +586,6 @@ def __init__( ): model = pm.modelcontext(model) - # Doesn't actually tune, but it's required to emit a sampler stat - # that indicates whether a draw was done in a tuning phase. self.tune = True # transition probabilities self.transit_p = transit_p @@ -649,10 +639,7 @@ def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: if accepted: logp_curr = logp_prop - stats = { - "tune": self.tune, - } - return q, [stats] + return q, [{}] @staticmethod def competence(var): @@ -695,9 +682,7 @@ class CategoricalGibbsMetropolis(ArrayStep): name = "categorical_gibbs_metropolis" - stats_dtypes_shapes = { - "tune": (bool, []), - } + stats_dtypes_shapes = {} _state_class = CategoricalGibbsMetropolisState @@ -793,7 +778,7 @@ def astep_unif(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType logp_curr = logp_prop # This step doesn't have any tunable parameters - return q, [{"tune": False}] + return q, [{}] def astep_prop(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: logp = args[0] @@ -811,7 +796,7 @@ def astep_prop(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType logp_curr = self.metropolis_proportional(q, logp, logp_curr, dim, k) # This step doesn't have any tunable parameters - return q, [{"tune": False}] + return q, [{}] def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: raise NotImplementedError() @@ -919,7 +904,6 @@ class DEMetropolis(PopulationArrayStepShared): stats_dtypes_shapes = { "accept": (np.float64, []), "accepted": (bool, []), - "tune": (bool, []), "scaling": (np.float64, []), "lambda": (np.float64, []), } @@ -1011,7 +995,6 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: self.steps_until_tune -= 1 stats = { - "tune": self.tune, "scaling": self.scaling, "lambda": self.lamb, "accept": np.exp(accept), @@ -1090,7 +1073,6 @@ class DEMetropolisZ(ArrayStepShared): stats_dtypes_shapes = { "accept": (np.float64, []), "accepted": (bool, []), - "tune": (bool, []), "scaling": (np.float64, []), "lambda": (np.float64, []), } @@ -1213,7 +1195,6 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: self.steps_until_tune -= 1 stats = { - "tune": self.tune, "scaling": np.mean(self.scaling), "lambda": self.lamb, "accept": np.exp(accept), diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index 180ac1c882..5ea92fc916 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -72,7 +72,6 @@ class Slice(ArrayStepShared): name = "slice" default_blocked = False stats_dtypes_shapes = { - "tune": (bool, []), "nstep_out": (int, []), "nstep_in": (int, []), } @@ -184,7 +183,6 @@ def astep(self, apoint: RaveledVars) -> tuple[RaveledVars, StatsType]: self.n_tunes += 1 stats = { - "tune": self.tune, "nstep_out": nstep_out, "nstep_in": nstep_in, } @@ -202,18 +200,17 @@ def competence(var, has_grad): @staticmethod def _progressbar_config(n_chains=1): columns = [ - TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)), TextColumn("{task.fields[nstep_out]}", table_column=Column("Steps out", ratio=1)), TextColumn("{task.fields[nstep_in]}", table_column=Column("Steps in", ratio=1)), ] - stats = {"tune": [True] * n_chains, "nstep_out": [0] * n_chains, "nstep_in": [0] * n_chains} + stats = {"nstep_out": [0] * n_chains, "nstep_in": [0] * n_chains} return columns, stats @staticmethod def _make_progressbar_update_functions(): def update_stats(step_stats): - return {key: step_stats[key] for key in {"tune", "nstep_out", "nstep_in"}} + return {key: step_stats[key] for key in {"nstep_out", "nstep_in"}} return (update_stats,) diff --git a/tests/backends/test_mcbackend.py b/tests/backends/test_mcbackend.py index e72731af6b..64ad927454 100644 --- a/tests/backends/test_mcbackend.py +++ b/tests/backends/test_mcbackend.py @@ -293,13 +293,11 @@ def test_return_multitrace(self, simple_model, discard_warmup): return_inferencedata=False, ) assert isinstance(mtrace, pm.backends.base.MultiTrace) - tune = mtrace._straces[0].get_sampler_stats("tune") - assert isinstance(tune, np.ndarray) + # warmup is tracked by the sampling driver if discard_warmup: - assert tune.shape == (7, 3) + assert len(mtrace) == 7 else: - assert tune.shape == (12, 3) - pass + assert len(mtrace) == 12 @pytest.mark.parametrize("cores", [1, 3]) def test_return_inferencedata(self, simple_model, cores): diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 090b76130b..fcacad7a95 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -414,6 +414,33 @@ def test_sample_return_lengths(self): assert idata.posterior.sizes["draw"] == 100 assert idata.posterior.sizes["chain"] == 3 + def test_categorical_gibbs_respects_driver_tune_boundary(self): + with pm.Model(): + pm.Categorical("x", p=np.array([0.2, 0.3, 0.5])) + sample_kwargs = { + "tune": 5, + "draws": 7, + "chains": 1, + "cores": 1, + "return_inferencedata": False, + "compute_convergence_checks": False, + "progressbar": False, + "random_seed": 123, + } + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) + mtrace = pm.sample(discard_tuned_samples=True, **sample_kwargs) + assert len(mtrace) == 7 + assert mtrace.report.n_tune == 5 + assert mtrace.report.n_draws == 7 + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) + with pytest.warns(UserWarning, match="will be included"): + mtrace_warmup = pm.sample(discard_tuned_samples=False, **sample_kwargs) + assert len(mtrace_warmup) == 12 + assert mtrace_warmup.report.n_tune == 5 + assert mtrace_warmup.report.n_draws == 7 + @pytest.mark.parametrize("cores", [1, 2]) def test_logs_sampler_warnings(self, caplog, cores): """Asserts that "warning" sampler stats are logged during sampling."""