Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions pymc/backends/mcbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
BlockedStep,
CompoundStep,
StatsBijection,
check_step_emits_tune,
flat_statname,
flatten_steps,
)
Expand Down Expand Up @@ -210,8 +209,6 @@ def make_runmeta_and_point_fn(
) -> tuple[mcb.RunMeta, PointFunc]:
variables, point_fn = get_variables_and_point_fn(model, initial_point)

check_step_emits_tune(step)
Copy link
Member

@ricardoV94 ricardoV94 Dec 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@michaelosthege what's the requirement here? Can we change things on mcbackend to not expect tune info attached to the steps?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No we can not change that tune info must be attached to a step. See my comment in #7997 (comment).


# In PyMC the sampler stats are grouped by the sampler.
sample_stats = []
steps = flatten_steps(step)
Expand Down
16 changes: 4 additions & 12 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,18 +1043,10 @@ def _sample_return(
else:
traces, length = _choose_chains(traces, 0)
mtrace = MultiTrace(traces)[:length]
# count the number of tune/draw iterations that happened
# ideally via the "tune" statistic, but not all samplers record it!
if "tune" in mtrace.stat_names:
# Get the tune stat directly from chain 0, sampler 0
stat = mtrace._straces[0].get_sampler_stats("tune", sampler_idx=0)
stat = tuple(stat)
n_tune = stat.count(True)
n_draws = stat.count(False)
else:
# these may be wrong when KeyboardInterrupt happened, but they're better than nothing
n_tune = min(tune, len(mtrace))
n_draws = max(0, len(mtrace) - n_tune)
# Count the number of tune/draw iterations that happened.
# The warmup/draw boundary is owned by the sampling driver.
n_tune = min(tune, len(mtrace))
n_draws = max(0, len(mtrace) - n_tune)

if discard_tuned_samples:
mtrace = mtrace[n_tune:]
Expand Down
10 changes: 4 additions & 6 deletions pymc/step_methods/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def infer_warn_stats_info(
sds[sname] = (dtype, None)
elif sds:
stats_dtypes.append({sname: dtype for sname, (dtype, _) in sds.items()})

# Even when a step method does not emit any stats, downstream components still assume one stats "slot" per step method. represent that with a single empty dict.
if not stats_dtypes:
stats_dtypes.append({})
return stats_dtypes, sds


Expand Down Expand Up @@ -352,12 +356,6 @@ def flatten_steps(step: BlockedStep | CompoundStep) -> list[BlockedStep]:


def check_step_emits_tune(step: CompoundStep | BlockedStep):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this ends up working without tune, remove this function as well

if isinstance(step, BlockedStep) and "tune" not in step.stats_dtypes_shapes:
raise TypeError(f"{type(step)} does not emit the required 'tune' stat.")
elif isinstance(step, CompoundStep):
for sstep in step.methods:
if "tune" not in sstep.stats_dtypes_shapes:
raise TypeError(f"{type(sstep)} does not emit the required 'tune' stat.")
return


Expand Down
1 change: 0 additions & 1 deletion pymc/step_methods/hmc/base_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,6 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
self.iter_count += 1

stats: dict[str, Any] = {
"tune": self.tune,
"diverging": diverging,
"divergences": self.divergences,
"perf_counter_diff": perf_end - perf_start,
Expand Down
1 change: 0 additions & 1 deletion pymc/step_methods/hmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ class HamiltonianMC(BaseHMC):
stats_dtypes_shapes = {
"step_size": (np.float64, []),
"n_steps": (np.int64, []),
"tune": (bool, []),
"step_size_bar": (np.float64, []),
"accept": (np.float64, []),
"diverging": (bool, []),
Expand Down
1 change: 0 additions & 1 deletion pymc/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ class NUTS(BaseHMC):
stats_dtypes_shapes = {
"depth": (np.int64, []),
"step_size": (np.float64, []),
"tune": (bool, []),
"mean_tree_accept": (np.float64, []),
"step_size_bar": (np.float64, []),
"tree_size": (np.float64, []),
Expand Down
31 changes: 6 additions & 25 deletions pymc/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ class Metropolis(ArrayStepShared):
stats_dtypes_shapes = {
"accept": (np.float64, []),
"accepted": (np.float64, []),
"tune": (bool, []),
"scaling": (np.float64, []),
}

Expand Down Expand Up @@ -316,7 +315,6 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
self.steps_until_tune -= 1

stats = {
"tune": self.tune,
"scaling": np.mean(self.scaling),
"accept": np.mean(np.exp(self.accept_rate_iter)),
"accepted": np.mean(self.accepted_iter),
Expand All @@ -331,15 +329,13 @@ def competence(var, has_grad):
@staticmethod
def _progressbar_config(n_chains=1):
columns = [
TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need a work-around for this, as we still want to show whether we're in tuning or not in the progressbar

TextColumn("{task.fields[scaling]:0.2f}", table_column=Column("Scaling", ratio=1)),
TextColumn(
"{task.fields[accept_rate]:0.2f}", table_column=Column("Accept Rate", ratio=1)
),
]

stats = {
"tune": [True] * n_chains,
"scaling": [0] * n_chains,
"accept_rate": [0.0] * n_chains,
}
Expand All @@ -351,7 +347,7 @@ def _make_progressbar_update_functions():
def update_stats(step_stats):
return {
"accept_rate" if key == "accept" else key: step_stats[key]
for key in ("tune", "accept", "scaling")
for key in ("accept", "scaling")
}

return (update_stats,)
Expand Down Expand Up @@ -448,7 +444,6 @@ class BinaryMetropolis(ArrayStep):

stats_dtypes_shapes = {
"accept": (np.float64, []),
"tune": (bool, []),
"p_jump": (np.float64, []),
}

Expand Down Expand Up @@ -505,7 +500,6 @@ def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]:
self.accepted += accepted

stats = {
"tune": self.tune,
"accept": np.exp(accept),
"p_jump": p_jump,
}
Expand Down Expand Up @@ -574,9 +568,7 @@ class BinaryGibbsMetropolis(ArrayStep):

name = "binary_gibbs_metropolis"

stats_dtypes_shapes = {
"tune": (bool, []),
}
stats_dtypes_shapes = {}

_state_class = BinaryGibbsMetropolisState

Expand All @@ -594,8 +586,6 @@ def __init__(
):
model = pm.modelcontext(model)

# Doesn't actually tune, but it's required to emit a sampler stat
# that indicates whether a draw was done in a tuning phase.
self.tune = True
# transition probabilities
self.transit_p = transit_p
Expand Down Expand Up @@ -649,10 +639,7 @@ def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]:
if accepted:
logp_curr = logp_prop

stats = {
"tune": self.tune,
}
return q, [stats]
return q, [{}]

@staticmethod
def competence(var):
Expand Down Expand Up @@ -695,9 +682,7 @@ class CategoricalGibbsMetropolis(ArrayStep):

name = "categorical_gibbs_metropolis"

stats_dtypes_shapes = {
"tune": (bool, []),
}
stats_dtypes_shapes = {}

_state_class = CategoricalGibbsMetropolisState

Expand Down Expand Up @@ -793,7 +778,7 @@ def astep_unif(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType
logp_curr = logp_prop

# This step doesn't have any tunable parameters
return q, [{"tune": False}]
return q, [{}]

def astep_prop(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]:
logp = args[0]
Expand All @@ -811,7 +796,7 @@ def astep_prop(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType
logp_curr = self.metropolis_proportional(q, logp, logp_curr, dim, k)

# This step doesn't have any tunable parameters
return q, [{"tune": False}]
return q, [{}]

def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]:
raise NotImplementedError()
Expand Down Expand Up @@ -919,7 +904,6 @@ class DEMetropolis(PopulationArrayStepShared):
stats_dtypes_shapes = {
"accept": (np.float64, []),
"accepted": (bool, []),
"tune": (bool, []),
"scaling": (np.float64, []),
"lambda": (np.float64, []),
}
Expand Down Expand Up @@ -1011,7 +995,6 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
self.steps_until_tune -= 1

stats = {
"tune": self.tune,
"scaling": self.scaling,
"lambda": self.lamb,
"accept": np.exp(accept),
Expand Down Expand Up @@ -1090,7 +1073,6 @@ class DEMetropolisZ(ArrayStepShared):
stats_dtypes_shapes = {
"accept": (np.float64, []),
"accepted": (bool, []),
"tune": (bool, []),
"scaling": (np.float64, []),
"lambda": (np.float64, []),
}
Expand Down Expand Up @@ -1213,7 +1195,6 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
self.steps_until_tune -= 1

stats = {
"tune": self.tune,
"scaling": np.mean(self.scaling),
"lambda": self.lamb,
"accept": np.exp(accept),
Expand Down
7 changes: 2 additions & 5 deletions pymc/step_methods/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ class Slice(ArrayStepShared):
name = "slice"
default_blocked = False
stats_dtypes_shapes = {
"tune": (bool, []),
"nstep_out": (int, []),
"nstep_in": (int, []),
}
Expand Down Expand Up @@ -184,7 +183,6 @@ def astep(self, apoint: RaveledVars) -> tuple[RaveledVars, StatsType]:
self.n_tunes += 1

stats = {
"tune": self.tune,
"nstep_out": nstep_out,
"nstep_in": nstep_in,
}
Expand All @@ -202,18 +200,17 @@ def competence(var, has_grad):
@staticmethod
def _progressbar_config(n_chains=1):
columns = [
TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)),
TextColumn("{task.fields[nstep_out]}", table_column=Column("Steps out", ratio=1)),
TextColumn("{task.fields[nstep_in]}", table_column=Column("Steps in", ratio=1)),
]

stats = {"tune": [True] * n_chains, "nstep_out": [0] * n_chains, "nstep_in": [0] * n_chains}
stats = {"nstep_out": [0] * n_chains, "nstep_in": [0] * n_chains}

return columns, stats

@staticmethod
def _make_progressbar_update_functions():
def update_stats(step_stats):
return {key: step_stats[key] for key in {"tune", "nstep_out", "nstep_in"}}
return {key: step_stats[key] for key in {"nstep_out", "nstep_in"}}

return (update_stats,)
8 changes: 3 additions & 5 deletions tests/backends/test_mcbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,13 +293,11 @@ def test_return_multitrace(self, simple_model, discard_warmup):
return_inferencedata=False,
)
assert isinstance(mtrace, pm.backends.base.MultiTrace)
tune = mtrace._straces[0].get_sampler_stats("tune")
assert isinstance(tune, np.ndarray)
# warmup is tracked by the sampling driver
if discard_warmup:
assert tune.shape == (7, 3)
assert len(mtrace) == 7
else:
assert tune.shape == (12, 3)
pass
assert len(mtrace) == 12

@pytest.mark.parametrize("cores", [1, 3])
def test_return_inferencedata(self, simple_model, cores):
Expand Down
27 changes: 27 additions & 0 deletions tests/sampling/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,33 @@ def test_sample_return_lengths(self):
assert idata.posterior.sizes["draw"] == 100
assert idata.posterior.sizes["chain"] == 3

def test_categorical_gibbs_respects_driver_tune_boundary(self):
with pm.Model():
pm.Categorical("x", p=np.array([0.2, 0.3, 0.5]))
sample_kwargs = {
"tune": 5,
"draws": 7,
"chains": 1,
"cores": 1,
"return_inferencedata": False,
"compute_convergence_checks": False,
"progressbar": False,
"random_seed": 123,
}
with warnings.catch_warnings():
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
mtrace = pm.sample(discard_tuned_samples=True, **sample_kwargs)
assert len(mtrace) == 7
assert mtrace.report.n_tune == 5
assert mtrace.report.n_draws == 7
with warnings.catch_warnings():
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
with pytest.warns(UserWarning, match="will be included"):
mtrace_warmup = pm.sample(discard_tuned_samples=False, **sample_kwargs)
assert len(mtrace_warmup) == 12
assert mtrace_warmup.report.n_tune == 5
assert mtrace_warmup.report.n_draws == 7

@pytest.mark.parametrize("cores", [1, 2])
def test_logs_sampler_warnings(self, caplog, cores):
"""Asserts that "warning" sampler stats are logged during sampling."""
Expand Down
Loading