Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions dimos/core/coordination/module_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from collections import defaultdict
from collections.abc import Mapping, MutableMapping
import importlib
import inspect
import shutil
import sys
import threading
Expand All @@ -30,7 +31,7 @@
from dimos.core.module import ModuleBase, ModuleSpec
from dimos.core.resource import Resource
from dimos.core.transport import LCMTransport, PubSubTransport, pLCMTransport
from dimos.spec.utils import spec_annotation_compliance, spec_structural_compliance
from dimos.spec.utils import is_spec, spec_annotation_compliance, spec_structural_compliance
from dimos.utils.generic import short_id
from dimos.utils.logging_config import setup_logger
from dimos.utils.safe_thread_map import safe_thread_map
Expand Down Expand Up @@ -767,19 +768,24 @@ def _connect_module_refs(
) -> None:
from dimos.core.coordination.blueprints import DisabledModuleProxy
from dimos.core.module import is_module_type
from dimos.spec.utils import is_spec
from dimos.core.rpc_client import AsyncSpecProxy

mod_and_mod_ref_to_proxy = {
(module, name): replacement
for (module, name), replacement in blueprint.remapping_map.items()
if is_spec(replacement) or is_module_type(replacement)
}

# Track the consumer's declared spec for each ref so we can wrap the proxy
# below if the spec contains async-declared methods.
declared_spec: dict[tuple[type[ModuleBase], str], Any] = {}

disabled_ref_proxies: dict[tuple[type[ModuleBase], str], DisabledModuleProxy] = {}
disabled_set = set(blueprint.disabled_modules_tuple)

for bp in blueprint.active_blueprints:
for module_ref in bp.module_refs:
declared_spec[bp.module, module_ref.name] = module_ref.spec
spec = mod_and_mod_ref_to_proxy.get((bp.module, module_ref.name), module_ref.spec)

if is_module_type(spec):
Expand All @@ -798,7 +804,10 @@ def _connect_module_refs(

for (base_module, ref_name), target_module in mod_and_mod_ref_to_proxy.items():
base_instance = module_coordinator.get_instance(base_module)
target_instance = module_coordinator.get_instance(target_module) # type: ignore[arg-type]
target_instance: Any = module_coordinator.get_instance(target_module) # type: ignore[arg-type]
async_methods = _async_methods_of_spec(declared_spec.get((base_module, ref_name)))
if async_methods:
target_instance = AsyncSpecProxy(target_instance, async_methods)
setattr(base_instance, ref_name, target_instance)
base_instance.set_module_ref(ref_name, target_instance)
module_coordinator._resolved_module_refs[base_module, ref_name] = cast(
Expand All @@ -811,6 +820,21 @@ def _connect_module_refs(
base_instance.set_module_ref(ref_name, cast("Any", proxy))


def _async_methods_of_spec(spec: Any) -> frozenset[str]:
if not is_spec(spec):
return frozenset()
names: set[str] = set()
for cls in spec.__mro__:
if cls is object:
continue
for attr_name, value in vars(cls).items():
if attr_name.startswith("_"):
continue
if inspect.iscoroutinefunction(value):
names.add(attr_name)
return frozenset(names)


def _log_blueprint_graph(blueprint: Blueprint, module_coordinator: ModuleCoordinator) -> None:
"""Log the module graph to Rerun if a RerunBridgeModule is active."""
from dimos.visualization.rerun.bridge import RerunBridgeModule
Expand Down
44 changes: 39 additions & 5 deletions dimos/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,56 @@

from __future__ import annotations

import asyncio
import functools
import inspect
from typing import (
TYPE_CHECKING,
Any,
ParamSpec,
TypeVar,
cast,
)

if TYPE_CHECKING:
from collections.abc import Callable

T = TypeVar("T")

from typing import ParamSpec, TypeVar

P = ParamSpec("P")
R = TypeVar("R")


def rpc(fn: Callable[P, R]) -> Callable[P, R]:
fn.__rpc__ = True # type: ignore[attr-defined]
return fn
"""Mark a method as an RPC body callable across modules.

Sync methods are tagged in place. Async methods get a sync dispatcher that
runs the coroutine on `self._loop`:

* Caller is on self._loop (another async @rpc, a handle_*, or a
process_observable callback): returns the coroutine so the caller can
`await` it normally.
* Caller is on any other thread (RPC dispatcher, sync test, sync @rpc on
the same module): schedules the coroutine onto self._loop and blocks
until done.
"""
if not inspect.iscoroutinefunction(fn):
fn.__rpc__ = True # type: ignore[attr-defined]
return fn

@functools.wraps(fn)
def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
loop = self._loop
if loop is None:
raise RuntimeError("async @rpc method called outside a running module loop")
try:
running = asyncio.get_running_loop()
except RuntimeError:
running = None
if running is loop:
return fn(self, *args, **kwargs) # type: ignore[call-arg]
future = asyncio.run_coroutine_threadsafe(fn(self, *args, **kwargs), loop) # type: ignore[call-arg, arg-type]
return future.result()

wrapper.__rpc__ = True # type: ignore[attr-defined]
wrapper.aio = fn # type: ignore[attr-defined]
return cast("Callable[P, R]", wrapper)
Loading
Loading