diff --git a/marimo/_runtime/executor/lifecycles/cached.py b/marimo/_runtime/executor/lifecycles/cached.py new file mode 100644 index 00000000000..a0c2ed1dbcb --- /dev/null +++ b/marimo/_runtime/executor/lifecycles/cached.py @@ -0,0 +1,256 @@ +# Copyright 2026 Marimo. All rights reserved. +"""CachedLifecycle — cell-level caching as setup/teardown. + +On setup: hash the cell + consult the lazy store. + + - HIT → restore defs into globals, return `Skip` with the cached + return value so the Evaluator short-circuits the body. + - MISS → pre-flight `cell.refs`: if any ref in `glbls` is an + `UnhashableStub` (a stale placeholder left over from an + upstream cached producer whose value resisted serialization), + invalidate the producer's manifest and raise + `MarimoCancelCellError(cells_to_rerun=producers ∪ self)`. + `Runner.run_all` catches the signal and hands it to + `Scheduler.requeue_for_rerun`; the body never runs this turn. + Otherwise, fall through and let the body run. + +On teardown: backfill on successful miss; drop the attempt on a raised +body. No body-trip handling — the pre-flight at setup catches it +strictly earlier, against the *refs* rather than reacting to a partial +body's runtime access. +""" + +from __future__ import annotations + +import time +from typing import TYPE_CHECKING, cast + +from marimo import _loggers +from marimo._runtime.exceptions import MarimoCancelCellError +from marimo._runtime.executor.lifecycles import Skip +from marimo._runtime.runner.result import RunResult +from marimo._save.cache import Cache +from marimo._save.hash import cache_attempt_from_hash +from marimo._save.loaders import ( + PERSISTENT_LOADERS, + BasePersistenceLoader, + LoaderKey, + resolve_loader, +) + +if TYPE_CHECKING: + from marimo._ast.cell import CellImpl + from marimo._runtime.dataflow import DirectedGraph + from marimo._types.globals import MutableGlobals + from marimo._types.ids import CellId_t + +LOGGER = _loggers.marimo_logger() + +# Loader backing cell-level caching, resolved through the persistent- +# loader registry so improvements to the lazy per-def loader arrive +# transparently. +DEFAULT_CELL_LOADER: LoaderKey = "lazy" + + +def _is_unhashable_stub(value: object) -> bool: + """Duck-typed check for serialization-failure placeholders. + + Loaders that cannot serialize a def restore a stub carrying the + class-level `__marimo_unhashable__` marker. Detecting the protocol + attribute (rather than importing the stub class) keeps the runtime + lifecycle and the serialization toolkit independently mergeable. + """ + return getattr(type(value), "__marimo_unhashable__", False) is True + + +class CachedLifecycle: + """Skip cell exec on cache hit; backfill cell results on miss success. + + Inner wrap relative to `StrictLifecycle`: when both are configured, the + Evaluator runs Strict.setup → Cached.setup → body → Cached.teardown → + Strict.teardown. Caching sees a sanitized scope (Strict already ran), + and Strict's restore happens after Cached's backfill (so the cache + captures the cell's real defs). + """ + + name = "cached" + + def __init__( + self, + graph: DirectedGraph, + pin_modules: bool = True, + loader: LoaderKey = DEFAULT_CELL_LOADER, + ) -> None: + self._graph = graph + self._pin_modules = pin_modules + # The persistent loaders are all BasePersistenceLoader subclasses + # (which carry `.store`); the registry is typed as the `Loader` base. + self._loader = cast( + BasePersistenceLoader, + resolve_loader(PERSISTENT_LOADERS[loader])(name="lazy"), + ) + # Per-cell state — populated in setup, consumed in teardown. + self._attempts: dict[CellId_t, Cache] = {} + self._exec_starts: dict[CellId_t, float] = {} + # Per-cell manifest path, recorded on hit/save. Consumed when + # this cell's pre-flight invalidates an upstream producer. + self._manifest_keys: dict[CellId_t, str] = {} + + def setup(self, cell: CellImpl, glbls: MutableGlobals) -> Skip | None: + cell_id = cell.cell_id + + attempt = cache_attempt_from_hash( + cell.mod, + self._graph, + cell_id, + glbls, + loader=self._loader, + pin_modules=self._pin_modules, + ) + self._attempts[cell_id] = attempt + + if attempt.hit: + try: + attempt.restore(glbls) + except Exception as e: + LOGGER.warning("Cache restore failed for %s: %s", cell_id, e) + self._attempts.pop(cell_id, None) + # Fall through to miss-path execution. + else: + if self._restored_ui_defs(attempt, glbls): + # A restored UIElement carries a fresh object id while + # the cached output HTML embeds the saving session's — + # the frontend and kernel would disagree and events go + # nowhere. UI construction is cheap and inherently + # session state: run the cell live instead. + LOGGER.debug( + "Cache hit for %s defines UI elements; running " + "live to register them with this session", + cell_id, + ) + self._attempts.pop(cell_id, None) + # Fall through to miss-path execution. + else: + self._manifest_keys[cell_id] = str( + self._loader.build_path(attempt.key) + ) + return Skip( + result=RunResult( + output=attempt.meta.get("return"), exception=None + ) + ) + + # Pre-flight refs against UnhashableStubs left in scope by upstream + # cached producers. Raises MarimoCancelCellError if any ref is a + # stub — body never runs this turn. + self._preflight_refs(cell, glbls) + + self._exec_starts[cell_id] = time.time() + return None + + def teardown( + self, + cell: CellImpl, + glbls: MutableGlobals, + run_result: RunResult, + ) -> None: + cell_id = cell.cell_id + attempt = self._attempts.pop(cell_id, None) + exec_start = self._exec_starts.pop(cell_id, None) + + if attempt is None: + return + if attempt.hit: + return + if run_result.exception is not None: + return + + runtime = time.time() - (exec_start if exec_start else time.time()) + try: + attempt.update( + {**glbls}, + meta={ + "return": run_result.output, + "runtime": runtime, + }, + preserve_pointers=False, + ) + saved = self._loader.save_cache(attempt) + if saved: + self._manifest_keys[cell_id] = str( + self._loader.build_path(attempt.key) + ) + except BaseException as e: + # Best-effort: save failures (incl. CacheException, which + # extends BaseException) must never break the teardown chain. + LOGGER.warning("Cache save failed for %s: %s", cell_id, e) + + @staticmethod + def _restored_ui_defs(attempt: Cache, glbls: MutableGlobals) -> bool: + """True if any def restored from the cache is a live UIElement.""" + from marimo._plugins.ui._core.ui_element import UIElement + + return any( + isinstance(glbls.get(name), UIElement) for name in attempt.defs + ) + + def _preflight_refs(self, cell: CellImpl, glbls: MutableGlobals) -> None: + """Detect UnhashableStub residues in refs; requeue producers. + + Walks `cell.refs` and checks each name in `glbls` for an + `UnhashableStub` instance — a placeholder left behind by an + upstream cached cell whose def couldn't be serialized. If any + are found, invalidates each producer's recorded manifest, drops + this cell's attempt so teardown won't try to backfill, and + raises `MarimoCancelCellError` with `cells_to_rerun` populated + so `Runner.run_all` can `Scheduler.requeue_for_rerun` the + producers (plus this cell, which retries after they produce real + values). + + Cheap top-level scan: only directly-stub refs trip. Stubs + embedded inside other values are not detected here — those would + surface during body execution as a `MarimoUnhashableCacheError` + from the stub's `__call__`. + """ + cell_id = cell.cell_id + stub_vars: list[str] = [] + for ref in cell.refs: + value = glbls.get(ref) if ref in glbls else None + if _is_unhashable_stub(value): + stub_vars.append(ref) + + if not stub_vars: + return + + cells_to_rerun: set[CellId_t] = {cell_id} + for var_name in stub_vars: + try: + cells_to_rerun.update(self._graph.get_defining_cells(var_name)) + except KeyError: + pass + + for producer_id in cells_to_rerun - {cell_id}: + self._invalidate(producer_id) + + # Drop our own attempt — body is being skipped this turn, so + # teardown must not backfill against the partially-restored scope. + self._attempts.pop(cell_id, None) + self._exec_starts.pop(cell_id, None) + + LOGGER.info( + "Pre-flight requeue for %s: stub refs %s; producers %s", + cell_id, + stub_vars, + cells_to_rerun - {cell_id}, + ) + raise MarimoCancelCellError(cells_to_rerun=cells_to_rerun) + + def _invalidate(self, cell_id: CellId_t) -> None: + """Delete the recorded manifest for `cell_id` (if any).""" + key = self._manifest_keys.pop(cell_id, None) + if key is None: + return + try: + self._loader.store.clear(key) + except Exception as e: + LOGGER.warning("Manifest clear failed for %s: %s", key, e) diff --git a/marimo/_runtime/runner/cell_runner.py b/marimo/_runtime/runner/cell_runner.py index 129669a1e54..82c908cdad0 100644 --- a/marimo/_runtime/runner/cell_runner.py +++ b/marimo/_runtime/runner/cell_runner.py @@ -23,6 +23,7 @@ from marimo._runtime.context.types import safe_get_context from marimo._runtime.control_flow import MarimoInterrupt, MarimoStopError from marimo._runtime.exceptions import ( + MarimoCancelCellError, MarimoMissingRefError, MarimoRuntimeException, unwrap_user_exception, @@ -157,6 +158,28 @@ def __init__( lifecycles: list[ExecutionLifecycle] = [] if execution_type == "strict": lifecycles.append(StrictLifecycle(self.graph)) + if user_config is not None and user_config.get("runtime", {}).get( + "cache_cells", False + ): + # Lazy import: pulls in the cache machinery (and its downstream + # marimo._save chain), which would create a circular import at + # module load. + from marimo._runtime.executor.lifecycles.cached import ( + CachedLifecycle, + ) + + lifecycles.append( + CachedLifecycle( + self.graph, + # Pinning trades staleness protection for key + # portability: a cache exported across environments + # (e.g. into a WASM bundle) only hits when module + # versions are excluded from the key. + pin_modules=bool( + user_config.get("runtime", {}).get("pin_modules", True) + ), + ) + ) self._evaluator = Evaluator( executor=resolve_executor(), lifecycles=lifecycles ) @@ -417,6 +440,26 @@ async def evaluate_interruptible(self, cell: CellImpl) -> RunResult: # rather than escaping to the broad except below. return RunResult(output=None, exception=exc) + @staticmethod + def _log_internal_error() -> None: + """Defensive: an unexpected escape from the Evaluator or a bug in + `_finalize_run_result` would otherwise tear down the runner loop. + Log a report and let the caller degrade to an empty RunResult.""" + LOGGER.error( + """marimo encountered an internal error. + + marimo finished executing a cell, but did not produce + a run result. + + Please copy this message and paste it in a GitHub issue: + + https://github.com/marimo-team/marimo/issues + + Any additional context of what caused this error, such + as sample code to reproduce, will help us debug. + """ + ) + async def run(self, cell_id: CellId_t) -> RunResult: """Run a cell.""" if self.debugger is not None: @@ -430,25 +473,22 @@ async def run(self, cell_id: CellId_t) -> RunResult: # effects are applied below in `_finalize_run_result`. try: raw_result = await self.evaluate_interruptible(cell) - run_result = self._finalize_run_result(raw_result, cell_id) except BaseException: - # Defensive: an unexpected escape from the Evaluator or a bug - # in `_finalize_run_result` would otherwise tear down the - # runner loop. Degrade gracefully with an empty RunResult. - LOGGER.error( - """marimo encountered an internal error. - - marimo finished executing a cell, but did not produce - a run result. + self._log_internal_error() + raw_result = RunResult(output=None, exception=None) - Please copy this message and paste it in a GitHub issue: + # Soft-cancel control signal from a lifecycle (e.g. CachedLifecycle + # tripping on a stale UnhashableStub ref): propagate to `run_all` + # so it can requeue the producing cells. Lifecycle teardown already + # ran inside the Evaluator; this is not a cell error, so it is not + # classified or recorded here. + if isinstance(raw_result.exception, MarimoCancelCellError): + raise raw_result.exception - https://github.com/marimo-team/marimo/issues - - Any additional context of what caused this error, such - as sample code to reproduce, will help us debug. - """ - ) + try: + run_result = self._finalize_run_result(raw_result, cell_id) + except BaseException: + self._log_internal_error() run_result = RunResult(output=None, exception=None) # Mark as interrupted if the cell raised a MarimoInterrupt @@ -807,4 +847,24 @@ async def _dispatch_runnable( cell.set_run_result_status("cancelled") cell.set_runtime_state("idle") continue - await self._run_one(cell_id, pre_exec_ctx, post_exec_ctx) + try: + await self._run_one(cell_id, pre_exec_ctx, post_exec_ctx) + except MarimoCancelCellError as e: + # Soft-cancel: a lifecycle (e.g. CachedLifecycle hitting + # a stale UnhashableStub) asked to re-run producer cells + # so they emit real values. Requeue them at the head of + # the queue; this cell retries after they run. Not a + # cell error — don't classify or record it. + LOGGER.debug( + "Soft-cancel for %s; requeuing %s", + cell_id, + e.cells_to_rerun, + ) + # The pre-execution hook set this cell's runtime_state to + # "running"; the soft-cancel raised out of `_run_one` + # before the post-execution hook could reset it. Mark + # every requeued cell "queued" so none lingers in + # "running" while its producers re-execute. + for rerun_id in e.cells_to_rerun: + self.graph.cells[rerun_id].set_runtime_state("queued") + self._scheduler.requeue_for_rerun(e.cells_to_rerun) diff --git a/marimo/_runtime/runner/hook_context.py b/marimo/_runtime/runner/hook_context.py index 9d51c5da882..24f653de4e4 100644 --- a/marimo/_runtime/runner/hook_context.py +++ b/marimo/_runtime/runner/hook_context.py @@ -42,6 +42,17 @@ def add(self, raising_cell: CellId_t, descendants: set[CellId_t]) -> None: self._by_raising_cell[raising_cell] = descendants self._all.update(descendants) + def discard(self, cell_id: CellId_t) -> None: + """Un-cancel a cell (as both a raiser and a descendant). + + Used by `Scheduler.requeue_for_rerun` to clear cancellation + before re-running a cell. + """ + self._all.discard(cell_id) + self._by_raising_cell.pop(cell_id, None) + for descendants in self._by_raising_cell.values(): + descendants.discard(cell_id) + def __contains__(self, cell_id: object) -> bool: """O(1) check if a cell has been cancelled.""" return cell_id in self._all diff --git a/marimo/_runtime/runner/scheduler.py b/marimo/_runtime/runner/scheduler.py index 053e505dcb1..39050b046f7 100644 --- a/marimo/_runtime/runner/scheduler.py +++ b/marimo/_runtime/runner/scheduler.py @@ -59,6 +59,7 @@ def batch( self, cell_ids: Iterable[CellId_t] | None = ... ) -> Iterator[Iterable[CellId_t]]: ... def requeue(self, cell_ids: Iterable[CellId_t]) -> None: ... + def requeue_for_rerun(self, cells: set[CellId_t]) -> None: ... def start_task( self, @@ -117,6 +118,33 @@ def requeue(self, cell_ids: Iterable[CellId_t]) -> None: self._cells_to_run.clear() self._cells_to_run.extend(cell_ids) + def requeue_for_rerun(self, cells: set[CellId_t]) -> None: + """Soft-cancel: put `cells` back at the head of the queue. + + Called when a lifecycle raises + `MarimoCancelCellError(cells_to_rerun=...)` — e.g. `CachedLifecycle` + signaling that producers need their bodies to actually run. + Un-cancels each cell and prepends them in **topological order** + (producers before consumers) so the next `batch()` yields the + producers first; they emit real values before the consumer that + tripped retries. Without the topo order a consumer requeued ahead + of its producer would re-trip on the same stale value and loop + forever (the cells set has no inherent ordering). + """ + ordered = dataflow.topological_sort(self._graph, cells) + # appendleft reverses, so iterate back-to-front to land the + # topologically-first cell at the head of the queue. + for cid in reversed(ordered): + self._cancelled.discard(cid) + # Move to the head even when already queued: a cell left at a + # later position than its producer would re-trip on the stale + # value. A deque has no move op, so drop the stale position + # (remove() is a no-op-safe O(n) scan) before prepending — this + # also prevents the cell appearing twice in the queue. + if cid in self._cells_to_run: + self._cells_to_run.remove(cid) + self._cells_to_run.appendleft(cid) + def cancel(self, cell_id: CellId_t) -> None: """Mark a cell and its descendants as cancelled.""" descendants = { diff --git a/tests/_runtime/test_cached_stage.py b/tests/_runtime/test_cached_stage.py new file mode 100644 index 00000000000..aed37bd4447 --- /dev/null +++ b/tests/_runtime/test_cached_stage.py @@ -0,0 +1,305 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Tests for CachedLifecycle — cell-level caching as a per-cell lifecycle. + +Ported from the original cell-caching test suite and adapted to the +integrated `executor/lifecycles` framework (the source branch's +`CachedStage`/`wrappers` names were abandoned intermediate renames). +UnhashableStub tripwire assertions follow the shipped design: `__call__` +is the only tripwire; other accesses fall through to Python defaults. +""" + +from __future__ import annotations + +import copy +import dataclasses +from typing import TYPE_CHECKING + +import pytest + +from marimo._runtime.exceptions import ( + MarimoCancelCellError, +) +from marimo._runtime.executor.lifecycles.cached import CachedLifecycle +from marimo._save.loaders.lazy import LazyLoader + +try: + # Ships with the stub serialization toolkit; the lifecycle detects + # stubs through the __marimo_unhashable__ protocol attribute and has + # no hard dependency on the class. + from marimo._save.stubs.lazy_stub import UnhashableStub +except ImportError: # pragma: no cover + UnhashableStub = None # type: ignore[assignment] + +# The end-to-end tripwire tests additionally need the lazy loader that +# *produces* UnhashableStub on serialization failure. +try: + from marimo._save.loaders.lazy import LazyStore as _LazyStore +except ImportError: # pragma: no cover + _LazyStore = None # type: ignore[assignment] + +requires_stub_loader = pytest.mark.skipif( + UnhashableStub is None or _LazyStore is None, + reason="needs the stub toolkit + per-def lazy store", +) + + +@dataclasses.dataclass +class _MarkerStub: + """Minimal stand-in carrying the unhashable-stub protocol marker.""" + + # Class-level protocol marker (no annotation → not a dataclass field). + __marimo_unhashable__ = True + + var_name: str + + +if TYPE_CHECKING: + from pathlib import Path + + from tests.conftest import ExecReqProvider, MockedKernel + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def cache_dir(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path: + """Redirect FileStore's default save path to a tmp dir for the test.""" + cache_path = tmp_path / "cache" + cache_path.mkdir() + + def _default_save_path(_self: object) -> Path: + return cache_path + + monkeypatch.setattr( + "marimo._save.stores.file.FileStore._default_save_path", + _default_save_path, + ) + return cache_path + + +@pytest.fixture +def tracked_loaders(monkeypatch: pytest.MonkeyPatch) -> list[LazyLoader]: + """Capture every LazyLoader instance constructed during the test. + + Lifecycles live on each Runner instance and are GC'd between runs, so + there's no kernel-level handle on the loader. Tracking via __init__ + lets tests call .flush() on every loader to drain background save + threads deterministically (instead of sleeping). + """ + instances: list[LazyLoader] = [] + original_init = LazyLoader.__init__ + + def _tracking_init( + self: LazyLoader, + *args: object, + **kwargs: object, + ) -> None: + original_init(self, *args, **kwargs) # type: ignore[arg-type] + instances.append(self) + + monkeypatch.setattr(LazyLoader, "__init__", _tracking_init) + return instances + + +@pytest.fixture +def caching_kernel( + mocked_kernel: MockedKernel, + cache_dir: Path, # noqa: ARG001 — needed for side effect + tracked_loaders: list[LazyLoader], # noqa: ARG001 — needed for side effect +) -> MockedKernel: + """A kernel with cache_cells enabled and a tmp cache dir.""" + # Deep-copy so we don't mutate the shared DEFAULT_CONFIG dict. + mocked_kernel.k.user_config = copy.deepcopy(mocked_kernel.k.user_config) + mocked_kernel.k.user_config["runtime"]["cache_cells"] = True + return mocked_kernel + + +# --------------------------------------------------------------------------- +# CachedLifecycle._preflight_refs — stub-ref detection routes to requeue +# --------------------------------------------------------------------------- + + +class TestCachedLifecyclePreflight: + def test_stub_ref_invalidates_producer_and_raises( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """When a consumer's transitive ref resolves to an UnhashableStub + in scope, pre-flight invalidates the producer's recorded manifest + and raises MarimoCancelCellError with cells_to_rerun populated, so + run_all can requeue the producer (plus this cell). + """ + from unittest.mock import MagicMock + + graph = MagicMock() + graph.get_defining_cells.return_value = {"producer"} + + life = CachedLifecycle(graph) + producer_manifest = "lazy/E_producer.jsonl" + life._manifest_keys["producer"] = producer_manifest + + clear_calls: list[str] = [] + + def _spy_clear(key: str) -> bool: + clear_calls.append(key) + return True + + monkeypatch.setattr(life._loader.store, "clear", _spy_clear) + + cell = _FakeCell("consumer", refs={"f"}) + glbls = {"f": _MarkerStub("f")} + + with pytest.raises(MarimoCancelCellError) as ei: + life._preflight_refs(cell, glbls) # type: ignore[arg-type] + + assert clear_calls == [producer_manifest] + assert {"producer", "consumer"} <= ei.value.cells_to_rerun + + def test_no_stub_refs_is_noop(self) -> None: + """Pre-flight returns cleanly when no ref is an UnhashableStub.""" + from unittest.mock import MagicMock + + life = CachedLifecycle(MagicMock()) + cell = _FakeCell("consumer", refs={"x"}) + # No exception expected. + life._preflight_refs(cell, {"x": 123}) # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# Integration tests — full kernel with cache_cells enabled +# --------------------------------------------------------------------------- + + +class TestCachedLifecycleIntegration: + async def test_basic_hit_miss_cycle( + self, + caching_kernel: MockedKernel, + exec_req: ExecReqProvider, + tracked_loaders: list[LazyLoader], + ) -> None: + """First run misses + executes; second run with same code hits.""" + k = caching_kernel.k + er = exec_req.get(code="x = 1 + 2") + + await k.run([er]) + assert k.globals["x"] == 3 + + for loader in tracked_loaders: + loader.flush() + + loaders_before_second = list(tracked_loaders) + await k.run([er]) + assert k.globals["x"] == 3 + + new_loaders = [ + ld for ld in tracked_loaders if ld not in loaders_before_second + ] + assert new_loaders, "Expected a fresh LazyLoader for the second run" + assert any(ld._hits > 0 for ld in new_loaders), ( + "Expected the second run's LazyLoader to record a cache hit" + ) + + @requires_stub_loader + async def test_unhashable_own_def_does_not_auto_rerun( + self, + caching_kernel: MockedKernel, + exec_req: ExecReqProvider, + tracked_loaders: list[LazyLoader], + ) -> None: + """Cell whose own def is a lambda: cache hit on next session, + body skipped, marker in scope. Not auto-rerun (that would defeat + caching for cells where downstream never needs the real value). + """ + k = caching_kernel.k + er = exec_req.get(code="f = lambda x: x + 1") + + await k.run([er]) + assert callable(k.globals["f"]) + assert k.globals["f"](2) == 3 + + for loader in tracked_loaders: + loader.flush() + + # Simulate fresh session. + k.globals.pop("f", None) + + loaders_before_second = list(tracked_loaders) + await k.run([er]) + new_loaders = [ + ld for ld in tracked_loaders if ld not in loaders_before_second + ] + + # Body skipped — `f` in scope is the UnhashableStub marker. + assert isinstance(k.globals.get("f"), UnhashableStub) + assert any(ld._hits > 0 for ld in new_loaders) + + async def test_failed_run_not_cached( + self, + caching_kernel: MockedKernel, + exec_req: ExecReqProvider, + tracked_loaders: list[LazyLoader], + ) -> None: + k = caching_kernel.k + er = exec_req.get(code="raise RuntimeError('boom')") + + await k.run([er]) + + for loader in tracked_loaders: + loader.flush() + + loaders_before_second = list(tracked_loaders) + await k.run([er]) + + new_loaders = [ + ld for ld in tracked_loaders if ld not in loaders_before_second + ] + assert new_loaders + assert all(ld._hits == 0 for ld in new_loaders) + + @requires_stub_loader + async def test_consumer_calling_lambda_recovers( + self, + caching_kernel: MockedKernel, + exec_req: ExecReqProvider, + tracked_loaders: list[LazyLoader], + ) -> None: + """Producer A defines a lambda; consumer B references it directly. + After fresh-kernel reset, A hits cache (stub in scope), B's hash + differs and misses, B's pre-flight sees the stub in its refs → + invalidates A and requeues. A re-runs with the real lambda; B + retries; `g == 15`. + """ + k = caching_kernel.k + producer = exec_req.get(code="f = lambda x: x + 10") + consumer = exec_req.get(code="g = f(5)") + + await k.run([producer, consumer]) + assert k.globals["g"] == 15 + + for loader in tracked_loaders: + loader.flush() + + # Simulate fresh session. + k.globals.pop("f", None) + k.globals.pop("g", None) + + await k.run([producer, consumer]) + assert k.globals["g"] == 15 + assert callable(k.globals["f"]) + assert not isinstance(k.globals["f"], UnhashableStub) + assert not isinstance(k.globals["g"], UnhashableStub) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _FakeCell: + def __init__(self, cell_id: str, refs: set[str] | None = None) -> None: + self.cell_id = cell_id + self.refs = refs or set() + self.defs: set[str] = set() + self.mod = None diff --git a/tests/_runtime/test_scheduler.py b/tests/_runtime/test_scheduler.py index f6b52ad0563..4736d78ae74 100644 --- a/tests/_runtime/test_scheduler.py +++ b/tests/_runtime/test_scheduler.py @@ -64,6 +64,44 @@ def fake_closure(graph: object, roots: set[CellId_t]) -> set[CellId_t]: cell_mock.set_run_result_status.assert_called_with("cancelled") +def test_requeue_for_rerun_moves_producer_ahead_of_queued_consumer( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A producer already queued *behind* the consumer is moved to the head + so it runs before the consumer's retry — not left stranded behind it + (which would re-trip on the stale value forever).""" + p, c, y = CellId_t("P"), CellId_t("C"), CellId_t("Y") + + def fake_topo(graph: object, cells: set[CellId_t]) -> list[CellId_t]: + del graph + return [cid for cid in (p, c) if cid in cells] # producer first + + monkeypatch.setattr("marimo._runtime.dataflow.topological_sort", fake_topo) + # Consumer C is the current (popped) cell, not in the queue; producer P + # is already queued, behind Y. + sched = SequentialScheduler([y, p], graph=_empty_graph()) + sched.requeue_for_rerun({p, c}) + assert list(sched.cells_to_run) == [p, c, y] + + +def test_requeue_for_rerun_no_duplicate( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A requeued cell already in the queue is moved, not duplicated.""" + p, c = CellId_t("P"), CellId_t("C") + + def fake_topo(graph: object, cells: set[CellId_t]) -> list[CellId_t]: + del graph + return [cid for cid in (p, c) if cid in cells] + + monkeypatch.setattr("marimo._runtime.dataflow.topological_sort", fake_topo) + sched = SequentialScheduler([c], graph=_empty_graph()) + sched.requeue_for_rerun({p, c}) + queued = list(sched.cells_to_run) + assert queued == [p, c] + assert queued.count(c) == 1 + + def test_batch_yields_singletons() -> None: sched = SequentialScheduler([], graph=_empty_graph()) cells = [CellId_t("a"), CellId_t("b"), CellId_t("c")]