diff --git a/dimos/core/coordination/module_coordinator.py b/dimos/core/coordination/module_coordinator.py index f6c82a84ff..9ac46e3c20 100644 --- a/dimos/core/coordination/module_coordinator.py +++ b/dimos/core/coordination/module_coordinator.py @@ -17,6 +17,7 @@ from collections import defaultdict from collections.abc import Mapping, MutableMapping import importlib +import inspect import shutil import sys import threading @@ -30,7 +31,7 @@ from dimos.core.module import ModuleBase, ModuleSpec from dimos.core.resource import Resource from dimos.core.transport import LCMTransport, PubSubTransport, pLCMTransport -from dimos.spec.utils import spec_annotation_compliance, spec_structural_compliance +from dimos.spec.utils import is_spec, spec_annotation_compliance, spec_structural_compliance from dimos.utils.generic import short_id from dimos.utils.logging_config import setup_logger from dimos.utils.safe_thread_map import safe_thread_map @@ -767,7 +768,7 @@ def _connect_module_refs( ) -> None: from dimos.core.coordination.blueprints import DisabledModuleProxy from dimos.core.module import is_module_type - from dimos.spec.utils import is_spec + from dimos.core.rpc_client import AsyncSpecProxy mod_and_mod_ref_to_proxy = { (module, name): replacement @@ -775,11 +776,16 @@ def _connect_module_refs( if is_spec(replacement) or is_module_type(replacement) } + # Track the consumer's declared spec for each ref so we can wrap the proxy + # below if the spec contains async-declared methods. + declared_spec: dict[tuple[type[ModuleBase], str], Any] = {} + disabled_ref_proxies: dict[tuple[type[ModuleBase], str], DisabledModuleProxy] = {} disabled_set = set(blueprint.disabled_modules_tuple) for bp in blueprint.active_blueprints: for module_ref in bp.module_refs: + declared_spec[bp.module, module_ref.name] = module_ref.spec spec = mod_and_mod_ref_to_proxy.get((bp.module, module_ref.name), module_ref.spec) if is_module_type(spec): @@ -798,7 +804,10 @@ def _connect_module_refs( for (base_module, ref_name), target_module in mod_and_mod_ref_to_proxy.items(): base_instance = module_coordinator.get_instance(base_module) - target_instance = module_coordinator.get_instance(target_module) # type: ignore[arg-type] + target_instance: Any = module_coordinator.get_instance(target_module) # type: ignore[arg-type] + async_methods = _async_methods_of_spec(declared_spec.get((base_module, ref_name))) + if async_methods: + target_instance = AsyncSpecProxy(target_instance, async_methods) setattr(base_instance, ref_name, target_instance) base_instance.set_module_ref(ref_name, target_instance) module_coordinator._resolved_module_refs[base_module, ref_name] = cast( @@ -811,6 +820,21 @@ def _connect_module_refs( base_instance.set_module_ref(ref_name, cast("Any", proxy)) +def _async_methods_of_spec(spec: Any) -> frozenset[str]: + if not is_spec(spec): + return frozenset() + names: set[str] = set() + for cls in spec.__mro__: + if cls is object: + continue + for attr_name, value in vars(cls).items(): + if attr_name.startswith("_"): + continue + if inspect.iscoroutinefunction(value): + names.add(attr_name) + return frozenset(names) + + def _log_blueprint_graph(blueprint: Blueprint, module_coordinator: ModuleCoordinator) -> None: """Log the module graph to Rerun if a RerunBridgeModule is active.""" from dimos.visualization.rerun.bridge import RerunBridgeModule diff --git a/dimos/core/core.py b/dimos/core/core.py index 1fac36d250..1c31415e73 100644 --- a/dimos/core/core.py +++ b/dimos/core/core.py @@ -15,22 +15,56 @@ from __future__ import annotations +import asyncio +import functools +import inspect from typing import ( TYPE_CHECKING, + Any, + ParamSpec, TypeVar, + cast, ) if TYPE_CHECKING: from collections.abc import Callable T = TypeVar("T") - -from typing import ParamSpec, TypeVar - P = ParamSpec("P") R = TypeVar("R") def rpc(fn: Callable[P, R]) -> Callable[P, R]: - fn.__rpc__ = True # type: ignore[attr-defined] - return fn + """Mark a method as an RPC body callable across modules. + + Sync methods are tagged in place. Async methods get a sync dispatcher that + runs the coroutine on `self._loop`: + + * Caller is on self._loop (another async @rpc, a handle_*, or a + process_observable callback): returns the coroutine so the caller can + `await` it normally. + * Caller is on any other thread (RPC dispatcher, sync test, sync @rpc on + the same module): schedules the coroutine onto self._loop and blocks + until done. + """ + if not inspect.iscoroutinefunction(fn): + fn.__rpc__ = True # type: ignore[attr-defined] + return fn + + @functools.wraps(fn) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + loop = self._loop + if loop is None: + raise RuntimeError("async @rpc method called outside a running module loop") + try: + running = asyncio.get_running_loop() + except RuntimeError: + running = None + if running is loop: + return fn(self, *args, **kwargs) # type: ignore[call-arg] + future = asyncio.run_coroutine_threadsafe(fn(self, *args, **kwargs), loop) # type: ignore[call-arg, arg-type] + return future.result() + + wrapper.__rpc__ = True # type: ignore[attr-defined] + wrapper.aio = fn # type: ignore[attr-defined] + return cast("Callable[P, R]", wrapper) diff --git a/dimos/core/module.py b/dimos/core/module.py index e44742b857..fb79bb280f 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio -from collections.abc import Callable +from collections.abc import AsyncGenerator, Callable from dataclasses import dataclass from functools import partial import inspect @@ -31,6 +31,7 @@ ) from pydantic import Field +from reactivex.disposable import CompositeDisposable, Disposable from dimos.core.core import T, rpc from dimos.core.global_config import GlobalConfig, global_config @@ -45,8 +46,14 @@ from dimos.protocol.tf.tf import LCMTF, TFSpec from dimos.utils import colors from dimos.utils.generic import classproperty +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() if TYPE_CHECKING: + from reactivex import Observable + from reactivex.abc import DisposableBase + from dimos.core.coordination.blueprints import Blueprint from dimos.core.introspection.module.info import ModuleInfo from dimos.core.rpc_client import RPCClient @@ -71,6 +78,7 @@ def get_loop() -> tuple[asyncio.AbstractEventLoop, threading.Thread | None]: except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) + loop.set_task_factory(_logging_task_factory) thr = threading.Thread(target=loop.run_forever, daemon=True) thr.start() @@ -112,6 +120,7 @@ class ModuleBase(Configurable, CompositeResource): _module_closed: bool = False _module_closed_lock: threading.Lock _loop_thread_timeout: float = 2.0 + _main_gen: AsyncGenerator[None, None] | None = None def __init__(self, config_args: dict[str, Any]) -> None: super().__init__(**config_args) @@ -150,10 +159,12 @@ def build(self) -> None: @rpc def start(self) -> None: - pass + self._start_main() + self._auto_bind_handlers() @rpc def stop(self) -> None: + self._stop_main() super().stop() self._close_module() @@ -203,6 +214,7 @@ def __getstate__(self): # type: ignore[no-untyped-def] state.pop("_loop_thread", None) state.pop("_rpc", None) state.pop("_tf", None) + state.pop("_main_gen", None) return state def __setstate__(self, state) -> None: # type: ignore[no-untyped-def] @@ -214,6 +226,7 @@ def __setstate__(self, state) -> None: # type: ignore[no-untyped-def] self._loop_thread = None self._rpc = None self._tf = None + self._main_gen = None @property def tf(self): # type: ignore[no-untyped-def] @@ -396,6 +409,227 @@ def get_skills(self) -> list[SkillInfo]: ) return skills + def spawn(self, coro: Any) -> Any: + """ + Schedule a coroutine on self._loop from any thread. + + Use this instead of bare `asyncio.run_coroutine_threadsafe(coro, + self._loop)` when scheduling a long-running async task sync context like + start(). + + Unhandled exceptions are routed to the module logger instead of being + silently stored in the returned Future, which is the common pitfall when + nothing ever reads `.result()`. + """ + + loop = self._loop + if loop is None or not loop.is_running(): + raise RuntimeError(f"{type(self).__name__}._loop is not running") + future = asyncio.run_coroutine_threadsafe(coro, loop) + future.add_done_callback(self._log_async_handler_error) + return future + + def process_observable( + self, + observable: "Observable[Any]", + async_cb: Callable[[Any], Any], + ) -> "DisposableBase": + """Subscribe `async_cb` (an async function) to `observable`, dispatching + each emitted value onto self._loop. Invocations are serialized through a + per-subscription dispatcher task with LATEST coalescing. The subscription + is registered for cleanup on stop().""" + if not inspect.iscoroutinefunction(async_cb): + raise TypeError("process_observable requires an `async def` callback") + on_msg, dispatcher_disp = self._make_async_dispatch(async_cb) + sub = observable.subscribe(on_msg) + return self.register_disposable(CompositeDisposable(sub, dispatcher_disp)) + + def _start_main(self) -> None: + """ + If the subclass defines `async def main(self)` as an async generator + with exactly one `yield`, run everything before the `yield` as part of + start(). + """ + main_fn = getattr(type(self), "main", None) + if main_fn is None: + return + if not inspect.isasyncgenfunction(main_fn): + raise TypeError( + f"{type(self).__name__}.main must be an `async def` with exactly " + "one `yield` (an async generator function)" + ) + loop = self._loop + if loop is None or not loop.is_running(): + raise RuntimeError(f"{type(self).__name__}._loop is not running") + gen = main_fn(self) + try: + asyncio.run_coroutine_threadsafe(gen.__anext__(), loop).result() + except StopAsyncIteration: + raise RuntimeError( + f"{type(self).__name__}.main must contain exactly one `yield` (found none)" + ) from None + except BaseException: + try: + asyncio.run_coroutine_threadsafe(gen.aclose(), loop).result() + except BaseException: + pass + raise + self._main_gen = gen + + def _stop_main(self) -> None: + """Resume `main` past its yield so the teardown section runs.""" + gen = self._main_gen + if gen is None: + return + self._main_gen = None + loop = self._loop + if loop is None or not loop.is_running(): + return + try: + asyncio.run_coroutine_threadsafe(gen.__anext__(), loop).result() + except StopAsyncIteration: + return + except BaseException as e: + # Do not fail teardown if main raises. Log and continue with best + # effort to close the module. + logger.exception( + f"Error during {type(self).__name__}.main teardown: {type(e).__name__}: {e}" + ) + return + # No StopAsyncIteration means main yielded a second time. + try: + asyncio.run_coroutine_threadsafe(gen.aclose(), loop).result() + except BaseException: + pass + logger.error( + f"{type(self).__name__}.main yielded more than once; " + "expected exactly one yield (setup, then teardown)" + ) + + def _auto_bind_handlers(self) -> None: + """ + For each declared `x: In[T]`, if `async def handle_x` exists, subscribe it + via process_observable so it runs on self._loop. + """ + # Validate every handler before subscribing any of them. + bindings: list[tuple[Any, Callable[[Any], Any]]] = [] + for input_name, in_stream in self.inputs.items(): + handler = getattr(self, f"handle_{input_name}", None) + if handler is None: + continue + # Async @rpc wraps the coroutine fn in a sync dispatcher. Unwrap it + # so we subscribe the raw coroutine fn instead of the wrapper (which + # would block on run_coroutine_threadsafe from the rx thread). + if hasattr(handler, "aio"): + handler = handler.aio.__get__(self, type(self)) + if not inspect.iscoroutinefunction(handler): + raise TypeError( + f"{type(self).__name__}.handle_{input_name} must be `async def` " + "(use a manual self..subscribe(...) for sync handlers)" + ) + bindings.append((in_stream, handler)) + + for in_stream, handler in bindings: + # process_observable runs each handler through a per-subscription + # dispatcher task on self._loop that serializes invocations and + # keeps only the latest unprocessed message. We subscribe to + # pure_observable() because the dispatcher already provides + # backpressure. + self.process_observable(in_stream.pure_observable(), handler) + + def _make_async_dispatch( + self, async_handler: Callable[[Any], Any] + ) -> tuple[Callable[[Any], None], "DisposableBase"]: + """Build a sync callback that delivers `msg` into a single-slot LATEST + mailbox drained by a dedicated dispatcher task on `self._loop`. + + Guarantees: + - The handler is invoked at most one-at-a-time (no interleaving across + awaits). + - If messages arrive faster than the handler can process them, + intermediate messages are dropped and only the most recent unprocessed + message is kept (LATEST policy). + - The returned Disposable cancels the dispatcher task. + """ + loop = self._loop + if loop is None or not loop.is_running(): + raise RuntimeError(f"{type(self).__name__}._loop is not running") + + async def _bootstrap() -> tuple[asyncio.Event, dict[str, Any], asyncio.Task[None]]: + event = asyncio.Event() + slot: dict[str, Any] = {"value": None, "has_value": False} + + async def dispatcher() -> None: + try: + while True: + await event.wait() + event.clear() + if not slot["has_value"]: + continue + msg = slot["value"] + slot["value"] = None + slot["has_value"] = False + try: + await async_handler(msg) + except asyncio.CancelledError: + raise + except BaseException as e: + self._log_async_handler_exception(e) + except asyncio.CancelledError: + return + + return event, slot, asyncio.create_task(dispatcher()) + + event, slot, task = asyncio.run_coroutine_threadsafe(_bootstrap(), loop).result(timeout=5.0) + + def on_msg(msg: Any) -> None: + loop_now = self._loop + if loop_now is None or not loop_now.is_running(): + return + + def _set() -> None: + slot["value"] = msg + slot["has_value"] = True + event.set() + + loop_now.call_soon_threadsafe(_set) + + disposed = False + + def _dispose() -> None: + nonlocal disposed + if disposed: + return + disposed = True + loop_now = self._loop + if loop_now is not None and loop_now.is_running(): + loop_now.call_soon_threadsafe(task.cancel) + + return on_msg, Disposable(_dispose) + + def _log_async_handler_exception(self, e: BaseException) -> None: + if isinstance(e, asyncio.CancelledError): + return # task cancelled during shutdown + # A coroutine interacting with a stopped loop surfaces as + # RuntimeError ("Event loop is closed", "no running event loop", + # etc.). Only swallow that when the loop is actually gone. Anything + # else (including RuntimeError raised by user code while the loop is + # healthy) is a real bug worth logging. + loop = self._loop + if isinstance(e, RuntimeError) and (loop is None or not loop.is_running()): + return + # Include exception type+message in the event string so it is + # visible on consoles whose formatters strip exc_info/traceback. + logger.exception( + f"Unhandled error in async task on {type(self).__name__}._loop: {type(e).__name__}: {e}" + ) + + def _log_async_handler_error(self, fut: Any) -> None: + try: + fut.result() + except BaseException as e: + self._log_async_handler_exception(e) + class Module(ModuleBase): def __init_subclass__(cls, **kwargs: Any) -> None: @@ -472,3 +706,33 @@ def is_module_type(value: Any) -> bool: return inspect.isclass(value) and issubclass(value, Module) except Exception: return False + + +def _logging_task_factory( + loop: asyncio.AbstractEventLoop, coro: Any, **kwargs: Any +) -> asyncio.Task[Any]: + """ + Adds a done callback to log unhandled exceptions from any task created on + the loop. + """ + task = asyncio.Task(coro, loop=loop, **kwargs) + task.add_done_callback(_log_task_exception) + return task + + +def _log_task_exception(task: asyncio.Task[Any]) -> None: + if task.cancelled(): + return + try: + exc = task.exception() + except asyncio.InvalidStateError: + return + if exc is None or isinstance(exc, (asyncio.CancelledError, StopAsyncIteration)): + return + # Calling task.exception() above marks the exception as retrieved, so + # asyncio's GC-time logger won't fire. We must log here. + name = task.get_name() + logger.error( + f"Unhandled exception in async task {name!r}: {type(exc).__name__}: {exc}", + exc_info=exc, + ) diff --git a/dimos/core/rpc_client.py b/dimos/core/rpc_client.py index d3a645d418..42407f315e 100644 --- a/dimos/core/rpc_client.py +++ b/dimos/core/rpc_client.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from collections.abc import Callable from typing import TYPE_CHECKING, Any, Protocol @@ -162,6 +163,50 @@ def __getattr__(self, name: str): # type: ignore[no-untyped-def] return result +class AsyncSpecProxy: + """Wraps an RPCClient (or compatible proxy) so methods declared `async def` + on the consumer's Spec are exposed as awaitables on the proxy. + + A consumer that types `ref: SomeSpec` where `SomeSpec` declares `async def + foo` will see `self.ref.foo(x)` return an awaitable. The underlying RPC call + is still synchronous over the wire. The caller's event loop stays unblocked + while the response round-trips. + + It's picklable so `set_module_ref`` can ship it across to the worker process. + """ + + def __init__(self, inner: Any, async_methods: frozenset[str]) -> None: + # Use object.__setattr__ for clarity; we don't override __setattr__ + # but this mirrors how DisabledModuleProxy guards its internals. + object.__setattr__(self, "_inner", inner) + object.__setattr__(self, "_async_methods", async_methods) + + def __getattr__(self, name: str) -> Any: + inner = object.__getattribute__(self, "_inner") + attr = getattr(inner, name) + async_methods = object.__getattribute__(self, "_async_methods") + if name not in async_methods or not callable(attr): + return attr + + def async_call(*args: Any, **kwargs: Any) -> Any: + async def _run() -> Any: + running = asyncio.get_running_loop() + return await running.run_in_executor(None, lambda: attr(*args, **kwargs)) + + return _run() + + return async_call + + def __reduce__(self) -> Any: + return ( + AsyncSpecProxy, + ( + object.__getattribute__(self, "_inner"), + object.__getattribute__(self, "_async_methods"), + ), + ) + + if TYPE_CHECKING: from dimos.core.module import Module diff --git a/dimos/core/test_async_module_dispatch_serialization.py b/dimos/core/test_async_module_dispatch_serialization.py new file mode 100644 index 0000000000..e392cc6602 --- /dev/null +++ b/dimos/core/test_async_module_dispatch_serialization.py @@ -0,0 +1,207 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import itertools +from queue import Empty, Queue +import time +from typing import Any + +import pytest + +from dimos.core.coordination.module_coordinator import ModuleCoordinator +from dimos.core.module import Module +from dimos.core.stream import In, Out +from dimos.core.transport import pLCMTransport + + +class BurstModule(Module): + """Slow handler that records a (value, start, end) tuple per invocation.""" + + a: In[int] + record: Out[dict] + + async def handle_a(self, value: int) -> None: + start = time.monotonic() + await asyncio.sleep(0.05) + end = time.monotonic() + self.record.publish({"value": value, "start": start, "end": end}) + + +@pytest.fixture +def start_burst_module(): + blueprint = BurstModule.blueprint() + coordinator = ModuleCoordinator.build(blueprint) + yield + coordinator.stop() + + +@pytest.fixture +def burst_a_transport(): + tr = pLCMTransport("/a") + tr.start() + yield tr + tr.stop() + + +@pytest.fixture +def burst_record_transport(): + tr = pLCMTransport("/record") + tr.start() + yield tr + tr.stop() + + +def _drain(queue: Queue, settle_timeout: float = 0.5) -> list[Any]: + items: list[Any] = [] + while True: + try: + items.append(queue.get(timeout=settle_timeout)) + except Empty: + return items + + +@pytest.mark.slow +def test_bursts_are_coalesced_and_handler_is_serialized( + start_burst_module, burst_a_transport, burst_record_transport +): + """Publishing 100 messages in a tight loop while the handler sleeps 50ms + must (a) coalesce (the handler is invoked far fewer than 100 times), + (b) eventually deliver the most recently published value, and + (c) never run two handler invocations concurrently.""" + queue: Queue = Queue() + burst_record_transport.subscribe(queue.put) + + n = 100 + for i in range(n): + burst_a_transport.publish(i) + + records = _drain(queue, settle_timeout=2.0) + + # Coalescing actually happened. + assert 0 < len(records) < n, f"expected coalescing, got {len(records)} records" + + # The most recently published value eventually reaches the handler. + assert records[-1]["value"] == n - 1, ( + f"last record {records[-1]['value']} should equal final published value {n - 1}" + ) + + # No two recorded [start, end] intervals overlap (handler is serial). + intervals = sorted((r["start"], r["end"]) for r in records) + for (_, prev_end), (next_start, _) in itertools.pairwise(intervals): + assert next_start >= prev_end, ( + f"overlapping handler intervals: prev_end={prev_end}, next_start={next_start}" + ) + + +class InterleaveModule(Module): + """Handler that yields between writing and reading a per-instance marker.""" + + a: In[int] + record: Out[dict] + + _marker: int = -1 + + async def handle_a(self, value: int) -> None: + self._marker = value + # Yield to the loop. Without serialization, another invocation could + # run here and overwrite _marker before this coroutine resumes. + await asyncio.sleep(0) + self.record.publish({"value": value, "marker": self._marker}) + + +@pytest.fixture +def start_interleave_module(): + blueprint = InterleaveModule.blueprint() + coordinator = ModuleCoordinator.build(blueprint) + yield + coordinator.stop() + + +@pytest.fixture +def interleave_a_transport(): + tr = pLCMTransport("/a") + tr.start() + yield tr + tr.stop() + + +@pytest.fixture +def interleave_record_transport(): + tr = pLCMTransport("/record") + tr.start() + yield tr + tr.stop() + + +@pytest.mark.slow +def test_handler_does_not_interleave_across_awaits( + start_interleave_module, interleave_a_transport, interleave_record_transport +): + """No invocation of `handle_a` may observe a marker written by a different + invocation (the dispatcher must serialize handler execution across `await` + points).""" + queue: Queue = Queue() + interleave_record_transport.subscribe(queue.put) + + for i in range(50): + interleave_a_transport.publish(i) + + records = _drain(queue, settle_timeout=1.0) + assert records, "expected at least one record" + + for r in records: + assert r["value"] == r["marker"], ( + f"marker {r['marker']} differs from value {r['value']} — " + "another handler invocation overwrote our state mid-handler" + ) + + +class CleanupModule(Module): + """Handler that sleeps for a long time so we can stop the coordinator + while the handler is mid-await.""" + + a: In[int] + + async def handle_a(self, value: int) -> None: + await asyncio.sleep(5.0) + + +@pytest.fixture +def cleanup_a_transport(): + tr = pLCMTransport("/a") + tr.start() + yield tr + tr.stop() + + +@pytest.mark.slow +def test_stop_cancels_in_flight_handler(cleanup_a_transport): + """Stopping the coordinator while a handler is awaiting must complete + quickly. The dispatcher cancels the handler instead of waiting for it.""" + blueprint = CleanupModule.blueprint() + coordinator = ModuleCoordinator.build(blueprint) + try: + cleanup_a_transport.publish(1) + time.sleep(0.1) # let the handler enter its sleep + start = time.monotonic() + coordinator.stop() + elapsed = time.monotonic() - start + except BaseException: + coordinator.stop() + raise + + # Without cancellation, stop() would either hang or be bounded only by the + # 5s asyncio.sleep. The dispatcher cancels the task synchronously. + assert elapsed < 3.0, f"stop() took {elapsed:.2f}s (handler not cancelled)" diff --git a/dimos/core/test_async_module_handles.py b/dimos/core/test_async_module_handles.py new file mode 100644 index 0000000000..aa24883d8b --- /dev/null +++ b/dimos/core/test_async_module_handles.py @@ -0,0 +1,63 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from queue import Queue + +import pytest + +from dimos.core.coordination.module_coordinator import ModuleCoordinator +from dimos.core.module import Module +from dimos.core.stream import In, Out +from dimos.core.transport import pLCMTransport + + +class DoubleModule(Module): + a: In[int] + double_a: Out[int] + + async def handle_a(self, a: int) -> None: + self.double_a.publish(a * 2) + + +@pytest.fixture +def start_double_module(): + blueprint = DoubleModule.blueprint() + coordinator = ModuleCoordinator.build(blueprint) + yield + coordinator.stop() + + +@pytest.fixture +def a_transport(): + a_tr = pLCMTransport("/a") + a_tr.start() + yield a_tr + a_tr.stop() + + +@pytest.fixture +def double_a_transport(): + double_a_tr = pLCMTransport("/double_a") + double_a_tr.start() + yield double_a_tr + double_a_tr.stop() + + +@pytest.mark.slow +def test_async_module_handles(start_double_module, a_transport, double_a_transport): + queue = Queue() + double_a_transport.subscribe(queue.put) + a_transport.publish(42) + doubled = queue.get(timeout=0.1) + assert doubled == 84 diff --git a/dimos/core/test_async_module_main.py b/dimos/core/test_async_module_main.py new file mode 100644 index 0000000000..384f735e13 --- /dev/null +++ b/dimos/core/test_async_module_main.py @@ -0,0 +1,254 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from collections.abc import AsyncIterator, Iterator +import logging +from queue import Queue +from typing import Any + +import pytest + +from dimos.core.coordination.module_coordinator import ModuleCoordinator +from dimos.core.module import Module +from dimos.core.stream import In, Out +from dimos.core.transport import pLCMTransport + + +@pytest.fixture +def module_log_records() -> Iterator[list[logging.LogRecord]]: + records: list[logging.LogRecord] = [] + + class _ListHandler(logging.Handler): + def emit(self, record: logging.LogRecord) -> None: + records.append(record) + + handler = _ListHandler(level=logging.DEBUG) + target = logging.getLogger("dimos/core/module.py") + target.addHandler(handler) + try: + yield records + finally: + target.removeHandler(handler) + + +class _Resource: + """Tiny stand-in for an external resource a module might own.""" + + def __init__(self) -> None: + self.started = False + self.stop_count = 0 + + def start(self) -> None: + self.started = True + + def stop(self) -> None: + self.stop_count += 1 + + +class HappyMain(Module): + """Records setup/teardown order and verifies main runs on _loop.""" + + events: list[str] + resource: _Resource + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.events = [] + self.resource = _Resource() + + async def main(self) -> AsyncIterator[None]: + assert asyncio.get_running_loop() is self._loop + self.resource.start() + self.events.append("setup") + yield + self.events.append("teardown") + self.resource.stop() + + +def test_main_setup_runs_before_start_returns(): + m = HappyMain() + assert m.events == [] + m.start() + try: + assert m.events == ["setup"] + assert m.resource.started is True + finally: + m.stop() + + +def test_main_teardown_runs_during_stop(): + m = HappyMain() + m.start() + m.stop() + assert m.events == ["setup", "teardown"] + assert m.resource.stop_count == 1 + + +def test_main_teardown_runs_only_once(): + m = HappyMain() + m.start() + m.stop() + # Calling stop() again should be a no-op for main (already torn down). + m.stop() + assert m.resource.stop_count == 1 + + +class NoYieldMain(Module): + async def main(self) -> AsyncIterator[None]: + # Lexically contains yield (so isasyncgenfunction is True), but runtime + # never reaches it -> StopAsyncIteration on first __anext__. + if False: + yield + + +def test_main_with_zero_runtime_yields_raises(): + m = NoYieldMain() + with pytest.raises(RuntimeError, match="exactly one `yield`.*found none"): + m.start() + # Even though start failed, stop should still be safe to call. + m.stop() + + +class NotAGeneratorMain(Module): + async def main(self) -> None: + return None + + +def test_main_that_is_not_an_async_generator_raises(): + m = NotAGeneratorMain() + with pytest.raises(TypeError, match="must be an `async def` with exactly one"): + m.start() + m.stop() + + +class TwoYieldsMain(Module): + teardown_count: int + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.teardown_count = 0 + + async def main(self) -> AsyncIterator[None]: + yield + yield + self.teardown_count += 1 + + +def test_main_with_two_yields_logs_and_continues_stop(module_log_records): + m = TwoYieldsMain() + m.start() + m.stop() # Must not raise. + assert any("yielded more than once" in rec.getMessage() for rec in module_log_records) + # The second-section code after the second yield should NOT have run + # because we close the generator instead of running it through. + assert m.teardown_count == 0 + + +class TeardownErrorMain(Module): + teardown_attempted: bool + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.teardown_attempted = False + + async def main(self) -> AsyncIterator[None]: + yield + self.teardown_attempted = True + raise RuntimeError("teardown failure") + + +def test_main_teardown_error_is_logged_not_raised(module_log_records): + m = TeardownErrorMain() + m.start() + m.stop() # Must not raise. + assert m.teardown_attempted is True + assert any( + "teardown" in rec.getMessage() and rec.levelno == logging.ERROR + for rec in module_log_records + ) + + +class SetupErrorMain(Module): + teardown_ran: bool + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.teardown_ran = False + + async def main(self) -> AsyncIterator[None]: + raise RuntimeError("setup failure") + yield # pragma: no cover (unreachable, but makes this an async gen) + self.teardown_ran = True # pragma: no cover + + +def test_main_setup_error_propagates_from_start(): + m = SetupErrorMain() + with pytest.raises(RuntimeError, match="setup failure"): + m.start() + # Generator was never stored, so stop() should not try to drive teardown. + m.stop() + assert m.teardown_ran is False + + +class MainAndHandlerModule(Module): + a: In[int] + out: Out[int] + + multiplier: int = 0 + setup_ran: bool = False + teardown_ran: bool = False + + async def main(self) -> AsyncIterator[None]: + self.multiplier = 7 + self.setup_ran = True + yield + self.teardown_ran = True + self.multiplier = 0 + + async def handle_a(self, value: int) -> None: + self.out.publish(value * self.multiplier) + + +@pytest.fixture +def start_main_handler_module(): + blueprint = MainAndHandlerModule.blueprint() + coordinator = ModuleCoordinator.build(blueprint) + yield + coordinator.stop() + + +@pytest.fixture +def a_transport(): + a_tr = pLCMTransport("/a") + a_tr.start() + yield a_tr + a_tr.stop() + + +@pytest.fixture +def out_transport(): + out_tr = pLCMTransport("/out") + out_tr.start() + yield out_tr + out_tr.stop() + + +@pytest.mark.slow +def test_main_and_handle_together(start_main_handler_module, a_transport, out_transport): + queue: Queue[int] = Queue() + out_transport.subscribe(queue.put) + a_transport.publish(6) + result = queue.get(timeout=4) + assert result == 42 # 6 * 7 diff --git a/dimos/core/test_async_module_process_observable.py b/dimos/core/test_async_module_process_observable.py new file mode 100644 index 0000000000..50c83f10cd --- /dev/null +++ b/dimos/core/test_async_module_process_observable.py @@ -0,0 +1,87 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from queue import Queue +import string + +import pytest +import reactivex as rx +from reactivex import operators as ops + +from dimos.core.coordination.module_coordinator import ModuleCoordinator +from dimos.core.core import rpc +from dimos.core.module import Module +from dimos.core.stream import Out +from dimos.core.transport import pLCMTransport + + +class StartModule(Module): + uppercase: Out[str] + + @rpc + def start(self) -> None: + super().start() + + observable = rx.interval(0.1).pipe( + ops.take(len(string.ascii_lowercase)), + ops.map(lambda i: string.ascii_lowercase[i]), + ) + + self.process_observable(observable, self.handle_letter) + + async def handle_letter(self, letter: str) -> None: + self.uppercase.publish(letter.upper()) + + +@pytest.fixture +def start_module(): + blueprint = StartModule.blueprint() + coordinator = ModuleCoordinator.build(blueprint) + yield + coordinator.stop() + + +@pytest.fixture +def get_collected_letters(): + uppercase_transport = pLCMTransport("/uppercase") + uppercase_transport.start() + queue = Queue() + uppercase_transport.subscribe(queue.put) + + def _get_collected_letters() -> list[str]: + return "".join([queue.get(timeout=4) for _ in range(26)]) + + yield _get_collected_letters + + uppercase_transport.stop() + + +@pytest.mark.slow +def test_async_module_process_observable(get_collected_letters, start_module): + """ + Tests that process_observable correctly processes items from an observable + in an async manner. + + Most of the logic is in get_collected_letters, because we need to setup the + subscription to the result before starting the module. This is because the + module emits from the start method. + + The strict equality below also locks down the serial-delivery contract: the + per-subscription dispatcher must invoke `handle_letter` once per item in the + order they were emitted (the source emits at 100ms intervals, slower than the + near-zero handler runtime, so no LATEST coalescing should occur). + """ + collected = get_collected_letters() + assert len(collected) == 26 + assert collected == "ABCDEFGHIJKLMNOPQRSTUVWXYZ" diff --git a/dimos/core/test_async_module_rpc.py b/dimos/core/test_async_module_rpc.py new file mode 100644 index 0000000000..5778bd42eb --- /dev/null +++ b/dimos/core/test_async_module_rpc.py @@ -0,0 +1,82 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from queue import Queue +from typing import Protocol + +import pytest + +from dimos.core.coordination.blueprints import autoconnect +from dimos.core.coordination.module_coordinator import ModuleCoordinator +from dimos.core.core import rpc +from dimos.core.module import Module +from dimos.core.stream import In, Out +from dimos.core.transport import pLCMTransport +from dimos.spec.utils import Spec + + +class MakeCube(Module): + @rpc + async def make_cube(self, x: int) -> int: + await asyncio.sleep(0.001) # Just so it's actually async. + return x * x * x + + +class MakeCubeSpec(Spec, Protocol): + async def make_cube(self, x: int) -> int: ... + + +class StartModule(Module): + a: In[int] + cube_a: Out[int] + _cuber: MakeCubeSpec + + @rpc + async def handle_a(self, x: int) -> None: + cube = await self._cuber.make_cube(x) + self.cube_a.publish(cube) + + +@pytest.fixture +def start_cube_module(): + blueprint = autoconnect(StartModule.blueprint(), MakeCube.blueprint()) + coordinator = ModuleCoordinator.build(blueprint) + yield + coordinator.stop() + + +@pytest.fixture +def a_transport(): + a_tr = pLCMTransport("/a") + a_tr.start() + yield a_tr + a_tr.stop() + + +@pytest.fixture +def cube_a_transport(): + cube_a_tr = pLCMTransport("/cube_a") + cube_a_tr.start() + yield cube_a_tr + cube_a_tr.stop() + + +@pytest.mark.slow +def test_async_module_rpc(start_cube_module, a_transport, cube_a_transport): + queue = Queue() + cube_a_transport.subscribe(queue.put) + a_transport.publish(3) + cubed = queue.get(timeout=0.1) + assert cubed == 27 diff --git a/dimos/core/test_async_module_rpc_sync_to_async.py b/dimos/core/test_async_module_rpc_sync_to_async.py new file mode 100644 index 0000000000..a054bcc868 --- /dev/null +++ b/dimos/core/test_async_module_rpc_sync_to_async.py @@ -0,0 +1,113 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from queue import Queue +from typing import Protocol + +import pytest + +from dimos.core.coordination.blueprints import autoconnect +from dimos.core.coordination.module_coordinator import ModuleCoordinator +from dimos.core.core import rpc +from dimos.core.module import Module +from dimos.core.stream import In, Out +from dimos.core.transport import pLCMTransport +from dimos.spec.utils import Spec + + +class ModuleA(Module): + @rpc + def a(self, x: int) -> int: + return x * 1000 + + +class ModuleB(Module): + @rpc + async def b(self, x: int) -> int: + return x * 100 + + +class ASpec(Spec, Protocol): + def a(self, x: int) -> int: ... + + +class BSpec(Spec, Protocol): + # ModuleB.b is async but we use sync in the spec + def b(self, x: int) -> int: ... + + +class ModuleAB(Module): + _a: ASpec + _b: BSpec + + @rpc + def ab(self, x: int) -> int: + return self._a.a(x) + self._b.b(x) + + +class ABSpec(Spec, Protocol): + # ModuleAB.ab is sync but we use async in the spec + async def ab(self, x: int) -> int: ... + + +class StartModule(Module): + in_value: In[int] + out_value: Out[int] + _ab: ABSpec + + @rpc + async def handle_in_value(self, x: int) -> None: + ret = await self._ab.ab(x) + self.out_value.publish(ret) + + +@pytest.fixture +def start_module(): + blueprint = autoconnect( + StartModule.blueprint(), + ModuleA.blueprint(), + ModuleB.blueprint(), + ModuleAB.blueprint(), + ) + coordinator = ModuleCoordinator.build(blueprint) + yield + coordinator.stop() + + +@pytest.fixture +def in_transport(): + ret = pLCMTransport("/in_value") + ret.start() + yield ret + ret.stop() + + +@pytest.fixture +def out_transport(): + ret = pLCMTransport("/out_value") + ret.start() + yield ret + ret.stop() + + +@pytest.mark.slow +def test_async_module_rpc_sync_to_async(start_module, in_transport, out_transport): + """ + Test that you can call a synchronous RPC from an asynchronous RPC and vice versa. + """ + queue = Queue() + out_transport.subscribe(queue.put) + in_transport.publish(4) + cubed = queue.get(timeout=0.1) + assert cubed == 4400 diff --git a/docs/usage/modules.md b/docs/usage/modules.md index 8c16fa8561..0a258634d9 100644 --- a/docs/usage/modules.md +++ b/docs/usage/modules.md @@ -1,4 +1,3 @@ - # DimOS Modules Modules are subsystems on a robot that operate autonomously and communicate with other subsystems using standardized messages. @@ -188,6 +187,151 @@ coordinator.load_module(CameraModule) coordinator.restart_module(CameraModule) ``` +## Async modules (lock-free state) + +Modules contain a per-instance asyncio loop on a daemon thread (`self._loop`). It is possible to write modules using only `async def` methods so that everything runs on the same thread and you don't need to use locks. The module's auto-bound input handlers, async `@rpc` methods, and `process_observable` callbacks all run on `self._loop`, and each handler subscription is serialized through a dedicated dispatcher task. + +### Auto-bound input handlers + +For every declared `x: In[T]`, if the module defines `async def handle_x(self, msg: T)`, the handler is automatically subscribed at `start()` and dispatched onto `self._loop`. Subscriptions are cleaned up at `stop()`. + +```python +from dimos.core.module import Module +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.PointStamped import PointStamped +from dimos.msgs.geometry_msgs.Twist import Twist + + +class MovementManager(Module): + clicked_point: In[PointStamped] + nav_cmd_vel: In[Twist] + tele_cmd_vel: In[Twist] + + cmd_vel: Out[Twist] + goal: Out[PointStamped] + + def __init__(self, **kwargs): + super().__init__(**kwargs) + # No lock needed. `_teleop_active` is only mutated on `self._loop`. + self._teleop_active = False + + async def handle_clicked_point(self, msg: PointStamped) -> None: + self.goal.publish(msg) + + async def handle_nav_cmd_vel(self, msg: Twist) -> None: + if not self._teleop_active: + self.cmd_vel.publish(msg) + + async def handle_tele_cmd_vel(self, msg: Twist) -> None: + self._teleop_active = True + self.cmd_vel.publish(msg) +``` + +Each handler runs in a per-handler dispatcher task on `self._loop`. Handlers are serialized: only one invocation of `handle_x` runs at a time. If messages arrive faster than the handler can process them, intermediate messages are dropped — only the most recent unprocessed message is kept (LATEST policy). The handler is guaranteed to eventually run with the most recently published value. + +### Async `@rpc` methods + +`@rpc` works on both sync and `async def` methods. When applied to an async method, the call site dispatches automatically: + +- From another thread (the RPC dispatcher, sync test code, a sync `@rpc` on the same module), the call blocks until the coroutine completes on `self._loop`. +- From inside the loop (another async `@rpc`, a `handle_*`, or a `process_observable` callback), it returns the coroutine so the caller can `await` it. + +```python +class NameModule(Module): + @rpc + async def say_hello(self, name: str) -> str: + return f"Hello {name}, from {self._my_name}" + + @rpc + async def set_my_name(self, new_name: str) -> None: + self._my_name = new_name +``` + +Async and sync `@rpc` methods are interchangeable for cross-module linking. Both are discovered via `Module.rpcs` and served through the same RPC machinery. A module ref or RPC client doesn't care whether the underlying method is sync or async. + +When the consumer types a module ref using a Spec that declares `async def`, the proxy automatically exposes those methods as awaitables: `await self._name_module.say_hello(name)`. + +```python +class NameSpec(Spec, Protocol): + async def say_hello(self, name: str) -> str: ... + async def set_my_name(self, new_name: str) -> None: ... + + +class StartModule(Module): + _name_module: NameSpec + + async def code(): + await self._name_module.set_my_name("John") + print(await self._name_module.say_hello("Bill")) +``` + +`NameModule` is async. But if you need to call it from a sync module, you just need to create a `SyncNameSpec`: + +```python +class SyncNameSpec(Spec, Protocol): + def say_hello(self, name: str) -> str: ... + def set_my_name(self, new_name: str) -> None: ... +``` + +This will match with `NameModule`. You can call it synchronously from your module, but it will run in the `self._loop` async loop in the `NameModule` module. + +The reverse is also true: you can call a sync module from async code. + +### `spawn`: schedule a long-running coroutine from sync code + +When you need to start a long-running async task from `start()` (e.g., a timer loop), use `self.spawn(coro)` instead of `asyncio.run_coroutine_threadsafe(coro, self._loop)`. The helper wires up a done-callback that surfaces unhandled exceptions to the module logger. bare `run_coroutine_threadsafe` silently stores the exception on the returned Future, where it disappears unless the user remembers to read `.result()`. + +```python +@rpc +def start(self) -> None: + super().start() + self._timer_future = self.spawn(self._timer_loop()) + +async def _timer_loop(self) -> None: + while True: + await asyncio.sleep(1.0) + ... + +@rpc +def stop(self) -> None: + if self._timer_future is not None: + self._timer_future.cancel() + super().stop() +``` + +### `process_observable`: async subscriptions to arbitrary observables + +Sometimes you have rxpy observables which you need to run inside `self._loop`. You can do this with `self.process_observable(observable, async_handler)` . + +```python +@rpc +def start(self) -> None: + super().start() + fast = self.foo.observable().pipe(ops.filter(lambda v: v > threshold)) + self.process_observable(fast, self._on_fast_foo) + +async def _on_fast_foo(self, v: int) -> None: + ... +``` + +### `main()`: combined setup/teardown + +When a module owns a resource that needs construction at startup *and* explicit cleanup at shutdown, define `async def main(self)` as an **async generator with exactly one `yield`**. Code before `yield` runs at `start()`, code after `yield` runs at `stop()`. + +```python +class PersonFollowSkillContainer(Module): + async def main(self) -> AsyncIterator[None]: + # setup + self._vl_model = create("qwen") + + yield + + # teardown + self._vl_model.stop() +``` + +Compared to splitting the same work across `__init__` / `start()` / `stop()`, `main()` keeps the construction-and-destruction of each resource visually adjacent. + ## Blueprints A blueprint is a predefined structure of interconnected modules. You can include blueprints or modules in new blueprints.