diff --git a/dimos/constants.py b/dimos/constants.py index b5c2e63620..d849f4aaf3 100644 --- a/dimos/constants.py +++ b/dimos/constants.py @@ -51,3 +51,5 @@ # Default timeout (seconds) for thread.join() during shutdown. DEFAULT_THREAD_JOIN_TIMEOUT = 2.0 + +DEFAULT_BUILD_NATIVE = False diff --git a/dimos/core/coordination/python_worker.py b/dimos/core/coordination/python_worker.py index 3c434a982e..a956c72d06 100644 --- a/dimos/core/coordination/python_worker.py +++ b/dimos/core/coordination/python_worker.py @@ -18,6 +18,7 @@ import multiprocessing from multiprocessing.connection import Connection import os +import signal import sys import threading import traceback @@ -337,12 +338,11 @@ class _WorkerState: def _worker_entrypoint(conn: Connection, worker_id: int) -> None: apply_library_config() + signal.signal(signal.SIGINT, signal.SIG_IGN) state = _WorkerState(instances={}, worker_id=worker_id) try: _worker_loop(conn, state) - except KeyboardInterrupt: - logger.info("Worker got KeyboardInterrupt.", worker_id=worker_id) except Exception as e: logger.error(f"Worker process error: {e}", exc_info=True) finally: @@ -361,12 +361,6 @@ def _worker_entrypoint(conn: Connection, worker_id: int) -> None: worker_id=worker_id, module_id=module_id, ) - except KeyboardInterrupt: - logger.warning( - "KeyboardInterrupt during worker stop", - module=type(instance).__name__, - worker_id=worker_id, - ) except Exception: logger.error("Error during worker shutdown", exc_info=True) @@ -433,7 +427,7 @@ def _worker_loop(conn: Connection, state: _WorkerState) -> None: if not conn.poll(timeout=0.1): continue request = conn.recv() - except (EOFError, KeyboardInterrupt): + except EOFError: break try: diff --git a/dimos/core/global_config.py b/dimos/core/global_config.py index ccf5b0644c..7263b591ce 100644 --- a/dimos/core/global_config.py +++ b/dimos/core/global_config.py @@ -17,6 +17,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict +from dimos.constants import DEFAULT_BUILD_NATIVE from dimos.models.vl.types import VlModelName ViewerBackend: TypeAlias = Literal["rerun", "rerun-web", "rerun-connect", "foxglove", "none"] @@ -52,6 +53,7 @@ class GlobalConfig(BaseSettings): nerf_speed: float = 1.0 planner_robot_speed: float | None = None mcp_port: int = 9990 + build_native: bool = DEFAULT_BUILD_NATIVE dtop: bool = False obstacle_avoidance: bool = True detection_model: VlModelName = "moondream" diff --git a/dimos/core/native_module.py b/dimos/core/native_module.py index 4e2ec0c699..5b6fecf3cc 100644 --- a/dimos/core/native_module.py +++ b/dimos/core/native_module.py @@ -41,8 +41,8 @@ class MyCppModule(NativeModule): from __future__ import annotations -import collections import enum +import functools import inspect import json import os @@ -51,15 +51,31 @@ class MyCppModule(NativeModule): import subprocess import sys import threading +import time from typing import IO, Any from pydantic import Field from dimos.constants import DEFAULT_THREAD_JOIN_TIMEOUT from dimos.core.core import rpc +from dimos.core.global_config import global_config from dimos.core.module import Module, ModuleConfig from dimos.utils.logging_config import setup_logger +if sys.platform.startswith("linux"): + import ctypes + from ctypes.util import find_library + + _LIBC = ctypes.CDLL(find_library("c"), use_errno=True) + + def _set_process_to_die_when_parent_dies() -> None: + _PR_SET_PDEATHSIG = 1 + if _LIBC.prctl(_PR_SET_PDEATHSIG, signal.SIGTERM) != 0: + err = ctypes.get_errno() + raise OSError(err, f"_set_process_to_die_when_parent_dies failed: {os.strerror(err)}") +else: + _set_process_to_die_when_parent_dies = None # type: ignore[assignment] + if sys.version_info < (3, 13): from typing_extensions import TypeVar else: @@ -74,22 +90,21 @@ class LogFormat(enum.Enum): class NativeModuleConfig(ModuleConfig): - """Configuration for a native (C/C++) subprocess module.""" - executable: str build_command: str | None = None cwd: str | None = None extra_args: list[str] = Field(default_factory=list) extra_env: dict[str, str] = Field(default_factory=dict) - shutdown_timeout: float = 10.0 + shutdown_timeout: float = DEFAULT_THREAD_JOIN_TIMEOUT log_format: LogFormat = LogFormat.TEXT + auto_build: bool = False # New version of Native Modules read json configs from stdin # Enable this to read from stdin instead of cli args stdin_config: bool = False - # Override in subclasses to exclude fields from CLI arg generation cli_exclude: frozenset[str] = frozenset() + cli_name_override: dict[str, str] = Field(default_factory=dict) def to_config_dict(self) -> dict[str, Any]: """ @@ -101,9 +116,6 @@ def to_config_dict(self) -> dict[str, Any]: } def to_cli_args(self) -> list[str]: - """ - Auto-convert subclass config fields to CLI args. - """ ignore_fields = {f for f in NativeModuleConfig.model_fields} args: list[str] = [] for f in self.__class__.model_fields: @@ -114,12 +126,13 @@ def to_cli_args(self) -> list[str]: val = getattr(self, f) if val is None: continue + cli_name = self.cli_name_override.get(f, f) if isinstance(val, bool): - args.extend([f"--{f}", str(val).lower()]) + args.extend([f"--{cli_name}", str(val).lower()]) elif isinstance(val, list): - args.extend([f"--{f}", ",".join(str(v) for v in val)]) + args.extend([f"--{cli_name}", ",".join(str(v) for v in val)]) else: - args.extend([f"--{f}", str(val)]) + args.extend([f"--{cli_name}", str(val)]) return args @@ -127,35 +140,36 @@ def to_cli_args(self) -> list[str]: class NativeModule(Module): - """Module that wraps a native executable as a managed subprocess. - - Subclass this, declare In/Out ports, and annotate ``config`` with a - :class:`NativeModuleConfig` subclass pointing at the executable. - - On ``start()``, the binary is launched with CLI args:: - - -- ... - - The native process should parse these args and pub/sub on the given - LCM topics directly. On ``stop()``, the process receives SIGTERM. - """ - config: NativeModuleConfig _process: subprocess.Popen[bytes] | None = None _watchdog: threading.Thread | None = None _stopping: bool = False - _last_stderr_lines: collections.deque[str] + _stop_lock: threading.Lock + + @functools.cached_property + def _mod_label(self) -> str: + exe = Path(self.config.executable).name if self.config.executable else "?" + return f"{type(self).__name__}({exe})" def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - self._last_stderr_lines = collections.deque(maxlen=50) - self._resolve_paths() + self._stop_lock = threading.Lock() + + if self.config.cwd is not None and not Path(self.config.cwd).is_absolute(): + base_dir = Path(inspect.getfile(type(self))).resolve().parent + self.config.cwd = str(base_dir / self.config.cwd) + if not Path(self.config.executable).is_absolute() and self.config.cwd is not None: + self.config.executable = str(Path(self.config.cwd) / self.config.executable) @rpc def start(self) -> None: if self._process is not None and self._process.poll() is None: - logger.warning("Native process already running", pid=self._process.pid) + logger.warning( + "Native process already running", + module=self._mod_label, + pid=self._process.pid, + ) return self._maybe_build() @@ -171,13 +185,13 @@ def start(self) -> None: env = {**os.environ, **self.config.extra_env} cwd = self.config.cwd or str(Path(self.config.executable).resolve().parent) - module_name = type(self).__name__ logger.info( - f"Starting native process: {module_name}", - module=module_name, + "Starting native process", + module=self._mod_label, cmd=" ".join(cmd), cwd=cwd, ) + self._process = subprocess.Popen( cmd, env=env, @@ -185,6 +199,8 @@ def start(self) -> None: stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, + start_new_session=True, + preexec_fn=_set_process_to_die_when_parent_dies, ) assert self._process.stdin is not None if self.config.stdin_config: @@ -195,117 +211,154 @@ def start(self) -> None: self._process.stdin.write(stdin_blob) self._process.stdin.close() logger.info( - f"Native process started: {module_name}", - module=module_name, + "Native process started", + module=self._mod_label, pid=self._process.pid, ) - self._stopping = False - self._watchdog = threading.Thread(target=self._watch_process, daemon=True) - self._watchdog.start() + watchdog = threading.Thread( + target=self._watch_process, + daemon=True, + name=f"native-watchdog-{self._mod_label}", + ) + with self._stop_lock: + self._stopping = False + self._watchdog = watchdog + watchdog.start() @rpc def stop(self) -> None: - self._stopping = True - if self._process is not None and self._process.poll() is None: - logger.info("Stopping native process", pid=self._process.pid) - self._process.send_signal(signal.SIGTERM) + # Capture refs under lock, but signal/wait/join outside it to avoid + # deadlocking with the watchdog's own stop() call. + with self._stop_lock: + if self._stopping: + return + self._stopping = True + proc = self._process + watchdog = self._watchdog + + if proc is not None and proc.poll() is None: + logger.info( + "Stopping native process", + module=self._mod_label, + pid=proc.pid, + ) + proc.send_signal(signal.SIGTERM) try: - self._process.wait(timeout=self.config.shutdown_timeout) + proc.wait(timeout=self.config.shutdown_timeout) except subprocess.TimeoutExpired: logger.warning( - "Native process did not exit, sending SIGKILL", pid=self._process.pid + "Native process did not exit, sending SIGKILL", + module=self._mod_label, + pid=proc.pid, ) - self._process.kill() - self._process.wait(timeout=5) - if self._watchdog is not None and self._watchdog is not threading.current_thread(): - self._watchdog.join(timeout=DEFAULT_THREAD_JOIN_TIMEOUT) - self._watchdog = None - self._process = None + proc.kill() + try: + proc.wait(timeout=self.config.shutdown_timeout) + except subprocess.TimeoutExpired: + logger.error( + "Native process not reapable after SIGKILL", + module=self._mod_label, + pid=proc.pid, + ) + + if watchdog is not None and watchdog is not threading.current_thread(): + watchdog.join(timeout=self.config.shutdown_timeout) + + with self._stop_lock: + self._watchdog = None + self._process = None + super().stop() def _watch_process(self) -> None: - """Block until the native process exits; trigger stop() if it crashed.""" - if self._process is None: + proc = self._process + if proc is None: return + pid = proc.pid - stdout_t = self._start_reader(self._process.stdout, "info") - stderr_t = self._start_reader(self._process.stderr, "warning") - rc = self._process.wait() - stdout_t.join(timeout=DEFAULT_THREAD_JOIN_TIMEOUT) - stderr_t.join(timeout=DEFAULT_THREAD_JOIN_TIMEOUT) + stdout_t = self._start_reader(proc.stdout, "info", pid) + stderr_t = self._start_reader(proc.stderr, "warning", pid) + rc = proc.wait() + stdout_t.join(timeout=self.config.shutdown_timeout) + stderr_t.join(timeout=self.config.shutdown_timeout) if self._stopping: + logger.info( + "Native process exited (expected)", + module=self._mod_label, + pid=pid, + returncode=rc, + ) return - module_name = type(self).__name__ - exe_name = Path(self.config.executable).name if self.config.executable else "unknown" - - # Use buffered stderr lines from the reader thread for the crash report. - last_stderr = "\n".join(self._last_stderr_lines) - logger.error( - f"Native process crashed: {module_name} ({exe_name})", - module=module_name, - executable=exe_name, - pid=self._process.pid, + "Native process died unexpectedly", + module=self._mod_label, + pid=pid, returncode=rc, - last_stderr=last_stderr[:500] if last_stderr else None, ) self.stop() - def _start_reader(self, stream: IO[bytes] | None, level: str) -> threading.Thread: - """Spawn a daemon thread that pipes a subprocess stream through the logger.""" - t = threading.Thread(target=self._read_log_stream, args=(stream, level), daemon=True) + def _start_reader( + self, + stream: IO[bytes] | None, + level: str, + pid: int, + ) -> threading.Thread: + t = threading.Thread( + target=self._read_log_stream, + args=(stream, level, pid), + daemon=True, + name=f"native-reader-{level}-{self._mod_label}", + ) t.start() return t - def _read_log_stream(self, stream: IO[bytes] | None, level: str) -> None: + def _read_log_stream( + self, + stream: IO[bytes] | None, + level: str, + pid: int, + ) -> None: if stream is None: return log_fn = getattr(logger, level) - is_stderr = level == "warning" for raw in stream: line = raw.decode("utf-8", errors="replace").rstrip() if not line: continue - if is_stderr: - self._last_stderr_lines.append(line) if self.config.log_format == LogFormat.JSON: try: data = json.loads(line) event = data.pop("event", line) - log_fn(event, **data) + log_fn(event, module=self._mod_label, pid=pid, **data) continue except (json.JSONDecodeError, TypeError): - logger.warning("malformed JSON from native module", raw=line) - log_fn(line, pid=self._process.pid if self._process else None) + pass + log_fn(line, module=self._mod_label, pid=pid) stream.close() - def _resolve_paths(self) -> None: - """Resolve relative ``cwd`` and ``executable`` against the subclass's source file.""" - if self.config.cwd is not None and not Path(self.config.cwd).is_absolute(): - source_file = inspect.getfile(type(self)) - base_dir = Path(source_file).resolve().parent - self.config.cwd = str(base_dir / self.config.cwd) - if not Path(self.config.executable).is_absolute() and self.config.cwd is not None: - self.config.executable = str(Path(self.config.cwd) / self.config.executable) - def _maybe_build(self) -> None: - """Run ``build_command`` if the executable does not exist.""" exe = Path(self.config.executable) - if exe.exists(): - return + if self.config.build_command is None: - raise FileNotFoundError( - f"Executable not found: {exe}. " - "Set build_command in config to auto-build, or build it manually." - ) + if not exe.exists(): + raise FileNotFoundError( + f"[{self._mod_label}] Executable not found: {exe}. " + "Set build_command in config to auto-build, or build it manually." + ) + return + + if exe.exists() and not self.config.auto_build and not global_config.build_native: + return + logger.info( - "Executable not found, running build", + "Building native module", executable=str(exe), build_command=self.config.build_command, ) + build_start = time.perf_counter() proc = subprocess.Popen( self.config.build_command, shell=True, @@ -315,27 +368,36 @@ def _maybe_build(self) -> None: stderr=subprocess.PIPE, ) stdout, stderr = proc.communicate() - for line in stdout.decode("utf-8", errors="replace").splitlines(): + build_elapsed = time.perf_counter() - build_start + + stdout_lines = stdout.decode("utf-8", errors="replace").splitlines() + stderr_lines = stderr.decode("utf-8", errors="replace").splitlines() + + for line in stdout_lines: if line.strip(): - logger.info(line) - for line in stderr.decode("utf-8", errors="replace").splitlines(): + logger.info(line, module=self._mod_label) + for line in stderr_lines: if line.strip(): - logger.warning(line) + logger.warning(line, module=self._mod_label) + if proc.returncode != 0: - stderr_tail = stderr.decode("utf-8", errors="replace").strip()[-1000:] raise RuntimeError( - f"Build command failed (exit {proc.returncode}): {self.config.build_command}\n" - f"stderr: {stderr_tail}" + f"[{self._mod_label}] Build command failed after {build_elapsed:.2f}s " + f"(exit {proc.returncode}): {self.config.build_command}" ) if not exe.exists(): raise FileNotFoundError( - f"Build command succeeded but executable still not found: {exe}\n" - f"Build output may have been written to a different path. " - f"Check that build_command produces the executable at the expected location." + f"[{self._mod_label}] Build command succeeded but executable still not found: {exe}" ) + logger.info( + "Build command completed", + module=self._mod_label, + executable=str(exe), + duration_sec=round(build_elapsed, 3), + ) + def _collect_topics(self) -> dict[str, str]: - """Extract LCM topic strings from blueprint-assigned stream transports.""" topics: dict[str, str] = {} for name in list(self.inputs) + list(self.outputs): stream = getattr(self, name, None) diff --git a/dimos/core/test_native_module.py b/dimos/core/test_native_module.py index c34ae0a3cc..bb11868b56 100644 --- a/dimos/core/test_native_module.py +++ b/dimos/core/test_native_module.py @@ -28,7 +28,7 @@ from dimos.core.coordination.module_coordinator import ModuleCoordinator from dimos.core.core import rpc from dimos.core.module import Module -from dimos.core.native_module import LogFormat, NativeModule, NativeModuleConfig +from dimos.core.native_module import NativeModule, NativeModuleConfig from dimos.core.stream import In, Out from dimos.core.transport import LCMTransport from dimos.msgs.geometry_msgs.Twist import Twist @@ -60,7 +60,6 @@ def read_json_file(path: str) -> dict[str, str]: class StubNativeConfig(NativeModuleConfig): executable: str = _ECHO - log_format: LogFormat = LogFormat.TEXT output_file: str | None = None die_after: float | None = None some_param: float = 1.5