From 0e48b32d9da110c81e2b29e914c87c936cf6f640 Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Tue, 21 Apr 2026 15:17:53 -0700 Subject: [PATCH 01/30] scripts for bringing up the openarm robot --- .../openarm/scripts/openarm_can_up.sh | 36 ++++++ .../openarm/scripts/openarm_set_mit_mode.py | 105 ++++++++++++++++++ 2 files changed, 141 insertions(+) create mode 100755 dimos/robot/manipulators/openarm/scripts/openarm_can_up.sh create mode 100755 dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py diff --git a/dimos/robot/manipulators/openarm/scripts/openarm_can_up.sh b/dimos/robot/manipulators/openarm/scripts/openarm_can_up.sh new file mode 100755 index 0000000000..d25fc41e43 --- /dev/null +++ b/dimos/robot/manipulators/openarm/scripts/openarm_can_up.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +# Bring up CAN interfaces for OpenArm. Default is classical CAN @ 1 Mbit, +# which is what most gs_usb (OpenMoko / Geschwister Schneider) USB-CAN +# adapters support. Use MODE=fd if you have a CAN-FD-capable adapter. +# Run with sudo or as root. +# +# Usage: +# sudo ./dimos/robot/manipulators/openarm/scripts/openarm_can_up.sh # classical 1M, can0 and can1 +# sudo ./dimos/robot/manipulators/openarm/scripts/openarm_can_up.sh can0 # single interface +# sudo MODE=fd ./dimos/robot/manipulators/openarm/scripts/openarm_can_up.sh can0 # CAN-FD 1M/5M +set -euo pipefail + +BITRATE=1000000 +DBITRATE=5000000 +MODE="${MODE:-classical}" # classical | fd +IFACES_ARG="${*:-can0 can1}" +# shellcheck disable=SC2206 +IFACES=(${IFACES_ARG[@]}) + +for IF in "${IFACES[@]}"; do + if ! ip link show "$IF" >/dev/null 2>&1; then + echo "[skip] $IF not present" + continue + fi + ip link set "$IF" down || true + if [ "$MODE" = "classical" ]; then + echo "[up ] $IF ${BITRATE} (classical CAN)" + ip link set "$IF" type can bitrate "$BITRATE" + else + echo "[up ] $IF ${BITRATE}/${DBITRATE} fd on" + ip link set "$IF" type can bitrate "$BITRATE" dbitrate "$DBITRATE" fd on + fi + ip link set "$IF" up + ip link set "$IF" txqueuelen 1000 + ip -details link show "$IF" | grep -E "can |bitrate" || true +done diff --git a/dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py b/dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py new file mode 100755 index 0000000000..a361e74e4e --- /dev/null +++ b/dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 +"""Write CTRL_MODE = MIT (1) to one or all OpenArm motors. + +Damiao motors have a persistent CTRL_MODE register (RID=10). If a motor was +previously configured in POS_VEL (2) / VEL (3) / POS_FORCE (4) mode, it will +respond to enable/disable but IGNORE MIT control frames — exactly the +"motor doesn't move, error grows" symptom. + +This script writes CTRL_MODE=1 (MIT) via the 0x7FF broadcast-write frame +format used by enactic/openarm_can: + + ID=0x7FF data = [id_lo, id_hi, 0x55, RID=10, val[0], val[1], val[2], val[3]] + +Run once per motor after CAN bring-up. The value is persistent across power +cycles. + +Usage: + # All 8 motors on can0 + python dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py --channel can0 --classical + + # Single motor + python dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py --channel can0 --id 0x05 --classical +""" +from __future__ import annotations + +import argparse +import struct +import sys +import time + +try: + import can +except ImportError: + sys.exit("python-can not installed") + +RID_CTRL_MODE = 10 +MIT_MODE = 1 +DEFAULT_IDS = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08] + + +def write_ctrl_mode(bus: can.BusABC, send_id: int, fd: bool) -> bool: + val = struct.pack("> 8) & 0xFF, 0x55, RID_CTRL_MODE, + val[0], val[1], val[2], val[3]]) + # Flush + while bus.recv(0.0) is not None: + pass + bus.send(can.Message(arbitration_id=0x7FF, data=data, + is_extended_id=False, is_fd=fd, bitrate_switch=fd)) + # Wait for ack on 0x7FF (per openarm_can param response) + t0 = time.monotonic() + while time.monotonic() - t0 < 0.2: + msg = bus.recv(0.2 - (time.monotonic() - t0)) + if msg is None: + break + # Reply echoes 0x55 in byte 2 of the 0x7FF channel + if len(msg.data) >= 4 and msg.data[2] in (0x33, 0x55): + rid = msg.data[3] + if rid == RID_CTRL_MODE: + echoed = struct.unpack(" int: + ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + ap.add_argument("--channel", default="can0") + ap.add_argument("--fd", action="store_true", help="Use CAN-FD (default: classical CAN)") + ap.add_argument("--id", type=lambda s: int(s, 0), default=None, + help="Single send ID (default: all 8)") + args = ap.parse_args() + + fd = args.fd + ids = [args.id] if args.id is not None else DEFAULT_IDS + + # Preflight: is the interface up? + try: + flags = int(open(f"/sys/class/net/{args.channel}/flags").read().strip(), 16) + except OSError: + print(f"ERROR: interface '{args.channel}' not found", file=sys.stderr) + return 1 + if not (flags & 0x1): + print(f"ERROR: SocketCAN interface '{args.channel}' is DOWN.", file=sys.stderr) + print(f" Run: sudo ./dimos/robot/manipulators/openarm/scripts/openarm_can_up.sh {args.channel}", file=sys.stderr) + return 1 + + print(f"Opening {args.channel} ({'CAN-FD' if fd else 'classical'})") + bus = can.Bus(interface="socketcan", channel=args.channel, fd=fd) + try: + ok = 0 + for i in ids: + if write_ctrl_mode(bus, i, fd): + ok += 1 + time.sleep(0.05) + print(f"\n{ok}/{len(ids)} motors set to MIT mode.") + return 0 if ok == len(ids) else 2 + finally: + bus.shutdown() + + +if __name__ == "__main__": + sys.exit(main()) From 1f6ee610a343314a17eb5fce927c8f603aa7adb2 Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Tue, 21 Apr 2026 15:18:14 -0700 Subject: [PATCH 02/30] adde open arm driver and adapter file --- .../hardware/manipulators/openarm/adapter.py | 484 +++++++++++++++++ dimos/hardware/manipulators/openarm/driver.py | 489 ++++++++++++++++++ 2 files changed, 973 insertions(+) create mode 100644 dimos/hardware/manipulators/openarm/adapter.py create mode 100644 dimos/hardware/manipulators/openarm/driver.py diff --git a/dimos/hardware/manipulators/openarm/adapter.py b/dimos/hardware/manipulators/openarm/adapter.py new file mode 100644 index 0000000000..6d7c03d586 --- /dev/null +++ b/dimos/hardware/manipulators/openarm/adapter.py @@ -0,0 +1,484 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OpenArm adapter — implements the ManipulatorAdapter protocol. + +Wraps the from-scratch Damiao MIT-mode CAN driver in :mod:`.driver`. All +physics-level work (frame packing, bus threading, motor state caching) +lives in the driver; this file just maps the dimos protocol methods to +driver calls and handles per-joint sign/offset convention. + +Units: radians, rad/s, Nm (matching the driver and the protocol). + +Default wiring matches the OpenArm v10 BOM (send IDs 1..7 + gripper 8, +motor types DM8006/DM4340/DM4310, recv = send | 0x10). See +``docs/capabilities/manipulation/openarm_integration.md``. +""" + +from __future__ import annotations + +import time +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import numpy as np + +from dimos.hardware.manipulators.openarm.driver import ( + CTRL_MODE_MIT, + DamiaoMotor, + MotorType, + OpenArmBus, +) +from dimos.hardware.manipulators.spec import ( + ControlMode, + JointLimits, + ManipulatorInfo, +) + +if TYPE_CHECKING: + from dimos.hardware.manipulators.registry import AdapterRegistry + + +def _socketcan_iface_up(name: str) -> bool: + """Return True if a SocketCAN interface is present and in the UP state. + + Reads /sys directly instead of shelling out to ip(8) — no subprocess, + no sudo, works in containers. + """ + try: + flags_path = Path("/sys/class/net") / name / "flags" + if not flags_path.exists(): + return False + # IFF_UP is bit 0 of the interface flags register. + return (int(flags_path.read_text().strip(), 16) & 0x1) == 0x1 + except OSError: + return False + + +# OpenArm v10 BOM — (send_id, MotorType) per joint, derived from the torque +# column of data/openarm_description/config/arm/v10/joint_limits.yaml. +_OPENARM_V10_ARM_MOTORS: list[tuple[int, MotorType]] = [ + (0x01, MotorType.DM8006), # joint1 + (0x02, MotorType.DM8006), # joint2 + (0x03, MotorType.DM4340), # joint3 + (0x04, MotorType.DM4340), # joint4 + (0x05, MotorType.DM4310), # joint5 + (0x06, MotorType.DM4310), # joint6 + (0x07, MotorType.DM4310), # joint7 +] +_OPENARM_V10_GRIPPER_MOTOR: tuple[int, MotorType] = (0x08, MotorType.DM4310) + +# Physical joint limits (measured). Joints 1 & 2 are mirrored between sides. +_V10_POS_LOWER_LEFT = [-3.45, -3.30, -1.50, -0.01, -1.50, -0.75, -1.50] +_V10_POS_UPPER_LEFT = [1.35, 0.15, 1.50, 2.40, 1.50, 0.75, 1.50] +_V10_POS_LOWER_RIGHT = [-1.35, -0.15, -1.50, -0.01, -1.50, -0.75, -1.50] +_V10_POS_UPPER_RIGHT = [3.45, 3.30, 1.50, 2.40, 1.50, 0.75, 1.50] +_V10_VEL_MAX = [16.754666, 16.754666, 5.445426, 5.445426, 20.943946, 20.943946, 20.943946] + +# Default MIT gains per joint for POSITION mode. +# kp range is [0, 500], kd range is [0, 5]. +# With gravity compensation enabled, the PD gains only handle transient +# tracking — they don't fight gravity. Lower kp = smoother, less buzz. +# High kd causes high-frequency buzz/grinding from the gearbox. +_DEFAULT_KP = [100.0, 100.0, 80.0, 80.0, 60.0, 60.0, 60.0] +_DEFAULT_KD = [1.5, 1.5, 1.0, 1.0, 0.8, 0.8, 0.8] + + +class OpenArmAdapter: + """Adapter for one OpenArm (7 DOF) on a single SocketCAN bus. + + Implements ``ManipulatorAdapter`` via duck typing — no inheritance. + + Parameters + ---------- + address: + SocketCAN channel, e.g. ``"can0"``. + dof: + Must be 7 (OpenArm is fixed-DOF). Kept as a parameter for adapter- + protocol uniformity. + side: + ``"left"`` or ``"right"``. Currently only stored for logging; no sign + flips are applied (the URDF handles left/right mirroring). + fd: + CAN-FD. Defaults to False because the gs_usb adapters we have don't + support FD, and OpenArm runs fine on classical 1 Mbit CAN. + interface: + python-can interface name. Use ``"virtual"`` for unit tests. + kp / kd: + Optional per-joint overrides of the POSITION-mode MIT gains. + gravity_comp: + Enable Pinocchio-based gravity compensation feedforward. Computes + ``tau_gravity = G(q_current)`` each tick and adds it as the tau_ff + term in the MIT frame, so the PD gains only handle transient + tracking — not fighting gravity. Eliminates steady-state error. + auto_set_mit_mode: + If True (default), write ``CTRL_MODE=MIT`` to every motor during + ``connect()``. Idempotent — safe to leave on. Set False to verify + a previous write persisted across power cycles (i.e. to confirm + motors stay in MIT mode without the adapter re-setting it). + """ + + # Per-side URDFs for Pinocchio gravity model + _REPO_ROOT = Path(__file__).resolve().parents[4] + _OPENARM_PKG = _REPO_ROOT / "data" / "openarm_description" + _URDF_LEFT = _OPENARM_PKG / "urdf" / "robot" / "openarm_v10_left.urdf" + _URDF_RIGHT = _OPENARM_PKG / "urdf" / "robot" / "openarm_v10_right.urdf" + + def __init__( + self, + address: str = "can0", + dof: int = 7, + *, + side: str = "left", + fd: bool = False, + interface: str = "socketcan", + kp: list[float] | None = None, + kd: list[float] | None = None, + gravity_comp: bool = True, + auto_set_mit_mode: bool = True, + **_: Any, + ) -> None: + if dof != 7: + raise ValueError(f"OpenArmAdapter only supports 7 DOF (got {dof})") + if side not in ("left", "right"): + raise ValueError(f"side must be 'left' or 'right', got {side!r}") + self._address = address + self._dof = dof + self._side = side + self._fd = fd + self._interface = interface + self._kp = list(kp) if kp is not None else list(_DEFAULT_KP) + self._kd = list(kd) if kd is not None else list(_DEFAULT_KD) + if len(self._kp) != dof or len(self._kd) != dof: + raise ValueError("kp/kd must be length 7") + self._gravity_comp = gravity_comp + self._auto_set_mit_mode = auto_set_mit_mode + + self._motors = [DamiaoMotor(sid, mt) for sid, mt in _OPENARM_V10_ARM_MOTORS] + self._bus: OpenArmBus | None = None + self._control_mode: ControlMode = ControlMode.POSITION + self._enabled: bool = False + # Last successful position command — used as q_target for VELOCITY mode + self._last_cmd_q: list[float] | None = None + + # Pinocchio model for gravity compensation (loaded lazily in connect()) + self._pin_model: Any = None + self._pin_data: Any = None + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def connect(self) -> bool: + # Preflight: verify the SocketCAN interface is up before opening the bus. + # Bringing the interface up requires root privileges, so we don't do it + # here — just fail early with a helpful message. + if self._interface == "socketcan" and not _socketcan_iface_up(self._address): + print( + f"ERROR: SocketCAN interface '{self._address}' is not UP.\n" + f" Run: sudo ip link set {self._address} up type can bitrate 1000000\n" + f" (or: sudo ./dimos/robot/manipulators/openarm/scripts/openarm_can_up.sh {self._address})" + ) + return False + + try: + self._bus = OpenArmBus( + channel=self._address, + motors=self._motors, + fd=self._fd, + interface=self._interface, + ) + self._bus.open() + except Exception as e: # noqa: BLE001 + print(f"ERROR: OpenArm {self._side}@{self._address} connect failed: {e}") + self._bus = None + return False + + # Ensure every motor is in MIT control mode. The write is idempotent + # (setting CTRL_MODE=MIT when it's already MIT is a no-op), so we + # write unconditionally rather than query-then-write. + if self._auto_set_mit_mode: + for m in self._motors: + self._bus.write_ctrl_mode(m.send_id, CTRL_MODE_MIT) + time.sleep(0.005) + else: + print( + f"OpenArm {self._side}@{self._address}: " + "auto_set_mit_mode disabled — relying on persisted register" + ) + + # Enable motors and wait for at least one state reply from each. + self._bus.enable_all() + self._enabled = True + if not self._bus.wait_all_states(timeout=0.5): + print( + f"WARNING: OpenArm {self._side}@{self._address}: not all motors " + "reported state within 0.5s — proceeding anyway" + ) + # Seed the position anchor from current hardware pose. + self._last_cmd_q = self.read_joint_positions() + + # Load Pinocchio model for gravity compensation + if self._gravity_comp: + try: + import pinocchio + + urdf = str(self._URDF_LEFT if self._side == "left" else self._URDF_RIGHT) + self._pin_model = pinocchio.buildModelFromUrdf(urdf) + self._pin_data = self._pin_model.createData() + print(f"OpenArm {self._side}: gravity compensation enabled (nq={self._pin_model.nq})") + except Exception as e: # noqa: BLE001 + print(f"WARNING: gravity comp disabled — {e}") + self._pin_model = None + self._pin_data = None + + return True + + def disconnect(self) -> None: + if self._bus is None: + return + try: + self._bus.disable_all() + except Exception: # noqa: BLE001 + pass + self._enabled = False + self._bus.close() + self._bus = None + + def is_connected(self) -> bool: + return self._bus is not None + + # ------------------------------------------------------------------ + # Info + # ------------------------------------------------------------------ + + def get_info(self) -> ManipulatorInfo: + return ManipulatorInfo( + vendor="Enactic", + model=f"OpenArm v10 ({self._side})", + dof=self._dof, + firmware_version=None, + serial_number=None, + ) + + def get_dof(self) -> int: + return self._dof + + def get_limits(self) -> JointLimits: + if self._side == "left": + lower, upper = _V10_POS_LOWER_LEFT, _V10_POS_UPPER_LEFT + else: + lower, upper = _V10_POS_LOWER_RIGHT, _V10_POS_UPPER_RIGHT + return JointLimits( + position_lower=list(lower), + position_upper=list(upper), + velocity_max=list(_V10_VEL_MAX), + ) + + # ------------------------------------------------------------------ + # Mode + # ------------------------------------------------------------------ + + def set_control_mode(self, mode: ControlMode) -> bool: + # OpenArm runs exclusively in Damiao MIT register mode; we emulate + # dimos ControlModes by tuning kp/kd/q/dq/tau on each MIT frame. + # Cartesian/impedance control are outside this adapter's scope. + if mode in ( + ControlMode.POSITION, + ControlMode.SERVO_POSITION, + ControlMode.VELOCITY, + ControlMode.TORQUE, + ): + self._control_mode = mode + return True + return False + + def get_control_mode(self) -> ControlMode: + return self._control_mode + + # ------------------------------------------------------------------ + # State reads + # ------------------------------------------------------------------ + + def _states_or_raise(self) -> list[Any]: + if self._bus is None: + raise RuntimeError("OpenArmAdapter not connected") + return self._bus.get_states() + + def read_joint_positions(self) -> list[float]: + return [s.q if s is not None else 0.0 for s in self._states_or_raise()] + + def read_joint_velocities(self) -> list[float]: + return [s.dq if s is not None else 0.0 for s in self._states_or_raise()] + + def read_joint_efforts(self) -> list[float]: + return [s.tau if s is not None else 0.0 for s in self._states_or_raise()] + + def read_state(self) -> dict[str, int]: + if self._bus is None: + return {"state": 0, "mode": 0} + states = self._bus.get_states() + # report the hottest rotor temperature so callers can monitor thermal + # stress with a single scalar + t_rotor = max((s.t_rotor for s in states if s is not None), default=0) + return { + "state": 1 if self._enabled else 0, + "mode": 1, # MIT + "t_rotor_max": int(t_rotor), + } + + def read_error(self) -> tuple[int, str]: + # The Damiao motors don't report a structured error code in the state + # frame; over-temperature / over-torque are detected by the host from + # the normal state fields. Surface a soft thermal warning here. + if self._bus is None: + return 0, "" + states = self._bus.get_states() + t_rotor = max((s.t_rotor for s in states if s is not None), default=0) + if t_rotor >= 85: + return 1, f"rotor over-temperature ({t_rotor}°C)" + return 0, "" + + # ------------------------------------------------------------------ + # Gravity compensation + # ------------------------------------------------------------------ + + def _compute_gravity_torques(self, q: list[float]) -> list[float]: + """Compute per-joint gravity torques at configuration q using Pinocchio. + + Returns [0.0]*dof if gravity comp is disabled or model not loaded. + """ + if self._pin_model is None or self._pin_data is None: + return [0.0] * self._dof + import pinocchio + + q_arr = np.array(q, dtype=np.float64) + tau_g = pinocchio.computeGeneralizedGravity( + self._pin_model, self._pin_data, q_arr + ) + # Clamp to motor torque limits for safety + limits = [m.limits for m in self._motors] # (p_max, v_max, t_max) + return [ + float(np.clip(tau_g[i], -lim[2], lim[2])) + for i, lim in enumerate(limits) + ] + + # ------------------------------------------------------------------ + # Commands + # ------------------------------------------------------------------ + + def write_joint_positions( + self, + positions: list[float], + velocity: float = 1.0, + ) -> bool: + if self._bus is None or not self._enabled: + return False + if len(positions) != self._dof: + return False + velocity = max(0.0, min(1.0, velocity)) + # Gravity feedforward: compute tau needed to hold the arm at the + # current configuration. The PD gains handle the rest. + q_current = self.read_joint_positions() + tau_ff = self._compute_gravity_torques(q_current) + commands = [ + (q, 0.0, kp * velocity, kd, tau) + for q, kp, kd, tau in zip(positions, self._kp, self._kd, tau_ff) + ] + self._bus.send_mit_many(commands) + self._last_cmd_q = list(positions) + return True + + def write_joint_velocities(self, velocities: list[float]) -> bool: + # MIT velocity tracking: kp=0, send dq directly, anchor q at the + # last-commanded position so the motor doesn't drift. + if self._bus is None or not self._enabled: + return False + if len(velocities) != self._dof: + return False + if self._last_cmd_q is None: + self._last_cmd_q = self.read_joint_positions() + anchor = self._last_cmd_q + commands = [ + (q_anchor, dq, 0.0, kd, 0.0) + for q_anchor, dq, kd in zip(anchor, velocities, self._kd) + ] + self._bus.send_mit_many(commands) + return True + + def write_stop(self) -> bool: + if self._bus is None: + return False + try: + q_now = self.read_joint_positions() + except Exception: # noqa: BLE001 + q_now = [0.0] * self._dof + tau_ff = self._compute_gravity_torques(q_now) + commands = [(q, 0.0, kp, kd, tau) + for q, kp, kd, tau in zip(q_now, self._kp, self._kd, tau_ff)] + self._bus.send_mit_many(commands) + self._last_cmd_q = q_now + return True + + def write_enable(self, enable: bool) -> bool: + if self._bus is None: + return False + if enable: + self._bus.enable_all() + else: + self._bus.disable_all() + self._enabled = enable + return True + + def read_enabled(self) -> bool: + return self._enabled + + def write_clear_errors(self) -> bool: + # Damiao motors have no separate clear-error command; re-enabling + # after a fault is the recovery path. + if self._bus is None: + return False + self._bus.disable_all() + self._bus.enable_all() + self._enabled = True + return True + + # ------------------------------------------------------------------ + # Cartesian / gripper / F/T — not supported at this layer + # ------------------------------------------------------------------ + + def read_cartesian_position(self) -> dict[str, float] | None: + return None + + def write_cartesian_position( + self, pose: dict[str, float], velocity: float = 1.0 + ) -> bool: + return False + + def read_gripper_position(self) -> float | None: + return None + + def write_gripper_position(self, position: float) -> bool: + return False + + def read_force_torque(self) -> list[float] | None: + return None + + +# ── Registry hook (required for auto-discovery) ─────────────────── +def register(registry: "AdapterRegistry") -> None: + registry.register("openarm", OpenArmAdapter) + + +__all__ = ["OpenArmAdapter", "register"] diff --git a/dimos/hardware/manipulators/openarm/driver.py b/dimos/hardware/manipulators/openarm/driver.py new file mode 100644 index 0000000000..88d57aacea --- /dev/null +++ b/dimos/hardware/manipulators/openarm/driver.py @@ -0,0 +1,489 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Damiao MIT-mode CAN driver for OpenArm. + +Standalone module — zero dimos dependencies. Implements the subset of the +Damiao DM-J CAN protocol needed for MIT-mode position / velocity / torque +control, matching the reference C++ implementation in +`enactic/openarm_can/src/openarm/damiao_motor/dm_motor_control.cpp`. + +Unit tests use ``can.Bus(interface="virtual")`` for loopback. + +All user-facing values are SI units: + angles in radians + angular velocity in rad/s + torque in Nm + +See ``docs/capabilities/manipulation/openarm_integration.md`` for protocol +byte-level documentation and motor tables. +""" + +from __future__ import annotations + +import enum +import struct +import threading +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import can + + +# --------------------------------------------------------------------------- +# Motor tables (from enactic/openarm_can dm_motor_constants.hpp) +# --------------------------------------------------------------------------- + + +class MotorType(str, enum.Enum): + """Damiao motor types used on OpenArm. Values match the reference library.""" + + DM3507 = "DM3507" + DM4310 = "DM4310" + DM4310_48V = "DM4310_48V" + DM4340 = "DM4340" + DM4340_48V = "DM4340_48V" + DM6006 = "DM6006" + DM8006 = "DM8006" + DM8009 = "DM8009" + DM10010L = "DM10010L" + DM10010 = "DM10010" + DMH3510 = "DMH3510" + DMH6215 = "DMH6215" + DMG6220 = "DMG6220" + + +# (p_max [rad], v_max [rad/s], t_max [Nm]) +_MOTOR_LIMITS: dict[MotorType, tuple[float, float, float]] = { + MotorType.DM3507: (12.5, 50.0, 5.0), + MotorType.DM4310: (12.5, 30.0, 10.0), + MotorType.DM4310_48V: (12.5, 50.0, 10.0), + MotorType.DM4340: (12.5, 8.0, 28.0), + MotorType.DM4340_48V: (12.5, 10.0, 28.0), + MotorType.DM6006: (12.5, 45.0, 20.0), + MotorType.DM8006: (12.5, 45.0, 40.0), + MotorType.DM8009: (12.5, 45.0, 54.0), + MotorType.DM10010L: (12.5, 25.0, 200.0), + MotorType.DM10010: (12.5, 20.0, 200.0), + MotorType.DMH3510: (12.5, 280.0, 1.0), + MotorType.DMH6215: (12.5, 45.0, 10.0), + MotorType.DMG6220: (12.5, 45.0, 10.0), +} + +# MIT gain ranges (protocol-fixed, same for every motor type) +KP_MIN, KP_MAX = 0.0, 500.0 +KD_MIN, KD_MAX = 0.0, 5.0 + +# Broadcast/control CAN IDs +_BROADCAST_ID = 0x7FF +_CMD_ENABLE = 0xFC +_CMD_DISABLE = 0xFD +_CMD_SET_ZERO = 0xFE +_RID_CTRL_MODE = 10 +CTRL_MODE_MIT = 1 + + +# --------------------------------------------------------------------------- +# Pack / unpack helpers (pure — safe to unit test in isolation) +# --------------------------------------------------------------------------- + + +def _clamp(x: float, lo: float, hi: float) -> float: + if x < lo: + return lo + if x > hi: + return hi + return x + + +def float_to_uint(x: float, lo: float, hi: float, bits: int) -> int: + """Quantize `x` in `[lo, hi]` to a `bits`-wide unsigned int. + + Matches ``CanPacketEncoder::double_to_uint`` from the reference library. + """ + x = _clamp(x, lo, hi) + span = hi - lo + return int((x - lo) / span * ((1 << bits) - 1)) + + +def uint_to_float(u: int, lo: float, hi: float, bits: int) -> float: + """Inverse of :func:`float_to_uint`.""" + span = hi - lo + return u / ((1 << bits) - 1) * span + lo + + +def pack_mit_frame( + motor_type: MotorType, + q: float, + dq: float, + kp: float, + kd: float, + tau: float, +) -> bytes: + """Pack a Damiao MIT-mode control frame into 8 bytes.""" + p_max, v_max, t_max = _MOTOR_LIMITS[motor_type] + q_u = float_to_uint(q, -p_max, p_max, 16) + dq_u = float_to_uint(dq, -v_max, v_max, 12) + kp_u = float_to_uint(kp, KP_MIN, KP_MAX, 12) + kd_u = float_to_uint(kd, KD_MIN, KD_MAX, 12) + tau_u = float_to_uint(tau, -t_max, t_max, 12) + return bytes( + [ + (q_u >> 8) & 0xFF, + q_u & 0xFF, + (dq_u >> 4) & 0xFF, + ((dq_u & 0xF) << 4) | ((kp_u >> 8) & 0xF), + kp_u & 0xFF, + (kd_u >> 4) & 0xFF, + ((kd_u & 0xF) << 4) | ((tau_u >> 8) & 0xF), + tau_u & 0xFF, + ] + ) + + +@dataclass(frozen=True) +class MotorState: + """Decoded state from a Damiao reply frame.""" + + q: float # rad + dq: float # rad/s + tau: float # Nm + t_mos: int # °C + t_rotor: int # °C + timestamp: float # monotonic seconds when received + + +def parse_state_frame(motor_type: MotorType, data: bytes) -> MotorState | None: + """Decode an 8-byte Damiao state reply. Returns None if too short.""" + if len(data) < 8: + return None + p_max, v_max, t_max = _MOTOR_LIMITS[motor_type] + q_u = (data[1] << 8) | data[2] + dq_u = (data[3] << 4) | (data[4] >> 4) + tau_u = ((data[4] & 0x0F) << 8) | data[5] + return MotorState( + q=uint_to_float(q_u, -p_max, p_max, 16), + dq=uint_to_float(dq_u, -v_max, v_max, 12), + tau=uint_to_float(tau_u, -t_max, t_max, 12), + t_mos=int(data[6]), + t_rotor=int(data[7]), + timestamp=time.monotonic(), + ) + + +def _pack_control_command(cmd: int) -> bytes: + return bytes([0xFF] * 7 + [cmd & 0xFF]) + + +def pack_write_param_frame(send_id: int, rid: int, value_u32: int) -> bytes: + """Broadcast parameter-write frame sent to CAN id 0x7FF.""" + val = struct.pack("> 8) & 0xFF, + 0x55, + rid & 0xFF, + val[0], + val[1], + val[2], + val[3], + ] + ) + + +# --------------------------------------------------------------------------- +# Motor descriptor +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class DamiaoMotor: + """A single motor on a shared CAN bus. + + Attributes: + send_id: CAN arbitration ID we send commands to (e.g. 0x01..0x08). + motor_type: Determines q/dq/tau scaling. + recv_id: CAN ID the motor replies on. Defaults to ``send_id | 0x10``, + which is the OpenArm convention and matches the pattern used by + ``enactic/openarm_can``. + """ + + send_id: int + motor_type: MotorType + recv_id: int | None = None + + @property + def effective_recv_id(self) -> int: + return self.recv_id if self.recv_id is not None else (self.send_id | 0x10) + + @property + def limits(self) -> tuple[float, float, float]: + """``(p_max, v_max, t_max)`` in rad / rad·s⁻¹ / Nm.""" + return _MOTOR_LIMITS[self.motor_type] + + +# --------------------------------------------------------------------------- +# Bus wrapper with background receive thread +# --------------------------------------------------------------------------- + + +class OpenArmBus: + """Manage one SocketCAN bus carrying a set of Damiao motors. + + Spawns a background thread that continuously drains the bus and caches + the latest :class:`MotorState` per motor. Callers fan out commands via + :meth:`send_mit_many` at their chosen control rate and pull fresh state + via :meth:`get_states` / :meth:`get_state`. + + Parameters + ---------- + channel: + SocketCAN interface name (e.g. ``"can0"``). + motors: + Motors present on this bus. + fd: + True to use CAN-FD (adapter must support it). OpenArm ships fine on + classical CAN @ 1 Mbit — keep this False unless you know you need FD. + interface: + python-can interface name. ``"socketcan"`` for real hardware, + ``"virtual"`` for unit tests. + """ + + def __init__( + self, + channel: str, + motors: list[DamiaoMotor], + *, + fd: bool = False, + interface: str = "socketcan", + ) -> None: + if not motors: + raise ValueError("OpenArmBus needs at least one motor") + # Enforce unique IDs — silent overlap would make state routing ambiguous. + send_ids = [m.send_id for m in motors] + if len(set(send_ids)) != len(send_ids): + raise ValueError(f"duplicate send_id in {send_ids}") + recv_ids = [m.effective_recv_id for m in motors] + if len(set(recv_ids)) != len(recv_ids): + raise ValueError(f"duplicate recv_id in {recv_ids}") + + self._channel = channel + self._motors = list(motors) + self._fd = fd + self._interface = interface + self._by_recv: dict[int, DamiaoMotor] = {m.effective_recv_id: m for m in motors} + + self._bus: "can.BusABC | None" = None + self._rx_thread: threading.Thread | None = None + self._rx_stop = threading.Event() + self._state_lock = threading.Lock() + self._states: dict[int, MotorState] = {} + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def open(self) -> None: + """Open the CAN bus and start the background RX thread.""" + if self._bus is not None: + return + import can # local import — python-can is optional + + self._bus = can.Bus(interface=self._interface, channel=self._channel, fd=self._fd) + self._rx_stop.clear() + self._rx_thread = threading.Thread( + target=self._rx_loop, name=f"openarm-rx-{self._channel}", daemon=True + ) + self._rx_thread.start() + + def close(self) -> None: + """Stop the RX thread and close the CAN bus.""" + self._rx_stop.set() + if self._rx_thread is not None: + self._rx_thread.join(timeout=1.0) + self._rx_thread = None + if self._bus is not None: + try: + self._bus.shutdown() + finally: + self._bus = None + + def __enter__(self) -> "OpenArmBus": + self.open() + return self + + def __exit__(self, *_exc: object) -> None: + self.close() + + # ------------------------------------------------------------------ + # Control commands + # ------------------------------------------------------------------ + + def enable_all(self) -> None: + for m in self._motors: + self._send_raw(m.send_id, _pack_control_command(_CMD_ENABLE)) + time.sleep(0.002) # 2ms inter-frame gap — gs_usb TX buffer is tiny + + def disable_all(self) -> None: + for m in self._motors: + self._send_raw(m.send_id, _pack_control_command(_CMD_DISABLE)) + time.sleep(0.002) + + def set_zero(self, send_id: int) -> None: + """Set current physical position as the motor's zero. + + Destructive — don't call this unless you know what you're doing. + """ + self._send_raw(send_id, _pack_control_command(_CMD_SET_ZERO)) + + def write_ctrl_mode(self, send_id: int, mode: int = CTRL_MODE_MIT) -> None: + """Write the persistent CTRL_MODE register (e.g. to switch to MIT).""" + self._send_raw( + _BROADCAST_ID, + pack_write_param_frame(send_id, _RID_CTRL_MODE, mode), + ) + + def send_mit( + self, + send_id: int, + q: float, + dq: float, + kp: float, + kd: float, + tau: float, + ) -> None: + """Send one MIT command frame to `send_id`.""" + motor = next((m for m in self._motors if m.send_id == send_id), None) + if motor is None: + raise KeyError(f"motor 0x{send_id:02X} not on bus {self._channel}") + data = pack_mit_frame(motor.motor_type, q, dq, kp, kd, tau) + self._send_raw(send_id, data) + + def send_mit_many( + self, + commands: list[tuple[float, float, float, float, float]], + ) -> None: + """Fan out one MIT frame per motor, in the order motors were registered. + + ``commands[i] = (q, dq, kp, kd, tau)`` is sent to ``self.motors[i]``. + """ + if len(commands) != len(self._motors): + raise ValueError( + f"expected {len(self._motors)} commands, got {len(commands)}" + ) + for i, (motor, cmd) in enumerate(zip(self._motors, commands)): + q, dq, kp, kd, tau = cmd + data = pack_mit_frame(motor.motor_type, q, dq, kp, kd, tau) + self._send_raw(motor.send_id, data) + # Tiny inter-frame gap to avoid TX buffer overflow on gs_usb. + # 7 frames × 0.5ms = 3.5ms total, well within a 10ms tick. + if i < len(self._motors) - 1: + time.sleep(0.0005) + + # ------------------------------------------------------------------ + # State access + # ------------------------------------------------------------------ + + @property + def motors(self) -> tuple[DamiaoMotor, ...]: + return tuple(self._motors) + + def get_state(self, send_id: int) -> MotorState | None: + motor = next((m for m in self._motors if m.send_id == send_id), None) + if motor is None: + return None + with self._state_lock: + return self._states.get(motor.effective_recv_id) + + def get_states(self) -> list[MotorState | None]: + """Latest cached state in motor-registration order.""" + with self._state_lock: + return [self._states.get(m.effective_recv_id) for m in self._motors] + + def wait_all_states(self, timeout: float = 0.5) -> bool: + """Block until every motor has reported at least once (or timeout).""" + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + with self._state_lock: + if all( + m.effective_recv_id in self._states for m in self._motors + ): + return True + time.sleep(0.005) + return False + + # ------------------------------------------------------------------ + # Internals + # ------------------------------------------------------------------ + + def _send_raw(self, arbitration_id: int, data: bytes) -> None: + if self._bus is None: + raise RuntimeError("bus not open — call .open() first") + import can + + msg = can.Message( + arbitration_id=arbitration_id, + data=data, + is_extended_id=False, + is_fd=self._fd, + bitrate_switch=self._fd, + ) + # Retry on TX buffer full (ENOBUFS / errno 105) — the gs_usb + # adapter has a tiny kernel-side TX queue. A short backoff lets + # the kernel drain one frame before we try again. + for attempt in range(4): + try: + self._bus.send(msg) + return + except can.CanOperationError as e: + if "105" in str(e) and attempt < 3: + time.sleep(0.001 * (attempt + 1)) + else: + raise + + def _rx_loop(self) -> None: + assert self._bus is not None + while not self._rx_stop.is_set(): + msg = self._bus.recv(timeout=0.05) + if msg is None: + continue + motor = self._by_recv.get(int(msg.arbitration_id)) + if motor is None: + continue + state = parse_state_frame(motor.motor_type, bytes(msg.data)) + if state is None: + continue + with self._state_lock: + self._states[motor.effective_recv_id] = state + + +__all__ = [ + "CTRL_MODE_MIT", + "DamiaoMotor", + "KD_MAX", + "KD_MIN", + "KP_MAX", + "KP_MIN", + "MotorState", + "MotorType", + "OpenArmBus", + "float_to_uint", + "pack_mit_frame", + "pack_write_param_frame", + "parse_state_frame", + "uint_to_float", +] From f1edf1b5022cd0e49c5edc23ab07f6bdcc79928d Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Tue, 21 Apr 2026 15:18:25 -0700 Subject: [PATCH 03/30] driver testing script --- .../manipulators/openarm/test_driver.py | 275 ++++++++++++++++++ 1 file changed, 275 insertions(+) create mode 100644 dimos/hardware/manipulators/openarm/test_driver.py diff --git a/dimos/hardware/manipulators/openarm/test_driver.py b/dimos/hardware/manipulators/openarm/test_driver.py new file mode 100644 index 0000000000..829f7970ef --- /dev/null +++ b/dimos/hardware/manipulators/openarm/test_driver.py @@ -0,0 +1,275 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +"""Unit tests for the Damiao MIT-mode driver — no hardware required. + +Uses ``can.Bus(interface="virtual")`` for loopback. +""" + +from __future__ import annotations + +import struct +import time + +import pytest + +can = pytest.importorskip("can") + +from dimos.hardware.manipulators.openarm.driver import ( + CTRL_MODE_MIT, + DamiaoMotor, + KD_MAX, + KP_MAX, + MotorType, + OpenArmBus, + float_to_uint, + pack_mit_frame, + pack_write_param_frame, + parse_state_frame, + uint_to_float, +) + + +# --------------------------------------------------------------------------- +# Pack / unpack primitives +# --------------------------------------------------------------------------- + + +def test_float_to_uint_endpoints_and_roundtrip() -> None: + # Endpoints + assert float_to_uint(-12.5, -12.5, 12.5, 16) == 0 + assert float_to_uint(12.5, -12.5, 12.5, 16) == (1 << 16) - 1 + # Midpoint is half the full range (rounded down) + mid = float_to_uint(0.0, -12.5, 12.5, 16) + assert mid in ((1 << 16) // 2 - 1, (1 << 16) // 2) + # Out-of-range clamps + assert float_to_uint(-100.0, -12.5, 12.5, 16) == 0 + assert float_to_uint(100.0, -12.5, 12.5, 16) == (1 << 16) - 1 + + +def test_roundtrip_all_gain_ranges() -> None: + # Quantization error should be tiny + for bits, lo, hi in [(16, -12.5, 12.5), (12, 0.0, KP_MAX), (12, 0.0, KD_MAX)]: + step = (hi - lo) / ((1 << bits) - 1) + for k in range(0, 1 << bits, max(1, (1 << bits) // 50)): + x = lo + k * step + u = float_to_uint(x, lo, hi, bits) + x2 = uint_to_float(u, lo, hi, bits) + assert abs(x - x2) <= step + + +def test_mit_frame_kp_kd_zero_and_pos_zero() -> None: + # q=dq=kp=kd=tau=0 → q_u = 32767 (16-bit midpoint), dq_u = 2047 (12-bit), + # tau_u = 2047. kp_u = kd_u = 0 (min of their 0-positive range). + data = pack_mit_frame(MotorType.DM4310, 0.0, 0.0, 0.0, 0.0, 0.0) + assert len(data) == 8 + # Reconstruct fields from bytes + q_u = (data[0] << 8) | data[1] + dq_u = (data[2] << 4) | (data[3] >> 4) + kp_u = ((data[3] & 0xF) << 8) | data[4] + kd_u = (data[5] << 4) | (data[6] >> 4) + tau_u = ((data[6] & 0xF) << 8) | data[7] + assert kp_u == 0 + assert kd_u == 0 + # 16-bit midpoint of symmetric range + assert q_u in (32767, 32768) + assert dq_u in (2047, 2048) + assert tau_u in (2047, 2048) + + +def test_mit_frame_full_positive() -> None: + # Command at every max → every _u field saturates. + data = pack_mit_frame(MotorType.DM4310, 12.5, 30.0, 500.0, 5.0, 10.0) + q_u = (data[0] << 8) | data[1] + dq_u = (data[2] << 4) | (data[3] >> 4) + kp_u = ((data[3] & 0xF) << 8) | data[4] + kd_u = (data[5] << 4) | (data[6] >> 4) + tau_u = ((data[6] & 0xF) << 8) | data[7] + assert q_u == 0xFFFF + assert dq_u == 0xFFF + assert kp_u == 0xFFF + assert kd_u == 0xFFF + assert tau_u == 0xFFF + + +def test_parse_state_roundtrip() -> None: + # Build a synthetic reply frame with known values and verify decode. + # Byte layout for state: [echo, q_hi, q_lo, dq_hi, dq_lo|tau_hi, tau_lo, t_mos, t_rotor] + motor = MotorType.DM4340 + p_max, v_max, t_max = 12.5, 8.0, 28.0 + q_u = float_to_uint(0.3, -p_max, p_max, 16) + dq_u = float_to_uint(-1.0, -v_max, v_max, 12) + tau_u = float_to_uint(2.0, -t_max, t_max, 12) + data = bytes( + [ + 0x03, + (q_u >> 8) & 0xFF, + q_u & 0xFF, + (dq_u >> 4) & 0xFF, + ((dq_u & 0xF) << 4) | ((tau_u >> 8) & 0xF), + tau_u & 0xFF, + 33, + 28, + ] + ) + state = parse_state_frame(motor, data) + assert state is not None + assert abs(state.q - 0.3) < 0.001 + assert abs(state.dq - (-1.0)) < 0.01 + assert abs(state.tau - 2.0) < 0.02 + assert state.t_mos == 33 + assert state.t_rotor == 28 + + +def test_parse_state_rejects_short_frames() -> None: + assert parse_state_frame(MotorType.DM4310, b"\x00" * 4) is None + + +def test_pack_write_param_ctrl_mode_mit() -> None: + data = pack_write_param_frame(0x05, 10, CTRL_MODE_MIT) + assert data[0] == 0x05 + assert data[1] == 0x00 + assert data[2] == 0x55 + assert data[3] == 10 + assert struct.unpack(" OpenArmBus: + return OpenArmBus(channel=channel, motors=motors, fd=False, interface="virtual") + + +def test_bus_validates_unique_ids() -> None: + with pytest.raises(ValueError, match="duplicate send_id"): + OpenArmBus( + channel="v0", + motors=[ + DamiaoMotor(0x01, MotorType.DM4310), + DamiaoMotor(0x01, MotorType.DM4310), + ], + fd=False, + interface="virtual", + ) + + +def test_bus_empty_motor_list_rejected() -> None: + with pytest.raises(ValueError): + OpenArmBus(channel="v0", motors=[], fd=False, interface="virtual") + + +def test_rx_thread_populates_state_cache() -> None: + # Two peers on the same virtual channel loop back to each other. + motors = [ + DamiaoMotor(0x01, MotorType.DM8006), + DamiaoMotor(0x05, MotorType.DM4310), + ] + bus = _make_bus("openarm-test-rx", motors) + # A raw sender on the same virtual channel injects state replies. + sender = can.Bus(interface="virtual", channel="openarm-test-rx") + try: + bus.open() + # Forge a reply for motor 0x01 (recv 0x11) at q = 0.25 rad + q_u = float_to_uint(0.25, -12.5, 12.5, 16) + dq_u = float_to_uint(0.0, -45.0, 45.0, 12) + tau_u = float_to_uint(0.0, -40.0, 40.0, 12) + payload = bytes( + [ + 0x01, + (q_u >> 8) & 0xFF, + q_u & 0xFF, + (dq_u >> 4) & 0xFF, + ((dq_u & 0xF) << 4) | ((tau_u >> 8) & 0xF), + tau_u & 0xFF, + 30, + 28, + ] + ) + sender.send( + can.Message(arbitration_id=0x11, data=payload, is_extended_id=False) + ) + # Poll briefly for the RX thread to consume it + deadline = time.monotonic() + 0.5 + s = None + while s is None and time.monotonic() < deadline: + s = bus.get_state(0x01) + time.sleep(0.01) + assert s is not None, "RX thread did not pick up synthetic state reply" + assert abs(s.q - 0.25) < 0.001 + # Motor 0x05 never got a reply → state should still be None + assert bus.get_state(0x05) is None + finally: + bus.close() + sender.shutdown() + + +def test_send_mit_many_fans_out_one_per_motor() -> None: + motors = [ + DamiaoMotor(0x01, MotorType.DM8006), + DamiaoMotor(0x02, MotorType.DM8006), + DamiaoMotor(0x05, MotorType.DM4310), + ] + bus = _make_bus("openarm-test-send", motors) + listener = can.Bus(interface="virtual", channel="openarm-test-send") + try: + bus.open() + bus.send_mit_many( + [ + (0.1, 0.0, 10.0, 0.5, 0.0), + (0.2, 0.0, 10.0, 0.5, 0.0), + (0.3, 0.0, 10.0, 0.5, 0.0), + ] + ) + seen_ids: set[int] = set() + deadline = time.monotonic() + 0.5 + while len(seen_ids) < 3 and time.monotonic() < deadline: + msg = listener.recv(timeout=0.1) + if msg is not None: + seen_ids.add(int(msg.arbitration_id)) + assert seen_ids == {0x01, 0x02, 0x05} + finally: + bus.close() + listener.shutdown() + + +def test_send_mit_many_size_mismatch() -> None: + bus = _make_bus( + "openarm-test-mismatch", + [DamiaoMotor(0x01, MotorType.DM4310), DamiaoMotor(0x02, MotorType.DM4310)], + ) + try: + bus.open() + with pytest.raises(ValueError): + bus.send_mit_many([(0.0, 0.0, 0.0, 0.0, 0.0)]) + finally: + bus.close() + + +def test_enable_disable_frames_sent() -> None: + bus = _make_bus( + "openarm-test-enable", + [DamiaoMotor(0x01, MotorType.DM4310), DamiaoMotor(0x05, MotorType.DM4310)], + ) + listener = can.Bus(interface="virtual", channel="openarm-test-enable") + try: + bus.open() + bus.enable_all() + seen = {} + deadline = time.monotonic() + 0.3 + while len(seen) < 2 and time.monotonic() < deadline: + msg = listener.recv(timeout=0.1) + if msg is not None: + seen[int(msg.arbitration_id)] = bytes(msg.data) + assert set(seen) == {0x01, 0x05} + for data in seen.values(): + assert data == bytes([0xFF] * 7 + [0xFC]) + finally: + bus.close() + listener.shutdown() From f20d29c288a1b99832936b7639087e58877806df Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Tue, 21 Apr 2026 15:19:11 -0700 Subject: [PATCH 04/30] added openarm blueprint and catalog file --- dimos/robot/all_blueprints.py | 8 + dimos/robot/catalog/openarm.py | 147 +++++++++++ .../robot/manipulators/openarm/blueprints.py | 235 ++++++++++++++++++ 3 files changed, 390 insertions(+) create mode 100644 dimos/robot/catalog/openarm.py create mode 100644 dimos/robot/manipulators/openarm/blueprints.py diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 5289ec74de..74a45f5e8f 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -25,6 +25,10 @@ "coordinator-mobile-manip-mock": "dimos.control.blueprints.mobile:coordinator_mobile_manip_mock", "coordinator-mock": "dimos.control.blueprints.basic:coordinator_mock", "coordinator-mock-twist-base": "dimos.control.blueprints.mobile:coordinator_mock_twist_base", + "coordinator-openarm-bimanual": "dimos.robot.manipulators.openarm.blueprints:coordinator_openarm_bimanual", + "coordinator-openarm-left": "dimos.robot.manipulators.openarm.blueprints:coordinator_openarm_left", + "coordinator-openarm-mock": "dimos.robot.manipulators.openarm.blueprints:coordinator_openarm_mock", + "coordinator-openarm-right": "dimos.robot.manipulators.openarm.blueprints:coordinator_openarm_right", "coordinator-piper": "dimos.control.blueprints.basic:coordinator_piper", "coordinator-piper-xarm": "dimos.control.blueprints.dual:coordinator_piper_xarm", "coordinator-servo-xarm6": "dimos.control.blueprints.teleop:coordinator_servo_xarm6", @@ -49,6 +53,8 @@ "drone-agentic": "dimos.robot.drone.blueprints.agentic.drone_agentic:drone_agentic", "drone-basic": "dimos.robot.drone.blueprints.basic.drone_basic:drone_basic", "dual-xarm6-planner": "dimos.manipulation.blueprints:dual_xarm6_planner", + "keyboard-teleop-openarm": "dimos.robot.manipulators.openarm.blueprints:keyboard_teleop_openarm", + "keyboard-teleop-openarm-mock": "dimos.robot.manipulators.openarm.blueprints:keyboard_teleop_openarm_mock", "keyboard-teleop-piper": "dimos.robot.manipulators.piper.blueprints:keyboard_teleop_piper", "keyboard-teleop-xarm6": "dimos.robot.manipulators.xarm.blueprints:keyboard_teleop_xarm6", "keyboard-teleop-xarm7": "dimos.robot.manipulators.xarm.blueprints:keyboard_teleop_xarm7", @@ -56,6 +62,8 @@ "mid360-fastlio": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio", "mid360-fastlio-voxels": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio_voxels", "mid360-fastlio-voxels-native": "dimos.hardware.sensors.lidar.fastlio2.fastlio_blueprints:mid360_fastlio_voxels_native", + "openarm-mock-planner-coordinator": "dimos.robot.manipulators.openarm.blueprints:openarm_mock_planner_coordinator", + "openarm-planner-coordinator": "dimos.robot.manipulators.openarm.blueprints:openarm_planner_coordinator", "teleop-phone": "dimos.teleop.phone.blueprints:teleop_phone", "teleop-phone-go2": "dimos.teleop.phone.blueprints:teleop_phone_go2", "teleop-phone-go2-fleet": "dimos.teleop.phone.blueprints:teleop_phone_go2_fleet", diff --git a/dimos/robot/catalog/openarm.py b/dimos/robot/catalog/openarm.py new file mode 100644 index 0000000000..91c49e123b --- /dev/null +++ b/dimos/robot/catalog/openarm.py @@ -0,0 +1,147 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OpenArm v10 robot configurations. + +Ships in bimanual form — two 7-DOF arms on a shared torso. We load the +full bimanual URDF and select each arm's 7 joints via ``joint_names``. +Drake handles the rest. + +The URDF is referenced from the in-tree ``data/openarm_description/`` +folder during bring-up; the plan is to migrate it to LFS at the end of +the integration (only the ``_OPENARM_MODEL_PATH`` constant below changes). +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +from dimos.robot.config import RobotConfig + +# Collision exclusion pairs — structural mesh overlaps in the OpenArm URDF. +# link5 and link7 collision meshes overlap by ~3mm at zero pose (and every +# other pose) — same pattern as R1 Pro's non-adjacent link overlap. +OPENARM_COLLISION_EXCLUSIONS: list[tuple[str, str]] = [ + ("openarm_left_link5", "openarm_left_link7"), + ("openarm_right_link5", "openarm_right_link7"), +] + +# Local path during bring-up. Swap to ``LfsPath("openarm_description/...")`` +# once the URDF is migrated to LFS. +_REPO_ROOT = Path(__file__).resolve().parents[3] +_OPENARM_PKG = _REPO_ROOT / "data" / "openarm_description" +_OPENARM_MODEL_PATH = _OPENARM_PKG / "urdf" / "robot" / "openarm_v10_bimanual.urdf" +# Per-side URDFs: extracted from bimanual expansion, only one arm + torso each. +# Avoids phantom-arm collisions when Drake loads both sides into one world. +_OPENARM_LEFT_MODEL = _OPENARM_PKG / "urdf" / "robot" / "openarm_v10_left.urdf" +_OPENARM_RIGHT_MODEL = _OPENARM_PKG / "urdf" / "robot" / "openarm_v10_right.urdf" + +# Pre-expanded single-arm URDF for Pinocchio FK (keyboard teleop, IK, etc.) +# Pinocchio doesn't handle xacro — this file is the expansion of the bimanual +# xacro with only one side's links kept. +OPENARM_V10_FK_MODEL = _OPENARM_PKG / "urdf" / "robot" / "openarm_v10_single.urdf" + + +def openarm_arm( + side: str = "left", + name: str | None = None, + *, + adapter_type: str = "mock", + address: str | None = None, + **overrides: Any, +) -> RobotConfig: + """Create an OpenArm v10 config for one side of the bimanual rig. + + Both sides share the bimanual URDF (built with ``bimanual:=true``) and + pick out their 7 arm joints via ``joint_names``. Drake's SetAutoRenaming + handles duplicate model names when both configs go into one planning + world. + + Args: + side: ``"left"`` or ``"right"``. + name: Robot name (defaults to ``f"{side}_arm"``). + adapter_type: ``"mock"`` for dry runs, ``"openarm"`` for real CAN. + address: SocketCAN channel (e.g. ``"can0"``) — required for real hw. + **overrides: Override any ``RobotConfig`` field. + """ + if side not in ("left", "right"): + raise ValueError(f"side must be 'left' or 'right', got {side!r}") + + resolved_name = name or f"{side}_arm" + # Pre-expanded bimanual URDF uses openarm_{side}_* naming. + joint_names = [f"openarm_{side}_joint{i}" for i in range(1, 8)] + ee_link = f"openarm_{side}_link7" + + defaults: dict[str, Any] = { + "name": resolved_name, + "model_path": _OPENARM_LEFT_MODEL if side == "left" else _OPENARM_RIGHT_MODEL, + "end_effector_link": ee_link, + "adapter_type": adapter_type, + "address": address, + "joint_names": joint_names, + # URDF already prefixes joints with "left_"/"right_" in bimanual mode, + # so suppress RobotConfig's automatic "{name}_" prefix. + "joint_prefix": "", + "base_link": "openarm_body_link0", + "home_joints": [0.0] * 7, + "base_pose": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], + "package_paths": {"openarm_description": _OPENARM_PKG}, + "collision_exclusion_pairs": OPENARM_COLLISION_EXCLUSIONS, + "auto_convert_meshes": True, + "max_velocity": 0.5, + "max_acceleration": 1.0, + "adapter_kwargs": {"side": side}, + } + # Merge adapter_kwargs rather than replace, so callers can add keys + # (e.g. auto_set_mit_mode) without clobbering the catalog's "side". + if "adapter_kwargs" in overrides: + defaults["adapter_kwargs"] = {**defaults["adapter_kwargs"], **overrides.pop("adapter_kwargs")} + defaults.update(overrides) + return RobotConfig(**defaults) + + +def openarm_single( + name: str = "arm", + *, + adapter_type: str = "mock", + address: str | None = None, + **overrides: Any, +) -> RobotConfig: + """Single-arm (non-bimanual) config for keyboard teleop / cartesian IK. + + Uses the pre-expanded single-arm URDF (bimanual:=false). Joint names are + ``openarm_joint1..7``. For bimanual planning use :func:`openarm_arm`. + """ + defaults: dict[str, Any] = { + "name": name, + "model_path": OPENARM_V10_FK_MODEL, + "end_effector_link": "openarm_left_link7", + "adapter_type": adapter_type, + "address": address, + "joint_names": [f"openarm_left_joint{i}" for i in range(1, 8)], + "joint_prefix": "", + "base_link": "openarm_body_link0", + "home_joints": [0.0] * 7, + "package_paths": {"openarm_description": _OPENARM_PKG}, + "auto_convert_meshes": True, + "max_velocity": 0.5, + "max_acceleration": 1.0, + "adapter_kwargs": {"side": "left"}, + } + defaults.update(overrides) + return RobotConfig(**defaults) + + +__all__ = ["OPENARM_V10_FK_MODEL", "openarm_arm", "openarm_single"] diff --git a/dimos/robot/manipulators/openarm/blueprints.py b/dimos/robot/manipulators/openarm/blueprints.py new file mode 100644 index 0000000000..05b29f9e87 --- /dev/null +++ b/dimos/robot/manipulators/openarm/blueprints.py @@ -0,0 +1,235 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OpenArm coordinator blueprints (single-arm + bimanual). + +Usage: + dimos run coordinator-openarm-mock # Bimanual, mock adapters + dimos run coordinator-openarm-left # Single arm, real CAN + dimos run coordinator-openarm-right # Single arm, real CAN + dimos run coordinator-openarm-bimanual # Both arms, real CAN + +The CAN interface each arm uses is set by the ``LEFT_CAN``/``RIGHT_CAN`` +constants below. Linux enumerates gs_usb adapters in USB-discovery order +which isn't guaranteed stable, so if your arms come up "swapped" just flip +these two values and rerun — no other code changes needed. +""" + +from __future__ import annotations + +from dimos.control.coordinator import ControlCoordinator +from dimos.core.coordination.blueprints import autoconnect +from dimos.core.transport import LCMTransport +from dimos.manipulation.manipulation_module import ManipulationModule +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.robot.catalog.openarm import ( + OPENARM_V10_FK_MODEL, + openarm_arm as _openarm, + openarm_single as _openarm_single, +) +from dimos.teleop.keyboard.keyboard_teleop_module import KeyboardTeleopModule + +# ── Mock bimanual: no hardware, great for verifying wiring ───────────── +_mock_left = _openarm(side="left", name="left_arm") +_mock_right = _openarm(side="right", name="right_arm") + +coordinator_openarm_mock = ControlCoordinator.blueprint( + hardware=[_mock_left.to_hardware_component(), _mock_right.to_hardware_component()], + tasks=[ + _mock_left.to_task_config(), + _mock_right.to_task_config(), + ], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + } +) + +# ── Single-arm hardware blueprints (first real bring-up targets) ─────── +# CAN interface each physical arm is on. Linux assigns can0/can1 in USB +# enumeration order which isn't guaranteed stable — if the arms come up +# swapped, flip these two values. +LEFT_CAN = "can1" +RIGHT_CAN = "can0" + +# Flip to False to skip the CTRL_MODE=MIT write at connect-time — useful for +# verifying the setting persists across power cycles. Leave True for normal +# operation (idempotent; ensures motors work even if they were reflashed / +# replaced / factory-reset). +AUTO_SET_MIT_MODE = True + +_HW_KW = dict(adapter_type="openarm", + adapter_kwargs={"auto_set_mit_mode": AUTO_SET_MIT_MODE}) +_left_hw = _openarm(side="left", name="left_arm", address=LEFT_CAN, **_HW_KW) +_right_hw = _openarm(side="right", name="right_arm", address=RIGHT_CAN, **_HW_KW) + +coordinator_openarm_left = ControlCoordinator.blueprint( + hardware=[_left_hw.to_hardware_component()], + tasks=[_left_hw.to_task_config()], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + } +) + +coordinator_openarm_right = ControlCoordinator.blueprint( + hardware=[_right_hw.to_hardware_component()], + tasks=[_right_hw.to_task_config()], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + } +) + +# ── Bimanual hardware blueprint ──────────────────────────────────────── +coordinator_openarm_bimanual = ControlCoordinator.blueprint( + hardware=[_left_hw.to_hardware_component(), _right_hw.to_hardware_component()], + tasks=[ + _left_hw.to_task_config(), + _right_hw.to_task_config(), + ], +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + } +) + + +# ── Planner + coordinator (mock): Drake plans, mock adapters execute ──── +# Great for visualizing motions in Meshcat with no hardware. +openarm_mock_planner_coordinator = autoconnect( + ManipulationModule.blueprint( + robots=[_mock_left.to_robot_model_config(), _mock_right.to_robot_model_config()], + planning_timeout=10.0, + enable_viz=True, + ), + ControlCoordinator.blueprint( + tick_rate=100.0, + publish_joint_state=True, + joint_state_frame_id="coordinator", + hardware=[_mock_left.to_hardware_component(), _mock_right.to_hardware_component()], + tasks=[ + _mock_left.to_task_config(), + _mock_right.to_task_config(), + ], + ), +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + } +) + +# ── Planner + coordinator (real hw): plan & execute on both arms ──────── +openarm_planner_coordinator = autoconnect( + ManipulationModule.blueprint( + robots=[_left_hw.to_robot_model_config(), _right_hw.to_robot_model_config()], + planning_timeout=10.0, + enable_viz=True, + ), + ControlCoordinator.blueprint( + tick_rate=100.0, + publish_joint_state=True, + joint_state_frame_id="coordinator", + hardware=[_left_hw.to_hardware_component(), _right_hw.to_hardware_component()], + tasks=[ + _left_hw.to_task_config(), + _right_hw.to_task_config(), + ], + ), +).transports( + { + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + } +) + + +# ── Keyboard teleop (single arm, mock) ────────────────────────────────── +# pygame keyboard UI → Cartesian IK (Drake) → mock coordinator execution, +# with Drake/Meshcat visualization. Good for testing the single-arm URDF +# and IK without touching hardware. +_teleop_cfg = _openarm_single(name="arm") + +keyboard_teleop_openarm_mock = autoconnect( + KeyboardTeleopModule.blueprint(model_path=str(OPENARM_V10_FK_MODEL), ee_joint_id=_teleop_cfg.dof), + ControlCoordinator.blueprint( + tick_rate=100.0, + publish_joint_state=True, + joint_state_frame_id="coordinator", + hardware=[_teleop_cfg.to_hardware_component()], + tasks=[ + _teleop_cfg.to_task_config( + task_type="cartesian_ik", + task_name="cartesian_ik_arm", + model_path=OPENARM_V10_FK_MODEL, + ee_joint_id=_teleop_cfg.dof, + ), + ], + ), + ManipulationModule.blueprint( + robots=[_teleop_cfg.to_robot_model_config()], + enable_viz=True, + ), +).transports( + { + ("cartesian_command", PoseStamped): LCMTransport( + "/coordinator/cartesian_command", PoseStamped + ), + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + } +) + +# ── Keyboard teleop (single arm, real hw on can0) ─────────────────────── +_teleop_hw_cfg = _openarm_single(name="arm", adapter_type="openarm", address=LEFT_CAN) + +keyboard_teleop_openarm = autoconnect( + KeyboardTeleopModule.blueprint(model_path=str(OPENARM_V10_FK_MODEL), ee_joint_id=_teleop_hw_cfg.dof), + ControlCoordinator.blueprint( + tick_rate=100.0, + publish_joint_state=True, + joint_state_frame_id="coordinator", + hardware=[_teleop_hw_cfg.to_hardware_component()], + tasks=[ + _teleop_hw_cfg.to_task_config( + task_type="cartesian_ik", + task_name="cartesian_ik_arm", + model_path=OPENARM_V10_FK_MODEL, + ee_joint_id=_teleop_hw_cfg.dof, + ), + ], + ), + ManipulationModule.blueprint( + robots=[_teleop_hw_cfg.to_robot_model_config()], + enable_viz=True, + ), +).transports( + { + ("cartesian_command", PoseStamped): LCMTransport( + "/coordinator/cartesian_command", PoseStamped + ), + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + } +) + + +__all__ = [ + "coordinator_openarm_bimanual", + "coordinator_openarm_left", + "coordinator_openarm_mock", + "coordinator_openarm_right", + "keyboard_teleop_openarm", + "keyboard_teleop_openarm_mock", + "openarm_mock_planner_coordinator", + "openarm_planner_coordinator", +] From d661fd5db2473df1add808a8ef6beb4d4cc1b161 Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Tue, 21 Apr 2026 15:19:24 -0700 Subject: [PATCH 05/30] debug scripts for openarm --- .../openarm/scripts/openarm_can_probe.py | 156 ++++++++ dimos/utils/workspace.py | 338 ++++++++++++++++++ 2 files changed, 494 insertions(+) create mode 100755 dimos/robot/manipulators/openarm/scripts/openarm_can_probe.py create mode 100644 dimos/utils/workspace.py diff --git a/dimos/robot/manipulators/openarm/scripts/openarm_can_probe.py b/dimos/robot/manipulators/openarm/scripts/openarm_can_probe.py new file mode 100755 index 0000000000..036b0a587f --- /dev/null +++ b/dimos/robot/manipulators/openarm/scripts/openarm_can_probe.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 +"""Probe an OpenArm on a SocketCAN interface. + +Enumerates all 8 expected Damiao motors (7 arm joints + gripper) on one CAN-FD +bus, enables each, reads back one state frame, then disables. This is the +Phase-0 hardware-verification script — if this does not work, nothing +downstream will. + +Run AFTER bringing the bus up with dimos/robot/manipulators/openarm/scripts/openarm_can_up.sh. + +Usage: + python dimos/robot/manipulators/openarm/scripts/openarm_can_probe.py --channel can0 + python dimos/robot/manipulators/openarm/scripts/openarm_can_probe.py --channel can1 --ids 1,2,3,4,5,6,7 +""" +from __future__ import annotations + +import argparse +import sys +import time + +try: + import can +except ImportError: + sys.exit("python-can not installed. Run: pip install 'python-can>=4.3'") + +# ---- Damiao motor limit tables (from enactic/openarm_can dm_motor_constants.hpp) +# [p_max rad, v_max rad/s, t_max Nm] +LIMITS: dict[str, tuple[float, float, float]] = { + "DM4310": (12.5, 30.0, 10.0), + "DM4340": (12.5, 8.0, 28.0), + "DM8006": (12.5, 45.0, 40.0), +} + +# OpenArm v10 per-joint motor assignment (derived from joint_limits.yaml effort column) +DEFAULT_MOTORS: list[tuple[int, str]] = [ + (0x01, "DM8006"), # joint1 + (0x02, "DM8006"), # joint2 + (0x03, "DM4340"), # joint3 + (0x04, "DM4340"), # joint4 + (0x05, "DM4310"), # joint5 + (0x06, "DM4310"), # joint6 + (0x07, "DM4310"), # joint7 + (0x08, "DM4310"), # gripper +] + +ENABLE = bytes([0xFF] * 7 + [0xFC]) +DISABLE = bytes([0xFF] * 7 + [0xFD]) + +FD = True # set by --classical at runtime + + +def uint_to_float(x: int, lo: float, hi: float, bits: int) -> float: + return x / ((1 << bits) - 1) * (hi - lo) + lo + + +def parse_state(motor_type: str, data: bytes) -> tuple[float, float, float, int, int] | None: + """Decode an 8-byte DM motor state reply. Returns (q, dq, tau, t_mos, t_rotor).""" + if len(data) < 8: + return None + p_max, v_max, t_max = LIMITS[motor_type] + q_u = (data[1] << 8) | data[2] + dq_u = (data[3] << 4) | (data[4] >> 4) + tau_u = ((data[4] & 0x0F) << 8) | data[5] + q = uint_to_float(q_u, -p_max, p_max, 16) + dq = uint_to_float(dq_u, -v_max, v_max, 12) + tau = uint_to_float(tau_u, -t_max, t_max, 12) + return q, dq, tau, data[6], data[7] + + +def probe_motor(bus: can.BusABC, send_id: int, recv_id: int, + motor_type: str, timeout: float = 0.2) -> bool: + """Enable motor, wait for state reply on recv_id, print result, disable.""" + # Flush any stale frames + while bus.recv(0.0) is not None: + pass + + bus.send(can.Message(arbitration_id=send_id, data=ENABLE, + is_extended_id=False, is_fd=FD, bitrate_switch=FD)) + t0 = time.monotonic() + while time.monotonic() - t0 < timeout: + msg = bus.recv(timeout - (time.monotonic() - t0)) + if msg is None: + break + if msg.arbitration_id != recv_id: + continue + parsed = parse_state(motor_type, bytes(msg.data)) + if parsed is None: + print(f" 0x{send_id:02X} ({motor_type}): short reply {list(msg.data)}") + bus.send(can.Message(arbitration_id=send_id, data=DISABLE, + is_extended_id=False, is_fd=FD, bitrate_switch=FD)) + return False + q, dq, tau, t_mos, t_rot = parsed + print(f" 0x{send_id:02X} ({motor_type:>6}): " + f"q={q:+.3f} rad dq={dq:+.3f} rad/s tau={tau:+.3f} Nm " + f"T_mos={t_mos}C T_rotor={t_rot}C") + bus.send(can.Message(arbitration_id=send_id, data=DISABLE, + is_extended_id=False, is_fd=FD, bitrate_switch=FD)) + return True + + print(f" 0x{send_id:02X} ({motor_type:>6}): NO REPLY on 0x{recv_id:02X} within {timeout*1e3:.0f}ms") + bus.send(can.Message(arbitration_id=send_id, data=DISABLE, + is_extended_id=False, is_fd=FD, bitrate_switch=FD)) + return False + + +def main() -> int: + ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + ap.add_argument("--channel", default="can0", help="SocketCAN interface (default: can0)") + ap.add_argument("--fd", action="store_true", help="Use CAN-FD (requires FD-capable adapter). Default is classical CAN @ 1 Mbit, which is what most gs_usb adapters support.") + ap.add_argument("--ids", default=None, + help="Comma-separated send IDs to probe (default: 1..8)") + ap.add_argument("--timeout", type=float, default=0.2, help="Reply timeout per motor (s)") + args = ap.parse_args() + + global FD + FD = args.fd + motors = DEFAULT_MOTORS + if args.ids: + wanted = {int(x, 0) for x in args.ids.split(",")} + motors = [m for m in DEFAULT_MOTORS if m[0] in wanted] + + # Preflight: is the interface up? + try: + flags = int(open(f"/sys/class/net/{args.channel}/flags").read().strip(), 16) + iface_up = bool(flags & 0x1) + except OSError: + print(f"ERROR: interface '{args.channel}' not found", file=sys.stderr) + return 1 + if not iface_up: + print(f"ERROR: SocketCAN interface '{args.channel}' is DOWN.", file=sys.stderr) + print(f" Run: sudo ./dimos/robot/manipulators/openarm/scripts/openarm_can_up.sh {args.channel}", file=sys.stderr) + return 1 + + print(f"Opening {args.channel} ({'CAN-FD' if FD else 'classical CAN'})...") + try: + bus = can.Bus(interface="socketcan", channel=args.channel, fd=FD) + except Exception as e: + print(f"ERROR opening {args.channel}: {e}", file=sys.stderr) + print(" Did you run 'sudo ./dimos/robot/manipulators/openarm/scripts/openarm_can_up.sh' first?", file=sys.stderr) + return 1 + + try: + print(f"Probing {len(motors)} motor(s) on {args.channel}:") + ok = 0 + for send_id, motor_type in motors: + recv_id = send_id | 0x10 + if probe_motor(bus, send_id, recv_id, motor_type, args.timeout): + ok += 1 + print(f"\n{ok}/{len(motors)} motors replied.") + return 0 if ok == len(motors) else 2 + finally: + bus.shutdown() + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dimos/utils/workspace.py b/dimos/utils/workspace.py new file mode 100644 index 0000000000..2fcd2fba8b --- /dev/null +++ b/dimos/utils/workspace.py @@ -0,0 +1,338 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reachability / manipulability workspace analysis for any serial robot. + +Samples joint configurations uniformly within URDF limits, runs FK, and +collects (a) end-effector Cartesian positions and (b) the Yoshikawa +manipulability index ``sqrt(det(J·Jᵀ))`` at each pose. Backed by Pinocchio. + +Library use: + from dimos.utils.workspace import WorkspaceMap + ws = WorkspaceMap("path/to/robot.urdf", n_samples=100_000) + result = ws.query((0.1, 0.3, 0.5)) + if result["reachable"]: + print(result["best_config"]) + +CLI: + python -m dimos.utils.workspace path/to/robot.urdf # visualize + python -m dimos.utils.workspace path/to/robot.urdf query 0.1 0.3 0.5 + python -m dimos.utils.workspace path/to/robot.urdf suggest 0.1 0.3 0.5 + python -m dimos.utils.workspace path/to/robot.urdf interactive +""" + +from __future__ import annotations + +import argparse +import sys +import time +from pathlib import Path +from typing import Any + +import numpy as np + + +class WorkspaceMap: + """Precomputed reachability + manipulability map with spatial lookup. + + Not specific to any robot — works on any serial-chain URDF that + Pinocchio can load. + """ + + def __init__( + self, + urdf_path: str | Path, + n_samples: int = 100_000, + *, + ee_joint_id: int | None = None, + seed: int = 42, + ) -> None: + """Build the workspace map. + + Args: + urdf_path: Path to a URDF parseable by Pinocchio (no xacro, no + ``package://`` URIs unless the mesh dirs resolve relative to + ``urdf_path``). + n_samples: Number of random joint configurations to sample. + ee_joint_id: Pinocchio joint index to treat as the end-effector. + Defaults to the last joint in the model. + seed: RNG seed for reproducibility. + """ + import pinocchio + + self._pin = pinocchio + self.model = pinocchio.buildModelFromUrdf(str(urdf_path)) + self.data = self.model.createData() + self.ee_id = ee_joint_id if ee_joint_id is not None else (self.model.njoints - 1) + self.q_lo = self.model.lowerPositionLimit.copy() + self.q_hi = self.model.upperPositionLimit.copy() + self._sample(n_samples, seed) + + def _sample(self, n: int, seed: int) -> None: + rng = np.random.default_rng(seed) + self.positions = np.empty((n, 3)) + self.configs = np.empty((n, self.model.nq)) + self.manipulability = np.empty(n) + + for i in range(n): + q = rng.uniform(self.q_lo, self.q_hi) + self._pin.forwardKinematics(self.model, self.data, q) + self._pin.computeJointJacobians(self.model, self.data, q) + + self.positions[i] = self.data.oMi[self.ee_id].translation + self.configs[i] = q + + J = self._pin.getJointJacobian( + self.model, self.data, self.ee_id, + self._pin.ReferenceFrame.LOCAL_WORLD_ALIGNED, + ) + JJt = J[:3, :] @ J[:3, :].T + self.manipulability[i] = np.sqrt(max(0.0, np.linalg.det(JJt))) + + def query( + self, + target: np.ndarray | tuple[float, float, float], + radius: float = 0.03, + ) -> dict[str, Any]: + """Check reachability at a Cartesian target (x,y,z).""" + target = np.asarray(target, dtype=np.float64) + dists = np.linalg.norm(self.positions - target, axis=1) + mask = dists < radius + if int(mask.sum()) == 0: + for r in (0.05, 0.08, 0.12): + mask = dists < r + if mask.sum() > 0: + break + n_nearby = int(mask.sum()) + if n_nearby == 0: + nearest = int(np.argmin(dists)) + return { + "reachable": False, + "n_configs": 0, + "mean_manipulability": 0.0, + "nearest_distance": float(dists[nearest]), + "nearest_position": self.positions[nearest].tolist(), + "nearest_config": self.configs[nearest].tolist(), + } + manip_nearby = self.manipulability[mask] + indices = np.where(mask)[0] + best_idx = indices[int(np.argmax(manip_nearby))] + return { + "reachable": True, + "n_configs": n_nearby, + "mean_manipulability": float(manip_nearby.mean()), + "max_manipulability": float(manip_nearby.max()), + "best_config": self.configs[best_idx].tolist(), + "best_position": self.positions[best_idx].tolist(), + "distance": float(dists[best_idx]), + } + + def stats(self) -> str: + """Human-readable workspace summary (bounds, reach, hull volume).""" + p = self.positions + lines = [ + "Workspace stats:", + f" Samples: {len(p):,}", + f" X range: [{p[:,0].min():.3f}, {p[:,0].max():.3f}] m", + f" Y range: [{p[:,1].min():.3f}, {p[:,1].max():.3f}] m", + f" Z range: [{p[:,2].min():.3f}, {p[:,2].max():.3f}] m", + f" Max reach from origin: {np.linalg.norm(p, axis=1).max():.3f} m", + f" Manipulability: [{self.manipulability.min():.4f}, {self.manipulability.max():.4f}]", + ] + try: + from scipy.spatial import ConvexHull + + hull = ConvexHull(p) + lines.append(f" Convex hull volume: {hull.volume:.4f} m³") + except Exception: # noqa: BLE001 + pass + return "\n".join(lines) + + +# ── Meshcat visualization helpers ────────────────────────────────────────── + + +def _colormap(values: np.ndarray) -> np.ndarray: + """Red→green colormap, clipped at 2nd/98th percentile for contrast.""" + v = values.copy() + lo, hi = np.percentile(v, 2), np.percentile(v, 98) + v = np.clip((v - lo) / (hi - lo + 1e-12), 0.0, 1.0) + colors = np.zeros((len(v), 3), dtype=np.uint8) + colors[:, 0] = ((1.0 - v) * 255).astype(np.uint8) + colors[:, 1] = (v * 255).astype(np.uint8) + colors[:, 2] = 25 + return colors + + +def render_cloud(meshcat: Any, ws: WorkspaceMap, path: str = "/workspace") -> None: + """Render a WorkspaceMap's EE positions to Drake's Meshcat, colored by manipulability.""" + from pydrake.perception import BaseField, Fields, PointCloud + + cloud = PointCloud(len(ws.positions), Fields(BaseField.kXYZs | BaseField.kRGBs)) + cloud.mutable_xyzs()[:] = ws.positions.T.astype(np.float32) + cloud.mutable_rgbs()[:] = _colormap(ws.manipulability).T + meshcat.SetObject(path, cloud, point_size=0.004) + + +def render_target( + meshcat: Any, + pos: list[float] | tuple[float, float, float], + name: str = "target", + color: tuple[float, float, float] = (1.0, 0.0, 0.0), +) -> None: + from pydrake.geometry import Rgba, Sphere + from pydrake.math import RigidTransform + + meshcat.SetObject(f"/{name}", Sphere(0.015), Rgba(*color, 0.8)) + meshcat.SetTransform(f"/{name}", RigidTransform(list(pos))) + + +# ── CLI ──────────────────────────────────────────────────────────────────── + + +def _cmd_viz(args: argparse.Namespace) -> int: + from pydrake.geometry import Meshcat + + ws = WorkspaceMap(args.urdf, args.samples, ee_joint_id=args.ee_joint_id) + print(ws.stats()) + meshcat = Meshcat() + print(f"\nMeshcat: {meshcat.web_url()}") + print("Green = dexterous, Red = near singularity\n") + render_cloud(meshcat, ws) + print("Press Ctrl-C to exit.") + try: + while True: + time.sleep(1.0) + except KeyboardInterrupt: + pass + return 0 + + +def _cmd_query(args: argparse.Namespace) -> int: + ws = WorkspaceMap(args.urdf, args.samples, ee_joint_id=args.ee_joint_id) + result = ws.query((args.x, args.y, args.z)) + print(f"\nTarget: ({args.x:.3f}, {args.y:.3f}, {args.z:.3f})") + if result["reachable"]: + print(f" REACHABLE — {result['n_configs']} configs within 3cm") + print(f" Mean manipulability: {result['mean_manipulability']:.4f}") + print(f" Best config: {[f'{q:.3f}' for q in result['best_config']]}") + p = result["best_position"] + print(f" Best EE position: ({p[0]:.3f}, {p[1]:.3f}, {p[2]:.3f})") + else: + print(" NOT REACHABLE") + p = result["nearest_position"] + print(f" Nearest reachable point: ({p[0]:.3f}, {p[1]:.3f}, {p[2]:.3f})") + print(f" Distance to nearest: {result['nearest_distance']:.3f} m") + print(f" Nearest config: {[f'{q:.3f}' for q in result['nearest_config']]}") + return 0 + + +def _cmd_suggest(args: argparse.Namespace) -> int: + ws = WorkspaceMap(args.urdf, args.samples, ee_joint_id=args.ee_joint_id) + target = np.array([args.x, args.y, args.z]) + dists = np.linalg.norm(ws.positions - target, axis=1) + closest = np.argsort(dists)[:20] + sorted_by_manip = sorted(closest, key=lambda i: -ws.manipulability[i]) + + print(f"\nSuggested poses near ({args.x:.3f}, {args.y:.3f}, {args.z:.3f}):") + print(f"{'#':>3} {'dist':>6} {'manip':>7} {'position':>30} joint config") + print("-" * 100) + for rank, idx in enumerate(sorted_by_manip[:10], 1): + p, q = ws.positions[idx], ws.configs[idx] + pos_str = f"({p[0]:+.3f}, {p[1]:+.3f}, {p[2]:+.3f})" + q_str = "[" + ", ".join(f"{v:.2f}" for v in q) + "]" + print(f"{rank:>3} {dists[idx]:>6.3f} {ws.manipulability[idx]:>7.4f} {pos_str:>30} {q_str}") + return 0 + + +def _cmd_interactive(args: argparse.Namespace) -> int: + from pydrake.geometry import Meshcat + + ws = WorkspaceMap(args.urdf, args.samples, ee_joint_id=args.ee_joint_id) + print(ws.stats()) + meshcat = Meshcat() + print(f"\nMeshcat: {meshcat.web_url()}") + render_cloud(meshcat, ws) + print("\nType target as 'x y z' (meters), or 'q' to quit.\n") + + while True: + try: + line = input("target> ").strip() + except (EOFError, KeyboardInterrupt): + break + if line.lower() in ("q", "quit", "exit"): + break + parts = line.split() + if len(parts) != 3: + print(" Enter three floats: x y z") + continue + try: + x, y, z = (float(p) for p in parts) + except ValueError: + print(" Invalid input") + continue + + result = ws.query((x, y, z)) + if result["reachable"]: + render_target(meshcat, (x, y, z), "target", (0.0, 1.0, 0.0)) + render_target(meshcat, result["best_position"], "best_ee", (0.0, 0.5, 1.0)) + print(f" REACHABLE — {result['n_configs']} configs, " + f"manip={result['mean_manipulability']:.4f}") + print(f" Joint config: {[round(q, 3) for q in result['best_config']]}") + else: + render_target(meshcat, (x, y, z), "target", (1.0, 0.0, 0.0)) + render_target(meshcat, result["nearest_position"], "nearest", (1.0, 0.5, 0.0)) + print(f" NOT REACHABLE — nearest is {result['nearest_distance']:.3f}m away") + print(f" Nearest config: {[round(q, 3) for q in result['nearest_config']]}") + return 0 + + +def main() -> int: + ap = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + ap.add_argument("urdf", type=Path, help="Path to a URDF parseable by Pinocchio") + ap.add_argument("--samples", type=int, default=100_000) + ap.add_argument( + "--ee-joint-id", + type=int, + default=None, + help="Pinocchio joint index for the end-effector (default: last joint)", + ) + + sub = ap.add_subparsers(dest="command") + sub.add_parser("viz", help="Visualize workspace in Meshcat (default)") + sub.add_parser("interactive", help="Visualize + interactive target query") + q = sub.add_parser("query", help="Query if a target is reachable") + q.add_argument("x", type=float) + q.add_argument("y", type=float) + q.add_argument("z", type=float) + s = sub.add_parser("suggest", help="Suggest reachable poses near a target") + s.add_argument("x", type=float) + s.add_argument("y", type=float) + s.add_argument("z", type=float) + + args = ap.parse_args() + cmd = args.command or "viz" + return { + "viz": _cmd_viz, + "query": _cmd_query, + "suggest": _cmd_suggest, + "interactive": _cmd_interactive, + }[cmd](args) + + +if __name__ == "__main__": + sys.exit(main()) From 415c9a6b4e41848808e1978940579be52240e686 Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Tue, 21 Apr 2026 15:34:05 -0700 Subject: [PATCH 06/30] added user guide readme for openarm --- .../manipulation/openarm_integration.md | 387 ++++++++++++++++++ 1 file changed, 387 insertions(+) create mode 100644 docs/capabilities/manipulation/openarm_integration.md diff --git a/docs/capabilities/manipulation/openarm_integration.md b/docs/capabilities/manipulation/openarm_integration.md new file mode 100644 index 0000000000..306a44fd16 --- /dev/null +++ b/docs/capabilities/manipulation/openarm_integration.md @@ -0,0 +1,387 @@ +# OpenArm Integration + +Guide for running the **OpenArm** — an open-source bimanual 7-DOF research arm built from Damiao DM-J quasi-direct-drive motors — under the dimos manipulation + control stack. + +**If you're standing in front of the hardware and just want to run it, skip to [Quick start](#quick-start).** + +Related: +- Upstream hardware + C++ reference: [enactic/openarm_can](https://github.com/enactic/openarm_can) +- How to integrate any new arm: [adding_a_custom_arm.md](adding_a_custom_arm.md) + +--- + +## Why this integration is different + +Every other arm in dimos wraps a vendor Python SDK: + +| Arm | Transport | Python SDK | +|---|---|---| +| xArm | TCP/IP | `xarm-python-sdk` | +| Piper | CAN (via SDK) | `piper_sdk` | +| R1 Pro | Galaxea | Galaxea SDK | +| Go2 / G1 | WebRTC | Unitree SDK | +| Panda | FCI | `panda-py` | + +**OpenArm ships no Python SDK.** The only interface is raw CAN frames on the wire, speaking the Damiao MIT-mode protocol. So dimos includes a from-scratch driver that encodes/decodes the protocol directly on a SocketCAN bus. The reference implementation is the Enactic C++ library at [enactic/openarm_can](https://github.com/enactic/openarm_can) — we port the frame layout from there. + +## Architecture + +``` +ManipulationModule → ControlCoordinator → OpenArmAdapter → OpenArmBus → SocketCAN → arm + (Drake plan) (100Hz tick loop) (dimos protocol) (CAN driver) +``` + +Code layout: + +``` +dimos/hardware/manipulators/openarm/ +├── driver.py # OpenArmBus, DamiaoMotor — pure CAN driver, no dimos deps +├── adapter.py # OpenArmAdapter — implements dimos ManipulatorAdapter protocol +├── test_driver.py # 13 unit tests (virtual CAN loopback, no hardware) +└── test_adapter.py # 11 unit tests (virtual CAN + mock state frames) + +dimos/robot/catalog/openarm.py # openarm_arm() and openarm_single() config factories +dimos/robot/manipulators/openarm/ +├── blueprints.py # coordinator-*, planner-*, keyboard-teleop-* blueprints +└── scripts/ # bring-up + diagnostic scripts (run manually by humans) + ├── openarm_can_up.sh # bring SocketCAN interfaces up (needs sudo) + ├── openarm_can_probe.py # enumerate & read state from all 8 motors + ├── openarm_set_mit_mode.py # one-time CTRL_MODE=MIT write per motor + └── ... (diagnostics) + +data/openarm_description/ # URDF + meshes (in-tree; may migrate to LFS) +└── urdf/robot/ + ├── openarm_v10_bimanual.urdf # both arms (14 DOF, used by coordinator) + ├── openarm_v10_left.urdf # left arm + torso (7 DOF, per-side planning) + ├── openarm_v10_right.urdf # right arm + torso (7 DOF) + └── openarm_v10_single.urdf # standalone arm (Pinocchio FK for teleop) +``` + +Workspace analysis is generic and lives in [dimos/utils/workspace.py](../../../dimos/utils/workspace.py) — works for any URDF, not just OpenArm. + +--- + +## Quick start + +You need: + +- 2× **OpenArm v10** arms, wired to USB-CAN adapters +- 2× **USB-CAN adapters** (we used gs_usb family, VID:PID `1d50:606f`, e.g. CANable 2.0). Classical CAN @ 1 Mbit is enough; CAN-FD not required +- **Python 3.12 venv with dimos installed** plus `python-can >= 4.3` and `pinocchio` +- **sudo** on first run (to bring up the CAN interfaces) + +### 1. Bring up the CAN buses + +```bash +sudo ./dimos/robot/manipulators/openarm/scripts/openarm_can_up.sh can0 can1 +``` + +This sets both interfaces to classical CAN @ 1 Mbit with a 1000-frame TX queue (enough headroom for the 100 Hz tick loop). If only one bus is present, pass just that one: `sudo ... openarm_can_up.sh can0`. + +**Troubleshooting:** +- `Operation not permitted` → you forgot `sudo`. +- `Operation not supported` on `fd on` → your adapter doesn't support CAN-FD. The script defaults to classical, so this shouldn't happen unless you set `MODE=fd`. +- Only one `can*` interface appears → the other adapter isn't enumerating. On gs_usb boards, the **blue LED** indicates USB enumeration. If one adapter only shows red/green, swap the USB cable (many USB-C cables are charge-only). + +### 2. Verify all 16 motors are alive + +```bash +python ./dimos/robot/manipulators/openarm/scripts/openarm_can_probe.py --channel can0 +python ./dimos/robot/manipulators/openarm/scripts/openarm_can_probe.py --channel can1 +``` + +Expected: `8/8 motors replied` on each bus, with plausible joint positions and rotor temps around 25–30 °C. + +### 3. (First time only) Put motors in MIT mode + +Damiao motors have a persistent `CTRL_MODE` register. They ship in POS_VEL mode by default, which means they will reply to enable/state queries but **silently ignore** any MIT control frames — the "motor doesn't move, error grows" failure. The adapter writes MIT on every `connect()` by default, so this step is usually automatic. If you want to set it explicitly once: + +```bash +python ./dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py --channel can0 +python ./dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py --channel can1 +``` + +The register is persistent across power cycles, so you only need this once per motor (or after a firmware reset). + +### 4. Run a blueprint + +| Blueprint | What it does | +|---|---| +| `coordinator-openarm-mock` | Bimanual, mock adapters. No hardware. | +| `openarm-mock-planner-coordinator` | Drake planner + bimanual mock, Meshcat viz. Great smoke test. | +| `coordinator-openarm-left` / `coordinator-openarm-right` | Single arm, real hardware on can0 / can1. | +| `coordinator-openarm-bimanual` | Both arms, real hardware, no planner. | +| `openarm-planner-coordinator` | **Main usable blueprint** — Drake planner + both arms on real hardware. | +| `keyboard-teleop-openarm-mock` / `keyboard-teleop-openarm` | Single-arm Cartesian IK + pygame keyboard, mock / real. | + +**Safety before hot-plugging hardware:** hold the arms before starting. On connect, the adapter enables all motors and sends gravity-comp holds — the arms go slightly stiff but don't leap. Ctrl-C to cleanly disable and exit. + +First-time recommendation: mock planner to verify everything wires up, then real single-arm, then bimanual. + +```bash +# smoke test (no hardware) +dimos run openarm-mock-planner-coordinator + +# single-arm bring-up (hold the arm physically first) +dimos run coordinator-openarm-left + +# full bimanual with planner +dimos run openarm-planner-coordinator +``` + +Meshcat will appear at http://localhost:7000. + +### 5. Drive the arms from the manipulation client + +With `openarm-planner-coordinator` running in one terminal, open a second terminal and start the REPL client: + +```bash +python -i -m dimos.manipulation.planning.examples.manipulation_client +``` + +This gives you an interactive Python prompt with these functions: + +| Function | Purpose | +|---|---| +| `robots()` | List configured robots (here: `["left_arm", "right_arm"]`) | +| `joints(robot_name)` | Read current joint positions (7 floats) | +| `ee(robot_name)` | Read current end-effector pose | +| `state()` | Module state: `IDLE`, `PLANNING`, `EXECUTING`, `FAULT`, etc. | +| `plan([q1..q7], robot_name)` | Plan a collision-free trajectory to a joint configuration | +| `plan_pose(x, y, z, robot_name=...)` | Plan to a Cartesian EE pose (preserves current orientation) | +| `preview(robot_name)` | Animate the planned path in Meshcat without executing | +| `execute(robot_name)` | Send the planned trajectory to the coordinator | +| `home(robot_name)` | Plan + execute to home joints | +| `commands()` | Print all available functions | + +#### Example session — simple joint moves + +```python +>>> robots() +['left_arm', 'right_arm'] + +>>> joints(robot_name="left_arm") +[0.02, -0.01, -0.13, 0.15, 0.17, -0.07, 0.10] + +>>> # One-liner: plan → preview in Meshcat → execute on hardware +>>> plan([0.3, 0, 0, 0, 0, 0, 0], robot_name="left_arm") and preview(robot_name="left_arm") and execute(robot_name="left_arm") +True + +>>> joints(robot_name="left_arm") +[0.30, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00] # arm is now at the commanded pose +``` + +`plan()` returns `True` on success, `False` if planning failed (check the coordinator terminal for `COLLISION_AT_GOAL`, `INVALID_START`, `NO_SOLUTION`, etc). The `and` chaining is an idiom — if any step fails, the next one is short-circuited. + +If you ever get stuck in a `FAULT` state (e.g. an invalid plan was sent), reset the state machine: + +```python +>>> _client.reset() +'Reset to IDLE — ready for new commands' +``` + +#### Example session — bimanual + +```python +>>> # Move both arms to mirrored poses +>>> plan([0.5, 0, 0, 0, 0, 0, 0], robot_name="left_arm") and execute(robot_name="left_arm") +True +>>> plan([-0.5, 0, 0, 0, 0, 0, 0], robot_name="right_arm") and execute(robot_name="right_arm") +True +``` + +Each arm plans and executes independently — the coordinator runs both trajectories simultaneously on separate tick-loop tasks. + +#### Example session — Cartesian target + +```python +>>> ee(robot_name="left_arm") # see where the EE currently is +>>> plan_pose(0.1, 0.3, 0.5, robot_name="left_arm") and preview(robot_name="left_arm") +True +>>> execute(robot_name="left_arm") +True +``` + +If you don't know which Cartesian targets are reachable, check first with the workspace tool — see [Workspace analysis](#workspace-analysis) below. `plan_pose` will fail with `NO_SOLUTION` if the IK can't find a configuration reaching the target. + +#### Adding obstacles + +```python +>>> add_box("table", 0.4, 0.0, 0.1, w=0.6, h=0.4, d=0.05) # rectangular obstacle +>>> add_sphere("ball", 0.3, 0.2, 0.4, radius=0.05) +>>> plan_pose(0.4, 0.0, 0.3, robot_name="left_arm") # now plans around it +>>> remove("table") # id returned by add_* +``` + +--- + +## Configuration + +### Which CAN bus is which arm + +Linux assigns `can0`/`can1` in USB-enumeration order, which isn't guaranteed stable across reboots or cable swaps. If the arms come up "swapped" (commanding `left_arm` moves the physical right arm), flip these two constants at the top of [blueprints.py](../../../dimos/robot/manipulators/openarm/blueprints.py): + +```python +LEFT_CAN = "can0" +RIGHT_CAN = "can1" +``` + +No other code changes are needed. + +### Gain tuning (MIT kp/kd) + +Defaults live in [adapter.py](../../../dimos/hardware/manipulators/openarm/adapter.py). Gains are per-joint because the shoulder motors (DM8006, 40 Nm) tolerate higher kp than the wrist motors (DM4310, 10 Nm): + +```python +_DEFAULT_KP = [100.0, 100.0, 80.0, 80.0, 60.0, 60.0, 60.0] +_DEFAULT_KD = [1.5, 1.5, 1.0, 1.0, 0.8, 0.8, 0.8] +``` + +Guidelines: +- `kp ∈ [0, 500]` in MIT mode. Higher kp = stiffer position tracking; too high → oscillation. +- `kd ∈ [0, 5]`. Higher kd = more damping, but values above ~2 on these gearboxes cause high-frequency buzz/grinding. +- Gravity compensation is on by default (`gravity_comp=True`) — the adapter uses Pinocchio to compute `G(q)` and adds it as feedforward torque. This removes the need for very high kp to fight gravity, so prefer low kp + gravity comp over high kp. + +### Physical joint limits + +The URDFs use the xacro-generated limits (which include per-side offsets for mirroring). The adapter's `get_limits()` reports the same per-side limits. If you measure tighter physical limits and want to enforce them, edit the URDFs directly — the planner will respect them. + +### Disabling auto MIT-mode write + +The adapter writes `CTRL_MODE=MIT` to every motor at `connect()`. It's idempotent (writing the same value is a no-op), so this is safe to leave on. To verify that a previous write persisted across a power cycle, flip `AUTO_SET_MIT_MODE = False` in [blueprints.py](../../../dimos/robot/manipulators/openarm/blueprints.py) and restart — the arms should still respond. + +--- + +## Motor mapping (OpenArm v10) + +Derived from the URDF's `joint_limits.yaml` (effort column) cross-checked against the Damiao torque tables. Both arms are identical. + +| Send ID | Recv ID | Joint | Motor | vMax [rad/s] | tMax [Nm] | +|---|---|---|---|---|---| +| 0x01 | 0x11 | joint1 | DM8006 | 45 | 40 | +| 0x02 | 0x12 | joint2 | DM8006 | 45 | 40 | +| 0x03 | 0x13 | joint3 | DM4340 | 8 | 28 | +| 0x04 | 0x14 | joint4 | DM4340 | 8 | 28 | +| 0x05 | 0x15 | joint5 | DM4310 | 30 | 10 | +| 0x06 | 0x16 | joint6 | DM4310 | 30 | 10 | +| 0x07 | 0x17 | joint7 | DM4310 | 30 | 10 | +| 0x08 | 0x18 | gripper | DM4310 | 30 | 10 | + +Convention: `recv_id = send_id | 0x10`. + +--- + +## Damiao protocol essentials + +Ported from `enactic/openarm_can/src/openarm/damiao_motor/dm_motor_control.cpp`. You shouldn't need these unless you're modifying the driver. + +### Enable / disable / zero-position + +Send to the motor's send_id. 8-byte payload: + +``` +[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, CMD] + where CMD = 0xFC (enable) | 0xFD (disable) | 0xFE (zero current pose) +``` + +### MIT control frame (8 bytes) + +Bit layout: `q[16] | dq[12] | kp[12] | kd[12] | tau[12]`. Each float quantized via: + +```python +def float_to_uint(x, lo, hi, bits): + x = clamp(x, lo, hi) + return round((x - lo) / (hi - lo) * ((1 << bits) - 1)) +``` + +Gain ranges: `kp ∈ [0, 500]`, `kd ∈ [0, 5]`. Position/velocity/torque ranges come from the motor-type table above. + +Byte layout: +``` +byte0 = q_u >> 8 +byte1 = q_u & 0xFF +byte2 = dq_u >> 4 +byte3 = ((dq_u & 0xF) << 4) | ((kp_u >> 8) & 0xF) +byte4 = kp_u & 0xFF +byte5 = kd_u >> 4 +byte6 = ((kd_u & 0xF) << 4) | ((tau_u >> 8) & 0xF) +byte7 = tau_u & 0xFF +``` + +### State reply (8 bytes, on recv_id) + +Same `q | dq | tau` layout + 2 temperature bytes: + +``` +byte0 = motor_id_echo +byte1..5 = q | dq | tau (same packing as above) +byte6 = t_mos (°C) +byte7 = t_rotor (°C) +``` + +### CTRL_MODE register write + +Broadcast frame on CAN ID `0x7FF`: + +``` +data = [send_id_lo, send_id_hi, 0x55, RID=10, val[0..3]] + where val = 1 (MIT) | 2 (POS_VEL) | 3 (VEL) | 4 (POS_FORCE), little-endian uint32 +``` + +Persistent across power cycles. + +--- + +## Known gotchas + +- **`ip link ... fd on` → `Operation not supported`.** gs_usb firmware doesn't support CAN-FD. Use classical CAN @ 1 Mbit (our bringup script's default). +- **Motors reply to probes but commands do nothing.** CTRL_MODE is not MIT. The adapter now writes MIT on connect, but if you disabled that and motors got reset, run `openarm_set_mit_mode.py`. +- **`COLLISION_AT_START` during planning.** `link5` and `link7` collision meshes overlap by 3 mm at every configuration. Handled by `OPENARM_COLLISION_EXCLUSIONS` in the catalog. If you see it anyway, the exclusion pairs may not be getting applied — check that the collision filter log line appears during world build. +- **`INVALID_START` during planning.** Hardware encoder noise pushed a joint 1 mrad past a URDF limit. Joint4 used to be exactly `lower=0.0` which tripped this — it's now `-0.01` to give breathing room. If you see it on a different joint, widen that limit by ~10 mrad. +- **"Transmit buffer full" (ENOBUFS) at 100 Hz.** Kernel TX queue too small. The bringup script sets `txqueuelen 1000`; the driver also retries on ENOBUFS. If you still see the error, check `ip -details link show canX | grep qlen`. +- **Arms swap sides.** USB enumeration order flipped. Swap `LEFT_CAN` / `RIGHT_CAN` in [blueprints.py](../../../dimos/robot/manipulators/openarm/blueprints.py). + +--- + +## Design decisions + +- **Driver separate from adapter.** `driver.py` has zero dimos deps → unit-testable with a virtual CAN bus, reusable outside dimos. +- **MIT mode for everything.** MIT can emulate position (high kp), velocity (kp=0, nonzero kd+dq), and torque (kp=kd=0, nonzero tau). One code path. +- **Gravity compensation on by default.** Eliminates steady-state position error without needing high kp. Needs Pinocchio + the per-side URDFs. +- **One adapter per CAN bus, keyed by `address`.** Matches the Piper adapter pattern. Bimanual = two adapters with different `address` values. +- **Per-side URDFs for Drake planning.** Loading the full 14-DOF bimanual URDF twice (once per robot instance) creates phantom-arm collisions with the "other" arm frozen at zero. The per-side URDFs keep only one arm's links + the torso, avoiding the phantom collisions while matching the bimanual kinematics exactly. +- **URDF stays in-tree (`data/openarm_description/`) for now.** Can migrate to LFS later — only the path constant in the catalog changes. +- **CAN bringup stays manual (`sudo`).** Auto-bringup from `connect()` would need sudo-in-a-library or a systemd unit; the explicit script is clearer and testable. For production, add a oneshot systemd unit that runs the script at boot. + +--- + +## Workspace analysis + +For figuring out which targets are reachable before planning, use the generic workspace tool: + +```bash +# Visualize the left arm's reachable workspace as a point cloud +python -m dimos.utils.workspace data/openarm_description/urdf/robot/openarm_v10_left.urdf + +# Check if a specific target is reachable +python -m dimos.utils.workspace data/openarm_description/urdf/robot/openarm_v10_left.urdf query 0.1 0.3 0.5 + +# Get a list of reachable poses near a target, ranked by manipulability +python -m dimos.utils.workspace data/openarm_description/urdf/robot/openarm_v10_left.urdf suggest 0.1 0.3 0.5 + +# Interactive: visualize + type targets to query +python -m dimos.utils.workspace data/openarm_description/urdf/robot/openarm_v10_left.urdf interactive +``` + +Points are colored by Yoshikawa manipulability index: green = dexterous, red = near singularity. Avoid planning targets in the red regions. + +--- + +## Testing + +```bash +# Unit tests (no hardware, use virtual CAN) +.venv/bin/python -m pytest dimos/hardware/manipulators/openarm/ -v +``` + +Expected: 24 passed (13 driver + 11 adapter). All tests use `can.Bus(interface="virtual")` loopback — no real hardware needed. From 247f8c9372eff3613c64969884078b95fcdf254f Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Tue, 21 Apr 2026 17:59:29 -0700 Subject: [PATCH 07/30] reduced docstring to minimal --- .../hardware/manipulators/openarm/adapter.py | 63 +++------------ dimos/hardware/manipulators/openarm/driver.py | 76 +++---------------- dimos/robot/catalog/openarm.py | 32 +------- .../robot/manipulators/openarm/blueprints.py | 14 +--- dimos/utils/workspace.py | 36 +-------- 5 files changed, 27 insertions(+), 194 deletions(-) diff --git a/dimos/hardware/manipulators/openarm/adapter.py b/dimos/hardware/manipulators/openarm/adapter.py index 6d7c03d586..f71c209430 100644 --- a/dimos/hardware/manipulators/openarm/adapter.py +++ b/dimos/hardware/manipulators/openarm/adapter.py @@ -12,19 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""OpenArm adapter — implements the ManipulatorAdapter protocol. - -Wraps the from-scratch Damiao MIT-mode CAN driver in :mod:`.driver`. All -physics-level work (frame packing, bus threading, motor state caching) -lives in the driver; this file just maps the dimos protocol methods to -driver calls and handles per-joint sign/offset convention. - -Units: radians, rad/s, Nm (matching the driver and the protocol). - -Default wiring matches the OpenArm v10 BOM (send IDs 1..7 + gripper 8, -motor types DM8006/DM4340/DM4310, recv = send | 0x10). See -``docs/capabilities/manipulation/openarm_integration.md``. -""" +"""OpenArm ManipulatorAdapter — wraps the Damiao MIT-mode driver. SI units.""" from __future__ import annotations @@ -51,16 +39,11 @@ def _socketcan_iface_up(name: str) -> bool: - """Return True if a SocketCAN interface is present and in the UP state. - - Reads /sys directly instead of shelling out to ip(8) — no subprocess, - no sudo, works in containers. - """ + """Return True if SocketCAN interface is present and UP. Reads /sys directly.""" try: flags_path = Path("/sys/class/net") / name / "flags" if not flags_path.exists(): return False - # IFF_UP is bit 0 of the interface flags register. return (int(flags_path.read_text().strip(), 16) & 0x1) == 0x1 except OSError: return False @@ -98,35 +81,14 @@ def _socketcan_iface_up(name: str) -> bool: class OpenArmAdapter: """Adapter for one OpenArm (7 DOF) on a single SocketCAN bus. - Implements ``ManipulatorAdapter`` via duck typing — no inheritance. - - Parameters - ---------- - address: - SocketCAN channel, e.g. ``"can0"``. - dof: - Must be 7 (OpenArm is fixed-DOF). Kept as a parameter for adapter- - protocol uniformity. - side: - ``"left"`` or ``"right"``. Currently only stored for logging; no sign - flips are applied (the URDF handles left/right mirroring). - fd: - CAN-FD. Defaults to False because the gs_usb adapters we have don't - support FD, and OpenArm runs fine on classical 1 Mbit CAN. - interface: - python-can interface name. Use ``"virtual"`` for unit tests. - kp / kd: - Optional per-joint overrides of the POSITION-mode MIT gains. - gravity_comp: - Enable Pinocchio-based gravity compensation feedforward. Computes - ``tau_gravity = G(q_current)`` each tick and adds it as the tau_ff - term in the MIT frame, so the PD gains only handle transient - tracking — not fighting gravity. Eliminates steady-state error. - auto_set_mit_mode: - If True (default), write ``CTRL_MODE=MIT`` to every motor during - ``connect()``. Idempotent — safe to leave on. Set False to verify - a previous write persisted across power cycles (i.e. to confirm - motors stay in MIT mode without the adapter re-setting it). + Key kwargs: + address: SocketCAN channel, e.g. "can0" + side: "left" or "right" (picks URDF for gravity comp) + kp / kd: per-joint MIT gains (optional override) + gravity_comp: Pinocchio G(q) feedforward torque (default True) + auto_set_mit_mode: write CTRL_MODE=MIT on connect (idempotent, default True) + fd: CAN-FD (False for gs_usb adapters, which don't support FD) + interface: python-can backend; "virtual" for unit tests """ # Per-side URDFs for Pinocchio gravity model @@ -355,10 +317,7 @@ def read_error(self) -> tuple[int, str]: # ------------------------------------------------------------------ def _compute_gravity_torques(self, q: list[float]) -> list[float]: - """Compute per-joint gravity torques at configuration q using Pinocchio. - - Returns [0.0]*dof if gravity comp is disabled or model not loaded. - """ + """Pinocchio G(q), clamped to motor torque limits. Zero if model not loaded.""" if self._pin_model is None or self._pin_data is None: return [0.0] * self._dof import pinocchio diff --git a/dimos/hardware/manipulators/openarm/driver.py b/dimos/hardware/manipulators/openarm/driver.py index 88d57aacea..6587799bd6 100644 --- a/dimos/hardware/manipulators/openarm/driver.py +++ b/dimos/hardware/manipulators/openarm/driver.py @@ -12,22 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Damiao MIT-mode CAN driver for OpenArm. +"""Damiao MIT-mode CAN driver for OpenArm. SI units throughout. -Standalone module — zero dimos dependencies. Implements the subset of the -Damiao DM-J CAN protocol needed for MIT-mode position / velocity / torque -control, matching the reference C++ implementation in -`enactic/openarm_can/src/openarm/damiao_motor/dm_motor_control.cpp`. - -Unit tests use ``can.Bus(interface="virtual")`` for loopback. - -All user-facing values are SI units: - angles in radians - angular velocity in rad/s - torque in Nm - -See ``docs/capabilities/manipulation/openarm_integration.md`` for protocol -byte-level documentation and motor tables. +Ported from ``enactic/openarm_can`` (C++). No dimos deps — testable with +``can.Bus(interface="virtual")``. """ from __future__ import annotations @@ -110,19 +98,12 @@ def _clamp(x: float, lo: float, hi: float) -> float: def float_to_uint(x: float, lo: float, hi: float, bits: int) -> int: - """Quantize `x` in `[lo, hi]` to a `bits`-wide unsigned int. - - Matches ``CanPacketEncoder::double_to_uint`` from the reference library. - """ x = _clamp(x, lo, hi) - span = hi - lo - return int((x - lo) / span * ((1 << bits) - 1)) + return int((x - lo) / (hi - lo) * ((1 << bits) - 1)) def uint_to_float(u: int, lo: float, hi: float, bits: int) -> float: - """Inverse of :func:`float_to_uint`.""" - span = hi - lo - return u / ((1 << bits) - 1) * span + lo + return u / ((1 << bits) - 1) * (hi - lo) + lo def pack_mit_frame( @@ -133,7 +114,6 @@ def pack_mit_frame( kd: float, tau: float, ) -> bytes: - """Pack a Damiao MIT-mode control frame into 8 bytes.""" p_max, v_max, t_max = _MOTOR_LIMITS[motor_type] q_u = float_to_uint(q, -p_max, p_max, 16) dq_u = float_to_uint(dq, -v_max, v_max, 12) @@ -212,15 +192,7 @@ def pack_write_param_frame(send_id: int, rid: int, value_u32: int) -> bytes: @dataclass(frozen=True) class DamiaoMotor: - """A single motor on a shared CAN bus. - - Attributes: - send_id: CAN arbitration ID we send commands to (e.g. 0x01..0x08). - motor_type: Determines q/dq/tau scaling. - recv_id: CAN ID the motor replies on. Defaults to ``send_id | 0x10``, - which is the OpenArm convention and matches the pattern used by - ``enactic/openarm_can``. - """ + """One Damiao motor on a CAN bus. recv_id defaults to send_id | 0x10.""" send_id: int motor_type: MotorType @@ -232,7 +204,6 @@ def effective_recv_id(self) -> int: @property def limits(self) -> tuple[float, float, float]: - """``(p_max, v_max, t_max)`` in rad / rad·s⁻¹ / Nm.""" return _MOTOR_LIMITS[self.motor_type] @@ -242,26 +213,7 @@ def limits(self) -> tuple[float, float, float]: class OpenArmBus: - """Manage one SocketCAN bus carrying a set of Damiao motors. - - Spawns a background thread that continuously drains the bus and caches - the latest :class:`MotorState` per motor. Callers fan out commands via - :meth:`send_mit_many` at their chosen control rate and pull fresh state - via :meth:`get_states` / :meth:`get_state`. - - Parameters - ---------- - channel: - SocketCAN interface name (e.g. ``"can0"``). - motors: - Motors present on this bus. - fd: - True to use CAN-FD (adapter must support it). OpenArm ships fine on - classical CAN @ 1 Mbit — keep this False unless you know you need FD. - interface: - python-can interface name. ``"socketcan"`` for real hardware, - ``"virtual"`` for unit tests. - """ + """One SocketCAN bus with a background RX thread caching latest state.""" def __init__( self, @@ -344,14 +296,10 @@ def disable_all(self) -> None: time.sleep(0.002) def set_zero(self, send_id: int) -> None: - """Set current physical position as the motor's zero. - - Destructive — don't call this unless you know what you're doing. - """ + """Set current physical position as the motor's zero. Destructive.""" self._send_raw(send_id, _pack_control_command(_CMD_SET_ZERO)) def write_ctrl_mode(self, send_id: int, mode: int = CTRL_MODE_MIT) -> None: - """Write the persistent CTRL_MODE register (e.g. to switch to MIT).""" self._send_raw( _BROADCAST_ID, pack_write_param_frame(send_id, _RID_CTRL_MODE, mode), @@ -366,7 +314,6 @@ def send_mit( kd: float, tau: float, ) -> None: - """Send one MIT command frame to `send_id`.""" motor = next((m for m in self._motors if m.send_id == send_id), None) if motor is None: raise KeyError(f"motor 0x{send_id:02X} not on bus {self._channel}") @@ -377,10 +324,7 @@ def send_mit_many( self, commands: list[tuple[float, float, float, float, float]], ) -> None: - """Fan out one MIT frame per motor, in the order motors were registered. - - ``commands[i] = (q, dq, kp, kd, tau)`` is sent to ``self.motors[i]``. - """ + """One MIT frame per motor; commands[i] → self.motors[i] = (q, dq, kp, kd, tau).""" if len(commands) != len(self._motors): raise ValueError( f"expected {len(self._motors)} commands, got {len(commands)}" @@ -410,12 +354,10 @@ def get_state(self, send_id: int) -> MotorState | None: return self._states.get(motor.effective_recv_id) def get_states(self) -> list[MotorState | None]: - """Latest cached state in motor-registration order.""" with self._state_lock: return [self._states.get(m.effective_recv_id) for m in self._motors] def wait_all_states(self, timeout: float = 0.5) -> bool: - """Block until every motor has reported at least once (or timeout).""" deadline = time.monotonic() + timeout while time.monotonic() < deadline: with self._state_lock: diff --git a/dimos/robot/catalog/openarm.py b/dimos/robot/catalog/openarm.py index 91c49e123b..afc25db838 100644 --- a/dimos/robot/catalog/openarm.py +++ b/dimos/robot/catalog/openarm.py @@ -12,16 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""OpenArm v10 robot configurations. - -Ships in bimanual form — two 7-DOF arms on a shared torso. We load the -full bimanual URDF and select each arm's 7 joints via ``joint_names``. -Drake handles the rest. - -The URDF is referenced from the in-tree ``data/openarm_description/`` -folder during bring-up; the plan is to migrate it to LFS at the end of -the integration (only the ``_OPENARM_MODEL_PATH`` constant below changes). -""" +"""OpenArm v10 robot configurations.""" from __future__ import annotations @@ -62,20 +53,7 @@ def openarm_arm( address: str | None = None, **overrides: Any, ) -> RobotConfig: - """Create an OpenArm v10 config for one side of the bimanual rig. - - Both sides share the bimanual URDF (built with ``bimanual:=true``) and - pick out their 7 arm joints via ``joint_names``. Drake's SetAutoRenaming - handles duplicate model names when both configs go into one planning - world. - - Args: - side: ``"left"`` or ``"right"``. - name: Robot name (defaults to ``f"{side}_arm"``). - adapter_type: ``"mock"`` for dry runs, ``"openarm"`` for real CAN. - address: SocketCAN channel (e.g. ``"can0"``) — required for real hw. - **overrides: Override any ``RobotConfig`` field. - """ + """OpenArm v10 config for one side. Uses per-side URDF (arm + torso only).""" if side not in ("left", "right"): raise ValueError(f"side must be 'left' or 'right', got {side!r}") @@ -119,11 +97,7 @@ def openarm_single( address: str | None = None, **overrides: Any, ) -> RobotConfig: - """Single-arm (non-bimanual) config for keyboard teleop / cartesian IK. - - Uses the pre-expanded single-arm URDF (bimanual:=false). Joint names are - ``openarm_joint1..7``. For bimanual planning use :func:`openarm_arm`. - """ + """Single-arm config (keyboard teleop, cartesian IK). Use openarm_arm() for bimanual.""" defaults: dict[str, Any] = { "name": name, "model_path": OPENARM_V10_FK_MODEL, diff --git a/dimos/robot/manipulators/openarm/blueprints.py b/dimos/robot/manipulators/openarm/blueprints.py index 05b29f9e87..1e24bcaf97 100644 --- a/dimos/robot/manipulators/openarm/blueprints.py +++ b/dimos/robot/manipulators/openarm/blueprints.py @@ -12,19 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""OpenArm coordinator blueprints (single-arm + bimanual). - -Usage: - dimos run coordinator-openarm-mock # Bimanual, mock adapters - dimos run coordinator-openarm-left # Single arm, real CAN - dimos run coordinator-openarm-right # Single arm, real CAN - dimos run coordinator-openarm-bimanual # Both arms, real CAN - -The CAN interface each arm uses is set by the ``LEFT_CAN``/``RIGHT_CAN`` -constants below. Linux enumerates gs_usb adapters in USB-discovery order -which isn't guaranteed stable, so if your arms come up "swapped" just flip -these two values and rerun — no other code changes needed. -""" +"""OpenArm blueprints. Flip LEFT_CAN / RIGHT_CAN below if arms come up swapped.""" from __future__ import annotations diff --git a/dimos/utils/workspace.py b/dimos/utils/workspace.py index 2fcd2fba8b..f8b20fd131 100644 --- a/dimos/utils/workspace.py +++ b/dimos/utils/workspace.py @@ -12,24 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Reachability / manipulability workspace analysis for any serial robot. +"""Reachability + Yoshikawa-manipulability workspace analysis via Pinocchio. -Samples joint configurations uniformly within URDF limits, runs FK, and -collects (a) end-effector Cartesian positions and (b) the Yoshikawa -manipulability index ``sqrt(det(J·Jᵀ))`` at each pose. Backed by Pinocchio. - -Library use: - from dimos.utils.workspace import WorkspaceMap - ws = WorkspaceMap("path/to/robot.urdf", n_samples=100_000) - result = ws.query((0.1, 0.3, 0.5)) - if result["reachable"]: - print(result["best_config"]) - -CLI: - python -m dimos.utils.workspace path/to/robot.urdf # visualize - python -m dimos.utils.workspace path/to/robot.urdf query 0.1 0.3 0.5 - python -m dimos.utils.workspace path/to/robot.urdf suggest 0.1 0.3 0.5 - python -m dimos.utils.workspace path/to/robot.urdf interactive +Run ``python -m dimos.utils.workspace [viz|query|suggest|interactive]``. """ from __future__ import annotations @@ -44,11 +29,7 @@ class WorkspaceMap: - """Precomputed reachability + manipulability map with spatial lookup. - - Not specific to any robot — works on any serial-chain URDF that - Pinocchio can load. - """ + """Sampled reachability + manipulability map. Works on any URDF Pinocchio can load.""" def __init__( self, @@ -58,17 +39,6 @@ def __init__( ee_joint_id: int | None = None, seed: int = 42, ) -> None: - """Build the workspace map. - - Args: - urdf_path: Path to a URDF parseable by Pinocchio (no xacro, no - ``package://`` URIs unless the mesh dirs resolve relative to - ``urdf_path``). - n_samples: Number of random joint configurations to sample. - ee_joint_id: Pinocchio joint index to treat as the end-effector. - Defaults to the last joint in the model. - seed: RNG seed for reproducibility. - """ import pinocchio self._pin = pinocchio From 91ab832610bb91cfa0c366e1361f2441dfe34e2b Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Tue, 21 Apr 2026 18:57:49 -0700 Subject: [PATCH 08/30] gravity compensation adde to velocity mode --- dimos/hardware/manipulators/openarm/adapter.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/dimos/hardware/manipulators/openarm/adapter.py b/dimos/hardware/manipulators/openarm/adapter.py index f71c209430..d80a8c1bf2 100644 --- a/dimos/hardware/manipulators/openarm/adapter.py +++ b/dimos/hardware/manipulators/openarm/adapter.py @@ -361,7 +361,9 @@ def write_joint_positions( def write_joint_velocities(self, velocities: list[float]) -> bool: # MIT velocity tracking: kp=0, send dq directly, anchor q at the - # last-commanded position so the motor doesn't drift. + # last-commanded position so the motor doesn't drift. Gravity + # feedforward is still needed — with kp=0 the only restoring force + # is damping, so without tau_ff the arm droops under its own weight. if self._bus is None or not self._enabled: return False if len(velocities) != self._dof: @@ -369,9 +371,11 @@ def write_joint_velocities(self, velocities: list[float]) -> bool: if self._last_cmd_q is None: self._last_cmd_q = self.read_joint_positions() anchor = self._last_cmd_q + q_current = self.read_joint_positions() + tau_ff = self._compute_gravity_torques(q_current) commands = [ - (q_anchor, dq, 0.0, kd, 0.0) - for q_anchor, dq, kd in zip(anchor, velocities, self._kd) + (q_anchor, dq, 0.0, kd, tau) + for q_anchor, dq, kd, tau in zip(anchor, velocities, self._kd, tau_ff) ] self._bus.send_mit_many(commands) return True From ae92718c1e79657136603c4af2dd8e32f324ff8f Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Tue, 21 Apr 2026 18:58:19 -0700 Subject: [PATCH 09/30] fixed stale docstring --- .../manipulators/openarm/scripts/openarm_can_probe.py | 9 ++++----- .../manipulators/openarm/scripts/openarm_set_mit_mode.py | 9 ++++++--- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/dimos/robot/manipulators/openarm/scripts/openarm_can_probe.py b/dimos/robot/manipulators/openarm/scripts/openarm_can_probe.py index 036b0a587f..ed57f16691 100755 --- a/dimos/robot/manipulators/openarm/scripts/openarm_can_probe.py +++ b/dimos/robot/manipulators/openarm/scripts/openarm_can_probe.py @@ -1,10 +1,9 @@ #!/usr/bin/env python3 """Probe an OpenArm on a SocketCAN interface. -Enumerates all 8 expected Damiao motors (7 arm joints + gripper) on one CAN-FD -bus, enables each, reads back one state frame, then disables. This is the -Phase-0 hardware-verification script — if this does not work, nothing -downstream will. +Enumerates all 8 expected Damiao motors (7 arm joints + gripper) on one CAN bus +(classical by default, use --fd for CAN-FD), enables each, reads back one state +frame, then disables. Phase-0 hardware-verification script. Run AFTER bringing the bus up with dimos/robot/manipulators/openarm/scripts/openarm_can_up.sh. @@ -46,7 +45,7 @@ ENABLE = bytes([0xFF] * 7 + [0xFC]) DISABLE = bytes([0xFF] * 7 + [0xFD]) -FD = True # set by --classical at runtime +FD = False # set by --fd at runtime; defaults to classical CAN @ 1 Mbit def uint_to_float(x: int, lo: float, hi: float, bits: int) -> float: diff --git a/dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py b/dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py index a361e74e4e..712d23b395 100755 --- a/dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py +++ b/dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py @@ -15,11 +15,14 @@ cycles. Usage: - # All 8 motors on can0 - python dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py --channel can0 --classical + # All 8 motors on can0 (classical CAN @ 1 Mbit, default) + python dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py --channel can0 # Single motor - python dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py --channel can0 --id 0x05 --classical + python dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py --channel can0 --id 0x05 + + # CAN-FD (only if your adapter supports it) + python dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py --channel can0 --fd """ from __future__ import annotations From 81169ea93aa71230188b1b8bb1b490f7d9cd81ad Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Tue, 21 Apr 2026 18:59:39 -0700 Subject: [PATCH 10/30] better error catching with errno library --- dimos/hardware/manipulators/openarm/driver.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/dimos/hardware/manipulators/openarm/driver.py b/dimos/hardware/manipulators/openarm/driver.py index 6587799bd6..843e7dc64b 100644 --- a/dimos/hardware/manipulators/openarm/driver.py +++ b/dimos/hardware/manipulators/openarm/driver.py @@ -21,6 +21,7 @@ from __future__ import annotations import enum +import errno import struct import threading import time @@ -384,15 +385,21 @@ def _send_raw(self, arbitration_id: int, data: bytes) -> None: is_fd=self._fd, bitrate_switch=self._fd, ) - # Retry on TX buffer full (ENOBUFS / errno 105) — the gs_usb - # adapter has a tiny kernel-side TX queue. A short backoff lets - # the kernel drain one frame before we try again. + # Retry on TX buffer full (ENOBUFS) — the gs_usb adapter has a + # tiny kernel-side TX queue. A short backoff lets the kernel drain + # one frame before we try again. Check errno on the underlying + # cause rather than string-matching the python-can message. for attempt in range(4): try: self._bus.send(msg) return except can.CanOperationError as e: - if "105" in str(e) and attempt < 3: + cause = e.__cause__ or e + is_enobufs = ( + getattr(cause, "errno", None) == errno.ENOBUFS + or "105" in str(e) # fallback if errno not exposed + ) + if is_enobufs and attempt < 3: time.sleep(0.001 * (attempt + 1)) else: raise From bae29ef02fa557069e07ce49c08e8ba071f4fb81 Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Tue, 21 Apr 2026 19:06:09 -0700 Subject: [PATCH 11/30] mypy fixes --- .../robot/manipulators/openarm/blueprints.py | 21 +++++++++++++------ .../openarm/scripts/openarm_set_mit_mode.py | 2 +- dimos/utils/workspace.py | 2 +- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/dimos/robot/manipulators/openarm/blueprints.py b/dimos/robot/manipulators/openarm/blueprints.py index 1e24bcaf97..13070d9fcb 100644 --- a/dimos/robot/manipulators/openarm/blueprints.py +++ b/dimos/robot/manipulators/openarm/blueprints.py @@ -30,8 +30,8 @@ from dimos.teleop.keyboard.keyboard_teleop_module import KeyboardTeleopModule # ── Mock bimanual: no hardware, great for verifying wiring ───────────── -_mock_left = _openarm(side="left", name="left_arm") -_mock_right = _openarm(side="right", name="right_arm") +_mock_left = _openarm(side="left") +_mock_right = _openarm(side="right") coordinator_openarm_mock = ControlCoordinator.blueprint( hardware=[_mock_left.to_hardware_component(), _mock_right.to_hardware_component()], @@ -58,10 +58,19 @@ # replaced / factory-reset). AUTO_SET_MIT_MODE = True -_HW_KW = dict(adapter_type="openarm", - adapter_kwargs={"auto_set_mit_mode": AUTO_SET_MIT_MODE}) -_left_hw = _openarm(side="left", name="left_arm", address=LEFT_CAN, **_HW_KW) -_right_hw = _openarm(side="right", name="right_arm", address=RIGHT_CAN, **_HW_KW) +_ADAPTER_KWARGS = {"auto_set_mit_mode": AUTO_SET_MIT_MODE} +_left_hw = _openarm( + side="left", + address=LEFT_CAN, + adapter_type="openarm", + adapter_kwargs=_ADAPTER_KWARGS, +) +_right_hw = _openarm( + side="right", + address=RIGHT_CAN, + adapter_type="openarm", + adapter_kwargs=_ADAPTER_KWARGS, +) coordinator_openarm_left = ControlCoordinator.blueprint( hardware=[_left_hw.to_hardware_component()], diff --git a/dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py b/dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py index 712d23b395..0b7f5b1cb1 100755 --- a/dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py +++ b/dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py @@ -60,7 +60,7 @@ def write_ctrl_mode(bus: can.BusABC, send_id: int, fd: bool) -> bool: if len(msg.data) >= 4 and msg.data[2] in (0x33, 0x55): rid = msg.data[3] if rid == RID_CTRL_MODE: - echoed = struct.unpack(" No """Render a WorkspaceMap's EE positions to Drake's Meshcat, colored by manipulability.""" from pydrake.perception import BaseField, Fields, PointCloud - cloud = PointCloud(len(ws.positions), Fields(BaseField.kXYZs | BaseField.kRGBs)) + cloud = PointCloud(len(ws.positions), Fields(int(BaseField.kXYZs) | int(BaseField.kRGBs))) cloud.mutable_xyzs()[:] = ws.positions.T.astype(np.float32) cloud.mutable_rgbs()[:] = _colormap(ws.manipulability).T meshcat.SetObject(path, cloud, point_size=0.004) From 2f6cc1e8cf100cd3c07eeb08ce8d14248c1baf88 Mon Sep 17 00:00:00 2001 From: Paul Nechifor Date: Wed, 22 Apr 2026 08:36:03 +0300 Subject: [PATCH 12/30] fix(tests): fix flakey porcelain test (#1900) --- dimos/porcelain/remote_module_source.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/dimos/porcelain/remote_module_source.py b/dimos/porcelain/remote_module_source.py index bbf08186e0..7a33a4271e 100644 --- a/dimos/porcelain/remote_module_source.py +++ b/dimos/porcelain/remote_module_source.py @@ -17,6 +17,7 @@ import copyreg import pickle import threading +import time from types import MappingProxyType from typing import TYPE_CHECKING, Any @@ -25,6 +26,8 @@ from dimos.porcelain.module_source import ModuleSource from dimos.utils.logging_config import setup_logger +_CONNECT_RETRY_DEADLINE_S = 2.0 + if TYPE_CHECKING: from dimos.core.coordination.blueprints import Blueprint @@ -53,7 +56,7 @@ class RemoteModuleSource(ModuleSource): is_remote = True def __init__(self, host: str, port: int) -> None: - self._coord_conn = rpyc.connect(host, port, config={"sync_request_timeout": 30}) + self._coord_conn = _rpyc_connect(host, port, config={"sync_request_timeout": 30}) self._cache: dict[str, tuple[rpyc.Connection, Any]] = {} self._lock = threading.RLock() @@ -68,7 +71,7 @@ def get_rpyc_module(self, name: str) -> Any: endpoint = self._coord_conn.root.get_module_endpoint(name) host, port, module_id = endpoint[0], int(endpoint[1]), int(endpoint[2]) - conn = rpyc.connect(host, port, config={"sync_request_timeout": 30}) + conn = _rpyc_connect(host, port, config={"sync_request_timeout": 30}) module = conn.root.get_module(module_id) self._cache[name] = (conn, module) return module @@ -105,3 +108,16 @@ def close(self) -> None: self._coord_conn.close() except Exception: pass + + +def _rpyc_connect(host: str, port: int, **kwargs: Any) -> rpyc.Connection: + deadline = time.monotonic() + _CONNECT_RETRY_DEADLINE_S + delay = 0.010 + while True: + try: + return rpyc.connect(host, port, **kwargs) + except ConnectionRefusedError: + if time.monotonic() >= deadline: + raise + time.sleep(delay) + delay = min(delay * 2, 0.200) From 388752d88c076cea634c0d599d05096cb3a40fbc Mon Sep 17 00:00:00 2001 From: RD <63036454+ruthwikdasyam@users.noreply.github.com> Date: Wed, 22 Apr 2026 20:10:30 -0700 Subject: [PATCH 13/30] feat(go2): rage mode via webrtc (#1903) --- dimos/protocol/pubsub/test_spec.py | 1 + dimos/robot/all_blueprints.py | 1 + dimos/robot/unitree/connection.py | 25 +++++++++++ ...unitree_go2_webrtc_rage_keyboard_teleop.py | 41 +++++++++++++++++ dimos/robot/unitree/go2/connection.py | 31 +++++++++++++ dimos/robot/unitree/keyboard_teleop.py | 44 ++++++++++++++----- dimos/robot/unitree/mujoco_connection.py | 3 ++ 7 files changed, 134 insertions(+), 12 deletions(-) create mode 100644 dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_rage_keyboard_teleop.py diff --git a/dimos/protocol/pubsub/test_spec.py b/dimos/protocol/pubsub/test_spec.py index 0e61132c1c..0907e662d5 100644 --- a/dimos/protocol/pubsub/test_spec.py +++ b/dimos/protocol/pubsub/test_spec.py @@ -282,6 +282,7 @@ async def consume_messages() -> None: @pytest.mark.slow +@pytest.mark.skipif_macos_bug @pytest.mark.parametrize("pubsub_context, topic, values", testdata) def test_high_volume_messages( pubsub_context: Callable[[], Any], topic: Any, values: list[Any] diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 74a45f5e8f..3c8cc80f82 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -97,6 +97,7 @@ "unitree-go2-temporal-memory": "dimos.robot.unitree.go2.blueprints.agentic.unitree_go2_temporal_memory:unitree_go2_temporal_memory", "unitree-go2-vlm-stream-test": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_vlm_stream_test:unitree_go2_vlm_stream_test", "unitree-go2-webrtc-keyboard-teleop": "dimos.robot.unitree.go2.blueprints.basic.unitree_go2_webrtc_keyboard_teleop:unitree_go2_webrtc_keyboard_teleop", + "unitree-go2-webrtc-rage-keyboard-teleop": "dimos.robot.unitree.go2.blueprints.basic.unitree_go2_webrtc_rage_keyboard_teleop:unitree_go2_webrtc_rage_keyboard_teleop", "unity-sim": "dimos.simulation.unity.blueprint:unity_sim", "xarm-perception": "dimos.manipulation.blueprints:xarm_perception", "xarm-perception-agent": "dimos.manipulation.blueprints:xarm_perception_agent", diff --git a/dimos/robot/unitree/connection.py b/dimos/robot/unitree/connection.py index 1add8595f1..85dd01ef55 100644 --- a/dimos/robot/unitree/connection.py +++ b/dimos/robot/unitree/connection.py @@ -79,6 +79,8 @@ def to_ndarray(self, format=None): # type: ignore[no-untyped-def] class UnitreeWebRTCConnection(Resource): + _SPORT_API_ID_RAGEMODE: int = 2059 + def __init__(self, ip: str, mode: str = "ai") -> None: self.ip = ip self.mode = mode @@ -296,6 +298,29 @@ def free_walk(self) -> bool: """Activate FreeWalk locomotion mode — enables walking and velocity commands.""" return bool(self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["FreeWalk"]})) + def enable_rage_mode(self) -> bool: + """Enable Rage Mode on the Go2 via WebRTC. + Assumes the robot is already in BalanceStand. + """ + rage_ok = bool( + self.publish_request( + RTC_TOPIC["SPORT_MOD"], + {"api_id": self._SPORT_API_ID_RAGEMODE, "parameter": {"data": True}}, + ) + ) + time.sleep(2.0) + + joystick_ok = bool( + self.publish_request( + RTC_TOPIC["SPORT_MOD"], + { + "api_id": SPORT_CMD["SwitchJoystick"], + "parameter": {"data": True}, + }, + ) + ) + return rage_ok and joystick_ok + def liedown(self) -> bool: return bool( self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandDown"]}) diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_rage_keyboard_teleop.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_rage_keyboard_teleop.py new file mode 100644 index 0000000000..c2fd9a3b87 --- /dev/null +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_rage_keyboard_teleop.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unitree Go2 keyboard teleop via WebRTC with Rage Mode enabled. + +Same topology as unitree-go2-webrtc-keyboard-teleop, but GO2Connection is +configured with rage_mode=True so FsmRageMode is toggled on after +BalanceStand at connection start. + +Usage: + dimos run unitree-go2-webrtc-rage-keyboard-teleop +""" + +from __future__ import annotations + +from dimos.core.coordination.blueprints import autoconnect +from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_webrtc_keyboard_teleop import ( + unitree_go2_webrtc_keyboard_teleop, +) +from dimos.robot.unitree.go2.connection import GO2Connection +from dimos.robot.unitree.keyboard_teleop import KeyboardTeleop + +unitree_go2_webrtc_rage_keyboard_teleop = autoconnect( + unitree_go2_webrtc_keyboard_teleop, + GO2Connection.blueprint(mode="rage"), + KeyboardTeleop.blueprint(linear_speed=1.25, angular_speed=1.2), +).global_config(obstacle_avoidance=True) + +__all__ = ["unitree_go2_webrtc_rage_keyboard_teleop"] diff --git a/dimos/robot/unitree/go2/connection.py b/dimos/robot/unitree/go2/connection.py index b994f0cae0..1b91cd7a27 100644 --- a/dimos/robot/unitree/go2/connection.py +++ b/dimos/robot/unitree/go2/connection.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum import logging import sys from threading import Thread @@ -56,8 +57,14 @@ logger = logging.getLogger(__name__) +class Go2Mode(str, Enum): + DEFAULT = "default" + RAGE = "rage" + + class ConnectionConfig(ModuleConfig): ip: str = Field(default_factory=lambda m: m["g"].robot_ip) + mode: Go2Mode = Go2Mode.DEFAULT class Go2ConnectionProtocol(Protocol): @@ -73,6 +80,7 @@ def standup(self) -> bool: ... def liedown(self) -> bool: ... def balance_stand(self) -> bool: ... def set_obstacle_avoidance(self, enabled: bool = True) -> None: ... + def enable_rage_mode(self) -> bool: ... def publish_request(self, topic: str, data: dict) -> dict: ... # type: ignore[type-arg] @@ -142,6 +150,9 @@ def balance_stand(self) -> bool: def set_obstacle_avoidance(self, enabled: bool = True) -> None: pass + def enable_rage_mode(self) -> bool: + return True + @simple_mcache def lidar_stream(self): # type: ignore[no-untyped-def] lidar_store = TimedSensorReplay(f"{self.dir_name}/lidar") # type: ignore[var-annotated] @@ -251,6 +262,10 @@ def onimage(image: Image) -> None: self.standup() time.sleep(3) self.connection.balance_stand() + + if self.config.mode == Go2Mode.RAGE: + self.connection.enable_rage_mode() + self.connection.set_obstacle_avoidance(self.config.g.obstacle_avoidance) # self.record("go2_bigoffice") @@ -317,6 +332,22 @@ def liedown(self) -> bool: """Make the robot lie down.""" return self.connection.liedown() + @rpc + def balance_stand(self) -> bool: + """Enter BalanceStand: neutral state for switching locomotion modes""" + return self.connection.balance_stand() + + @rpc + def enable_rage_mode(self) -> bool: + """Enable Rage Mode (~2.5 m/s forward velocity envelope). + Ensures BalanceStand precondition regardless of current FSM state. + """ + self.connection.balance_stand() + time.sleep(0.3) + result = self.connection.enable_rage_mode() + logger.info("Rage Mode enabled") + return result + @rpc def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: """Publish a request to the WebRTC connection. diff --git a/dimos/robot/unitree/keyboard_teleop.py b/dimos/robot/unitree/keyboard_teleop.py index 4ebf6e3cce..e3c78ecc52 100644 --- a/dimos/robot/unitree/keyboard_teleop.py +++ b/dimos/robot/unitree/keyboard_teleop.py @@ -29,11 +29,20 @@ # Force X11 driver to avoid OpenGL threading issues os.environ["SDL_VIDEODRIVER"] = "x11" +DEFAULT_LINEAR_SPEED: float = 0.5 # m/s +DEFAULT_ANGULAR_SPEED: float = 0.8 # rad/s +DEFAULT_BOOST_MULTIPLIER: float = 2.0 +DEFAULT_SLOW_MULTIPLIER: float = 0.5 + class KeyboardTeleop(Module): """Pygame-based keyboard control module. Outputs standard Twist messages on /cmd_vel for velocity control. + + Speed constants can be tuned at the top of this file, or overridden + per-instance by passing linear_speed / angular_speed / + boost_multiplier / slow_multiplier to the constructor. """ cmd_vel: Out[Twist] # Standard velocity commands @@ -45,9 +54,20 @@ class KeyboardTeleop(Module): _clock: pygame.time.Clock | None = None _font: pygame.font.Font | None = None - def __init__(self, **kwargs: Any) -> None: + def __init__( + self, + linear_speed: float = DEFAULT_LINEAR_SPEED, + angular_speed: float = DEFAULT_ANGULAR_SPEED, + boost_multiplier: float = DEFAULT_BOOST_MULTIPLIER, + slow_multiplier: float = DEFAULT_SLOW_MULTIPLIER, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self._stop_event = threading.Event() + self.linear_speed = linear_speed + self.angular_speed = angular_speed + self.boost_multiplier = boost_multiplier + self.slow_multiplier = slow_multiplier @rpc def start(self) -> None: @@ -115,28 +135,28 @@ def _pygame_loop(self) -> None: # Forward/backward (W/S) if pygame.K_w in self._keys_held: - twist.linear.x = 0.5 + twist.linear.x = self.linear_speed if pygame.K_s in self._keys_held: - twist.linear.x = -0.5 + twist.linear.x = -self.linear_speed # Strafe left/right (Q/E) if pygame.K_q in self._keys_held: - twist.linear.y = 0.5 + twist.linear.y = self.linear_speed if pygame.K_e in self._keys_held: - twist.linear.y = -0.5 + twist.linear.y = -self.linear_speed # Turning (A/D) if pygame.K_a in self._keys_held: - twist.angular.z = 0.8 + twist.angular.z = self.angular_speed if pygame.K_d in self._keys_held: - twist.angular.z = -0.8 + twist.angular.z = -self.angular_speed - # Apply speed modifiers (Shift = 2x, Ctrl = 0.5x) + # Apply speed modifiers (Shift = boost, Ctrl = slow) speed_multiplier = 1.0 if pygame.K_LSHIFT in self._keys_held or pygame.K_RSHIFT in self._keys_held: - speed_multiplier = 2.0 + speed_multiplier = self.boost_multiplier elif pygame.K_LCTRL in self._keys_held or pygame.K_RCTRL in self._keys_held: - speed_multiplier = 0.5 + speed_multiplier = self.slow_multiplier twist.linear.x *= speed_multiplier twist.linear.y *= speed_multiplier @@ -165,9 +185,9 @@ def _update_display(self, twist: Twist) -> None: # Determine active speed multiplier speed_mult_text = "" if pygame.K_LSHIFT in self._keys_held or pygame.K_RSHIFT in self._keys_held: - speed_mult_text = " [BOOST 2x]" + speed_mult_text = f" [BOOST {self.boost_multiplier:g}x]" elif pygame.K_LCTRL in self._keys_held or pygame.K_RCTRL in self._keys_held: - speed_mult_text = " [SLOW 0.5x]" + speed_mult_text = f" [SLOW {self.slow_multiplier:g}x]" texts = [ "Keyboard Teleop" + speed_mult_text, diff --git a/dimos/robot/unitree/mujoco_connection.py b/dimos/robot/unitree/mujoco_connection.py index 03d15db756..39c0904684 100644 --- a/dimos/robot/unitree/mujoco_connection.py +++ b/dimos/robot/unitree/mujoco_connection.py @@ -241,6 +241,9 @@ def balance_stand(self) -> bool: def set_obstacle_avoidance(self, enabled: bool = True) -> None: pass + def enable_rage_mode(self) -> bool: + return True + def get_video_frame(self) -> NDArray[Any] | None: if self.shm_data is None: return None From bf65b4a3a8fb4ab2967119efe6701dd9f75b575c Mon Sep 17 00:00:00 2001 From: Andrew Lauer <69774903+aclauer@users.noreply.github.com> Date: Thu, 23 Apr 2026 18:48:19 -0700 Subject: [PATCH 14/30] feat:rust native modules (#1794) Co-authored-by: leshy --- dimos/core/native_module.py | 29 +- examples/native-modules/rust/.gitignore | 1 + examples/native-modules/rust/Cargo.lock | 305 +++++++++++++ examples/native-modules/rust/Cargo.toml | 18 + .../native-modules/rust/src/native_ping.rs | 40 ++ .../native-modules/rust/src/native_pong.rs | 49 +++ examples/native-modules/rust_ping_pong.py | 76 ++++ native/rust/.gitignore | 1 + native/rust/Cargo.lock | 296 +++++++++++++ native/rust/Cargo.toml | 15 + native/rust/src/lcm.rs | 29 ++ native/rust/src/lib.rs | 10 + native/rust/src/module.rs | 416 ++++++++++++++++++ native/rust/src/transport.rs | 13 + 14 files changed, 1293 insertions(+), 5 deletions(-) create mode 100644 examples/native-modules/rust/.gitignore create mode 100644 examples/native-modules/rust/Cargo.lock create mode 100644 examples/native-modules/rust/Cargo.toml create mode 100644 examples/native-modules/rust/src/native_ping.rs create mode 100644 examples/native-modules/rust/src/native_pong.rs create mode 100644 examples/native-modules/rust_ping_pong.py create mode 100644 native/rust/.gitignore create mode 100644 native/rust/Cargo.lock create mode 100644 native/rust/Cargo.toml create mode 100644 native/rust/src/lcm.rs create mode 100644 native/rust/src/lib.rs create mode 100644 native/rust/src/module.rs create mode 100644 native/rust/src/transport.rs diff --git a/dimos/core/native_module.py b/dimos/core/native_module.py index 6cc918776e..4e2ec0c699 100644 --- a/dimos/core/native_module.py +++ b/dimos/core/native_module.py @@ -84,15 +84,25 @@ class NativeModuleConfig(ModuleConfig): shutdown_timeout: float = 10.0 log_format: LogFormat = LogFormat.TEXT + # New version of Native Modules read json configs from stdin + # Enable this to read from stdin instead of cli args + stdin_config: bool = False + # Override in subclasses to exclude fields from CLI arg generation cli_exclude: frozenset[str] = frozenset() - def to_cli_args(self) -> list[str]: - """Auto-convert subclass config fields to CLI args. + def to_config_dict(self) -> dict[str, Any]: + """ + Return module-specific config fields as a plain dict (for stdin JSON). + """ + ignore_fields = set(NativeModuleConfig.model_fields) + return { + k: v for k, v in self.model_dump().items() if k not in ignore_fields and v is not None + } - Iterates fields defined on the concrete subclass (not NativeModuleConfig - or its parents) and converts them to ``["--name", str(value)]`` pairs. - Skips fields whose values are ``None`` and fields in ``cli_exclude``. + def to_cli_args(self) -> list[str]: + """ + Auto-convert subclass config fields to CLI args. """ ignore_fields = {f for f in NativeModuleConfig.model_fields} args: list[str] = [] @@ -172,9 +182,18 @@ def start(self) -> None: cmd, env=env, cwd=cwd, + stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) + assert self._process.stdin is not None + if self.config.stdin_config: + config_dict = self.config.to_config_dict() + stdin_blob = ( + json.dumps({"topics": topics, "config": config_dict or None}).encode() + b"\n" + ) + self._process.stdin.write(stdin_blob) + self._process.stdin.close() logger.info( f"Native process started: {module_name}", module=module_name, diff --git a/examples/native-modules/rust/.gitignore b/examples/native-modules/rust/.gitignore new file mode 100644 index 0000000000..2f7896d1d1 --- /dev/null +++ b/examples/native-modules/rust/.gitignore @@ -0,0 +1 @@ +target/ diff --git a/examples/native-modules/rust/Cargo.lock b/examples/native-modules/rust/Cargo.lock new file mode 100644 index 0000000000..420f9b0ef4 --- /dev/null +++ b/examples/native-modules/rust/Cargo.lock @@ -0,0 +1,305 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "dimos-lcm" +version = "0.1.0" +source = "git+https://github.com/dimensionalOS/dimos-lcm.git?branch=rust-codegen#fd2e7e2d28597b34dce1d92d3065796a1b722590" +dependencies = [ + "byteorder", + "socket2 0.5.10", + "tokio", +] + +[[package]] +name = "dimos-native-module" +version = "0.1.0" +dependencies = [ + "dimos-lcm", + "serde", + "serde_json", + "tokio", +] + +[[package]] +name = "dimos-native-module-examples" +version = "0.1.0" +dependencies = [ + "dimos-native-module", + "lcm-msgs", + "serde", + "tokio", +] + +[[package]] +name = "itoa" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" + +[[package]] +name = "lcm-msgs" +version = "0.1.0" +source = "git+https://github.com/dimensionalOS/dimos-lcm.git?branch=rust-codegen#fd2e7e2d28597b34dce1d92d3065796a1b722590" +dependencies = [ + "byteorder", +] + +[[package]] +name = "libc" +version = "0.2.185" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52ff2c0fe9bc6cb6b14a0592c2ff4fa9ceb83eea9db979b0487cd054946a2b8f" + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "mio" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.61.2", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "socket2" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "socket2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tokio" +version = "1.52.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b67dee974fe86fd92cc45b7a95fdd2f99a36a6d7b0d431a231178d3d670bbcc6" +dependencies = [ + "libc", + "mio", + "pin-project-lite", + "socket2 0.6.3", + "tokio-macros", + "windows-sys 0.61.2", +] + +[[package]] +name = "tokio-macros" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "385a6cb71ab9ab790c5fe8d67f1645e6c450a7ce006a33de03daa956cf70a496" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/examples/native-modules/rust/Cargo.toml b/examples/native-modules/rust/Cargo.toml new file mode 100644 index 0000000000..65b9a8244f --- /dev/null +++ b/examples/native-modules/rust/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "dimos-native-module-examples" +version = "0.1.0" +edition = "2021" + +[[bin]] +name = "native_ping" +path = "src/native_ping.rs" + +[[bin]] +name = "native_pong" +path = "src/native_pong.rs" + +[dependencies] +dimos-native-module = { path = "../../../native/rust" } +lcm-msgs = { git = "https://github.com/dimensionalOS/dimos-lcm.git", branch = "rust-codegen" } +tokio = { version = "1", features = ["rt-multi-thread", "macros", "time"] } +serde = { version = "1", features = ["derive"] } diff --git a/examples/native-modules/rust/src/native_ping.rs b/examples/native-modules/rust/src/native_ping.rs new file mode 100644 index 0000000000..ab19c78b28 --- /dev/null +++ b/examples/native-modules/rust/src/native_ping.rs @@ -0,0 +1,40 @@ +// NativeModule ping example. +// +// Sends a Twist message at 5 Hz and logs each echo received on `confirm`. + +use dimos_native_module::{LcmTransport, NativeModule}; +use lcm_msgs::geometry_msgs::{Twist, Vector3}; +use tokio::time::{interval, Duration}; + +#[tokio::main] +async fn main() { + let transport = LcmTransport::new() + .await + .expect("Failed to create transport"); + let (mut module, _config) = NativeModule::from_stdin::<()>(transport) + .await + .expect("Failed to read config from stdin"); + + let mut confirm = module.input("confirm", Twist::decode); + let data = module.output("data", Twist::encode); + let _handle = module.spawn(); + + let mut ticker = interval(Duration::from_millis(200)); + let mut seq = 0u64; + + loop { + tokio::select! { + _ = ticker.tick() => { + let msg = Twist { + linear: Vector3 { x: seq as f64, y: 0.0, z: 0.0 }, + angular: Vector3 { x: 0.0, y: 0.0, z: 0.0 }, + }; + data.publish(&msg).await.ok(); + seq += 1; + } + Some(echo) = confirm.recv() => { + eprintln!("ping: echo received (seq={}, sample_config={})", echo.linear.x as u64, echo.angular.z as i64); + } + } + } +} diff --git a/examples/native-modules/rust/src/native_pong.rs b/examples/native-modules/rust/src/native_pong.rs new file mode 100644 index 0000000000..74109eb4ba --- /dev/null +++ b/examples/native-modules/rust/src/native_pong.rs @@ -0,0 +1,49 @@ +// NativeModule pong example. +// +// Receives Twist messages on `data` and echoes each one back on `confirm`, +// embedding the sample_config value in the reply's angular.z field. + +use dimos_native_module::{LcmTransport, NativeModule}; +use lcm_msgs::geometry_msgs::{Twist, Vector3}; +use serde::Deserialize; + +#[derive(Debug, Deserialize, Default)] +#[serde(deny_unknown_fields)] +struct PongConfig { + sample_config: i64, +} + +#[tokio::main] +async fn main() { + let transport = LcmTransport::new() + .await + .expect("Failed to create transport"); + let (mut module, config) = NativeModule::from_stdin::(transport) + .await + .expect("Failed to read config from stdin"); + + eprintln!("pong: sample_config={}", config.sample_config); + + let mut data = module.input("data", Twist::decode); + let confirm = module.output("confirm", Twist::encode); + let _handle = module.spawn(); + + eprintln!("pong ready"); + + loop { + match data.recv().await { + Some(msg) => { + let reply = Twist { + linear: msg.linear, + angular: Vector3 { + x: 0.0, + y: 0.0, + z: config.sample_config as f64, + }, + }; + confirm.publish(&reply).await.ok(); + } + None => break, + } + } +} diff --git a/examples/native-modules/rust_ping_pong.py b/examples/native-modules/rust_ping_pong.py new file mode 100644 index 0000000000..3672418050 --- /dev/null +++ b/examples/native-modules/rust_ping_pong.py @@ -0,0 +1,76 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Two Rust NativeModules for a simple ping-pong example. + +PingModule and PongModule both declare a `data` port (Twist) and a `confirm` port (Twist). +Ping publishes to data (received by pong), and Pong publishes to confirm (received by ping). +Topics and module configs are sent through stdin to modules. + +Run with: + python examples/native-modules/rust_ping_pong.py +""" + +from __future__ import annotations + +from pathlib import Path + +from dimos.core.coordination.blueprints import autoconnect +from dimos.core.coordination.module_coordinator import ModuleCoordinator +from dimos.core.native_module import NativeModule, NativeModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.Twist import Twist + +_RUST_DIR = Path(__file__).parent / "rust" +_EXAMPLES = _RUST_DIR / "target" / "release" + + +class PingConfig(NativeModuleConfig): + executable: str = str(_EXAMPLES / "native_ping") + build_command: str = "cargo build --release" + cwd: str = str(_RUST_DIR) + stdin_config: bool = True + + +class PongConfig(NativeModuleConfig): + executable: str = str(_EXAMPLES / "native_pong") + build_command: str = "cargo build --release" + cwd: str = str(_RUST_DIR) + stdin_config: bool = True + sample_config: int = 42 + + +class PingModule(NativeModule): + """Publishes Twist messages at 5 Hz on `data` and logs echoes from `confirm`.""" + + config: PingConfig + data: Out[Twist] + confirm: In[Twist] + + +class PongModule(NativeModule): + """Echoes every received Twist message back.""" + + config: PongConfig + data: In[Twist] + confirm: Out[Twist] + + +if __name__ == "__main__": + ModuleCoordinator.build( + autoconnect( + PingModule.blueprint(), + PongModule.blueprint(), + ).global_config(viewer="none") + ).loop() diff --git a/native/rust/.gitignore b/native/rust/.gitignore new file mode 100644 index 0000000000..2f7896d1d1 --- /dev/null +++ b/native/rust/.gitignore @@ -0,0 +1 @@ +target/ diff --git a/native/rust/Cargo.lock b/native/rust/Cargo.lock new file mode 100644 index 0000000000..45982487ec --- /dev/null +++ b/native/rust/Cargo.lock @@ -0,0 +1,296 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "dimos-lcm" +version = "0.1.0" +source = "git+https://github.com/dimensionalOS/dimos-lcm.git?branch=rust-codegen#50538b1372d6e06fdb0399abc6a35c2aa650a72f" +dependencies = [ + "byteorder", + "socket2 0.5.10", + "tokio", +] + +[[package]] +name = "dimos-native-module" +version = "0.1.0" +dependencies = [ + "dimos-lcm", + "lcm-msgs", + "serde", + "serde_json", + "tokio", +] + +[[package]] +name = "itoa" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" + +[[package]] +name = "lcm-msgs" +version = "0.1.0" +source = "git+https://github.com/dimensionalOS/dimos-lcm.git?branch=rust-codegen#50538b1372d6e06fdb0399abc6a35c2aa650a72f" +dependencies = [ + "byteorder", +] + +[[package]] +name = "libc" +version = "0.2.185" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52ff2c0fe9bc6cb6b14a0592c2ff4fa9ceb83eea9db979b0487cd054946a2b8f" + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "mio" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.61.2", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "socket2" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "socket2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tokio" +version = "1.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a91135f59b1cbf38c91e73cf3386fca9bb77915c45ce2771460c9d92f0f3d776" +dependencies = [ + "libc", + "mio", + "pin-project-lite", + "socket2 0.6.3", + "tokio-macros", + "windows-sys 0.61.2", +] + +[[package]] +name = "tokio-macros" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "385a6cb71ab9ab790c5fe8d67f1645e6c450a7ce006a33de03daa956cf70a496" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/native/rust/Cargo.toml b/native/rust/Cargo.toml new file mode 100644 index 0000000000..e3e24a6ad3 --- /dev/null +++ b/native/rust/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "dimos-native-module" +version = "0.1.0" +edition = "2021" +description = "Rust native module SDK for dimos NativeModule framework" +license = "Apache-2.0" + +[dependencies] +dimos-lcm = { git = "https://github.com/dimensionalOS/dimos-lcm.git", branch = "rust-codegen" } +tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync", "time"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" + +[dev-dependencies] +lcm-msgs = { git = "https://github.com/dimensionalOS/dimos-lcm.git", branch = "rust-codegen" } diff --git a/native/rust/src/lcm.rs b/native/rust/src/lcm.rs new file mode 100644 index 0000000000..a4fbd027f4 --- /dev/null +++ b/native/rust/src/lcm.rs @@ -0,0 +1,29 @@ +use std::io; + +use dimos_lcm::{Lcm, LcmOptions}; + +use crate::transport::Transport; + +/// LCM UDP multicast transport. Wraps `dimos_lcm::Lcm`. +pub struct LcmTransport(Lcm); + +impl LcmTransport { + pub async fn new() -> io::Result { + Ok(Self(Lcm::new().await?)) + } + + pub async fn with_options(opts: LcmOptions) -> io::Result { + Ok(Self(Lcm::with_options(opts).await?)) + } +} + +impl Transport for LcmTransport { + async fn publish(&self, channel: &str, data: &[u8]) -> io::Result<()> { + self.0.publish(channel, data).await + } + + async fn recv(&mut self) -> io::Result<(String, Vec)> { + let msg = self.0.recv().await?; + Ok((msg.channel, msg.data)) + } +} diff --git a/native/rust/src/lib.rs b/native/rust/src/lib.rs new file mode 100644 index 0000000000..d98866417f --- /dev/null +++ b/native/rust/src/lib.rs @@ -0,0 +1,10 @@ +pub mod lcm; +pub mod module; +pub mod transport; + +pub use lcm::LcmTransport; +pub use module::{Input, NativeModule, NativeModuleHandle, Output}; +pub use transport::Transport; + +// Re-export LcmOptions so callers don't need to depend on dimos-lcm directly. +pub use dimos_lcm::LcmOptions; diff --git a/native/rust/src/module.rs b/native/rust/src/module.rs new file mode 100644 index 0000000000..37ac2bd5e7 --- /dev/null +++ b/native/rust/src/module.rs @@ -0,0 +1,416 @@ +use std::collections::HashMap; +use std::io::{self, BufRead}; +use tokio::sync::mpsc; + +use serde::de::DeserializeOwned; + +use crate::transport::Transport; + +const INPUT_CHANNEL_CAPACITY: usize = 16; +const PUBLISH_CHANNEL_CAPACITY: usize = 64; + +// Each input() call produces a TypedRoute that decodes its message type +// and forwards it to the right Input's mpsc channel. +trait Route: Send { + fn topic(&self) -> &str; + fn try_dispatch(&self, data: &[u8]); +} + +struct TypedRoute { + topic: String, + decode: fn(&[u8]) -> io::Result, + sender: mpsc::Sender, +} + +impl Route for TypedRoute { + fn topic(&self) -> &str { + &self.topic + } + + fn try_dispatch(&self, data: &[u8]) { + match (self.decode)(data) { + // If the input channel is full, the newest message is dropped. + Ok(msg) => { + let _ = self.sender.try_send(msg); + } + Err(e) => eprintln!("dimos_module: decode error on {}: {e}", self.topic), + } + } +} +pub struct Input { + pub topic: String, + receiver: mpsc::Receiver, +} + +impl Input { + pub async fn recv(&mut self) -> Option { + self.receiver.recv().await + } +} + +pub struct Output { + pub topic: String, + encode: fn(&T) -> Vec, + sender: mpsc::Sender<(String, Vec)>, +} + +impl Output { + pub async fn publish(&self, msg: &T) -> io::Result<()> { + let data = (self.encode)(msg); + self.sender + .send((self.topic.clone(), data)) + .await + .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "background task gone")) + } +} + +/// Parse a JSON config line as written by the Python NativeModule coordinator. +/// Returns `(topics, config)`. Extracted so it can be unit-tested without stdin. +fn parse_config_json(line: &str) -> io::Result<(HashMap, C)> { + let json: serde_json::Value = serde_json::from_str(line.trim()) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + + let mut topics = HashMap::new(); + if let Some(t) = json.get("topics").and_then(|v| v.as_object()) { + for (port, topic) in t { + if let Some(s) = topic.as_str() { + topics.insert(port.clone(), s.to_string()); + } + } + } + + let config: C = match json.get("config") { + None => return Err(io::Error::new( + io::ErrorKind::InvalidData, + "missing 'config' field in stdin JSON — coordinator must always send a config object", + )), + Some(v) => serde_json::from_value(v.clone()).map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("failed to deserialize config: {e}"), + ) + })?, + }; + + Ok((topics, config)) +} + +/// High-level wrapper around a transport for use in dimos native modules. +/// +/// Generic over any `T: Transport`. Use `LcmTransport` for the standard LCM +/// UDP multicast transport. +/// +/// # Usage +/// +/// ```ignore +/// let transport = LcmTransport::new().await?; +/// let (mut module, config) = NativeModule::from_stdin::(transport).await?; +/// +/// let mut image_in = module.input("color_image", Image::decode); +/// let cmd_out = module.output("cmd_vel", Twist::encode); +/// let _handle = module.spawn(); +/// +/// loop { +/// tokio::select! { +/// Some(frame) = image_in.recv() => { cmd_out.publish(&twist).await.ok(); } +/// } +/// } +/// ``` +pub struct NativeModule { + transport: T, + routes: Vec>, + topics: HashMap, + publish_tx: mpsc::Sender<(String, Vec)>, + publish_rx: mpsc::Receiver<(String, Vec)>, +} + +impl NativeModule { + pub(crate) fn new(transport: T) -> Self { + let (publish_tx, publish_rx) = mpsc::channel(PUBLISH_CHANNEL_CAPACITY); + Self { + transport, + routes: Vec::new(), + topics: HashMap::new(), + publish_tx, + publish_rx, + } + } + + /// Parse `--port_name topic_string` pairs from argv, as injected by NativeModule. + pub async fn from_args(transport: T) -> io::Result { + let mut module = Self::new(transport); + let args: Vec = std::env::args().collect(); + let mut i = 1; + while i < args.len() { + if let Some(port) = args[i].strip_prefix("--") { + if i + 1 < args.len() && !args[i + 1].starts_with("--") { + module.topics.insert(port.to_string(), args[i + 1].clone()); + i += 2; + continue; + } + } + i += 1; + } + Ok(module) + } + + /// Read config from a single JSON line on stdin, as written by the Python NativeModule declaration. + /// + /// The JSON format is: + /// ```json + /// {"topics": {"port_name": "lcm/topic", ...}, "config": { ... }} + /// ``` + /// + /// `C` is the module-specific config type. Use `()` for modules with no configuration. + pub async fn from_stdin( + transport: T, + ) -> io::Result<(Self, C)> { + let mut line = String::new(); + io::stdin().lock().read_line(&mut line)?; + + let (topics, config) = parse_config_json::(&line)?; + + let mut module = Self::new(transport); + module.topics = topics; + + let exe = std::env::current_exe() + .ok() + .and_then(|p| p.file_name().map(|n| n.to_string_lossy().into_owned())) + .unwrap_or_else(|| "unknown".to_string()); + eprintln!("[{exe}] topics received:"); + for (port, topic) in &module.topics { + eprintln!(" {port} -> {topic}"); + } + eprintln!("[{exe}] config: {config:?}"); + + Ok((module, config)) + } + + /// Manually set a topic for a port — useful for testing without a parent process. + pub fn map_topic(&mut self, port: &str, topic: &str) { + self.topics.insert(port.to_string(), topic.to_string()); + } + + fn topic_for(&self, port: &str) -> String { + self.topics + .get(port) + .cloned() + .unwrap_or_else(|| format!("/{port}")) + } + + /// Register an input port. Must be called before `spawn()`. + pub fn input( + &mut self, + port: &str, + decode: fn(&[u8]) -> io::Result, + ) -> Input { + let topic = self.topic_for(port); + let (tx, rx) = mpsc::channel(INPUT_CHANNEL_CAPACITY); + self.routes.push(Box::new(TypedRoute { + topic: topic.clone(), + decode, + sender: tx, + })); + Input { + topic, + receiver: rx, + } + } + + /// Register an output port. Must be called before `spawn()`. + pub fn output(&self, port: &str, encode: fn(&M) -> Vec) -> Output { + Output { + topic: self.topic_for(port), + encode, + sender: self.publish_tx.clone(), + } + } + + /// Start the background recv/dispatch/publish loop. + /// + /// Consumes the module — no new ports can be registered after this point. + pub fn spawn(self) -> NativeModuleHandle { + let NativeModule { + mut transport, + routes, + mut publish_rx, + .. + } = self; + + let handle = tokio::spawn(async move { + loop { + tokio::select! { + result = transport.recv() => match result { + Ok((channel, data)) => { + for route in &routes { + if route.topic() == channel { + route.try_dispatch(&data); + } + } + } + Err(e) => { + eprintln!("dimos_module: recv error: {e}"); + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + } + }, + Some((topic, data)) = publish_rx.recv() => { + if let Err(e) = transport.publish(&topic, &data).await { + eprintln!("dimos_module: publish error on {topic}: {e}"); + } + } + } + } + }); + + NativeModuleHandle(handle) + } +} + +pub struct NativeModuleHandle(tokio::task::JoinHandle<()>); + +impl NativeModuleHandle { + pub async fn join(self) -> Result<(), tokio::task::JoinError> { + self.0.await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde::Deserialize; + + struct MockTransport; + + impl crate::transport::Transport for MockTransport { + async fn publish(&self, _channel: &str, _data: &[u8]) -> io::Result<()> { + Ok(()) + } + async fn recv(&mut self) -> io::Result<(String, Vec)> { + std::future::pending().await + } + } + + #[derive(Debug, Deserialize, Default, PartialEq)] + #[serde(deny_unknown_fields)] + struct TestConfig { + value: i64, + name: String, + } + + // --- parse_config_json --- + + #[test] + fn parses_topics_and_config() { + let json = r#"{"topics": {"data": "/foo/data", "confirm": "/foo/confirm"}, "config": {"value": 42, "name": "hello"}}"#; + let (topics, config) = parse_config_json::(json).unwrap(); + assert_eq!(topics["data"], "/foo/data"); + assert_eq!(topics["confirm"], "/foo/confirm"); + assert_eq!( + config, + TestConfig { + value: 42, + name: "hello".into() + } + ); + } + + #[test] + fn missing_config_field_returns_error() { + let json = r#"{"topics": {"data": "/foo/data"}}"#; + let result = parse_config_json::(json); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("missing 'config' field")); + } + + #[test] + fn null_config_succeeds_for_unit_type() { + let json = r#"{"topics": {}, "config": null}"#; + let (_topics, _config) = parse_config_json::<()>(json).unwrap(); + } + + #[test] + fn null_config_errors_when_struct_expects_fields() { + let json = r#"{"topics": {}, "config": null}"#; + let result = parse_config_json::(json); + assert!(result.is_err()); + } + + #[test] + fn empty_config_object_errors_when_struct_expects_fields() { + let json = r#"{"topics": {}, "config": {}}"#; + let result = parse_config_json::(json); + assert!(result.is_err()); + } + + #[test] + fn config_with_wrong_type_returns_error() { + let json = r#"{"topics": {}, "config": {"value": "not_a_number", "name": "x"}}"#; + let result = parse_config_json::(json); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("failed to deserialize config")); + } + + #[test] + fn missing_topics_field_gives_empty_map() { + let json = r#"{"config": {"value": 1, "name": "x"}}"#; + let (topics, _config) = parse_config_json::(json).unwrap(); + assert!(topics.is_empty()); + } + + #[test] + fn malformed_json_returns_error() { + let result = parse_config_json::<()>("not json at all"); + assert!(result.is_err()); + } + + #[test] + fn unknown_config_field_returns_error() { + let json = r#"{"topics": {}, "config": {"value": 1, "name": "x", "unexpected": true}}"#; + let result = parse_config_json::(json); + assert!(result.is_err()); + } + + // --- topic_for / map_topic --- + + #[test] + fn unmapped_port_falls_back_to_slash_port() { + let module = NativeModule::new(MockTransport); + assert_eq!(module.topic_for("cmd_vel"), "/cmd_vel"); + } + + #[test] + fn map_topic_overrides_fallback() { + let mut module = NativeModule::new(MockTransport); + module.map_topic("cmd_vel", "/robot/cmd_vel"); + assert_eq!(module.topic_for("cmd_vel"), "/robot/cmd_vel"); + } + + #[test] + fn input_uses_mapped_topic() { + let mut module = NativeModule::new(MockTransport); + module.map_topic("data", "/test/data"); + let input = module.input("data", |b| Ok(b.to_vec())); + assert_eq!(input.topic, "/test/data"); + } + + #[test] + fn input_falls_back_to_slash_port_when_unmapped() { + let mut module = NativeModule::new(MockTransport); + let input = module.input("data", |b| Ok(b.to_vec())); + assert_eq!(input.topic, "/data"); + } + + #[test] + fn output_uses_mapped_topic() { + let mut module = NativeModule::new(MockTransport); + module.map_topic("cmd_vel", "/robot/cmd_vel"); + let output = module.output("cmd_vel", |b: &Vec| b.clone()); + assert_eq!(output.topic, "/robot/cmd_vel"); + } +} diff --git a/native/rust/src/transport.rs b/native/rust/src/transport.rs new file mode 100644 index 0000000000..0322f52681 --- /dev/null +++ b/native/rust/src/transport.rs @@ -0,0 +1,13 @@ +use std::future::Future; +use std::io; + +/// Abstraction over the message transport used by a native module. +/// +/// New transport protocols should implement this trait. +/// `NativeModule` is generic over any transport +pub trait Transport: Send + 'static { + /// Send `data` on `channel`. + fn publish(&self, channel: &str, data: &[u8]) -> impl Future> + Send; + /// Block until the next inbound message, returning `(channel, data)`. + fn recv(&mut self) -> impl Future)>> + Send; +} From bd8e63e643577bc62585883131ec7634e23c8f96 Mon Sep 17 00:00:00 2001 From: leshy Date: Fri, 24 Apr 2026 21:06:34 +0300 Subject: [PATCH 15/30] Feat/memory2 - plotting, examples, recorder module, semantic search (#1769) Co-authored-by: RD <63036454+ruthwikdasyam@users.noreply.github.com> --- .gitattributes | 1 + .gitignore | 3 + README.md | 2 +- bin/hooks/lfs_check | 18 +- data/.lfs/go2_bigoffice.db.tar.gz | 4 +- dimos/core/global_config.py | 2 +- dimos/core/test_modules.py | 331 -------------- .../camera/gstreamer/gstreamer_camera.py | 4 +- .../camera/gstreamer/gstreamer_sender.py | 4 +- dimos/hardware/sensors/fake_zed_module.py | 10 +- dimos/mapping/occupancy/visualizations.py | 95 ++++ .../pointclouds/test_occupancy_speed.py | 4 +- dimos/mapping/test_voxels.py | 6 +- dimos/mapping/voxels.py | 2 +- dimos/memory/test_embedding.py | 7 +- dimos/memory/timeseries/base.py | 78 +--- dimos/memory/timeseries/legacy.py | 87 +--- dimos/memory2/architecture.md | 11 +- dimos/memory2/backend.py | 26 +- dimos/memory2/codecs/test_codecs.py | 2 +- dimos/memory2/embeddings.md | 10 +- dimos/memory2/intro.md | 23 +- dimos/memory2/module.py | 210 ++++++++- dimos/memory2/observationstore/sqlite.py | 24 +- dimos/memory2/store/base.py | 24 +- dimos/memory2/store/sqlite.py | 4 +- dimos/memory2/store/test_null.py | 4 +- dimos/memory2/stream.py | 214 ++++++--- dimos/memory2/streaming.md | 35 +- dimos/memory2/test_blobstore_integration.py | 12 +- dimos/memory2/test_e2e.py | 85 +++- dimos/memory2/test_embedding.py | 36 +- dimos/memory2/test_materialize.py | 49 ++ dimos/memory2/test_module.py | 8 +- dimos/memory2/test_registry.py | 2 +- dimos/memory2/test_save.py | 31 +- dimos/memory2/test_store.py | 24 +- dimos/memory2/test_stream.py | 64 +-- dimos/memory2/transform.py | 267 ++++++++++- dimos/memory2/type/observation.py | 68 +-- dimos/memory2/vectorstore/base.py | 8 +- dimos/memory2/vectorstore/memory.py | 4 +- dimos/memory2/vectorstore/sqlite.py | 8 +- dimos/memory2/vis/color.py | 271 ++++++++++++ dimos/memory2/vis/plot/elements.py | 105 +++++ dimos/memory2/vis/plot/plot.py | 121 +++++ dimos/memory2/vis/plot/rerun.py | 27 ++ dimos/memory2/vis/plot/svg.py | 258 +++++++++++ dimos/memory2/vis/plot/test_plot.py | 354 +++++++++++++++ dimos/memory2/vis/space/elements.py | 165 +++++++ dimos/memory2/vis/space/rerun.py | 336 ++++++++++++++ dimos/memory2/vis/space/space.py | 187 ++++++++ dimos/memory2/vis/space/svg.py | 348 +++++++++++++++ dimos/memory2/vis/space/test_space.py | 350 +++++++++++++++ dimos/memory2/vis/utils.py | 72 +++ dimos/models/embedding/base.py | 3 +- dimos/models/vl/test_vlm.py | 10 +- dimos/msgs/nav_msgs/OccupancyGrid.py | 2 +- dimos/msgs/sensor_msgs/Image.py | 11 + dimos/msgs/sensor_msgs/PointCloud2.py | 12 +- dimos/msgs/sensor_msgs/test_image.py | 4 +- dimos/perception/detection/conftest.py | 8 +- dimos/perception/detection/module2D.py | 2 +- .../detection/type/detection2d/bbox.py | 42 ++ .../detection2d/test_imageDetections2D.py | 31 ++ .../type/detection3d/imageDetections3DPC.py | 36 ++ .../detection/type/imageDetections.py | 23 +- .../perception/test_spatial_memory_module.py | 14 +- dimos/protocol/pubsub/impl/test_rospubsub.py | 5 +- dimos/robot/all_blueprints.py | 5 + dimos/robot/drone/dji_video_stream.py | 4 +- dimos/robot/drone/mavlink_connection.py | 4 +- dimos/robot/drone/test_drone.py | 32 +- .../go2/blueprints/basic/unitree_go2_basic.py | 10 +- .../go2/blueprints/smart/unitree_go2.py | 26 +- dimos/robot/unitree/go2/connection.py | 23 +- dimos/robot/unitree/modular/detect.py | 8 +- dimos/robot/unitree/testing/test_tooling.py | 8 +- dimos/simulation/engines/mujoco_sim_module.py | 2 - dimos/types/test_timestamped.py | 4 +- dimos/utils/data.py | 8 +- dimos/utils/test_data.py | 4 +- dimos/utils/testing/moment.py | 4 +- dimos/utils/testing/replay.py | 302 ++++++++++++- dimos/utils/testing/test_replay.py | 20 +- dimos/visualization/rerun/bridge.py | 21 +- dimos/visualization/rerun/init.py | 27 ++ docs/agents/testing.md | 2 +- docs/capabilities/memory/algo_comparison.md | 157 +++++++ .../capabilities/memory/assets/.gitattributes | 3 + .../capabilities/memory/assets/all_images.png | 3 + .../capabilities/memory/assets/brightness.svg | 3 + .../memory/assets/color_image.svg | 3 + docs/capabilities/memory/assets/embedding.svg | 3 + .../memory/assets/embedding_focused.svg | 3 + docs/capabilities/memory/assets/grid.png | 3 + .../memory/assets/peak_detections.svg | 3 + .../capabilities/memory/assets/peak_space.svg | 3 + docs/capabilities/memory/assets/plants.png | 3 + .../memory/assets/plants_auto.png | 3 + .../memory/assets/plants_meaningful.png | 3 + .../memory/assets/plants_peak_detections.png | 3 + .../memory/assets/plot_brightness_algo.svg | 3 + .../assets/plot_brightness_algo_delta.svg | 3 + .../memory/assets/plot_colors.svg | 3 + .../capabilities/memory/assets/plot_named.svg | 3 + .../memory/assets/plot_plantness.svg | 3 + .../assets/plot_plantness_autopeaks.svg | 3 + .../assets/plot_plantness_autopeaks2.svg | 3 + .../assets/plot_plantness_autopeaks_map.svg | 3 + .../assets/plot_plantness_brightness.svg | 3 + .../memory/assets/plot_plantness_gap_fill.svg | 3 + .../memory/assets/plot_plantness_marked.svg | 3 + .../assets/plot_plantness_significant.svg | 3 + .../memory/assets/plot_robot_data.svg | 3 + docs/capabilities/memory/assets/speed.svg | 3 + docs/capabilities/memory/demo_rerun.py | 88 ++++ docs/capabilities/memory/index.md | 210 +++++++++ docs/capabilities/memory/plot.md | 418 ++++++++++++++++++ docs/usage/cli.md | 2 +- flake.nix | 1 + pyproject.toml | 3 +- uv.lock | 58 +-- 123 files changed, 5280 insertions(+), 960 deletions(-) delete mode 100644 dimos/core/test_modules.py create mode 100644 dimos/memory2/test_materialize.py create mode 100644 dimos/memory2/vis/color.py create mode 100644 dimos/memory2/vis/plot/elements.py create mode 100644 dimos/memory2/vis/plot/plot.py create mode 100644 dimos/memory2/vis/plot/rerun.py create mode 100644 dimos/memory2/vis/plot/svg.py create mode 100644 dimos/memory2/vis/plot/test_plot.py create mode 100644 dimos/memory2/vis/space/elements.py create mode 100644 dimos/memory2/vis/space/rerun.py create mode 100644 dimos/memory2/vis/space/space.py create mode 100644 dimos/memory2/vis/space/svg.py create mode 100644 dimos/memory2/vis/space/test_space.py create mode 100644 dimos/memory2/vis/utils.py create mode 100644 dimos/visualization/rerun/init.py create mode 100644 docs/capabilities/memory/algo_comparison.md create mode 100644 docs/capabilities/memory/assets/.gitattributes create mode 100644 docs/capabilities/memory/assets/all_images.png create mode 100644 docs/capabilities/memory/assets/brightness.svg create mode 100644 docs/capabilities/memory/assets/color_image.svg create mode 100644 docs/capabilities/memory/assets/embedding.svg create mode 100644 docs/capabilities/memory/assets/embedding_focused.svg create mode 100644 docs/capabilities/memory/assets/grid.png create mode 100644 docs/capabilities/memory/assets/peak_detections.svg create mode 100644 docs/capabilities/memory/assets/peak_space.svg create mode 100644 docs/capabilities/memory/assets/plants.png create mode 100644 docs/capabilities/memory/assets/plants_auto.png create mode 100644 docs/capabilities/memory/assets/plants_meaningful.png create mode 100644 docs/capabilities/memory/assets/plants_peak_detections.png create mode 100644 docs/capabilities/memory/assets/plot_brightness_algo.svg create mode 100644 docs/capabilities/memory/assets/plot_brightness_algo_delta.svg create mode 100644 docs/capabilities/memory/assets/plot_colors.svg create mode 100644 docs/capabilities/memory/assets/plot_named.svg create mode 100644 docs/capabilities/memory/assets/plot_plantness.svg create mode 100644 docs/capabilities/memory/assets/plot_plantness_autopeaks.svg create mode 100644 docs/capabilities/memory/assets/plot_plantness_autopeaks2.svg create mode 100644 docs/capabilities/memory/assets/plot_plantness_autopeaks_map.svg create mode 100644 docs/capabilities/memory/assets/plot_plantness_brightness.svg create mode 100644 docs/capabilities/memory/assets/plot_plantness_gap_fill.svg create mode 100644 docs/capabilities/memory/assets/plot_plantness_marked.svg create mode 100644 docs/capabilities/memory/assets/plot_plantness_significant.svg create mode 100644 docs/capabilities/memory/assets/plot_robot_data.svg create mode 100644 docs/capabilities/memory/assets/speed.svg create mode 100644 docs/capabilities/memory/demo_rerun.py create mode 100644 docs/capabilities/memory/index.md create mode 100644 docs/capabilities/memory/plot.md diff --git a/.gitattributes b/.gitattributes index 8f05eb707f..55c93ccff2 100644 --- a/.gitattributes +++ b/.gitattributes @@ -16,3 +16,4 @@ *.mov filter=lfs diff=lfs merge=lfs -text binary *.gif filter=lfs diff=lfs merge=lfs -text binary *.foxe filter=lfs diff=lfs merge=lfs -text binary +docs/capabilities/memory/assets/** filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index 4045db012e..267aee13e4 100644 --- a/.gitignore +++ b/.gitignore @@ -77,3 +77,6 @@ CLAUDE.MD htmlcov/ .coverage .coverage.* + +# Memory2 autorecord +recording*.db diff --git a/README.md b/README.md index 6f40e8dc0e..a16c37f795 100644 --- a/README.md +++ b/README.md @@ -188,7 +188,7 @@ dimos run unitree-go2 | Run command | What it does | |-------------|-------------| | `dimos --replay run unitree-go2` | Quadruped navigation replay — SLAM, costmap, A* planning | -| `dimos --replay --replay-dir unitree_go2_office_walk2 run unitree-go2-temporal-memory` | Quadruped temporal memory replay | +| `dimos --replay --replay-db go2_bigoffice run unitree-go2-memory` | Quadruped temporal memory replay | | `dimos --simulation run unitree-go2-agentic` | Quadruped agentic + MCP server in simulation | | `dimos --simulation run unitree-g1` | Humanoid in MuJoCo simulation | | `dimos --replay run drone-basic` | Drone video + telemetry replay | diff --git a/bin/hooks/lfs_check b/bin/hooks/lfs_check index 0ddb847d56..3d493ec82d 100755 --- a/bin/hooks/lfs_check +++ b/bin/hooks/lfs_check @@ -8,6 +8,13 @@ NC='\033[0m' ROOT=$(git rev-parse --show-toplevel) cd $ROOT +# Glob patterns (matched against directory name) to ignore +IGNORE_GLOBS=( + ".lfs" + "*-wal" + "*-shm" +) + new_data=() # Enable nullglob to make globs expand to nothing when not matching @@ -19,8 +26,15 @@ for dir_path in data/*; do # Extract directory name dir_name=$(basename "$dir_path") - # Skip .lfs directory if it exists - [ "$dir_name" = ".lfs" ] && continue + # Skip ignored directories + skip=0 + for pat in "${IGNORE_GLOBS[@]}"; do + if [[ "$dir_name" == $pat ]]; then + skip=1 + break + fi + done + [ "$skip" = "1" ] && continue # Define compressed file path compressed_file="data/.lfs/${dir_name}.tar.gz" diff --git a/data/.lfs/go2_bigoffice.db.tar.gz b/data/.lfs/go2_bigoffice.db.tar.gz index 315610b5cb..540d7009ba 100644 --- a/data/.lfs/go2_bigoffice.db.tar.gz +++ b/data/.lfs/go2_bigoffice.db.tar.gz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:142f7a7d64d3b77c97acd0d15d53e9ea28c4f558776a6bb3919a4da32c2f4d37 -size 192241937 +oid sha256:e66f5472e72f370446d8dcd802f70f3c3c07e4e083c5d6a394873877dec4c88d +size 196309743 diff --git a/dimos/core/global_config.py b/dimos/core/global_config.py index ccf5b0644c..214401959e 100644 --- a/dimos/core/global_config.py +++ b/dimos/core/global_config.py @@ -34,7 +34,7 @@ class GlobalConfig(BaseSettings): can_port: str | None = None simulation: bool = False replay: bool = False - replay_dir: str = "go2_sf_office" + replay_db: str = "go2_bigoffice" new_memory: bool = False viewer: ViewerBackend = "rerun" n_workers: int = 2 diff --git a/dimos/core/test_modules.py b/dimos/core/test_modules.py deleted file mode 100644 index d96b58af5f..0000000000 --- a/dimos/core/test_modules.py +++ /dev/null @@ -1,331 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Test that all Module subclasses implement required resource management methods.""" - -import ast -import inspect -from pathlib import Path - -import pytest - -from dimos.core.module import Module - - -class ModuleVisitor(ast.NodeVisitor): - """AST visitor to find classes and their base classes.""" - - def __init__(self, filepath: str) -> None: - self.filepath = filepath - self.classes: list[ - tuple[str, list[str], set[str]] - ] = [] # (class_name, base_classes, methods) - - def visit_ClassDef(self, node: ast.ClassDef) -> None: - """Visit a class definition.""" - # Get base class names - base_classes = [] - for base in node.bases: - if isinstance(base, ast.Name): - base_classes.append(base.id) - elif isinstance(base, ast.Attribute): - # Handle cases like dimos.core.Module - parts = [] - current = base - while isinstance(current, ast.Attribute): - parts.append(current.attr) - current = current.value - if isinstance(current, ast.Name): - parts.append(current.id) - base_classes.append(".".join(reversed(parts))) - - # Get method names defined in this class - methods = set() - for item in node.body: - if isinstance(item, ast.FunctionDef): - methods.add(item.name) - - self.classes.append((node.name, base_classes, methods)) - self.generic_visit(node) - - -def get_import_aliases(tree: ast.AST) -> dict[str, str]: - """Extract import aliases from the AST.""" - aliases = {} - - for node in ast.walk(tree): - if isinstance(node, ast.Import): - for alias in node.names: - key = alias.asname if alias.asname else alias.name - aliases[key] = alias.name - elif isinstance(node, ast.ImportFrom): - module = node.module or "" - for alias in node.names: - key = alias.asname if alias.asname else alias.name - full_name = f"{module}.{alias.name}" if module else alias.name - aliases[key] = full_name - - return aliases - - -def is_module_subclass( - base_classes: list[str], - aliases: dict[str, str], - class_hierarchy: dict[str, list[str]] | None = None, - current_module_path: str | None = None, -) -> bool: - """Check if any base class is or resolves to dimos.core.Module or its variants (recursively).""" - target_classes = { - "Module", - "ModuleBase", - "dimos.core.Module", - "dimos.core.ModuleBase", - "dimos.core.module.Module", - "dimos.core.module.ModuleBase", - } - - def find_qualified_name(base: str, context_module: str | None = None) -> str: - """Find the qualified name for a base class, using import context if available.""" - if not class_hierarchy: - return base - - # First try exact match (already fully qualified or in hierarchy) - if base in class_hierarchy: - return base - - # Check if it's in our aliases (from imports) - if base in aliases: - resolved = aliases[base] - if resolved in class_hierarchy: - return resolved - # The resolved name might be a qualified name that exists - return resolved - - # If we have a context module and base is a simple name, - # try to find it in the same module first (for local classes) - if context_module and "." not in base: - same_module_qualified = f"{context_module}.{base}" - if same_module_qualified in class_hierarchy: - return same_module_qualified - - # Otherwise return the base as-is - return base - - def check_base( - base: str, visited: set[str] | None = None, context_module: str | None = None - ) -> bool: - if visited is None: - visited = set() - - # Avoid infinite recursion - if base in visited: - return False - visited.add(base) - - # Check direct match - if base in target_classes: - return True - - # Check if it's an alias - if base in aliases: - resolved = aliases[base] - if resolved in target_classes: - return True - # Continue checking with resolved name - base = resolved - - # If we have a class hierarchy, recursively check parent classes - if class_hierarchy: - # Resolve the base class name to a qualified name - qualified_name = find_qualified_name(base, context_module) - - if qualified_name in class_hierarchy: - # Check all parent classes - for parent_base in class_hierarchy[qualified_name]: - if check_base(parent_base, visited, None): # Parent lookups don't use context - return True - - return False - - for base in base_classes: - if check_base(base, context_module=current_module_path): - return True - - return False - - -def scan_file( - filepath: Path, - class_hierarchy: dict[str, list[str]] | None = None, - root_path: Path | None = None, -) -> list[tuple[str, str, bool, bool, set[str]]]: - """ - Scan a Python file for Module subclasses. - - Returns: - List of (class_name, filepath, has_start, has_stop, forbidden_methods) - """ - forbidden_method_names = {"acquire", "release", "open", "close", "shutdown", "clean", "cleanup"} - - try: - with open(filepath, encoding="utf-8") as f: - content = f.read() - - tree = ast.parse(content, filename=str(filepath)) - aliases = get_import_aliases(tree) - - visitor = ModuleVisitor(str(filepath)) - visitor.visit(tree) - - # Get module path for this file to properly resolve base classes - current_module_path = None - if root_path: - try: - rel_path = filepath.relative_to(root_path.parent) - module_parts = list(rel_path.parts[:-1]) - if rel_path.stem != "__init__": - module_parts.append(rel_path.stem) - current_module_path = ".".join(module_parts) - except ValueError: - pass - - results = [] - for class_name, base_classes, methods in visitor.classes: - if is_module_subclass(base_classes, aliases, class_hierarchy, current_module_path): - has_start = "start" in methods - has_stop = "stop" in methods - forbidden_found = methods & forbidden_method_names - results.append((class_name, str(filepath), has_start, has_stop, forbidden_found)) - - return results - - except (SyntaxError, UnicodeDecodeError): - # Skip files that can't be parsed - return [] - - -def build_class_hierarchy(root_path: Path) -> dict[str, list[str]]: - """Build a complete class hierarchy by scanning all Python files.""" - hierarchy = {} - - for filepath in sorted(root_path.rglob("*.py")): - # Skip __pycache__ and other irrelevant directories - if "__pycache__" in filepath.parts or ".venv" in filepath.parts: - continue - - try: - with open(filepath, encoding="utf-8") as f: - content = f.read() - - tree = ast.parse(content, filename=str(filepath)) - visitor = ModuleVisitor(str(filepath)) - visitor.visit(tree) - - # Convert filepath to module path (e.g., dimos/core/module.py -> dimos.core.module) - try: - rel_path = filepath.relative_to(root_path.parent) - except ValueError: - # If we can't get relative path, skip this file - continue - - # Convert path to module notation - module_parts = list(rel_path.parts[:-1]) # Exclude filename - if rel_path.stem != "__init__": - module_parts.append(rel_path.stem) # Add filename without .py - module_name = ".".join(module_parts) - - for class_name, base_classes, _ in visitor.classes: - # Use fully qualified name as key to avoid conflicts - qualified_name = f"{module_name}.{class_name}" if module_name else class_name - hierarchy[qualified_name] = base_classes - - except (SyntaxError, UnicodeDecodeError): - # Skip files that can't be parsed - continue - - return hierarchy - - -def scan_directory(root_path: Path) -> list[tuple[str, str, bool, bool, set[str]]]: - """Scan all Python files in the directory tree.""" - # First, build the complete class hierarchy - class_hierarchy = build_class_hierarchy(root_path) - - # Then scan for Module subclasses using the complete hierarchy - results = [] - - for filepath in sorted(root_path.rglob("*.py")): - # Skip __pycache__ and other irrelevant directories - if "__pycache__" in filepath.parts or ".venv" in filepath.parts: - continue - - file_results = scan_file(filepath, class_hierarchy, root_path) - results.extend(file_results) - - return results - - -def get_all_module_subclasses(): - """Find all Module subclasses in the dimos codebase.""" - # Get the dimos package directory - dimos_file = inspect.getfile(Module) - dimos_path = Path(dimos_file).parent.parent # Go up from dimos/core/module.py to dimos/ - - results = scan_directory(dimos_path) - - # Filter out test modules and base classes - filtered_results = [] - for class_name, filepath, has_start, has_stop, forbidden_methods in results: - # Skip base module classes themselves - if class_name in ("Module", "ModuleBase"): - continue - - # Skip test-only modules (those defined in test_ files) - if "test_" in Path(filepath).name: - continue - - filtered_results.append((class_name, filepath, has_start, has_stop, forbidden_methods)) - - return filtered_results - - -@pytest.mark.parametrize( - "class_name,filepath,has_start,has_stop,forbidden_methods", - get_all_module_subclasses(), - ids=lambda val: val[0] if isinstance(val, str) else str(val), -) -def test_module_has_start_and_stop( - class_name: str, filepath, has_start, has_stop, forbidden_methods -) -> None: - """Test that Module subclasses implement start and stop methods and don't use forbidden methods.""" - # Get relative path for better error messages - try: - rel_path = Path(filepath).relative_to(Path.cwd()) - except ValueError: - rel_path = filepath - - errors = [] - - # Check for missing required methods - if not has_start: - errors.append("missing required method: start") - if not has_stop: - errors.append("missing required method: stop") - - # Check for forbidden methods - if forbidden_methods: - forbidden_list = ", ".join(sorted(forbidden_methods)) - errors.append(f"has forbidden method(s): {forbidden_list}") - - assert not errors, f"{class_name} in {rel_path} has issues:\n - " + "\n - ".join(errors) diff --git a/dimos/hardware/sensors/camera/gstreamer/gstreamer_camera.py b/dimos/hardware/sensors/camera/gstreamer/gstreamer_camera.py index 5bb20321d7..055d50f960 100644 --- a/dimos/hardware/sensors/camera/gstreamer/gstreamer_camera.py +++ b/dimos/hardware/sensors/camera/gstreamer/gstreamer_camera.py @@ -33,11 +33,11 @@ if "/usr/lib/python3/dist-packages" not in sys.path: sys.path.insert(0, "/usr/lib/python3/dist-packages") -import gi # type: ignore[import-not-found] +import gi # type: ignore[import-not-found,import-untyped] gi.require_version("Gst", "1.0") gi.require_version("GstApp", "1.0") -from gi.repository import GLib, Gst # type: ignore[import-not-found] +from gi.repository import GLib, Gst # type: ignore[import-not-found,import-untyped] logger = setup_logger(level=logging.INFO) diff --git a/dimos/hardware/sensors/camera/gstreamer/gstreamer_sender.py b/dimos/hardware/sensors/camera/gstreamer/gstreamer_sender.py index ba472b32e9..f886e922fd 100755 --- a/dimos/hardware/sensors/camera/gstreamer/gstreamer_sender.py +++ b/dimos/hardware/sensors/camera/gstreamer/gstreamer_sender.py @@ -24,11 +24,11 @@ if "/usr/lib/python3/dist-packages" not in sys.path: sys.path.insert(0, "/usr/lib/python3/dist-packages") -import gi # type: ignore[import-not-found] +import gi # type: ignore[import-not-found,import-untyped] gi.require_version("Gst", "1.0") gi.require_version("GstVideo", "1.0") -from gi.repository import GLib, Gst # type: ignore[import-not-found] +from gi.repository import GLib, Gst # type: ignore[import-not-found,import-untyped] # Initialize GStreamer Gst.init(None) diff --git a/dimos/hardware/sensors/fake_zed_module.py b/dimos/hardware/sensors/fake_zed_module.py index 4874419d21..41a431e16e 100644 --- a/dimos/hardware/sensors/fake_zed_module.py +++ b/dimos/hardware/sensors/fake_zed_module.py @@ -27,12 +27,12 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out +from dimos.memory.timeseries.legacy import LegacyPickleStore from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.sensor_msgs.Image import Image, ImageFormat from dimos.msgs.std_msgs.Header import Header from dimos.protocol.tf.tf import TF from dimos.utils.logging_config import setup_logger -from dimos.utils.testing.replay import TimedSensorReplay logger = setup_logger(level=logging.INFO) @@ -85,7 +85,7 @@ def image_autocast(x): # type: ignore[no-untyped-def] return x return x - color_replay = TimedSensorReplay(f"{self.recording_path}/color", autocast=image_autocast) + color_replay = LegacyPickleStore(f"{self.recording_path}/color", autocast=image_autocast) return color_replay.stream() @functools.cache @@ -102,7 +102,7 @@ def depth_autocast(x): # type: ignore[no-untyped-def] return x return x - depth_replay = TimedSensorReplay(f"{self.recording_path}/depth", autocast=depth_autocast) + depth_replay = LegacyPickleStore(f"{self.recording_path}/depth", autocast=depth_autocast) return depth_replay.stream() @functools.cache @@ -124,7 +124,7 @@ def pose_autocast(x): # type: ignore[no-untyped-def] return x return x - pose_replay = TimedSensorReplay(f"{self.recording_path}/pose", autocast=pose_autocast) + pose_replay = LegacyPickleStore(f"{self.recording_path}/pose", autocast=pose_autocast) return pose_replay.stream() @functools.cache @@ -200,7 +200,7 @@ def camera_info_autocast(x): # type: ignore[no-untyped-def] return x return x - info_replay = TimedSensorReplay( + info_replay = LegacyPickleStore( f"{self.recording_path}/camera_info", autocast=camera_info_autocast ) return info_replay.stream() diff --git a/dimos/mapping/occupancy/visualizations.py b/dimos/mapping/occupancy/visualizations.py index 36321896be..7c344e5130 100644 --- a/dimos/mapping/occupancy/visualizations.py +++ b/dimos/mapping/occupancy/visualizations.py @@ -16,6 +16,7 @@ from typing import Literal, TypeAlias import cv2 +import matplotlib.pyplot as plt import numpy as np from numpy.typing import NDArray @@ -136,6 +137,100 @@ def _interpolate_turbo(t: float) -> tuple[int, int, int]: ) +def generate_rgba_texture( + grid: OccupancyGrid, + colormap: str | None = None, + opacity: float = 1.0, + cost_range: tuple[int, int] | None = None, + background: str | None = None, +) -> NDArray[np.uint8]: + """Generate RGBA texture for an occupancy grid. + + Args: + grid: OccupancyGrid to render. + colormap: Optional matplotlib colormap name. + opacity: Blend factor (0.0 to 1.0). Blends towards background color. + cost_range: Optional (min, max) cost range. Cells outside range use background. + background: Hex color for background (e.g. "#484981"). Default is black. + + Returns: + RGBA numpy array of shape (height, width, 4). + Note: NOT flipped - caller handles orientation. + """ + if background is not None: + bg = background.lstrip("#") + bg_rgb = np.array([int(bg[i : i + 2], 16) for i in (0, 2, 4)], dtype=np.float32) + else: + bg_rgb = np.array([0, 0, 0], dtype=np.float32) + + if cost_range is not None: + in_range_mask = (grid.grid >= cost_range[0]) & (grid.grid <= cost_range[1]) + else: + in_range_mask = None + + if colormap is not None: + cmap = plt.get_cmap(colormap) + grid_float = grid.grid.astype(np.float32) + + vis = np.zeros((grid.height, grid.width, 4), dtype=np.uint8) + + free_mask = grid.grid == 0 + occupied_mask = grid.grid > 0 + + if np.any(free_mask): + fg = np.array(cmap(0.0)[:3]) * 255 + blended = fg * opacity + bg_rgb * (1 - opacity) + vis[free_mask, :3] = blended.astype(np.uint8) + vis[free_mask, 3] = 255 + + if np.any(occupied_mask): + costs = grid_float[occupied_mask] + cost_norm = 0.5 + (costs / 100) * 0.5 + fg = cmap(cost_norm)[:, :3] * 255 + blended = fg * opacity + bg_rgb * (1 - opacity) + vis[occupied_mask, :3] = blended.astype(np.uint8) + vis[occupied_mask, 3] = 255 + + unknown_mask = grid.grid == -1 + vis[unknown_mask] = 0 + + if in_range_mask is not None: + out_of_range = ~in_range_mask & (grid.grid != -1) + vis[out_of_range, :3] = bg_rgb.astype(np.uint8) + vis[out_of_range, 3] = 255 + + return vis + + # Default: Foxglove-style coloring + vis = np.zeros((grid.height, grid.width, 4), dtype=np.uint8) + + free_mask = grid.grid == 0 + occupied_mask = grid.grid > 0 + + fg_free = np.array([72, 73, 129], dtype=np.float32) + blended_free = fg_free * opacity + bg_rgb * (1 - opacity) + vis[free_mask, :3] = blended_free.astype(np.uint8) + vis[free_mask, 3] = 255 + + if np.any(occupied_mask): + costs = grid.grid[occupied_mask].astype(np.float32) + factor = (1 - costs / 100).clip(0, 1) + fg_occ = np.column_stack([72 * factor, 73 * factor, 129 * factor]) + blended_occ = fg_occ * opacity + bg_rgb * (1 - opacity) + vis[occupied_mask, :3] = blended_occ.astype(np.uint8) + vis[occupied_mask, 3] = 255 + + unknown_mask = grid.grid == -1 + vis[unknown_mask] = 0 + + if in_range_mask is not None: + out_of_range = ~in_range_mask & (grid.grid != -1) + vis[out_of_range, :3] = bg_rgb.astype(np.uint8) + vis[out_of_range, 3] = 255 + + return vis + + @lru_cache(maxsize=1) def _turbo_lut() -> NDArray[np.uint8]: # Pre-compute lookup table for all possible values (-1 to 100) diff --git a/dimos/mapping/pointclouds/test_occupancy_speed.py b/dimos/mapping/pointclouds/test_occupancy_speed.py index 115ee73ae0..ceed625765 100644 --- a/dimos/mapping/pointclouds/test_occupancy_speed.py +++ b/dimos/mapping/pointclouds/test_occupancy_speed.py @@ -19,6 +19,7 @@ from dimos.mapping.pointclouds.occupancy import OCCUPANCY_ALGOS from dimos.mapping.voxels import VoxelGrid +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.cli.plot import bar from dimos.utils.data import get_data, get_data_dir from dimos.utils.testing.replay import TimedSensorReplay @@ -28,7 +29,8 @@ def test_build_map(): grid = VoxelGrid() - for _ts, frame in TimedSensorReplay("unitree_go2_bigoffice/lidar").iterate(): + replay: TimedSensorReplay[PointCloud2] = TimedSensorReplay("go2_bigoffice/lidar") + for _ts, frame in replay.iterate_ts(): grid.add_frame(frame) pickle_file = get_data_dir() / "unitree_go2_bigoffice_map.pickle" diff --git a/dimos/mapping/test_voxels.py b/dimos/mapping/test_voxels.py index fc95b4652b..3684186051 100644 --- a/dimos/mapping/test_voxels.py +++ b/dimos/mapping/test_voxels.py @@ -20,10 +20,10 @@ from dimos.core.transport import LCMTransport from dimos.mapping.voxels import VoxelGrid +from dimos.memory.timeseries.legacy import LegacyPickleStore from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.utils.data import get_data from dimos.utils.testing.moment import OutputMoment -from dimos.utils.testing.replay import TimedSensorReplay from dimos.utils.testing.test_moment import Go2Moment @@ -109,7 +109,7 @@ def test_carving(grid: VoxelGrid, moment1: Go2MapperMoment, moment2: Go2MapperMo def test_ingest_a_few(grid: VoxelGrid) -> None: data_dir = get_data("unitree_go2_office_walk2") - lidar_store = TimedSensorReplay(f"{data_dir}/lidar") + lidar_store = LegacyPickleStore(f"{data_dir}/lidar") for i in [1, 4, 8]: frame = lidar_store.find_closest_seek(i) @@ -155,7 +155,7 @@ def test_roundtrip(moment1: Go2MapperMoment, voxel_size: float, expected_points: def test_roundtrip_range_preserved(grid: VoxelGrid) -> None: """Test that input coordinate ranges are preserved in output.""" data_dir = get_data("unitree_go2_office_walk2") - lidar_store = TimedSensorReplay(f"{data_dir}/lidar") + lidar_store = LegacyPickleStore(f"{data_dir}/lidar") frame = lidar_store.find_closest_seek(1.0) assert frame is not None diff --git a/dimos/mapping/voxels.py b/dimos/mapping/voxels.py index e7a5e7ec4e..67c1347cb5 100644 --- a/dimos/mapping/voxels.py +++ b/dimos/mapping/voxels.py @@ -242,7 +242,7 @@ class VoxelGridMapperConfig(ModuleConfig): frame_id: str = "world" -class VoxelGridMapper(StreamModule): +class VoxelGridMapper(StreamModule[PointCloud2, PointCloud2]): """Accumulate lidar point clouds into a global voxel map.""" config: VoxelGridMapperConfig diff --git a/dimos/memory/test_embedding.py b/dimos/memory/test_embedding.py index 9a59ed51e1..01c76b93cf 100644 --- a/dimos/memory/test_embedding.py +++ b/dimos/memory/test_embedding.py @@ -16,17 +16,16 @@ from dimos.memory.embedding import EmbeddingMemory, SpatialEntry from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped -from dimos.utils.data import get_data from dimos.utils.testing.replay import TimedSensorReplay -dir_name = "unitree_go2_bigoffice" - @pytest.mark.skip def test_embed_frame() -> None: """Test embedding a single frame.""" # Load a frame from recorded data - video = TimedSensorReplay(get_data(dir_name) / "video") + from dimos.msgs.sensor_msgs.Image import Image + + video: TimedSensorReplay[Image] = TimedSensorReplay("go2_bigoffice/color_image") frame = video.find_closest_seek(10) # Create memory and embed diff --git a/dimos/memory/timeseries/base.py b/dimos/memory/timeseries/base.py index 2831836020..a8e8654a2d 100644 --- a/dimos/memory/timeseries/base.py +++ b/dimos/memory/timeseries/base.py @@ -21,8 +21,6 @@ import reactivex as rx from reactivex import operators as ops -from reactivex.disposable import CompositeDisposable, Disposable -from reactivex.scheduler import TimeoutScheduler if TYPE_CHECKING: from collections.abc import Iterator @@ -292,74 +290,12 @@ def stream( Uses scheduler-based timing with absolute time reference to prevent drift. """ + from dimos.utils.testing.replay import timed_playback - def subscribe( - observer: rx.abc.ObserverBase[T], - scheduler: rx.abc.SchedulerBase | None = None, - ) -> rx.abc.DisposableBase: - sched = scheduler or TimeoutScheduler() - disp = CompositeDisposable() - is_disposed = False - - iterator = self.iterate_items( + return timed_playback( + lambda: self.iterate_items( seek=seek, duration=duration, from_timestamp=from_timestamp, loop=loop - ) - - try: - first_ts, first_data = next(iterator) - except StopIteration: - observer.on_completed() - return Disposable() - - start_local_time = time.time() - start_replay_time = first_ts - - observer.on_next(first_data) - - try: - next_message: tuple[float, T] | None = next(iterator) - except StopIteration: - observer.on_completed() - return disp - - def schedule_emission(message: tuple[float, T]) -> None: - nonlocal next_message, is_disposed - - if is_disposed: - return - - msg_ts, msg_data = message - - try: - next_message = next(iterator) - except StopIteration: - next_message = None - - target_time = start_local_time + (msg_ts - start_replay_time) / speed - delay = max(0.0, target_time - time.time()) - - def emit( - _scheduler: rx.abc.SchedulerBase, _state: object - ) -> rx.abc.DisposableBase | None: - if is_disposed: - return None - observer.on_next(msg_data) - if next_message is not None: - schedule_emission(next_message) - else: - observer.on_completed() - return None - - sched.schedule_relative(delay, emit) - - if next_message is not None: - schedule_emission(next_message) - - def dispose() -> None: - nonlocal is_disposed - is_disposed = True - disp.dispose() - - return Disposable(dispose) - - return rx.create(subscribe) + ), + speed=speed, + detect_loop=loop, + ) diff --git a/dimos/memory/timeseries/legacy.py b/dimos/memory/timeseries/legacy.py index a98b0baddf..27194462f7 100644 --- a/dimos/memory/timeseries/legacy.py +++ b/dimos/memory/timeseries/legacy.py @@ -22,13 +22,9 @@ from pathlib import Path import pickle import re -import time from typing import Any, cast -import reactivex as rx -from reactivex.disposable import CompositeDisposable, Disposable from reactivex.observable import Observable -from reactivex.scheduler import TimeoutScheduler from dimos.memory.timeseries.base import T, TimeSeriesStore from dimos.utils.data import get_data, get_data_dir @@ -323,82 +319,11 @@ def stream( Uses stored timestamps from pickle files for timing (not data.ts). """ + from dimos.utils.testing.replay import timed_playback - def subscribe( - observer: rx.abc.ObserverBase[T], - scheduler: rx.abc.SchedulerBase | None = None, - ) -> rx.abc.DisposableBase: - sched = scheduler or TimeoutScheduler() - disp = CompositeDisposable() - is_disposed = False - - iterator = self.iterate_ts( + return timed_playback( + lambda: self.iterate_ts( seek=seek, duration=duration, from_timestamp=from_timestamp, loop=loop - ) - - try: - first_ts, first_data = next(iterator) - except StopIteration: - observer.on_completed() - return Disposable() - - start_local_time = time.time() - start_replay_time = first_ts - - observer.on_next(first_data) - - try: - next_message: tuple[float, T] | None = next(iterator) - except StopIteration: - observer.on_completed() - return disp - - prev_ts = first_ts - - def schedule_emission(message: tuple[float, T]) -> None: - nonlocal next_message, is_disposed, start_local_time, start_replay_time, prev_ts - - if is_disposed: - return - - ts, data = message - - # Detect loop restart: timestamp jumped backwards - if ts < prev_ts: - start_local_time = time.time() - start_replay_time = ts - prev_ts = ts - - try: - next_message = next(iterator) - except StopIteration: - next_message = None - - target_time = start_local_time + (ts - start_replay_time) / speed - delay = max(0.0, target_time - time.time()) - - def emit( - _scheduler: rx.abc.SchedulerBase, _state: object - ) -> rx.abc.DisposableBase | None: - if is_disposed: - return None - observer.on_next(data) - if next_message is not None: - schedule_emission(next_message) - else: - observer.on_completed() - return None - - sched.schedule_relative(delay, emit) - - if next_message is not None: - schedule_emission(next_message) - - def dispose() -> None: - nonlocal is_disposed - is_disposed = True - disp.dispose() - - return Disposable(dispose) - - return rx.create(subscribe) + ), + speed=speed, + ) diff --git a/dimos/memory2/architecture.md b/dimos/memory2/architecture.md index 9dc805577f..c4a90a7085 100644 --- a/dimos/memory2/architecture.md +++ b/dimos/memory2/architecture.md @@ -83,27 +83,28 @@ images = store.stream("images") images.append(frame, ts=time.time(), pose=(x, y, z), tags={"camera": "front"}) # Query -recent = images.after(t).limit(10).fetch() -nearest = images.near(pose, radius=2.0).fetch() +recent = images.after(t).limit(10).to_list() +nearest = images.near(pose, radius=2.0).to_list() latest = images.last() # Transform (class or bare generator function) edges = images.transform(Canny()).save(store.stream("edges")) +edges.drain() # actually run the pipeline; .save() is lazy def running_avg(upstream): total, n = 0.0, 0 for obs in upstream: total += obs.data; n += 1 yield obs.derive(data=total / n) -avgs = stream.transform(running_avg).fetch() +avgs = stream.transform(running_avg).to_list() # Live for obs in images.live().transform(process): handle(obs) # Embed + search -images.transform(EmbedImages(clip)).save(store.stream("embedded")) -results = store.stream("embedded").search(query_vec, k=5).fetch() +images.transform(EmbedImages(clip)).save(store.stream("embedded")).drain() +results = store.stream("embedded").search(query_vec, k=5).to_list() ``` ## Implementations diff --git a/dimos/memory2/backend.py b/dimos/memory2/backend.py index 42843d0557..3c3b89669e 100644 --- a/dimos/memory2/backend.py +++ b/dimos/memory2/backend.py @@ -50,6 +50,7 @@ def __init__( *, metadata_store: ObservationStore[T], codec: Codec[Any], + data_type: type = object, blob_store: BlobStore | None = None, vector_store: VectorStore | None = None, notifier: Notifier[T] | None = None, @@ -58,6 +59,7 @@ def __init__( super().__init__() self.metadata_store = self.register_disposable(metadata_store) self.codec = codec + self.data_type = data_type self.blob_store = self.register_disposable(blob_store) if blob_store else None self.vector_store = self.register_disposable(vector_store) if vector_store else None self.notifier: Notifier[T] = self.register_disposable(notifier or SubjectNotifier()) @@ -87,9 +89,19 @@ def loader() -> Any: return loader def append(self, obs: Observation[T]) -> Observation[T]: + # Validate payload type matches stream type + if self.data_type is not object and not isinstance(obs._data, self.data_type): + raise TypeError( + f"Stream expects {self.data_type.__qualname__}, got {type(obs._data).__qualname__}" + ) + obs.data_type = self.data_type + + # Scalars are stored inline in the metadata value column — skip blob + is_scalar = isinstance(obs._data, (int, float)) + # Encode payload before any locking (avoids holding locks during IO) encoded: bytes | None = None - if self.blob_store is not None: + if self.blob_store is not None and not is_scalar: encoded = self.codec.encode(obs._data) try: @@ -97,7 +109,7 @@ def append(self, obs: Observation[T]) -> Observation[T]: row_id = self.metadata_store.insert(obs) obs.id = row_id - # Store blob + # Store blob (non-scalar data only) if encoded is not None: assert self.blob_store is not None self.blob_store.put(self.name, row_id, encoded) @@ -132,11 +144,14 @@ def iterate(self, query: StreamQuery) -> Iterator[Observation[T]]: return self._iterate_snapshot(query) def _attach_loaders(self, it: Iterator[Observation[T]]) -> Iterator[Observation[T]]: - """Attach lazy blob loaders to observations from the metadata store.""" + """Attach lazy blob loaders and data_type to observations from the metadata store.""" if self.blob_store is None: - yield from it + for obs in it: + obs.data_type = self.data_type + yield obs return for obs in it: + obs.data_type = self.data_type if obs._loader is None and isinstance(obs._data, type(_UNLOADED)): obs._loader = self._make_loader(obs.id) yield obs @@ -171,7 +186,7 @@ def _vector_search(self, query: StreamQuery) -> Iterator[Observation[T]]: vs = self.vector_store assert vs is not None and query.search_vec is not None - hits = vs.search(self.name, query.search_vec, query.search_k or 10) + hits = vs.search(self.name, query.search_vec, query.search_k) if not hits: return @@ -234,6 +249,7 @@ def count(self, query: StreamQuery) -> int: def serialize(self) -> dict[str, Any]: """Serialize the fully-resolved backend config to a dict.""" return { + "data_type": f"{self.data_type.__module__}.{self.data_type.__qualname__}", "codec_id": codec_id(self.codec), "eager_blobs": self.eager_blobs, "metadata_store": self.metadata_store.serialize() diff --git a/dimos/memory2/codecs/test_codecs.py b/dimos/memory2/codecs/test_codecs.py index eece78b1c3..48ab23f7c3 100644 --- a/dimos/memory2/codecs/test_codecs.py +++ b/dimos/memory2/codecs/test_codecs.py @@ -123,7 +123,7 @@ def _jpeg_case() -> Case | None: TurboJPEG() # fail fast if native lib is missing - replay = TimedSensorReplay("unitree_go2_bigoffice/video") + replay = TimedSensorReplay("go2_bigoffice/color_image") frames = [replay.find_closest_seek(float(i)) for i in range(1, 4)] codec = JpegCodec(quality=95) except (ImportError, RuntimeError): diff --git a/dimos/memory2/embeddings.md b/dimos/memory2/embeddings.md index 9028c29f9d..3e3f341c70 100644 --- a/dimos/memory2/embeddings.md +++ b/dimos/memory2/embeddings.md @@ -54,7 +54,7 @@ class Embed(Transformer[T, T]): ```python query_vec = clip.embed_text("a cat in the kitchen") -results = images.transform(Embed(clip)).search(query_vec, k=20).fetch() +results = images.transform(Embed(clip)).search(query_vec, k=20).to_list() # results[0].data → Image # results[0].embedding → np.ndarray # results[0].similarity → 0.93 @@ -64,7 +64,7 @@ results = images.transform(Embed(clip)) \ .search(query_vec, k=50) \ .after(one_hour_ago) \ .near(kitchen_pose, 5.0) \ - .fetch() + .to_list() ``` ## Backend Handles Storage Strategy @@ -109,7 +109,7 @@ for obs in logs.transform(Embed(clip.text)): unified.append(obs.data, ts=obs.ts, tags={"modality": "text"}, embedding=obs.embedding) -results = unified.search(query_vec, k=20).fetch() +results = unified.search(query_vec, k=20).to_list() # results[i].tags["modality"] tells you what it is ``` @@ -140,9 +140,9 @@ FTS is keyword-based, not embedding-based. Complementary, not competing: ```python # Keyword search via FTS5 logs = store.stream("logs") -logs.search_text("motor fault").fetch() +logs.search_text("motor fault").to_list() # Semantic search via embeddings log_idx = logs.transform(Embed(sentence_model)).store("log_emb") -log_idx.search(model.embed("motor problems"), k=10).fetch() +log_idx.search(model.embed("motor problems"), k=10).to_list() ``` diff --git a/dimos/memory2/intro.md b/dimos/memory2/intro.md index e88561c283..1b2153908b 100644 --- a/dimos/memory2/intro.md +++ b/dimos/memory2/intro.md @@ -54,7 +54,7 @@ Available filters: `.after(t)`, `.before(t)`, `.at(t)`, `.near(pose, radius)`, ` Terminals materialize or consume the stream: ```python session=memory ansi=false -print(logs.before(5.0).tags(level="error").fetch()) +print(logs.before(5.0).tags(level="error").to_list()) ``` @@ -62,7 +62,9 @@ print(logs.before(5.0).tags(level="error").fetch()) [Observation(id=2, ts=2.0, pose=None, tags={'level': 'error'})] ``` -Available terminals: `.fetch()`, `.first()`, `.last()`, `.count()`, `.exists()`, `.summary()`, `.get_time_range()`, `.drain()`, `.save(target)`. +Available terminals: `.to_list()`, `.first()`, `.last()`, `.count()`, `.exists()`, `.summary()`, `.get_time_range()`, `.drain()`, `.drain_thread()`. + +`.save(target)` is a lazy pass-through — pair it with a terminal (e.g. `.drain()` or `.drain_thread()`). ## Transforms @@ -143,7 +145,7 @@ from dimos.memory2.embed import EmbedText clip = CLIPModel() -for obs in logs.transform(EmbedText(clip)).search(clip.embed_text("hardware problem"), k=3).fetch(): +for obs in logs.transform(EmbedText(clip)).search(clip.embed_text("hardware problem"), k=3).to_list(): print(f"{obs.similarity:.3f} {obs.data}") ``` @@ -157,14 +159,15 @@ for obs in logs.transform(EmbedText(clip)).search(clip.embed_text("hardware prob The embedded stream above was ephemeral — built on the fly for one query. To persist embeddings automatically as logs arrive, pipe a live stream through the transform into a stored stream: ```python skip -import threading - embedded_logs = store.stream("embedded_logs", str) -threading.Thread( - target=lambda: logs.live().transform(EmbedText(clip)).save(embedded_logs), - daemon=True, -).start() +handle = ( + logs.live() + .transform(EmbedText(clip)) + .save(embedded_logs) + .drain_thread() +) # every new log is now automatically embedded and stored -# embedded_logs.search(query, k=5).fetch() to query at any time +# embedded_logs.search(query, k=5).to_list() to query at any time +# handle.dispose() to stop ``` diff --git a/dimos/memory2/module.py b/dimos/memory2/module.py index 6eb0b2160c..be781b9ba6 100644 --- a/dimos/memory2/module.py +++ b/dimos/memory2/module.py @@ -15,28 +15,78 @@ from __future__ import annotations import inspect -from typing import Any +import logging +import os +from pathlib import Path +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast +from pydantic import field_validator +from reactivex.disposable import Disposable + +from dimos.agents.annotation import skill +from dimos.constants import DIMOS_PROJECT_ROOT from dimos.core.core import rpc -from dimos.core.module import Module +from dimos.core.module import Module, ModuleConfig +from dimos.memory2.embed import EmbedImages from dimos.memory2.store.null import NullStore +from dimos.memory2.store.sqlite import SqliteStore from dimos.memory2.stream import Stream +from dimos.memory2.transform import QualityWindow +from dimos.memory2.type.observation import EmbeddedObservation, Observation +from dimos.models.embedding.base import EmbeddingModel +from dimos.models.embedding.clip import CLIPModel +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.sensor_msgs.Image import Image + +if TYPE_CHECKING: + from reactivex.abc import DisposableBase + + from dimos.core.stream import In, Out + +logger = logging.getLogger(__name__) +T = TypeVar("T") +TIn = TypeVar("TIn") +TOut = TypeVar("TOut") + + +def stream_to_port(stream: Stream[T], out: Out[T]) -> DisposableBase: + """Forward each observation's ``data`` from *stream* to a Module ``Out`` port. + + Iteration runs on the dimos thread pool via :meth:`Stream.observable`. + """ -class StreamModule(Module): - """Module base class that wires a memory2 stream pipeline. + def _on_error(e: Exception) -> None: + logger.error("stream_to_port() pipeline error: %s", e, exc_info=True) - **Static pipeline** + return stream.observable().subscribe( + on_next=lambda obs: out.publish(obs.data), + on_error=_on_error, + ) - class VoxelGridMapper(StreamModule): + +def port_to_stream(in_: In[T], stream: Stream[T]) -> DisposableBase: + """Append each message received on a Module ``In`` port to *stream*.""" + return Disposable(in_.subscribe(stream.append)) + + +class StreamModule(Module, Generic[TIn, TOut]): + """Module base class that wires a memory2 stream pipeline + and deploys it as a dimos module + + Parameterize with the In/Out data types so the pipeline is + statically typed end-to-end:: + + class VoxelGridMapper(StreamModule[PointCloud2, PointCloud2]): pipeline = Stream().transform(VoxelMapTransformer()) lidar: In[PointCloud2] global_map: Out[PointCloud2] **Config-driven pipeline** - class VoxelGridMapper(StreamModule[VoxelGridMapperConfig]): - def pipeline(self, stream: Stream) -> Stream: + class VoxelGridMapper(StreamModule[PointCloud2, PointCloud2]): + config: VoxelGridMapperConfig + def pipeline(self, stream: Stream[PointCloud2]) -> Stream[PointCloud2]: return stream.transform(VoxelMap(**self.config.model_dump())) lidar: In[PointCloud2] @@ -63,24 +113,23 @@ def start(self) -> None: f"found {len(self.inputs)} In and {len(self.outputs)} Out" ) - ((in_name, inp_port),) = self.inputs.items() - ((_, out_port),) = self.outputs.items() + ((in_name, in_port_raw),) = self.inputs.items() + ((_, out_port_raw),) = self.outputs.items() + in_port = cast("In[TIn]", in_port_raw) + out_port = cast("Out[TOut]", out_port_raw) store = self.register_disposable(NullStore()) store.start() - stream: Stream[Any] = store.stream(in_name, inp_port.type) + stream: Stream[TIn] = store.stream(in_name, in_port.type) # we push input into the stream - inp_port.subscribe(lambda msg: stream.append(msg)) + self.register_disposable(port_to_stream(in_port, stream)) - live = stream.live() # and we push stream output to the output port - self._apply_pipeline(live).subscribe( - lambda obs: out_port.publish(obs.data), - ) + self.register_disposable(stream_to_port(self._apply_pipeline(stream.live()), out_port)) - def _apply_pipeline(self, stream: Stream[Any]) -> Stream[Any]: + def _apply_pipeline(self, stream: Stream[TIn]) -> Stream[TOut]: """Apply the pipeline to a live stream. Handles both static (class attr) and dynamic (method) pipelines. @@ -108,3 +157,130 @@ def _apply_pipeline(self, stream: Stream[Any]) -> Stream[Any]: @rpc def stop(self) -> None: super().stop() + + +class MemoryModuleConfig(ModuleConfig): + db_path: str | Path = "recording.db" + + @field_validator("db_path", mode="before") + @classmethod + def _resolve_path(cls, v: str | Path) -> Path: + p = Path(os.fspath(v)) + if not p.is_absolute(): + p = DIMOS_PROJECT_ROOT / p + return p + + +class RecorderConfig(MemoryModuleConfig): + overwrite: bool = True + + +class MemoryModule(Module): + """Base class for memory-related modules, like recorders and search systems. + Provides a config with a db_path for the module's MemoryStore, and common start/stop logic. + + If changing the backend globally in dimos, this class will be replaced + """ + + config: MemoryModuleConfig + _store: SqliteStore | None = None + + @property + def store(self) -> SqliteStore: + if self._store is not None: + return self._store + + self._store = self.register_disposable( + SqliteStore(path=str(self.config.db_path)), + ) + self._store.start() + return self._store + + +class SemanticSearchConfig(MemoryModuleConfig): + embedding_model: type[EmbeddingModel] = CLIPModel + + +class SemanticSearch(MemoryModule): + config: SemanticSearchConfig + model: EmbeddingModel | None = None + embeddings: Stream[Any] | None = None + + @rpc + def start(self) -> None: + super().start() + + self.model = self.register_disposable(self.config.embedding_model()) + self.model.start() + + self.embeddings = self.store.stream("color_image_embedded", Image) + + # fmt: off + self.store.streams.color_image \ + .live() \ + .filter(lambda obs: obs.data.brightness > 0.1) \ + .transform(QualityWindow(lambda img: img.sharpness, window=0.5)) \ + .transform(EmbedImages(self.model, batch_size=2)) \ + .save(self.embeddings) \ + .drain_thread() + # fmt: on + + @skill + def search(self, query: str) -> PoseStamped: + from dimos.memory2.transform import peaks + + assert self.model is not None and self.embeddings is not None, ( + "SemanticSearch.search() called before start()" + ) + + query_vector = self.model.embed_text(query) + + # TODO(lesh): cluster results by peaks, then sort by time/distance + # depending on the desired weighting. + results = self.embeddings.search(query_vector) + + def _similarity(obs: Observation[Any]) -> float: + return cast("EmbeddedObservation[Any]", obs).similarity or 0.0 + + return results.transform(peaks(key=_similarity, distance=1.0)).last().pose_stamped + + +class Recorder(MemoryModule): + """Records all ``In`` ports to a memory2 SQLite database. + + Subclass with the topics you want to record:: + + class MyRecorder(Recorder): + color_image: In[Image] + lidar: In[PointCloud2] + + blueprint.add(MyRecorder, db_path="session.db") + """ + + config: RecorderConfig + + @rpc + def start(self) -> None: + super().start() + + # TODO: store reset API/logic is not implemented yet. This module + # shouldn't need to know about files (SqliteStore specific), and + # .live() subs need to know how to re-sub in case of a restart of + # this module in a deployed blueprint. + db_path = Path(self.config.db_path) + if db_path.exists(): + if self.config.overwrite: + db_path.unlink() + logger.info("Deleted existing recording %s", db_path) + else: + raise FileExistsError(f"Recording already exists: {db_path}") + + if not self.inputs: + logger.warning("Recorder has no In ports — nothing to record, subclass the Recorder") + return + + for name, port in self.inputs.items(): + stream: Stream[Any] = self.store.stream(name, port.type) + self.register_disposable(port_to_stream(port, stream)) + logger.info("Recording %s (%s)", name, port.type.__name__) + logger.info("Recording %s (%s)", name, port.type.__name__) diff --git a/dimos/memory2/observationstore/sqlite.py b/dimos/memory2/observationstore/sqlite.py index 64c619066e..42f4193581 100644 --- a/dimos/memory2/observationstore/sqlite.py +++ b/dimos/memory2/observationstore/sqlite.py @@ -156,9 +156,9 @@ def _compile_query( """ prefix = "meta." if join_blob else "" if join_blob: - select = f'SELECT meta.id, meta.ts, meta.pose_x, meta.pose_y, meta.pose_z, meta.pose_qx, meta.pose_qy, meta.pose_qz, meta.pose_qw, json(meta.tags), blob.data FROM "{table}" AS meta JOIN "{table}_blob" AS blob ON blob.id = meta.id' + select = f'SELECT meta.id, meta.ts, meta.value, meta.pose_x, meta.pose_y, meta.pose_z, meta.pose_qx, meta.pose_qy, meta.pose_qz, meta.pose_qw, json(meta.tags), blob.data FROM "{table}" AS meta JOIN "{table}_blob" AS blob ON blob.id = meta.id' else: - select = f'SELECT id, ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, json(tags) FROM "{table}"' + select = f'SELECT id, ts, value, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, json(tags) FROM "{table}"' where_parts: list[str] = [] params: list[Any] = [] @@ -281,6 +281,7 @@ def _ensure_tables(self) -> None: f'CREATE TABLE IF NOT EXISTS "{self._name}" (' " id INTEGER PRIMARY KEY AUTOINCREMENT," " ts REAL NOT NULL UNIQUE," + " value NUMERIC," " pose_x REAL, pose_y REAL, pose_z REAL," " pose_qx REAL, pose_qy REAL, pose_qz REAL, pose_qw REAL," " tags BLOB DEFAULT (jsonb('{}'))" @@ -317,14 +318,18 @@ def loader() -> Any: def _row_to_obs(self, row: tuple[Any, ...], *, has_blob: bool = False) -> Observation[T]: if has_blob: - row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json, blob_data = row + row_id, ts, value, px, py, pz, qx, qy, qz, qw, tags_json, blob_data = row else: - row_id, ts, px, py, pz, qx, qy, qz, qw, tags_json = row + row_id, ts, value, px, py, pz, qx, qy, qz, qw, tags_json = row blob_data = None pose = _reconstruct_pose(px, py, pz, qx, qy, qz, qw) tags = json.loads(tags_json) if tags_json else {} + # Scalar data stored inline in value column + if value is not None: + return Observation(id=row_id, ts=ts, pose=pose, tags=tags, _data=value) + if has_blob and blob_data is not None: assert self._codec is not None, "codec is required for data loading" data = self._codec.decode(blob_data) @@ -350,6 +355,7 @@ def _ensure_tag_indexes(self, tags: dict[str, Any]) -> None: def insert(self, obs: Observation[T]) -> int: pose = _decompose_pose(obs.pose) tags_json = json.dumps(obs.tags) if obs.tags else "{}" + value = obs._data if isinstance(obs._data, (int, float)) else None with self._lock: if obs.tags: @@ -360,9 +366,9 @@ def insert(self, obs: Observation[T]) -> int: px = py = pz = qx = qy = qz = qw = None # type: ignore[assignment] cur = self._conn.execute( - f'INSERT INTO "{self._name}" (ts, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags) ' - "VALUES (?, ?, ?, ?, ?, ?, ?, ?, jsonb(?))", - (obs.ts, px, py, pz, qx, qy, qz, qw, tags_json), + f'INSERT INTO "{self._name}" (ts, value, pose_x, pose_y, pose_z, pose_qx, pose_qy, pose_qz, pose_qw, tags) ' + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, jsonb(?))", + (obs.ts, value, px, py, pz, qx, qy, qz, qw, tags_json), ) row_id = cur.lastrowid assert row_id is not None @@ -423,7 +429,7 @@ def fetch_by_ids(self, ids: list[int]) -> list[Observation[T]]: placeholders = ",".join("?" * len(ids)) if join: sql = ( - f"SELECT meta.id, meta.ts, meta.pose_x, meta.pose_y, meta.pose_z, " + f"SELECT meta.id, meta.ts, meta.value, meta.pose_x, meta.pose_y, meta.pose_z, " f"meta.pose_qx, meta.pose_qy, meta.pose_qz, meta.pose_qw, json(meta.tags), blob.data " f'FROM "{self._name}" AS meta ' f'JOIN "{self._name}_blob" AS blob ON blob.id = meta.id ' @@ -431,7 +437,7 @@ def fetch_by_ids(self, ids: list[int]) -> list[Observation[T]]: ) else: sql = ( - f"SELECT id, ts, pose_x, pose_y, pose_z, " + f"SELECT id, ts, value, pose_x, pose_y, pose_z, " f"pose_qx, pose_qy, pose_qz, pose_qw, json(tags) " f'FROM "{self._name}" WHERE id IN ({placeholders})' ) diff --git a/dimos/memory2/store/base.py b/dimos/memory2/store/base.py index 37ebe8ffe4..f6096bdceb 100644 --- a/dimos/memory2/store/base.py +++ b/dimos/memory2/store/base.py @@ -37,29 +37,28 @@ class StreamAccessor: __slots__ = ("_store",) def __init__(self, store: Store) -> None: - object.__setattr__(self, "_store", store) + self._store = store def __getattr__(self, name: str) -> Stream[Any]: if name.startswith("_"): raise AttributeError(name) - store: Store = object.__getattribute__(self, "_store") - if name not in store.list_streams(): - raise AttributeError(f"No stream {name!r}. Available: {store.list_streams()}") - return store.stream(name) + if name not in self._store.list_streams(): + raise AttributeError(f"No stream {name!r}. Available: {self._store.list_streams()}") + return self._store.stream(name) def __getitem__(self, name: str) -> Stream[Any]: - store: Store = object.__getattribute__(self, "_store") - if name not in store.list_streams(): + if name not in self._store.list_streams(): raise KeyError(name) - return store.stream(name) + return self._store.stream(name) def __dir__(self) -> list[str]: - store: Store = object.__getattribute__(self, "_store") - return store.list_streams() + return self._store.list_streams() def __repr__(self) -> str: - names = object.__getattribute__(self, "_store").list_streams() - return f"StreamAccessor({names})" + return f"StreamAccessor({self._store.list_streams()})" + + def items(self) -> list[tuple[str, Stream[Any]]]: + return [(name, self._store.stream(name)) for name in self._store.list_streams()] class StoreConfig(BaseConfig): @@ -136,6 +135,7 @@ def _create_backend( return Backend( metadata_store=obs, codec=codec, + data_type=payload_type or object, blob_store=bs, vector_store=vs, notifier=notifier, diff --git a/dimos/memory2/store/sqlite.py b/dimos/memory2/store/sqlite.py index 55cf6d5777..bbe563baec 100644 --- a/dimos/memory2/store/sqlite.py +++ b/dimos/memory2/store/sqlite.py @@ -60,10 +60,11 @@ def _open_connection(self) -> sqlite3.Connection: def _assemble_backend(self, name: str, stored: dict[str, Any]) -> Backend[Any]: """Reconstruct a Backend from a stored config dict.""" - from dimos.memory2.codecs.base import codec_from_id + from dimos.memory2.codecs.base import _resolve_payload_type, codec_from_id payload_module = stored["payload_module"] codec = codec_from_id(stored["codec_id"], payload_module) + data_type = _resolve_payload_type(payload_module) eager_blobs = stored.get("eager_blobs", False) page_size = stored.get("page_size", self.config.page_size) @@ -110,6 +111,7 @@ def _assemble_backend(self, name: str, stored: dict[str, Any]) -> Backend[Any]: backend: Backend[Any] = Backend( metadata_store=metadata_store, codec=codec, + data_type=data_type, blob_store=bs, vector_store=vs, notifier=notifier, diff --git a/dimos/memory2/store/test_null.py b/dimos/memory2/store/test_null.py index 3461ff3d9d..8e1e0a5780 100644 --- a/dimos/memory2/store/test_null.py +++ b/dimos/memory2/store/test_null.py @@ -40,7 +40,7 @@ def test_max_size_zero_empty_query() -> None: stream = store.stream("test", str) stream.append("data") assert stream.count() == 0 - assert stream.fetch() == [] + assert stream.to_list() == [] def test_null_store_discards_history() -> None: @@ -53,4 +53,4 @@ def test_null_store_discards_history() -> None: stream.append(3) assert stream.count() == 0 - assert stream.fetch() == [] + assert stream.to_list() == [] diff --git a/dimos/memory2/stream.py b/dimos/memory2/stream.py index 75bf6ab6a0..3a92f7708f 100644 --- a/dimos/memory2/stream.py +++ b/dimos/memory2/stream.py @@ -14,8 +14,14 @@ from __future__ import annotations +import sys import time -from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, cast, overload + +if sys.version_info >= (3, 13): + from typing import TypeVar +else: + from typing_extensions import TypeVar from dimos.core.resource import CompositeResource from dimos.memory2.buffer import BackpressureBuffer, KeepLast @@ -45,10 +51,11 @@ T = TypeVar("T") R = TypeVar("R") +O = TypeVar("O", bound=Observation[Any], default=Observation[T]) logger = setup_logger() -class Stream(CompositeResource, Generic[T]): +class Stream(CompositeResource, Generic[T, O]): """Lazy, pull-based stream over observations. Every filter/transform method returns a new Stream — no computation @@ -56,7 +63,7 @@ class Stream(CompositeResource, Generic[T]): data; transform sources apply filters as Python predicates. Implements CompositeResource so subscriptions created via ``.subscribe()`` - and ``.publish()`` are tracked and disposed on ``stop()``. + are tracked and disposed on ``stop()``. An *unbound* stream (``Stream()``) records a chain of transforms without a real source. Use ``.chain()`` to apply it to a bound stream:: @@ -67,7 +74,7 @@ class Stream(CompositeResource, Generic[T]): def __init__( self, - source: Backend[T] | Stream[Any] | None = None, + source: Backend[T] | Stream[Any, Any] | None = None, *, transform: Transformer[Any, T] | None = None, query: StreamQuery = StreamQuery(), @@ -118,7 +125,7 @@ def is_live(self) -> bool: return self._source.is_live() return False - def __iter__(self) -> Iterator[Observation[T]]: + def __iter__(self) -> Iterator[O]: if self._source is None: raise TypeError( "Cannot iterate an unbound stream. Use .chain() to apply it to a real stream first." @@ -126,15 +133,15 @@ def __iter__(self) -> Iterator[Observation[T]]: if isinstance(self._source, Stream): return self._iter_transform() # Backend handles all query application (including live if requested) - return self._source.iterate(self._query) + return cast("Iterator[O]", self._source.iterate(self._query)) - def _iter_transform(self) -> Iterator[Observation[T]]: + def _iter_transform(self) -> Iterator[O]: """Iterate a transform source, applying query filters in Python.""" assert isinstance(self._source, Stream) and self._transform is not None - it: Iterator[Observation[T]] = self._transform(iter(self._source)) - return self._query.apply(it, live=self.is_live()) + it = self._transform(iter(self._source)) + return cast("Iterator[O]", self._query.apply(it, live=self.is_live())) - def _replace_query(self, **overrides: Any) -> Stream[T]: + def _replace_query(self, **overrides: Any) -> Stream[T, O]: q = self._query new_q = StreamQuery( filters=overrides.get("filters", q.filters), @@ -149,45 +156,64 @@ def _replace_query(self, **overrides: Any) -> Stream[T]: ) return Stream(self._source, transform=self._transform, query=new_q) - def _with_filter(self, f: Filter) -> Stream[T]: + def _with_filter(self, f: Filter) -> Stream[T, O]: return self._replace_query(filters=(*self._query.filters, f)) - def after(self, t: float) -> Stream[T]: + def after(self, t: float) -> Stream[T, O]: return self._with_filter(AfterFilter(t)) - def before(self, t: float) -> Stream[T]: + def before(self, t: float) -> Stream[T, O]: return self._with_filter(BeforeFilter(t)) - def time_range(self, t1: float, t2: float) -> Stream[T]: + def time_range(self, t1: float, t2: float) -> Stream[T, O]: return self._with_filter(TimeRangeFilter(t1, t2)) - def at(self, t: float, tolerance: float = 1.0) -> Stream[T]: + def at(self, t: float, tolerance: float = 1.0) -> Stream[T, O]: return self._with_filter(AtFilter(t, tolerance)) - def near(self, pose: Any, radius: float) -> Stream[T]: + def at_relative(self, t: float, tolerance: float = 1.0) -> Stream[T, O]: + """Like `at` but ``t`` is seconds from the first observation.""" + t0 = self.first().ts + return self.at(t0 + t, tolerance=tolerance) + + def near(self, pose: Any, radius: float) -> Stream[T, O]: return self._with_filter(NearFilter(pose, radius)) - def tags(self, **tags: Any) -> Stream[T]: + def tags(self, **tags: Any) -> Stream[T, O]: return self._with_filter(TagsFilter(tags)) - def order_by(self, field: str, desc: bool = False) -> Stream[T]: + def order_by(self, field: str, desc: bool = False) -> Stream[T, O]: return self._replace_query(order_field=field, order_desc=desc) - def limit(self, k: int) -> Stream[T]: + def limit(self, k: int) -> Stream[T, O]: return self._replace_query(limit_val=k) - def offset(self, n: int) -> Stream[T]: + def offset(self, n: int) -> Stream[T, O]: return self._replace_query(offset_val=n) - def search(self, query: Embedding, k: int) -> Stream[T]: - """Return top-k observations by cosine similarity to *query*. + def search(self, query: Embedding, k: int | None = None) -> Stream[T, EmbeddedObservation[T]]: + """Rank observations by cosine similarity to *query*. + + Returns a stream whose observations are :class:`EmbeddedObservation` + with ``.similarity`` populated. - The backend handles the actual computation. ListObservationStore does - brute-force cosine; SqliteObservationStore pushes down to vec0. + If *k* is omitted, unbounded backends return all scored hits and + bounded backends (e.g. sqlite-vec) apply their own default cap. """ - return self._replace_query(search_vec=query, search_k=k) + new_q = StreamQuery( + filters=self._query.filters, + order_field=self._query.order_field, + order_desc=self._query.order_desc, + limit_val=self._query.limit_val, + offset_val=self._query.offset_val, + live_buffer=self._query.live_buffer, + search_vec=query, + search_k=k, + search_text=self._query.search_text, + ) + return Stream(self._source, transform=self._transform, query=new_q) - def search_text(self, text: str) -> Stream[T]: + def search_text(self, text: str) -> Stream[T, O]: """Filter observations whose data contains *text*. ListObservationStore does case-insensitive substring match; @@ -195,13 +221,41 @@ def search_text(self, text: str) -> Stream[T]: """ return self._replace_query(search_text=text) - def filter(self, pred: Callable[[Observation[T]], bool]) -> Stream[T]: + def filter(self, pred: Callable[[O], bool]) -> Stream[T, O]: """Filter by arbitrary predicate on the full Observation.""" - return self._with_filter(PredicateFilter(pred)) + return self._with_filter(PredicateFilter(cast("Callable[[Observation[Any]], bool]", pred))) - def map(self, fn: Callable[[Observation[T]], Observation[R]]) -> Stream[R]: + def tap(self, fn: Callable[[O], Any]) -> Stream[T, O]: + """Call *fn* on each observation without changing it.""" + + def _tap(upstream: Iterator[Observation[T]]) -> Iterator[Observation[T]]: + for obs in upstream: + fn(cast("O", obs)) + yield obs + + return cast("Stream[T, O]", self.transform(FnIterTransformer(_tap))) + + def scan_data(self, state: Any, fn: Callable[[Any, O], tuple[Any, R]]) -> Stream[R]: + """Stateful map: ``fn(state, obs) -> (new_state, new_data)``. + + Each observation is yielded with ``.data`` replaced by ``new_data``. + """ + + def _scan(upstream: Iterator[Observation[T]]) -> Iterator[Observation[R]]: + s = state + for obs in upstream: + s, val = fn(s, cast("O", obs)) + yield obs.derive(data=val) + + return self.transform(FnIterTransformer(_scan)) + + def map(self, fn: Callable[[O], Observation[R]]) -> Stream[R]: + """Map each observation to a new observation (possibly of a new data type).""" + return self.transform(FnTransformer(lambda obs: fn(cast("O", obs)))) + + def map_data(self, fn: Callable[[O], R]) -> Stream[R]: """Transform each observation's data via callable.""" - return self.transform(FnTransformer(lambda obs: fn(obs))) + return self.transform(FnTransformer(lambda obs: obs.derive(data=fn(cast("O", obs))))) def transform( self, @@ -222,7 +276,7 @@ def detect(upstream): xf = FnIterTransformer(xf) return Stream(source=self, transform=xf, query=StreamQuery()) - def live(self, buffer: BackpressureBuffer[Observation[Any]] | None = None) -> Stream[T]: + def live(self, buffer: BackpressureBuffer[Observation[Any]] | None = None) -> Stream[T, O]: """Return a stream whose iteration never ends — backfill then live tail. All backends support live mode via their built-in ``Notifier``. @@ -239,30 +293,36 @@ def live(self, buffer: BackpressureBuffer[Observation[Any]] | None = None) -> St buf = buffer if buffer is not None else KeepLast() return self._replace_query(live_buffer=buf) - def save(self, target: Stream[T]) -> Stream[T]: - """Sync terminal: iterate self, append each obs to target's backend. + def save(self, target: Stream[T, O]) -> Stream[T, O]: + """Lazy pass-through that appends each observation to *target*'s backend. - Returns the target stream for continued querying. + Iteration drives both the passthrough and the appends — pick a terminal + (``.drain()`` sync, ``.drain_thread()`` background, ``.to_list()``, + ``for obs in ...``). """ if isinstance(target._source, Stream) or target._source is None: raise TypeError( "Cannot save to a transform/unbound stream. Target must be backend-backed." ) backend = target._source - for obs in self: - backend.append(obs) - return target - def fetch(self) -> list[Observation[T]]: + def _save(upstream: Iterator[Observation[T]]) -> Iterator[Observation[T]]: + for obs in upstream: + backend.append(obs) + yield obs + + return cast("Stream[T, O]", self.transform(FnIterTransformer(_save))) + + def to_list(self) -> list[O]: """Materialize all observations into a list.""" if self.is_live(): raise TypeError( - ".fetch() on a live stream would block forever. " + ".to_list() on a live stream would block forever. " "Use .drain() or .save(target) instead." ) return list(self) - def first(self) -> Observation[T]: + def first(self) -> O: """Return the first matching observation.""" it = iter(self.limit(1)) try: @@ -270,7 +330,7 @@ def first(self) -> Observation[T]: except StopIteration: raise LookupError("No matching observation") from None - def last(self) -> Observation[T]: + def last(self) -> O: """Return the last matching observation (by timestamp).""" return self.order_by("ts", desc=True).first() @@ -308,6 +368,19 @@ def summary(self) -> str: dur = t1 - t0 return f"{self}: {n} items, {dt0} — {dt1} ({dur:.1f}s)" + def materialize(self) -> Stream[T, O]: + """Materialize into memory and return a replayable stream. + + Useful when you need to iterate the same results multiple times + without re-running the upstream query. + """ + from dimos.memory2.store.memory import MemoryStore + + mem = MemoryStore() + target = cast("Stream[T, O]", mem.stream("materialize")) + self.save(target).drain() + return target + def drain(self) -> int: """Consume all observations, discarding results. Returns count consumed. @@ -319,7 +392,11 @@ def drain(self) -> int: n += 1 return n - def observable(self) -> reactivex.Observable[Observation[T]]: + def drain_thread(self) -> DisposableBase: + """Drain this stream on the dimos thread pool; returns a disposable.""" + return self.subscribe(lambda _: None) + + def observable(self) -> reactivex.Observable[O]: """Convert this stream to an RxPY Observable. Iteration is scheduled on the dimos thread pool so subscribe() never @@ -336,7 +413,7 @@ def observable(self) -> reactivex.Observable[Observation[T]]: def subscribe( self, - on_next: Callable[[Observation[T]], None] | ObserverBase[Observation[T]] | None = None, + on_next: Callable[[O], None] | ObserverBase[O] | None = None, on_error: Callable[[Exception], None] | None = None, on_completed: Callable[[], None] | None = None, ) -> DisposableBase: @@ -352,26 +429,7 @@ def subscribe( ) ) - def publish(self, out: Any) -> DisposableBase: - """Publish each observation's data to a Module ``Out`` port. - - Iteration runs on the dimos thread pool (via :meth:`subscribe`). - Returns a ``DisposableBase`` suitable for ``register_disposable()``. - - Example:: - - lidar.live().transform(VoxelMapTransformer()).publish(self.global_map) - """ - - def _on_error(e: Exception) -> None: - logger.error("Stream.publish() pipeline error: %s", e, exc_info=True) - - return self.subscribe( - on_next=lambda obs: out.publish(obs.data), - on_error=_on_error, - ) - - def chain(self, other: Stream[R]) -> Stream[R]: + def chain(self, other: Stream[R, Any]) -> Stream[R]: """Append operations from an unbound stream to this stream. Extracts the transform/filter chain from *other* (which must be @@ -381,7 +439,7 @@ def chain(self, other: Stream[R]) -> Stream[R]: store.stream("lidar").live().chain(pipeline) """ ops: list[tuple[Transformer[Any, Any] | None, StreamQuery]] = [] - current: Stream[Any] | None | Any = other + current: Stream[Any, Any] | None | Any = other found_root = False while isinstance(current, Stream): ops.append((current._transform, current._query)) @@ -399,7 +457,7 @@ def chain(self, other: Stream[R]) -> Stream[R]: if query.live_buffer is not None: raise TypeError("live() cannot be used on unbound streams") - result: Stream[Any] = self + result: Stream[Any, Any] = self for xf, query in reversed(ops): if xf is not None: result = result.transform(xf) @@ -413,6 +471,26 @@ def chain(self, other: Stream[R]) -> Stream[R]: result = result.order_by(query.order_field, desc=query.order_desc) return cast("Stream[R]", result) + @overload + def append( + self, + payload: T, + *, + ts: float | None = ..., + pose: Any | None = ..., + tags: dict[str, Any] | None = ..., + embedding: None = None, + ) -> Observation[T]: ... + @overload + def append( + self, + payload: T, + *, + ts: float | None = ..., + pose: Any | None = ..., + tags: dict[str, Any] | None = ..., + embedding: Embedding, + ) -> EmbeddedObservation[T]: ... def append( self, payload: T, @@ -422,7 +500,11 @@ def append( tags: dict[str, Any] | None = None, embedding: Embedding | None = None, ) -> Observation[T]: - """Append to the backing store. Only works if source is a Backend.""" + """Append to the backing store. Only works if source is a Backend. + + Returns :class:`EmbeddedObservation` when *embedding* is provided, + else a plain :class:`Observation`. + """ if isinstance(self._source, Stream) or self._source is None: raise TypeError( "Cannot append to a transform/unbound stream. Append to the source stream." diff --git a/dimos/memory2/streaming.md b/dimos/memory2/streaming.md index fd7f5519a1..3ddae6d438 100644 --- a/dimos/memory2/streaming.md +++ b/dimos/memory2/streaming.md @@ -49,7 +49,7 @@ stream.live().transform(xf).last() ```python # Search the stored data, not the live tail -results = stream.search(vec, k=5).fetch() +results = stream.search(vec, k=5).to_list() # First works fine (uses limit(1), no materialization) obs = stream.live().transform(xf).first() @@ -59,27 +59,32 @@ obs = stream.live().transform(xf).first() Terminals trigger iteration and return a value. They're the "go" button — nothing executes until a terminal is called. -| Method | Returns | Memory | Live behaviour | -|-----------------|---------------------|--------------------|-----------------------------------------| -| `.fetch()` | `list[Observation]` | Grows with results | TypeError without `.limit()` first | -| `.drain()` | `int` (count) | Constant | Blocks forever, memory stays flat | -| `.save(target)` | target `Stream` | Constant | Blocks forever, appends each to store | -| `.first()` | `Observation` | Constant | Returns first item, then stops | -| `.exists()` | `bool` | Constant | Returns after one item check | -| `.last()` | `Observation` | Materializes | TypeError (uses order_by internally) | -| `.count()` | `int` | Constant | TypeError on transform streams | +| Method | Returns | Memory | Live behaviour | +|-------------------|---------------------|--------------------|-----------------------------------------| +| `.to_list()` | `list[Observation]` | Grows with results | TypeError without `.limit()` first | +| `.drain()` | `int` (count) | Constant | Blocks forever, memory stays flat | +| `.drain_thread()` | `DisposableBase` | Constant | Runs on the dimos thread pool | +| `.first()` | `Observation` | Constant | Returns first item, then stops | +| `.exists()` | `bool` | Constant | Returns after one item check | +| `.last()` | `Observation` | Materializes | TypeError (uses order_by internally) | +| `.count()` | `int` | Constant | TypeError on transform streams | + +`.save(target)` is **not** a terminal — it's a lazy pass-through that appends each +observation to ``target``'s backend as the stream is iterated. Pair it with +``.drain()`` (sync) or ``.drain_thread()`` (background) to actually run the pipeline. ### Choosing the right terminal **Batch query** — collect results into memory: ```python -results = stream.after(t).search(vec, k=10).fetch() +results = stream.after(t).search(vec, k=10).to_list() ``` **Live ingestion** — process forever, constant memory: ```python -# Embed and store continuously -stream.live().transform(EmbedImages(clip)).save(target) +# Embed and store continuously on the dimos thread pool +handle = stream.live().transform(EmbedImages(clip)).save(target).drain_thread() +# handle is a DisposableBase — dispose() to stop # Side-effect pipeline (no storage) stream.live().transform(process).drain() @@ -93,7 +98,7 @@ has_data = stream.exists() # quick check **Bounded live** — collect a fixed number from a live stream: ```python -batch = stream.live().limit(100).fetch() # OK — limit makes it finite +batch = stream.live().limit(100).to_list() # OK — limit makes it finite ``` ### Error summary @@ -104,6 +109,6 @@ All operations that would silently hang on live streams raise `TypeError` instea |-------------------------------------|-----------------------------------------------| | `live.transform(xf).search(vec, k)` | `.search() requires finite data` | | `live.transform(xf).order_by("ts")` | `.order_by() requires finite data` | -| `live.fetch()` (without `.limit()`) | `.fetch() would collect forever` | +| `live.to_list()` (without `.limit()`) | `.to_list() would collect forever` | | `live.transform(xf).count()` | `.count() would block forever` | | `live.transform(xf).last()` | `.order_by() requires finite data` (via last) | diff --git a/dimos/memory2/test_blobstore_integration.py b/dimos/memory2/test_blobstore_integration.py index 6c26a635c0..3d081e6d74 100644 --- a/dimos/memory2/test_blobstore_integration.py +++ b/dimos/memory2/test_blobstore_integration.py @@ -81,7 +81,7 @@ def test_eager_preloads_data(self, bs: FileBlobStore) -> None: s.append("payload", ts=1.0) # Iterating with eager_blobs triggers load - results = s.fetch() + results = s.to_list() assert len(results) == 1 # Data should be loaded (not _UNLOADED) assert not isinstance(results[0]._data, type(_UNLOADED)) @@ -96,8 +96,8 @@ def test_per_stream_eager_override(self, store: MemoryStore) -> None: eager_stream = store.stream("eager", str, eager_blobs=True) eager_stream.append("eager-val", ts=1.0) - lazy_results = lazy_stream.fetch() - eager_results = eager_stream.fetch() + lazy_results = lazy_stream.to_list() + eager_results = eager_stream.to_list() # Lazy: data stays unloaded until accessed assert lazy_results[0].data == "lazy-val" @@ -127,7 +127,7 @@ def test_blobstore_with_vector_search(self, bs: FileBlobStore) -> None: s.append("south", ts=3.0, embedding=_emb([0, -1, 0])) # Vector search triggers lazy load via obs.derive(data=obs.data, ...) - results = s.search(_emb([0, 1, 0]), k=2).fetch() + results = s.search(_emb([0, 1, 0]), k=2).to_list() assert len(results) == 2 assert results[0].data == "north" assert results[0].similarity > 0.99 @@ -138,7 +138,7 @@ def test_blobstore_with_text_search(self, store: MemoryStore) -> None: s.append("temperature ok", ts=2.0) # Text search triggers lazy load via str(obs.data) - results = s.search_text("motor").fetch() + results = s.search_text("motor").to_list() assert len(results) == 1 assert results[0].data == "motor fault" @@ -148,7 +148,7 @@ def test_multiple_appends_get_unique_blobs(self, store: MemoryStore) -> None: s.append("second", ts=2.0) s.append("third", ts=3.0) - results = s.fetch() + results = s.to_list() assert [r.data for r in results] == ["first", "second", "third"] def test_fetch_preserves_metadata(self, store: MemoryStore) -> None: diff --git a/dimos/memory2/test_e2e.py b/dimos/memory2/test_e2e.py index d35663f9ec..cb0450c005 100644 --- a/dimos/memory2/test_e2e.py +++ b/dimos/memory2/test_e2e.py @@ -21,14 +21,15 @@ import pytest +from dimos.memory.timeseries.legacy import LegacyPickleStore from dimos.memory2.embed import EmbedImages from dimos.memory2.store.sqlite import SqliteStore from dimos.memory2.transform import QualityWindow from dimos.models.embedding.clip import CLIPModel from dimos.msgs.sensor_msgs.Image import Image from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.robot.unitree.type.odometry import Odometry from dimos.utils.data import get_data_dir -from dimos.utils.testing.replay import TimedSensorReplay if TYPE_CHECKING: from collections.abc import Iterator @@ -47,7 +48,7 @@ def session() -> Iterator[SqliteStore]: class PoseIndex: """Preloaded odom data with O(log n) closest-timestamp lookup.""" - def __init__(self, replay: TimedSensorReplay) -> None: + def __init__(self, replay: LegacyPickleStore[Any]) -> None: self._timestamps: list[float] = [] self._data: list[Any] = [] for ts, data in replay.iterate_ts(): @@ -67,39 +68,68 @@ def find_closest(self, ts: float) -> Any | None: return self._data[idx - 1] return self._data[idx] + def __iter__(self) -> Iterator[tuple[float, Any]]: + return iter(zip(self._timestamps, self._data, strict=False)) + @pytest.fixture(scope="module") -def video_replay() -> TimedSensorReplay: - return TimedSensorReplay("unitree_go2_bigoffice/video") +def video_replay() -> LegacyPickleStore[Image]: + return LegacyPickleStore("unitree_go2_bigoffice/video") @pytest.fixture(scope="module") def odom_index() -> PoseIndex: - return PoseIndex(TimedSensorReplay("unitree_go2_bigoffice/odom")) + return PoseIndex(LegacyPickleStore("unitree_go2_bigoffice/odom")) @pytest.fixture(scope="module") -def lidar_replay() -> TimedSensorReplay: - return TimedSensorReplay("unitree_go2_bigoffice/lidar") +def lidar_replay() -> LegacyPickleStore[PointCloud2]: + return LegacyPickleStore("unitree_go2_bigoffice/lidar") @pytest.mark.tool class TestImportReplay: - """Import legacy pickle replay data into a memory2 SqliteStore.""" + """Import legacy pickle replay data into a memory2 SqliteStore. + + Lidar/odom are trimmed to start at video's first ts. The Memory2ReplayAdapter + scheduler anchors each stream to its own first_ts on subscribe, so aligning + first_ts across streams keeps replay synchronized. + """ + + def test_import_odom( + self, + session: SqliteStore, + odom_index: PoseIndex, + video_replay: LegacyPickleStore[Any], + ) -> None: + threshold = video_replay.first_timestamp() + with session.stream("odom", Odometry) as odom: + count = 0 + skipped = 0 + for ts, data in odom_index: + if ts < threshold: + skipped += 1 + continue + odom.append(data, ts=ts, pose=data) + count += 1 + + assert count > 0 + assert odom.count() == count + print(f"Imported {count} odom frames (skipped {skipped} before {threshold:.2f})") def test_import_video( self, session: SqliteStore, - video_replay: TimedSensorReplay, + video_replay: LegacyPickleStore[Any], odom_index: PoseIndex, ) -> None: with session.stream("color_image", Image) as video: count = 0 for ts, frame in video_replay.iterate_ts(): pose = odom_index.find_closest(ts) - print("import", frame) video.append(frame, ts=ts, pose=pose) count += 1 + print(f"import [{count}] ts={ts:.2f} {frame}") assert count > 0 assert video.count() == count @@ -108,24 +138,31 @@ def test_import_video( def test_import_lidar( self, session: SqliteStore, - lidar_replay: TimedSensorReplay, + lidar_replay: LegacyPickleStore[Any], odom_index: PoseIndex, + video_replay: LegacyPickleStore[Any], ) -> None: - # can also be explicit here - # lidar = session.stream("lidar", PointCloud2, codec=Lz4Codec(LcmCodec(PointCloud2))) + threshold = video_replay.first_timestamp() lidar = session.stream("lidar", PointCloud2, codec="lz4+lcm") count = 0 + skipped = 0 for ts, frame in lidar_replay.iterate_ts(): + if ts < threshold: + skipped += 1 + continue pose = odom_index.find_closest(ts) - print("import", frame) lidar.append(frame, ts=ts, pose=pose) count += 1 + print(f"import [{count}] ts={ts:.2f} {frame}") assert count > 0 assert lidar.count() == count - print(f"Imported {count} lidar frames") + print(f"Imported {count} lidar frames (skipped {skipped} before {threshold:.2f})") + +@pytest.mark.tool +class TestEmbed: def test_embed_and_save(self, session: SqliteStore, clip: CLIPModel) -> None: """Embed video frames at 1Hz and persist to an embedded stream.""" video = session.stream("color_image", Image) @@ -136,9 +173,11 @@ def test_embed_and_save(self, session: SqliteStore, clip: CLIPModel) -> None: embedded = session.stream("color_image_embedded", Image) - # Downsample to 1Hz, then embed + # Downsample to 2Hz, then embed pipeline = ( - video.transform(QualityWindow(lambda img: img.sharpness, window=1.0)) + video.filter(lambda obs: obs.data.brightness > 0.1) + .tap(print) + .transform(QualityWindow(lambda img: img.sharpness, window=0.5)) .transform(EmbedImages(clip)) .save(embedded) ) @@ -163,7 +202,7 @@ def test_query_imported_data(self, session: SqliteStore) -> None: assert first_frame.ts < last_frame.ts mid_ts = (first_frame.ts + last_frame.ts) / 2 - subset = video.time_range(first_frame.ts, mid_ts).fetch() + subset = video.time_range(first_frame.ts, mid_ts).to_list() assert 0 < len(subset) < video.count() streams = session.list_streams() @@ -207,15 +246,15 @@ def test_time_range_filter(self, session: SqliteStore) -> None: first = video.first() # Grab first 5 seconds - window = video.time_range(first.ts, first.ts + 5.0).fetch() + window = video.time_range(first.ts, first.ts + 5.0).to_list() assert len(window) > 0 assert len(window) < video.count() assert all(first.ts <= obs.ts <= first.ts + 5.0 for obs in window) def test_limit_offset_pagination(self, session: SqliteStore) -> None: video = session.stream("color_image", Image) - page1 = video.limit(10).fetch() - page2 = video.offset(10).limit(10).fetch() + page1 = video.limit(10).to_list() + page2 = video.offset(10).limit(10).to_list() assert len(page1) == 10 assert len(page2) == 10 @@ -223,7 +262,7 @@ def test_limit_offset_pagination(self, session: SqliteStore) -> None: def test_order_by_desc(self, session: SqliteStore) -> None: video = session.stream("color_image", Image) - last_10 = video.order_by("ts", desc=True).limit(10).fetch() + last_10 = video.order_by("ts", desc=True).limit(10).to_list() assert len(last_10) == 10 assert all(last_10[i].ts >= last_10[i + 1].ts for i in range(9)) @@ -292,7 +331,7 @@ def test_search_by_text(self, session: SqliteStore, clip: CLIPModel) -> None: embedded = session.stream("color_image_embedded", Image) query = clip.embed_text("a door") - results = embedded.search(query, k=5).fetch() + results = embedded.search(query, k=5).to_list() assert len(results) > 0 for obs in results: assert obs.similarity is not None diff --git a/dimos/memory2/test_embedding.py b/dimos/memory2/test_embedding.py index 57d66da278..caa1080ffd 100644 --- a/dimos/memory2/test_embedding.py +++ b/dimos/memory2/test_embedding.py @@ -95,7 +95,7 @@ def test_search_returns_top_k(self, memory_store) -> None: s.append("south", embedding=_emb([0, -1, 0])) s.append("west", embedding=_emb([-1, 0, 0])) - results = s.search(_emb([0, 1, 0]), k=2).fetch() + results = s.search(_emb([0, 1, 0]), k=2).to_list() assert len(results) == 2 assert results[0].data == "north" assert results[0].similarity is not None @@ -107,7 +107,7 @@ def test_search_sorted_by_similarity(self, memory_store) -> None: s.append("close", embedding=_emb([0.9, 0.1, 0])) s.append("exact", embedding=_emb([1, 0, 0])) - results = s.search(_emb([1, 0, 0]), k=3).fetch() + results = s.search(_emb([1, 0, 0]), k=3).to_list() assert results[0].data == "exact" assert results[1].data == "close" assert results[2].data == "far" @@ -119,7 +119,7 @@ def test_search_skips_non_embedded(self, memory_store) -> None: s.append("plain") # no embedding s.append("embedded", embedding=_emb([1, 0, 0])) - results = s.search(_emb([1, 0, 0]), k=10).fetch() + results = s.search(_emb([1, 0, 0]), k=10).to_list() assert len(results) == 1 assert results[0].data == "embedded" @@ -129,7 +129,7 @@ def test_search_with_filters(self, memory_store) -> None: s.append("late", ts=20.0, embedding=_emb([1, 0, 0])) # Only the late one should pass the after filter - results = s.after(15.0).search(_emb([1, 0, 0]), k=10).fetch() + results = s.after(15.0).search(_emb([1, 0, 0]), k=10).to_list() assert len(results) == 1 assert results[0].data == "late" @@ -139,7 +139,7 @@ def test_search_with_limit(self, memory_store) -> None: s.append(f"item{i}", embedding=_emb([1, 0, 0])) # search k=5 then limit 2 - results = s.search(_emb([1, 0, 0]), k=5).limit(2).fetch() + results = s.search(_emb([1, 0, 0]), k=5).limit(2).to_list() assert len(results) == 2 def test_search_with_live_raises(self, memory_store) -> None: @@ -156,7 +156,7 @@ def test_search_text_substring(self, memory_store) -> None: s.append("temperature normal") s.append("motor overheating") - results = s.search_text("motor").fetch() + results = s.search_text("motor").to_list() assert len(results) == 2 assert {r.data for r in results} == {"motor fault detected", "motor overheating"} @@ -165,7 +165,7 @@ def test_search_text_case_insensitive(self, memory_store) -> None: s.append("Motor Fault") s.append("other event") - results = s.search_text("motor fault").fetch() + results = s.search_text("motor fault").to_list() assert len(results) == 1 def test_search_text_with_filters(self, memory_store) -> None: @@ -174,7 +174,7 @@ def test_search_text_with_filters(self, memory_store) -> None: s.append("motor warning", ts=20.0) s.append("motor fault", ts=30.0) - results = s.after(15.0).search_text("fault").fetch() + results = s.after(15.0).search_text("fault").to_list() assert len(results) == 1 assert results[0].ts == 30.0 @@ -182,7 +182,7 @@ def test_search_text_no_match(self, memory_store) -> None: s = memory_store.stream("logs", str) s.append("all clear") - results = s.search_text("motor").fetch() + results = s.search_text("motor").to_list() assert len(results) == 0 @@ -193,9 +193,9 @@ def test_save_preserves_embeddings(self, memory_store) -> None: emb = _emb([1, 0, 0]) src.append("item", embedding=emb) - src.save(dst) + src.save(dst).drain() - results = dst.fetch() + results = dst.to_list() assert len(results) == 1 assert isinstance(results[0], EmbeddedObservation) # Same vector content (different Embedding instance after re-append) @@ -207,9 +207,9 @@ def test_save_mixed_embedded_and_plain(self, memory_store) -> None: src.append("plain") src.append("embedded", embedding=_emb([0, 1, 0])) - src.save(dst) + src.save(dst).drain() - results = dst.fetch() + results = dst.to_list() assert len(results) == 2 assert type(results[0]) is Observation assert isinstance(results[1], EmbeddedObservation) @@ -248,7 +248,7 @@ def test_embed_images_produces_embedded_observations(self, memory_store) -> None s.append("img1", ts=1.0) s.append("img2", ts=2.0) - results = s.transform(EmbedImages(model)).fetch() + results = s.transform(EmbedImages(model)).to_list() assert len(results) == 2 for obs in results: assert isinstance(obs, EmbeddedObservation) @@ -263,7 +263,7 @@ def test_embed_text_produces_embedded_observations(self, memory_store) -> None: s.append("motor fault", ts=1.0) s.append("all clear", ts=2.0) - results = s.transform(EmbedText(model)).fetch() + results = s.transform(EmbedText(model)).to_list() assert len(results) == 2 for obs in results: assert isinstance(obs, EmbeddedObservation) @@ -290,7 +290,7 @@ def test_embed_then_search(self, memory_store) -> None: embedded = s.transform(EmbedText(model)) # Get the embedding for the first item, then search for similar first_emb = embedded.first().embedding - results = embedded.search(first_emb, k=3).fetch() + results = embedded.search(first_emb, k=3).to_list() assert len(results) == 3 # First result should be the exact match assert results[0].similarity is not None @@ -354,7 +354,7 @@ def test_search_uses_vector_store(self) -> None: s.append("south", embedding=_emb([0, -1, 0])) s.append("west", embedding=_emb([-1, 0, 0])) - results = s.search(_emb([0, 1, 0]), k=2).fetch() + results = s.search(_emb([0, 1, 0]), k=2).to_list() assert len(results) == 2 assert results[0].data == "north" assert results[0].similarity is not None @@ -371,7 +371,7 @@ def test_search_with_filters_via_vector_store(self) -> None: s.append("late", ts=20.0, embedding=_emb([1, 0, 0])) # Filter + search: only "late" passes the after filter - results = s.after(15.0).search(_emb([1, 0, 0]), k=10).fetch() + results = s.after(15.0).search(_emb([1, 0, 0]), k=10).to_list() assert len(results) == 1 assert results[0].data == "late" diff --git a/dimos/memory2/test_materialize.py b/dimos/memory2/test_materialize.py new file mode 100644 index 0000000000..7d8a3178fb --- /dev/null +++ b/dimos/memory2/test_materialize.py @@ -0,0 +1,49 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.memory2.backend import Backend +from dimos.memory2.codecs.pickle import PickleCodec +from dimos.memory2.observationstore.memory import ListObservationStore +from dimos.memory2.stream import Stream +from dimos.memory2.type.observation import Observation + + +def make_stream(n: int = 5, start_ts: float = 0.0) -> Stream[int]: + backend: Backend[int] = Backend( + metadata_store=ListObservationStore[int](name="test"), codec=PickleCodec() + ) + for i in range(n): + backend.append(Observation(id=-1, ts=start_ts + i, _data=i * 10)) + return Stream(source=backend) + + +class TestMaterialize: + def test_returns_same_data(self) -> None: + materialized = make_stream(3).materialize() + assert [o.data for o in materialized] == [0, 10, 20] + + def test_replayable(self) -> None: + materialized = make_stream(3).materialize() + first = [o.data for o in materialized] + second = [o.data for o in materialized] + assert first == second == [0, 10, 20] + + def test_with_transform(self) -> None: + materialized = make_stream(3).map(lambda obs: obs.derive(data=obs.data * 2)).materialize() + assert [o.data for o in materialized] == [0, 20, 40] + assert [o.data for o in materialized] == [0, 20, 40] + + def test_queryable(self) -> None: + materialized = make_stream(5, start_ts=0.0).materialize() + assert [o.data for o in materialized.after(2.0)] == [30, 40] diff --git a/dimos/memory2/test_module.py b/dimos/memory2/test_module.py index ad2aa096e4..59bf7da879 100644 --- a/dimos/memory2/test_module.py +++ b/dimos/memory2/test_module.py @@ -45,7 +45,7 @@ def __call__(self, upstream: Iterator[Observation[int]]) -> Iterator[Observation # -- Pipeline styles ------------------------------------------------------- -class StaticStreamModule(StreamModule): +class StaticStreamModule(StreamModule[int, int]): """Pipeline as a static Stream chain on the class.""" pipeline = Stream().transform(Double()) @@ -53,7 +53,7 @@ class StaticStreamModule(StreamModule): doubled: Out[int] -class StaticTransformerModule(StreamModule): +class StaticTransformerModule(StreamModule[int, int]): """Pipeline as a bare Transformer on the class.""" pipeline = Double() @@ -65,12 +65,12 @@ class MethodPipelineConfig(ModuleConfig): factor: int = 2 -class MethodPipelineModule(StreamModule): +class MethodPipelineModule(StreamModule[int, int]): """Pipeline as a method with access to self.config.""" config: MethodPipelineConfig - def pipeline(self, stream: Stream) -> Stream: + def pipeline(self, stream: Stream[int]) -> Stream[int]: return stream.transform(Double(factor=self.config.factor)) numbers: In[int] diff --git a/dimos/memory2/test_registry.py b/dimos/memory2/test_registry.py index d611073075..7c1b20239e 100644 --- a/dimos/memory2/test_registry.py +++ b/dimos/memory2/test_registry.py @@ -178,7 +178,7 @@ def test_reopen_preserves_data(self, tmp_path) -> None: with SqliteStore(path=db) as store2: s2 = store2.stream("nums", int) assert s2.count() == 2 - obs = s2.fetch() + obs = s2.to_list() assert [o.data for o in obs] == [42, 99] def test_reopen_preserves_codec(self, tmp_path) -> None: diff --git a/dimos/memory2/test_save.py b/dimos/memory2/test_save.py index 8ebb12082b..eba8905974 100644 --- a/dimos/memory2/test_save.py +++ b/dimos/memory2/test_save.py @@ -59,19 +59,22 @@ def test_save_populates_target(self) -> None: source = make_stream(3) target = Stream(source=_make_backend("target")) - source.save(target) + source.save(target).drain() - results = target.fetch() + results = target.to_list() assert len(results) == 3 assert [o.data for o in results] == [0, 10, 20] - def test_save_returns_target_stream(self) -> None: + def test_save_is_lazy(self) -> None: + """save() returns a lazy passthrough — nothing written until iterated.""" source = make_stream(2) target = Stream(source=_make_backend("target")) - result = source.save(target) + pipeline = source.save(target) - assert result is target + assert target.to_list() == [], "expected target empty before iteration" + list(pipeline) # drive iteration + assert len(target.to_list()) == 2 def test_save_preserves_data(self) -> None: backend = _make_backend("src") @@ -79,7 +82,7 @@ def test_save_preserves_data(self) -> None: source = Stream(source=backend) target = Stream(source=_make_backend("dst")) - source.save(target) + source.save(target).drain() obs = target.first() assert obs.data == 42 @@ -92,32 +95,32 @@ def test_save_with_transform(self) -> None: doubled = source.transform(FnTransformer(lambda obs: obs.derive(data=obs.data * 2))) target = Stream(source=_make_backend("target")) - doubled.save(target) + doubled.save(target).drain() - assert [o.data for o in target.fetch()] == [0, 20, 40] + assert [o.data for o in target.to_list()] == [0, 20, 40] def test_save_rejects_transform_target(self) -> None: source = make_stream(2) base = make_stream(2) transform_stream = base.transform(FnTransformer(lambda obs: obs.derive(obs.data))) - with pytest.raises(TypeError, match="Cannot save to a transform"): + with pytest.raises(TypeError, match="Cannot save to"): source.save(transform_stream) def test_save_target_queryable(self) -> None: source = make_stream(5, start_ts=0.0) # ts: 0,1,2,3,4 target = Stream(source=_make_backend("target")) - result = source.save(target) + source.save(target).drain() - after_2 = result.after(2.0).fetch() + after_2 = target.after(2.0).to_list() assert [o.data for o in after_2] == [30, 40] def test_save_empty_source(self) -> None: source = make_stream(0) target = Stream(source=_make_backend("target")) - result = source.save(target) + source.save(target).drain() - assert result.count() == 0 - assert result.fetch() == [] + assert target.count() == 0 + assert target.to_list() == [] diff --git a/dimos/memory2/test_store.py b/dimos/memory2/test_store.py index aa525c8758..5b6d146364 100644 --- a/dimos/memory2/test_store.py +++ b/dimos/memory2/test_store.py @@ -49,7 +49,7 @@ def test_append_multiple_and_fetch(self, session: Store) -> None: s.append(2.0, ts=200.0) s.append(3.0, ts=300.0) - results = s.fetch() + results = s.to_list() assert len(results) == 3 assert [o.data for o in results] == [1.0, 2.0, 3.0] @@ -94,7 +94,7 @@ def test_filter_after(self, session: Store) -> None: s.append(2, ts=20.0) s.append(3, ts=30.0) - results = s.after(15.0).fetch() + results = s.after(15.0).to_list() assert [o.data for o in results] == [2, 3] def test_filter_before(self, session: Store) -> None: @@ -103,7 +103,7 @@ def test_filter_before(self, session: Store) -> None: s.append(2, ts=20.0) s.append(3, ts=30.0) - results = s.before(25.0).fetch() + results = s.before(25.0).to_list() assert [o.data for o in results] == [1, 2] def test_filter_time_range(self, session: Store) -> None: @@ -112,7 +112,7 @@ def test_filter_time_range(self, session: Store) -> None: s.append(2, ts=20.0) s.append(3, ts=30.0) - results = s.time_range(15.0, 25.0).fetch() + results = s.time_range(15.0, 25.0).to_list() assert [o.data for o in results] == [2] def test_filter_tags(self, session: Store) -> None: @@ -121,7 +121,7 @@ def test_filter_tags(self, session: Store) -> None: s.append("b", tags={"kind": "error"}) s.append("c", tags={"kind": "info"}) - results = s.tags(kind="info").fetch() + results = s.tags(kind="info").to_list() assert [o.data for o in results] == ["a", "c"] def test_limit_and_offset(self, session: Store) -> None: @@ -129,7 +129,7 @@ def test_limit_and_offset(self, session: Store) -> None: for i in range(5): s.append(i, ts=float(i)) - page = s.offset(1).limit(2).fetch() + page = s.offset(1).limit(2).to_list() assert [o.data for o in page] == [1, 2] def test_order_by_desc(self, session: Store) -> None: @@ -138,7 +138,7 @@ def test_order_by_desc(self, session: Store) -> None: s.append(2, ts=20.0) s.append(3, ts=30.0) - results = s.order_by("ts", desc=True).fetch() + results = s.order_by("ts", desc=True).to_list() assert [o.data for o in results] == [3, 2, 1] def test_separate_streams_isolated(self, session: Store) -> None: @@ -182,7 +182,7 @@ def _emb(v: list[float]) -> Embedding: s.append("east", embedding=_emb([1, 0, 0])) s.append("south", embedding=_emb([0, -1, 0])) - results = s.search(_emb([0, 1, 0]), k=2).fetch() + results = s.search(_emb([0, 1, 0]), k=2).to_list() assert len(results) == 2 assert results[0].data == "north" assert results[0].similarity > 0.99 @@ -194,7 +194,7 @@ def test_search_text(self, session: Store) -> None: # SqliteObservationStore blocks search_text to prevent full table scans try: - results = s.search_text("motor").fetch() + results = s.search_text("motor").to_list() except NotImplementedError: pytest.skip("search_text not supported on this backend") assert len(results) == 1 @@ -243,11 +243,11 @@ def test_sqlite_lazy_and_eager_same_values(self, sqlite_store: Store) -> None: lazy_s.append("beta", ts=2.0, tags={"k": "w"}) # Lazy read - lazy_results = lazy_s.fetch() + lazy_results = lazy_s.to_list() # Eager read — new stream handle with eager_blobs on same backend eager_s = sqlite_store.stream("vals", str, eager_blobs=True) - eager_results = eager_s.fetch() + eager_results = eager_s.to_list() assert [o.data for o in lazy_results] == [o.data for o in eager_results] assert [o.tags for o in lazy_results] == [o.tags for o in eager_results] @@ -412,7 +412,7 @@ def _emb(v: list[float]) -> Embedding: s.append("north", ts=1.0, embedding=_emb([0, 1, 0])) s.append("east", ts=2.0, embedding=_emb([1, 0, 0])) - results = s.search(_emb([0, 1, 0]), k=2).fetch() + results = s.search(_emb([0, 1, 0]), k=2).to_list() assert len(vec_spy.searches) == 1 assert results[0].data == "north" diff --git a/dimos/memory2/test_stream.py b/dimos/memory2/test_stream.py index e53cd15d9f..f0aaf70137 100644 --- a/dimos/memory2/test_stream.py +++ b/dimos/memory2/test_stream.py @@ -68,7 +68,7 @@ def test_empty_stream(self, make_stream): assert list(stream) == [] def test_fetch_materializes_to_list(self, make_stream): - result = make_stream(3).fetch() + result = make_stream(3).to_list() assert isinstance(result, list) assert len(result) == 3 @@ -85,27 +85,27 @@ class TestTemporalFilters: def test_after(self, make_stream): """.after(t) keeps observations with ts > t.""" - result = make_stream(5).after(2.0).fetch() + result = make_stream(5).after(2.0).to_list() assert [o.ts for o in result] == [3.0, 4.0] def test_before(self, make_stream): """.before(t) keeps observations with ts < t.""" - result = make_stream(5).before(2.0).fetch() + result = make_stream(5).before(2.0).to_list() assert [o.ts for o in result] == [0.0, 1.0] def test_time_range(self, make_stream): """.time_range(t1, t2) keeps t1 <= ts <= t2.""" - result = make_stream(5).time_range(1.0, 3.0).fetch() + result = make_stream(5).time_range(1.0, 3.0).to_list() assert [o.ts for o in result] == [1.0, 2.0, 3.0] def test_at_with_tolerance(self, make_stream): """.at(t, tolerance) keeps observations within tolerance of t.""" - result = make_stream(5).at(2.0, tolerance=0.5).fetch() + result = make_stream(5).at(2.0, tolerance=0.5).to_list() assert [o.ts for o in result] == [2.0] def test_chained_temporal_filters(self, make_stream): """Filters compose — each narrows the result.""" - result = make_stream(10).after(2.0).before(7.0).fetch() + result = make_stream(10).after(2.0).before(7.0).to_list() assert [o.ts for o in result] == [3.0, 4.0, 5.0, 6.0] @@ -118,7 +118,7 @@ def test_near_with_tuples(self, memory_session): stream.append("close", ts=1.0, pose=(1, 1, 0)) stream.append("far", ts=2.0, pose=(10, 10, 10)) - result = stream.near((0, 0, 0), radius=2.0).fetch() + result = stream.near((0, 0, 0), radius=2.0).to_list() assert [o.data for o in result] == ["origin", "close"] def test_near_excludes_no_pose(self, memory_session): @@ -126,7 +126,7 @@ def test_near_excludes_no_pose(self, memory_session): stream.append("no_pose", ts=0.0) stream.append("has_pose", ts=1.0, pose=(0, 0, 0)) - result = stream.near((0, 0, 0), radius=10.0).fetch() + result = stream.near((0, 0, 0), radius=10.0).to_list() assert [o.data for o in result] == ["has_pose"] @@ -139,7 +139,7 @@ def test_filter_by_tag(self, memory_session): stream.append("car", ts=1.0, tags={"type": "vehicle", "wheels": 4}) stream.append("dog", ts=2.0, tags={"type": "animal", "legs": 4}) - result = stream.tags(type="animal").fetch() + result = stream.tags(type="animal").to_list() assert [o.data for o in result] == ["cat", "dog"] def test_filter_multiple_tags(self, memory_session): @@ -147,25 +147,25 @@ def test_filter_multiple_tags(self, memory_session): stream.append("a", ts=0.0, tags={"x": 1, "y": 2}) stream.append("b", ts=1.0, tags={"x": 1, "y": 3}) - result = stream.tags(x=1, y=2).fetch() + result = stream.tags(x=1, y=2).to_list() assert [o.data for o in result] == ["a"] class TestOrderLimitOffset: def test_limit(self, make_stream): - result = make_stream(10).limit(3).fetch() + result = make_stream(10).limit(3).to_list() assert len(result) == 3 def test_offset(self, make_stream): - result = make_stream(5).offset(2).fetch() + result = make_stream(5).offset(2).to_list() assert [o.data for o in result] == [20, 30, 40] def test_limit_and_offset(self, make_stream): - result = make_stream(10).offset(2).limit(3).fetch() + result = make_stream(10).offset(2).limit(3).to_list() assert [o.data for o in result] == [20, 30, 40] def test_order_by_ts_desc(self, make_stream): - result = make_stream(5).order_by("ts", desc=True).fetch() + result = make_stream(5).order_by("ts", desc=True).to_list() assert [o.ts for o in result] == [4.0, 3.0, 2.0, 1.0, 0.0] def test_first(self, make_stream): @@ -200,21 +200,21 @@ class TestFunctionalAPI: def test_filter_with_predicate(self, make_stream): """.filter() takes a predicate on the full Observation.""" - result = make_stream(5).filter(lambda obs: obs.data > 20).fetch() + result = make_stream(5).filter(lambda obs: obs.data > 20).to_list() assert [o.data for o in result] == [30, 40] def test_filter_on_metadata(self, make_stream): """Predicates can access ts, tags, pose — not just data.""" - result = make_stream(5).filter(lambda obs: obs.ts % 2 == 0).fetch() + result = make_stream(5).filter(lambda obs: obs.ts % 2 == 0).to_list() assert [o.ts for o in result] == [0.0, 2.0, 4.0] def test_map(self, make_stream): """.map() transforms each observation's data.""" - result = make_stream(3).map(lambda obs: obs.derive(data=obs.data * 2)).fetch() + result = make_stream(3).map(lambda obs: obs.derive(data=obs.data * 2)).to_list() assert [o.data for o in result] == [0, 20, 40] def test_map_preserves_ts(self, make_stream): - result = make_stream(3).map(lambda obs: obs.derive(data=str(obs.data))).fetch() + result = make_stream(3).map(lambda obs: obs.derive(data=str(obs.data))).to_list() assert [o.ts for o in result] == [0.0, 1.0, 2.0] assert [o.data for o in result] == ["0", "10", "20"] @@ -224,7 +224,7 @@ class TestTransformChaining: def test_single_transform(self, make_stream): xf = FnTransformer(lambda obs: obs.derive(data=obs.data + 1)) - result = make_stream(3).transform(xf).fetch() + result = make_stream(3).transform(xf).to_list() assert [o.data for o in result] == [1, 11, 21] def test_chained_transforms(self, make_stream): @@ -232,14 +232,14 @@ def test_chained_transforms(self, make_stream): add_one = FnTransformer(lambda obs: obs.derive(data=obs.data + 1)) double = FnTransformer(lambda obs: obs.derive(data=obs.data * 2)) - result = make_stream(3).transform(add_one).transform(double).fetch() + result = make_stream(3).transform(add_one).transform(double).to_list() # (0+1)*2=2, (10+1)*2=22, (20+1)*2=42 assert [o.data for o in result] == [2, 22, 42] def test_transform_can_skip(self, make_stream): """Returning None from a transformer skips that observation.""" keep_even = FnTransformer(lambda obs: obs if obs.data % 20 == 0 else None) - result = make_stream(5).transform(keep_even).fetch() + result = make_stream(5).transform(keep_even).to_list() assert [o.data for o in result] == [0, 20, 40] def test_transform_filter_transform(self, memory_session): @@ -256,7 +256,7 @@ def test_transform_filter_transform(self, memory_session): stream.transform(add_ten) # 11, 12, 13 .near((0, 0, 0), 5.0) # keeps pose at (0,0,0) and (1,0,0) .transform(double) # 22, 26 - .fetch() + .to_list() ) assert [o.data for o in result] == [22, 26] @@ -267,7 +267,7 @@ def double_all(upstream): for obs in upstream: yield obs.derive(data=obs.data * 2) - result = make_stream(3).transform(double_all).fetch() + result = make_stream(3).transform(double_all).to_list() assert [o.data for o in result] == [0, 20, 40] def test_generator_function_stateful(self, make_stream): @@ -279,7 +279,7 @@ def running_sum(upstream): total += obs.data yield obs.derive(data=total) - result = make_stream(3).transform(running_sum).fetch() + result = make_stream(3).transform(running_sum).to_list() # 0, 0+10=10, 10+20=30 assert [o.data for o in result] == [0, 10, 30] @@ -297,7 +297,7 @@ def test_quality_window(self, memory_session): stream.append(0.6, ts=2.2) xf = QualityWindow(quality_fn=lambda v: v, window=1.0) - result = stream.transform(xf).fetch() + result = stream.transform(xf).to_list() assert [o.data for o in result] == [0.9, 0.8, 0.6] def test_streaming_not_buffering(self, make_stream): @@ -310,7 +310,7 @@ def __call__(self, upstream): calls.append(obs.data) yield obs - result = make_stream(100).transform(CountingXf()).limit(3).fetch() + result = make_stream(100).transform(CountingXf()).limit(3).to_list() assert len(result) == 3 # The transformer should have processed at most a few more than 3 # (not all 100) due to lazy evaluation @@ -341,7 +341,7 @@ def __call__(self, upstream): stream.append(5) stream.append(10) - result = stream.chain(pipeline).fetch() + result = stream.chain(pipeline).to_list() assert [obs.data for obs in result] == [11, 21] def test_iteration_raises(self) -> None: @@ -365,7 +365,7 @@ def __call__(self, upstream): yield obs.derive(data=obs.data * 2) pipeline = Stream().transform(Double()) - result = stream.chain(pipeline).fetch() + result = stream.chain(pipeline).to_list() assert [obs.data for obs in result] == [20, 40, 60] def test_chain_multiple_transforms(self) -> None: @@ -386,7 +386,7 @@ def __call__(self, upstream): yield obs.derive(data=obs.data + 10) pipeline = Stream().transform(Double()).transform(AddTen()) - result = stream.chain(pipeline).fetch() + result = stream.chain(pipeline).to_list() assert result[0].data == 20 # (5 * 2) + 10 def test_chain_preserves_filters(self) -> None: @@ -399,7 +399,7 @@ def test_chain_preserves_filters(self) -> None: stream.append(30, ts=3.0) pipeline = Stream().after(1.5) - result = stream.chain(pipeline).fetch() + result = stream.chain(pipeline).to_list() assert [obs.data for obs in result] == [20, 30] def test_chain_rejects_bound_stream(self) -> None: @@ -690,7 +690,7 @@ def test_fetch_on_live_without_limit_raises(self, memory_session): live = stream.live(buffer=Unbounded()) with pytest.raises(TypeError, match="block forever"): - live.fetch() + live.to_list() def test_fetch_on_live_transform_without_limit_raises(self, memory_session): """fetch() on a live transform without limit() raises TypeError.""" @@ -699,7 +699,7 @@ def test_fetch_on_live_transform_without_limit_raises(self, memory_session): live_xf = stream.live(buffer=Unbounded()).transform(xf) with pytest.raises(TypeError, match="block forever"): - live_xf.fetch() + live_xf.to_list() def test_count_on_live_transform_raises(self, memory_session): """count() on a live transform stream raises TypeError.""" diff --git a/dimos/memory2/transform.py b/dimos/memory2/transform.py index 5754ac36e3..02c7dad9c8 100644 --- a/dimos/memory2/transform.py +++ b/dimos/memory2/transform.py @@ -15,8 +15,10 @@ from __future__ import annotations from abc import ABC, abstractmethod +import collections import inspect -from typing import TYPE_CHECKING, Any, Generic, TypeVar +import math +from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast from dimos.memory2.utils.formatting import FilterRepr @@ -105,17 +107,272 @@ def __call__(self, upstream: Iterator[Observation[T]]) -> Iterator[Observation[R yield o.derive(data=r) -def stride(n: int) -> FnIterTransformer[T, T]: +def downsample(n: int) -> FnIterTransformer[T, T]: """Yield every *n*-th observation, skipping the rest.""" if n < 1: - raise ValueError(f"stride(n) requires n >= 1, got {n}") + raise ValueError(f"downsample(n) requires n >= 1, got {n}") - def _stride(upstream: Iterator[Observation[T]]) -> Iterator[Observation[T]]: + def _downsample(upstream: Iterator[Observation[T]]) -> Iterator[Observation[T]]: for i, obs in enumerate(upstream): if i % n == 0: yield obs - return FnIterTransformer(_stride) + return FnIterTransformer(_downsample) + + +def throttle(interval: float) -> FnIterTransformer[T, T]: + """Yield at most one observation per *interval* seconds.""" + if interval <= 0: + raise ValueError(f"throttle(interval) requires interval > 0, got {interval}") + + def _throttle(upstream: Iterator[Observation[T]]) -> Iterator[Observation[T]]: + last_ts: float | None = None + for obs in upstream: + if last_ts is None or obs.ts - last_ts >= interval: + last_ts = obs.ts + yield obs + + return FnIterTransformer(_throttle) + + +def speed() -> FnIterTransformer[Any, float]: + """Compute speed (m/s) between consecutive observations from their poses.""" + + def _speed(upstream: Iterator[Observation[Any]]) -> Iterator[Observation[float]]: + prev: Observation[Any] | None = None + for obs in upstream: + if prev is not None and obs.pose is not None and prev.pose is not None: + dx = obs.pose[0] - prev.pose[0] + dy = obs.pose[1] - prev.pose[1] + dz = obs.pose[2] - prev.pose[2] + dt = obs.ts - prev.ts + v = math.sqrt(dx * dx + dy * dy + dz * dz) / dt if dt > 0 else 0.0 + yield obs.derive(data=v) + prev = obs + + return FnIterTransformer(_speed) + + +def smooth(window: int) -> FnIterTransformer[float, float]: + """Sliding window average over obs.data (must be numeric).""" + + def _smooth(upstream: Iterator[Observation[float]]) -> Iterator[Observation[float]]: + buf: collections.deque[float] = collections.deque(maxlen=window) + for obs in upstream: + buf.append(obs.data) + yield obs.derive(data=sum(buf) / len(buf)) + + return FnIterTransformer(_smooth) + + +def peaks( + prominence: float = 0.02, + distance: float = 5.0, + width: float | None = 0.5, + key: Callable[[Observation[T]], float] | None = None, +) -> FnIterTransformer[T, T]: + """Yield only the local-maximum observations, gated by peak shape. + + Runs scipy.signal.find_peaks on a scalar extracted from each observation + and emits the qualifying observations in timestamp order. Each yielded + observation gets its peak's prominence stashed on ``tags["peak_prominence"]``. + + All parameters are in the natural units of the stream (seconds and + data-range units), not sample counts. Time-based parameters are + converted to sample counts internally using the median sample spacing. + + - ``prominence``: minimum topological prominence to keep. Assumes the + upstream data is roughly normalized to [0, 1]; with default 0.1 a peak + has to stick up at least 10% of the range above its surroundings. + Pass 0.0 to return *every* local maximum with its prominence attached + — useful for plotting the distribution and picking a threshold. + - ``distance``: minimum time in seconds between detected peaks. + - ``width``: minimum peak width in seconds at 50% prominence. Filters + sub-second noise spikes. Pass ``None`` to disable. + - ``key``: callable that extracts the scalar signal from an observation. + Defaults to ``obs.data``. Use this when ``obs.data`` isn't the scalar + you want to detect peaks on (e.g. image observations with a + ``similarity`` metadata field). + """ + from scipy.signal import find_peaks + + key_fn: Callable[[Observation[T]], float] = ( + key if key is not None else cast("Callable[[Observation[T]], float]", lambda obs: obs.data) + ) + + def _peaks(upstream: Iterator[Observation[T]]) -> Iterator[Observation[T]]: + items = list(upstream) + if len(items) < 3: + return + values = [key_fn(obs) for obs in items] + + # Median sample spacing — used to convert seconds → samples + # consistently for both `distance` and `width`. + spacings = sorted(items[i + 1].ts - items[i].ts for i in range(len(items) - 1)) + median_spacing = spacings[len(spacings) // 2] if spacings else 0.0 + + def seconds_to_samples(seconds: float | None) -> int | None: + if seconds is None or median_spacing <= 0: + return None + return max(1, round(seconds / median_spacing)) + + # Always pass a numeric `prominence` so scipy populates props["prominences"]. + # Passing None would skip the computation, leaving tags empty. + idx, props = find_peaks( + values, + prominence=prominence, + distance=seconds_to_samples(distance), + width=seconds_to_samples(width), + ) + proms = props["prominences"] + + for i, prom in zip(idx, proms, strict=True): + yield items[int(i)].tag(peak_prominence=float(prom)) + + return FnIterTransformer(_peaks) + + +def _median(sorted_vals: list[float]) -> float: + n = len(sorted_vals) + return sorted_vals[n // 2] if n % 2 else 0.5 * (sorted_vals[n // 2 - 1] + sorted_vals[n // 2]) + + +def _mad_threshold(values: list[float], k: float) -> tuple[float, float, float]: + """Returns (threshold, median, scale) where scale = MAD * 1.4826.""" + median = _median(sorted(values)) + scale = _median(sorted(abs(v - median) for v in values)) * 1.4826 + return median + k * scale, median, scale + + +def _otsu_threshold(values: list[float]) -> float: + """1D Otsu threshold: maximizes between-class variance over the value list.""" + sorted_vals = sorted(values) + n = len(sorted_vals) + total = sum(sorted_vals) + best_var, best_thresh = -1.0, sorted_vals[-1] + cum = 0.0 + for i in range(n - 1): + cum += sorted_vals[i] + count = i + 1 + w0, w1 = count / n, (n - count) / n + m0, m1 = cum / count, (total - cum) / (n - count) + var = w0 * w1 * (m0 - m1) ** 2 + if var > best_var: + best_var, best_thresh = var, 0.5 * (sorted_vals[i] + sorted_vals[i + 1]) + return best_thresh + + +def _gap_threshold(values: list[float]) -> float: + """Largest log-ratio gap between consecutive sorted values.""" + sorted_vals = sorted(v for v in values if v > 0) + n = len(sorted_vals) + if n < 2: + return sorted(values)[len(values) // 2] if values else 0.0 + best_ratio, best_idx = 0.0, n - 1 + for i in range(n - 1): + ratio = sorted_vals[i + 1] / sorted_vals[i] + if ratio > best_ratio: + best_ratio, best_idx = ratio, i + return 0.5 * (sorted_vals[best_idx] + sorted_vals[best_idx + 1]) + + +def significant( + method: Literal["mad", "otsu", "gap"] = "mad", + k: float = 3.0, + tag: str = "peak_prominence", +) -> FnIterTransformer[T, T]: + """Keep observations whose ``tags[tag]`` is an outlier in its own distribution. + + Designed to chain after :func:`peaks` so the cutoff is *derived from the + prominence distribution itself*, invariant to overall signal range. The + upstream :func:`peaks` call still does the shape gating (``distance``, + ``width``, and a small ``prominence`` floor to reject obvious noise); + :func:`significant` then picks a statistical cutoff from what survives. + + Each surviving observation gets ``tags["significance"]`` attached. + + - ``method``: + - ``"mad"``: keep values above ``median + k * 1.4826 * MAD``. Robust + default; assumes most upstream values are noise. ``significance`` + is the resulting (value - median) / scale, i.e. a robust z-score. + - ``"otsu"``: 1D Otsu — picks the threshold maximizing between-class + variance over the value distribution. Parameter-free; works when + the distribution is roughly bimodal. ``significance`` is value / + threshold. + - ``"gap"``: largest ratio gap between consecutive sorted values. + Crisp when peaks are well separated from noise, brittle otherwise + (a single tiny value at the bottom of the list can dominate). + ``significance`` is value / threshold. + - ``k``: only used by ``"mad"`` (≈3 ≙ 3-sigma equivalent). + - ``tag``: which tag holds the scalar to threshold on. Defaults to + ``peak_prominence`` (set by :func:`peaks`). + """ + if method not in ("mad", "otsu", "gap"): + raise ValueError(f"unknown method {method!r}; expected 'mad', 'otsu', or 'gap'") + + def _significant(upstream: Iterator[Observation[T]]) -> Iterator[Observation[T]]: + items = list(upstream) + if len(items) < 2: + return + try: + values = [float(o.tags[tag]) for o in items] + except KeyError as e: + raise ValueError( + f"significant() requires upstream observations to be tagged with {tag!r}; " + f"chain after peaks() or set tag= to a tag that exists" + ) from e + + if method == "mad": + threshold, median, scale = _mad_threshold(values, k) + for obs, val in zip(items, values, strict=True): + if val >= threshold and scale > 0: + yield obs.tag(significance=(val - median) / scale) + else: + threshold = _otsu_threshold(values) if method == "otsu" else _gap_threshold(values) + for obs, val in zip(items, values, strict=True): + if val >= threshold: + yield obs.tag(significance=val / threshold if threshold > 0 else 0.0) + + return FnIterTransformer(_significant) + + +def smooth_time(seconds: float) -> FnIterTransformer[float, float]: + """Sliding window average over obs.data, by time. + + Averages all observations whose timestamp is within ``seconds`` of the + current observation's timestamp. Unlike ``smooth(window)`` (which uses a + fixed sample count and so depends on sampling rate), the effective window + here adapts: dense regions average more samples, sparse regions average + fewer. + """ + if seconds <= 0: + raise ValueError(f"smooth_time(seconds) requires seconds > 0, got {seconds}") + + def _smooth(upstream: Iterator[Observation[float]]) -> Iterator[Observation[float]]: + buf: collections.deque[Observation[float]] = collections.deque() + for obs in upstream: + buf.append(obs) + while buf and obs.ts - buf[0].ts > seconds: + buf.popleft() + yield obs.derive(data=sum(o.data for o in buf) / len(buf)) + + return FnIterTransformer(_smooth) + + +def normalize() -> FnIterTransformer[float, float]: + """Normalize obs.data to [0, 1] range across all observations.""" + + def _normalize(upstream: Iterator[Observation[float]]) -> Iterator[Observation[float]]: + items = list(upstream) + if not items: + return + values = [obs.data for obs in items] + lo, hi = min(values), max(values) + for obs in items: + t = (obs.data - lo) / (hi - lo) if hi != lo else 0.5 + yield obs.derive(data=t) + + return FnIterTransformer(_normalize) class QualityWindow(Transformer[T, T]): diff --git a/dimos/memory2/type/observation.py b/dimos/memory2/type/observation.py index 3efefc0220..63b10a6dbf 100644 --- a/dimos/memory2/type/observation.py +++ b/dimos/memory2/type/observation.py @@ -14,9 +14,15 @@ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass, field, fields +import sys import threading -from typing import TYPE_CHECKING, Any, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self if TYPE_CHECKING: from collections.abc import Callable @@ -25,6 +31,7 @@ from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped T = TypeVar("T") +R = TypeVar("R") class _Unloaded: @@ -45,6 +52,7 @@ class Observation(Generic[T]): id: int ts: float + data_type: type = object pose: Any | None = None tags: dict[str, Any] = field(default_factory=dict) _data: T | _Unloaded = field(default=_UNLOADED, repr=False) @@ -58,7 +66,8 @@ def pose_stamped(self) -> PoseStamped: if self.pose is None: raise LookupError("No pose set on this observation") x, y, z, qx, qy, qz, qw = self.pose - return PoseStamped(ts=self.ts, position=(x, y, z), orientation=(qx, qy, qz, qw)) + ps: PoseStamped = PoseStamped(ts=self.ts, position=(x, y, z), orientation=(qx, qy, qz, qw)) + return ps @property def data(self) -> T: @@ -77,29 +86,32 @@ def data(self) -> T: return val return val - def derive(self, *, data: Any, **overrides: Any) -> Observation[Any]: - """Create a new observation preserving ts/pose/tags, replacing data. + def derive(self, *, data: R, **overrides: Any) -> Observation[R]: + """New observation with replaced ``data``; other fields carry over. - If ``embedding`` is passed, promotes the result to + Passing ``embedding`` on a plain :class:`Observation` promotes it to :class:`EmbeddedObservation`. """ - if "embedding" in overrides: - return EmbeddedObservation( - id=self.id, - ts=overrides.get("ts", self.ts), - pose=overrides.get("pose", self.pose), - tags=overrides.get("tags", self.tags), - _data=data, - embedding=overrides["embedding"], - similarity=overrides.get("similarity"), - ) - return Observation( - id=self.id, - ts=overrides.get("ts", self.ts), - pose=overrides.get("pose", self.pose), - tags=overrides.get("tags", self.tags), - _data=data, + cls: type[Observation[Any]] = ( + EmbeddedObservation + if "embedding" in overrides and not isinstance(self, EmbeddedObservation) + else type(self) + ) + kwargs: dict[str, Any] = {f.name: getattr(self, f.name) for f in fields(self)} + kwargs.update(overrides) + kwargs.update(data_type=type(data), _data=data, _loader=None, _data_lock=threading.Lock()) + return cast("Observation[R]", cls(**kwargs)) + + def tag(self, **tags: Any) -> Self: + """Return a new observation with tags merged in.""" + kwargs: dict[str, Any] = {f.name: getattr(self, f.name) for f in fields(self)} + kwargs.update( + tags={**self.tags, **tags}, + _data=_UNLOADED, + _loader=lambda: self.data, + _data_lock=threading.Lock(), ) + return type(self)(**kwargs) @dataclass @@ -108,15 +120,3 @@ class EmbeddedObservation(Observation[T]): embedding: Embedding | None = None similarity: float | None = None - - def derive(self, *, data: Any, **overrides: Any) -> EmbeddedObservation[Any]: - """Preserve embedding unless explicitly replaced.""" - return EmbeddedObservation( - id=self.id, - ts=overrides.get("ts", self.ts), - pose=overrides.get("pose", self.pose), - tags=overrides.get("tags", self.tags), - _data=data, - embedding=overrides.get("embedding", self.embedding), - similarity=overrides.get("similarity", self.similarity), - ) diff --git a/dimos/memory2/vectorstore/base.py b/dimos/memory2/vectorstore/base.py index 7069c62312..d76daaa5be 100644 --- a/dimos/memory2/vectorstore/base.py +++ b/dimos/memory2/vectorstore/base.py @@ -52,8 +52,12 @@ def put(self, stream_name: str, key: int, embedding: Embedding) -> None: ... @abstractmethod - def search(self, stream_name: str, query: Embedding, k: int) -> list[tuple[int, float]]: - """Return top-k (observation_id, similarity) pairs, descending.""" + def search(self, stream_name: str, query: Embedding, k: int | None) -> list[tuple[int, float]]: + """Return top-k (observation_id, similarity) pairs, descending. + + ``k=None`` means "store default" — backends that require a bound + (e.g. sqlite-vec) pick one internally; unbounded backends return all. + """ ... @abstractmethod diff --git a/dimos/memory2/vectorstore/memory.py b/dimos/memory2/vectorstore/memory.py index a776a26986..5d3d5c5d6b 100644 --- a/dimos/memory2/vectorstore/memory.py +++ b/dimos/memory2/vectorstore/memory.py @@ -48,13 +48,13 @@ def stop(self) -> None: def put(self, stream_name: str, key: int, embedding: Embedding) -> None: self._vectors.setdefault(stream_name, {})[key] = embedding - def search(self, stream_name: str, query: Embedding, k: int) -> list[tuple[int, float]]: + def search(self, stream_name: str, query: Embedding, k: int | None) -> list[tuple[int, float]]: vectors = self._vectors.get(stream_name, {}) if not vectors: return [] scored = [(key, float(emb @ query)) for key, emb in vectors.items()] scored.sort(key=lambda x: x[1], reverse=True) - return scored[:k] + return scored if k is None else scored[:k] def delete(self, stream_name: str, key: int) -> None: vectors = self._vectors.get(stream_name, {}) diff --git a/dimos/memory2/vectorstore/sqlite.py b/dimos/memory2/vectorstore/sqlite.py index f362a7eb3f..bb5e9d200e 100644 --- a/dimos/memory2/vectorstore/sqlite.py +++ b/dimos/memory2/vectorstore/sqlite.py @@ -85,13 +85,17 @@ def put(self, stream_name: str, key: int, embedding: Embedding) -> None: (key, json.dumps(vec)), ) - def search(self, stream_name: str, query: Embedding, k: int) -> list[tuple[int, float]]: + # vec0 virtual tables require an explicit k in the MATCH clause; apply a + # default cap when the caller did not specify one. + _DEFAULT_K = 4096 + + def search(self, stream_name: str, query: Embedding, k: int | None) -> list[tuple[int, float]]: validate_identifier(stream_name) vec = query.to_numpy().tolist() try: rows = self._conn.execute( f'SELECT rowid, distance FROM "{stream_name}_vec" WHERE embedding MATCH ? AND k = ?', - (json.dumps(vec), k), + (json.dumps(vec), k if k is not None else self._DEFAULT_K), ).fetchall() except sqlite3.OperationalError as e: if "no such table" in str(e): diff --git a/dimos/memory2/vis/color.py b/dimos/memory2/vis/color.py new file mode 100644 index 0000000000..c1897f4ccf --- /dev/null +++ b/dimos/memory2/vis/color.py @@ -0,0 +1,271 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Color type and utilities for memory2 visualization. + +Canonical storage is RGBA float 0-1 (matplotlib-native, easy to blend). All +consumers convert at the boundary: + + matplotlib: Color.rgba_f() → (r, g, b, a) + SVG: Color.hex() + a → "#rrggbb" + opacity="..." + rerun: Color.rgb_u8() → (r, g, b) u8 + PIL: Color.rgba_u8() → (r, g, b, a) u8 + +Use ``Color.from_hex("#3498db")`` or a palette name (``"red"``, ``"blue"``, …) +to construct. Use ``ColorRange(cmap)`` + ``range(value)`` for cmap-deferred +colors whose min/max is learned as you add elements. +""" + +from __future__ import annotations + +from collections.abc import Iterable, Iterator +import colorsys +from dataclasses import dataclass +import functools +from typing import Any + + +@functools.lru_cache(maxsize=16) +def _cmap(name: str) -> Any: + import matplotlib.pyplot as plt + + return plt.get_cmap(name) + + +@dataclass(frozen=True, eq=False) +class Color: + """Concrete RGBA color, stored as floats in [0, 1]. + + Immutable — all manipulation methods return a new instance. + """ + + r: float + g: float + b: float + a: float = 1.0 + + @classmethod + def from_hex(cls, s: str) -> Color: + """Parse ``#rrggbb`` / ``#rgb`` or a palette name (``"red"``, …).""" + if not s.startswith("#"): + try: + return _PALETTE_NAMES[s.lower()] + except KeyError: + raise ValueError( + f"Unknown color name {s!r}. Use a hex string (#rrggbb) " + f"or one of the palette names: {sorted(_PALETTE_NAMES)}" + ) from None + h = s.lstrip("#") + if len(h) == 3: + h = "".join(c * 2 for c in h) + if len(h) != 6: + raise ValueError(f"Invalid hex color {s!r}") + return cls(int(h[0:2], 16) / 255, int(h[2:4], 16) / 255, int(h[4:6], 16) / 255) + + @classmethod + def from_cmap(cls, cmap: str, t: float) -> Color: + """Sample a matplotlib colormap at ``t`` (clamped to [0, 1]).""" + r, g, b, a = _cmap(cmap)(max(0.0, min(1.0, t))) + return cls(r, g, b, a) + + @classmethod + def coerce(cls, x: Color | DeferredColor | str) -> Color: + """Normalize any accepted color form to a concrete ``Color``.""" + if isinstance(x, Color): + return x + if isinstance(x, DeferredColor): + return x.resolve() + return cls.from_hex(x) + + def hex(self) -> str: + """``#rrggbb`` — alpha is dropped; SVG pairs this with an ``opacity`` attribute.""" + return f"#{round(self.r * 255):02x}{round(self.g * 255):02x}{round(self.b * 255):02x}" + + def rgb_u8(self) -> tuple[int, int, int]: + return (round(self.r * 255), round(self.g * 255), round(self.b * 255)) + + def rgba_u8(self) -> tuple[int, int, int, int]: + return (*self.rgb_u8(), round(self.a * 255)) + + def rgba_f(self) -> tuple[float, float, float, float]: + return (self.r, self.g, self.b, self.a) + + def with_alpha(self, a: float) -> Color: + return Color(self.r, self.g, self.b, a) + + def blend(self, other: Color, t: float) -> Color: + """Linear RGB blend: ``t=0`` → self, ``t=1`` → other.""" + t = max(0.0, min(1.0, t)) + return Color( + self.r + (other.r - self.r) * t, + self.g + (other.g - self.g) * t, + self.b + (other.b - self.b) * t, + self.a + (other.a - self.a) * t, + ) + + def __str__(self) -> str: + return self.hex() + + def __eq__(self, other: object) -> bool: + if isinstance(other, Color): + return (self.r, self.g, self.b, self.a) == (other.r, other.g, other.b, other.a) + if isinstance(other, str): + try: + return self.hex() == Color.from_hex(other).hex() + except ValueError: + return False + return NotImplemented + + def __hash__(self) -> int: + return hash((self.r, self.g, self.b, self.a)) + + +class ColorRange: + """Tracks value min/max as you call it; returns :class:`DeferredColor` instances. + + Each ``ColorRange`` is its own aggregator — no cross-plot bleed, no global + registry. The returned ``DeferredColor`` resolves to a cmap-sampled + :class:`Color` at render time using the final min/max. + """ + + def __init__(self, cmap: str = "turbo") -> None: + self.cmap = cmap + self._lo: float | None = None + self._hi: float | None = None + + def __call__(self, value: float) -> DeferredColor: + self._lo = value if self._lo is None else min(self._lo, value) + self._hi = value if self._hi is None else max(self._hi, value) + return DeferredColor(self, value) + + +@dataclass(frozen=True) +class DeferredColor: + """A value tagged with a :class:`ColorRange`; resolves lazily to :class:`Color`.""" + + range: ColorRange + value: float + + def resolve(self) -> Color: + lo, hi = self.range._lo, self.range._hi + t = 0.5 if lo is None or hi is None or lo == hi else (self.value - lo) / (hi - lo) + return Color.from_cmap(self.range.cmap, t) + + def __str__(self) -> str: + return self.resolve().hex() + + +def resolve_deferred(elements: Iterable[Any]) -> None: + """Mutate ``el.color`` from :class:`DeferredColor` → :class:`Color` for each element.""" + for el in elements: + c = getattr(el, "color", None) + if isinstance(c, DeferredColor): + el.color = c.resolve() + + +# Named palette: 12 visually-distinct colors that share visual weight. +# +# Indices 0..5 are hand-curated flat-UI colors that match the defaults in +# vis/space/elements.py so a Plot embedded next to a Space drawing reads as +# the same family. Indices 6..11 are generated by gap-subdivision in HSL +# space (preserving the curated set's average L≈0.51 / S≈0.72) so they fill +# the largest hue gaps between the curated colors. Together they form 12 +# maximally-distinct hues. Beyond 12, `palette_iter` continues with a +# golden-angle hue walk that uses the same average L/S. + +blue = Color.from_hex("#3498db") +red = Color.from_hex("#e74c3c") +yellow = Color.from_hex("#f1c40f") +teal = Color.from_hex("#1abc9c") +purple = Color.from_hex("#9b59b6") +orange = Color.from_hex("#e67e22") +green = Color.from_hex("#4cdc29") +magenta = Color.from_hex("#dc2994") +indigo = Color.from_hex("#3329dc") +cyan = Color.from_hex("#29c9dc") +vermilion = Color.from_hex("#dc5b29") +amber = Color.from_hex("#dc9a29") + +PALETTE: list[Color] = [ + blue, + red, + yellow, + teal, + purple, + orange, + green, + magenta, + indigo, + cyan, + vermilion, + amber, +] + +_PALETTE_NAMES: dict[str, Color] = { + "blue": blue, + "red": red, + "yellow": yellow, + "teal": teal, + "purple": purple, + "orange": orange, + "green": green, + "magenta": magenta, + "indigo": indigo, + "cyan": cyan, + "vermilion": vermilion, + "amber": amber, +} + + +def palette_iter( + palette: list[Color] = PALETTE, + exclude: Iterable[Color | str] | None = None, +) -> Iterator[Color]: + """Yield colors forever for auto-assigning Series/Markers. + + Yields ``palette`` in order, then continues indefinitely via a + golden-angle (137.5°) hue walk anchored at the average L/S of ``palette`` + so generated colors share visual weight with the named ones. + + ``exclude`` skips already-pinned colors; accepts ``Color`` or hex/name strings. + """ + excluded: set[str] = set() + for x in exclude or (): + try: + excluded.add(Color.coerce(x).hex()) + except ValueError: + pass # unknown string — nothing to exclude + + def emit(c: Color) -> bool: + return c.hex() not in excluded + + for c in palette: + if emit(c): + yield c + + if not palette: + return + + hls = [colorsys.rgb_to_hls(c.r, c.g, c.b) for c in palette] + avg_l = sum(p[1] for p in hls) / len(hls) + avg_s = sum(p[2] for p in hls) / len(hls) + # Anchor the walk on the last palette color's hue so the first + # generated color is offset 137.5° from the end of the named set. + hue = hls[-1][0] + while True: + hue = (hue + 137.5 / 360.0) % 1.0 + r, g, b = colorsys.hls_to_rgb(hue, avg_l, avg_s) + c = Color(r, g, b) + if emit(c): + yield c diff --git a/dimos/memory2/vis/plot/elements.py b/dimos/memory2/vis/plot/elements.py new file mode 100644 index 0000000000..8b5932da53 --- /dev/null +++ b/dimos/memory2/vis/plot/elements.py @@ -0,0 +1,105 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Element types for Plot (2D charts).""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import StrEnum +from typing import Union + + +class Style(StrEnum): + """Line style for Series and HLine elements. + + Values match matplotlib's `linestyle` names so they pass through directly + to the renderer without translation. + """ + + solid = "solid" + dashed = "dashed" + dotted = "dotted" + + +@dataclass +class Series: + """Line connecting (t, y) points. + + ``connect`` is the maximum gap (in x-axis units, typically seconds) over + which the renderer will draw a connecting line. Samples whose neighbors + are further apart get visually separated — useful when a stream has + holes (e.g. an embedding stream that skipped dark frames). Set to + ``None`` to always connect regardless of gap size. + + ``gap_fill`` controls what happens at a gap. ``None`` (default) breaks + the line entirely. A float value drops the line to that value across + the gap, producing a "valley" — set ``gap_fill=0.0`` to render holes as + drops to zero. + """ + + ts: list[float] + values: list[float] + color: str | None = None + width: float = 1.5 + label: str | None = None + axis: str | None = None + opacity: float = 1.0 + style: Style = Style.solid + connect: float | None = 2.0 + gap_fill: float | None = None + + +@dataclass +class Markers: + """Scatter dots at (t, y) points.""" + + ts: list[float] + values: list[float] + color: str | None = None + radius: float = 0.5 + label: str | None = None + axis: str | None = None + opacity: float = 1.0 + + +@dataclass +class HLine: + """Horizontal reference line.""" + + y: float + color: str = "#888888" + style: Style = Style.dashed + label: str | None = None + axis: str | None = None + opacity: float = 1.0 + + +@dataclass +class VLine: + """Vertical reference line. + + Always draws on the primary x-axis — twin axes all share the same x, so + there's no need for an ``axis`` field: the line spans the full y range + regardless of which axes owns it. + """ + + x: float + color: str = "#888888" + style: Style = Style.dashed + label: str | None = None + opacity: float = 1.0 + + +PlotElement = Union[Series, Markers, HLine, VLine] diff --git a/dimos/memory2/vis/plot/plot.py b/dimos/memory2/vis/plot/plot.py new file mode 100644 index 0000000000..6235e44bda --- /dev/null +++ b/dimos/memory2/vis/plot/plot.py @@ -0,0 +1,121 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Plot: 2D chart builder for memory2 visualization.""" + +from __future__ import annotations + +from enum import StrEnum +from typing import Any + +from dimos.memory2.vis.plot.elements import HLine, Markers, PlotElement, Series, VLine + + +class TimeAxis(StrEnum): + """How the x-axis is formatted. + + - ``raw``: unix timestamps as-is (matplotlib's default numeric formatter). + - ``relative``: seconds since the first sample, e.g. ``"0s"``, ``"60s"``. + - ``absolute``: wall-clock time from the system timezone, e.g. ``"11:09:11"``. + """ + + raw = "raw" + relative = "relative" + absolute = "absolute" + + +class Plot: + """2D chart. Today treats X as time; generalized axes are follow-up work. + + Elements can be added as: + - Series/Markers/HLine directly + - Stream[float] → materializes, extracts obs.ts/obs.data into Series + - list[Observation[float]] → extracts obs.ts/obs.data into Series + """ + + def __init__(self, time_axis: TimeAxis = TimeAxis.relative) -> None: + self._elements: list[PlotElement] = [] + self.time_axis = time_axis + + def add(self, element: Any, **kwargs: Any) -> Plot: + """Add a plot element with smart dispatch.""" + from dimos.memory2.stream import Stream + from dimos.memory2.type.observation import Observation + + if isinstance(element, (Series, Markers, HLine, VLine)): + self._elements.append(element) + elif isinstance(element, Stream): + self._add_from_observations(element.to_list(), **kwargs) + elif isinstance(element, list) and element and isinstance(element[0], Observation): + self._add_from_observations(element, **kwargs) + elif hasattr(element, "__iter__"): + # Try as iterable of observations + items = list(element) + if items and isinstance(items[0], Observation): + self._add_from_observations(items, **kwargs) + else: + raise TypeError(f"Plot.add() cannot handle iterable of {type(items[0]).__name__}.") + else: + raise TypeError( + f"Plot.add() does not know how to handle {type(element).__name__}. " + f"Pass Series, Markers, HLine, VLine, a Stream, or a list of Observations." + ) + + return self + + def _add_from_observations(self, obs_list: list[Any], **kwargs: Any) -> None: + """Convert observations to a Series (ts → x, data → y).""" + ts = [obs.ts for obs in obs_list] + values = [float(obs.data) for obs in obs_list] + self._elements.append(Series(ts=ts, values=values, **kwargs)) + + def to_svg(self, path: str | None = None) -> str: + """Render to SVG string. Optionally write to file.""" + from dimos.memory2.vis.color import resolve_deferred + from dimos.memory2.vis.plot.svg import render + + resolve_deferred(self._elements) + svg = render(self) + if path is not None: + with open(path, "w") as f: + f.write(svg) + return svg + + def to_rerun(self, app_id: str = "plot", spawn: bool = True) -> None: + """Render to Rerun viewer (placeholder — currently a no-op).""" + from dimos.memory2.vis.color import resolve_deferred + from dimos.memory2.vis.plot.rerun import render + + resolve_deferred(self._elements) + render(self, app_id=app_id, spawn=spawn) + + def _repr_svg_(self) -> str: + """Jupyter inline display.""" + return self.to_svg() + + @property + def elements(self) -> list[PlotElement]: + """Read-only access to accumulated elements.""" + return list(self._elements) + + def __len__(self) -> int: + return len(self._elements) + + def __repr__(self) -> str: + counts: dict[str, int] = {} + for el in self._elements: + name = type(el).__name__ + counts[name] = counts.get(name, 0) + 1 + parts = [f"{n}={c}" for n, c in sorted(counts.items())] + return f"Plot({', '.join(parts)})" diff --git a/dimos/memory2/vis/plot/rerun.py b/dimos/memory2/vis/plot/rerun.py new file mode 100644 index 0000000000..e5ce9365ab --- /dev/null +++ b/dimos/memory2/vis/plot/rerun.py @@ -0,0 +1,27 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Rerun renderer for Plot — placeholder, no-op until we design the real one.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dimos.memory2.vis.plot.plot import Plot + + +def render(plot: Plot, app_id: str = "plot", spawn: bool = True) -> None: + """Placeholder — does nothing. Real rerun output for Plot is future work.""" + pass diff --git a/dimos/memory2/vis/plot/svg.py b/dimos/memory2/vis/plot/svg.py new file mode 100644 index 0000000000..0249f2bcc2 --- /dev/null +++ b/dimos/memory2/vis/plot/svg.py @@ -0,0 +1,258 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Matplotlib-based SVG renderer for Plot.""" + +from __future__ import annotations + +import io +from typing import TYPE_CHECKING, Any + +import matplotlib +import matplotlib.pyplot as plt + +from dimos.memory2.vis.color import Color, palette_iter +from dimos.memory2.vis.plot.elements import HLine, Markers, Series, VLine +from dimos.memory2.vis.plot.plot import TimeAxis + +if TYPE_CHECKING: + from dimos.memory2.vis.plot.plot import Plot + +matplotlib.use("Agg") + + +def _apply_time_axis(ax: Any, plot: Plot) -> None: + """Install an x-axis tick formatter based on plot.time_axis.""" + if plot.time_axis == TimeAxis.raw: + return + + # Reference point: earliest sample across all Series/Markers. + all_ts: list[float] = [] + for el in plot.elements: + if isinstance(el, (Series, Markers)) and el.ts: + all_ts.append(el.ts[0]) + if not all_ts: + return + t0 = min(all_ts) + + from matplotlib.ticker import FuncFormatter + + if plot.time_axis == TimeAxis.relative: + + def fmt(ts: float, _: int = 0) -> str: + return f"{ts - t0:.0f}s" + elif plot.time_axis == TimeAxis.absolute: + from datetime import datetime + + def fmt(ts: float, _: int = 0) -> str: + return datetime.fromtimestamp(ts).strftime("%H:%M:%S") + else: + return + + ax.xaxis.set_major_formatter(FuncFormatter(fmt)) + + +def _break_on_gaps( + ts: list[float], + values: list[float], + max_gap: float | None, + fill: float | None = None, +) -> tuple[list[float], list[float]]: + """Handle gaps in a series. Returns (ts', values') ready to plot. + + When two consecutive samples are more than ``max_gap`` apart in x: + + - ``fill is None``: insert ``(NaN, NaN)`` between them. matplotlib's + ``plot`` skips line segments touching a NaN endpoint, so the line + visually breaks across the gap. + - ``fill is not None``: insert ``(prev_t, fill)`` and ``(next_t, fill)``, + so the line drops vertically at the prev sample, runs flat at ``fill`` + across the gap, then rises vertically at the next sample. + + Returns the arrays unchanged when ``max_gap`` is ``None``. + """ + if max_gap is None or len(ts) < 2: + return list(ts), list(values) + out_ts: list[float] = [ts[0]] + out_v: list[float] = [values[0]] + nan = float("nan") + for i in range(1, len(ts)): + if ts[i] - ts[i - 1] > max_gap: + if fill is None: + out_ts.append(nan) + out_v.append(nan) + else: + out_ts.append(ts[i - 1]) + out_v.append(fill) + out_ts.append(ts[i]) + out_v.append(fill) + out_ts.append(ts[i]) + out_v.append(values[i]) + return out_ts, out_v + + +def render(plot: Plot, width: float = 10, height: float = 3.5) -> str: + """Render a Plot to an SVG string via matplotlib.""" + with plt.style.context("dark_background"): + fig, ax = plt.subplots(figsize=(width, height)) + fig.patch.set_alpha(0.0) + ax.set_facecolor("#16213e") + ax.grid(True, color="#2a2a4a", linewidth=0.5) + + # Lazily create twin y-axes for any element with axis != None. + # All twins share the primary x-axis (matplotlib `ax.twinx()`). + # The first twin sits at the default right edge; each additional twin + # gets its right spine pushed outward in axes-relative coordinates so + # their tick labels form a ladder instead of stacking on top of each + # other. The figure's right margin grows below to make room. + axes: dict[str | None, Any] = {None: ax} + twin_offset_step = 0.10 + + def axis_for(name: str | None) -> Any: + if name not in axes: + twin = ax.twinx() + twin.set_facecolor("none") + # Index among twins: 0 = first (no offset), 1 = second, ... + twin_index = sum(1 for k in axes if k is not None) + if twin_index > 0: + twin.spines["right"].set_position(("axes", 1.0 + twin_offset_step * twin_index)) + axes[name] = twin + return axes[name] + + # Drive a single shared color cycle across all axes (primary + twins) + # so series on a twin don't reuse the primary's first color. Excludes + # any color the user has already pinned to a specific element so the + # auto-cycle won't double-up on it. + explicit_colors = { + el.color + for el in plot.elements + if isinstance(el, (Series, Markers)) and el.color is not None + } + color_iter = palette_iter(exclude=explicit_colors) + + # Track the dominant color of each twin axis (the color of the first + # Series/Markers landed on it) so we can color-code its spine and tick + # labels after the plot loop. Primary axis stays neutral so its ticks + # read as the baseline. + axis_colors: dict[str, tuple[float, float, float, float]] = {} + + def mpl_color(c: Color, opacity: float) -> tuple[float, float, float, float]: + return c.with_alpha(c.a * opacity).rgba_f() + + for el in plot.elements: + # VLine has no axis field — it always draws on the primary. + target = ax if isinstance(el, VLine) else axis_for(el.axis) + raw: str | Color | None = el.color + if raw is None and isinstance(el, (Series, Markers)): + raw = next(color_iter) + color = Color.coerce(raw) if raw is not None else None + rgba = mpl_color(color, el.opacity) if color is not None else None + if ( + not isinstance(el, VLine) + and el.axis is not None + and el.axis not in axis_colors + and isinstance(el, (Series, Markers)) + and rgba is not None + ): + axis_colors[el.axis] = rgba + if isinstance(el, Series): + ts, values = _break_on_gaps(el.ts, el.values, el.connect, el.gap_fill) + target.plot( + ts, + values, + color=rgba, + linewidth=el.width, + label=el.label, + linestyle=el.style.value, + ) + elif isinstance(el, Markers): + target.scatter( + el.ts, + el.values, + color=rgba, + s=el.radius**2 * 10, + label=el.label, + ) + elif isinstance(el, HLine): + target.axhline( + el.y, + color=rgba, + linestyle=el.style.value, + linewidth=1, + label=el.label, + ) + elif isinstance(el, VLine): + # Always on the primary — twins share x, so visually identical. + ax.axvline( + el.x, + color=rgba, + linestyle=el.style.value, + linewidth=1, + label=el.label, + ) + + # Thin spine + tick borders (1px) on every axes — primary and twins. + # Color-code each twin's right spine and y tick labels with the color + # of its first Series/Markers, so users can tell which numbers belong + # to which series. Primary axis stays neutral. + for name, axes_obj in axes.items(): + for spine in axes_obj.spines.values(): + spine.set_linewidth(1) + axes_obj.tick_params(width=1) + if name is not None and name in axis_colors: + c = axis_colors[name] + axes_obj.spines["right"].set_color(c) + axes_obj.spines["right"].set_linewidth(1) + axes_obj.tick_params(axis="y", colors=c, width=1) + + # Combine handles from all axes into a single legend. Attach it to the + # *last* axes created (the most recent twin, or the primary if there + # are no twins) so the legend paints last and isn't covered by twin + # tick labels / spines drawn afterward in matplotlib's axes draw order. + all_handles: list[Any] = [] + all_labels: list[str] = [] + for axes_obj in axes.values(): + h, l = axes_obj.get_legend_handles_labels() + all_handles.extend(h) + all_labels.extend(l) + if all_handles: + legend_host = next(reversed(axes.values())) + legend_host.legend( + all_handles, + all_labels, + facecolor="#1a1a2e", + edgecolor="#2a2a4a", + framealpha=0.9, + ) + + ax.set_xlabel("time (s)") + _apply_time_axis(ax, plot) + + # Make room on the right for offset twin spines (each extra twin past + # the first needs about `twin_offset_step` of axes-relative width). + # `tight_layout` doesn't know about offset spines and will clip them, + # so for 2+ twins we use explicit margins instead. + n_twins = sum(1 for k in axes if k is not None) + if n_twins >= 2: + extras = n_twins - 1 + right_margin = max(0.6, 0.95 - twin_offset_step * extras) + fig.subplots_adjust(left=0.08, right=right_margin, top=0.95, bottom=0.18) + else: + fig.tight_layout() + + buf = io.StringIO() + fig.savefig(buf, format="svg") + plt.close(fig) + + return buf.getvalue() diff --git a/dimos/memory2/vis/plot/test_plot.py b/dimos/memory2/vis/plot/test_plot.py new file mode 100644 index 0000000000..12cb979c5a --- /dev/null +++ b/dimos/memory2/vis/plot/test_plot.py @@ -0,0 +1,354 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Plot builder and SVG rendering.""" + +import math + +import pytest + +from dimos.memory2.type.observation import Observation +from dimos.memory2.vis.plot.elements import HLine, Markers, Series, Style, VLine +from dimos.memory2.vis.plot.plot import Plot, TimeAxis + + +class TestPlotAdd: + """Plot.add() smart dispatch.""" + + def test_add_series(self): + p = Plot() + s = Series(ts=[1, 2, 3], values=[10, 20, 30], label="speed") + p.add(s) + assert len(p) == 1 + assert p.elements[0] is s + + def test_add_markers(self): + p = Plot() + m = Markers(ts=[1, 2], values=[5, 10], color="red") + p.add(m) + assert len(p) == 1 + assert isinstance(p.elements[0], Markers) + + def test_add_hline(self): + p = Plot() + p.add(HLine(y=0.5, label="threshold")) + assert len(p) == 1 + assert p.elements[0].y == 0.5 + + def test_add_from_observation_list(self): + obs_list = [ + Observation(id=i, ts=float(i), pose=(i, 0, 0, 0, 0, 0, 1), _data=float(i * 10)) + for i in range(5) + ] + p = Plot() + p.add(obs_list, label="test", color="blue") + assert len(p) == 1 + el = p.elements[0] + assert isinstance(el, Series) + assert el.ts == [0.0, 1.0, 2.0, 3.0, 4.0] + assert el.values == [0.0, 10.0, 20.0, 30.0, 40.0] + assert el.label == "test" + assert el.color == "blue" + + def test_add_chaining(self): + p = Plot().add(Series(ts=[1, 2], values=[10, 20])).add(HLine(y=15)) + assert len(p) == 2 + + def test_add_unknown_type_raises(self): + p = Plot() + with pytest.raises(TypeError, match="does not know how to handle"): + p.add(42) + + +class TestPlotSVG: + """SVG rendering via matplotlib.""" + + def test_empty_plot(self): + svg = Plot().to_svg() + assert "" in svg + + def test_series_renders(self): + p = Plot() + p.add(Series(ts=[0, 1, 2, 3], values=[0, 1, 4, 9], label="y=x²")) + svg = p.to_svg() + assert "HH:MM:SS + assert re.search(r"\d\d:\d\d:\d\d", svg) + + def test_time_axis_raw_preserves_default_matplotlib_format(self): + # In raw mode we do nothing, so matplotlib's default numeric formatter + # runs. Big unix timestamps get rendered in scientific form. + p = Plot(time_axis=TimeAxis.raw) + p.add(Series(ts=[1_700_000_000, 1_700_000_060], values=[1, 2])) + svg = p.to_svg() + # Raw mode should not produce "0s" style labels. + assert "0s" not in svg + + def test_opacity_appears_in_svg(self): + # opacity=0.4 should land as opacity="0.4" on the matplotlib-rendered + # path. matplotlib emits it as either `opacity` or `stroke-opacity` + # depending on the artist; we just need to see the value in the output. + p = Plot() + p.add(Series(ts=[0, 1, 2], values=[0, 1, 2], opacity=0.4)) + svg = p.to_svg() + assert "0.4" in svg + + def test_explicit_color_excluded_from_auto_cycle(self): + # If the user pins a Series to color.red, the auto-cycle for the next + # series should skip red and yield yellow (the third color) instead + # of red — otherwise we'd get two red lines. + from dimos.memory2.vis import color + + p = Plot() + p.add(Series(ts=[0, 1], values=[0, 1])) # auto → blue (first) + p.add(Series(ts=[0, 1], values=[2, 3], color=color.red)) # explicit red + p.add(Series(ts=[0, 1], values=[4, 5])) # auto → yellow (red is excluded) + svg = p.to_svg() + # Both blue and yellow should appear, plus the explicit red. + assert color.blue.hex() in svg + assert color.red.hex() in svg + assert color.yellow.hex() in svg + # Red should appear exactly once (the explicit one, not from the cycle). + assert svg.count(color.red.hex()) == 1 + + +class TestPlotRepr: + def test_repr_empty(self): + assert repr(Plot()) == "Plot()" + + def test_repr_with_elements(self): + p = Plot() + p.add(Series(ts=[0], values=[0])) + p.add(Series(ts=[0], values=[0])) + p.add(HLine(y=1)) + assert repr(p) == "Plot(HLine=1, Series=2)" + + +class TestPlotRerunStub: + """Plot.to_rerun() is currently a no-op placeholder — must not raise.""" + + def test_to_rerun_does_not_raise(self): + Plot().to_rerun() + + +class TestPalette: + """The named palette and palette_iter live in vis/color.py.""" + + def test_named_constants_exist(self): + from dimos.memory2.vis import color + + assert color.blue == "#3498db" + assert color.red == "#e74c3c" + assert color.green.hex().startswith("#") + assert color.amber.hex().startswith("#") + assert len(color.PALETTE) == 12 + assert color.PALETTE[0] == color.blue + assert color.PALETTE[6] == color.green + + def test_palette_iter_yields_palette_first(self): + from dimos.memory2.vis import color + + it = color.palette_iter() + assert [next(it) for _ in range(12)] == color.PALETTE + + def test_palette_iter_continues_past_palette(self): + from dimos.memory2.vis import color + + it = color.palette_iter() + first_thirteen = [next(it) for _ in range(13)] + # 13th color is generated, must be a valid Color and distinct from all 12 named. + assert first_thirteen[12].hex().startswith("#") + assert first_thirteen[12] not in color.PALETTE + + def test_palette_iter_excludes(self): + from dimos.memory2.vis import color + + it = color.palette_iter(exclude={color.red, color.yellow}) + first_three = [next(it) for _ in range(3)] + # Skipped red and yellow, so the first three are blue, teal, purple. + assert first_three == [color.blue, color.teal, color.purple] diff --git a/dimos/memory2/vis/space/elements.py b/dimos/memory2/vis/space/elements.py new file mode 100644 index 0000000000..b181551cb3 --- /dev/null +++ b/dimos/memory2/vis/space/elements.py @@ -0,0 +1,165 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Element types for the Space drawing language. + +Each element wraps one or more dimos.msgs with rendering intent + style. +For example, Pose(posestamped) says "render this PoseStamped as a circle + +heading arrow", while Arrow(posestamped) says "render it as an arrow only." + +SVG renderer collapses to 2D (top-down XY projection, Z ignored). +Rerun renderer can use the wrapped msgs' .to_rerun() methods directly. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Union + +from dimos.memory2.vis.color import Color, DeferredColor + +ColorLike = Union[str, Color, DeferredColor] + +if TYPE_CHECKING: + from dimos.memory2.type.observation import Observation + from dimos.msgs.geometry_msgs.Point import Point as GeoPoint + from dimos.msgs.geometry_msgs.Pose import Pose as GeoPose + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + from dimos.msgs.geometry_msgs.Vector3 import Vector3 + from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid + from dimos.msgs.nav_msgs.Path import Path + from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo + from dimos.msgs.sensor_msgs.Image import Image + from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 + + +@dataclass +class Pose: + """Circle + heading arrow at a pose. + + Default element for PoseStamped. + SVG: at .msg.x/.y + heading from .msg.yaw + Rerun: msg.to_rerun() (Transform3D) + msg.to_rerun_arrow() + """ + + msg: PoseStamped | GeoPose + color: ColorLike = "#1abc9c" + size: float = 0.3 + label: str | None = None + opacity: float = 1.0 + + +@dataclass +class Arrow: + """Heading arrow only (no dot). + + SVG: + arrowhead from .msg.x/.y along .msg.yaw + Rerun: msg.to_rerun_arrow() + """ + + msg: PoseStamped | GeoPose + color: ColorLike = "#e67e22" + length: float = 0.5 + opacity: float = 1.0 + + +@dataclass +class Point: + """Dot at a position. + + Default element for geometry_msgs.Point / PointStamped. + SVG: + optional label + Rerun: rr.Points3D + """ + + msg: GeoPoint | GeoPose + color: ColorLike = "#e74c3c" + radius: float = 0.05 + label: str | None = None + opacity: float = 1.0 + + +@dataclass +class Box3D: + """3D bounding box, rendered as rectangle in top-down view. + + Built from Detection3D.bbox or manually from center + size. + SVG: centered at .center.x/.y with .size.x/.y + Rerun: rr.Boxes3D + """ + + center: GeoPose + size: Vector3 + color: ColorLike = "#f1c40f" + label: str | None = None + opacity: float = 1.0 + + +@dataclass +class Camera: + """Camera frustum at a pose, with optional image and intrinsics. + + SVG: FOV wedge at .pose.x/.y/.yaw (if camera_info), else dot + thumbnail + Rerun: rr.Pinhole + rr.Transform3D + optional rr.Image + """ + + pose: PoseStamped + image: Image | None = None + camera_info: CameraInfo | None = None + color: ColorLike = "#9b59b6" + label: str | None = None + opacity: float = 1.0 + + +@dataclass +class Polyline: + """Styled polyline wrapping a Path msg. + + SVG: through .msg.poses[*].x/.y + Rerun: rr.LineStrips3D + """ + + msg: Path + color: ColorLike = "#3498db" + width: float = 0.05 + opacity: float = 1.0 + + +@dataclass +class Text: + """Text annotation at a world position. + + SVG: + Rerun: rr.TextLog + """ + + position: tuple[float, float, float] + text: str + font_size: float = 12.0 + color: ColorLike = "#333333" + opacity: float = 1.0 + + +SpaceElement = Union[ + Pose, + Arrow, + Point, + Box3D, + Camera, + Polyline, + Text, + "OccupancyGrid", # pass-through, rendered as base map raster + "PointCloud2", # pass-through, rerun renders full 3D, SVG collapses to occupancy grid + "Observation[Any]", # pass-through, renderer decides presentation (covers EmbeddedObservation) +] diff --git a/dimos/memory2/vis/space/rerun.py b/dimos/memory2/vis/space/rerun.py new file mode 100644 index 0000000000..435097e77c --- /dev/null +++ b/dimos/memory2/vis/space/rerun.py @@ -0,0 +1,336 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Rerun renderer for Space. Logs scene elements as 3D archetypes.""" + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING, Any + +from dimos.memory2.type.observation import Observation +from dimos.memory2.vis.color import Color +from dimos.memory2.vis.space.elements import Arrow, Box3D, Camera, Point, Polyline, Pose, Text +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 + +if TYPE_CHECKING: + from dimos.memory2.vis.space.space import Space + + +def _rgba(el: Any) -> tuple[int, int, int, int]: + """Combine element color + opacity into an RGBA u8 tuple for rerun.""" + c = Color.coerce(getattr(el, "color", "#000000")) + opacity = float(getattr(el, "opacity", 1.0)) + return c.with_alpha(c.a * opacity).rgba_u8() + + +# base_link → camera_optical extrinsics (applied at render time for image observations) +_BASE_TO_OPTICAL = Transform( + translation=Vector3(0.3, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="base_link", + child_frame_id="camera_link", +) + Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(-0.5, 0.5, -0.5, 0.5), + frame_id="camera_link", + child_frame_id="camera_optical", +) + + +def render(space: Space, app_id: str = "space", spawn: bool = True) -> None: + """Render a Space to a Rerun viewer.""" + import rerun as rr + import rerun.blueprint as rrb + + from dimos.visualization.rerun.init import rerun_init + + rerun_init(app_id, spawn=spawn) + + # Collect elements by type + points: list[Point] = [] + poses: list[Pose] = [] + arrows: list[Arrow] = [] + boxes: list[Box3D] = [] + cameras: list[Camera] = [] + polylines: list[Polyline] = [] + texts: list[Text] = [] + grids: list[OccupancyGrid] = [] + pointclouds: list[PointCloud2] = [] + observations: list[Observation[Any]] = [] + panels: list[Observation[Any]] = [] + + for el in space.elements: + if isinstance(el, Observation): + if _is_image(el.data) and el.pose is None: + panels.append(el) + else: + observations.append(el) + elif isinstance(el, Point): + points.append(el) + elif isinstance(el, Pose): + poses.append(el) + elif isinstance(el, Arrow): + arrows.append(el) + elif isinstance(el, Box3D): + boxes.append(el) + elif isinstance(el, Camera): + cameras.append(el) + elif isinstance(el, Polyline): + polylines.append(el) + elif isinstance(el, Text): + texts.append(el) + elif isinstance(el, OccupancyGrid): + grids.append(el) + elif isinstance(el, PointCloud2): + pointclouds.append(el) + + # Build and send blueprint + has_images = ( + any(c.image is not None for c in cameras) + or any(_has_image(obs) for obs in observations) + or bool(panels) + ) + views: list[Any] = [ + rrb.Spatial3DView( + origin="scene", + name="Scene", + background=rrb.Background(kind="SolidColor", color=[0, 0, 0]), + line_grid=rrb.LineGrid3D( + plane=rr.components.Plane3D.XY.with_distance(0.5), + ), + ) + ] + if has_images: + views.append(rrb.Spatial2DView(origin="scene", name="Images")) + + blueprint = rrb.Blueprint( + rrb.Horizontal(*views, column_shares=[2, 1]) if len(views) > 1 else views[0] + ) + rr.send_blueprint(blueprint) + + # Log elements + if grids: + for i, el in enumerate(grids): + rr.log(f"scene/map/{i}", el.to_rerun(), static=True) + + if pointclouds: + for i, el in enumerate(pointclouds): + rr.log(f"scene/pointcloud/{i}", el.to_rerun(), static=True) + + if points: + rr.log( + "scene/points", + rr.Points3D( + positions=[[p.msg.x, p.msg.y, p.msg.z] for p in points], + colors=[_rgba(p) for p in points], + radii=[max(p.radius, 0.05) for p in points], + labels=[p.label or "" for p in points] if any(p.label for p in points) else None, + ), + static=True, + ) + + if poses: + rr.log( + "scene/poses", + rr.Points3D( + positions=[[p.msg.x, p.msg.y, 0] for p in poses], + colors=[_rgba(p) for p in poses], + radii=[p.size * 0.3 for p in poses], + labels=[p.label or "" for p in poses] if any(p.label for p in poses) else None, + ), + static=True, + ) + rr.log( + "scene/poses/headings", + rr.Arrows3D( + origins=[[p.msg.x, p.msg.y, 0] for p in poses], + vectors=[ + [math.cos(p.msg.yaw) * p.size, math.sin(p.msg.yaw) * p.size, 0] for p in poses + ], + colors=[_rgba(p) for p in poses], + ), + static=True, + ) + + if arrows: + rr.log( + "scene/arrows", + rr.Arrows3D( + origins=[[a.msg.x, a.msg.y, 0] for a in arrows], + vectors=[ + [math.cos(a.msg.yaw) * a.length, math.sin(a.msg.yaw) * a.length, 0] + for a in arrows + ], + colors=[_rgba(a) for a in arrows], + ), + static=True, + ) + + if boxes: + rr.log( + "scene/boxes", + rr.Boxes3D( + centers=[[b.center.x, b.center.y, 0] for b in boxes], + half_sizes=[[b.size.x / 2, b.size.y / 2, b.size.z / 2] for b in boxes], + colors=[_rgba(b) for b in boxes], + labels=[b.label or "" for b in boxes] if any(b.label for b in boxes) else None, + ), + static=True, + ) + + for i, el in enumerate(polylines): + rr.log( + f"scene/polylines/{i}", + rr.LineStrips3D( + strips=[[[p.x, p.y, 0] for p in el.msg.poses]], + colors=[_rgba(el)], + radii=[el.width / 2], + ), + static=True, + ) + + if texts: + rr.log( + "scene/texts", + rr.Points3D( + positions=[[t.position[0], t.position[1], 0] for t in texts], + labels=[t.text for t in texts], + colors=[_rgba(t) for t in texts], + radii=[0.01] * len(texts), + ), + static=True, + ) + + for i, el in enumerate(cameras): + path = f"scene/cameras/{i}" + rr.log(path, el.pose.to_rerun(), static=True) + if el.camera_info: + pinhole = el.camera_info.to_rerun() + assert not isinstance(pinhole, list) + rr.log(path, pinhole, static=True) + elif el.image: + h, w = el.image.shape[:2] + focal = max(w, h) + rr.log( + path, + rr.Pinhole(focal_length=focal, principal_point=[w / 2, h / 2], resolution=[w, h]), + static=True, + ) + if el.image: + rr.log(f"{path}/image", el.image.to_rerun(), static=True) + + for i, obs in enumerate(observations): + path = f"scene/observations/{i}" + data = obs.data + img = _as_image(data) + if img is not None: + # Apply base→optical extrinsics for camera frustum rendering + world_T_optical = Transform.from_pose("world", obs.pose_stamped) + _BASE_TO_OPTICAL + rr.log(path, world_T_optical.to_pose().to_rerun(), static=True) + h, w = img.shape[:2] + focal = max(w, h) + rr.log( + path, + rr.Pinhole( + focal_length=focal, + principal_point=[w / 2, h / 2], + resolution=[w, h], + image_plane_distance=1.0, + ), + static=True, + ) + rr.log(f"{path}/image", img.to_rerun(), static=True) + elif isinstance(data, PointCloud2): + rr.log(path, obs.pose_stamped.to_rerun(), static=True) + rr.log(f"{path}/pointcloud", data.to_rerun(), static=True) + elif isinstance(data, (int, float)): + rr.log( + path, + rr.Points3D( + positions=[[obs.pose_stamped.x, obs.pose_stamped.y, 0]], + labels=[str(data)], + radii=[0.025], + ), + static=True, + ) + elif isinstance(data, str): + # Word-wrap for label + words = data.split() + lines: list[str] = [] + line: str = "" + for word in words: + if line and len(line) + len(word) + 1 > 40: + lines.append(line) + line = word + else: + line = f"{line} {word}" if line else word + if line: + lines.append(line) + label = "\n".join(lines) + x, y = obs.pose_stamped.x, obs.pose_stamped.y + # Pin: line from ground up, label at the tip + rr.log( + f"{path}/pin", + rr.LineStrips3D( + strips=[[[x, y, 1.5], [x, y, 3.0]]], + colors=[(100, 100, 100)], + radii=[0.01], + ), + static=True, + ) + rr.log( + f"{path}/label", + rr.Points3D( + positions=[[x, y, 3.5]], + labels=[label], + radii=[0.001], + ), + static=True, + ) + else: + rr.log( + path, + rr.Points3D(positions=[[obs.pose_stamped.x, obs.pose_stamped.y, 0]], radii=[0.05]), + static=True, + ) + + for i, obs in enumerate(panels): + img = _as_image(obs.data) + if img is not None: + rr.log(f"scene/panels/{i}", img.to_rerun(), static=True) + + +def _as_image(data: Any) -> Any | None: + """Return an Image if data is an Image or ImageDetections, else None.""" + from dimos.msgs.sensor_msgs.Image import Image + from dimos.perception.detection.type.imageDetections import ImageDetections + + if isinstance(data, Image): + return data + if isinstance(data, ImageDetections): + return data.annotated_image() + return None + + +def _is_image(data: Any) -> bool: + return _as_image(data) is not None + + +def _has_image(obs: Observation[Any]) -> bool: + return _is_image(obs.data) diff --git a/dimos/memory2/vis/space/space.py b/dimos/memory2/vis/space/space.py new file mode 100644 index 0000000000..4ce973b133 --- /dev/null +++ b/dimos/memory2/vis/space/space.py @@ -0,0 +1,187 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Space: 2D spatial rendering canvas (world frame). + +Space.add() is a smart dispatcher: it accepts element types directly +(explicit rendering mode), raw dimos msgs (auto-wrapped into default +element), or observations (smart dispatch based on data type). +""" + +from __future__ import annotations + +from typing import Any + +from dimos.memory2.type.observation import EmbeddedObservation, Observation +from dimos.memory2.vis.color import ColorRange, resolve_deferred +from dimos.memory2.vis.space.elements import ( + Arrow, + Box3D, + Camera, + Point, + Polyline, + Pose, + SpaceElement, + Text, +) +from dimos.msgs.geometry_msgs.Point import Point as GeoPoint +from dimos.msgs.geometry_msgs.Pose import Pose as GeoPose +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.nav_msgs.Path import Path as NavPath +from dimos.msgs.protocol import DimosMsg +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.vision_msgs.Detection3D import Detection3D + + +def _autocolor_value(item: Any) -> float | None: + """Extract the scalar to colormap for an item, or None if not auto-colorable.""" + if isinstance(item, EmbeddedObservation): + return float(item.similarity or 0.0) + if isinstance(item, Observation): + if item.data_type in (float, int): + return float(item.data) + return float(item.ts) + if isinstance(item, (int, float)): + return float(item) + return None + + +class Space: + """2D spatial rendering canvas (world frame). + + Accumulates elements for spatial visualization. Elements can be added as: + - Element types directly: ``s.add(Pose(posestamped, color="red"))`` + - Raw dimos msgs with style kwargs: ``s.add(posestamped, color="red")`` + - Observations (smart dispatch): ``s.add(observation)`` + - Lists of EmbeddedObservations: ``s.add(results)`` → similarity heatmap + - Streams / iterables: ``s.add(stream)`` → materializes and adds each obs.data + """ + + def __init__(self) -> None: + self._elements: list[SpaceElement] = [] + self._autocolor_ranges: list[ColorRange] = [] + + def add(self, element: Any, **kwargs: Any) -> Space: + """Add an element with smart dispatch. + + Element types (Pose, Arrow, Point, etc.) are stored as-is. + Raw dimos msgs are auto-wrapped into their default element, + with ``**kwargs`` forwarded as style (color, label, etc.). + """ + + if isinstance(element, (Pose, Arrow, Point, Box3D, Camera, Polyline, Text)): + self._elements.append(element) + elif isinstance(element, DimosMsg): + self.add_dimos_msg(element, **kwargs) + elif isinstance(element, EmbeddedObservation): + self.add_embedded_observation(element, **kwargs) + elif isinstance(element, Observation): + self.add_observation(element, **kwargs) + elif hasattr(element, "__iter__"): + cmap = kwargs.pop("cmap", "turbo") + color_range = ColorRange(cmap=cmap) + self._autocolor_ranges.append(color_range) + for item in element: + v = _autocolor_value(item) + if v is not None and isinstance(item, Observation): + self._elements.append(Arrow(msg=item.pose_stamped, color=color_range(v))) + else: + self.add(item, **kwargs) + else: + raise TypeError( + f"Space.add() does not know how to handle {type(element).__name__}. " + f"Pass an element type (Pose, Arrow, Point, ...) or a dimos msg." + ) + + return self + + def add_dimos_msg(self, msg: DimosMsg, **kwargs: Any) -> None: + """Dispatch a DimosMsg to its default element type.""" + if isinstance(msg, PoseStamped): + self._elements.append(Pose(msg=msg, **kwargs)) + elif isinstance(msg, GeoPose): + self._elements.append(Pose(msg=msg, **kwargs)) + elif isinstance(msg, GeoPoint): + self._elements.append(Point(msg=msg, **kwargs)) + elif isinstance(msg, NavPath): + self._elements.append(Polyline(msg=msg, **kwargs)) + elif isinstance(msg, OccupancyGrid): + self._elements.append(msg) + elif isinstance(msg, PointCloud2): + self._elements.append(msg) + elif isinstance(msg, Detection3D): + self._elements.append( + Box3D( + center=msg.bbox.center, + size=msg.bbox.size, + label=getattr(msg, "id", None), + **kwargs, + ) + ) + else: + raise TypeError( + f"No default element for {type(msg).__name__}. " + f"Wrap it explicitly (e.g. Pose(msg), Arrow(msg))." + ) + + def add_embedded_observation(self, obs: EmbeddedObservation[Any], **kwargs: Any) -> None: + """Pass through to renderer like a regular Observation.""" + self._elements.append(obs) + + def add_observation(self, obs: Observation[Any], **kwargs: Any) -> None: + """Store the observation directly; renderers decide how to display it.""" + self._elements.append(obs) + + def base_map(self, grid: OccupancyGrid) -> Space: + """Add an OccupancyGrid as the background map.""" + return self.add(grid) + + def to_svg(self, path: str | None = None) -> str: + """Render to SVG string. Optionally write to file.""" + from dimos.memory2.vis.space.svg import render + + resolve_deferred(self._elements) + svg = render(self) + if path is not None: + with open(path, "w") as f: + f.write(svg) + return svg + + def to_rerun(self, app_id: str = "space", spawn: bool = True) -> None: + """Render to Rerun viewer.""" + from dimos.memory2.vis.space.rerun import render + + resolve_deferred(self._elements) + render(self, app_id=app_id, spawn=spawn) + + def _repr_svg_(self) -> str: + """Jupyter inline display.""" + return self.to_svg() + + @property + def elements(self) -> list[SpaceElement]: + """Read-only access to accumulated elements.""" + return list(self._elements) + + def __len__(self) -> int: + return len(self._elements) + + def __repr__(self) -> str: + counts: dict[str, int] = {} + for el in self._elements: + name = type(el).__name__ + counts[name] = counts.get(name, 0) + 1 + parts = [f"{n}={c}" for n, c in sorted(counts.items())] + return f"Space({', '.join(parts)})" diff --git a/dimos/memory2/vis/space/svg.py b/dimos/memory2/vis/space/svg.py new file mode 100644 index 0000000000..e54da044da --- /dev/null +++ b/dimos/memory2/vis/space/svg.py @@ -0,0 +1,348 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SVG renderer for Space. + +Top-down XY projection (Z ignored). Renders in world coordinates with Y-flip. +The SVG viewBox is computed from actual rendered content, so all element types +automatically contribute to the viewport bounds. +""" + +from __future__ import annotations + +import base64 +from dataclasses import dataclass +import io +import math +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np +from PIL import Image as PILImage + +from dimos.mapping.occupancy.visualizations import generate_rgba_texture +from dimos.memory2.type.observation import Observation +from dimos.memory2.vis.color import Color +from dimos.memory2.vis.space.elements import ( + Arrow, + Box3D, + Camera, + Point, + Polyline, + Pose, + SpaceElement, + Text, +) +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 + +if TYPE_CHECKING: + from dimos.memory2.vis.space.space import Space + + +@dataclass +class Bounds: + """Accumulates world-space bounding box during rendering.""" + + xmin: float = float("inf") + xmax: float = float("-inf") + ymin: float = float("inf") + ymax: float = float("-inf") + + def include(self, x: float, y: float) -> None: + self.xmin = min(self.xmin, x) + self.xmax = max(self.xmax, x) + self.ymin = min(self.ymin, y) + self.ymax = max(self.ymax, y) + + @property + def empty(self) -> bool: + return self.xmin > self.xmax + + @property + def width(self) -> float: + return max(self.xmax - self.xmin, 1.0) + + @property + def height(self) -> float: + return max(self.ymax - self.ymin, 1.0) + + +def _y(wy: float) -> float: + """Flip Y axis: world Y-up → SVG Y-down.""" + return -wy + + +def _style(el: object) -> tuple[str, float]: + """Return (hex, combined-alpha) from an element's color and opacity.""" + c = Color.coerce(getattr(el, "color", "#000000")) + opacity = float(getattr(el, "opacity", 1.0)) + return c.hex(), c.a * opacity + + +# Element renderers — all emit world-coordinate SVG and grow Bounds + + +def _render_point(el: Point, b: Bounds) -> str: + x, y = el.msg.x, _y(el.msg.y) + r = el.radius + b.include(x - r, y - r) + b.include(x + r, y + r) + fill, alpha = _style(el) + parts = [f''] + if el.label: + parts.append( + f'{_esc(el.label)}' + ) + return "\n".join(parts) + + +def _render_arrow(el: Arrow, b: Bounds) -> str: + x, y = el.msg.x, _y(el.msg.y) + yaw = el.msg.yaw + length = el.length + half_base = length * 0.4 + + # Tip of the triangle + tx = x + math.cos(yaw) * length + ty = y - math.sin(yaw) * length # sin negated for Y-flip + # Two base corners (perpendicular to yaw) + bx1 = x + math.cos(yaw + math.pi / 2) * half_base + by1 = y - math.sin(yaw + math.pi / 2) * half_base + bx2 = x + math.cos(yaw - math.pi / 2) * half_base + by2 = y - math.sin(yaw - math.pi / 2) * half_base + + for px, py in [(x, y), (tx, ty), (bx1, by1), (bx2, by2)]: + b.include(px, py) + + stroke, alpha = _style(el) + return ( + f'' + ) + + +def _render_pose(el: Pose, b: Bounds) -> str: + arrow = Arrow(msg=el.msg, length=el.size, color=el.color, opacity=el.opacity) + parts = [_render_arrow(arrow, b)] + if el.label: + x, y = el.msg.x, _y(el.msg.y) + fill, alpha = _style(el) + parts.append( + f'{_esc(el.label)}' + ) + return "\n".join(parts) + + +def _render_polyline(el: Polyline, b: Bounds) -> str: + pts = [] + for p in el.msg.poses: + x, y = p.x, _y(p.y) + b.include(x, y) + pts.append(f"{x:.4f},{y:.4f}") + stroke, alpha = _style(el) + return ( + f'' + ) + + +def _render_box3d(el: Box3D, b: Bounds) -> str: + cx, cy = el.center.x, el.center.y + hw, hh = el.size.x / 2, el.size.y / 2 + # Top-left in world → SVG + x = cx - hw + y = _y(cy + hh) + w = el.size.x + h = el.size.y + b.include(x, y) + b.include(x + w, y + h) + stroke, alpha = _style(el) + parts = [ + f'' + ] + if el.label: + font_size = max(h * 0.3, 0.2) + parts.append( + f'{_esc(el.label)}' + ) + return "\n".join(parts) + + +def _render_camera(el: Camera, b: Bounds) -> str: + x, y = el.pose.x, _y(el.pose.y) + yaw = el.pose.yaw + stroke, alpha = _style(el) + + if el.camera_info and el.camera_info.K[4] > 0: + fy = el.camera_info.K[4] + fov_y = 2 * math.atan(el.camera_info.height / (2 * fy)) + fov_half = fov_y / 2 + wedge_len = 0.8 + + a1 = yaw + fov_half + a2 = yaw - fov_half + x1 = x + math.cos(a1) * wedge_len + y1 = y - math.sin(a1) * wedge_len + x2 = x + math.cos(a2) * wedge_len + y2 = y - math.sin(a2) * wedge_len + + for px, py in [(x, y), (x1, y1), (x2, y2)]: + b.include(px, py) + + parts = [ + f'' + ] + else: + r = 0.15 + b.include(x - r, y - r) + b.include(x + r, y + r) + parts = [ + f'' + ] + + if el.label: + parts.append( + f'{_esc(el.label)}' + ) + return "\n".join(parts) + + +def _render_text(el: Text, b: Bounds) -> str: + x, y = el.position[0], _y(el.position[1]) + b.include(x, y) + fill, alpha = _style(el) + return ( + f'{_esc(el.text)}' + ) + + +def _render_occupancy_grid(el: OccupancyGrid, b: Bounds) -> str: + if el.grid.size == 0: + return "" + + rgba = np.flipud(generate_rgba_texture(el)) + img = PILImage.fromarray(rgba, "RGBA") + buf = io.BytesIO() + img.save(buf, format="PNG") + b64 = base64.b64encode(buf.getvalue()).decode("ascii") + + ox, oy = el.origin.x, el.origin.y + world_w = el.width * el.resolution + world_h = el.height * el.resolution + + # SVG top-left: world top-left with Y-flip + sx = ox + sy = _y(oy + world_h) + + b.include(sx, sy) + b.include(sx + world_w, sy + world_h) + + return ( + f'' + ) + + +# Dispatch + top-level render + + +def _render_element(el: SpaceElement, b: Bounds) -> str: + if isinstance(el, Point): + return _render_point(el, b) + elif isinstance(el, Pose): + return _render_pose(el, b) + elif isinstance(el, Arrow): + return _render_arrow(el, b) + elif isinstance(el, Polyline): + return _render_polyline(el, b) + elif isinstance(el, Box3D): + return _render_box3d(el, b) + elif isinstance(el, Camera): + return _render_camera(el, b) + elif isinstance(el, Text): + return _render_text(el, b) + elif isinstance(el, OccupancyGrid): + return _render_occupancy_grid(el, b) + elif isinstance(el, PointCloud2): + from dimos.mapping.occupancy.inflation import simple_inflate + from dimos.mapping.pointclouds.occupancy import height_cost_occupancy + + return _render_occupancy_grid(simple_inflate(height_cost_occupancy(el), 0.05), b) + elif isinstance(el, Observation): + if el.pose is None: + return "" + if el.data_type == float: + return _render_arrow(Arrow(msg=el.pose_stamped, color="#ff0000"), b) + else: + return _render_arrow(Arrow(msg=el.pose_stamped), b) + + else: + return f"" + + +def render( + space: Space, + path: str | Path | None = None, + width_px: float = 800, + padding: float = 0.5, +) -> str: + """Render a Space to an SVG string, optionally writing to *path*.""" + b = Bounds() + fragments: list[str] = [] + + for el in space.elements: + fragments.append(_render_element(el, b)) + + if b.empty: + b.include(0, 0) + b.include(1, 1) + + b.xmin -= padding + b.xmax += padding + b.ymin -= padding + b.ymax += padding + + aspect = b.height / b.width + svg_h = width_px * aspect + + parts: list[str] = [ + f'', + ] + parts.extend(fragments) + parts.append("") + svg = "\n".join(parts) + + if path is not None: + Path(path).write_text(svg) + + return svg + + +def _esc(s: str) -> str: + """Escape text for SVG XML.""" + return s.replace("&", "&").replace("<", "<").replace(">", ">") diff --git a/dimos/memory2/vis/space/test_space.py b/dimos/memory2/vis/space/test_space.py new file mode 100644 index 0000000000..5ae86af443 --- /dev/null +++ b/dimos/memory2/vis/space/test_space.py @@ -0,0 +1,350 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Space builder and element types.""" + +import numpy as np +import pytest + +from dimos.memory2.type.observation import EmbeddedObservation, Observation +from dimos.memory2.vis.space.elements import Arrow, Box3D, Camera, Point, Polyline, Pose, Text +from dimos.memory2.vis.space.space import Space +from dimos.msgs.geometry_msgs.Point import Point as GeoPoint +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.nav_msgs.OccupancyGrid import OccupancyGrid +from dimos.msgs.nav_msgs.Path import Path as Path +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.vision_msgs.Detection3D import Detection3D + + +class TestElementTypes: + """Element types wrap msgs with rendering intent + style.""" + + def test_pose_wraps_posestamped(self): + ps = PoseStamped(3.2, 1.5, 0.0) + p = Pose(ps, color="red", label="fridge") + assert p.msg is ps + assert p.color == "red" + assert p.label == "fridge" + + def test_arrow_wraps_posestamped(self): + ps = PoseStamped(1, 2, 0, 0, 0, 0.1, 1) + a = Arrow(ps, color="orange", length=0.8) + assert a.msg is ps + assert a.length == 0.8 + + def test_point_wraps_geopoint(self): + gp = GeoPoint(7.1, 4.3, 0) + p = Point(gp, color="green", label="bottle") + assert p.msg is gp + assert p.msg.x == pytest.approx(7.1) + + def test_point_wraps_posestamped(self): + ps = PoseStamped(3, 1, 0) + p = Point(ps, radius=0.5) + assert p.msg.x == pytest.approx(3.0) + + def test_box3d_from_center_size(self): + from dimos.msgs.geometry_msgs.Pose import Pose as GeoPose + from dimos.msgs.geometry_msgs.Vector3 import Vector3 + + b = Box3D(center=GeoPose(5, 3, 0), size=Vector3(2, 1, 0.5), label="table") + assert b.center.x == pytest.approx(5.0) + assert b.size.x == pytest.approx(2.0) + assert b.label == "table" + + def test_camera_with_image(self): + ps = PoseStamped(1, 2, 0) + img = Image(np.zeros((480, 640, 3), dtype=np.uint8)) + c = Camera(pose=ps, image=img, color="purple") + assert c.pose is ps + assert c.image is img + assert c.camera_info is None + + def test_text(self): + t = Text((1, 8, 0), "exploration run #3") + assert t.text == "exploration run #3" + assert t.color == "#333333" + + +class TestSpaceExplicitElements: + """Space.add() with explicit element types stores them as-is.""" + + def test_add_pose(self): + s = Space() + ps = PoseStamped(3, 1, 0) + pose = Pose(ps, color="red") + s.add(pose) + assert len(s) == 1 + assert s.elements[0] is pose + + def test_add_multiple_types(self): + s = Space() + ps = PoseStamped(3, 1, 0) + s.add(Pose(ps, color="red")) + s.add(Arrow(ps, color="orange")) + s.add(Point(GeoPoint(1, 2, 0), label="x")) + s.add(Text((0, 0, 0), "hello")) + assert len(s) == 4 + + def test_chaining(self): + ps = PoseStamped(1, 1, 0) + s = Space().add(Pose(ps)).add(Arrow(ps)).add(Text((0, 0, 0), "hi")) + assert len(s) == 3 + + +class TestSpaceAutoWrap: + """Space.add() with raw dimos msgs auto-wraps into default element.""" + + def test_posestamped_becomes_pose(self): + s = Space() + ps = PoseStamped(3.2, 1.5, 0) + s.add(ps, color="blue", label="auto") + assert len(s) == 1 + el = s.elements[0] + assert isinstance(el, Pose) + assert el.msg is ps + assert el.color == "blue" + assert el.label == "auto" + + def test_geopoint_becomes_point(self): + s = Space() + gp = GeoPoint(7, 4, 0) + s.add(gp, color="yellow") + el = s.elements[0] + assert isinstance(el, Point) + assert el.msg is gp + assert el.color == "yellow" + + def test_path_becomes_polyline(self): + s = Space() + p = Path(poses=[PoseStamped(i, 0, 0) for i in range(3)]) + s.add(p, color="blue", width=0.1) + el = s.elements[0] + assert isinstance(el, Polyline) + assert el.color == "blue" + assert el.width == 0.1 + assert len(el.msg.poses) == 3 + + def test_occupancy_grid_passthrough(self): + s = Space() + grid = OccupancyGrid() + s.add(grid) + assert s.elements[0] is grid + + def test_detection3d_becomes_box3d(self): + det = Detection3D() + det.bbox.center.position.x = 5.0 + det.bbox.center.position.y = 3.0 + det.bbox.size.x = 2.0 + det.bbox.size.y = 1.0 + det.bbox.size.z = 0.5 + + s = Space() + s.add(det, color="yellow") + el = s.elements[0] + assert isinstance(el, Box3D) + assert el.center.position.x == pytest.approx(5.0) + assert el.size.x == pytest.approx(2.0) + assert el.color == "yellow" + + def test_unknown_type_raises(self): + s = Space() + with pytest.raises(TypeError, match="does not know how to handle"): + s.add(42) + + +class TestSpaceObservations: + """Space.add() smart dispatch for Observation types.""" + + def test_image_observation_stored_as_observation(self): + img = Image(np.zeros((480, 640, 3), dtype=np.uint8)) + obs = Observation(id=1, ts=1.0, pose=(3, 1, 0, 0, 0, 0, 1), _data=img) + + s = Space() + s.add(obs) + el = s.elements[0] + assert isinstance(el, Observation) + assert el.data is img + + def test_non_image_observation_stored_as_observation(self): + obs = Observation(id=2, ts=2.0, pose=(5, 2, 0, 0, 0, 0, 1), _data="some_data") + + s = Space() + s.add(obs) + el = s.elements[0] + assert isinstance(el, Observation) + assert el.data == "some_data" + + def test_posestamped_observation_stored_as_observation(self): + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped as PS + + obs = Observation(id=3, ts=3.0, pose=(1, 2, 0, 0, 0, 0, 1), _data=PS(5, 2, 0)) + + s = Space() + s.add(obs) + el = s.elements[0] + assert isinstance(el, Observation) + assert el.data.x == pytest.approx(5.0) + + def test_embedded_observation_stored_as_arrow(self): + obs = EmbeddedObservation( + id=0, + ts=0.0, + pose=(1, 2, 0, 0, 0, 0, 1), + _data="x", + similarity=0.8, + ) + + s = Space() + s.add(obs) + assert len(s) == 1 + el = s.elements[0] + assert isinstance(el, EmbeddedObservation) + + +class TestSpaceConvenience: + """Space convenience methods: base_map.""" + + def test_base_map(self): + grid = OccupancyGrid() + s = Space().base_map(grid) + assert len(s) == 1 + assert isinstance(s.elements[0], OccupancyGrid) + + def test_add_list_of_msgs(self): + poses = [PoseStamped(i, 0, 0) for i in range(3)] + s = Space() + s.add(poses, color="red") + assert len(s) == 3 + for el in s.elements: + assert isinstance(el, Pose) + assert el.color == "red" + + +class TestSpaceRepr: + def test_repr_empty(self): + assert repr(Space()) == "Space()" + + def test_repr_with_elements(self): + s = Space() + ps = PoseStamped(0, 0, 0) + s.add(Pose(ps)) + s.add(Pose(ps)) + s.add(Arrow(ps)) + assert repr(s) == "Space(Arrow=1, Pose=2)" + + +class TestSVGRender: + """SVG rendering produces valid SVG with expected elements.""" + + def test_empty_space(self): + svg = Space().to_svg() + assert svg.startswith("") + + def test_point_renders_circle(self): + from dimos.memory2.vis import color + + s = Space() + s.add(Point(GeoPoint(3, 4, 0), color="red", label="hi")) + svg = s.to_svg() + assert "")) + svg = s.to_svg() + assert " Observation[Image]: + """Tile images into a grid mosaic. + + Accepts Image instances, Observation/EmbeddedObservation with Image data, + or any iterable of these (including Stream). Returns a poseless + Observation[Image] tagged ``{"mosaic": True}`` — the rerun renderer + displays poseless image observations as flat 2D panels. + """ + images: list[Image] = [] + for f in frames: + if isinstance(f, Image): + images.append(f) + elif isinstance(f, ImageDetections2D): + images.append(f.annotated_image(scale=4)) + elif isinstance(f, Observation) and isinstance(f.data, Image): + images.append(f.data) + elif isinstance(f, EmbeddedObservation) and isinstance(f.data, Image): + images.append(f.data) + elif isinstance(f, Observation) and isinstance(f.data, ImageDetections2D): + images.append(f.data.annotated_image(scale=4)) + elif isinstance(f, EmbeddedObservation) and isinstance(f.data, ImageDetections2D): + images.append(f.data.annotated_image(scale=4)) + else: + raise TypeError(f"Cannot extract Image from {type(f).__name__}: {f!r}") + if not images: + raise ValueError("No images to mosaic") + + aspect = images[0].width / max(images[0].height, 1) + cell_w = int(cell_height * aspect) + rows = math.ceil(len(images) / cols) + + canvas = np.zeros((rows * cell_height, cols * cell_w, 3), dtype=np.uint8) + for i, img in enumerate(images): + r, c = divmod(i, cols) + tile = cv2.resize(img.to_bgr().data, (cell_w, cell_height)) + canvas[r * cell_height : (r + 1) * cell_height, c * cell_w : (c + 1) * cell_w] = tile + + result = Image(data=canvas, format=ImageFormat.BGR) + return Observation(id=0, ts=0.0, data_type=Image, _data=result, tags={"mosaic": True}) diff --git a/dimos/models/embedding/base.py b/dimos/models/embedding/base.py index 7850978317..4f5c4d8164 100644 --- a/dimos/models/embedding/base.py +++ b/dimos/models/embedding/base.py @@ -21,6 +21,7 @@ import numpy as np import torch +from dimos.core.resource import Resource from dimos.models.base import HuggingFaceModelConfig, LocalModelConfig from dimos.types.timestamped import Timestamped @@ -87,7 +88,7 @@ def to_cpu(self) -> Embedding: return self -class EmbeddingModel(ABC): +class EmbeddingModel(Resource, ABC): """Abstract base class for embedding models supporting vision and language.""" device: str diff --git a/dimos/models/vl/test_vlm.py b/dimos/models/vl/test_vlm.py index f0fd3b8d5a..f8b1c281d5 100644 --- a/dimos/models/vl/test_vlm.py +++ b/dimos/models/vl/test_vlm.py @@ -228,10 +228,10 @@ def test_vlm_query_multi(model_class: "type[VlModel]", model_name: str) -> None: @pytest.mark.slow def test_vlm_query_batch(model_class: "type[VlModel]", model_name: str) -> None: """Test query_batch optimization - multiple images, same query.""" - from dimos.utils.testing.replay import TimedSensorReplay + from dimos.memory.timeseries.legacy import LegacyPickleStore - # Load 5 frames at 1-second intervals using TimedSensorReplay - replay = TimedSensorReplay[Image]("unitree_go2_office_walk2/video") + # Load 5 frames at 1-second intervals using LegacyPickleStore + replay = LegacyPickleStore[Image]("unitree_go2_office_walk2/video") images = [replay.find_closest_seek(i).to_rgb() for i in range(0, 10, 2)] print(f"\nTesting {model_name} query_batch with {len(images)} images") @@ -285,9 +285,9 @@ def test_vlm_resize( sizes: list[tuple[int, int] | None], ) -> None: """Test VLM auto_resize effect on performance.""" - from dimos.utils.testing.replay import TimedSensorReplay + from dimos.memory.timeseries.legacy import LegacyPickleStore - replay = TimedSensorReplay[Image]("unitree_go2_office_walk2/video") + replay = LegacyPickleStore[Image]("unitree_go2_office_walk2/video") image = replay.find_closest_seek(0).to_rgb() labels: list[str] = [] diff --git a/dimos/msgs/nav_msgs/OccupancyGrid.py b/dimos/msgs/nav_msgs/OccupancyGrid.py index f2d1e84cec..b468bbcac0 100644 --- a/dimos/msgs/nav_msgs/OccupancyGrid.py +++ b/dimos/msgs/nav_msgs/OccupancyGrid.py @@ -481,7 +481,7 @@ def cell_value(self, world_position: Vector3) -> int: def to_rerun( self, colormap: str | None = None, - z_offset: float = 0.01, + z_offset: float = 0, opacity: float = 1.0, cost_range: tuple[int, int] | None = None, background: str | None = None, diff --git a/dimos/msgs/sensor_msgs/Image.py b/dimos/msgs/sensor_msgs/Image.py index 7da0238f32..58a328d2fa 100644 --- a/dimos/msgs/sensor_msgs/Image.py +++ b/dimos/msgs/sensor_msgs/Image.py @@ -374,6 +374,17 @@ def crop(self, x: int, y: int, width: int, height: int) -> Image: return Image(data=cropped_data, format=self.format, frame_id=self.frame_id, ts=self.ts) + @property + def brightness(self) -> float: + """Return mean brightness in [0, 1]. + + Strides to ~256px on the long edge first — ~O(N/step²) cheaper than + reading every pixel, and the mean converges quickly (CLT). + """ + max_val = 65535.0 if self.format in (ImageFormat.GRAY16, ImageFormat.DEPTH16) else 255.0 + step = max(1, max(self.data.shape[:2]) // 256) + return float(self.data[::step, ::step].mean() / max_val) + @property def sharpness(self) -> float: """Return sharpness score. diff --git a/dimos/msgs/sensor_msgs/PointCloud2.py b/dimos/msgs/sensor_msgs/PointCloud2.py index b45b844bf7..2f0527cebd 100644 --- a/dimos/msgs/sensor_msgs/PointCloud2.py +++ b/dimos/msgs/sensor_msgs/PointCloud2.py @@ -92,8 +92,10 @@ def __init__( self._pcd_tensor: o3d.t.geometry.PointCloud = o3d.t.geometry.PointCloud() elif isinstance(pointcloud, o3d.t.geometry.PointCloud): self._pcd_tensor = pointcloud + elif len(pointcloud.points) == 0: + # from_legacy() warns on empty legacy clouds; build an empty tensor instead + self._pcd_tensor = o3d.t.geometry.PointCloud() else: - # Convert legacy to tensor self._pcd_tensor = o3d.t.geometry.PointCloud.from_legacy(pointcloud) self._pcd_legacy_cache: o3d.geometry.PointCloud | None = None @@ -157,6 +159,8 @@ def pointcloud(self) -> o3d.geometry.PointCloud: def pointcloud(self, value: o3d.geometry.PointCloud | o3d.t.geometry.PointCloud) -> None: if isinstance(value, o3d.t.geometry.PointCloud): self._pcd_tensor = value + elif len(value.points) == 0: + self._pcd_tensor = o3d.t.geometry.PointCloud() else: self._pcd_tensor = o3d.t.geometry.PointCloud.from_legacy(value) self._pcd_legacy_cache = None @@ -653,6 +657,7 @@ def to_rerun( mode: str = "spheres", size: float | None = None, fill_mode: str = "solid", + bottom_cutoff: float | None = None, **kwargs: object, ) -> Archetype: """Convert to Rerun archetype for visualization. @@ -676,6 +681,11 @@ def to_rerun( if len(points) == 0: return rr.Points3D([]) if mode != "boxes" else rr.Boxes3D(centers=[]) + if bottom_cutoff is not None: + points = points[points[:, 2] >= bottom_cutoff] + if len(points) == 0: + return rr.Points3D([]) if mode != "boxes" else rr.Boxes3D(centers=[]) + # Use class_ids for height-based colormap (viewer resolves colors via AnnotationContext) # Fall back to explicit colors when provided class_ids = None diff --git a/dimos/msgs/sensor_msgs/test_image.py b/dimos/msgs/sensor_msgs/test_image.py index 2f236613f2..502161755f 100644 --- a/dimos/msgs/sensor_msgs/test_image.py +++ b/dimos/msgs/sensor_msgs/test_image.py @@ -20,9 +20,9 @@ _IS_MACOS = sys.platform == "darwin" +from dimos.memory.timeseries.legacy import LegacyPickleStore from dimos.msgs.sensor_msgs.Image import Image, ImageFormat, sharpness_barrier from dimos.utils.data import get_data -from dimos.utils.testing.replay import TimedSensorReplay @pytest.fixture @@ -72,7 +72,7 @@ def test_opencv_conversion(img: Image) -> None: @pytest.mark.tool def test_sharpness_stream() -> None: get_data("unitree_office_walk") # Preload data for testing - video_store = TimedSensorReplay( + video_store = LegacyPickleStore( "unitree_office_walk/video", autocast=lambda x: Image.from_numpy(x).to_rgb() ) diff --git a/dimos/perception/detection/conftest.py b/dimos/perception/detection/conftest.py index 5f8f1bc4b9..afd11cabf6 100644 --- a/dimos/perception/detection/conftest.py +++ b/dimos/perception/detection/conftest.py @@ -23,6 +23,7 @@ import pytest from dimos.core.transport import LCMTransport +from dimos.memory.timeseries.legacy import LegacyPickleStore from dimos.msgs.geometry_msgs.Transform import Transform from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo from dimos.msgs.sensor_msgs.Image import Image @@ -39,7 +40,6 @@ from dimos.robot.unitree.go2 import connection from dimos.robot.unitree.type.odometry import Odometry from dimos.utils.data import get_data -from dimos.utils.testing.replay import TimedSensorReplay class Moment(TypedDict, total=False): @@ -80,12 +80,12 @@ def moment_provider(**kwargs) -> Moment: data_dir = "unitree_go2_lidar_corrected" get_data(data_dir) - lidar_frame_result = TimedSensorReplay(f"{data_dir}/lidar").find_closest_seek(seek) + lidar_frame_result = LegacyPickleStore(f"{data_dir}/lidar").find_closest_seek(seek) if lidar_frame_result is None: raise ValueError("No lidar frame found") lidar_frame: PointCloud2 = lidar_frame_result - image_frame = TimedSensorReplay( + image_frame = LegacyPickleStore( f"{data_dir}/video", ).find_closest(lidar_frame.ts) @@ -94,7 +94,7 @@ def moment_provider(**kwargs) -> Moment: image_frame.frame_id = "camera_optical" - odom_frame = TimedSensorReplay(f"{data_dir}/odom", autocast=Odometry.from_msg).find_closest( + odom_frame = LegacyPickleStore(f"{data_dir}/odom", autocast=Odometry.from_msg).find_closest( lidar_frame.ts ) diff --git a/dimos/perception/detection/module2D.py b/dimos/perception/detection/module2D.py index 2cc8268d24..fb07e02d3c 100644 --- a/dimos/perception/detection/module2D.py +++ b/dimos/perception/detection/module2D.py @@ -78,7 +78,7 @@ def process_image_frame(self, image: Image) -> ImageDetections2D: imageDetections = self.detector.process_image(image) if not self.config.filter: return imageDetections - return imageDetections.filter(*self.config.filter) # type: ignore[return-value] + return imageDetections.filter(*self.config.filter) @simple_mcache def sharp_image_stream(self) -> Observable[Image]: diff --git a/dimos/perception/detection/type/detection2d/bbox.py b/dimos/perception/detection/type/detection2d/bbox.py index 57f4b632bd..9ba9034722 100644 --- a/dimos/perception/detection/type/detection2d/bbox.py +++ b/dimos/perception/detection/type/detection2d/bbox.py @@ -98,6 +98,48 @@ def to_repr_dict(self) -> dict[str, Any]: "bbox": f"[{x1:.0f},{y1:.0f},{x2:.0f},{y2:.0f}]", } + def draw_on(self, img: Any, scale: float = 1.0) -> None: + """Draw this detection's bbox and label onto a BGR numpy array (in-place).""" + import cv2 + import numpy as np + + x1, y1, x2, y2 = map(int, self.bbox) + + h = hashlib.md5(self.name.encode()).digest()[0] + bgr = [ + int(c) + for c in cv2.applyColorMap(np.array([[h]], dtype=np.uint8), cv2.COLORMAP_HSV)[0][0] + ] + + thickness = max(1, int(2 * scale)) + cv2.rectangle(img, (x1, y1), (x2, y2), bgr, thickness) + + label = self.name + if self.confidence < 1.0: + label = f"{self.name} {self.confidence:.2f}" + font_scale = 0.5 * scale + font_thickness = max(1, int(scale)) + (tw, th), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness) + cv2.rectangle(img, (x1, y1 - th - 6), (x1 + tw + 4, y1), (0, 0, 0), -1) + cv2.rectangle(img, (x1, y1 - th - 6), (x1 + tw + 4, y1), bgr, max(1, int(scale))) + cv2.putText( + img, + label, + (x1 + 2, y1 - 4), + cv2.FONT_HERSHEY_SIMPLEX, + font_scale, + (255, 255, 255), + font_thickness, + ) + + def annotated_image(self, scale: float = 1.0) -> Image: + """Return the full image with this detection's bbox and label drawn on it.""" + img = self.image.to_opencv().copy() + self.draw_on(img, scale=scale) + from dimos.msgs.sensor_msgs.Image import Image + + return Image.from_opencv(img, ts=self.ts) + # return focused image, only on the bbox def cropped_image(self, padding: int = 20) -> Image: """Return a cropped version of the image focused on the bounding box. diff --git a/dimos/perception/detection/type/detection2d/test_imageDetections2D.py b/dimos/perception/detection/type/detection2d/test_imageDetections2D.py index 4897d8d034..8092e4da45 100644 --- a/dimos/perception/detection/type/detection2d/test_imageDetections2D.py +++ b/dimos/perception/detection/type/detection2d/test_imageDetections2D.py @@ -50,3 +50,34 @@ def test_from_ros_detection2d_array(get_moment_2d) -> None: print(f" Recovered bbox: {recovered_det.bbox}") print(f" Track ID: {recovered_det.track_id}") print(f" Confidence: {recovered_det.confidence:.3f}") + + +def test_filter(imageDetections2d: ImageDetections2D) -> None: + dets = imageDetections2d.detections + assert len(dets) >= 1, "fixture should provide at least one detection" + + # No predicates → keep everything. + assert imageDetections2d.filter().detections == dets + + # Single predicate that always fails → empty. + result = imageDetections2d.filter(lambda d: False) + assert result.detections == [] + assert isinstance(result, ImageDetections2D) # subclass preserved + + # Single predicate that always passes → same detections, new instance. + result = imageDetections2d.filter(lambda d: True) + assert result.detections == dets + assert result is not imageDetections2d + + # Multi-predicate cascade: only keep detections with confidence > 0.5 AND + # class_id == first detection's class_id. Cascade must AND all predicates. + target_cls = dets[0].class_id + expected = [d for d in dets if d.confidence > 0.5 and d.class_id == target_cls] + result = imageDetections2d.filter( + lambda d: d.confidence > 0.5, + lambda d: d.class_id == target_cls, + ) + assert result.detections == expected + + # Preserves image. + assert result.image is imageDetections2d.image diff --git a/dimos/perception/detection/type/detection3d/imageDetections3DPC.py b/dimos/perception/detection/type/detection3d/imageDetections3DPC.py index 0fbb1a7c59..22c8c47b30 100644 --- a/dimos/perception/detection/type/detection3d/imageDetections3DPC.py +++ b/dimos/perception/detection/type/detection3d/imageDetections3DPC.py @@ -14,15 +14,51 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from lcm_msgs.foxglove_msgs import SceneUpdate # type: ignore[import-not-found] from dimos.perception.detection.type.detection3d.pointcloud import Detection3DPC from dimos.perception.detection.type.imageDetections import ImageDetections +if TYPE_CHECKING: + from dimos_lcm.sensor_msgs import CameraInfo + + from dimos.msgs.geometry_msgs.Transform import Transform + from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 + from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D + from dimos.perception.detection.type.detection3d.pointcloud_filters import PointCloudFilter + class ImageDetections3DPC(ImageDetections[Detection3DPC]): """Specialized class for 3D detections in an image.""" + @classmethod + def from_2d( + cls, + detections_2d: ImageDetections2D, + world_pointcloud: PointCloud2, + camera_info: CameraInfo, + world_to_optical_transform: Transform, + filters: list[PointCloudFilter] | None = None, + ) -> ImageDetections3DPC: + """Project every 2D detection into 3D, dropping any that yield no valid points.""" + detections_3d = [ + d3d + for det in detections_2d + if ( + d3d := Detection3DPC.from_2d( + det, + world_pointcloud, + camera_info, + world_to_optical_transform, + filters, + ) + ) + is not None + ] + return cls(image=detections_2d.image, detections=detections_3d) + def to_foxglove_scene_update(self) -> SceneUpdate: """Convert all detections to a Foxglove SceneUpdate message. diff --git a/dimos/perception/detection/type/imageDetections.py b/dimos/perception/detection/type/imageDetections.py index 25cd45545a..d1fac8669c 100644 --- a/dimos/perception/detection/type/imageDetections.py +++ b/dimos/perception/detection/type/imageDetections.py @@ -16,7 +16,7 @@ from functools import reduce from operator import add -from typing import TYPE_CHECKING, Generic, TypeVar +from typing import TYPE_CHECKING, Generic, Self, TypeVar from dimos_lcm.vision_msgs import Detection2DArray @@ -61,7 +61,7 @@ def __iter__(self) -> Iterator: # type: ignore[type-arg] def __getitem__(self, index): # type: ignore[no-untyped-def] return self.detections[index] - def filter(self, *predicates: Callable[[T], bool]) -> ImageDetections[T]: + def filter(self, *predicates: Callable[[T], bool]) -> Self: """Filter detections using one or more predicate functions. Multiple predicates are applied in cascade (all must return True). @@ -70,12 +70,10 @@ def filter(self, *predicates: Callable[[T], bool]) -> ImageDetections[T]: *predicates: Functions that take a detection and return True to keep it Returns: - A new ImageDetections instance with filtered detections + A new instance of the same class with filtered detections """ - filtered_detections = self.detections - for predicate in predicates: - filtered_detections = [det for det in filtered_detections if predicate(det)] - return ImageDetections(self.image, filtered_detections) + filtered_detections = [det for det in self.detections if all(p(det) for p in predicates)] + return type(self)(self.image, filtered_detections) def to_ros_detection2d_array(self) -> Detection2DArray: return Detection2DArray( @@ -84,6 +82,17 @@ def to_ros_detection2d_array(self) -> Detection2DArray: detections=[det.to_ros_detection2d() for det in self.detections], ) + def annotated_image(self, scale: float = 1.0) -> Image: + """Return the image with all detection bboxes and labels drawn on it.""" + img = self.image.to_opencv().copy() + for det in self.detections: + if hasattr(det, "draw_on"): + det.draw_on(img, scale=scale) + + from dimos.msgs.sensor_msgs.Image import Image as ImageMsg + + return ImageMsg.from_opencv(img, ts=self.image.ts) + def to_foxglove_annotations(self) -> ImageAnnotations: if not self.detections: return ImageAnnotations( diff --git a/dimos/perception/test_spatial_memory_module.py b/dimos/perception/test_spatial_memory_module.py index e2fcb06f68..309dec701b 100644 --- a/dimos/perception/test_spatial_memory_module.py +++ b/dimos/perception/test_spatial_memory_module.py @@ -25,13 +25,13 @@ from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out from dimos.core.transport import LCMTransport +from dimos.memory.timeseries.legacy import LegacyPickleStore from dimos.msgs.geometry_msgs.Transform import Transform from dimos.msgs.sensor_msgs.Image import Image from dimos.perception.spatial_perception import SpatialMemory from dimos.robot.unitree.type.odometry import Odometry from dimos.utils.data import get_data from dimos.utils.logging_config import setup_logger -from dimos.utils.testing.replay import TimedSensorReplay logger = setup_logger() @@ -41,7 +41,7 @@ class VideoReplayConfig(ModuleConfig): class VideoReplayModule(Module): - """Module that replays video data from TimedSensorReplay.""" + """Module that replays video data from LegacyPickleStore.""" config: VideoReplayConfig @@ -51,8 +51,8 @@ class VideoReplayModule(Module): @rpc def start(self) -> None: """Start replaying video data.""" - # Use TimedSensorReplay to replay video frames - video_replay = TimedSensorReplay(self.config.video_path, autocast=Image.from_numpy) + # Use LegacyPickleStore to replay video frames + video_replay = LegacyPickleStore(self.config.video_path, autocast=Image.from_numpy) # Subscribe to the replay stream and publish to LCM self._subscription = ( @@ -90,8 +90,8 @@ def _publish_tf(self, odom: Odometry) -> None: @rpc def start(self) -> None: """Start replaying odometry data.""" - # Use TimedSensorReplay to replay odometry - odom_replay = TimedSensorReplay(self.odom_path, autocast=Odometry.from_msg) + # Use LegacyPickleStore to replay odometry + odom_replay = LegacyPickleStore(self.odom_path, autocast=Odometry.from_msg) # Subscribe to the replay stream and publish to tf self._subscription = ( @@ -128,7 +128,7 @@ def dimos(): @pytest.mark.skipif_in_ci @pytest.mark.asyncio async def test_spatial_memory_module_with_replay(dimos, tmp_path): - """Test SpatialMemory module with TimedSensorReplay inputs.""" + """Test SpatialMemory module with LegacyPickleStore inputs.""" # Get test data paths data_path = get_data("unitree_office_walk") video_path = os.path.join(data_path, "video") diff --git a/dimos/protocol/pubsub/impl/test_rospubsub.py b/dimos/protocol/pubsub/impl/test_rospubsub.py index 6f29b3591b..e26f2ac709 100644 --- a/dimos/protocol/pubsub/impl/test_rospubsub.py +++ b/dimos/protocol/pubsub/impl/test_rospubsub.py @@ -26,7 +26,6 @@ # Add msg_name to LCM PointStamped for testing nested message conversion PointStamped.msg_name = "geometry_msgs.PointStamped" -from dimos.utils.data import get_data from dimos.utils.testing.collector import CallbackCollector from dimos.utils.testing.replay import TimedSensorReplay @@ -79,10 +78,8 @@ def test_pointcloud2_pubsub(publisher, subscriber): that can't be treated like a standard message with direct field copy. Uses LCM encode/decode roundtrip to properly convert internal representation. """ - dir_name = get_data("unitree_go2_bigoffice") - # Load real lidar data from replay (5 seconds in) - replay = TimedSensorReplay(f"{dir_name}/lidar") + replay = TimedSensorReplay("go2_bigoffice/lidar") original = replay.find_closest_seek(5.0) assert original is not None, "Failed to load lidar data from replay" diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 3c8cc80f82..881914b129 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -91,6 +91,7 @@ "unitree-go2-coordinator": "dimos.robot.unitree.go2.blueprints.basic.unitree_go2_coordinator:unitree_go2_coordinator", "unitree-go2-detection": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_detection:unitree_go2_detection", "unitree-go2-fleet": "dimos.robot.unitree.go2.blueprints.basic.unitree_go2_fleet:unitree_go2_fleet", + "unitree-go2-memory": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2:unitree_go2_memory", "unitree-go2-ros": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_ros:unitree_go2_ros", "unitree-go2-security": "dimos.robot.unitree.go2.blueprints.agentic.unitree_go2_security:unitree_go2_security", "unitree-go2-spatial": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_spatial:unitree_go2_spatial", @@ -133,6 +134,7 @@ "g1-sim-connection": "dimos.robot.unitree.g1.sim.G1SimConnection", "go2-connection": "dimos.robot.unitree.go2.connection.GO2Connection", "go2-fleet-connection": "dimos.robot.unitree.go2.fleet_connection.Go2FleetConnection", + "go2-memory": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2.Go2Memory", "google-maps-skill-container": "dimos.agents.skills.google_maps_skill_container.GoogleMapsSkillContainer", "gps-nav-skill-container": "dimos.agents.skills.gps_nav_skill.GpsNavSkillContainer", "grasp-gen-module": "dimos.manipulation.grasping.graspgen_module.GraspGenModule", @@ -146,6 +148,7 @@ "map": "dimos.robot.unitree.type.map.Map", "mcp-client": "dimos.agents.mcp.mcp_client.McpClient", "mcp-server": "dimos.agents.mcp.mcp_server.McpServer", + "memory-module": "dimos.memory2.module.MemoryModule", "mock-b1-connection-module": "dimos.robot.unitree.b1.connection.MockB1ConnectionModule", "module-a": "dimos.robot.unitree.demo_error_on_name_conflicts.ModuleA", "module-b": "dimos.robot.unitree.demo_error_on_name_conflicts.ModuleB", @@ -167,11 +170,13 @@ "quest-teleop-module": "dimos.teleop.quest.quest_teleop_module.QuestTeleopModule", "real-sense-camera": "dimos.hardware.sensors.camera.realsense.camera.RealSenseCamera", "receiver-module": "dimos.utils.demo_image_encoding.ReceiverModule", + "recorder": "dimos.memory2.module.Recorder", "reid-module": "dimos.perception.detection.reid.module.ReidModule", "replanning-a-star-planner": "dimos.navigation.replanning_a_star.module.ReplanningAStarPlanner", "rerun-bridge-module": "dimos.visualization.rerun.bridge.RerunBridgeModule", "ros-nav": "dimos.navigation.rosnav.ROSNav", "security-module": "dimos.experimental.security_demo.security_module.SecurityModule", + "semantic-search": "dimos.memory2.module.SemanticSearch", "simple-phone-teleop": "dimos.teleop.phone.phone_extensions.SimplePhoneTeleop", "spatial-memory": "dimos.perception.spatial_perception.SpatialMemory", "speak-skill": "dimos.agents.skills.speak_skill.SpeakSkill", diff --git a/dimos/robot/drone/dji_video_stream.py b/dimos/robot/drone/dji_video_stream.py index 60618ae712..df153192e3 100644 --- a/dimos/robot/drone/dji_video_stream.py +++ b/dimos/robot/drone/dji_video_stream.py @@ -214,7 +214,7 @@ def get_stream(self) -> Observable[Image]: # type: ignore[override] """ from reactivex import operators as ops - from dimos.utils.testing.replay import TimedSensorReplay + from dimos.memory.timeseries.legacy import LegacyPickleStore def _fix_format(img: Image) -> Image: if img.format == ImageFormat.BGR: @@ -222,7 +222,7 @@ def _fix_format(img: Image) -> Image: return img logger.info("Creating video replay stream") - video_store: Any = TimedSensorReplay("drone/video") + video_store: Any = LegacyPickleStore("drone/video") stream: Observable[Image] = video_store.stream().pipe(ops.map(_fix_format)) return stream diff --git a/dimos/robot/drone/mavlink_connection.py b/dimos/robot/drone/mavlink_connection.py index 7e7d233c3f..9c0e7984fb 100644 --- a/dimos/robot/drone/mavlink_connection.py +++ b/dimos/robot/drone/mavlink_connection.py @@ -1030,12 +1030,12 @@ def __init__(self, connection_string: str) -> None: # Create fake mavlink object class FakeMavlink: def __init__(self) -> None: + from dimos.memory.timeseries.legacy import LegacyPickleStore from dimos.utils.data import get_data - from dimos.utils.testing.replay import TimedSensorReplay get_data("drone") - self.replay: Any = TimedSensorReplay("drone/mavlink") + self.replay: Any = LegacyPickleStore("drone/mavlink") self.messages: list[dict[str, Any]] = [] # The stream() method returns an Observable that emits messages with timing self.replay.stream().subscribe(self.messages.append) diff --git a/dimos/robot/drone/test_drone.py b/dimos/robot/drone/test_drone.py index ea0c42ecf5..8baa177e84 100644 --- a/dimos/robot/drone/test_drone.py +++ b/dimos/robot/drone/test_drone.py @@ -194,7 +194,7 @@ class TestReplayMode(unittest.TestCase): def test_fake_mavlink_connection(self) -> None: """Test FakeMavlinkConnection replays messages correctly.""" - with patch("dimos.utils.testing.replay.TimedSensorReplay") as mock_replay: + with patch("dimos.memory.timeseries.legacy.LegacyPickleStore") as mock_replay: # Mock the replay stream MagicMock() mock_messages = [ @@ -220,7 +220,7 @@ def test_fake_mavlink_connection(self) -> None: def test_fake_video_stream_no_throttling(self) -> None: """Test FakeDJIVideoStream returns replay stream with format fix.""" - with patch("dimos.utils.testing.replay.TimedSensorReplay") as mock_replay: + with patch("dimos.memory.timeseries.legacy.LegacyPickleStore") as mock_replay: mock_stream = MagicMock() mock_replay.return_value.stream.return_value = mock_stream @@ -282,7 +282,7 @@ def test_connection_module_replay_with_messages(self) -> None: os.environ["DRONE_CONNECTION"] = "replay" - with patch("dimos.utils.testing.replay.TimedSensorReplay") as mock_replay: + with patch("dimos.memory.timeseries.legacy.LegacyPickleStore") as mock_replay: # Set up MAVLink replay stream mavlink_messages = [ {"mavpackettype": "HEARTBEAT", "type": 2, "base_mode": 193}, @@ -334,7 +334,7 @@ def subscribe(callback) -> None: # Configure mock replay to return appropriate streams def replay_side_effect(store_name: str): - print(f"[TEST] TimedSensorReplay created for: {store_name}") + print(f"[TEST] LegacyPickleStore created for: {store_name}") mock = MagicMock() if "mavlink" in store_name: mock.stream.return_value = create_mavlink_stream() @@ -435,7 +435,7 @@ def tearDown(self) -> None: self.foxglove_patch.stop() @patch("dimos.robot.drone.drone.ModuleCoordinator") - @patch("dimos.utils.testing.replay.TimedSensorReplay") + @patch("dimos.memory.timeseries.legacy.LegacyPickleStore") def test_full_system_with_replay(self, mock_replay, mock_coordinator_class) -> None: """Test full drone system initialization and operation with replay mode.""" # Set up mock replay data @@ -569,7 +569,7 @@ def deploy_side_effect(module_class, **kwargs): class TestDroneControlCommands(unittest.TestCase): """Test drone control commands with FakeMavlinkConnection.""" - @patch("dimos.utils.testing.replay.TimedSensorReplay") + @patch("dimos.memory.timeseries.legacy.LegacyPickleStore") @patch("dimos.utils.data.get_data") def test_arm_disarm_commands(self, mock_get_data, mock_replay) -> None: """Test arm and disarm commands work with fake connection.""" @@ -588,7 +588,7 @@ def test_arm_disarm_commands(self, mock_get_data, mock_replay) -> None: result = conn.disarm() self.assertIsInstance(result, bool) # Should return bool without crashing - @patch("dimos.utils.testing.replay.TimedSensorReplay") + @patch("dimos.memory.timeseries.legacy.LegacyPickleStore") @patch("dimos.utils.data.get_data") def test_takeoff_land_commands(self, mock_get_data, mock_replay) -> None: """Test takeoff and land commands with fake connection.""" @@ -607,7 +607,7 @@ def test_takeoff_land_commands(self, mock_get_data, mock_replay) -> None: result = conn.land() self.assertIsNotNone(result) - @patch("dimos.utils.testing.replay.TimedSensorReplay") + @patch("dimos.memory.timeseries.legacy.LegacyPickleStore") @patch("dimos.utils.data.get_data") def test_set_mode_command(self, mock_get_data, mock_replay) -> None: """Test flight mode setting with fake connection.""" @@ -628,7 +628,7 @@ def test_set_mode_command(self, mock_get_data, mock_replay) -> None: class TestDronePerception(unittest.TestCase): """Test drone perception capabilities.""" - @patch("dimos.utils.testing.replay.TimedSensorReplay") + @patch("dimos.memory.timeseries.legacy.LegacyPickleStore") @patch("dimos.utils.data.get_data") def test_video_stream_replay(self, mock_get_data, mock_replay) -> None: """Test video stream works with replay data.""" @@ -698,7 +698,7 @@ def piped_subscribe(callback): class TestDroneMovementAndOdometry(unittest.TestCase): """Test drone movement commands and odometry.""" - @patch("dimos.utils.testing.replay.TimedSensorReplay") + @patch("dimos.memory.timeseries.legacy.LegacyPickleStore") @patch("dimos.utils.data.get_data") def test_movement_command_conversion(self, mock_get_data, mock_replay) -> None: """Test movement commands are properly converted from ROS to NED.""" @@ -718,7 +718,7 @@ def test_movement_command_conversion(self, mock_get_data, mock_replay) -> None: # Movement should be converted to NED internally # The fake connection doesn't actually send commands, but it should not crash - @patch("dimos.utils.testing.replay.TimedSensorReplay") + @patch("dimos.memory.timeseries.legacy.LegacyPickleStore") @patch("dimos.utils.data.get_data") def test_odometry_from_replay(self, mock_get_data, mock_replay) -> None: """Test odometry is properly generated from replay messages.""" @@ -765,7 +765,7 @@ def replay_stream_subscribe(callback) -> None: self.assertIsNotNone(odom.orientation) self.assertEqual(odom.frame_id, "world") - @patch("dimos.utils.testing.replay.TimedSensorReplay") + @patch("dimos.memory.timeseries.legacy.LegacyPickleStore") @patch("dimos.utils.data.get_data") def test_position_integration_indoor(self, mock_get_data, mock_replay) -> None: """Test position integration for indoor flight without GPS.""" @@ -810,7 +810,7 @@ def replay_stream_subscribe(callback) -> None: class TestDroneStatusAndTelemetry(unittest.TestCase): """Test drone status and telemetry reporting.""" - @patch("dimos.utils.testing.replay.TimedSensorReplay") + @patch("dimos.memory.timeseries.legacy.LegacyPickleStore") @patch("dimos.utils.data.get_data") def test_status_extraction(self, mock_get_data, mock_replay) -> None: """Test status is properly extracted from MAVLink messages.""" @@ -855,7 +855,7 @@ def replay_stream_subscribe(callback) -> None: self.assertIn("altitude", status) self.assertIn("heading", status) - @patch("dimos.utils.testing.replay.TimedSensorReplay") + @patch("dimos.memory.timeseries.legacy.LegacyPickleStore") @patch("dimos.utils.data.get_data") def test_telemetry_json_publishing(self, mock_get_data, mock_replay) -> None: """Test full telemetry is published as JSON.""" @@ -909,7 +909,7 @@ def replay_stream_subscribe(callback) -> None: class TestFlyToErrorHandling(unittest.TestCase): """Test fly_to() error handling paths.""" - @patch("dimos.utils.testing.replay.TimedSensorReplay") + @patch("dimos.memory.timeseries.legacy.LegacyPickleStore") @patch("dimos.utils.data.get_data") def test_concurrency_lock(self, mock_get_data, mock_replay) -> None: """flying_to_target=True rejects concurrent fly_to() calls.""" @@ -923,7 +923,7 @@ def test_concurrency_lock(self, mock_get_data, mock_replay) -> None: result = conn.fly_to(37.0, -122.0, 10.0) self.assertIn("Already flying to target", result) - @patch("dimos.utils.testing.replay.TimedSensorReplay") + @patch("dimos.memory.timeseries.legacy.LegacyPickleStore") @patch("dimos.utils.data.get_data") def test_error_when_not_connected(self, mock_get_data, mock_replay) -> None: """connected=False returns error immediately.""" diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py index 9acc0bb7bb..54a2c0f7c6 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py @@ -48,6 +48,10 @@ def _convert_camera_info(camera_info: Any) -> Any: ) +def _convert_global_map(grid: Any) -> Any: + return grid.to_rerun(bottom_cutoff=0) + + def _convert_navigation_costmap(grid: Any) -> Any: return grid.to_rerun( colormap="Accent", @@ -62,7 +66,6 @@ def _static_base_link(rr: Any) -> list[Any]: rr.Boxes3D( half_sizes=[0.35, 0.155, 0.2], colors=[(0, 255, 127)], - fill_mode="wireframe", ), rr.Transform3D(parent_frame="tf#/base_link"), ] @@ -83,6 +86,9 @@ def _go2_rerun_blueprint() -> Any: line_grid=rrb.LineGrid3D( plane=rr.components.Plane3D.XY.with_distance(0.5), ), + overrides={ + "world/lidar": rrb.EntityBehavior(visible=False), + }, ), column_shares=[1, 2], ), @@ -103,6 +109,7 @@ def _go2_rerun_blueprint() -> Any: # This is unsustainable once we move to multi robot etc "visual_override": { "world/camera_info": _convert_camera_info, + "world/global_map": _convert_global_map, "world/navigation_costmap": _convert_navigation_costmap, }, "max_hz": { @@ -134,6 +141,7 @@ def _go2_rerun_blueprint() -> Any: else: with_vis = _transports_base + unitree_go2_basic = ( autoconnect( with_vis, diff --git a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py index 01f4f7bfb9..f353d995af 100644 --- a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py +++ b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py @@ -13,9 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path + from dimos.core.coordination.blueprints import autoconnect +from dimos.core.stream import In from dimos.mapping.costmapper import CostMapper from dimos.mapping.voxels import VoxelGridMapper +from dimos.memory2.module import Recorder, RecorderConfig +from dimos.msgs.sensor_msgs.Image import Image +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector import ( WavefrontFrontierExplorer, ) @@ -30,6 +36,22 @@ ReplanningAStarPlanner.blueprint(), WavefrontFrontierExplorer.blueprint(), PatrollingModule.blueprint(), -).global_config(n_workers=7, robot_model="unitree_go2") +).global_config(n_workers=9, robot_model="unitree_go2") + + +class Go2MemoryConfig(RecorderConfig): + db_path: str | Path = "recording_go2.db" + + +class Go2Memory(Recorder): + color_image: In[Image] + lidar: In[PointCloud2] + config: Go2MemoryConfig + + +unitree_go2_memory = autoconnect( + unitree_go2, + Go2Memory.blueprint(), +).global_config(n_workers=10) -__all__ = ["unitree_go2"] +__all__ = ["unitree_go2", "unitree_go2_memory"] diff --git a/dimos/robot/unitree/go2/connection.py b/dimos/robot/unitree/go2/connection.py index 1b91cd7a27..055f0ab591 100644 --- a/dimos/robot/unitree/go2/connection.py +++ b/dimos/robot/unitree/go2/connection.py @@ -45,7 +45,6 @@ from dimos.msgs.sensor_msgs.Image import Image, ImageFormat from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.connection import UnitreeWebRTCConnection -from dimos.utils.data import get_data from dimos.utils.decorators.decorators import simple_mcache from dimos.utils.testing.replay import TimedSensorReplay, TimedSensorStorage @@ -102,11 +101,26 @@ def _camera_info_static() -> CameraInfo: ) +# Static camera mount chain: base_link -> camera_link -> camera_optical. +# TODO we need a standardized way to specify this for all cameras in dimos +BASE_TO_OPTICAL: Transform = Transform( + translation=Vector3(0.3, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="base_link", + child_frame_id="camera_link", +) + Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(-0.5, 0.5, -0.5, 0.5), + frame_id="camera_link", + child_frame_id="camera_optical", +) + + def make_connection(ip: str | None, cfg: GlobalConfig) -> Go2ConnectionProtocol: connection_type = cfg.unitree_connection_type if ip in ("fake", "mock", "replay") or connection_type == "replay": - dataset = cfg.replay_dir + dataset = cfg.replay_db return ReplayConnection(dataset=dataset) elif ip == "mujoco" or connection_type == "mujoco": from dimos.robot.unitree.mujoco_connection import MujocoConnection @@ -121,11 +135,10 @@ class ReplayConnection(UnitreeWebRTCConnection): # we don't want UnitreeWebRTCConnection to init def __init__( # type: ignore[no-untyped-def] self, - dataset: str = "go2_sf_office", + dataset: str = "go2_bigoffice", **kwargs, ) -> None: self.dir_name = dataset - get_data(self.dir_name) self.replay_config = { "loop": kwargs.get("loop", True), "seek": kwargs.get("seek"), @@ -181,7 +194,7 @@ def _autocast_video(x): # type: ignore[no-untyped-def] arr = x.to_ndarray(format="rgb24") if hasattr(x, "to_ndarray") else x return Image.from_numpy(arr, format=ImageFormat.RGB, frame_id="camera_optical") - video_store = TimedSensorReplay(f"{self.dir_name}/video", autocast=_autocast_video) + video_store = TimedSensorReplay(f"{self.dir_name}/color_image", autocast=_autocast_video) return video_store.stream(**self.replay_config) def move(self, twist: Twist, duration: float = 0.0) -> bool: diff --git a/dimos/robot/unitree/modular/detect.py b/dimos/robot/unitree/modular/detect.py index d6ed78d101..b521301a98 100644 --- a/dimos/robot/unitree/modular/detect.py +++ b/dimos/robot/unitree/modular/detect.py @@ -139,6 +139,7 @@ def broadcast( # type: ignore[no-untyped-def] def process_data(): # type: ignore[no-untyped-def] + from dimos.memory.timeseries.legacy import LegacyPickleStore from dimos.msgs.sensor_msgs.Image import Image from dimos.perception.detection.module2D import ( # type: ignore[attr-defined] Detection2DModule, @@ -146,15 +147,14 @@ def process_data(): # type: ignore[no-untyped-def] ) from dimos.robot.unitree.type.odometry import Odometry from dimos.utils.data import get_data - from dimos.utils.testing.replay import TimedSensorReplay get_data("unitree_office_walk") target = 1751591272.9654856 - lidar_store = TimedSensorReplay( + lidar_store = LegacyPickleStore( "unitree_office_walk/lidar", autocast=pointcloud2_from_webrtc_lidar ) - video_store = TimedSensorReplay("unitree_office_walk/video", autocast=Image.from_numpy) - odom_store = TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + video_store = LegacyPickleStore("unitree_office_walk/video", autocast=Image.from_numpy) + odom_store = LegacyPickleStore("unitree_office_walk/odom", autocast=Odometry.from_msg) def attach_frame_id(image: Image) -> Image: image.frame_id = "camera_optical" diff --git a/dimos/robot/unitree/testing/test_tooling.py b/dimos/robot/unitree/testing/test_tooling.py index 40db01feee..c4f64c054f 100644 --- a/dimos/robot/unitree/testing/test_tooling.py +++ b/dimos/robot/unitree/testing/test_tooling.py @@ -16,17 +16,17 @@ import pytest +from dimos.memory.timeseries.legacy import LegacyPickleStore from dimos.robot.unitree.type.lidar import pointcloud2_from_webrtc_lidar from dimos.robot.unitree.type.odometry import Odometry from dimos.utils.reactive import backpressure -from dimos.utils.testing.replay import TimedSensorReplay @pytest.mark.tool def test_replay_all() -> None: - lidar_store = TimedSensorReplay("unitree/lidar", autocast=pointcloud2_from_webrtc_lidar) - odom_store = TimedSensorReplay("unitree/odom", autocast=Odometry.from_msg) - video_store = TimedSensorReplay("unitree/video") + lidar_store = LegacyPickleStore("unitree/lidar", autocast=pointcloud2_from_webrtc_lidar) + odom_store = LegacyPickleStore("unitree/odom", autocast=Odometry.from_msg) + video_store = LegacyPickleStore("unitree/video") backpressure(odom_store.stream()).subscribe(print) backpressure(lidar_store.stream()).subscribe(print) diff --git a/dimos/simulation/engines/mujoco_sim_module.py b/dimos/simulation/engines/mujoco_sim_module.py index 0a11218004..3d2ff927fe 100644 --- a/dimos/simulation/engines/mujoco_sim_module.py +++ b/dimos/simulation/engines/mujoco_sim_module.py @@ -112,8 +112,6 @@ class MujocoSimModule( camera_info: Out[CameraInfo] depth_camera_info: Out[CameraInfo] - default_config = MujocoSimModuleConfig - def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self._engine: MujocoEngine | None = None diff --git a/dimos/types/test_timestamped.py b/dimos/types/test_timestamped.py index 1eb892b299..3913ce9b76 100644 --- a/dimos/types/test_timestamped.py +++ b/dimos/types/test_timestamped.py @@ -20,6 +20,7 @@ from reactivex.scheduler import ThreadPoolScheduler from dimos.memory.timeseries.inmemory import InMemoryStore +from dimos.memory.timeseries.legacy import LegacyPickleStore from dimos.msgs.sensor_msgs.Image import Image from dimos.types.timestamped import ( Timestamped, @@ -30,7 +31,6 @@ ) from dimos.utils.data import get_data from dimos.utils.reactive import backpressure -from dimos.utils.testing.replay import TimedSensorReplay def test_timestamped_dt_method() -> None: @@ -297,7 +297,7 @@ def spy(image): # sensor reply of raw video frames video_raw = ( - TimedSensorReplay( + LegacyPickleStore( "unitree_office_walk/video", autocast=lambda x: Image.from_numpy(x).to_rgb() ) .stream(speed) diff --git a/dimos/utils/data.py b/dimos/utils/data.py index 75e1fb291c..63e644909a 100644 --- a/dimos/utils/data.py +++ b/dimos/utils/data.py @@ -52,7 +52,7 @@ def _get_user_data_dir() -> Path: @cache -def _get_repo_root() -> Path: +def get_project_root() -> Path: # Check if running from git repo if (DIMOS_PROJECT_ROOT / ".git").exists(): return DIMOS_PROJECT_ROOT @@ -107,8 +107,8 @@ def _get_repo_root() -> Path: @cache def get_data_dir(extra_path: str | None = None) -> Path: if extra_path: - return _get_repo_root() / "data" / extra_path - return _get_repo_root() / "data" + return get_project_root() / "data" / extra_path + return get_project_root() / "data" @cache @@ -186,7 +186,7 @@ def _pull_lfs_archive(filename: str | Path) -> Path: _check_git_lfs_available() # Find repository root - repo_root = _get_repo_root() + repo_root = get_project_root() # Construct path to test data file file_path = _get_lfs_dir() / (str(filename) + ".tar.gz") diff --git a/dimos/utils/test_data.py b/dimos/utils/test_data.py index 9970fc5912..373126ec26 100644 --- a/dimos/utils/test_data.py +++ b/dimos/utils/test_data.py @@ -25,7 +25,7 @@ @pytest.mark.slow def test_pull_file() -> None: - repo_root = data._get_repo_root() + repo_root = data.get_project_root() test_file_name = "cafe.jpg" test_file_compressed = data._get_lfs_dir() / (test_file_name + ".tar.gz") test_file_decompressed = data.get_data_dir() / test_file_name @@ -81,7 +81,7 @@ def test_pull_file() -> None: @pytest.mark.slow def test_pull_dir() -> None: - repo_root = data._get_repo_root() + repo_root = data.get_project_root() test_dir_name = "ab_lidar_frames" test_dir_compressed = data._get_lfs_dir() / (test_dir_name + ".tar.gz") test_dir_decompressed = data.get_data_dir() / test_dir_name diff --git a/dimos/utils/testing/moment.py b/dimos/utils/testing/moment.py index 7ab1b07eba..0130a5b0a3 100644 --- a/dimos/utils/testing/moment.py +++ b/dimos/utils/testing/moment.py @@ -17,8 +17,8 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar from dimos.core.resource import Resource +from dimos.memory.timeseries.legacy import LegacyPickleStore from dimos.types.timestamped import Timestamped -from dimos.utils.testing.replay import TimedSensorReplay if TYPE_CHECKING: from dimos.core.stream import Transport @@ -30,7 +30,7 @@ class SensorMoment(Generic[T], Resource): value: T | None = None def __init__(self, name: str, transport: Transport[T]) -> None: - self.replay: TimedSensorReplay[T] = TimedSensorReplay(name) + self.replay: LegacyPickleStore[T] = LegacyPickleStore(name) self.transport = transport def seek(self, timestamp: float) -> None: diff --git a/dimos/utils/testing/replay.py b/dimos/utils/testing/replay.py index 588b63e099..372d4167ce 100644 --- a/dimos/utils/testing/replay.py +++ b/dimos/utils/testing/replay.py @@ -12,11 +12,309 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Shim for TimedSensorReplay/TimedSensorStorage.""" +"""Shim layer exposing the legacy ``TimedSensorReplay`` API over memory2. + +``TimedSensorReplay(name, autocast=...)`` opens the memory2 SQLite database at +``{get_data_dir}/{dataset}.db`` and reads the named stream. ``name`` is expected +to be ``"/"``. + +Callers that still need to read from legacy pickle dirs should import +``LegacyPickleStore`` directly from ``dimos.memory.timeseries.legacy``. The +write-side (``TimedSensorStorage``/``SensorStorage``) still points at +``LegacyPickleStore`` — out of scope for the memory2 migration. +""" + +from __future__ import annotations + +from collections.abc import Callable, Iterator +from pathlib import Path +import time +from typing import Any, Generic, TypeVar, cast + +import reactivex as rx +from reactivex.abc import DisposableBase, ObserverBase, SchedulerBase +from reactivex.disposable import CompositeDisposable, Disposable +from reactivex.observable import Observable +from reactivex.scheduler import TimeoutScheduler from dimos.memory.timeseries.legacy import LegacyPickleStore +from dimos.memory2.store.sqlite import SqliteStore +from dimos.utils.data import get_data + +T = TypeVar("T") + + +# Shared SqliteStore per .db path — ReplayConnection opens three adapters +# (lidar, odom, color_image) for the same dataset, so sharing a connection +# avoids redundant opens. +_stores: dict[str, SqliteStore] = {} + + +def _resolve_db_path(dataset: str) -> Path: + """Map a dataset name to an on-disk .db path (LFS-downloading on miss). + + - ``"go2_bigoffice"`` → ``{data_dir}/go2_bigoffice.db`` + - Absolute/relative paths are used as-is. + """ + p = Path(dataset) + if p.is_absolute() or p.exists(): + return p + return get_data(f"{dataset}.db") + + +def _get_store(dataset: str) -> SqliteStore: + db_path = _resolve_db_path(dataset) + key = str(db_path) + store = _stores.get(key) + if store is None: + store = SqliteStore(path=key) + store.start() + _stores[key] = store + return store + + +def _close_all() -> None: + """Close every cached SqliteStore. For test teardown.""" + for store in _stores.values(): + store.stop() + _stores.clear() + + +def timed_playback( + source: Callable[[], Iterator[tuple[float, T]]], + speed: float = 1.0, + detect_loop: bool = True, +) -> Observable[T]: + """Replay a ``(ts, data)`` iterator as an Observable at real-time speed. + + Anchors on the first timestamp and schedules subsequent emissions with + ``scheduler.schedule_relative`` at ``anchor + (ts - first_ts) / speed``. + When ``detect_loop`` is set, a backwards-going timestamp re-anchors — use + this when the source iterator loops. + + ``source`` is a factory: called fresh on each subscription so the same + Observable can be re-subscribed without iterator collisions. + + Pending timers are tracked on a CompositeDisposable and cancelled on + subscription dispose. + """ + + def subscribe( + observer: ObserverBase[T], + scheduler: SchedulerBase | None = None, + ) -> DisposableBase: + sched = scheduler or TimeoutScheduler() + disp = CompositeDisposable() + is_disposed = False + iterator = source() + + try: + first_ts, first_data = next(iterator) + except StopIteration: + observer.on_completed() + return Disposable() + + start_local_time = time.time() + start_replay_time = first_ts + + observer.on_next(first_data) + + try: + next_message: tuple[float, T] | None = next(iterator) + except StopIteration: + observer.on_completed() + return disp + + prev_ts = first_ts + + def schedule_emission(message: tuple[float, T]) -> None: + nonlocal next_message, start_local_time, start_replay_time, prev_ts + + if is_disposed: + return + + ts, data = message + + if detect_loop and ts < prev_ts: + start_local_time = time.time() + start_replay_time = ts + prev_ts = ts + + try: + next_message = next(iterator) + except StopIteration: + next_message = None + + target_time = start_local_time + (ts - start_replay_time) / speed + delay = max(0.0, target_time - time.time()) + + def emit(_scheduler: SchedulerBase, _state: object) -> DisposableBase | None: + if is_disposed: + return None + observer.on_next(data) + if next_message is not None: + schedule_emission(next_message) + else: + observer.on_completed() + return None + + disp.add(sched.schedule_relative(delay, emit)) + + if next_message is not None: + schedule_emission(next_message) + + def dispose() -> None: + nonlocal is_disposed + is_disposed = True + disp.dispose() + + return Disposable(dispose) + + return rx.create(subscribe) + + +class Memory2ReplayAdapter(Generic[T]): + """Memory2-backed replacement for the legacy ``TimedSensorReplay``. + + Accepts names shaped like ``"/"`` (e.g. + ``"go2_bigoffice/lidar"``). ``autocast`` is applied after the codec + decode, matching legacy behavior. + """ + + def __init__(self, name: str | Path, autocast: Callable[[Any], T] | None = None) -> None: + parts = str(name).split("/", 1) + if len(parts) != 2: + raise ValueError( + f"Expected '/' name, got {name!r}. " + "E.g. TimedSensorReplay('go2_bigoffice/lidar')." + ) + self._dataset, self._stream_name = parts + self._autocast = autocast + + @property + def _stream(self) -> Any: + return _get_store(self._dataset).stream(self._stream_name) + + def _decode(self, obs: Any) -> T: + data = obs.data + if self._autocast is not None: + data = self._autocast(data) + return cast("T", data) + + def iterate_ts( + self, + seek: float | None = None, + duration: float | None = None, + from_timestamp: float | None = None, + loop: bool = False, + ) -> Iterator[tuple[float, T]]: + s = self._stream + + first_ts = self.first_timestamp() + if first_ts is None: + return + + start: float | None = None + if from_timestamp is not None: + start = from_timestamp + elif seek is not None: + start = first_ts + seek + + end: float | None = None + if duration is not None: + start_ts = start if start is not None else first_ts + end = start_ts + duration + + # Time-bound stream using memory2 filters. time_range is inclusive on + # both sides; .after is exclusive. Use time_range with +inf for the + # inclusive-start semantics legacy callers rely on. + if start is not None and end is not None: + bound = s.time_range(start, end) + elif start is not None: + bound = s.time_range(start, float("inf")) + elif end is not None: + bound = s.before(end) # no start → include from beginning + else: + bound = s + + while True: + emitted = False + for obs in bound: + emitted = True + yield (obs.ts, self._decode(obs)) + if not loop or not emitted: + break + + def iterate(self) -> Iterator[T]: + for _, data in self.iterate_ts(): + yield data + + def first_timestamp(self) -> float | None: + try: + return float(self._stream.first().ts) + except LookupError: + return None + + def first(self) -> T | None: + try: + return self._decode(self._stream.first()) + except LookupError: + return None + + def find_closest(self, timestamp: float, tolerance: float = 1.0) -> T | None: + try: + obs = self._stream.at(timestamp, tolerance).first() + except LookupError: + return None + return self._decode(obs) + + def find_closest_seek(self, seconds: float) -> T | None: + first_ts = self.first_timestamp() + if first_ts is None: + return None + try: + obs = self._stream.time_range(first_ts + seconds, float("inf")).first() + except LookupError: + return None + return self._decode(obs) + + def count(self) -> int: + return int(self._stream.count()) + + @property + def files(self) -> list[Path]: + """Compat stub — memory2 has no per-frame files.""" + return [] + + def load_one(self, name: int | str | Path) -> tuple[float, T]: + """Compat stub — index-based access by offset.""" + if not isinstance(name, int): + raise TypeError( + f"Memory2ReplayAdapter.load_one only supports integer offsets; got {name!r}" + ) + obs = self._stream.limit(1).offset(int(name)).first() + return (obs.ts, self._decode(obs)) + + def stream( + self, + speed: float = 1.0, + seek: float | None = None, + duration: float | None = None, + from_timestamp: float | None = None, + loop: bool = False, + ) -> Observable[T]: + """Real-time scheduled playback as an RxPY Observable.""" + return timed_playback( + lambda: self.iterate_ts( + seek=seek, duration=duration, from_timestamp=from_timestamp, loop=loop + ), + speed=speed, + ) + + +TimedSensorReplay = Memory2ReplayAdapter +# Write-side + non-timed read-side stay on legacy pickle. SensorReplay = LegacyPickleStore SensorStorage = LegacyPickleStore -TimedSensorReplay = LegacyPickleStore TimedSensorStorage = LegacyPickleStore diff --git a/dimos/utils/testing/test_replay.py b/dimos/utils/testing/test_replay.py index 10ace353f7..1a8a911b07 100644 --- a/dimos/utils/testing/test_replay.py +++ b/dimos/utils/testing/test_replay.py @@ -16,16 +16,16 @@ from reactivex import operators as ops +from dimos.memory.timeseries.legacy import LegacyPickleStore from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.robot.unitree.type.lidar import pointcloud2_from_webrtc_lidar from dimos.robot.unitree.type.odometry import Odometry from dimos.utils.data import get_data -from dimos.utils.testing import replay def test_timed_sensor_replay() -> None: get_data("unitree_office_walk") - odom_store = replay.TimedSensorReplay("unitree_office_walk/odom") + odom_store = LegacyPickleStore("unitree_office_walk/odom") itermsgs = [] for msg in odom_store.iterate(): @@ -51,7 +51,7 @@ def test_timed_sensor_replay() -> None: def test_iterate_ts_no_seek() -> None: """Test iterate_ts without seek (start_timestamp=None)""" - odom_store = replay.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + odom_store = LegacyPickleStore("unitree_office_walk/odom", autocast=Odometry.from_msg) # Test without seek ts_msgs = [] @@ -69,7 +69,7 @@ def test_iterate_ts_no_seek() -> None: def test_iterate_ts_with_from_timestamp() -> None: """Test iterate_ts with from_timestamp (absolute timestamp)""" - odom_store = replay.TimedSensorReplay("unitree_office_walk/odom") + odom_store = LegacyPickleStore("unitree_office_walk/odom") # First get all messages to find a good seek point all_msgs = [] @@ -97,7 +97,7 @@ def test_iterate_ts_with_from_timestamp() -> None: def test_iterate_ts_with_relative_seek() -> None: """Test iterate_ts with seek (relative seconds after first timestamp)""" - odom_store = replay.TimedSensorReplay("unitree_office_walk/odom") + odom_store = LegacyPickleStore("unitree_office_walk/odom") # Get first few messages to understand timing all_msgs = [] @@ -126,7 +126,7 @@ def test_iterate_ts_with_relative_seek() -> None: def test_stream_with_seek() -> None: """Test stream method with seek parameters""" - odom_store = replay.TimedSensorReplay("unitree_office_walk/odom") + odom_store = LegacyPickleStore("unitree_office_walk/odom") # Test stream with relative seek msgs_with_seek = [] @@ -152,7 +152,7 @@ def test_stream_with_seek() -> None: def test_duration_with_loop() -> None: """Test duration parameter with looping in TimedSensorReplay""" - odom_store = replay.TimedSensorReplay("unitree_office_walk/odom") + odom_store = LegacyPickleStore("unitree_office_walk/odom") # Collect timestamps from a small duration window collected_ts = [] @@ -187,7 +187,7 @@ def test_first_methods() -> None: """Test first() and first_timestamp() methods""" # Test SensorReplay.first() - lidar_replay = replay.SensorReplay("office_lidar", autocast=pointcloud2_from_webrtc_lidar) + lidar_replay = LegacyPickleStore("office_lidar", autocast=pointcloud2_from_webrtc_lidar) print("first file", lidar_replay.files[0]) # Verify the first file ends with 000.pickle using regex @@ -207,7 +207,7 @@ def test_first_methods() -> None: assert abs(first_msg.ts - first_from_iterate.ts) < 1.0 # Within 1 second tolerance # Test TimedSensorReplay.first_timestamp() - odom_store = replay.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + odom_store = LegacyPickleStore("unitree_office_walk/odom", autocast=Odometry.from_msg) first_ts = odom_store.first_timestamp() assert first_ts is not None assert isinstance(first_ts, float) @@ -224,7 +224,7 @@ def test_first_methods() -> None: def test_find_closest() -> None: """Test find_closest method in TimedSensorReplay""" - odom_store = replay.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + odom_store = LegacyPickleStore("unitree_office_walk/odom", autocast=Odometry.from_msg) # Get some reference timestamps timestamps = [] diff --git a/dimos/visualization/rerun/bridge.py b/dimos/visualization/rerun/bridge.py index f4a7e6f226..f2e3e51d08 100644 --- a/dimos/visualization/rerun/bridge.py +++ b/dimos/visualization/rerun/bridge.py @@ -41,13 +41,13 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig -from dimos.msgs.sensor_msgs.PointCloud2 import register_colormap_annotation from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.protocol.pubsub.patterns import Glob, pattern_matches from dimos.protocol.pubsub.spec import SubscribeAllCapable from dimos.utils.logging_config import setup_logger +from dimos.visualization.rerun.init import rerun_init -RERUN_GRPC_PORT = 9876 +RERUN_GRPC_PORT = 9877 RERUN_WEB_PORT = 9090 # TODO OUT visual annotations @@ -182,6 +182,9 @@ class Config(ModuleConfig): # Static items logged once after start. Maps entity_path -> callable(rr) returning Archetype static: dict[str, Callable[[Any], Archetype]] = field(default_factory=dict) + grpc_port: int = RERUN_GRPC_PORT + web_port: int = RERUN_WEB_PORT + # Per-entity max update rate (Hz). Entities not listed are unthrottled. # Use for heavy entities to prevent viewer backpressure. max_hz: dict[str, float] = field(default_factory=dict) @@ -215,6 +218,7 @@ class RerunBridgeModule(Module): config: Config + # TODO this doesn't belong here, either hardcode it or put it to rerun bridge config GV_SCALE = 100.0 # graphviz inches to rerun screen units MODULE_RADIUS = 30.0 CHANNEL_RADIUS = 20.0 @@ -301,28 +305,30 @@ def start(self) -> None: } # Initialize and spawn Rerun viewer - rr.init("dimos") + rerun_init("dimos") if self.config.viewer_mode == "native": try: import rerun_bindings rerun_bindings.spawn( - port=RERUN_GRPC_PORT, + port=self.config.grpc_port, executable_name="dimos-viewer", memory_limit=self.config.memory_limit, ) + rr.connect_grpc(f"rerun+http://127.0.0.1:{self.config.grpc_port}/proxy") except ImportError: - pass # dimos-viewer not installed + rr.spawn(connect=True, memory_limit=self.config.memory_limit) except Exception: logger.warning( "dimos-viewer found but failed to spawn, falling back to stock rerun", exc_info=True, ) - rr.spawn(connect=True, memory_limit=self.config.memory_limit) + rr.spawn(connect=True, memory_limit=self.config.memory_limit) elif self.config.viewer_mode == "web": server_uri = rr.serve_grpc() rr.serve_web_viewer(connect_to=server_uri, open_browser=False) + elif self.config.viewer_mode == "connect": rr.connect_grpc(self.config.connect_url) # "none" - just init, no viewer (connect externally) @@ -330,9 +336,6 @@ def start(self) -> None: if self.config.blueprint: rr.send_blueprint(_with_graph_tab(self.config.blueprint())) - # Register colormap for viewer-side color resolution (PointCloud2 class_ids) - register_colormap_annotation("turbo") - # Start pubsubs and subscribe to all messages for pubsub in self.config.pubsubs: logger.info(f"bridge listening on {pubsub.__class__.__name__}") diff --git a/dimos/visualization/rerun/init.py b/dimos/visualization/rerun/init.py new file mode 100644 index 0000000000..4ecc3550ac --- /dev/null +++ b/dimos/visualization/rerun/init.py @@ -0,0 +1,27 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared Rerun initialization. Call ``rerun_init()`` instead of ``rr.init()``.""" + +from __future__ import annotations + +import rerun as rr + +from dimos.msgs.sensor_msgs.PointCloud2 import register_colormap_annotation + + +def rerun_init(app_id: str = "dimos", **kwargs: object) -> None: + """Initialize Rerun with standard defaults.""" + rr.init(app_id, **kwargs) # type: ignore[arg-type] + register_colormap_annotation("turbo") diff --git a/docs/agents/testing.md b/docs/agents/testing.md index 45614c81d2..773316b4ef 100644 --- a/docs/agents/testing.md +++ b/docs/agents/testing.md @@ -71,7 +71,7 @@ def test_query(store: SqliteStore) -> None: assert store.stream("video", Image).count() > 0 def test_search(store: SqliteStore) -> None: - results = store.stream("video", Image).limit(5).fetch() + results = store.stream("video", Image).limit(5).to_list() assert len(results) == 5 ``` diff --git a/docs/capabilities/memory/algo_comparison.md b/docs/capabilities/memory/algo_comparison.md new file mode 100644 index 0000000000..ffb1cb3dbf --- /dev/null +++ b/docs/capabilities/memory/algo_comparison.md @@ -0,0 +1,157 @@ + +Example on how we can use memory to compare two algos on a real data. + +```python +import time + +from dimos.memory2.store.sqlite import SqliteStore +from dimos.memory2.transform import throttle +from dimos.memory2.vis import color +from dimos.memory2.vis.plot.elements import Style +from dimos.memory2.vis.plot.plot import Plot +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.utils.data import get_data + +store = SqliteStore(path=get_data("go2_bigoffice.db")) +images = store.streams.color_image + + +def slow_brightness(img: Image) -> float: + """Naive full-pixel mean, for reference.""" + max_val = 65535.0 if img.format in (ImageFormat.GRAY16, ImageFormat.DEPTH16) else 255.0 + return float(img.data.mean() / max_val) + + +def timed(fn): + """Wrap ``fn(img) -> float`` so it returns execution time in ms instead. + + Touches ``img.data.shape`` first so the lazy blob load isn't counted. + """ + def _fn(obs): + img = obs.data + _ = img.data # warm lazy load, this actually loads from sql + t0 = time.perf_counter() + fn(img) + return (time.perf_counter() - t0) * 1000 + return _fn + + +plot = Plot() + +plot.add( + images.transform(throttle(0.5)).map_data(lambda obs: obs.data.brightness), + label="brightness", + color=color.blue, +) + +plot.add( + images.transform(throttle(0.5)).map_data(lambda obs: slow_brightness(obs.data)), + label="slow_brightness", + style=Style.dashed, + color=color.red, +) + +plot.add( + images.transform(throttle(0.5)).map_data(timed(lambda img: img.brightness)), + label="brightness (ms)", + axis="time", + color=color.blue, + opacity=0.5, +) + +plot.add( + images.transform(throttle(0.5)).map_data(timed(slow_brightness)), + label="slow_brightness (ms)", + axis="time", + color=color.red, + opacity=0.5, +) + + +plot.to_svg("assets/plot_brightness_algo.svg") + +delta_plot = Plot() + +delta_plot.add( + images.transform(throttle(0.5)).map_data( + lambda obs: obs.data.brightness - slow_brightness(obs.data) + ), + label="delta (fast - slow)", + color=color.green, +) + +delta_plot.to_svg("assets/plot_brightness_algo_delta.svg") + +``` + +![output](assets/plot_brightness_algo.svg) + +![output](assets/plot_brightness_algo_delta.svg) + +We see that new algo is strictly better. + +Above example loads the same data and iterates it for each plot line, it's a bit slow but readable and easy to write during development. Below is an example that generates the same results but more efficiently + +```python +import time + +from dimos.memory2.store.sqlite import SqliteStore +from dimos.memory2.transform import throttle +from dimos.memory2.vis import color +from dimos.memory2.vis.plot.elements import HLine, Series, Style +from dimos.memory2.vis.plot.plot import Plot +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.utils.data import get_data + +store = SqliteStore(path=get_data("go2_bigoffice.db")) +images = store.streams.color_image + + +def slow_brightness(img: Image) -> float: + """Naive full-pixel mean, for reference.""" + max_val = 65535.0 if img.format in (ImageFormat.GRAY16, ImageFormat.DEPTH16) else 255.0 + return float(img.data.mean() / max_val) + + +def timed(fn, img): + """Call ``fn(img)`` once, return (value, ms).""" + t0 = time.perf_counter() + v = fn(img) + return v, (time.perf_counter() - t0) * 1000 + + +def compute(obs): + """One pass per image: both values, both times, delta.""" + img = obs.data + _ = img.data # warm lazy load so only compute is timed + fast_v, fast_ms = timed(lambda i: i.brightness, img) + slow_v, slow_ms = timed(slow_brightness, img) + return { + "fast": fast_v, + "slow": slow_v, + "fast_ms": fast_ms, + "slow_ms": slow_ms, + "delta": fast_v - slow_v, + } + + +# Iterate the source once; all five series below read from the cache. +metrics = images.transform(throttle(0.5)).map_data(compute).materialize() + +plot = Plot() +plot.add(metrics.map_data(lambda o: o.data["fast"]), + label="brightness", color=color.blue) +plot.add(metrics.map_data(lambda o: o.data["slow"]), + label="slow_brightness", color=color.red, style=Style.dashed) +plot.add(metrics.map_data(lambda o: o.data["fast_ms"]), + label="brightness (ms)", axis="time", color=color.blue, opacity=0.5) +plot.add(metrics.map_data(lambda o: o.data["slow_ms"]), + label="slow_brightness (ms)", axis="time", color=color.red, opacity=0.5) +plot.to_svg("assets/plot_brightness_algo.svg") + +delta_plot = Plot() +delta_plot.add(metrics.map_data(lambda o: o.data["delta"]), + label="delta (fast - slow)", color=color.green) +delta_plot.add(HLine(y=0, style=Style.dashed, color=color.red)) +delta_plot.to_svg("assets/plot_brightness_algo_delta.svg") +``` diff --git a/docs/capabilities/memory/assets/.gitattributes b/docs/capabilities/memory/assets/.gitattributes new file mode 100644 index 0000000000..769e3570f2 --- /dev/null +++ b/docs/capabilities/memory/assets/.gitattributes @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e3041f7efbfcb80f33ede705be7c37bc41b68329bd433ccac0912675d09449e +size 257 diff --git a/docs/capabilities/memory/assets/all_images.png b/docs/capabilities/memory/assets/all_images.png new file mode 100644 index 0000000000..721fea793b --- /dev/null +++ b/docs/capabilities/memory/assets/all_images.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a873f01d1d9cf853cbadff9f63c49310817c126ca2249c165f857084227058b7 +size 7355457 diff --git a/docs/capabilities/memory/assets/brightness.svg b/docs/capabilities/memory/assets/brightness.svg new file mode 100644 index 0000000000..93cde2f25e --- /dev/null +++ b/docs/capabilities/memory/assets/brightness.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d7d10930a71a727d0783b06114c690c23735bb24c4c2ab8dc902f3aabd22125 +size 333222 diff --git a/docs/capabilities/memory/assets/color_image.svg b/docs/capabilities/memory/assets/color_image.svg new file mode 100644 index 0000000000..09aa3da03c --- /dev/null +++ b/docs/capabilities/memory/assets/color_image.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:056e134b19ddd55d69d97f50f5f0ed9ffc324dcbb874ceee49405dee3a155b4f +size 828447 diff --git a/docs/capabilities/memory/assets/embedding.svg b/docs/capabilities/memory/assets/embedding.svg new file mode 100644 index 0000000000..d0b23cc68f --- /dev/null +++ b/docs/capabilities/memory/assets/embedding.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6724056792c73ea06d6b17122b24a963ec801a4f0afdb65701e1e63f93691271 +size 210886 diff --git a/docs/capabilities/memory/assets/embedding_focused.svg b/docs/capabilities/memory/assets/embedding_focused.svg new file mode 100644 index 0000000000..897357f05e --- /dev/null +++ b/docs/capabilities/memory/assets/embedding_focused.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:71d748748f18881a767cc530af1d6e9acdec91072ba54ac6b1ad154e580a3ebb +size 34640 diff --git a/docs/capabilities/memory/assets/grid.png b/docs/capabilities/memory/assets/grid.png new file mode 100644 index 0000000000..27d7dd939f --- /dev/null +++ b/docs/capabilities/memory/assets/grid.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fa3a8d7ab0b11f50b08377a005ec9797f35283ac8fb53f5d0dde88cda23aeb62 +size 1858645 diff --git a/docs/capabilities/memory/assets/peak_detections.svg b/docs/capabilities/memory/assets/peak_detections.svg new file mode 100644 index 0000000000..255d4fc70a --- /dev/null +++ b/docs/capabilities/memory/assets/peak_detections.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:522ec0640c685f5637743ee4a4e9b6691be7838d6a1d40d9295e4ffde8c961f1 +size 35245 diff --git a/docs/capabilities/memory/assets/peak_space.svg b/docs/capabilities/memory/assets/peak_space.svg new file mode 100644 index 0000000000..7a6e729166 --- /dev/null +++ b/docs/capabilities/memory/assets/peak_space.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9b7e8ea30b99cf2a86da4b9c25bba67de9f7d2e23e854ea347af4625b63f2231 +size 33190 diff --git a/docs/capabilities/memory/assets/plants.png b/docs/capabilities/memory/assets/plants.png new file mode 100644 index 0000000000..4edff3ffd8 --- /dev/null +++ b/docs/capabilities/memory/assets/plants.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4732b965e06e47aa9b1f8e8cfdca4aa859832442719f46c9a7d77b5e202dcfc4 +size 173223 diff --git a/docs/capabilities/memory/assets/plants_auto.png b/docs/capabilities/memory/assets/plants_auto.png new file mode 100644 index 0000000000..8647a3458f --- /dev/null +++ b/docs/capabilities/memory/assets/plants_auto.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:744d9661017781864661621af6aab25625408b602b84a5eec228ada445587d59 +size 1011358 diff --git a/docs/capabilities/memory/assets/plants_meaningful.png b/docs/capabilities/memory/assets/plants_meaningful.png new file mode 100644 index 0000000000..77b864981d --- /dev/null +++ b/docs/capabilities/memory/assets/plants_meaningful.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:02a9b707e75f17b9557f004f77a578d2a3010f6402d609dcfd5985da9082e490 +size 132383 diff --git a/docs/capabilities/memory/assets/plants_peak_detections.png b/docs/capabilities/memory/assets/plants_peak_detections.png new file mode 100644 index 0000000000..06193bccfd --- /dev/null +++ b/docs/capabilities/memory/assets/plants_peak_detections.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cc9fcf23e5b72e4d0005e54487d52860460662c82a08a9d9582286349016d98c +size 1564191 diff --git a/docs/capabilities/memory/assets/plot_brightness_algo.svg b/docs/capabilities/memory/assets/plot_brightness_algo.svg new file mode 100644 index 0000000000..11942261a1 --- /dev/null +++ b/docs/capabilities/memory/assets/plot_brightness_algo.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c738e3425aded30ed0c0af864837d6a6a526c4abf9655e426b9cc96eeed10520 +size 80187 diff --git a/docs/capabilities/memory/assets/plot_brightness_algo_delta.svg b/docs/capabilities/memory/assets/plot_brightness_algo_delta.svg new file mode 100644 index 0000000000..f6719d0beb --- /dev/null +++ b/docs/capabilities/memory/assets/plot_brightness_algo_delta.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5032a62e9891e849a568389b1be253c895ee399fec3f6010e0339bfd6faee5e5 +size 35161 diff --git a/docs/capabilities/memory/assets/plot_colors.svg b/docs/capabilities/memory/assets/plot_colors.svg new file mode 100644 index 0000000000..fc194cc8ca --- /dev/null +++ b/docs/capabilities/memory/assets/plot_colors.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:802a0e541d26b370cb5eacec25461d1c708330ceff973f54ed4330cba6441db9 +size 76264 diff --git a/docs/capabilities/memory/assets/plot_named.svg b/docs/capabilities/memory/assets/plot_named.svg new file mode 100644 index 0000000000..00986f6d33 --- /dev/null +++ b/docs/capabilities/memory/assets/plot_named.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1d48cf86d1faf94abed5aa5271b9a42feb099960f5a71ce5bcd0f5a5e4b68471 +size 31180 diff --git a/docs/capabilities/memory/assets/plot_plantness.svg b/docs/capabilities/memory/assets/plot_plantness.svg new file mode 100644 index 0000000000..c3d3b529a9 --- /dev/null +++ b/docs/capabilities/memory/assets/plot_plantness.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:058b8f99cb88b11fbd77fbbaa36c104a78ec6868b983f2499cda9f13e83b13f8 +size 28561 diff --git a/docs/capabilities/memory/assets/plot_plantness_autopeaks.svg b/docs/capabilities/memory/assets/plot_plantness_autopeaks.svg new file mode 100644 index 0000000000..a33a048c70 --- /dev/null +++ b/docs/capabilities/memory/assets/plot_plantness_autopeaks.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2502a08c864b23ba933adc4cd06a721f4d0aaa0631e6b73d346c7fa5d7ad5965 +size 32085 diff --git a/docs/capabilities/memory/assets/plot_plantness_autopeaks2.svg b/docs/capabilities/memory/assets/plot_plantness_autopeaks2.svg new file mode 100644 index 0000000000..811f432047 --- /dev/null +++ b/docs/capabilities/memory/assets/plot_plantness_autopeaks2.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:660595c2d7bf72df91c62a22f4f7a9ff42274debbf21a2d62047d389ede5b4e5 +size 35079 diff --git a/docs/capabilities/memory/assets/plot_plantness_autopeaks_map.svg b/docs/capabilities/memory/assets/plot_plantness_autopeaks_map.svg new file mode 100644 index 0000000000..a0a5552381 --- /dev/null +++ b/docs/capabilities/memory/assets/plot_plantness_autopeaks_map.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c8b6e17ad03a5202097b1d1eacbc24ecaa2ee5d1ff39f1b5404d8cb6d09ef584 +size 66566 diff --git a/docs/capabilities/memory/assets/plot_plantness_brightness.svg b/docs/capabilities/memory/assets/plot_plantness_brightness.svg new file mode 100644 index 0000000000..95525661f8 --- /dev/null +++ b/docs/capabilities/memory/assets/plot_plantness_brightness.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:19bf4a9034729e87056d17d3c8afb34e3e3ee57a27829b4157f3547a7e56995e +size 48452 diff --git a/docs/capabilities/memory/assets/plot_plantness_gap_fill.svg b/docs/capabilities/memory/assets/plot_plantness_gap_fill.svg new file mode 100644 index 0000000000..119647b9db --- /dev/null +++ b/docs/capabilities/memory/assets/plot_plantness_gap_fill.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70898fb90c688b522cbb4cfc4288e024656ecdf4392b1876acf3afb0da92b968 +size 29091 diff --git a/docs/capabilities/memory/assets/plot_plantness_marked.svg b/docs/capabilities/memory/assets/plot_plantness_marked.svg new file mode 100644 index 0000000000..c5212ec2a4 --- /dev/null +++ b/docs/capabilities/memory/assets/plot_plantness_marked.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:40a930f7b71d66eb272a399833e855b555de55ba9dd6fe09af79397aa11a860b +size 30531 diff --git a/docs/capabilities/memory/assets/plot_plantness_significant.svg b/docs/capabilities/memory/assets/plot_plantness_significant.svg new file mode 100644 index 0000000000..d298fbbd92 --- /dev/null +++ b/docs/capabilities/memory/assets/plot_plantness_significant.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5039f3ae678ef9c5986d1e93d54b12cc1d268278b2bd8037e02c11ea7f7848ce +size 29491 diff --git a/docs/capabilities/memory/assets/plot_robot_data.svg b/docs/capabilities/memory/assets/plot_robot_data.svg new file mode 100644 index 0000000000..ff8b3d33e1 --- /dev/null +++ b/docs/capabilities/memory/assets/plot_robot_data.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d95d619243b3636c4c8a92aaa39d648b68b967f1e2fc5d487db385cf6f3151c0 +size 109884 diff --git a/docs/capabilities/memory/assets/speed.svg b/docs/capabilities/memory/assets/speed.svg new file mode 100644 index 0000000000..da6384e339 --- /dev/null +++ b/docs/capabilities/memory/assets/speed.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f97972f6f8be8393859ee5b6234094c28d6fb6d0720926eeb7826d0769d1d32 +size 828282 diff --git a/docs/capabilities/memory/demo_rerun.py b/docs/capabilities/memory/demo_rerun.py new file mode 100644 index 0000000000..3b0bcf5536 --- /dev/null +++ b/docs/capabilities/memory/demo_rerun.py @@ -0,0 +1,88 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pickle +from typing import TypeVar + +from dimos.mapping.occupancy.inflation import simple_inflate +from dimos.mapping.pointclouds.occupancy import general_occupancy +from dimos.memory2.store.sqlite import SqliteStore +from dimos.memory2.transform import normalize, smooth +from dimos.memory2.vis.color import Color +from dimos.memory2.vis.space.elements import Point +from dimos.memory2.vis.space.space import Space +from dimos.memory2.vis.utils import mosaic +from dimos.models.embedding.clip import CLIPModel +from dimos.utils.data import get_data + +T = TypeVar("T") + +store = SqliteStore(path=get_data("go2_bigoffice.db")) +global_map = pickle.loads(get_data("unitree_go2_bigoffice_map.pickle").read_bytes()) +costmap = simple_inflate(general_occupancy(global_map), 0.05) + +clip = CLIPModel() +embedded = store.streams.color_image_embedded + +drawing = Space() +# drawing.add(costmap) +# drawing.add(global_map) + +search_text = "robot" +search_vector = clip.embed_text(search_text) + +# store.streams.color_image.transform(speed()).transform(smooth(30)).transform(normalize()).tap( +# lambda obs: drawing.add(Point(obs.pose_stamped, color=Color.from_cmap("turbo", obs.data))) +# ).drain() + +store.streams.color_image.map(lambda obs: obs.derive(data=obs.data.brightness)).transform( + smooth(30) +).transform(normalize()).tap( + lambda obs: drawing.add(Point(obs.pose_stamped, color=Color.from_cmap("turbo", obs.data))) +).drain() + + +# # fmt: off +# embedded.search(search_vector, k=10) \ +# .tap(drawing.add) \ +# .tap(lambda obs: drawing.add(store.streams.lidar.at(obs.ts).first().data)).drain() +# # fmt: on + +from dimos.models.vl.moondream import MoondreamVlModel + +moondream = MoondreamVlModel() +moondream.start() + +from dimos.models.vl.florence import Florence2Model + +florence = Florence2Model() +florence.start() + +search_results = ( + embedded.search(search_vector, k=18) + .tap(lambda obs: drawing.add(obs.derive(data=florence.caption(obs.data)))) + .map(lambda obs: obs.derive(data=moondream.query_detections(obs.data, search_text))) + .materialize() +) + +drawing.add(mosaic(search_results)) + +# fmt: off +search_results \ + .tap(drawing.add) \ + .tap(lambda obs: drawing.add(store.streams.lidar.at(obs.ts).first().data)) \ + .drain() +# fmt: on + +drawing.to_rerun() diff --git a/docs/capabilities/memory/index.md b/docs/capabilities/memory/index.md new file mode 100644 index 0000000000..290cdfd37d --- /dev/null +++ b/docs/capabilities/memory/index.md @@ -0,0 +1,210 @@ +
Python + +```python fold session=mem output=none +import pickle +from dimos.mapping.pointclouds.occupancy import general_occupancy, simple_occupancy, height_cost_occupancy +from dimos.mapping.occupancy.inflation import simple_inflate +from dimos.memory2.store.sqlite import SqliteStore +from dimos.memory2.vis.color import Color +from dimos.memory2.transform import downsample, throttle, speed, smooth +from dimos.memory2.vis.space.space import Space +from dimos.utils.data import get_data +from dimos.memory2.vis.space.elements import Point +``` + +
+ +we init our recording, investigate available streams + +```python session=mem +store = SqliteStore(path=get_data("go2_bigoffice.db")) + +for name, stream in store.streams.items(): + print(stream.summary()) +``` + + +``` +Stream("color_image"): 4164 items, 2025-12-26 11:09:08 — 2025-12-26 11:14:00 (292.5s) +Stream("color_image_embedded"): 267 items, 2025-12-26 11:09:12 — 2025-12-26 11:14:00 (288.4s) +Stream("lidar"): 2251 items, 2025-12-26 11:09:08 — 2025-12-26 11:14:00 (292.3s) +Stream("odom"): 5465 items, 2025-12-26 11:09:08 — 2025-12-26 11:14:00 (292.5s) +``` + +Any stream is drawable + +```python session=mem output=none +global_map = pickle.loads(get_data("unitree_go2_bigoffice_map.pickle").read_bytes()) + +drawing = Space() + +# this is not necessary but we use a global map as a nice base for a drawing +drawing.add(global_map) +drawing.add(store.streams.color_image) +drawing.to_svg("assets/color_image.svg") +``` + + + +our drawing system applies turbo color scheme to timestamps by default + +![output](assets/color_image.svg) + +we can create new streams by querying existing streams, and we can save, further transform or draw those + +```python session=mem output=none + +drawing = Space() +drawing.add(global_map) + +drawing.add( + store.streams.color_image \ + # calculate speed in m/s by checking distance between poses and timestamps of observations + .transform(speed()) \ + # rolling window average + .transform(smooth(50))) + +drawing.to_svg("assets/speed.svg") +``` + +![output](assets/speed.svg) + +we can do all kinds of things with this, for example map out room lighting + +```python session=mem output=none +drawing = Space() +drawing.add(global_map) + +drawing.add( + store.streams.color_image \ + # here we will take 4fps because brightness calculation loads the actual image + # observation.data triggers another db query to fetch the data + # otherwise observations only hold positions and timestamps + .transform(throttle(0.25)) \ + # we calculate brightness + .map(lambda obs: obs.derive(data=obs.data.brightness))) + +drawing.to_svg("assets/brightness.svg") +``` + +![output](assets/brightness.svg) + +So knowing above, we can create embeddings for the full stream, + +```python session=mem skip +from dimos.models.embedding.clip import CLIPModel +from dimos.msgs.sensor_msgs.Image import Image +from dimos.memory2.transform import QualityWindow +from dimos.memory2.embed import EmbedImages + +embedded = store.stream("color_image_embedded", Image) +clip = CLIPModel() + +# Downsample to 2Hz, filter dark images, then embed +pipeline = ( + store.streams.color_image.filter(lambda obs: obs.data.brightness > 0.1) + .transform(QualityWindow(lambda img: img.sharpness, window=0.5)) + .transform(EmbedImages(clip)) + .save(embedded) +) + +print(pipeline) + +``` + +this pipeline is ready to execute by lazy, we can execute it by iterating, or calling .drain() + +```python skip +for obs in pipeline: + print(f" [{count}] ts={obs.ts:.2f} pose={obs.pose}") +``` + +let's query it! + +```python session=mem output=none +from dimos.models.embedding.clip import CLIPModel + +drawing = Space() +drawing.add(global_map) + +clip = CLIPModel() +search_vector = clip.embed_text("shop") +drawing.add(store.streams.color_image_embedded.search(search_vector)) + +drawing.to_svg("assets/embedding.svg") +``` + +![output](assets/embedding.svg) + +We don't really have to deal with the whole global map actually, let's get top 10 embeddings, and render only lidar around those. + +```python session=mem output=none +from dimos.models.embedding.clip import CLIPModel +from dimos.mapping.voxels import VoxelMapTransformer +drawing = Space() + +# this is defined here, but not executed +matches = store.streams.color_image_embedded.search(search_vector, k=30) + +print(matches) # Stream("color_image_embedded") | vector_search(k=50) + +# here we execute it once, and feed it into a global mapper, then draw the map +drawing.add( + matches.map(lambda obs: store.streams.lidar.at(obs.ts).last()) \ + .transform(VoxelMapTransformer()) \ + .last().data) + +# then we add matches to the map +drawing.add(matches) + +drawing.to_svg("assets/embedding_focused.svg") +``` + + +``` +Stream("color_image_embedded") | vector_search(k=30) +08:19:54.129 [inf][dimos/mapping/voxels.py ] VoxelGrid using device: CUDA:0 +``` + +![output](assets/embedding_focused.svg) + +
Python + +```python fold session=mem +import matplotlib +import matplotlib.pyplot as plt +import math + +def plot_mosaic(frames, path, cols=5): + matplotlib.use("Agg") + rows = math.ceil(len(frames) / cols) + aspect = frames[0].width / frames[0].height + fig_w, fig_h = 12, 12 * rows / (cols * aspect) + + fig, axes = plt.subplots(rows, cols, figsize=(fig_w, fig_h)) + fig.patch.set_facecolor("black") + for i, ax in enumerate(axes.flat): + if i < len(frames): + ax.imshow(frames[i].data) + for spine in ax.spines.values(): + spine.set_color("black") + spine.set_linewidth(0) + ax.set_xticks([]) + ax.set_yticks([]) + else: + ax.axis("off") + plt.subplots_adjust(wspace=0.02, hspace=0.02, left=0, right=1, top=1, bottom=0) + plt.savefig(path, facecolor="black", dpi=100, bbox_inches="tight", pad_inches=0) + plt.close() + +``` + +
+ +let's view those images + +```python session=mem +plot_mosaic(matches.map(lambda obs: obs.data).to_list(), "assets/grid.png") +``` + +![output](assets/grid.png) diff --git a/docs/capabilities/memory/plot.md b/docs/capabilities/memory/plot.md new file mode 100644 index 0000000000..3a3716a8fa --- /dev/null +++ b/docs/capabilities/memory/plot.md @@ -0,0 +1,418 @@ + +## color cycle + +You add streams, system auto assigns colors + +```python session=plot output=none +import math +import random + +from dimos.memory2.vis.plot.elements import Series +from dimos.memory2.vis.plot.plot import Plot + +rng = random.Random(42) +xs = [i * 0.1 for i in range(120)] + +color_check = Plot() +for i in range(14): + phase = rng.uniform(0, 2 * math.pi) + freq = rng.uniform(0.5, 1.8) + amp = rng.uniform(0.6, 1.4) + offset = i * 0.5 # vertical separation so curves don't overlap + ys = [amp * math.sin(freq * x + phase) + offset for x in xs] + + color_check.add(Series(ts=xs, values=ys, label=f"curve {i + 1}")) + +color_check.to_svg("assets/plot_colors.svg") +``` + +![output](assets/plot_colors.svg) + +named colors can also be used explicitly. when you pin a series to one of +the named colors, the auto-cycle excludes it for the remaining series, so +you never end up with two lines that share a color by accident. + +```python session=plot output=none +from dimos.memory2.vis import color +from dimos.memory2.vis.plot.elements import Series, HLine, Style + +p = Plot() +# auto → blue +p.add(Series(ts=xs, values=[math.sin(x) for x in xs])) +# explicit green, dotted +p.add(Series(ts=xs, values=[math.cos(x) for x in xs], color=color.red, style=Style.dotted)) +# auto → yellow (red is excluded) +p.add(Series(ts=xs, values=[math.sin(2 * x) for x in xs])) +# explicit color +p.add(HLine(y=0, style=Style.dashed, opacity=0.5, color="#ff0000")) +p.to_svg("assets/plot_named.svg") +``` + +![output](assets/plot_named.svg) + +## speed plot + +you can assign different axes to different time series, label them etc + +```python session=robotdata output=none +from dimos.memory2.store.sqlite import SqliteStore +from dimos.memory2.transform import smooth, speed, throttle +from dimos.memory2.vis import color +from dimos.memory2.vis.plot.elements import Series +from dimos.memory2.vis.plot.plot import Plot +from dimos.utils.data import get_data + +store = SqliteStore(path=get_data("go2_bigoffice.db")) +images = store.streams.color_image + +plot = Plot() +plot.add( + images.transform(speed()).transform(smooth(40)), + label="speed (m/s)", + opacity=0.75 +) + +plot.add( + images.transform(throttle(0.5)).map_data(lambda obs: obs.data.brightness).transform(smooth(10)), + label="brightness", + color=color.blue, +) + +plot.add( + images.transform(throttle(0.5)).scan_data(images.first().ts, lambda state, obs: [state, obs.ts - state]), + label="time", + axis="time", + opacity=0.5 +) + +plot.to_svg("assets/plot_robot_data.svg") +``` + +![output](assets/plot_robot_data.svg) + +## Semantic search + +Let's find some plants! + +```python session=robotdata +from dimos.memory2.vis.plot.elements import Series, HLine, Style +from dimos.memory2.vis import color +from dimos.memory2.transform import normalize, smooth_time + +from dimos.models.embedding.clip import CLIPModel +clip = CLIPModel() +search_vector = clip.embed_text("plant") + +# we will cache this into memory since it takes a second, +# and use it to play with graphing +plantness_query = ( + store.streams.color_image_embedded + .search(search_vector) + # search() returns observations sorted by similarity, we re-sort by time + .order_by("ts") +) + +# we've built our query +print(plantness_query) + +# we evaluate it into a in-memory stream, +# since we want to further process/plot multiple times +plantness_query_materialized = plantness_query.materialize() + +print(plantness_query_materialized) +print(plantness_query_materialized.summary()) + +# let's create a numerical stream +plantness_similarity = plantness_query_materialized.map_data(lambda obs: obs.similarity).materialize() + +plot = Plot() + +plot.add(plantness_similarity, + label="plant-ness", + color=color.green, +) + +plot.to_svg("assets/plot_plantness.svg") +``` + + +``` +Stream("color_image_embedded") | vector_search() | order_by(ts) +Stream("materialize") +Stream("materialize"): 267 items, 2025-12-26 11:09:12 — 2025-12-26 11:14:00 (288.4s) +``` + +![output](assets/plot_plantness.svg) + +We can be pretty sure the robot saw some plants by peaks at beginning and end of data, but this graph doesn't look great, why? + +Embeddings are calculated according to some minimum picture brightness. Completely dark images are both useless and also semantically close to everything. + +Let's investigate how our embedding stream relates to image brightness: + +```python session=robotdata + +plot = Plot() + +plot.add(plantness_similarity, + label="plant-ness", + color=color.green, +) + +plot.add( + images.transform(throttle(0.5)).map_data(lambda obs: obs.data.brightness), + label="brightness", + axis="brightness" +) + +plot.add(HLine(y=0.15, style=Style.dashed, color=color.red)) + +plot.to_svg("assets/plot_plantness_brightness.svg") +``` + +![output](assets/plot_plantness_brightness.svg) +We see that stuff isn't embedded below some minimum brightness. + +Let's now fill the gaps in our semantic graph a bit, looks super ugly above, we will tell plotter to consider unmapped values as zero and connect values that are within 7.5 seconds, smooth with 5 second time window, and normalize the data + +```python session=robotdata + +plot = Plot() + +plot.add( + plantness_similarity \ + .transform(smooth_time(5.0)) \ + .transform(normalize()), \ + label="plant-ness", + color=color.green, + gap_fill=0.0, + connect=7.5 +) + +plot.to_svg("assets/plot_plantness_gap_fill.svg") + +``` + +![output](assets/plot_plantness_gap_fill.svg) + +Looks better, these are some very obvious peaks, I'm curious let's see what was captured then. + +Let's auto-detect the peaks, extract images from those moments, and run a 2D detector + +```python session=robotdata +from dimos.mapping.voxels import VoxelMapTransformer +from dimos.memory2.vis.space.space import Space +from dimos.memory2.transform import peaks +from dimos.memory2.vis.color import ColorRange +from dimos.memory2.vis.plot.elements import VLine +from dimos.memory2.vis.utils import mosaic +from dimos.memory2.stream import Stream +from itertools import chain + +semantic_peaks = plantness_query_materialized.transform(peaks(key=lambda obs: obs.similarity, distance=1.0)) + +# load all lidar frames captured in the readius around the semantic peaks +# feed them into a global mapper to get a single pointcloud around our areas of interest +global_map = semantic_peaks \ + .map(lambda obs: store.streams.lidar.near(obs.pose_stamped, radius=0.5).first()) \ + .transform(VoxelMapTransformer()) \ + .last().data + +drawing = Space() +drawing.add(global_map) +drawing.add(semantic_peaks) +drawing.to_svg("assets/plot_plantness_autopeaks_map.svg") + +peakColor = ColorRange("turbo") +for i, p in enumerate(semantic_peaks): + print(f"t={p.ts - plantness_similarity.first().ts:6.1f}s score={p.similarity:.3f} prominence={p.tags['peak_prominence']:.3f}") + plot.add(VLine(p.ts, color=peakColor(i))) + +plot.to_svg("assets/plot_plantness_autopeaks.svg") + +from dimos.models.vl.moondream import MoondreamVlModel +moondream = MoondreamVlModel() +moondream.start() + +# peaks is still a stream of image observations (with prominence and semantic similarity metadata) +# so we can just draw it directly via mosaic that takes image streams +m = mosaic(semantic_peaks.map_data(lambda obs: moondream.query_detections(obs.data, "plant"))) + +m.data.save("assets/plants_auto.png") +``` + + +``` +14:59:33.042 [inf][dimos/mapping/voxels.py ] VoxelGrid using device: CUDA:0 +t= 14.1s score=0.224 prominence=0.031 +t= 26.3s score=0.225 prominence=0.033 +t= 32.7s score=0.224 prominence=0.022 +t= 37.0s score=0.259 prominence=0.067 +t= 60.6s score=0.227 prominence=0.031 +t= 61.5s score=0.218 prominence=0.026 +t= 76.3s score=0.221 prominence=0.031 +t= 84.0s score=0.223 prominence=0.027 +t= 89.1s score=0.219 prominence=0.020 +t= 162.9s score=0.224 prominence=0.041 +t= 168.0s score=0.219 prominence=0.031 +t= 172.4s score=0.218 prominence=0.020 +t= 240.4s score=0.243 prominence=0.047 +t= 245.6s score=0.224 prominence=0.028 +t= 279.6s score=0.230 prominence=0.030 +``` + + +![output](assets/plot_plantness_autopeaks.svg) + +![output](assets/plants_auto.png) + +![output](assets/plot_plantness_autopeaks_map.svg) + +## Which peaks are significant? + +We got 15 peaks back, we ran a detector on all of them so we can start projecting into 3D but let's say we want some sort of pre-filter of just globally significant peaks. we can see most peaks prominence sits around 0.02–0.03 and only a couple (0.067 at t=37s, 0.047 at t=240s) really stand out. We might want to auto detect those. + +`significant()` replaces that guesswork by thresholding on the distribution of prominences itself. Default outlier detection uses MAD (median absolute deviation) + +Once we put the surviving peaks on the timeline we get two very obvious plants. + +```python session=robotdata +from dimos.memory2.transform import significant + +plot = Plot() +plot.add( + plantness_similarity.transform(smooth_time(5.0)).transform(normalize()), + label="plant-ness", color=color.green, gap_fill=0.0, connect=7.5, +) + +meaningful_peaks = semantic_peaks.transform(significant(method="mad")) + +for peak in meaningful_peaks: + plot.add(VLine(peak.ts, color=color.red)) + +m = mosaic(meaningful_peaks) +m.data.save("assets/plants_meaningful.png") + +plot.to_svg("assets/plot_plantness_significant.svg") +``` + + +![output](assets/plot_plantness_significant.svg) + +![output](assets/plants_meaningful.png) + +Rule of thumb: keep a small absolute floor on `peaks(prominence=...)` to +reject shape-noise, then let `significant()` pick the statistical cutoff. + +## Semantic peak analysis + +Let's focus on those two peaks. load all images in the vicinity of a detection, + +We'll also pull all lidar frames in their vicinity and reconstruct global maps for those areas. + +```python session=robotdata + +from dimos.memory2.vis.space.elements import Point +from dimos.memory2.transform import QualityWindow + +drawing = Space() + +# TODO actual near/at filters need to accept observation streams in order to easily +# reconstruct all frames in vicinity of another stream +# for now for simplicity here we are focusing only on one semantic hotspot. +meaningful_peak = meaningful_peaks.first() + +# load all images captured in the readius around the semantic peak +near_images = images.near(meaningful_peak.pose_stamped, radius=2.5) \ + .filter(lambda obs: obs.data.brightness > 0.1) \ + .transform(QualityWindow(lambda img: img.sharpness, window=0.5)) + +# load all lidar frames captured in the readius around the semantic peak +# feed them into a global mapper to get a single pointcloud around our area of interest +global_map = store.streams.lidar.near(meaningful_peak.pose_stamped, radius=2.5) \ + .transform(VoxelMapTransformer()) \ + .last().data + +# run our global mapper only on lidar frames around the POI +drawing.add(global_map) +drawing.add(meaningful_peak.pose_stamped, color=color.green) + +# run a detector, filter small weird detections +detections = (near_images + .map_data(lambda obs: moondream.query_detections(obs.data, "plant")) + .map_data(lambda obs: obs.data.filter(lambda det: det.bbox_2d_volume() > 3000)) + .filter(lambda obs: len(obs.data) > 0) + .materialize()) + # materialize this stream since we'll want to re-use it later + +drawing.add(detections) +drawing.to_svg("assets/peak_space.svg") + +m = mosaic(detections) +m.data.save("assets/plants_peak_detections.png") +``` + +![output](assets/peak_space.svg) +![output](assets/plants_peak_detections.png) + +## 3D Projection + +```python session=robotdata output=none +from dimos.perception.detection.type.detection3d.imageDetections3DPC import ( + ImageDetections3DPC, +) +from dimos.robot.unitree.go2.connection import ( + _camera_info_static as go2_camerainfo, + BASE_TO_OPTICAL, +) +from dimos.memory2.vis.space.elements import Box3D +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 + +# TODO We need a nicer way to get optical transform for image streams +# depending on the source +def world_to_optical(base_pose): + return -(Transform.from_pose("base_link", base_pose) + BASE_TO_OPTICAL) + +drawing = Space() + +drawing.add(global_map) + +drawing.add(detections) + +camera_info = go2_camerainfo() + +detections3d = (detections + .map_data(lambda obs: ImageDetections3DPC.from_2d( + obs.data, + global_map, + camera_info, + world_to_optical(obs.pose_stamped), + )) + .filter(lambda obs: len(obs.data) > 0)) + +# TODO detection3d needs to be a natural thing to render +for obs in detections3d: + for d3d in obs.data: + aabb = d3d.get_bounding_box() + c, e = aabb.get_center(), aabb.get_extent() + drawing.add(Box3D( + center=Pose(float(c[0]), float(c[1]), float(c[2])), + size=Vector3(float(e[0]), float(e[1]), float(e[2])), + color=color.green, label="plant", + )) + +drawing.to_svg("assets/peak_detections.svg") + +``` + +![output](assets/peak_detections.svg) + +# TODO further steps + +- These are 3D bounding boxes with associated pointclouds, render in rerun + +- Some basic statistical outlier filters - we have many overlaping detections here and we can be pretty sure there are plants right of the robot, but unclear about left. + +- Now that we have 3d locations in space, we can load all camera images observing detections in space (not just rely on radius around the embedding peak) see in how many of these images we actually detect an object. (another strategy for false positive filtering) diff --git a/docs/usage/cli.md b/docs/usage/cli.md index 7a25ee4ae3..017b441c7e 100644 --- a/docs/usage/cli.md +++ b/docs/usage/cli.md @@ -16,7 +16,7 @@ dimos [GLOBAL OPTIONS] COMMAND [ARGS] | `--robot-ips` | TEXT | `None` | Multiple robot IPs | | `--simulation` / `--no-simulation` | bool | `False` | Enable MuJoCo simulation | | `--replay` / `--no-replay` | bool | `False` | Use recorded replay data | -| `--replay-dir` | TEXT | `go2_sf_office` | Replay dataset directory name | +| `--replay-db` | TEXT | `go2_bigoffice` | Replay memory2 SQLite database name | | `--new-memory` / `--no-new-memory` | bool | `False` | Clear persistent memory on start | | `--viewer` | `rerun\|rerun-web\|rerun-connect\|foxglove\|none` | `rerun` | Visualization backend | | `--n-workers` | INT | `2` | Number of forkserver workers | diff --git a/flake.nix b/flake.nix index c22b1f7791..7517bc9dd2 100644 --- a/flake.nix +++ b/flake.nix @@ -77,6 +77,7 @@ { vals.pkg=pkgs.python312Packages.pip; flags={}; } { vals.pkg=pkgs.python312Packages.setuptools; flags={}; } { vals.pkg=pkgs.python312Packages.virtualenv; flags={}; } + { vals.pkg=pkgs.uv; flags={}; } { vals.pkg=pkgs.pre-commit; flags={}; } ### Runtime deps diff --git a/pyproject.toml b/pyproject.toml index fa35dd79de..4cc8e3d9f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -248,7 +248,7 @@ dev = [ "watchdog>=3.0.0", # docs - "md-babel-py==1.1.1", + "md-babel-py==1.1.3", # LSP "python-lsp-server[all]==1.14.0", @@ -307,7 +307,6 @@ drone = [ ] dds = [ - "dimos[dev]", "cyclonedds>=0.10.5", ] diff --git a/uv.lock b/uv.lock index aebf6f9055..7357f46359 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'darwin'", @@ -1775,45 +1775,7 @@ cuda = [ { name = "xformers", marker = "platform_machine == 'x86_64'" }, ] dds = [ - { name = "coverage" }, { name = "cyclonedds" }, - { name = "lxml-stubs" }, - { name = "md-babel-py" }, - { name = "mypy" }, - { name = "pandas-stubs" }, - { name = "pre-commit" }, - { name = "py-spy" }, - { name = "pytest" }, - { name = "pytest-asyncio" }, - { name = "pytest-env" }, - { name = "pytest-mock" }, - { name = "pytest-timeout" }, - { name = "python-lsp-ruff" }, - { name = "python-lsp-server", extra = ["all"] }, - { name = "requests-mock" }, - { name = "ruff" }, - { name = "scipy-stubs", version = "1.15.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "scipy-stubs", version = "1.17.1.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "terminaltexteffects" }, - { name = "types-colorama" }, - { name = "types-defusedxml" }, - { name = "types-gevent" }, - { name = "types-greenlet" }, - { name = "types-jmespath" }, - { name = "types-jsonschema" }, - { name = "types-networkx" }, - { name = "types-protobuf" }, - { name = "types-psutil" }, - { name = "types-psycopg2" }, - { name = "types-pysocks" }, - { name = "types-pytz" }, - { name = "types-pyyaml" }, - { name = "types-requests" }, - { name = "types-simplejson" }, - { name = "types-tabulate" }, - { name = "types-tensorflow" }, - { name = "types-tqdm" }, - { name = "watchdog" }, ] dev = [ { name = "coverage" }, @@ -1994,7 +1956,6 @@ requires-dist = [ { name = "cyclonedds", marker = "extra == 'dds'", specifier = ">=0.10.5" }, { name = "dimos", extras = ["agents", "web", "perception", "visualization"], marker = "extra == 'base'" }, { name = "dimos", extras = ["base"], marker = "extra == 'unitree'" }, - { name = "dimos", extras = ["dev"], marker = "extra == 'dds'" }, { name = "dimos-lcm" }, { name = "dimos-lcm", marker = "extra == 'docker'" }, { name = "dimos-viewer", specifier = ">=0.30.0a2" }, @@ -2027,7 +1988,7 @@ requires-dist = [ { name = "lxml-stubs", marker = "extra == 'dev'", specifier = ">=0.5.1,<1" }, { name = "lz4", specifier = ">=4.4.5" }, { name = "matplotlib", marker = "extra == 'manipulation'", specifier = ">=3.7.1" }, - { name = "md-babel-py", marker = "extra == 'dev'", specifier = "==1.1.1" }, + { name = "md-babel-py", marker = "extra == 'dev'", specifier = "==1.1.3" }, { name = "moondream", marker = "extra == 'perception'" }, { name = "mujoco", marker = "extra == 'sim'", specifier = ">=3.3.4" }, { name = "mypy", marker = "extra == 'dev'", specifier = "==1.19.0" }, @@ -4817,11 +4778,11 @@ wheels = [ [[package]] name = "md-babel-py" -version = "1.1.1" +version = "1.1.3" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ca/b3/f814d429edf2848ba03079a3f6da443e6d45b984a7fc22766cb73939d289/md_babel_py-1.1.1.tar.gz", hash = "sha256:826fea96b7415eeaab7607ed5e8eb6d7723f22b9f1005af1b7da12f68766123d", size = 30547, upload-time = "2026-01-20T06:27:32.496Z" } +sdist = { url = "https://files.pythonhosted.org/packages/93/d5/abfe601ac4b3414eb9065d1ac789d93dec3ce4ad9b3fad6d953a6325d7dc/md_babel_py-1.1.3.tar.gz", hash = "sha256:8a6efbd2a2d1e8a1c5e963451cfeab01df9759964eba17b5e69466839ab2d3cd", size = 30631, upload-time = "2026-04-18T10:45:50.11Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c6/4a/dbe497b41432a98c7d4f043cf112410957553ce27e56bc366714695f53a9/md_babel_py-1.1.1-py3-none-any.whl", hash = "sha256:4df82011f123f13b6f9979226e69b0ce06209d94e4c029b60eeb2f54a709d2d0", size = 25836, upload-time = "2026-01-20T06:27:31.514Z" }, + { url = "https://files.pythonhosted.org/packages/b0/37/27dafbaa7d80ce0a31bb9c308f224f7ad5720012870e09d51e9684372daa/md_babel_py-1.1.3-py3-none-any.whl", hash = "sha256:bc432ad570e435e24f01a6f944d9a467d97ec1c374033337a96e0b555fb180e3", size = 25879, upload-time = "2026-04-18T10:45:48.968Z" }, ] [[package]] @@ -5616,6 +5577,7 @@ resolution-markers = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/82/6c/90d3f532f608a03a13c1d6c16c266ffa3828e8011b1549d3b61db2ad59f5/nvidia_cublas_cu12-12.9.1.4-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:7a950dae01add3b415a5a5cdc4ec818fb5858263e9cca59004bb99fdbbd3a5d6", size = 575006342, upload-time = "2025-06-05T20:04:16.902Z" }, + { url = "https://files.pythonhosted.org/packages/77/3c/aa88abe01f3be3d1f8f787d1d33dc83e76fec05945f9a28fbb41cfb99cd5/nvidia_cublas_cu12-12.9.1.4-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:453611eb21a7c1f2c2156ed9f3a45b691deda0440ec550860290dc901af5b4c2", size = 581242350, upload-time = "2025-06-05T20:04:51.979Z" }, { url = "https://files.pythonhosted.org/packages/45/a1/a17fade6567c57452cfc8f967a40d1035bb9301db52f27808167fbb2be2f/nvidia_cublas_cu12-12.9.1.4-py3-none-win_amd64.whl", hash = "sha256:1e5fee10662e6e52bd71dec533fbbd4971bb70a5f24f3bc3793e5c2e9dc640bf", size = 553153899, upload-time = "2025-06-05T20:13:35.556Z" }, ] @@ -5674,6 +5636,7 @@ resolution-markers = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/e0/0279bd94539fda525e0c8538db29b72a5a8495b0c12173113471d28bce78/nvidia_cuda_runtime_cu12-12.9.79-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:83469a846206f2a733db0c42e223589ab62fd2fabac4432d2f8802de4bded0a4", size = 3515012, upload-time = "2025-06-05T20:00:35.519Z" }, + { url = "https://files.pythonhosted.org/packages/bc/46/a92db19b8309581092a3add7e6fceb4c301a3fd233969856a8cbf042cd3c/nvidia_cuda_runtime_cu12-12.9.79-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:25bba2dfb01d48a9b59ca474a1ac43c6ebf7011f1b0b8cc44f54eb6ac48a96c3", size = 3493179, upload-time = "2025-06-05T20:00:53.735Z" }, { url = "https://files.pythonhosted.org/packages/59/df/e7c3a360be4f7b93cee39271b792669baeb3846c58a4df6dfcf187a7ffab/nvidia_cuda_runtime_cu12-12.9.79-py3-none-win_amd64.whl", hash = "sha256:8e018af8fa02363876860388bd10ccb89eb9ab8fb0aa749aaf58430a9f7c4891", size = 3591604, upload-time = "2025-06-05T20:11:17.036Z" }, ] @@ -9930,12 +9893,19 @@ name = "triton" version = "3.6.0" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/44/ba/b1b04f4b291a3205d95ebd24465de0e5bf010a2df27a4e58a9b5f039d8f2/triton-3.6.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6c723cfb12f6842a0ae94ac307dba7e7a44741d720a40cf0e270ed4a4e3be781", size = 175972180, upload-time = "2026-01-20T16:15:53.664Z" }, { url = "https://files.pythonhosted.org/packages/8c/f7/f1c9d3424ab199ac53c2da567b859bcddbb9c9e7154805119f8bd95ec36f/triton-3.6.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a6550fae429e0667e397e5de64b332d1e5695b73650ee75a6146e2e902770bea", size = 188105201, upload-time = "2026-01-20T16:00:29.272Z" }, + { url = "https://files.pythonhosted.org/packages/0f/2c/96f92f3c60387e14cc45aed49487f3486f89ea27106c1b1376913c62abe4/triton-3.6.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:49df5ef37379c0c2b5c0012286f80174fcf0e073e5ade1ca9a86c36814553651", size = 176081190, upload-time = "2026-01-20T16:16:00.523Z" }, { url = "https://files.pythonhosted.org/packages/e0/12/b05ba554d2c623bffa59922b94b0775673de251f468a9609bc9e45de95e9/triton-3.6.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e8e323d608e3a9bfcc2d9efcc90ceefb764a82b99dea12a86d643c72539ad5d3", size = 188214640, upload-time = "2026-01-20T16:00:35.869Z" }, + { url = "https://files.pythonhosted.org/packages/17/5d/08201db32823bdf77a0e2b9039540080b2e5c23a20706ddba942924ebcd6/triton-3.6.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:374f52c11a711fd062b4bfbb201fd9ac0a5febd28a96fb41b4a0f51dde3157f4", size = 176128243, upload-time = "2026-01-20T16:16:07.857Z" }, { url = "https://files.pythonhosted.org/packages/ab/a8/cdf8b3e4c98132f965f88c2313a4b493266832ad47fb52f23d14d4f86bb5/triton-3.6.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:74caf5e34b66d9f3a429af689c1c7128daba1d8208df60e81106b115c00d6fca", size = 188266850, upload-time = "2026-01-20T16:00:43.041Z" }, + { url = "https://files.pythonhosted.org/packages/3c/12/34d71b350e89a204c2c7777a9bba0dcf2f19a5bfdd70b57c4dbc5ffd7154/triton-3.6.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:448e02fe6dc898e9e5aa89cf0ee5c371e99df5aa5e8ad976a80b93334f3494fd", size = 176133521, upload-time = "2026-01-20T16:16:13.321Z" }, { url = "https://files.pythonhosted.org/packages/f9/0b/37d991d8c130ce81a8728ae3c25b6e60935838e9be1b58791f5997b24a54/triton-3.6.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10c7f76c6e72d2ef08df639e3d0d30729112f47a56b0c81672edc05ee5116ac9", size = 188289450, upload-time = "2026-01-20T16:00:49.136Z" }, + { url = "https://files.pythonhosted.org/packages/ce/4e/41b0c8033b503fd3cfcd12392cdd256945026a91ff02452bef40ec34bee7/triton-3.6.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1722e172d34e32abc3eb7711d0025bb69d7959ebea84e3b7f7a341cd7ed694d6", size = 176276087, upload-time = "2026-01-20T16:16:18.989Z" }, { url = "https://files.pythonhosted.org/packages/35/f8/9c66bfc55361ec6d0e4040a0337fb5924ceb23de4648b8a81ae9d33b2b38/triton-3.6.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d002e07d7180fd65e622134fbd980c9a3d4211fb85224b56a0a0efbd422ab72f", size = 188400296, upload-time = "2026-01-20T16:00:56.042Z" }, + { url = "https://files.pythonhosted.org/packages/49/55/5ecf0dcaa0f2fbbd4420f7ef227ee3cb172e91e5fede9d0ecaddc43363b4/triton-3.6.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ef5523241e7d1abca00f1d240949eebdd7c673b005edbbce0aca95b8191f1d43", size = 176138577, upload-time = "2026-01-20T16:16:25.426Z" }, { url = "https://files.pythonhosted.org/packages/df/3d/9e7eee57b37c80cec63322c0231bb6da3cfe535a91d7a4d64896fcb89357/triton-3.6.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a17a5d5985f0ac494ed8a8e54568f092f7057ef60e1b0fa09d3fd1512064e803", size = 188273063, upload-time = "2026-01-20T16:01:07.278Z" }, + { url = "https://files.pythonhosted.org/packages/48/db/56ee649cab5eaff4757541325aca81f52d02d4a7cd3506776cad2451e060/triton-3.6.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0b3a97e8ed304dfa9bd23bb41ca04cdf6b2e617d5e782a8653d616037a5d537d", size = 176274804, upload-time = "2026-01-20T16:16:31.528Z" }, { url = "https://files.pythonhosted.org/packages/f6/56/6113c23ff46c00aae423333eb58b3e60bdfe9179d542781955a5e1514cb3/triton-3.6.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:46bd1c1af4b6704e554cad2eeb3b0a6513a980d470ccfa63189737340c7746a7", size = 188397994, upload-time = "2026-01-20T16:01:14.236Z" }, ] From 551e6807874af9fc62b60385f8a9ebfb1800a5ba Mon Sep 17 00:00:00 2001 From: Paul Nechifor Date: Sat, 25 Apr 2026 09:51:33 +0300 Subject: [PATCH 16/30] fix attributes (#1918) --- .gitattributes | 1 + docs/capabilities/memory/assets/.gitattributes | 5 ++--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.gitattributes b/.gitattributes index 55c93ccff2..c4ccd1f825 100644 --- a/.gitattributes +++ b/.gitattributes @@ -17,3 +17,4 @@ *.gif filter=lfs diff=lfs merge=lfs -text binary *.foxe filter=lfs diff=lfs merge=lfs -text binary docs/capabilities/memory/assets/** filter=lfs diff=lfs merge=lfs -text +docs/capabilities/memory/assets/.gitattributes -filter -diff -merge text diff --git a/docs/capabilities/memory/assets/.gitattributes b/docs/capabilities/memory/assets/.gitattributes index 769e3570f2..ddffe375a9 100644 --- a/docs/capabilities/memory/assets/.gitattributes +++ b/docs/capabilities/memory/assets/.gitattributes @@ -1,3 +1,2 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:0e3041f7efbfcb80f33ede705be7c37bc41b68329bd433ccac0912675d09449e -size 257 +plot_brightness_algo.svg filter=lfs diff=lfs merge=lfs -text +plot_brightness_algo_delta.svg filter=lfs diff=lfs merge=lfs -text From 349e36d8727b758f5e81a98a5c5bfe54816bc748 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Mon, 27 Apr 2026 13:33:09 -0500 Subject: [PATCH 17/30] Jeff/fix/rconnect2 (#1784) --- .gitignore | 3 + dimos/core/coordination/python_worker.py | 16 +- dimos/core/docker_module.py | 2 +- dimos/core/global_config.py | 11 +- dimos/hardware/sensors/camera/module.py | 5 +- .../lidar/fastlio2/fastlio_blueprints.py | 35 ++- .../sensors/lidar/livox/livox_blueprints.py | 4 +- dimos/manipulation/blueprints.py | 10 +- dimos/manipulation/grasping/demo_grasping.py | 4 +- .../wavefront_frontier_goal_selector.py | 11 + dimos/navigation/replanning_a_star/module.py | 18 +- .../movement_manager/movement_manager.py | 133 +++++++++ .../movement_manager/test_movement_manager.py | 117 ++++++++ .../demo_object_scene_registration.py | 4 +- dimos/robot/all_blueprints.py | 2 + dimos/robot/cli/dimos.py | 48 +++- .../drone/blueprints/basic/drone_basic.py | 17 +- .../blueprints/perceptive/unitree_g1_shm.py | 10 +- .../primitive/uintree_g1_primitive_no_nav.py | 19 +- .../agentic/unitree_go2_security.py | 4 +- .../go2/blueprints/basic/unitree_go2_basic.py | 34 +-- .../go2/blueprints/basic/unitree_go2_fleet.py | 6 +- .../unitree_go2_webrtc_keyboard_teleop.py | 4 + .../go2/blueprints/smart/unitree_go2.py | 6 +- dimos/robot/unitree/keyboard_teleop.py | 10 +- dimos/robot/unitree/mujoco_connection.py | 16 +- dimos/simulation/unity/blueprint.py | 4 +- dimos/teleop/quest/blueprints.py | 4 +- dimos/test_no_sections.py | 2 + dimos/utils/generic.py | 17 ++ dimos/visualization/rerun/bridge.py | 253 +++++++++--------- dimos/visualization/rerun/conftest.py | 45 ++++ dimos/visualization/rerun/constants.py | 31 +++ .../visualization/rerun/test_viewer_ws_e2e.py | 201 ++++++++++++++ .../rerun/test_websocket_server.py | 210 +++++++++++++++ dimos/visualization/rerun/websocket_server.py | 244 +++++++++++++++++ dimos/visualization/vis_module.py | 87 ++++++ .../web/websocket_vis/websocket_vis_module.py | 24 +- docs/development/conventions.md | 12 + docs/usage/cli.md | 4 +- docs/usage/visualization.md | 42 +-- pyproject.toml | 2 +- uv.lock | 26 +- 43 files changed, 1465 insertions(+), 292 deletions(-) create mode 100644 dimos/navigation/smart_nav/modules/movement_manager/movement_manager.py create mode 100644 dimos/navigation/smart_nav/modules/movement_manager/test_movement_manager.py create mode 100644 dimos/visualization/rerun/conftest.py create mode 100644 dimos/visualization/rerun/constants.py create mode 100644 dimos/visualization/rerun/test_viewer_ws_e2e.py create mode 100644 dimos/visualization/rerun/test_websocket_server.py create mode 100644 dimos/visualization/rerun/websocket_server.py create mode 100644 dimos/visualization/vis_module.py create mode 100644 docs/development/conventions.md diff --git a/.gitignore b/.gitignore index 267aee13e4..9b2c6a5442 100644 --- a/.gitignore +++ b/.gitignore @@ -73,6 +73,9 @@ CLAUDE.MD /.mcp.json *.speedscope.json +# Hidden/personal directories +.hidden/ + # Coverage htmlcov/ .coverage diff --git a/dimos/core/coordination/python_worker.py b/dimos/core/coordination/python_worker.py index 3c434a982e..6c3aab3a2d 100644 --- a/dimos/core/coordination/python_worker.py +++ b/dimos/core/coordination/python_worker.py @@ -18,6 +18,7 @@ import multiprocessing from multiprocessing.connection import Connection import os +import signal import sys import threading import traceback @@ -337,12 +338,15 @@ class _WorkerState: def _worker_entrypoint(conn: Connection, worker_id: int) -> None: apply_library_config() + # Ignore SIGINT so the coordinator can orchestrate shutdown via the pipe. + # Without this, workers race with the coordinator: they start tearing down + # modules locally while the coordinator tries to send stop() RPCs, causing + # BrokenPipeErrors. + signal.signal(signal.SIGINT, signal.SIG_IGN) state = _WorkerState(instances={}, worker_id=worker_id) try: _worker_loop(conn, state) - except KeyboardInterrupt: - logger.info("Worker got KeyboardInterrupt.", worker_id=worker_id) except Exception as e: logger.error(f"Worker process error: {e}", exc_info=True) finally: @@ -361,12 +365,6 @@ def _worker_entrypoint(conn: Connection, worker_id: int) -> None: worker_id=worker_id, module_id=module_id, ) - except KeyboardInterrupt: - logger.warning( - "KeyboardInterrupt during worker stop", - module=type(instance).__name__, - worker_id=worker_id, - ) except Exception: logger.error("Error during worker shutdown", exc_info=True) @@ -433,7 +431,7 @@ def _worker_loop(conn: Connection, state: _WorkerState) -> None: if not conn.poll(timeout=0.1): continue request = conn.recv() - except (EOFError, KeyboardInterrupt): + except EOFError: break try: diff --git a/dimos/core/docker_module.py b/dimos/core/docker_module.py index 3ad9620556..f82a1b56db 100644 --- a/dimos/core/docker_module.py +++ b/dimos/core/docker_module.py @@ -30,7 +30,7 @@ from dimos.core.rpc_client import ModuleProxyProtocol, RpcCall from dimos.protocol.rpc.pubsubrpc import LCMRPC from dimos.utils.logging_config import setup_logger -from dimos.visualization.rerun.bridge import RERUN_GRPC_PORT, RERUN_WEB_PORT +from dimos.visualization.rerun.constants import RERUN_GRPC_PORT, RERUN_WEB_PORT if TYPE_CHECKING: from collections.abc import Callable diff --git a/dimos/core/global_config.py b/dimos/core/global_config.py index 214401959e..435f421dd1 100644 --- a/dimos/core/global_config.py +++ b/dimos/core/global_config.py @@ -13,13 +13,16 @@ # limitations under the License. import re -from typing import Literal, TypeAlias from pydantic_settings import BaseSettings, SettingsConfigDict from dimos.models.vl.types import VlModelName - -ViewerBackend: TypeAlias = Literal["rerun", "rerun-web", "rerun-connect", "foxglove", "none"] +from dimos.visualization.rerun.constants import ( + RERUN_ENABLE_WEB, + RERUN_OPEN_DEFAULT, + RerunOpenOption, + ViewerBackend, +) def _get_all_numbers(s: str) -> list[float]: @@ -37,6 +40,8 @@ class GlobalConfig(BaseSettings): replay_db: str = "go2_bigoffice" new_memory: bool = False viewer: ViewerBackend = "rerun" + rerun_open: RerunOpenOption = RERUN_OPEN_DEFAULT + rerun_web: bool = RERUN_ENABLE_WEB n_workers: int = 2 memory_limit: str = "auto" mujoco_camera_position: str | None = None diff --git a/dimos/hardware/sensors/camera/module.py b/dimos/hardware/sensors/camera/module.py index 9b4f50920c..0fe0d8f030 100644 --- a/dimos/hardware/sensors/camera/module.py +++ b/dimos/hardware/sensors/camera/module.py @@ -21,6 +21,7 @@ from dimos.agents.annotation import skill from dimos.core.coordination.blueprints import autoconnect from dimos.core.core import rpc +from dimos.core.global_config import global_config from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out from dimos.hardware.sensors.camera.spec import CameraHardware @@ -31,7 +32,7 @@ from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo from dimos.msgs.sensor_msgs.Image import Image, sharpness_barrier from dimos.spec import perception -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.vis_module import vis_module def default_transform() -> Transform: @@ -120,5 +121,5 @@ def stop(self) -> None: demo_camera = autoconnect( CameraModule.blueprint(), - RerunBridgeModule.blueprint(), + vis_module(viewer_backend=global_config.viewer), ) diff --git a/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py b/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py index 2946f1d247..2c2a64d61e 100644 --- a/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py +++ b/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py @@ -15,30 +15,45 @@ from dimos.core.coordination.blueprints import autoconnect from dimos.hardware.sensors.lidar.fastlio2.module import FastLio2 from dimos.mapping.voxels import VoxelGridMapper -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.vis_module import vis_module voxel_size = 0.05 mid360_fastlio = autoconnect( FastLio2.blueprint(voxel_size=voxel_size, map_voxel_size=voxel_size, map_freq=-1), - RerunBridgeModule.blueprint(), + vis_module( + "rerun", + rerun_config={ + "visual_override": { + "world/lidar": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), + }, + }, + ), ).global_config(n_workers=2, robot_model="mid360_fastlio2") mid360_fastlio_voxels = autoconnect( FastLio2.blueprint(), VoxelGridMapper.blueprint(voxel_size=voxel_size, carve_columns=False), - RerunBridgeModule.blueprint( - visual_override={ - "world/lidar": None, - } + vis_module( + "rerun", + rerun_config={ + "visual_override": { + "world/global_map": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), + "world/lidar": None, + }, + }, ), ).global_config(n_workers=3, robot_model="mid360_fastlio2_voxels") mid360_fastlio_voxels_native = autoconnect( FastLio2.blueprint(voxel_size=voxel_size, map_voxel_size=voxel_size, map_freq=3.0), - RerunBridgeModule.blueprint( - visual_override={ - "world/lidar": None, - } + vis_module( + "rerun", + rerun_config={ + "visual_override": { + "world/lidar": None, + "world/global_map": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), + }, + }, ), ).global_config(n_workers=2, robot_model="mid360_fastlio2") diff --git a/dimos/hardware/sensors/lidar/livox/livox_blueprints.py b/dimos/hardware/sensors/lidar/livox/livox_blueprints.py index 34ebc33c2a..e437d73994 100644 --- a/dimos/hardware/sensors/lidar/livox/livox_blueprints.py +++ b/dimos/hardware/sensors/lidar/livox/livox_blueprints.py @@ -14,9 +14,9 @@ from dimos.core.coordination.blueprints import autoconnect from dimos.hardware.sensors.lidar.livox.module import Mid360 -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.vis_module import vis_module mid360 = autoconnect( Mid360.blueprint(), - RerunBridgeModule.blueprint(), + vis_module("rerun"), ).global_config(n_workers=2, robot_model="mid360") diff --git a/dimos/manipulation/blueprints.py b/dimos/manipulation/blueprints.py index f950ea8efa..1c006c1d04 100644 --- a/dimos/manipulation/blueprints.py +++ b/dimos/manipulation/blueprints.py @@ -44,7 +44,7 @@ from dimos.msgs.sensor_msgs.JointState import JointState from dimos.perception.object_scene_registration import ObjectSceneRegistrationModule from dimos.robot.catalog.ufactory import xarm6 as _catalog_xarm6, xarm7 as _catalog_xarm7 -from dimos.robot.foxglove_bridge import FoxgloveBridge # TODO: migrate to rerun +from dimos.visualization.vis_module import vis_module # Single XArm6 planner (standalone, no coordinator) _xarm6_planner_cfg = _catalog_xarm6( @@ -196,14 +196,14 @@ use_aabb=True, max_obstacle_width=0.06, ), - FoxgloveBridge.blueprint(), # TODO: migrate to rerun + vis_module("foxglove"), ) .transports( { ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), } ) - .global_config(viewer="foxglove", n_workers=4) + .global_config(n_workers=4) ) @@ -289,7 +289,7 @@ from dimos.robot.catalog.ufactory import XARM7_SIM_PATH from dimos.simulation.engines.mujoco_sim_module import MujocoSimModule -from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode +from dimos.visualization.rerun.bridge import RerunBridgeModule _xarm7_sim_cfg = _catalog_xarm7( name="arm", @@ -323,7 +323,7 @@ hardware=[_xarm7_sim_cfg.to_hardware_component()], tasks=[_xarm7_sim_cfg.to_task_config()], ), - RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode()), + RerunBridgeModule.blueprint(), ).transports( { ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), diff --git a/dimos/manipulation/grasping/demo_grasping.py b/dimos/manipulation/grasping/demo_grasping.py index 37e1d38f1e..4a1d4b2cf6 100644 --- a/dimos/manipulation/grasping/demo_grasping.py +++ b/dimos/manipulation/grasping/demo_grasping.py @@ -22,7 +22,7 @@ from dimos.manipulation.grasping.grasping import GraspingModule from dimos.perception.detection.detectors.yoloe import YoloePromptMode from dimos.perception.object_scene_registration import ObjectSceneRegistrationModule -from dimos.robot.foxglove_bridge import FoxgloveBridge +from dimos.visualization.vis_module import vis_module camera_module = RealSenseCamera.blueprint(enable_pointcloud=False) @@ -44,7 +44,7 @@ ("/tmp", "/tmp", "rw") ], # Grasp visualization debug standalone: python -m dimos.manipulation.grasping.visualize_grasps ), - FoxgloveBridge.blueprint(), + vis_module("foxglove"), McpServer.blueprint(), McpClient.blueprint(), ).global_config(viewer="foxglove") diff --git a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py index b8dbe0dfc8..338d10d9b0 100644 --- a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py @@ -115,6 +115,7 @@ class WavefrontFrontierExplorer(Module): goal_reached: In[Bool] explore_cmd: In[Bool] stop_explore_cmd: In[Bool] + stop_movement: In[Bool] # LCM outputs goal_request: Out[PoseStamped] @@ -171,6 +172,10 @@ def start(self) -> None: unsub = self.stop_explore_cmd.subscribe(self._on_stop_explore_cmd) self.register_disposable(Disposable(unsub)) + if self.stop_movement.transport is not None: + unsub = self.stop_movement.subscribe(self._on_stop_movement) + self.register_disposable(Disposable(unsub)) + @rpc def stop(self) -> None: self.stop_exploration() @@ -201,6 +206,12 @@ def _on_stop_explore_cmd(self, msg: Bool) -> None: logger.info("Received exploration stop command via LCM") self.stop_exploration() + def _on_stop_movement(self, msg: Bool) -> None: + """Handle stop movement from teleop — cancel active exploration.""" + if msg.data and self.exploration_active: + logger.info("WavefrontFrontierExplorer: stop_movement received, stopping exploration") + self.stop_exploration() + def _count_costmap_information(self, costmap: OccupancyGrid) -> int: """ Count the amount of information in a costmap (free space + obstacles). diff --git a/dimos/navigation/replanning_a_star/module.py b/dimos/navigation/replanning_a_star/module.py index 2375af20ce..efc16b52d6 100644 --- a/dimos/navigation/replanning_a_star/module.py +++ b/dimos/navigation/replanning_a_star/module.py @@ -28,6 +28,9 @@ from dimos.msgs.nav_msgs.Path import Path from dimos.navigation.base import NavigationInterface, NavigationState from dimos.navigation.replanning_a_star.global_planner import GlobalPlanner +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() class ReplanningAStarPlanner(Module, NavigationInterface): @@ -36,10 +39,11 @@ class ReplanningAStarPlanner(Module, NavigationInterface): goal_request: In[PoseStamped] clicked_point: In[PointStamped] target: In[PoseStamped] + stop_movement: In[Bool] goal_reached: Out[Bool] navigation_state: Out[String] # TODO: set it - cmd_vel: Out[Twist] + nav_cmd_vel: Out[Twist] path: Out[Path] navigation_costmap: Out[OccupancyGrid] @@ -72,9 +76,14 @@ def start(self) -> None: ) ) + if self.stop_movement.transport is not None: + self.register_disposable( + Disposable(self.stop_movement.subscribe(self._on_stop_movement)) + ) + self.register_disposable(self._planner.path.subscribe(self.path.publish)) - self.register_disposable(self._planner.cmd_vel.subscribe(self.cmd_vel.publish)) + self.register_disposable(self._planner.cmd_vel.subscribe(self.nav_cmd_vel.publish)) self.register_disposable(self._planner.goal_reached.subscribe(self.goal_reached.publish)) @@ -92,6 +101,11 @@ def stop(self) -> None: super().stop() + def _on_stop_movement(self, msg: Bool) -> None: + if msg.data: + logger.info("ReplanningAStarPlanner: stop_movement received, cancelling goal") + self.cancel_goal() + @rpc def set_goal(self, goal: PoseStamped) -> bool: self._planner.handle_goal_request(goal) diff --git a/dimos/navigation/smart_nav/modules/movement_manager/movement_manager.py b/dimos/navigation/smart_nav/modules/movement_manager/movement_manager.py new file mode 100644 index 0000000000..5a2dd195c0 --- /dev/null +++ b/dimos/navigation/smart_nav/modules/movement_manager/movement_manager.py @@ -0,0 +1,133 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MovementManager: click-to-goal relay + teleop/nav velocity mux.""" + +from __future__ import annotations + +import math +import threading +import time +from typing import Any + +from dimos_lcm.std_msgs import Bool # type: ignore[import-untyped] +from reactivex.disposable import Disposable + +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.PointStamped import PointStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class MovementManagerConfig(ModuleConfig): + tele_cooldown_sec: float = 1.0 + tele_cmd_vel_scaling: Twist = Twist(Vector3(1, 1, 1), Vector3(1, 1, 1)) + + +class MovementManager(Module): + """Combine tele_cmd_vel (keyboard controls) and nav_cmd_vel in a sane way, output cmd_vel""" + + config: MovementManagerConfig + + clicked_point: In[PointStamped] + nav_cmd_vel: In[Twist] + tele_cmd_vel: In[Twist] + + goal: Out[PointStamped] + way_point: Out[PointStamped] + cmd_vel: Out[Twist] + stop_movement: Out[Bool] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._lock = threading.Lock() + self._teleop_active = False + self._last_teleop_time = 0.0 + + @rpc + def start(self) -> None: + super().start() + self.register_disposable(Disposable(self.clicked_point.subscribe(self._on_click))) + self.register_disposable(Disposable(self.nav_cmd_vel.subscribe(self._on_nav))) + self.register_disposable(Disposable(self.tele_cmd_vel.subscribe(self._on_teleop))) + + @rpc + def stop(self) -> None: + with self._lock: + self._teleop_active = False + super().stop() + + def _on_click(self, msg: PointStamped) -> None: + if not all(math.isfinite(v) for v in (msg.x, msg.y, msg.z)): + logger.warning("Ignored invalid click", x=msg.x, y=msg.y, z=msg.z) + return + if abs(msg.x) > 500 or abs(msg.y) > 500 or abs(msg.z) > 50: + logger.warning("Ignored out-of-range click", x=msg.x, y=msg.y, z=msg.z) + return + + logger.debug("Goal", x=round(msg.x, 1), y=round(msg.y, 1), z=round(msg.z, 1)) + self.way_point.publish(msg) + self.goal.publish(msg) + + def _cancel_goal(self) -> None: + self.stop_movement.publish(Bool(data=True)) + # NOTE: this NaN goal is more of a safety fallback. + # It can be REALLY bad if a robot is supposed to stop moving but wont + # we should probably think a more robust/strict requirement on planners + cancel = PointStamped( + ts=time.time(), frame_id="map", x=float("nan"), y=float("nan"), z=float("nan") + ) + self.way_point.publish(cancel) + self.goal.publish(cancel) + logger.debug("Navigation cancelled — waiting for new goal") + + def _on_nav(self, msg: Twist) -> None: + with self._lock: + if self._teleop_active: + # check if cooldown has expired + elapsed = time.monotonic() - self._last_teleop_time + if elapsed < self.config.tele_cooldown_sec: + return + self._teleop_active = False + self.cmd_vel.publish(msg) + + def _on_teleop(self, msg: Twist) -> None: + with self._lock: + was_active = self._teleop_active + self._teleop_active = True + self._last_teleop_time = time.monotonic() + + if not was_active: + self._cancel_goal() + logger.info("Teleop active") + + scale = self.config.tele_cmd_vel_scaling + scaled = Twist( + linear=Vector3( + msg.linear.x * scale.linear.x, + msg.linear.y * scale.linear.y, + msg.linear.z * scale.linear.z, + ), + angular=Vector3( + msg.angular.x * scale.angular.x, + msg.angular.y * scale.angular.y, + msg.angular.z * scale.angular.z, + ), + ) + self.cmd_vel.publish(scaled) diff --git a/dimos/navigation/smart_nav/modules/movement_manager/test_movement_manager.py b/dimos/navigation/smart_nav/modules/movement_manager/test_movement_manager.py new file mode 100644 index 0000000000..6858055605 --- /dev/null +++ b/dimos/navigation/smart_nav/modules/movement_manager/test_movement_manager.py @@ -0,0 +1,117 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for MovementManager: click-to-goal + teleop/nav velocity mux.""" + +from __future__ import annotations + +import math +import time +from unittest.mock import MagicMock + +import pytest + +from dimos.msgs.geometry_msgs.PointStamped import PointStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.navigation.smart_nav.modules.movement_manager.movement_manager import ( + MovementManager, +) + + +@pytest.fixture() +def manager() -> MovementManager: + """Create a real MovementManager and mock the publish methods on its output streams.""" + module = MovementManager(tele_cooldown_sec=0.1) + module.cmd_vel.publish = MagicMock() + module.stop_movement.publish = MagicMock() + module.goal.publish = MagicMock() + module.way_point.publish = MagicMock() + yield module + module._close_module() + + +def _twist(lx: float = 0.0) -> Twist: + return Twist(linear=Vector3(lx, 0, 0), angular=Vector3(0, 0, 0)) + + +def _click(x: float = 1.0, y: float = 2.0, z: float = 0.0) -> PointStamped: + return PointStamped(ts=time.time(), frame_id="map", x=x, y=y, z=z) + + +def test_teleop_suppresses_nav_and_cancels_goal(manager: MovementManager) -> None: + """Teleop arriving should suppress nav, publish stop_movement, and cancel the goal with NaN.""" + manager.config.tele_cooldown_sec = 10.0 + manager._on_teleop(_twist(lx=0.3)) + + # Nav is suppressed + manager.cmd_vel.publish.reset_mock() # type: ignore[union-attr] + manager._on_nav(_twist(lx=0.9)) + manager.cmd_vel.publish.assert_not_called() # type: ignore[union-attr] + + # stop_movement fired + manager.stop_movement.publish.assert_called_once() # type: ignore[union-attr] + + # Goal cancelled with NaN + cancel_msg = manager.goal.publish.call_args[0][0] # type: ignore[union-attr] + assert math.isnan(cancel_msg.x) + + +def test_nav_resumes_after_cooldown(manager: MovementManager) -> None: + """After the cooldown expires, nav commands pass through again.""" + manager.config.tele_cooldown_sec = 0.05 + manager._on_teleop(_twist(lx=0.3)) + time.sleep(0.1) + manager.cmd_vel.publish.reset_mock() # type: ignore[union-attr] + + manager._on_nav(_twist(lx=0.9)) + manager.cmd_vel.publish.assert_called_once() # type: ignore[union-attr] + + +def test_valid_click_publishes_goal(manager: MovementManager) -> None: + """A valid click should publish to both goal and way_point.""" + click = _click(x=5.0, y=3.0, z=0.1) + manager._on_click(click) + manager.goal.publish.assert_called_once_with(click) # type: ignore[union-attr] + manager.way_point.publish.assert_called_once_with(click) # type: ignore[union-attr] + + +def test_invalid_clicks_rejected(manager: MovementManager) -> None: + """NaN, Inf, and out-of-range clicks should not publish.""" + for bad_click in [ + _click(x=float("nan")), + _click(x=float("inf")), + _click(x=600.0), + ]: + manager._on_click(bad_click) + manager.goal.publish.assert_not_called() # type: ignore[union-attr] + + +def test_tele_cmd_vel_scaling() -> None: + """tele_cmd_vel_scaling multiplies each teleop twist component independently.""" + scaling = Twist(Vector3(0.5, 2.0, 0.0), Vector3(1.0, 1.0, 0.25)) + module = MovementManager(tele_cooldown_sec=10.0, tele_cmd_vel_scaling=scaling) + module.cmd_vel.publish = MagicMock() + module.stop_movement.publish = MagicMock() + module.goal.publish = MagicMock() + module.way_point.publish = MagicMock() + + module._on_teleop(Twist(Vector3(1, 1, 1), Vector3(1, 1, 1))) + + published = module.cmd_vel.publish.call_args[0][0] # type: ignore[union-attr] + assert published.linear.x == pytest.approx(0.5) + assert published.linear.y == pytest.approx(2.0) + assert published.linear.z == pytest.approx(0.0) + assert published.angular.z == pytest.approx(0.25) + module._close_module() diff --git a/dimos/perception/demo_object_scene_registration.py b/dimos/perception/demo_object_scene_registration.py index c9b489f54b..28044dec13 100644 --- a/dimos/perception/demo_object_scene_registration.py +++ b/dimos/perception/demo_object_scene_registration.py @@ -20,7 +20,7 @@ from dimos.hardware.sensors.camera.zed.compat import ZEDCamera from dimos.perception.detection.detectors.yoloe import YoloePromptMode from dimos.perception.object_scene_registration import ObjectSceneRegistrationModule -from dimos.robot.foxglove_bridge import FoxgloveBridge +from dimos.visualization.vis_module import vis_module camera_choice = "zed" @@ -34,7 +34,7 @@ demo_object_scene_registration = autoconnect( camera_module, ObjectSceneRegistrationModule.blueprint(target_frame="world", prompt_mode=YoloePromptMode.LRPC), - FoxgloveBridge.blueprint(), + vis_module("foxglove"), McpServer.blueprint(), McpClient.blueprint(), ).global_config(viewer="foxglove") diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 881914b129..8e17e74e71 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -152,6 +152,7 @@ "mock-b1-connection-module": "dimos.robot.unitree.b1.connection.MockB1ConnectionModule", "module-a": "dimos.robot.unitree.demo_error_on_name_conflicts.ModuleA", "module-b": "dimos.robot.unitree.demo_error_on_name_conflicts.ModuleB", + "movement-manager": "dimos.navigation.smart_nav.modules.movement_manager.movement_manager.MovementManager", "mujoco-sim-module": "dimos.simulation.engines.mujoco_sim_module.MujocoSimModule", "navigation-module": "dimos.robot.unitree.rosnav.NavigationModule", "navigation-skill-container": "dimos.agents.skills.navigation.NavigationSkillContainer", @@ -174,6 +175,7 @@ "reid-module": "dimos.perception.detection.reid.module.ReidModule", "replanning-a-star-planner": "dimos.navigation.replanning_a_star.module.ReplanningAStarPlanner", "rerun-bridge-module": "dimos.visualization.rerun.bridge.RerunBridgeModule", + "rerun-web-socket-server": "dimos.visualization.rerun.websocket_server.RerunWebSocketServer", "ros-nav": "dimos.navigation.rosnav.ROSNav", "security-module": "dimos.experimental.security_demo.security_module.SecurityModule", "semantic-search": "dimos.memory2.module.SemanticSearch", diff --git a/dimos/robot/cli/dimos.py b/dimos/robot/cli/dimos.py index 37d1bd2be0..e99553c2b3 100644 --- a/dimos/robot/cli/dimos.py +++ b/dimos/robot/cli/dimos.py @@ -21,10 +21,11 @@ import json import os from pathlib import Path +import signal import sys import time import types -from typing import TYPE_CHECKING, Any, Union, get_args, get_origin +from typing import TYPE_CHECKING, Any, Union, cast, get_args, get_origin import click from dotenv import load_dotenv @@ -38,7 +39,10 @@ from dimos.core.daemon import daemonize, install_signal_handlers from dimos.core.global_config import GlobalConfig, global_config from dimos.core.run_registry import get_most_recent, is_pid_alive, stop_entry +from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.protocol.service.lcmservice import autoconf from dimos.utils.logging_config import setup_logger +from dimos.visualization.rerun.constants import RerunOpenOption if TYPE_CHECKING: from dimos.core.coordination.blueprints import Blueprint, BlueprintAtom @@ -222,6 +226,10 @@ def run( cli_config_overrides: dict[str, Any] = ctx.obj + # this is a workaround until we have a proper way to have delayed-module-choice in blueprints + # ex: vis_module(viewer=global_config.viewer) is WRONG (viewer will always be default value) without this patch + global_config.update(**cli_config_overrides) + # Clean stale registry entries stale = cleanup_stale() if stale: @@ -660,17 +668,43 @@ def send( @main.command(name="rerun-bridge") def rerun_bridge_cmd( - viewer_mode: str = typer.Option( - "native", help="Viewer mode: native (desktop), web (browser), none (headless)" - ), memory_limit: str = typer.Option( "25%", help="Memory limit for Rerun viewer (e.g., '4GB', '16GB', '25%')" ), + rerun_open: str = typer.Option( + "native", help="How to open Rerun: one of native, web, both, none" + ), + rerun_web: bool = typer.Option( + True, "--rerun-web/--no-rerun-web", help="Enable/Disable Rerun web server" + ), ) -> None: - """Launch the Rerun visualization bridge.""" - from dimos.visualization.rerun.bridge import run_bridge + """Launch the Rerun visualization bridge. + + Standalone utility: runs the bridge directly in the main process (no + blueprint / worker pool) so users can attach a viewer to existing LCM + traffic without building a full module graph. + """ + # Deferred: RerunBridgeModule pulls in the rerun package (~1s), keep it + # out of the CLI's hot path so `dimos --help` stays fast. + from dimos.visualization.rerun.bridge import RerunBridgeModule + + valid = get_args(RerunOpenOption) + if rerun_open not in valid: + raise typer.BadParameter( + f"rerun_open must be one of {valid}, got {rerun_open!r}", param_hint="--rerun-open" + ) + autoconf(check_only=True) + + bridge = RerunBridgeModule( + memory_limit=memory_limit, + rerun_open=cast("RerunOpenOption", rerun_open), + rerun_web=rerun_web, + pubsubs=[LCM()], + ) + bridge.start() - run_bridge(viewer_mode=viewer_mode, memory_limit=memory_limit) + signal.signal(signal.SIGINT, lambda *_: bridge.stop()) + signal.pause() if __name__ == "__main__": diff --git a/dimos/robot/drone/blueprints/basic/drone_basic.py b/dimos/robot/drone/blueprints/basic/drone_basic.py index c1838d6ac7..aaf82f6355 100644 --- a/dimos/robot/drone/blueprints/basic/drone_basic.py +++ b/dimos/robot/drone/blueprints/basic/drone_basic.py @@ -20,10 +20,9 @@ from dimos.core.coordination.blueprints import autoconnect from dimos.core.global_config import global_config -from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.robot.drone.camera_module import DroneCameraModule from dimos.robot.drone.connection_module import DroneConnectionModule -from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule +from dimos.visualization.vis_module import vis_module def _static_drone_body(rr: Any) -> list[Any]: @@ -60,23 +59,12 @@ def _drone_rerun_blueprint() -> Any: _rerun_config = { "blueprint": _drone_rerun_blueprint, - "pubsubs": [LCM()], "static": { "world/tf/base_link": _static_drone_body, }, } -# Conditional visualization -if global_config.viewer == "foxglove": - from dimos.robot.foxglove_bridge import FoxgloveBridge - - _vis = FoxgloveBridge.blueprint() -elif global_config.viewer.startswith("rerun"): - from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode - - _vis = RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **_rerun_config) -else: - _vis = autoconnect() +_vis = vis_module(global_config.viewer, rerun_config=_rerun_config) # Determine connection string based on replay flag connection_string = "udp:0.0.0.0:14550" @@ -92,7 +80,6 @@ def _drone_rerun_blueprint() -> Any: outdoor=False, ), DroneCameraModule.blueprint(camera_intrinsics=[1000.0, 1000.0, 960.0, 540.0]), - WebsocketVisModule.blueprint(), ) __all__ = [ diff --git a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py index dd135a60a1..4941abad38 100644 --- a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py +++ b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py @@ -17,10 +17,11 @@ from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE from dimos.core.coordination.blueprints import autoconnect +from dimos.core.global_config import global_config from dimos.core.transport import pSHMTransport from dimos.msgs.sensor_msgs.Image import Image -from dimos.robot.foxglove_bridge import FoxgloveBridge from dimos.robot.unitree.g1.blueprints.perceptive.unitree_g1 import unitree_g1 +from dimos.visualization.vis_module import vis_module unitree_g1_shm = autoconnect( unitree_g1.transports( @@ -30,10 +31,9 @@ ), } ), - FoxgloveBridge.blueprint( - shm_channels=[ - "/color_image#sensor_msgs.Image", - ] + vis_module( + viewer_backend=global_config.viewer, + foxglove_config={"shm_channels": ["/color_image#sensor_msgs.Image"]}, ), ) diff --git a/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py b/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py index b04443732f..eeabea7909 100644 --- a/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py +++ b/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py @@ -40,8 +40,7 @@ from dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector import ( WavefrontFrontierExplorer, ) -from dimos.protocol.pubsub.impl.lcmpubsub import LCM -from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule +from dimos.visualization.vis_module import vis_module def _convert_camera_info(camera_info: Any) -> Any: @@ -94,7 +93,6 @@ def _g1_rerun_blueprint() -> Any: rerun_config = { "blueprint": _g1_rerun_blueprint, - "pubsubs": [LCM()], "visual_override": { "world/camera_info": _convert_camera_info, "world/navigation_costmap": _convert_navigation_costmap, @@ -104,18 +102,7 @@ def _g1_rerun_blueprint() -> Any: }, } -if global_config.viewer == "foxglove": - from dimos.robot.foxglove_bridge import FoxgloveBridge - - _with_vis = autoconnect(FoxgloveBridge.blueprint()) -elif global_config.viewer.startswith("rerun"): - from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode - - _with_vis = autoconnect( - RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config) - ) -else: - _with_vis = autoconnect() +_with_vis = vis_module(viewer_backend=global_config.viewer, rerun_config=rerun_config) def _create_webcam() -> Webcam: @@ -150,8 +137,6 @@ def _create_webcam() -> Webcam: VoxelGridMapper.blueprint(), CostMapper.blueprint(), WavefrontFrontierExplorer.blueprint(), - # Visualization - WebsocketVisModule.blueprint(), ) .global_config(n_workers=4, robot_model="unitree_g1") .transports( diff --git a/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_security.py b/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_security.py index be9e04a7fd..4b39a106b8 100644 --- a/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_security.py +++ b/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_security.py @@ -18,7 +18,7 @@ from dimos.core.coordination.blueprints import autoconnect from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.robot.unitree.go2.blueprints.agentic.unitree_go2_agentic import unitree_go2_agentic -from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode +from dimos.visualization.rerun.bridge import RerunBridgeModule def _convert_camera_info(camera_info: Any) -> Any: @@ -85,7 +85,7 @@ def _go2_rerun_blueprint() -> Any: unitree_go2_security = autoconnect( unitree_go2_agentic, - RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config), + RerunBridgeModule.blueprint(**rerun_config), ) __all__ = ["unitree_go2_security"] diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py index 54a2c0f7c6..4f86ccb0a3 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py @@ -22,10 +22,9 @@ from dimos.core.global_config import global_config from dimos.core.transport import pSHMTransport from dimos.msgs.sensor_msgs.Image import Image -from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator from dimos.robot.unitree.go2.connection import GO2Connection -from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule +from dimos.visualization.vis_module import vis_module # Mac has some issue with high bandwidth UDP, so we use pSHMTransport for color_image # actually we can use pSHMTransport for all platforms, and for all streams @@ -99,9 +98,6 @@ def _go2_rerun_blueprint() -> Any: rerun_config = { "blueprint": _go2_rerun_blueprint, - # any pubsub that supports subscribe_all and topic that supports str(topic) - # is acceptable here - "pubsubs": [LCM()], # Custom converters for specific rerun entity paths # Normally all these would be specified in their respectative modules # Until this is implemented we have central overrides here @@ -123,30 +119,20 @@ def _go2_rerun_blueprint() -> Any: }, } - -if global_config.viewer == "foxglove": - from dimos.robot.foxglove_bridge import FoxgloveBridge - - with_vis = autoconnect( - _transports_base, - FoxgloveBridge.blueprint(shm_channels=["/color_image#sensor_msgs.Image"]), - ) -elif global_config.viewer.startswith("rerun"): - from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode - - with_vis = autoconnect( - _transports_base, - RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config), - ) -else: - with_vis = _transports_base +_with_vis = autoconnect( + _transports_base, + vis_module( + viewer_backend=global_config.viewer, + rerun_config=rerun_config, + foxglove_config={"shm_channels": ["/color_image#sensor_msgs.Image"]}, + ), +) unitree_go2_basic = ( autoconnect( - with_vis, + _with_vis, GO2Connection.blueprint(), - WebsocketVisModule.blueprint(), ) .global_config(n_workers=4, robot_model="unitree_go2") .configurators(ClockSyncConfigurator()) diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py index a7a10767bf..bda362eeca 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py @@ -22,15 +22,13 @@ from dimos.core.coordination.blueprints import autoconnect from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator -from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import with_vis +from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import _with_vis from dimos.robot.unitree.go2.fleet_connection import Go2FleetConnection -from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule unitree_go2_fleet = ( autoconnect( - with_vis, + _with_vis, Go2FleetConnection.blueprint(), - WebsocketVisModule.blueprint(), ) .global_config(n_workers=4, robot_model="unitree_go2") .configurators(ClockSyncConfigurator()) diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_keyboard_teleop.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_keyboard_teleop.py index 01117ec3b5..3be0c62379 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_keyboard_teleop.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_keyboard_teleop.py @@ -31,6 +31,10 @@ unitree_go2_webrtc_keyboard_teleop = autoconnect( unitree_go2_coordinator, KeyboardTeleop.blueprint(), +).remappings( + [ + (KeyboardTeleop, "tele_cmd_vel", "cmd_vel"), + ] ) __all__ = ["unitree_go2_webrtc_keyboard_teleop"] diff --git a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py index f353d995af..16711115ab 100644 --- a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py +++ b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py @@ -27,6 +27,7 @@ ) from dimos.navigation.patrolling.module import PatrollingModule from dimos.navigation.replanning_a_star.module import ReplanningAStarPlanner +from dimos.navigation.smart_nav.modules.movement_manager.movement_manager import MovementManager from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import unitree_go2_basic unitree_go2 = autoconnect( @@ -36,7 +37,8 @@ ReplanningAStarPlanner.blueprint(), WavefrontFrontierExplorer.blueprint(), PatrollingModule.blueprint(), -).global_config(n_workers=9, robot_model="unitree_go2") + MovementManager.blueprint(), +).global_config(n_workers=10, robot_model="unitree_go2") class Go2MemoryConfig(RecorderConfig): @@ -52,6 +54,6 @@ class Go2Memory(Recorder): unitree_go2_memory = autoconnect( unitree_go2, Go2Memory.blueprint(), -).global_config(n_workers=10) +).global_config(n_workers=11) __all__ = ["unitree_go2", "unitree_go2_memory"] diff --git a/dimos/robot/unitree/keyboard_teleop.py b/dimos/robot/unitree/keyboard_teleop.py index e3c78ecc52..3e8f76a1cc 100644 --- a/dimos/robot/unitree/keyboard_teleop.py +++ b/dimos/robot/unitree/keyboard_teleop.py @@ -38,14 +38,14 @@ class KeyboardTeleop(Module): """Pygame-based keyboard control module. - Outputs standard Twist messages on /cmd_vel for velocity control. + Outputs standard Twist messages on /tele_cmd_vel for velocity control. Speed constants can be tuned at the top of this file, or overridden per-instance by passing linear_speed / angular_speed / boost_multiplier / slow_multiplier to the constructor. """ - cmd_vel: Out[Twist] # Standard velocity commands + tele_cmd_vel: Out[Twist] # Standard velocity commands _stop_event: threading.Event _keys_held: set[int] | None = None @@ -86,7 +86,7 @@ def stop(self) -> None: stop_twist = Twist() stop_twist.linear = Vector3(0, 0, 0) stop_twist.angular = Vector3(0, 0, 0) - self.cmd_vel.publish(stop_twist) + self.tele_cmd_vel.publish(stop_twist) self._stop_event.set() @@ -119,7 +119,7 @@ def _pygame_loop(self) -> None: stop_twist = Twist() stop_twist.linear = Vector3(0, 0, 0) stop_twist.angular = Vector3(0, 0, 0) - self.cmd_vel.publish(stop_twist) + self.tele_cmd_vel.publish(stop_twist) print("EMERGENCY STOP!") elif event.key == pygame.K_ESCAPE: # ESC quits @@ -163,7 +163,7 @@ def _pygame_loop(self) -> None: twist.angular.z *= speed_multiplier # Always publish twist at 50Hz - self.cmd_vel.publish(twist) + self.tele_cmd_vel.publish(twist) self._update_display(twist) diff --git a/dimos/robot/unitree/mujoco_connection.py b/dimos/robot/unitree/mujoco_connection.py index 39c0904684..43ddeb6530 100644 --- a/dimos/robot/unitree/mujoco_connection.py +++ b/dimos/robot/unitree/mujoco_connection.py @@ -20,9 +20,12 @@ from collections.abc import Callable import functools import json +import os +from pathlib import Path import pickle import subprocess import sys +import sysconfig import threading import time from typing import Any, TypeVar @@ -126,12 +129,23 @@ def start(self) -> None: # Launch the subprocess try: - # mjpython must be used macOS (because of launch_passive inside mujoco_process.py) + # mjpython must be used on macOS (because of launch_passive inside mujoco_process.py). + # It needs libpython on the dylib search path; uv-installed Pythons + # use @rpath which doesn't always resolve inside venvs, so we + # point DYLD_LIBRARY_PATH at the real libpython directory. executable = sys.executable if sys.platform != "darwin" else "mjpython" + env = os.environ.copy() + if sys.platform == "darwin": + # on some systems mujoco looks in the wrong place for shared libraries. So we force it look in the right place + libdir = Path(sysconfig.get_config_var("LIBDIR") or "") + if libdir.is_dir(): + existing = env.get("DYLD_LIBRARY_PATH", "") + env["DYLD_LIBRARY_PATH"] = f"{libdir}:{existing}" if existing else str(libdir) self.process = subprocess.Popen( [executable, str(LAUNCHER_PATH), config_pickle, shm_names_json], stderr=subprocess.PIPE, + env=env, ) except Exception as e: diff --git a/dimos/simulation/unity/blueprint.py b/dimos/simulation/unity/blueprint.py index f7e2d34ccb..d9b29ee610 100644 --- a/dimos/simulation/unity/blueprint.py +++ b/dimos/simulation/unity/blueprint.py @@ -28,7 +28,7 @@ from dimos.core.coordination.blueprints import autoconnect from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.simulation.unity.module import UnityBridgeModule -from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode +from dimos.visualization.rerun.bridge import RerunBridgeModule def _rerun_blueprint() -> Any: @@ -57,5 +57,5 @@ def _rerun_blueprint() -> Any: unity_sim = autoconnect( UnityBridgeModule.blueprint(), - RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config), + RerunBridgeModule.blueprint(**rerun_config), ) diff --git a/dimos/teleop/quest/blueprints.py b/dimos/teleop/quest/blueprints.py index 57c925c3f0..b825f29a17 100644 --- a/dimos/teleop/quest/blueprints.py +++ b/dimos/teleop/quest/blueprints.py @@ -26,12 +26,12 @@ from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.teleop.quest.quest_extensions import ArmTeleopModule from dimos.teleop.quest.quest_types import Buttons -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.vis_module import vis_module # Arm teleop with press-and-hold engage (has rerun viz) teleop_quest_rerun = autoconnect( ArmTeleopModule.blueprint(), - RerunBridgeModule.blueprint(), + vis_module("rerun"), ).transports( { ("left_controller_output", PoseStamped): LCMTransport("/teleop/left_delta", PoseStamped), diff --git a/dimos/test_no_sections.py b/dimos/test_no_sections.py index 902288b2e6..79f2d61b8f 100644 --- a/dimos/test_no_sections.py +++ b/dimos/test_no_sections.py @@ -52,6 +52,8 @@ ".tox", # third-party vendored code "gtsam", + # hidden/personal directories + ".hidden", } # Lines that match section patterns but are actually programmatic / intentional. diff --git a/dimos/utils/generic.py b/dimos/utils/generic.py index 84168ce057..200c7c6d86 100644 --- a/dimos/utils/generic.py +++ b/dimos/utils/generic.py @@ -16,10 +16,27 @@ import hashlib import json import os +import socket import string from typing import Any, Generic, TypeVar, overload import uuid +import psutil + + +def get_local_ips() -> list[tuple[str, str]]: + """Return ``(ip, interface_name)`` for every non-loopback IPv4 address. + + Picks up physical, virtual, and VPN interfaces (including Tailscale). + """ + results: list[tuple[str, str]] = [] + for iface, addrs in psutil.net_if_addrs().items(): + for addr in addrs: + if addr.family == socket.AF_INET and not addr.address.startswith("127."): + results.append((addr.address, iface)) + return results + + _T = TypeVar("_T") diff --git a/dimos/visualization/rerun/bridge.py b/dimos/visualization/rerun/bridge.py index f2e3e51d08..f6744e74fb 100644 --- a/dimos/visualization/rerun/bridge.py +++ b/dimos/visualization/rerun/bridge.py @@ -18,18 +18,19 @@ from collections.abc import Callable from dataclasses import field -from functools import lru_cache +import socket import subprocess import time from typing import ( Any, - Literal, Protocol, TypeAlias, TypeGuard, cast, + get_args, runtime_checkable, ) +from urllib.parse import urlparse from reactivex.disposable import Disposable import rerun as rr @@ -37,19 +38,23 @@ import rerun.blueprint as rrb from rerun.blueprint import Blueprint from toolz import pipe # type: ignore[import-untyped] -import typer from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.protocol.pubsub.patterns import Glob, pattern_matches from dimos.protocol.pubsub.spec import SubscribeAllCapable +from dimos.utils.generic import get_local_ips from dimos.utils.logging_config import setup_logger +from dimos.visualization.rerun.constants import ( + RERUN_ENABLE_WEB, + RERUN_GRPC_PORT, + RERUN_OPEN_DEFAULT, + RERUN_WEB_PORT, + RerunOpenOption, +) from dimos.visualization.rerun.init import rerun_init -RERUN_GRPC_PORT = 9877 -RERUN_WEB_PORT = 9090 - # TODO OUT visual annotations # # In the future it would be nice if modules can annotate their individual OUTs with (general or rerun specific) @@ -95,7 +100,6 @@ BlueprintFactory: TypeAlias = Callable[[], "Blueprint"] -# to_rerun() can return a single archetype or a list of (entity_path, archetype) tuples RerunMulti: TypeAlias = "list[tuple[str, Archetype]]" RerunData: TypeAlias = "Archetype | RerunMulti" @@ -119,18 +123,16 @@ class RerunConvertible(Protocol): def to_rerun(self) -> RerunData: ... -ViewerMode = Literal["native", "web", "connect", "none"] - - def _hex_to_rgba(hex_color: str) -> int: """Convert '#RRGGBB' to a 0xRRGGBBAA int (fully opaque).""" h = hex_color.lstrip("#") - return (int(h, 16) << 8) | 0xFF + if len(h) == 6: + return int(h + "ff", 16) + return int(h[:8], 16) def _with_graph_tab(bp: Blueprint) -> Blueprint: """Add a Graph tab alongside the existing viewer layout without changing it.""" - root = bp.root_container return rrb.Blueprint( rrb.Tabs( @@ -156,50 +158,26 @@ def _default_blueprint() -> Blueprint: ) -# Maps global_config.viewer -> bridge viewer_mode. -# Evaluated at blueprint construction time (main process), not in start() (worker process). -_BACKEND_TO_MODE: dict[str, ViewerMode] = { - "rerun": "native", - "rerun-web": "web", - "rerun-connect": "connect", - "none": "none", -} - - -def _resolve_viewer_mode() -> ViewerMode: - from dimos.core.global_config import global_config - - return _BACKEND_TO_MODE.get(global_config.viewer, "native") - - class Config(ModuleConfig): - """Configuration for RerunBridgeModule.""" - pubsubs: list[SubscribeAllCapable[Any, Any]] = field(default_factory=lambda: [LCM()]) visual_override: dict[Glob | str, Callable[[Any], Archetype]] = field(default_factory=dict) - - # Static items logged once after start. Maps entity_path -> callable(rr) returning Archetype static: dict[str, Callable[[Any], Archetype]] = field(default_factory=dict) - - grpc_port: int = RERUN_GRPC_PORT - web_port: int = RERUN_WEB_PORT - - # Per-entity max update rate (Hz). Entities not listed are unthrottled. - # Use for heavy entities to prevent viewer backpressure. max_hz: dict[str, float] = field(default_factory=dict) entity_prefix: str = "world" topic_to_entity: Callable[[Any], str] | None = None - viewer_mode: ViewerMode = field(default_factory=_resolve_viewer_mode) connect_url: str = "rerun+http://127.0.0.1:9877/proxy" memory_limit: str = "25%" - - # Blueprint factory: callable(rrb) -> Blueprint for viewer layout configuration - # Set to None to disable default blueprint + rerun_open: RerunOpenOption = RERUN_OPEN_DEFAULT + rerun_web: bool = RERUN_ENABLE_WEB + web_port: int = RERUN_WEB_PORT blueprint: BlueprintFactory | None = _default_blueprint +Config.model_rebuild(_types_namespace={"Archetype": Archetype, "Blueprint": Blueprint}) + + class RerunBridgeModule(Module): """Bridge that logs messages from pubsubs to Rerun. @@ -217,22 +195,31 @@ class RerunBridgeModule(Module): """ config: Config + _last_log: dict[str, float] # TODO this doesn't belong here, either hardcode it or put it to rerun bridge config - GV_SCALE = 100.0 # graphviz inches to rerun screen units - MODULE_RADIUS = 30.0 - CHANNEL_RADIUS = 20.0 + GRAPH_VIZ_SCALE = 100.0 + MODULE_RADIUS = 20.0 + CHANNEL_RADIUS = 12.0 + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._last_log = {} + self._override_cache: dict[str, Callable[[Any], RerunData | None]] = {} - @lru_cache(maxsize=256) def _visual_override_for_entity_path( self, entity_path: str ) -> Callable[[Any], RerunData | None]: """Return a composed visual override for the entity path. Chains matching overrides from config, ending with final_convert - which handles .to_rerun() or passes through Archetypes. + which handles .to_rerun() or passes through Archetypes. Cached per + instance (not via ``lru_cache`` on a method, which would leak ``self``). """ - # find all matching converters for this entity path + cached = self._override_cache.get(entity_path) + if cached is not None: + return cached + matches = [ fn for pattern, fn in self.config.visual_override.items() @@ -241,9 +228,13 @@ def _visual_override_for_entity_path( # None means "suppress this topic entirely" if any(fn is None for fn in matches): - return lambda msg: None - # final step (ensures we return Archetype or None) + def suppressed(msg: Any) -> RerunData | None: + return None + + self._override_cache[entity_path] = suppressed + return suppressed + def final_convert(msg: Any) -> RerunData | None: if isinstance(msg, Archetype): return msg @@ -253,23 +244,21 @@ def final_convert(msg: Any) -> RerunData | None: return msg.to_rerun() return None - # compose all converters - return lambda msg: pipe(msg, *matches, final_convert) + def composed(msg: Any) -> RerunData | None: + return cast("RerunData | None", pipe(msg, *matches, final_convert)) + + self._override_cache[entity_path] = composed + return composed def _get_entity_path(self, topic: Any) -> str: - """Convert a topic to a Rerun entity path.""" if self.config.topic_to_entity: return self.config.topic_to_entity(topic) - # Default: use topic.name if available (LCM Topic), else str topic_str = getattr(topic, "name", None) or str(topic) - # Strip everything after # (LCM topic suffix) - topic_str = topic_str.split("#")[0] + topic_str = topic_str.split("#")[0] # strip LCM topic suffix return f"{self.config.entity_prefix}{topic_str}" def _on_message(self, msg: Any, topic: Any) -> None: - """Handle incoming message - log to rerun.""" - entity_path: str = self._get_entity_path(topic) # Throttle entities with a max_hz limit @@ -279,7 +268,6 @@ def _on_message(self, msg: Any, topic: Any) -> None: return self._last_log[entity_path] = now - # apply visual overrides (including final_convert which handles .to_rerun()) rerun_data: RerunData | None = self._visual_override_for_entity_path(entity_path)(msg) if not rerun_data: @@ -296,47 +284,87 @@ def _on_message(self, msg: Any, topic: Any) -> None: def start(self) -> None: super().start() - logger.info("Rerun bridge starting", viewer_mode=self.config.viewer_mode) + logger.info("Rerun bridge starting") - # Build throttle lookup: entity_path → min interval in seconds - self._last_log: dict[str, float] = {} + self._last_log = {} self._min_intervals: dict[str, float] = { entity: 1.0 / hz for entity, hz in self.config.max_hz.items() if hz > 0 } - # Initialize and spawn Rerun viewer rerun_init("dimos") - if self.config.viewer_mode == "native": + parsed = urlparse(self.config.connect_url.replace("rerun+", "", 1)) + grpc_port = parsed.port or RERUN_GRPC_PORT + + port_in_use = False + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + port_in_use = sock.connect_ex(("127.0.0.1", grpc_port)) == 0 + + if port_in_use: + logger.info(f"gRPC port {grpc_port} already in use, connecting to existing server") + rr.connect_grpc(url=self.config.connect_url) + server_uri = self.config.connect_url + else: + server_uri = rr.serve_grpc( + grpc_port=grpc_port, + server_memory_limit=self.config.memory_limit, + ) + logger.info(f"Rerun gRPC server ready at {server_uri}") + + if self.config.rerun_open not in get_args(RerunOpenOption): + logger.warning( + f"rerun_open was {self.config.rerun_open} which is not one of " + f"{get_args(RerunOpenOption)}" + ) + + spawned = False + if self.config.rerun_open in ("native", "both"): try: import rerun_bindings + # Use --connect so the viewer connects to the bridge's gRPC + # server rather than starting its own (which would conflict). rerun_bindings.spawn( - port=self.config.grpc_port, executable_name="dimos-viewer", memory_limit=self.config.memory_limit, + extra_args=["--connect", server_uri], ) - rr.connect_grpc(f"rerun+http://127.0.0.1:{self.config.grpc_port}/proxy") + spawned = True except ImportError: - rr.spawn(connect=True, memory_limit=self.config.memory_limit) + pass # dimos-viewer not installed except Exception: logger.warning( "dimos-viewer found but failed to spawn, falling back to stock rerun", exc_info=True, ) - rr.spawn(connect=True, memory_limit=self.config.memory_limit) - elif self.config.viewer_mode == "web": - server_uri = rr.serve_grpc() - rr.serve_web_viewer(connect_to=server_uri, open_browser=False) - elif self.config.viewer_mode == "connect": - rr.connect_grpc(self.config.connect_url) - # "none" - just init, no viewer (connect externally) + # fallback on normal (non-dimos-viewer) rerun + if not spawned: + try: + rr.spawn(connect=True, memory_limit=self.config.memory_limit) + spawned = True + except (RuntimeError, FileNotFoundError): + logger.warning( + "Rerun native viewer not available (headless?). " + "Bridge will continue without a viewer — data is still " + "accessible via --rerun-open web or by connecting a viewer to the gRPC server.", + exc_info=True, + ) + + open_web = self.config.rerun_open == "web" or self.config.rerun_open == "both" + if open_web or self.config.rerun_web: + rr.serve_web_viewer( + connect_to=server_uri, + open_browser=open_web, + web_port=self.config.web_port, + ) + + if self.config.rerun_open == "none" or (self.config.rerun_open == "native" and not spawned): + self._log_connect_hints(grpc_port) if self.config.blueprint: rr.send_blueprint(_with_graph_tab(self.config.blueprint())) - # Start pubsubs and subscribe to all messages for pubsub in self.config.pubsubs: logger.info(f"bridge listening on {pubsub.__class__.__name__}") if hasattr(pubsub, "start"): @@ -344,13 +372,35 @@ def start(self) -> None: unsub = pubsub.subscribe_all(self._on_message) self.register_disposable(Disposable(unsub)) - # Add pubsub stop as disposable for pubsub in self.config.pubsubs: if hasattr(pubsub, "stop"): self.register_disposable(Disposable(pubsub.stop)) # type: ignore[union-attr] self._log_static() + def _log_connect_hints(self, grpc_port: int) -> None: + """Log CLI commands for connecting a viewer to this bridge.""" + local_ips = get_local_ips() + hostname = socket.gethostname() + connect_url = f"rerun+http://127.0.0.1:{grpc_port}/proxy" + + lines = [ + "", + "=" * 60, + "Rerun gRPC server running (no viewer opened)", + "", + "Connect a viewer:", + f" dimos-viewer --connect {connect_url}", + ] + for ip, iface in local_ips: + lines.append(f" dimos-viewer --connect rerun+http://{ip}:{grpc_port}/proxy # {iface}") + lines.append("") + lines.append(f" hostname: {hostname}") + lines.append("=" * 60) + lines.append("") + + logger.info("\n".join(lines)) + def _log_static(self) -> None: for entity_path, factory in self.config.static.items(): data = factory(rr) @@ -371,7 +421,6 @@ def log_blueprint_graph(self, dot_code: str, module_names: list[str]) -> None: dot_code: The DOT-format graph (from ``introspection.blueprint.dot.render``). module_names: List of module class names (to distinguish modules from channels). """ - try: result = subprocess.run( ["dot", "-Tplain"], input=dot_code, text=True, capture_output=True, timeout=30 @@ -393,8 +442,8 @@ def log_blueprint_graph(self, dot_code: str, module_names: list[str]) -> None: if line.startswith("node "): parts = line.split() node_id = parts[1].strip('"') - x = float(parts[2]) * self.GV_SCALE - y = -float(parts[3]) * self.GV_SCALE + x = float(parts[2]) * self.GRAPH_VIZ_SCALE + y = -float(parts[3]) * self.GRAPH_VIZ_SCALE label = parts[6].strip('"') color = parts[9].strip('"') @@ -427,49 +476,5 @@ def log_blueprint_graph(self, dot_code: str, module_names: list[str]) -> None: @rpc def stop(self) -> None: + self._override_cache.clear() super().stop() - - -def run_bridge( - viewer_mode: str = "native", - memory_limit: str = "25%", -) -> None: - """Start a RerunBridgeModule with default LCM config and block until interrupted.""" - import signal - - from dimos.protocol.service.lcmservice import autoconf - - autoconf(check_only=True) - - bridge = RerunBridgeModule( - viewer_mode=viewer_mode, - memory_limit=memory_limit, - # any pubsub that supports subscribe_all and topic that supports str(topic) - # is acceptable here - pubsubs=[LCM()], - ) - - bridge.start() - - signal.signal(signal.SIGINT, lambda *_: bridge.stop()) - signal.pause() - - -app = typer.Typer() - - -@app.command() -def cli( - viewer_mode: str = typer.Option( - "native", help="Viewer mode: native (desktop), web (browser), none (headless)" - ), - memory_limit: str = typer.Option( - "25%", help="Memory limit for Rerun viewer (e.g., '4GB', '16GB', '25%')" - ), -) -> None: - """Rerun bridge for LCM messages.""" - run_bridge(viewer_mode=viewer_mode, memory_limit=memory_limit) - - -if __name__ == "__main__": - app() diff --git a/dimos/visualization/rerun/conftest.py b/dimos/visualization/rerun/conftest.py new file mode 100644 index 0000000000..f269bb8015 --- /dev/null +++ b/dimos/visualization/rerun/conftest.py @@ -0,0 +1,45 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from collections.abc import Callable +import time + +import pytest +import websockets.asyncio.client as ws_client + + +def _wait_for_server(port: int, timeout: float = 5.0) -> None: + """Block until the WebSocket server on *port* accepts a connection.""" + + async def _probe() -> None: + async with ws_client.connect(f"ws://127.0.0.1:{port}/ws"): + pass + + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + asyncio.run(_probe()) + return + except Exception: + time.sleep(0.05) + raise TimeoutError(f"Server on port {port} did not become ready within {timeout}s") + + +@pytest.fixture() +def wait_for_server() -> Callable[[int, float], None]: + """Fixture that returns a callable to wait for a WebSocket server.""" + return _wait_for_server diff --git a/dimos/visualization/rerun/constants.py b/dimos/visualization/rerun/constants.py new file mode 100644 index 0000000000..860c691cef --- /dev/null +++ b/dimos/visualization/rerun/constants.py @@ -0,0 +1,31 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Rerun visualization defaults and type aliases. + +This module is intentionally free of heavy imports so it can be +loaded from lightweight entry-points like ``global_config`` and +``dimos --help`` without pulling in the Rerun SDK or the module +framework. +""" + +from typing import Literal, TypeAlias + +ViewerBackend: TypeAlias = Literal["rerun", "foxglove", "none"] +RerunOpenOption: TypeAlias = Literal["none", "web", "native", "both"] + +RERUN_OPEN_DEFAULT: RerunOpenOption = "native" +RERUN_ENABLE_WEB = False +RERUN_GRPC_PORT = 9876 +RERUN_WEB_PORT = 9877 diff --git a/dimos/visualization/rerun/test_viewer_ws_e2e.py b/dimos/visualization/rerun/test_viewer_ws_e2e.py new file mode 100644 index 0000000000..260699a3e8 --- /dev/null +++ b/dimos/visualization/rerun/test_viewer_ws_e2e.py @@ -0,0 +1,201 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end tests for dimos-viewer ↔ RerunWebSocketServer protocol.""" + +from __future__ import annotations + +import asyncio +import json +import os +import subprocess +import threading +import time +from typing import Any + +import pytest +import websockets.asyncio.client as ws_client + +from dimos.visualization.rerun.websocket_server import RerunWebSocketServer + +_E2E_PORT = 13032 + + +@pytest.fixture() +def server(wait_for_server: Any) -> RerunWebSocketServer: + module = RerunWebSocketServer(port=_E2E_PORT) + module.start() + wait_for_server(_E2E_PORT) + yield module # type: ignore[misc] + module.stop() + + +def _send_messages(port: int, messages: list[dict[str, Any]], *, delay: float = 0.05) -> None: + async def _run() -> None: + async with ws_client.connect(f"ws://127.0.0.1:{port}/ws") as ws: + for msg in messages: + await ws.send(json.dumps(msg)) + await asyncio.sleep(delay) + + asyncio.run(_run()) + + +class TestViewerProtocolE2E: + """Verify the Python-server side of the viewer ↔ DimOS protocol.""" + + def test_viewer_click_reaches_stream(self, server: RerunWebSocketServer) -> None: + """A viewer click over WebSocket publishes PointStamped.""" + received: list[Any] = [] + done = threading.Event() + unsub = server.clicked_point.subscribe(lambda pt: (received.append(pt), done.set())) + + _send_messages( + _E2E_PORT, + [ + { + "type": "click", + "x": 10.0, + "y": 20.0, + "z": 0.5, + "entity_path": "/world/robot", + "timestamp_ms": 42000, + } + ], + ) + + done.wait(timeout=3.0) + unsub() + + assert len(received) == 1 + pt = received[0] + assert pt.x == pytest.approx(10.0) + assert pt.y == pytest.approx(20.0) + assert pt.z == pytest.approx(0.5) + assert pt.frame_id == "/world/robot" + assert pt.ts == pytest.approx(42.0) + + def test_full_viewer_session_sequence(self, server: RerunWebSocketServer) -> None: + """Realistic session: heartbeats, click, twist, stop — only the click produces a point.""" + received: list[Any] = [] + done = threading.Event() + unsub = server.clicked_point.subscribe(lambda pt: (received.append(pt), done.set())) + + _send_messages( + _E2E_PORT, + [ + {"type": "heartbeat", "timestamp_ms": 1000}, + {"type": "heartbeat", "timestamp_ms": 2000}, + { + "type": "click", + "x": 3.14, + "y": 2.71, + "z": 1.41, + "entity_path": "/world", + "timestamp_ms": 3000, + }, + { + "type": "twist", + "linear_x": 0.5, + "linear_y": 0.0, + "linear_z": 0.0, + "angular_x": 0.0, + "angular_y": 0.0, + "angular_z": 0.0, + }, + {"type": "stop"}, + {"type": "heartbeat", "timestamp_ms": 4000}, + ], + delay=0.2, + ) + + done.wait(timeout=3.0) + unsub() + + assert len(received) == 1, f"Expected exactly 1 click, got {len(received)}" + assert received[0].x == pytest.approx(3.14) + assert received[0].y == pytest.approx(2.71) + assert received[0].z == pytest.approx(1.41) + + def test_reconnect_after_disconnect(self, server: RerunWebSocketServer) -> None: + """Server keeps accepting new connections after a client disconnects.""" + received: list[Any] = [] + all_done = threading.Event() + + def _on_pt(pt: Any) -> None: + received.append(pt) + if len(received) >= 2: + all_done.set() + + unsub = server.clicked_point.subscribe(_on_pt) + + _send_messages( + _E2E_PORT, + [{"type": "click", "x": 1.0, "y": 0.0, "z": 0.0, "entity_path": "", "timestamp_ms": 0}], + ) + _send_messages( + _E2E_PORT, + [{"type": "click", "x": 2.0, "y": 0.0, "z": 0.0, "entity_path": "", "timestamp_ms": 0}], + ) + + all_done.wait(timeout=5.0) + unsub() + + xs = sorted(pt.x for pt in received) + assert xs == [1.0, 2.0], f"Unexpected xs: {xs}" + + +class TestViewerBinaryConnectMode: + """Smoke test: dimos-viewer binary starts in --connect mode.""" + + @pytest.fixture() + def viewer_process(self, server: RerunWebSocketServer) -> subprocess.Popen[bytes]: + proc = subprocess.Popen( + [ + "dimos-viewer", + "--connect", + f"--ws-url=ws://127.0.0.1:{_E2E_PORT}/ws", + ], + env={**os.environ, "DISPLAY": ""}, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + yield proc # type: ignore[misc] + proc.terminate() + try: + proc.wait(timeout=3) + except subprocess.TimeoutExpired: + proc.kill() + + @pytest.mark.skip( + reason="Incompatible with current winit: fails without DISPLAY (headless CI exits before WS connect) and hangs with DISPLAY (GUI event loop blocks before printing URL).", + ) + def test_viewer_ws_client_connects(self, viewer_process: subprocess.Popen[bytes]) -> None: + """dimos-viewer --connect starts and its WS client connects to our server.""" + deadline = time.monotonic() + 5.0 + while time.monotonic() < deadline: + if viewer_process.poll() is not None: + break + time.sleep(0.1) + + stdout = ( + viewer_process.stdout.read().decode(errors="replace") if viewer_process.stdout else "" + ) + stderr = ( + viewer_process.stderr.read().decode(errors="replace") if viewer_process.stderr else "" + ) + + combined = stdout + stderr + assert f"ws://127.0.0.1:{_E2E_PORT}" in combined, ( + f"Viewer did not attempt WS connection.\nstdout:\n{stdout}\nstderr:\n{stderr}" + ) diff --git a/dimos/visualization/rerun/test_websocket_server.py b/dimos/visualization/rerun/test_websocket_server.py new file mode 100644 index 0000000000..b4304cf7b4 --- /dev/null +++ b/dimos/visualization/rerun/test_websocket_server.py @@ -0,0 +1,210 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RerunWebSocketServer.""" + +from __future__ import annotations + +import asyncio +import json +import threading +import time +from typing import Any + +import pytest +import websockets.asyncio.client as ws_client + +from dimos.visualization.rerun.websocket_server import RerunWebSocketServer + +_TEST_PORT = 13031 + + +class MockViewerPublisher: + """Simulates dimos-viewer sending JSON events over WebSocket.""" + + def __init__(self, url: str) -> None: + self._url = url + self._ws: Any = None + self._loop: asyncio.AbstractEventLoop | None = None + + def __enter__(self) -> MockViewerPublisher: + self._loop = asyncio.new_event_loop() + self._ws = self._loop.run_until_complete(self._connect()) + return self + + def __exit__(self, *_: Any) -> None: + if self._ws is not None and self._loop is not None: + self._loop.run_until_complete(self._ws.close()) + if self._loop is not None: + self._loop.close() + + async def _connect(self) -> Any: + return await ws_client.connect(self._url) + + def send_click( + self, x: float, y: float, z: float, entity_path: str = "", timestamp_ms: int = 0 + ) -> None: + self._send( + { + "type": "click", + "x": x, + "y": y, + "z": z, + "entity_path": entity_path, + "timestamp_ms": timestamp_ms, + } + ) + + def send_twist( + self, + linear_x: float, + linear_y: float, + linear_z: float, + angular_x: float, + angular_y: float, + angular_z: float, + ) -> None: + self._send( + { + "type": "twist", + "linear_x": linear_x, + "linear_y": linear_y, + "linear_z": linear_z, + "angular_x": angular_x, + "angular_y": angular_y, + "angular_z": angular_z, + } + ) + + def send_stop(self) -> None: + self._send({"type": "stop"}) + + def flush(self, delay: float = 0.1) -> None: + time.sleep(delay) + + def _send(self, msg: dict[str, Any]) -> None: + assert self._loop is not None and self._ws is not None + self._loop.run_until_complete(self._ws.send(json.dumps(msg))) + + +@pytest.fixture() +def server(wait_for_server: Any) -> RerunWebSocketServer: + module = RerunWebSocketServer(port=_TEST_PORT) + module.start() + wait_for_server(_TEST_PORT) + yield module # type: ignore[misc] + module.stop() + + +@pytest.fixture() +def publisher(server: RerunWebSocketServer) -> MockViewerPublisher: + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as publisher: + yield publisher # type: ignore[misc] + + +# ── Tests ──────────────────────────────────────────────────────────────── + + +def test_click_publishes_point_stamped( + server: RerunWebSocketServer, publisher: MockViewerPublisher +) -> None: + """Click event arrives as PointStamped with correct coords, frame_id, and timestamp.""" + received: list[Any] = [] + done = threading.Event() + + unsub = server.clicked_point.subscribe(lambda point: (received.append(point), done.set())) + + publisher.send_click(1.5, 2.5, 0.0, "/robot/base", timestamp_ms=5000) + publisher.flush() + done.wait(timeout=2.0) + unsub() + + assert len(received) == 1 + point = received[0] + assert point.x == pytest.approx(1.5) + assert point.y == pytest.approx(2.5) + assert point.z == pytest.approx(0.0) + assert point.frame_id == "/robot/base" + assert point.ts == pytest.approx(5.0) + + +def test_twist_publishes_on_tele_cmd_vel( + server: RerunWebSocketServer, publisher: MockViewerPublisher +) -> None: + """Twist event arrives as Twist on tele_cmd_vel.""" + received: list[Any] = [] + done = threading.Event() + + unsub = server.tele_cmd_vel.subscribe(lambda twist: (received.append(twist), done.set())) + + publisher.send_twist(0.5, 0.0, 0.0, 0.0, 0.0, 0.8) + publisher.flush() + done.wait(timeout=2.0) + unsub() + + assert len(received) == 1 + assert received[0].linear.x == pytest.approx(0.5) + assert received[0].angular.z == pytest.approx(0.8) + + +def test_stop_publishes_zero_twist( + server: RerunWebSocketServer, publisher: MockViewerPublisher +) -> None: + """Stop event publishes a zero Twist on tele_cmd_vel.""" + received: list[Any] = [] + done = threading.Event() + + unsub = server.tele_cmd_vel.subscribe(lambda twist: (received.append(twist), done.set())) + + publisher.send_stop() + publisher.flush() + done.wait(timeout=2.0) + unsub() + + assert len(received) == 1 + assert received[0].is_zero() + + +def test_invalid_json_does_not_crash(server: RerunWebSocketServer) -> None: + """Malformed JSON is silently dropped; server stays alive for the next message.""" + + async def _send_bad() -> None: + async with ws_client.connect(f"ws://127.0.0.1:{_TEST_PORT}/ws") as ws: + await ws.send("this is not json {{") + await asyncio.sleep(0.1) + await ws.send(json.dumps({"type": "heartbeat", "timestamp_ms": 0})) + await asyncio.sleep(0.1) + + asyncio.run(_send_bad()) + + +def test_mixed_message_sequence( + server: RerunWebSocketServer, publisher: MockViewerPublisher +) -> None: + """Realistic session: heartbeat, click, twist, stop — only the click produces a point.""" + received: list[Any] = [] + done = threading.Event() + unsub = server.clicked_point.subscribe(lambda point: (received.append(point), done.set())) + + publisher.send_click(7.0, 8.0, 9.0, "/map", timestamp_ms=1100) + publisher.send_twist(0.3, 0.0, 0.0, 0.0, 0.0, 0.2) + publisher.send_stop() + publisher.flush() + done.wait(timeout=2.0) + unsub() + + assert len(received) == 1 + assert received[0].x == pytest.approx(7.0) + assert received[0].y == pytest.approx(8.0) + assert received[0].z == pytest.approx(9.0) diff --git a/dimos/visualization/rerun/websocket_server.py b/dimos/visualization/rerun/websocket_server.py new file mode 100644 index 0000000000..0c0ac2acf2 --- /dev/null +++ b/dimos/visualization/rerun/websocket_server.py @@ -0,0 +1,244 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""WebSocket server module that receives events from dimos-viewer. + +When dimos-viewer is started with ``--connect``, LCM multicast is unavailable +across machines. The viewer falls back to sending click, twist, and stop events +as JSON over a WebSocket connection. This module acts as the server-side +counterpart: it listens for those connections and translates incoming messages +into DimOS stream publishes. + +Message format (newline-delimited JSON, ``"type"`` discriminant): + + {"type":"heartbeat","timestamp_ms":1234567890} + {"type":"click","x":1.0,"y":2.0,"z":3.0,"entity_path":"/world","timestamp_ms":1234567890} + {"type":"twist","linear_x":0.5,"linear_y":0.0,"linear_z":0.0, + "angular_x":0.0,"angular_y":0.0,"angular_z":0.8} + {"type":"stop"} +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import socket +import threading +from typing import Any, Literal, TypedDict, Union + +import websockets +import websockets.asyncio.server as ws_server + +from dimos.core.core import rpc +from dimos.core.global_config import global_config +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import Out +from dimos.msgs.geometry_msgs.PointStamped import PointStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.utils.generic import get_local_ips +from dimos.utils.logging_config import setup_logger +from dimos.visualization.rerun.constants import RERUN_GRPC_PORT + +logger = setup_logger() + + +class ClickMsg(TypedDict): + type: Literal["click"] + x: float + y: float + z: float + entity_path: str + timestamp_ms: int + + +class TwistMsg(TypedDict): + type: Literal["twist"] + linear_x: float + linear_y: float + linear_z: float + angular_x: float + angular_y: float + angular_z: float + + +class StopMsg(TypedDict): + type: Literal["stop"] + + +class HeartbeatMsg(TypedDict): + type: Literal["heartbeat"] + timestamp_ms: int + + +ViewerMsg = Union[ClickMsg, TwistMsg, StopMsg, HeartbeatMsg] + + +def _handshake_noise_filter(record: logging.LogRecord) -> bool: + """Drop noisy "opening handshake failed" records from port scanners etc.""" + msg = record.getMessage() + return not ("opening handshake failed" in msg or "did not receive a valid HTTP request" in msg) + + +class Config(ModuleConfig): + host: str | None = None + port: int = 3030 + start_timeout: float = 10.0 + + +class RerunWebSocketServer(Module): + """Receives dimos-viewer WebSocket events and publishes them as DimOS streams. + + The viewer connects to this module (not the other way around) when running + in ``--connect`` mode. Each click event is converted to a ``PointStamped`` + and published on the ``clicked_point`` stream so downstream modules (e.g. + ``ReplanningAStarPlanner``) can consume it without modification. + + Outputs: + clicked_point: 3-D world-space point from the most recent viewer click. + tele_cmd_vel: Twist velocity commands from keyboard teleop, including stop events. + + Note: ``stop_movement`` is owned by ``MovementManager`` — it will fire + that signal when it sees the first teleop twist arrive here. + """ + + config: Config + + clicked_point: Out[PointStamped] + tele_cmd_vel: Out[Twist] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._stop_event: asyncio.Event | None = None + self._server_ready = threading.Event() + self.host = self.config.host if self.config.host is not None else global_config.listen_host + + @rpc + def start(self) -> None: + super().start() + assert self._loop is not None + asyncio.run_coroutine_threadsafe(self._serve(), self._loop) + self._server_ready.wait(timeout=self.config.start_timeout) + self._log_connect_hints() + + @rpc + def stop(self) -> None: + self._server_ready.wait(timeout=self.config.start_timeout) + if self._loop is not None and not self._loop.is_closed() and self._stop_event is not None: + self._loop.call_soon_threadsafe(self._stop_event.set) + super().stop() + + def _log_connect_hints(self) -> None: + """Log full dimos-viewer commands that viewers can use to connect.""" + local_ips = get_local_ips() + hostname = socket.gethostname() + host = self.host + ws_url = f"ws://{host}:{self.config.port}/ws" + grpc_url = f"rerun+http://{host}:{RERUN_GRPC_PORT}/proxy" + + lines = [ + "", + "=" * 60, + f"RerunWebSocketServer listening on {ws_url}", + "", + "Connect a viewer:", + f" dimos-viewer --connect {grpc_url} --ws-url {ws_url}", + ] + if local_ips: + lines.append("") + lines.append("From another machine on the network:") + for ip, iface in local_ips: + remote_grpc = f"rerun+http://{ip}:{RERUN_GRPC_PORT}/proxy" + remote_ws = f"ws://{ip}:{self.config.port}/ws" + lines.append( + f" dimos-viewer --connect {remote_grpc} --ws-url {remote_ws} # {iface}" + ) + lines.append("") + lines.append(f" hostname: {hostname}") + lines.append("=" * 60) + lines.append("") + + logger.info("\n".join(lines)) + + async def _serve(self) -> None: + self._stop_event = asyncio.Event() + + ws_logger = logging.getLogger("websockets.server") + ws_logger.addFilter(_handshake_noise_filter) + + async with ws_server.serve( + self._handle_client, + host=self.host, + port=self.config.port, + ping_interval=30, + ping_timeout=30, + logger=ws_logger, + ): + self._server_ready.set() + await self._stop_event.wait() + + async def _handle_client(self, websocket: Any) -> None: + if hasattr(websocket, "request") and websocket.request.path != "/ws": + await websocket.close(1008, "Not Found") + return + addr = websocket.remote_address + logger.info(f"RerunWebSocketServer: viewer connected from {addr}") + try: + async for raw in websocket: + self._dispatch(raw) + except websockets.ConnectionClosed: + pass + + def _dispatch(self, raw: str | bytes) -> None: + try: + msg: dict[str, Any] = json.loads(raw) + except json.JSONDecodeError: + logger.warning(f"RerunWebSocketServer: ignoring non-JSON message: {raw!r}") + return + + if not isinstance(msg, dict): + return + + msg_type = msg.get("type") + + if msg_type == "click": + self.clicked_point.publish( + PointStamped( + x=float(msg.get("x", 0)), + y=float(msg.get("y", 0)), + z=float(msg.get("z", 0)), + ts=float(msg.get("timestamp_ms", 0)) / 1000.0, + frame_id=str(msg.get("entity_path", "")), + ) + ) + + elif msg_type == "twist": + self.tele_cmd_vel.publish( + Twist( + linear=Vector3( + float(msg.get("linear_x", 0)), + float(msg.get("linear_y", 0)), + float(msg.get("linear_z", 0)), + ), + angular=Vector3( + float(msg.get("angular_x", 0)), + float(msg.get("angular_y", 0)), + float(msg.get("angular_z", 0)), + ), + ) + ) + + elif msg_type == "stop": + self.tele_cmd_vel.publish(Twist.zero()) diff --git a/dimos/visualization/vis_module.py b/dimos/visualization/vis_module.py new file mode 100644 index 0000000000..badcba34db --- /dev/null +++ b/dimos/visualization/vis_module.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared visualization module factory for all robot blueprints.""" + +from typing import Any, get_args + +from dimos.core.coordination.blueprints import Blueprint, autoconnect +from dimos.visualization.rerun.constants import ViewerBackend + + +def vis_module( + viewer_backend: ViewerBackend, + rerun_config: dict[str, Any] | None = None, + foxglove_config: dict[str, Any] | None = None, +) -> Blueprint: + """Create a visualization blueprint based on the selected viewer backend. + + Bundles the appropriate viewer module (Rerun or Foxglove) together with + the ``WebsocketVisModule`` and ``RerunWebSocketServer`` so that the web + dashboard and remote viewer connections work out of the box. + + Example usage:: + + from dimos.core.global_config import global_config + viz = vis_module( + global_config.viewer, + rerun_config={ + "visual_override": { + "world/camera_info": lambda ci: ci.to_rerun(...), + }, + "static": { + "world/tf/base_link": lambda rr: [rr.Boxes3D(...)], + }, + }, + ) + """ + from dimos.visualization.rerun.websocket_server import RerunWebSocketServer + from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule + + if foxglove_config is None: + foxglove_config = {} + if rerun_config is None: + rerun_config = {} + + match viewer_backend: + case "foxglove": + from dimos.robot.foxglove_bridge import FoxgloveBridge + + return autoconnect( + FoxgloveBridge.blueprint(**foxglove_config), + RerunWebSocketServer.blueprint(), + WebsocketVisModule.blueprint(), + ) + case "rerun": + from dimos.core.global_config import global_config + from dimos.protocol.pubsub.impl.lcmpubsub import LCM + from dimos.visualization.rerun.bridge import RerunBridgeModule + + rerun_config = {**rerun_config} # copy (avoid mutation) + rerun_config.setdefault("pubsubs", [LCM()]) + rerun_config.setdefault("rerun_open", global_config.rerun_open) + rerun_config.setdefault("rerun_web", global_config.rerun_web) + return autoconnect( + RerunBridgeModule.blueprint( + **rerun_config, + ), + RerunWebSocketServer.blueprint(), + WebsocketVisModule.blueprint(), + ) + case "none": + return autoconnect(WebsocketVisModule.blueprint()) + case _: + valid = ", ".join(get_args(ViewerBackend)) + raise ValueError(f"Unknown viewer_backend {viewer_backend!r}. Expected one of: {valid}") diff --git a/dimos/web/websocket_vis/websocket_vis_module.py b/dimos/web/websocket_vis/websocket_vis_module.py index 3d6b3df11c..1ce7e74502 100644 --- a/dimos/web/websocket_vis/websocket_vis_module.py +++ b/dimos/web/websocket_vis/websocket_vis_module.py @@ -105,7 +105,7 @@ class WebsocketVisModule(Module): gps_goal: Out[LatLon] explore_cmd: Out[Bool] stop_explore_cmd: Out[Bool] - cmd_vel: Out[Twist] + tele_cmd_vel: Out[Twist] movecmd_stamped: Out[TwistStamped] def __init__(self, **kwargs: Any) -> None: @@ -158,9 +158,11 @@ def start(self) -> None: self._uvicorn_server_thread = threading.Thread(target=self._run_uvicorn_server, daemon=True) self._uvicorn_server_thread.start() - # Auto-open browser only for rerun-web (dashboard with Rerun iframe + command center) - # For rerun and foxglove, users access the command center manually if needed - if self.config.g.viewer == "rerun-web": + # Auto-open the dashboard tab only when the user explicitly asked for a + # web-based viewer (rerun_open == "web" or "both"). `rerun_web` alone + # only means "serve the viewer"; it should not trigger a browser popup + # when the user chose the native viewer. + if self.config.g.viewer == "rerun" and self.config.g.rerun_open in ("web", "both"): url = f"http://localhost:{self.config.port}/" logger.info(f"Dimensional Command Center: {url}") @@ -236,11 +238,13 @@ def _create_server(self) -> None: async def serve_index(request): # type: ignore[no-untyped-def] """Serve appropriate HTML based on viewer mode.""" - # If running native Rerun, redirect to standalone command center - if self.config.g.viewer != "rerun-web": + # Serve the full dashboard (with Rerun iframe) only when the rerun + # web server is enabled; otherwise redirect to the standalone + # command center. + if not ( + self.config.g.viewer == "rerun" and self.config.g.rerun_open in ("web", "both") + ): return RedirectResponse(url="/command-center") - - # Otherwise serve full dashboard with Rerun iframe return FileResponse(_DASHBOARD_HTML, media_type="text/html") async def serve_command_center(request): # type: ignore[no-untyped-def] @@ -333,14 +337,14 @@ async def clear_gps_goals(sid: str) -> None: @self.sio.event # type: ignore[untyped-decorator] async def move_command(sid: str, data: dict[str, Any]) -> None: # Publish Twist if transport is configured - if self.cmd_vel and self.cmd_vel.transport: + if self.tele_cmd_vel and self.tele_cmd_vel.transport: twist = Twist( linear=Vector3(data["linear"]["x"], data["linear"]["y"], data["linear"]["z"]), angular=Vector3( data["angular"]["x"], data["angular"]["y"], data["angular"]["z"] ), ) - self.cmd_vel.publish(twist) + self.tele_cmd_vel.publish(twist) # Publish TwistStamped if transport is configured if self.movecmd_stamped and self.movecmd_stamped.transport: diff --git a/docs/development/conventions.md b/docs/development/conventions.md new file mode 100644 index 0000000000..2b25a7c3c6 --- /dev/null +++ b/docs/development/conventions.md @@ -0,0 +1,12 @@ +This mostly to track when conventions change (with regard to codebase updates) because this codebase is under heavy development. Note: this is a non-exhaustive list of conventions. + +- Instead of using `RerunBridge` in blueprints we always use `vis_module` which allows the CLI to control if its foxglove, rerun, or no-vis at all +- When global_config.py shouldn't accidentally/indirectly import heavy libraries like rerun. But sometimes global_config needs the type definition or default value from a module. Preferably we import from the module file directly, however when thats not possible, we create a config.py for just that module's config and import that into global_config.py. +- When adding visualization tools to a blueprint/autoconnect, instead of using RerunBridge or WebsocketVisModule directly we should always use `vis_module`, which right now should look something like `vis_module(viewer_backend=global_config.viewer, rerun_config={}),` +- `DEFAULT_THREAD_JOIN_TIMEOUT` is used for all thread.join timeouts +- Don't use print inside of tests +- Module configs should be specified as `config: ModuleSpecificConfigClass` +- To customize the way rerun renders something, right now we use a `rerun_config` dict. This will (hopefully) change very soon to be a per-module config instead of a per-blueprint config +- Similar to the `rerun_config` the `rrb` (rerun blueprint) is defined at a blueprint level right now, but ideally would be a per-module contribution with only a per-blueprint override of the layout. +- No `__init__.py` files +- Helper blueprints (like `_with_vis`) that should not be used on their own need to start with an underscore to avoid being picked up by the all_blueprints.py code generation step diff --git a/docs/usage/cli.md b/docs/usage/cli.md index 017b441c7e..bba73368b2 100644 --- a/docs/usage/cli.md +++ b/docs/usage/cli.md @@ -18,7 +18,9 @@ dimos [GLOBAL OPTIONS] COMMAND [ARGS] | `--replay` / `--no-replay` | bool | `False` | Use recorded replay data | | `--replay-db` | TEXT | `go2_bigoffice` | Replay memory2 SQLite database name | | `--new-memory` / `--no-new-memory` | bool | `False` | Clear persistent memory on start | -| `--viewer` | `rerun\|rerun-web\|rerun-connect\|foxglove\|none` | `rerun` | Visualization backend | +| `--viewer` | `rerun\|foxglove\|none` | `rerun` | Visualization backend | +| `--rerun-open` | `native\|web\|both\|none` | `native` | How to open the Rerun viewer | +| `--rerun-web` / `--no-rerun-web` | bool | `False` | Serve the Rerun web viewer | | `--n-workers` | INT | `2` | Number of forkserver workers | | `--memory-limit` | TEXT | `auto` | Rerun viewer memory limit | | `--mcp-port` | INT | `9990` | MCP server port | diff --git a/docs/usage/visualization.md b/docs/usage/visualization.md index 57ad460354..9ece977a68 100644 --- a/docs/usage/visualization.md +++ b/docs/usage/visualization.md @@ -1,37 +1,43 @@ # Viewer Backends -Dimos supports three visualization backends: Rerun (web or native) and Foxglove. +Dimos supports three visualization backends: `rerun` (default), `foxglove`, and `none`. ## Quick Start -Choose your viewer via the CLI (preferred): +Choose your viewer via the CLI: ```bash # Rerun native viewer (default) - dimos-viewer with built-in teleop + click-to-navigate dimos run unitree-go2 -# Explicitly select the viewer mode: +# Explicitly select the viewer backend: dimos --viewer rerun run unitree-go2 -dimos --viewer rerun-web run unitree-go2 dimos --viewer foxglove run unitree-go2 +dimos --viewer none run unitree-go2 ``` -Alternative (environment variable): +Control how the Rerun viewer opens with `--rerun-open` and `--rerun-web`: ```bash -# Rerun native viewer (default) - dimos-viewer with built-in teleop + click-to-navigate -VIEWER=rerun dimos run unitree-go2 +# Open native desktop viewer (default) +dimos --rerun-open native run unitree-go2 + +# Open web viewer in browser +dimos --rerun-open web run unitree-go2 + +# Open both native and web +dimos --rerun-open both run unitree-go2 -# Rerun web viewer - browser dashboard + teleop at http://localhost:7779 -VIEWER=rerun-web dimos run unitree-go2 +# No viewer (headless) — data still accessible via gRPC +dimos --rerun-open none run unitree-go2 -# Foxglove - Use Foxglove Studio instead of Rerun -VIEWER=foxglove dimos run unitree-go2 +# Serve the web viewer without auto-opening a browser +dimos --rerun-web --rerun-open native run unitree-go2 ``` ## Viewer Modes Explained -### Rerun Native (`rerun`) — Default +### Rerun Native (`rerun`, `--rerun-open native`) — Default **What you get:** - [dimos-viewer](https://github.com/dimensionalOS/dimos-viewer), a custom Dimensional fork of Rerun with built-in keyboard teleop and click-to-navigate @@ -41,7 +47,7 @@ VIEWER=foxglove dimos run unitree-go2 --- -### Rerun Web (`rerun-web`) +### Rerun Web (`rerun`, `--rerun-open web`) **What you get:** - Browser-based dashboard at http://localhost:7779 @@ -63,18 +69,16 @@ VIEWER=foxglove dimos run unitree-go2 ## Rendering with Custom Blueprints -To enable rerun within your own blueprint simply include `RerunBridgeModule`: +To enable visualization in your own blueprint, use `vis_module`: ```python -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.core.global_config import global_config +from dimos.visualization.vis_module import vis_module from dimos.hardware.sensors.camera.module import CameraModule -from dimos.protocol.pubsub.impl.lcmpubsub import LCM camera_demo = autoconnect( CameraModule.blueprint(), - RerunBridgeModule.blueprint( - viewer_mode="native", # native (desktop), web (browser), none (headless) - ), + vis_module(viewer_backend=global_config.viewer), ) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index 4cc8e3d9f2..758b51fb00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,7 +82,7 @@ dependencies = [ # TODO: rerun shouldn't be required but rn its in core (there is NO WAY to use dimos without rerun rn) # remove this once rerun is optional in core "rerun-sdk>=0.20.0", - "dimos-viewer>=0.30.0a2", + "dimos-viewer==0.30.0a6.dev99", "toolz>=1.1.0", "protobuf>=6.33.5,<7", "psutil>=7.0.0", diff --git a/uv.lock b/uv.lock index 7357f46359..c77994a0e4 100644 --- a/uv.lock +++ b/uv.lock @@ -1958,7 +1958,7 @@ requires-dist = [ { name = "dimos", extras = ["base"], marker = "extra == 'unitree'" }, { name = "dimos-lcm" }, { name = "dimos-lcm", marker = "extra == 'docker'" }, - { name = "dimos-viewer", specifier = ">=0.30.0a2" }, + { name = "dimos-viewer", specifier = "==0.30.0a6.dev99" }, { name = "dimos-viewer", marker = "extra == 'visualization'", specifier = ">=0.30.0a4" }, { name = "drake", marker = "platform_machine != 'aarch64' and sys_platform == 'darwin' and extra == 'manipulation'", specifier = "==1.45.0" }, { name = "drake", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and extra == 'manipulation'", specifier = ">=1.40.0" }, @@ -2128,18 +2128,18 @@ wheels = [ [[package]] name = "dimos-viewer" -version = "0.30.0a6" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/0b/90/ad6d0e1e177a10a0b4f7e736436b6d2741acaeb402ab59504347236744f4/dimos_viewer-0.30.0a6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e623a21e6992e263513847e12809a0d234d73fc7af42a6428e84ca165ba682d0", size = 35309553, upload-time = "2026-03-18T15:22:26.874Z" }, - { url = "https://files.pythonhosted.org/packages/a1/84/1c8f41ff2bd5b6ee143eb6119107397dac284fa4f1f8335623c498bd1d9c/dimos_viewer-0.30.0a6-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:36068a3293cb1c7f4db9f4e6c9fea2d7dd2a2527025f803585f4d3aaad9aedbd", size = 39072034, upload-time = "2026-03-18T15:22:29.592Z" }, - { url = "https://files.pythonhosted.org/packages/58/e6/d6214245e5b99e1da262d037f52d3d39c6b87c65acb516fb08f11378e932/dimos_viewer-0.30.0a6-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:2bf36e8c8bd9dd822bedd1cb2d80ee2bf74b58184ba33872494baed0395fa7ff", size = 41447599, upload-time = "2026-03-18T15:22:32.699Z" }, - { url = "https://files.pythonhosted.org/packages/48/04/80f566400776cab9af68b4a3c0132f55786acd1641ea39d8b75e797a2e22/dimos_viewer-0.30.0a6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:947cfa10c583b357d589c10cb466c63b3651a83d1013a254c0ba03fc2959bef7", size = 35309552, upload-time = "2026-03-18T15:22:35.395Z" }, - { url = "https://files.pythonhosted.org/packages/4c/c3/72157e0806951c2c71c70dcd783e27be8d694344d7ecdb94eaef1066cf99/dimos_viewer-0.30.0a6-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:53ca4ac1f0778f1d9afb317b6268c941c02b20af86dd2aaaf1ea79f2c1d1eeb8", size = 39072018, upload-time = "2026-03-18T15:22:38.043Z" }, - { url = "https://files.pythonhosted.org/packages/2f/92/959fc1e9cdcb5fd8d793b2c8515a6086c9f913ba470baad1f3182ae4c242/dimos_viewer-0.30.0a6-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:27e108060a942c92f7869a0e45693dfe1798896bd90cbac6d1ce019a682f8ba7", size = 41447647, upload-time = "2026-03-18T15:22:41.003Z" }, - { url = "https://files.pythonhosted.org/packages/ab/d6/d76763b60d82539e92777500551116306cfea462f6976ad814a3bdf57e1d/dimos_viewer-0.30.0a6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f4f49f973c51055cfd594b68a8e9d183c706f94b1513b6b69db900d05850f741", size = 35309553, upload-time = "2026-03-18T15:22:43.681Z" }, - { url = "https://files.pythonhosted.org/packages/26/ab/6ea7686c467caecdc74dd8d3a0267053ac74229b3afebc64cff180d5074c/dimos_viewer-0.30.0a6-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:791ef1c1d8d41db69a7d2b701ed3f0b6bc39cb3264aaef7300eddb576c8df7ed", size = 39072062, upload-time = "2026-03-18T15:22:46.264Z" }, - { url = "https://files.pythonhosted.org/packages/3c/87/fce7aac56d8a234d3d7c0911928bb3471d7852e35263b966d2aac5be42cd/dimos_viewer-0.30.0a6-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:dd976c39c38718b8373e1894d55b78c10bcb8c5716c8dbd5fba59141bc08ab3c", size = 41447667, upload-time = "2026-03-18T15:22:49.214Z" }, +version = "0.30.0a6.dev99" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3d/0e/d363be05f172bafe5f41a95db318891637e902c50edfdc642edec6bb5111/dimos_viewer-0.30.0a6.dev99-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cfa57e68e8f4094d4a38d202414046fd2419ff2875ace3f16b8581c3106feca4", size = 35405401, upload-time = "2026-04-17T04:19:10.126Z" }, + { url = "https://files.pythonhosted.org/packages/e7/ab/0730fed402b3b92e35194f11b76119754d619fa6bab00a1932b5c78f87b3/dimos_viewer-0.30.0a6.dev99-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:f3bc243342131c8c2b653cc6b76f04d65aad525f5560829b78aa1a7d31a9d375", size = 39167146, upload-time = "2026-04-17T04:19:14.177Z" }, + { url = "https://files.pythonhosted.org/packages/bb/d9/1415d5d7e609d69b05e8e1167a66dd7cb78f3933205f9b321ae18233384c/dimos_viewer-0.30.0a6.dev99-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b954083fcb8951641554fdea95425b3b5ac9415cd1b65410a137d38d3dd57b8a", size = 41536165, upload-time = "2026-04-17T04:19:17.379Z" }, + { url = "https://files.pythonhosted.org/packages/93/7c/7ee6049a753c01ccbe8357f9c5f789378103b87331e5ca7977f05adf5c42/dimos_viewer-0.30.0a6.dev99-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0387201efd1260f968853f0d7863876b6db375b2af15b22f221a893fcce6549c", size = 35405408, upload-time = "2026-04-17T04:19:20.08Z" }, + { url = "https://files.pythonhosted.org/packages/de/2e/9b4252a12c4b641ab1479a6a4d3d576e75fc42ca2a797d88e2e0626abda0/dimos_viewer-0.30.0a6.dev99-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:a0fae6f2077fc6ceb25e1ed33fb7ccf183ef3e2a30456aa5462b953c1419e547", size = 39167138, upload-time = "2026-04-17T04:19:23.292Z" }, + { url = "https://files.pythonhosted.org/packages/46/2a/4bd02c3d79df2aefc5be47afda6b95121937cef0a3f6b15d071691ec3ca7/dimos_viewer-0.30.0a6.dev99-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e844015f3ad193d50201c39abd3e3f34abbf03adbfb1075468696c1236df1409", size = 41536172, upload-time = "2026-04-17T04:19:26.421Z" }, + { url = "https://files.pythonhosted.org/packages/1b/b1/efcea9b9e21c4ab75e2df016a27e5045e30d91a494465ab0cc627d8d8bc3/dimos_viewer-0.30.0a6.dev99-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:dc82061c2c025684c0fbed5392f793d137b1b0fc3aa1b601988bf4d2ee88aa27", size = 35405409, upload-time = "2026-04-17T04:19:29.574Z" }, + { url = "https://files.pythonhosted.org/packages/2d/8e/d482b0b9379c40ddd7547600543ce726fc3b5d10e396a876f22b2d76d0e6/dimos_viewer-0.30.0a6.dev99-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:0f6acfa0de3083e746ac43fe0d0a328d624bcb859dc698b1bbc592f444f52f15", size = 39167144, upload-time = "2026-04-17T04:19:32.301Z" }, + { url = "https://files.pythonhosted.org/packages/6d/eb/08922721c74ceaa99a824258db02c438d50f77c22ff80332cbc4b1a8db7b/dimos_viewer-0.30.0a6.dev99-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:56fa9139c49ec4bf96b12d6e98d3de3319a66876374ae57bda4534ab7a347765", size = 41536171, upload-time = "2026-04-17T04:19:35.29Z" }, ] [[package]] From 2a1ca9b2522b106c424cf7987233139b89e35a3c Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Mon, 27 Apr 2026 15:08:40 -0500 Subject: [PATCH 18/30] fix(types): resolve mypy 3.10 errors (#1921) --- .pre-commit-config.yaml | 1 - dimos/agents_deprecated/agent.py | 2 -- .../memory/spatial_vector_db.py | 4 +-- dimos/agents_deprecated/modules/base_agent.py | 3 ++ dimos/memory2/vis/plot/elements.py | 4 +-- dimos/memory2/vis/plot/plot.py | 4 +-- dimos/models/base.py | 4 +++ dimos/models/embedding/base.py | 1 - dimos/models/embedding/clip.py | 2 ++ dimos/models/embedding/mobileclip.py | 4 ++- dimos/models/embedding/test_embedding.py | 14 ++++++++ dimos/models/embedding/treid.py | 2 ++ dimos/models/qwen/bbox.py | 14 ++++++++ dimos/models/qwen/video_query.py | 17 ++++++++- dimos/models/segmentation/edge_tam.py | 5 ++- dimos/models/vl/base.py | 21 +++++++++++ dimos/models/vl/create.py | 17 +++++++++ dimos/models/vl/moondream.py | 16 +++++++++ dimos/models/vl/moondream_hosted.py | 22 ++++++++++-- dimos/models/vl/openai.py | 35 ++++++++++++++++--- dimos/models/vl/qwen.py | 32 ++++++++++++++--- dimos/models/vl/test_base.py | 14 ++++++++ dimos/models/vl/test_captioner.py | 14 ++++++++ dimos/models/vl/test_models.py | 0 dimos/models/vl/test_vlm.py | 18 ++++++++-- dimos/models/vl/types.py | 14 ++++++++ dimos/msgs/sensor_msgs/Image.py | 2 +- dimos/perception/detection/module2D.py | 3 +- .../detection/type/imageDetections.py | 8 ++++- dimos/protocol/rpc/pubsubrpc.py | 2 +- dimos/simulation/engines/mujoco_shm.py | 6 ++-- 31 files changed, 271 insertions(+), 34 deletions(-) delete mode 100644 dimos/models/vl/test_models.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 81d232da96..b5a06a16d0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,5 @@ default_stages: [pre-commit] default_install_hook_types: [pre-commit, commit-msg] -exclude: (dimos/models/.*)|(deprecated) repos: - repo: https://github.com/Lucas-C/pre-commit-hooks diff --git a/dimos/agents_deprecated/agent.py b/dimos/agents_deprecated/agent.py index 1d48ce2fa4..4515cd5bfb 100644 --- a/dimos/agents_deprecated/agent.py +++ b/dimos/agents_deprecated/agent.py @@ -897,5 +897,3 @@ def stream_query(self, query_text: str) -> Observable: # type: ignore[type-arg] return create( lambda observer, _: self._observable_query(observer, incoming_query=query_text) # type: ignore[arg-type] ) - - diff --git a/dimos/agents_deprecated/memory/spatial_vector_db.py b/dimos/agents_deprecated/memory/spatial_vector_db.py index b5e356dcc5..376fb00de9 100644 --- a/dimos/agents_deprecated/memory/spatial_vector_db.py +++ b/dimos/agents_deprecated/memory/spatial_vector_db.py @@ -227,8 +227,8 @@ def _process_query_results(self, results) -> list[dict]: # type: ignore[no-unty ) # Get the image from visual memory - #image = self.visual_memory.get(lookup_id) - #result["image"] = image + # image = self.visual_memory.get(lookup_id) + # result["image"] = image processed_results.append(result) diff --git a/dimos/agents_deprecated/modules/base_agent.py b/dimos/agents_deprecated/modules/base_agent.py index 5108dc5248..579ca6ee73 100644 --- a/dimos/agents_deprecated/modules/base_agent.py +++ b/dimos/agents_deprecated/modules/base_agent.py @@ -33,6 +33,7 @@ logger = setup_logger() + class BaseAgentConfig(ModuleConfig): model: str = "openai::gpt-4o-mini" system_prompt: str | None = None @@ -46,12 +47,14 @@ class BaseAgentConfig(ModuleConfig): rag_threshold: float = 0.45 process_all_inputs: bool = False + class BaseAgentModule(BaseAgent, Module): # type: ignore[misc] """Agent module that inherits from BaseAgent and adds DimOS module interface. This provides a thin wrapper around BaseAgent functionality, exposing it through the DimOS module system with RPC methods and stream I/O. """ + config: BaseAgentConfig # Module I/O - AgentMessage based communication diff --git a/dimos/memory2/vis/plot/elements.py b/dimos/memory2/vis/plot/elements.py index 8b5932da53..7f83de2b94 100644 --- a/dimos/memory2/vis/plot/elements.py +++ b/dimos/memory2/vis/plot/elements.py @@ -17,11 +17,11 @@ from __future__ import annotations from dataclasses import dataclass -from enum import StrEnum +from enum import Enum from typing import Union -class Style(StrEnum): +class Style(str, Enum): """Line style for Series and HLine elements. Values match matplotlib's `linestyle` names so they pass through directly diff --git a/dimos/memory2/vis/plot/plot.py b/dimos/memory2/vis/plot/plot.py index 6235e44bda..082b147125 100644 --- a/dimos/memory2/vis/plot/plot.py +++ b/dimos/memory2/vis/plot/plot.py @@ -16,13 +16,13 @@ from __future__ import annotations -from enum import StrEnum +from enum import Enum from typing import Any from dimos.memory2.vis.plot.elements import HLine, Markers, PlotElement, Series, VLine -class TimeAxis(StrEnum): +class TimeAxis(str, Enum): """How the x-axis is formatted. - ``raw``: unix timestamps as-is (matplotlib's default numeric formatter). diff --git a/dimos/models/base.py b/dimos/models/base.py index 5ded5196e6..c95418b2e9 100644 --- a/dimos/models/base.py +++ b/dimos/models/base.py @@ -27,12 +27,14 @@ # Device string type - 'cuda', 'cpu', 'cuda:0', 'cuda:1', etc. DeviceType = Annotated[str, "Device identifier (e.g., 'cuda', 'cpu', 'cuda:0')"] + class LocalModelConfig(BaseConfig): device: DeviceType = "cuda" if torch.cuda.is_available() else "cpu" dtype: torch.dtype = torch.float32 warmup: bool = False autostart: bool = False + class LocalModel(Resource, Configurable): """Base class for all local GPU/CPU models. @@ -121,11 +123,13 @@ def _ensure_cuda_initialized(self) -> None: except Exception: pass + class HuggingFaceModelConfig(LocalModelConfig): model_name: str = "" trust_remote_code: bool = True dtype: torch.dtype = torch.float16 + class HuggingFaceModel(LocalModel): """Base class for HuggingFace transformers-based models. diff --git a/dimos/models/embedding/base.py b/dimos/models/embedding/base.py index 4f5c4d8164..e3cad47a9a 100644 --- a/dimos/models/embedding/base.py +++ b/dimos/models/embedding/base.py @@ -167,5 +167,4 @@ def query( top_values, top_indices = similarities.topk(k=min(top_k, len(candidates))) return [(idx.item(), val.item()) for idx, val in zip(top_indices, top_values, strict=False)] - ... diff --git a/dimos/models/embedding/clip.py b/dimos/models/embedding/clip.py index 09bf15fc06..b0b4c99d76 100644 --- a/dimos/models/embedding/clip.py +++ b/dimos/models/embedding/clip.py @@ -26,10 +26,12 @@ from dimos.models.embedding.base import Embedding, EmbeddingModel, HuggingFaceEmbeddingModelConfig from dimos.msgs.sensor_msgs.Image import Image + class CLIPModelConfig(HuggingFaceEmbeddingModelConfig): model_name: str = "openai/clip-vit-base-patch32" dtype: torch.dtype = torch.float32 + class CLIPModel(EmbeddingModel, HuggingFaceModel): """CLIP embedding model for vision-language re-identification.""" diff --git a/dimos/models/embedding/mobileclip.py b/dimos/models/embedding/mobileclip.py index 7d7ab767fe..ff1f2efc5a 100644 --- a/dimos/models/embedding/mobileclip.py +++ b/dimos/models/embedding/mobileclip.py @@ -18,16 +18,18 @@ import open_clip from PIL import Image as PILImage import torch -import torch.nn.functional as F +import torch.nn.functional as F # noqa: N812 from dimos.models.base import LocalModel from dimos.models.embedding.base import Embedding, EmbeddingModel, EmbeddingModelConfig from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.data import get_data + class MobileCLIPModelConfig(EmbeddingModelConfig): model_name: str = "MobileCLIP2-S4" + class MobileCLIPModel(EmbeddingModel, LocalModel): """MobileCLIP embedding model for vision-language re-identification.""" diff --git a/dimos/models/embedding/test_embedding.py b/dimos/models/embedding/test_embedding.py index 20aac83dbb..ec6e60c627 100644 --- a/dimos/models/embedding/test_embedding.py +++ b/dimos/models/embedding/test_embedding.py @@ -1,3 +1,17 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import time from typing import Any diff --git a/dimos/models/embedding/treid.py b/dimos/models/embedding/treid.py index 836e502d3e..7f56ab3d3c 100644 --- a/dimos/models/embedding/treid.py +++ b/dimos/models/embedding/treid.py @@ -28,12 +28,14 @@ from dimos.msgs.sensor_msgs.Image import Image from dimos.utils.data import get_data + # osnet models downloaded from https://kaiyangzhou.github.io/deep-person-reid/MODEL_ZOO.html # into dimos/data/models_torchreid/ # feel free to add more class TorchReIDModelConfig(EmbeddingModelConfig): model_name: str = "osnet_x1_0" + class TorchReIDModel(EmbeddingModel, LocalModel): """TorchReID embedding model for person re-identification.""" diff --git a/dimos/models/qwen/bbox.py b/dimos/models/qwen/bbox.py index bddee8308f..e16ffe1f48 100644 --- a/dimos/models/qwen/bbox.py +++ b/dimos/models/qwen/bbox.py @@ -1 +1,15 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + BBox = tuple[float, float, float, float] # (x1, y1, x2, y2) diff --git a/dimos/models/qwen/video_query.py b/dimos/models/qwen/video_query.py index 584b6ebd16..725973102f 100644 --- a/dimos/models/qwen/video_query.py +++ b/dimos/models/qwen/video_query.py @@ -1,3 +1,17 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Utility functions for one-off video frame queries using Qwen model.""" import json @@ -161,7 +175,8 @@ def query_single_frame( def get_bbox_from_qwen( - video_stream: Observable, object_name: str | None = None # type: ignore[type-arg] + video_stream: Observable, # type: ignore[type-arg] + object_name: str | None = None, ) -> tuple[BBox, float] | None: """Get bounding box coordinates from Qwen for a specific object or any object. diff --git a/dimos/models/segmentation/edge_tam.py b/dimos/models/segmentation/edge_tam.py index cf9cb19e43..88ec707b1c 100644 --- a/dimos/models/segmentation/edge_tam.py +++ b/dimos/models/segmentation/edge_tam.py @@ -79,15 +79,14 @@ def __init__( OmegaConf.update(cfg, key, value) if cfg.model._target_ != "sam2.sam2_video_predictor.SAM2VideoPredictor": - logger.warning( - f"Config target is {cfg.model._target_}, forcing SAM2VideoPredictor" - ) + logger.warning(f"Config target is {cfg.model._target_}, forcing SAM2VideoPredictor") cfg.model._target_ = "sam2.sam2_video_predictor.SAM2VideoPredictor" self._predictor = instantiate(cfg.model, _recursive_=True) # Suppress the per-frame "propagate in video" tqdm bar from sam2 import sam2.sam2_video_predictor as _svp + _svp.tqdm = lambda iterable, *a, **kw: iterable ckpt_path = str(get_data("models_edgetam") / "edgetam.pt") diff --git a/dimos/models/vl/base.py b/dimos/models/vl/base.py index dfb046b58f..9daeb62792 100644 --- a/dimos/models/vl/base.py +++ b/dimos/models/vl/base.py @@ -1,3 +1,17 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import annotations from abc import ABC, abstractmethod @@ -18,6 +32,7 @@ logger = logging.getLogger(__name__) + class Captioner(ABC): """Interface for models that can generate image captions.""" @@ -47,9 +62,11 @@ def caption_batch(self, *images: Image) -> list[str]: """ return [self.caption(img) for img in images] + # Type alias for VLM detection format: [label, x1, y1, x2, y2] VlmDetection = tuple[str, float, float, float, float] + def vlm_detection_to_detection2d( vlm_detection: VlmDetection | list[str | float], track_id: int, @@ -103,9 +120,11 @@ def vlm_detection_to_detection2d( image=image, ) + # Type alias for VLM point format: [label, x, y] VlmPoint = tuple[str, float, float] + def vlm_point_to_detection2d_point( vlm_point: VlmPoint | list[str | float], track_id: int, @@ -152,12 +171,14 @@ def vlm_point_to_detection2d_point( track_id=track_id, ) + class VlModelConfig(BaseConfig): """Configuration for VlModel.""" auto_resize: tuple[int, int] | None = None """Optional (width, height) tuple. If set, images are resized to fit.""" + class VlModel(Captioner, Resource, Configurable): """Vision-language model that can answer questions about images. diff --git a/dimos/models/vl/create.py b/dimos/models/vl/create.py index b3c78cd7f1..362a95b000 100644 --- a/dimos/models/vl/create.py +++ b/dimos/models/vl/create.py @@ -1,14 +1,31 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dimos.models.vl.base import VlModel from dimos.models.vl.types import VlModelName __all__ = ["VlModelName", "create"] + def create(name: VlModelName) -> VlModel: # This uses inline imports to only import what's needed. match name: case "qwen": from dimos.models.vl.qwen import QwenVlModel + return QwenVlModel() case "moondream": from dimos.models.vl.moondream import MoondreamVlModel + return MoondreamVlModel() diff --git a/dimos/models/vl/moondream.py b/dimos/models/vl/moondream.py index 98be03202d..e3cfe744ce 100644 --- a/dimos/models/vl/moondream.py +++ b/dimos/models/vl/moondream.py @@ -1,3 +1,17 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from functools import cached_property from typing import Any import warnings @@ -17,6 +31,7 @@ # Moondream works well with 512x512 max MOONDREAM_DEFAULT_AUTO_RESIZE = (512, 512) + class MoondreamConfig(HuggingFaceModelConfig, VlModelConfig): """Configuration for MoondreamVlModel.""" @@ -24,6 +39,7 @@ class MoondreamConfig(HuggingFaceModelConfig, VlModelConfig): dtype: torch.dtype = torch.bfloat16 auto_resize: tuple[int, int] | None = MOONDREAM_DEFAULT_AUTO_RESIZE + class MoondreamVlModel(HuggingFaceModel, VlModel): config: MoondreamConfig _model_class = AutoModelForCausalLM diff --git a/dimos/models/vl/moondream_hosted.py b/dimos/models/vl/moondream_hosted.py index 7f675f0990..76e55451a1 100644 --- a/dimos/models/vl/moondream_hosted.py +++ b/dimos/models/vl/moondream_hosted.py @@ -1,5 +1,20 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from functools import cached_property import os +from typing import Any import warnings import moondream as md # type: ignore[import-untyped] @@ -12,9 +27,11 @@ from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D from dimos.perception.detection.type.detection2d.point import Detection2DPoint + class Config(VlModelConfig): api_key: str | None = None + class MoondreamHostedVlModel(VlModel): config: Config @@ -56,7 +73,9 @@ def caption(self, image: Image | np.ndarray, length: str = "normal") -> str: result = self._client.caption(pil_image, length=length) return result.get("caption", str(result)) # type: ignore[no-any-return] - def query_detections(self, image: Image, query: str, **kwargs) -> ImageDetections2D[Detection2DBBox]: # type: ignore[no-untyped-def] + def query_detections( + self, image: Image, query: str, **kwargs: Any + ) -> ImageDetections2D[Detection2DBBox]: """Detect objects using Moondream's hosted detect method. Args: @@ -146,4 +165,3 @@ def query_points( def stop(self) -> None: pass - diff --git a/dimos/models/vl/openai.py b/dimos/models/vl/openai.py index 57dae01160..9b1e7bb1a4 100644 --- a/dimos/models/vl/openai.py +++ b/dimos/models/vl/openai.py @@ -1,3 +1,17 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from functools import cached_property import os from typing import Any @@ -11,10 +25,12 @@ logger = setup_logger() + class OpenAIVlModelConfig(VlModelConfig): model_name: str = "gpt-4o-mini" api_key: str | None = None + class OpenAIVlModel(VlModel): config: OpenAIVlModelConfig @@ -28,7 +44,13 @@ def _client(self) -> OpenAI: return OpenAI(api_key=api_key) - def query(self, image: Image | np.ndarray, query: str, response_format: dict | None = None, **kwargs) -> str: # type: ignore[no-untyped-def, type-arg] + def query( + self, + image: Image | np.ndarray, + query: str, + response_format: dict[str, Any] | None = None, + **kwargs: Any, + ) -> str: if isinstance(image, np.ndarray): import warnings @@ -69,7 +91,11 @@ def query(self, image: Image | np.ndarray, query: str, response_format: dict | N return response.choices[0].message.content # type: ignore[no-any-return] def query_batch( - self, images: list[Image], query: str, response_format: dict[str, Any] | None = None, **kwargs: Any + self, + images: list[Image], + query: str, + response_format: dict[str, Any] | None = None, + **kwargs: Any, ) -> list[str]: """Query VLM with multiple images using a single API call.""" if not images: @@ -78,7 +104,9 @@ def query_batch( content: list[dict[str, Any]] = [ { "type": "image_url", - "image_url": {"url": f"data:image/png;base64,{self._prepare_image(img)[0].to_base64()}"}, + "image_url": { + "url": f"data:image/png;base64,{self._prepare_image(img)[0].to_base64()}" + }, } for img in images ] @@ -98,4 +126,3 @@ def stop(self) -> None: """Release the OpenAI client.""" if "_client" in self.__dict__: del self.__dict__["_client"] - diff --git a/dimos/models/vl/qwen.py b/dimos/models/vl/qwen.py index a55dbdfeba..a59cd31355 100644 --- a/dimos/models/vl/qwen.py +++ b/dimos/models/vl/qwen.py @@ -1,3 +1,17 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from functools import cached_property import os from typing import Any @@ -8,12 +22,14 @@ from dimos.models.vl.base import VlModel, VlModelConfig from dimos.msgs.sensor_msgs.Image import Image + class QwenVlModelConfig(VlModelConfig): """Configuration for Qwen VL model.""" model_name: str = "qwen2.5-vl-72b-instruct" api_key: str | None = None + class QwenVlModel(VlModel): config: QwenVlModelConfig @@ -66,17 +82,23 @@ def query(self, image: Image | np.ndarray, query: str) -> str: # type: ignore[o return response.choices[0].message.content # type: ignore[return-value] def query_batch( - self, images: list[Image], query: str, response_format: dict[str, Any] | None = None, **kwargs: Any + self, + images: list[Image], + query: str, + response_format: dict[str, Any] | None = None, + **kwargs: Any, ) -> list[str]: """Query VLM with multiple images using a single API call.""" if not images: return [] content: list[dict[str, Any]] = [ - { - "type": "image_url", - "image_url": {"url": f"data:image/png;base64,{self._prepare_image(img)[0].to_base64()}"}, - } + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{self._prepare_image(img)[0].to_base64()}" + }, + } for img in images ] content.append({"type": "text", "text": query}) diff --git a/dimos/models/vl/test_base.py b/dimos/models/vl/test_base.py index b0b03e70fa..de7bc2898d 100644 --- a/dimos/models/vl/test_base.py +++ b/dimos/models/vl/test_base.py @@ -1,3 +1,17 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from unittest.mock import MagicMock from dimos_lcm.foxglove_msgs.ImageAnnotations import ImageAnnotations diff --git a/dimos/models/vl/test_captioner.py b/dimos/models/vl/test_captioner.py index 734c83290e..7dc73196f0 100644 --- a/dimos/models/vl/test_captioner.py +++ b/dimos/models/vl/test_captioner.py @@ -1,3 +1,17 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from collections.abc import Generator import time from typing import Protocol, TypeVar diff --git a/dimos/models/vl/test_models.py b/dimos/models/vl/test_models.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/models/vl/test_vlm.py b/dimos/models/vl/test_vlm.py index f8b1c281d5..1a7a3add4d 100644 --- a/dimos/models/vl/test_vlm.py +++ b/dimos/models/vl/test_vlm.py @@ -1,3 +1,17 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import time from typing import TYPE_CHECKING @@ -35,7 +49,7 @@ @pytest.mark.slow @pytest.mark.skipif_in_ci def test_vlm_bbox_detections(model_class: "type[VlModel]", model_name: str) -> None: - if model_class is MoondreamHostedVlModel and 'MOONDREAM_API_KEY' not in os.environ: + if model_class is MoondreamHostedVlModel and "MOONDREAM_API_KEY" not in os.environ: pytest.skip("Need MOONDREAM_API_KEY to run") image = Image.from_file(get_data("cafe.jpg")).to_rgb() @@ -110,7 +124,7 @@ def test_vlm_bbox_detections(model_class: "type[VlModel]", model_name: str) -> N def test_vlm_point_detections(model_class: "type[VlModel]", model_name: str) -> None: """Test VLM point detection capabilities.""" - if model_class is MoondreamHostedVlModel and 'MOONDREAM_API_KEY' not in os.environ: + if model_class is MoondreamHostedVlModel and "MOONDREAM_API_KEY" not in os.environ: pytest.skip("Need MOONDREAM_API_KEY to run") image = Image.from_file(get_data("cafe.jpg")).to_rgb() diff --git a/dimos/models/vl/types.py b/dimos/models/vl/types.py index ac8b0f024d..d20a61ae37 100644 --- a/dimos/models/vl/types.py +++ b/dimos/models/vl/types.py @@ -1,3 +1,17 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Literal VlModelName = Literal["qwen", "moondream"] diff --git a/dimos/msgs/sensor_msgs/Image.py b/dimos/msgs/sensor_msgs/Image.py index 58a328d2fa..e25ee72611 100644 --- a/dimos/msgs/sensor_msgs/Image.py +++ b/dimos/msgs/sensor_msgs/Image.py @@ -482,7 +482,7 @@ def lcm_encode(self, frame_id: str | None = None) -> bytes: channels = 1 if self.data.ndim == 2 else self.data.shape[2] msg.step = self.width * self.dtype.itemsize * channels - view = memoryview(np.ascontiguousarray(self.data)).cast("B") + view = memoryview(np.ascontiguousarray(self.data)).cast("B") # type: ignore[arg-type] msg.data_length = len(view) msg.data = view diff --git a/dimos/perception/detection/module2D.py b/dimos/perception/detection/module2D.py index fb07e02d3c..3f9aee84e4 100644 --- a/dimos/perception/detection/module2D.py +++ b/dimos/perception/detection/module2D.py @@ -78,7 +78,8 @@ def process_image_frame(self, image: Image) -> ImageDetections2D: imageDetections = self.detector.process_image(image) if not self.config.filter: return imageDetections - return imageDetections.filter(*self.config.filter) + filtered: ImageDetections2D = imageDetections.filter(*self.config.filter) + return filtered @simple_mcache def sharp_image_stream(self) -> Observable[Image]: diff --git a/dimos/perception/detection/type/imageDetections.py b/dimos/perception/detection/type/imageDetections.py index d1fac8669c..1eea0c9c3c 100644 --- a/dimos/perception/detection/type/imageDetections.py +++ b/dimos/perception/detection/type/imageDetections.py @@ -16,7 +16,13 @@ from functools import reduce from operator import add -from typing import TYPE_CHECKING, Generic, Self, TypeVar +import sys +from typing import TYPE_CHECKING, Generic, TypeVar + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self from dimos_lcm.vision_msgs import Detection2DArray diff --git a/dimos/protocol/rpc/pubsubrpc.py b/dimos/protocol/rpc/pubsubrpc.py index 926bc7dfb8..abfd521666 100644 --- a/dimos/protocol/rpc/pubsubrpc.py +++ b/dimos/protocol/rpc/pubsubrpc.py @@ -91,7 +91,7 @@ def __init__( def __getstate__(self) -> dict[str, Any]: state: dict[str, Any] if hasattr(super(), "__getstate__"): - state = super().__getstate__() # type: ignore[assignment] + state = super().__getstate__() # type: ignore[misc, assignment] else: state = self.__dict__.copy() diff --git a/dimos/simulation/engines/mujoco_shm.py b/dimos/simulation/engines/mujoco_shm.py index 15abb5d170..c0623c7915 100644 --- a/dimos/simulation/engines/mujoco_shm.py +++ b/dimos/simulation/engines/mujoco_shm.py @@ -203,7 +203,8 @@ def read_position_command(self, num_joints: int) -> NDArray[np.float64] | None: return None self._last_pos_cmd_seq = seq arr = self._array(self.shm.pos_t, MAX_JOINTS, np.float64) - return arr[:num_joints].copy() + result: NDArray[np.float64] = arr[:num_joints].copy() + return result def read_velocity_command(self, num_joints: int) -> NDArray[np.float64] | None: seq = self._get_seq(SEQ_VELOCITY_CMD) @@ -211,7 +212,8 @@ def read_velocity_command(self, num_joints: int) -> NDArray[np.float64] | None: return None self._last_vel_cmd_seq = seq arr = self._array(self.shm.vel_t, MAX_JOINTS, np.float64) - return arr[:num_joints].copy() + result: NDArray[np.float64] = arr[:num_joints].copy() + return result def read_gripper_command(self) -> float | None: seq = self._get_seq(SEQ_GRIPPER_CMD) From 0de4da76c0b8a9fa84103cce5308a90b330d57c2 Mon Sep 17 00:00:00 2001 From: RD <63036454+ruthwikdasyam@users.noreply.github.com> Date: Mon, 27 Apr 2026 17:45:15 -0700 Subject: [PATCH 19/30] feat(go2): go2 SDK adapter + nix cyclonedds setup (#1885) Co-authored-by: Mustafa Bhadsorawala <39084056+mustafab0@users.noreply.github.com> --- .github/workflows/ci.yml | 2 +- .github/workflows/macos.yml | 2 +- .gitignore | 1 + dimos/control/coordinator.py | 1 + .../drive_trains/unitree_go2/README.md | 69 ++ .../drive_trains/unitree_go2/adapter.py | 691 ++++++++++++++++++ dimos/robot/all_blueprints.py | 1 + .../basic/unitree_go2_keyboard_teleop.py | 68 ++ docs/usage/transports/dds.md | 33 +- pyproject.toml | 7 + uv.lock | 53 +- 11 files changed, 924 insertions(+), 4 deletions(-) create mode 100644 dimos/hardware/drive_trains/unitree_go2/README.md create mode 100644 dimos/hardware/drive_trains/unitree_go2/adapter.py create mode 100644 dimos/robot/unitree/go2/blueprints/basic/unitree_go2_keyboard_teleop.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bcdb6f28d6..9d7b8caf8a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -36,7 +36,7 @@ jobs: git clean -ffdx -e .venv - name: Install Python dependencies - run: uv sync --all-extras --no-extra dds --frozen + run: uv sync --all-extras --no-extra dds --no-extra unitree-dds --frozen - name: Remove pydrake stubs run: | diff --git a/.github/workflows/macos.yml b/.github/workflows/macos.yml index 62684238f5..a289090c96 100644 --- a/.github/workflows/macos.yml +++ b/.github/workflows/macos.yml @@ -69,7 +69,7 @@ jobs: - name: Install dependencies run: | - uv sync --all-extras --no-extra dds --no-extra cuda --frozen + uv sync --all-extras --no-extra dds --no-extra unitree-dds --no-extra cuda --frozen - name: Build C++ extensions run: | diff --git a/.gitignore b/.gitignore index 9b2c6a5442..ea68926e96 100644 --- a/.gitignore +++ b/.gitignore @@ -65,6 +65,7 @@ yolo11n.pt *mobileclip* /results +/result **/cpp/result CLAUDE.MD diff --git a/dimos/control/coordinator.py b/dimos/control/coordinator.py index ba682b96ba..555db13522 100644 --- a/dimos/control/coordinator.py +++ b/dimos/control/coordinator.py @@ -249,6 +249,7 @@ def _create_twist_base_adapter(self, component: HardwareComponent) -> TwistBaseA dof=len(component.joints), address=component.address, hardware_id=component.hardware_id, + **component.adapter_kwargs, ) def _create_task_from_config(self, cfg: TaskConfig) -> ControlTask: diff --git a/dimos/hardware/drive_trains/unitree_go2/README.md b/dimos/hardware/drive_trains/unitree_go2/README.md new file mode 100644 index 0000000000..dea50cf728 --- /dev/null +++ b/dimos/hardware/drive_trains/unitree_go2/README.md @@ -0,0 +1,69 @@ +# Unitree Go2 drive-train adapter + +[`adapter.py`](adapter.py) — `UnitreeGo2TwistAdapter` (high-level): Twist `(vx, vy, wz)` via SportClient, with optional Rage Mode (`rage_mode=True`, ~2.5 m/s forward envelope). Auto-registered as `"unitree_go2"` and used by blueprints like `unitree-go2-keyboard-teleop`. This is the one you want for teleop, navigation, or anything velocity-commanded. + +--- + +## Running + +Build the CycloneDDS C library via nix (once per machine — creates +`./result` symlink at the repo root, which acts as a GC root): + +```bash +nix build nixpkgs#cyclonedds +``` + +Point your shell / venv at it so `cyclonedds-python` can find the C +library at install and runtime. Easiest: append to `.venv/bin/activate` +so it's set every time you activate the venv: + +```bash +cat >> .venv/bin/activate < float: + return lo if x < lo else hi if x > hi else x + + +@dataclass +class _Session: + """Active connection state for a Go2. + + The session object is created by connect() and set on the adapter under + _session_lock. All mutable state that can be touched by both the DDS + callback thread and the control thread lives here, guarded by `lock`. + """ + + client: SportClient + motion_switcher: MotionSwitcherClient + lock: threading.Lock + state_sub: ChannelSubscriber | None = None + latest_state: SportModeState_ | None = None + enabled: bool = False + locomotion_ready: bool = False + + # Rage Mode joystick publisher (rt/wirelesscontroller_unprocessed path) + rage_active: bool = False + rage_pub: ChannelPublisher | None = None + rage_thread: threading.Thread | None = None + rage_stop: threading.Event | None = None + rage_cmd: tuple[float, float, float] = (0.0, 0.0, 0.0) + + +class UnitreeGo2TwistAdapter: + """TwistBaseAdapter for the Unitree Go2 quadruped via unitree_sdk2py (DDS). + + 3 DOF velocity: [vx, vy, wz]. + - vx: forward/backward linear velocity (m/s) + - vy: lateral (left positive) linear velocity (m/s) + - wz: yaw rate (rad/s) + + Thread model: + - _session_lock guards the self._session reference across threads. + - session.lock guards latest_state and SportClient RPC serialization. + Never take _session_lock while holding session.lock - the DDS callback + already holds session.lock briefly during state updates. + """ + + # AI-controller API ID for the Rage Mode toggle. + _SPORT_API_ID_RAGEMODE: int = 2059 + + # Rage velocity envelope (m/s, m/s, rad/s) from rage_mode_export_cfg.json. + _RAGE_UP_VX: float = 2.5 + _RAGE_UP_VY: float = 1.0 + _RAGE_UP_VYAW: float = 5.0 + + _RAGE_PUBLISH_HZ: float = 100.0 + _RAGE_LY_SIGN: float = 1.0 # vx → ly + _RAGE_LX_SIGN: float = -1.0 # vy → lx + _RAGE_RX_SIGN: float = -1.0 # wz → rx + + def __init__( + self, + dof: int = 3, + speed_level: int = 1, + rage_mode: bool = False, + **_: Any, + ) -> None: + if dof != 3: + raise ValueError(f"Go2 only supports 3 DOF (vx, vy, wz), got {dof}") + + self._session: _Session | None = None + self._session_lock = threading.Lock() + self._speed_level = speed_level + self._rage_mode_default = rage_mode + self._last_guard_warn_ts: float = 0.0 + + def connect(self) -> bool: + """Connect to Go2, verify sport mode, stand up, enter FreeWalk. + + Sequence: + 1. ChannelFactoryInitialize(0) — default domain, default NIC. + 2. MotionSwitcher.Init + poll CheckMode() until a sport mode + is reported (DDS discovery) or _DISCOVERY_TIMEOUT_S elapses. + 3. Subscribe rt/sportmodestate for telemetry. + 4. SportClient.Init. + 5. _initialize_locomotion(): StandUp + FreeWalk + SpeedLevel. + 6. If rage_mode=True, set_rage_mode(True). + + Returns True on success, False on connect/init/locomotion + failure. On failure, logs guidance and the adapter stays in a + clean "not connected" state so a retry can succeed. + """ + with self._session_lock: + if self._session is not None: + logger.warning("[Go2] Already connected — disconnect first") + return False + + # ChannelFactoryInitialize raises if the factory already exists. + try: + ChannelFactoryInitialize(0) + except Exception: + pass + + motion_switcher = MotionSwitcherClient() + motion_switcher.SetTimeout(0.5) + motion_switcher.Init() + + # Poll CheckMode() through DDS discovery + mode = "" + for _ in range(50): + try: + code, data = motion_switcher.CheckMode() + except (OSError, RuntimeError, TimeoutError): + time.sleep(0.1) + continue + if code == 0 and isinstance(data, dict): + mode = (data.get("name") or "").strip() + if mode: + break + time.sleep(0.1) + motion_switcher.SetTimeout(5.0) + if not mode: + logger.error("[Go2] No sport mode active") + return False + logger.info(f"[Go2] Sport mode '{mode}' active") + + client = SportClient() + client.SetTimeout(10.0) + + session = _Session( + client=client, + motion_switcher=motion_switcher, + lock=threading.Lock(), + ) + + def state_callback(msg: SportModeState_) -> None: + with session.lock: + session.latest_state = msg + + state_sub = ChannelSubscriber("rt/sportmodestate", SportModeState_) + state_sub.Init(state_callback, 10) + session.state_sub = state_sub + + with self._session_lock: + self._session = session + + # disconnect() must run on any failure + try: + client.Init() + logger.info("[Go2] Connected") + + if not self._initialize_locomotion(): + logger.error("[Go2] Failed to initialize locomotion mode") + self.disconnect() + return False + + if self._rage_mode_default and not self.set_rage_mode(True): + logger.warning("[Go2] Rage Mode enable failed — continuing with regular locomotion") + except Exception: + self.disconnect() + raise + + return True + + def disconnect(self) -> None: + """Stop motion, stand the robot down, and tear down DDS resources. + + Safe to call multiple times. Explicitly Close()s the state + subscriber to prevent DDS reader leaks across reconnects. + """ + with self._session_lock: + session = self._session + self._session = None + + if session is None: + return + + self._stop_rage_joystick(session) + try: + with session.lock: + session.client.StopMove() + with session.lock: + session.client.StandDown() + except (OSError, RuntimeError, TimeoutError) as e: + logger.error(f"[Go2] Error during disconnect: {e}") + + if session.state_sub is not None: + try: + session.state_sub.Close() + except (OSError, RuntimeError) as e: + logger.error(f"[Go2] Error closing state subscriber: {e}") + + def is_connected(self) -> bool: + with self._session_lock: + return self._session is not None + + def get_dof(self) -> int: + """Always 3 for Go2 (vx, vy, wz).""" + return 3 + + def read_velocities(self) -> list[float]: + """Measured velocities [vx, vy, wz] from SportModeState_. + + Sources: + vx, vy: state.velocity[0], state.velocity[1] + wz: state.imu_state.gyroscope[2] + + Returns [0.0, 0.0, 0.0] during the startup gap before the first + DDS callback has populated latest_state. + """ + session = self._get_session() + with session.lock: + if session.latest_state is None: + return [0.0, 0.0, 0.0] + state = session.latest_state + return [ + float(state.velocity[0]), + float(state.velocity[1]), + float(state.imu_state.gyroscope[2]), + ] + + def read_odometry(self) -> list[float] | None: + """Measured pose [x, y, theta] from SportModeState_. + + Sources: + x, y: state.position[0], state.position[1] + theta: state.imu_state.rpy[2] (yaw) + + Returns None if no state message has arrived yet. + """ + session = self._get_session() + with session.lock: + if session.latest_state is None: + return None + state = session.latest_state + return [ + float(state.position[0]), + float(state.position[1]), + float(state.imu_state.rpy[2]), + ] + + def write_velocities(self, velocities: list[float]) -> bool: + """Send a Twist command [vx, vy, wz] to the Go2. + + When Rage Mode is active, the command is stashed in + session.rage_cmd and the 100 Hz joystick publisher thread picks + it up on its next tick (Rage's FSM ignores SportClient.Move). + Otherwise the command is forwarded directly via + SportClient.Move() → FsmFreeWalk. + + Refuses (returns False) if: + - len(velocities) != 3 + - session not enabled (write_enable(True) not called) + - locomotion not ready (StandUp/FreeWalk incomplete) + + Guard warnings are rate-limited to 1 Hz since this is called + at 100 Hz from the tick loop. + """ + if len(velocities) != 3: + return False + + session = self._get_session() + + if not session.enabled: + self._warn_guard("Not enabled, ignoring velocity command") + return False + + if not session.locomotion_ready: + self._warn_guard("Locomotion not ready, ignoring velocity command") + return False + + vx, vy, wz = velocities + + if session.rage_active: + session.rage_cmd = (vx, vy, wz) + return True + + return self._send_velocity(vx, vy, wz) + + def _warn_guard(self, msg: str) -> None: + """Rate-limited guard warning (at most once per second). + + write_velocities runs at 100 Hz from the tick loop; without + throttling, a sustained guard miss would emit 100 warnings/s. + """ + now = time.monotonic() + if now - self._last_guard_warn_ts < 1.0: + return + self._last_guard_warn_ts = now + logger.warning(f"[Go2] {msg}") + + def write_stop(self) -> bool: + """Stop motion via SportClient.StopMove(). Leaves robot standing.""" + session = self._get_session() + with session.lock: + session.client.StopMove() + return True + + def write_enable(self, enable: bool) -> bool: + """Enable/disable velocity command path. + + enable=True: ensures locomotion is ready (re-initializes if needed), + then flips session.enabled. + enable=False: calls write_stop() and clears session.enabled. Does + NOT stand the robot down — use disconnect() for that. + """ + session = self._get_session() + + if enable: + if not session.locomotion_ready: + if not self._initialize_locomotion(): + logger.error("[Go2] Failed to initialize locomotion") + return False + session.enabled = True + logger.info("[Go2] Enabled") + return True + + self.write_stop() + session.enabled = False + logger.info("[Go2] Disabled") + return True + + def read_enabled(self) -> bool: + with self._session_lock: + return self._session is not None and self._session.enabled + + def check_mode(self) -> str | None: + """Return the current MotionSwitcher mode name, or None on RPC fail. + + Wraps MotionSwitcher.CheckMode(). Empty string means no controller + active; None means the RPC returned a non-zero code or non-dict data. + """ + session = self._get_session() + code, data = session.motion_switcher.CheckMode() + if code == 0 and isinstance(data, dict): + return (data.get("name") or "").strip() + return None + + def get_sport_state(self) -> SportModeState_ | None: + """Return the latest SportModeState_ snapshot for diagnostics. + + Returned object is the live SDK message — do not mutate it. None + if no state message has arrived. + """ + session = self._get_session() + with session.lock: + return session.latest_state + + def get_status(self) -> dict[str, Any]: + """One-shot snapshot of adapter + robot state""" + with self._session_lock: + session = self._session + + if session is None: + return { + "connected": False, + "mode": None, + "enabled": False, + "locomotion_ready": False, + "rage_active": False, + "speed_level": self._speed_level, + "has_state": False, + "velocity": None, + "position": None, + "body_height": None, + "sport_mode_num": None, + } + + mode = self.check_mode() + + with session.lock: + state = session.latest_state + enabled = session.enabled + locomotion_ready = session.locomotion_ready + rage_active = session.rage_active + + velocity: list[float] | None = None + position: list[float] | None = None + body_height: float | None = None + sport_mode_num: int | None = None + + if state is not None: + try: + velocity = [ + float(state.velocity[0]), + float(state.velocity[1]), + float(state.imu_state.gyroscope[2]), + ] + position = [ + float(state.position[0]), + float(state.position[1]), + float(state.imu_state.rpy[2]), + ] + body_height = float(state.body_height) + sport_mode_num = int(state.mode) + except (AttributeError, IndexError, TypeError, ValueError): + pass + + return { + "connected": True, + "mode": mode, + "enabled": enabled, + "locomotion_ready": locomotion_ready, + "rage_active": rage_active, + "speed_level": self._speed_level, + "has_state": state is not None, + "velocity": velocity, + "position": position, + "body_height": body_height, + "sport_mode_num": sport_mode_num, + } + + def set_speed_level(self, level: int) -> bool: + """Set the SportClient speed envelope at runtime. + + Go2 SDK convention: -1 = slow, 0 = normal, 1 = fast (max). When + Rage is active, the Rage envelope (_RAGE_UP_VX etc.) applies + instead. Updates self._speed_level so subsequent + _initialize_locomotion() calls apply the same level. + + Returns True if the RPC returned 0. + """ + session = self._get_session() + with session.lock: + ret = session.client.SpeedLevel(level) + + if ret != 0: + logger.warning(f"[Go2] SpeedLevel({level}) returned {ret}") + return False + + self._speed_level = level + logger.info(f"[Go2] SpeedLevel set to {level}") + return True + + def set_rage_mode(self, enable: bool) -> bool: + """Toggle Rage Mode (api_id 2059) — widens forward envelope to ~2.5 m/s. + + Velocity input flows via rt/wirelesscontroller_unprocessed, not + SportClient.Move (FsmRageMode isn't in AiController::Move's dispatch). + Idempotent. Returns True on 2059 success; publisher/SwitchJoystick + failures are logged but don't fail the call. + """ + session = self._get_session() + + if session.rage_active == enable: + return True + + with session.lock: + ret = session.client.BalanceStand() + if ret != 0: + # Non-zero is usually benign here (already balanced / FSM transition + # in progress) — only fatal if the rage toggle below also fails. + logger.info(f"[Go2] BalanceStand returned {ret} (likely already balanced — proceeding)") + time.sleep(0.3) + + if not self._call_sport_api(self._SPORT_API_ID_RAGEMODE, {"data": enable}): + return False + + if enable: + time.sleep(2.0) # let FsmRageMode transition settle + self._start_rage_joystick(session) + with session.lock: + sj_ret = session.client.SwitchJoystick(True) + if sj_ret != 0: + logger.warning(f"[Go2] SwitchJoystick(True) after rage returned {sj_ret}") + else: + self._stop_rage_joystick(session) + with session.lock: + sj_ret = session.client.SwitchJoystick(False) + if sj_ret != 0: + logger.warning(f"[Go2] SwitchJoystick(False) after rage returned {sj_ret}") + + logger.info(f"[Go2] Rage Mode {'enabled' if enable else 'disabled'}") + return True + + def _start_rage_joystick(self, session: _Session) -> None: + """Create the WirelessController publisher and spawn the 100Hz thread.""" + if session.rage_pub is not None: + return + pub = ChannelPublisher("rt/wirelesscontroller_unprocessed", WirelessController_) + pub.Init() + session.rage_pub = pub + + session.rage_stop = threading.Event() + session.rage_cmd = (0.0, 0.0, 0.0) + session.rage_active = True + session.rage_thread = threading.Thread( + target=self._rage_joystick_loop, + args=(session,), + name="go2-rage-joystick", + daemon=True, + ) + session.rage_thread.start() + + def _stop_rage_joystick(self, session: _Session) -> None: + """Stop the publisher thread and release the DDS writer. + + Closes ChannelPublisher explicitly to avoid leaking the DDS writer + across repeated set_rage_mode(True/False) cycles. + """ + session.rage_active = False + if session.rage_stop is not None: + session.rage_stop.set() + if session.rage_thread is not None: + session.rage_thread.join(timeout=1.0) + session.rage_thread = None + session.rage_stop = None + if session.rage_pub is not None: + try: + session.rage_pub.Close() + except (OSError, RuntimeError) as e: + logger.warning(f"[Go2] Rage publisher Close raised: {e}") + session.rage_pub = None + + def _rage_joystick_loop(self, session: _Session) -> None: + """Publish the latest rage_cmd as a WirelessController_ message. + + Runs at _RAGE_PUBLISH_HZ. On each tick, reads session.rage_cmd, + normalizes to stick axes via the envelope constants, and writes + a WirelessController_ message. Exits when rage_stop is set or + the session's publisher is torn down. + """ + period = 1.0 / self._RAGE_PUBLISH_HZ + msg = unitree_go_msg_dds__WirelessController_() + msg.keys = 0 + msg.ry = 0.0 + + while session.rage_stop is not None and not session.rage_stop.wait(period): + pub = session.rage_pub + if pub is None: + return + vx, vy, wz = session.rage_cmd + + ly = _clip(vx / self._RAGE_UP_VX, -1.0, 1.0) * self._RAGE_LY_SIGN + lx = _clip(vy / self._RAGE_UP_VY, -1.0, 1.0) * self._RAGE_LX_SIGN + rx = _clip(wz / self._RAGE_UP_VYAW, -1.0, 1.0) * self._RAGE_RX_SIGN + + msg.lx = float(lx) + msg.ly = float(ly) + msg.rx = float(rx) + + try: + pub.Write(msg) + except (OSError, RuntimeError) as e: + logger.warning(f"[Go2] Rage joystick publish raised: {e}") + return + + def _call_sport_api(self, api_id: int, payload: dict[str, Any] | None = None) -> bool: + """Generic escape hatch for undocumented mcf sport API IDs. + + SportClient's internal dispatcher rejects unregistered api_ids + with code 3103 (RPC_ERR_CLIENT_API_NOT_REG) before any message + leaves the process — the public SDK only registers its named + methods in __init__. We call _RegistApi() first (idempotent dict + set) so undocumented IDs like RAGEMODE reach the robot. + + Uses leading-underscore SDK methods (_RegistApi, _Call) — these + are not part of the public SDK contract. Verified working against + unitree-sdk2py-dimos>=1.0.2; retest if the SDK is upgraded. + + Returns True on RPC code 0. On failure, logs code + response. + """ + session = self._get_session() + body = json.dumps(payload or {}) + with session.lock: + session.client._RegistApi(api_id, 0) + code, data = session.client._Call(api_id, body) + + if code != 0: + logger.warning(f"[Go2] _Call({api_id}, {body}) -> code={code} data={data!r}") + return False + return True + + def _get_session(self) -> _Session: + """Return active session or raise RuntimeError if disconnected. + + Note: callers using the returned session.lock must NEVER then + try to acquire self._session_lock — see the lock-ordering rule + in the class docstring. + """ + session = self._session + if session is None: + raise RuntimeError("Go2 not connected") + return session + + def _initialize_locomotion(self) -> bool: + """StandUp → 3s settle → FreeWalk → 2s settle → SpeedLevel. + + Called from connect() and from write_enable(True) if locomotion + was not yet ready. Assumes a sport mode is already active. + """ + session = self._get_session() + + if not self.check_mode(): + logger.error("[Go2] No sport mode active") + return False + + logger.info("[Go2] Standing up...") + with session.lock: + ret = session.client.StandUp() + if ret != 0: + logger.error(f"[Go2] StandUp failed with code {ret}") + return False + time.sleep(3) + + logger.info("[Go2] Activating FreeWalk...") + with session.lock: + ret = session.client.FreeWalk() + if ret != 0: + logger.error(f"[Go2] FreeWalk failed with code {ret}") + return False + time.sleep(2) + + with session.lock: + sl_ret = session.client.SpeedLevel(self._speed_level) + if sl_ret == 0: + logger.info(f"[Go2] SpeedLevel({self._speed_level}) applied") + else: + logger.warning(f"[Go2] SpeedLevel({self._speed_level}) returned {sl_ret}") + + session.locomotion_ready = True + logger.info("[Go2] Locomotion ready") + return True + + def _send_velocity(self, vx: float, vy: float, wz: float) -> bool: + session = self._get_session() + with session.lock: + ret = session.client.Move(vx, vy, wz) + if ret != 0: + logger.warning(f"[Go2] Move() returned code {ret}") + return False + return True + + +def register(registry: TwistBaseAdapterRegistry) -> None: + registry.register("unitree_go2", UnitreeGo2TwistAdapter) + + +__all__ = ["UnitreeGo2TwistAdapter"] diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 8e17e74e71..11b2ceb731 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -91,6 +91,7 @@ "unitree-go2-coordinator": "dimos.robot.unitree.go2.blueprints.basic.unitree_go2_coordinator:unitree_go2_coordinator", "unitree-go2-detection": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_detection:unitree_go2_detection", "unitree-go2-fleet": "dimos.robot.unitree.go2.blueprints.basic.unitree_go2_fleet:unitree_go2_fleet", + "unitree-go2-keyboard-teleop": "dimos.robot.unitree.go2.blueprints.basic.unitree_go2_keyboard_teleop:unitree_go2_keyboard_teleop", "unitree-go2-memory": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2:unitree_go2_memory", "unitree-go2-ros": "dimos.robot.unitree.go2.blueprints.smart.unitree_go2_ros:unitree_go2_ros", "unitree-go2-security": "dimos.robot.unitree.go2.blueprints.agentic.unitree_go2_security:unitree_go2_security", diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_keyboard_teleop.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_keyboard_teleop.py new file mode 100644 index 0000000000..bd4216026b --- /dev/null +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_keyboard_teleop.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unitree Go2 keyboard teleop via ControlCoordinator (DDS/SDK2 path). + +WASD keys -> Twist -> coordinator twist_command -> UnitreeGo2TwistAdapter (DDS). + +Usage: + dimos run unitree-go2-keyboard-teleop +""" + +from __future__ import annotations + +from dimos.control.components import HardwareComponent, HardwareType, make_twist_base_joints +from dimos.control.coordinator import ControlCoordinator, TaskConfig +from dimos.core.coordination.blueprints import autoconnect +from dimos.core.transport import LCMTransport +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.sensor_msgs.JointState import JointState +from dimos.robot.unitree.keyboard_teleop import KeyboardTeleop + +_go2_joints = make_twist_base_joints("go2") + +unitree_go2_keyboard_teleop = ( + autoconnect( + ControlCoordinator.blueprint( + hardware=[ + HardwareComponent( + hardware_id="go2", + hardware_type=HardwareType.BASE, + joints=_go2_joints, + adapter_type="unitree_go2", + adapter_kwargs={"rage_mode": False}, + ), + ], + tasks=[ + TaskConfig( + name="vel_go2", + type="velocity", + joint_names=_go2_joints, + priority=10, + ), + ], + ), + KeyboardTeleop.blueprint(), + ) + .transports( + { + ("twist_command", Twist): LCMTransport("/cmd_vel", Twist), + ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), + } + ) + .global_config(obstacle_avoidance=True) +) + +__all__ = ["unitree_go2_keyboard_teleop"] diff --git a/docs/usage/transports/dds.md b/docs/usage/transports/dds.md index 1aec0bafe5..924b9d43e8 100644 --- a/docs/usage/transports/dds.md +++ b/docs/usage/transports/dds.md @@ -1,6 +1,37 @@ # Installing DDS Transport Libs on Ubuntu -The `dds` extra provides DDS (Data Distribution Service) transport support via [Eclipse Cyclone DDS](https://cyclonedds.io/docs/cyclonedds-python/latest/). This requires installing system libraries before the Python package can be built. +The `dds` extra provides DDS (Data Distribution Service) transport support via [Eclipse Cyclone DDS](https://cyclonedds.io/docs/cyclonedds-python/latest/). The Python package builds C extensions against the CycloneDDS C library, so the C library must be installed before the Python package. + +## Recommended: nix-provided cyclonedds + +No `sudo`, no system pollution. Requires [Nix](/docs/installation/nix.md). + +```bash +nix build nixpkgs#cyclonedds # creates ./result symlink (GC root) +export CYCLONEDDS_HOME=$PWD/result +export LD_LIBRARY_PATH="$CYCLONEDDS_HOME/lib:$LD_LIBRARY_PATH" +uv pip install -e '.[dds]' +``` + +`LD_LIBRARY_PATH` must stay set at runtime. Persist with one of: + +```bash +# Per-venv (auto-set on `source .venv/bin/activate`) +cat >> .venv/bin/activate <> ~/.bashrc <=2.0.7" ] +unitree-dds = [ + "dimos[unitree]", + "unitree-sdk2py-dimos>=1.0.2", + "cyclonedds>=0.10.5", +] + manipulation = [ # Planning (Drake) "drake==1.45.0; sys_platform == 'darwin' and platform_machine != 'aarch64'", @@ -423,6 +429,7 @@ module = [ "torchreid", "turbojpeg", "ultralytics.*", + "unitree_sdk2py.*", "unitree_webrtc_connect.*", "xarm.*", "ament_index_python.*", diff --git a/uv.lock b/uv.lock index c77994a0e4..dfbc569f8a 100644 --- a/uv.lock +++ b/uv.lock @@ -1929,6 +1929,39 @@ unitree = [ { name = "unitree-webrtc-connect-leshy" }, { name = "uvicorn" }, ] +unitree-dds = [ + { name = "anthropic" }, + { name = "bitsandbytes", marker = "sys_platform == 'linux'" }, + { name = "cyclonedds" }, + { name = "dimos-viewer" }, + { name = "fastapi" }, + { name = "ffmpeg-python" }, + { name = "filterpy" }, + { name = "hydra-core" }, + { name = "langchain" }, + { name = "langchain-chroma" }, + { name = "langchain-core" }, + { name = "langchain-huggingface" }, + { name = "langchain-ollama" }, + { name = "langchain-openai" }, + { name = "langchain-text-splitters" }, + { name = "lap" }, + { name = "moondream" }, + { name = "ollama" }, + { name = "omegaconf" }, + { name = "openai" }, + { name = "openai-whisper" }, + { name = "pillow" }, + { name = "rerun-sdk" }, + { name = "sounddevice" }, + { name = "soundfile" }, + { name = "sse-starlette" }, + { name = "transformers", extra = ["torch"] }, + { name = "ultralytics" }, + { name = "unitree-sdk2py-dimos" }, + { name = "unitree-webrtc-connect-leshy" }, + { name = "uvicorn" }, +] visualization = [ { name = "dimos-viewer" }, { name = "rerun-sdk" }, @@ -1954,8 +1987,10 @@ requires-dist = [ { name = "ctransformers", extras = ["cuda"], marker = "extra == 'cuda'", specifier = "==0.2.27" }, { name = "cupy-cuda12x", marker = "platform_machine == 'x86_64' and extra == 'cuda'", specifier = "==13.6.0" }, { name = "cyclonedds", marker = "extra == 'dds'", specifier = ">=0.10.5" }, + { name = "cyclonedds", marker = "extra == 'unitree-dds'", specifier = ">=0.10.5" }, { name = "dimos", extras = ["agents", "web", "perception", "visualization"], marker = "extra == 'base'" }, { name = "dimos", extras = ["base"], marker = "extra == 'unitree'" }, + { name = "dimos", extras = ["unitree"], marker = "extra == 'unitree-dds'" }, { name = "dimos-lcm" }, { name = "dimos-lcm", marker = "extra == 'docker'" }, { name = "dimos-viewer", specifier = "==0.30.0a6.dev99" }, @@ -2100,6 +2135,7 @@ requires-dist = [ { name = "types-tqdm", marker = "extra == 'dev'", specifier = ">=4.67.0.20250809,<5" }, { name = "typing-extensions", marker = "python_full_version < '3.11'", specifier = ">=4.0" }, { name = "ultralytics", marker = "extra == 'perception'", specifier = ">=8.3.70" }, + { name = "unitree-sdk2py-dimos", marker = "extra == 'unitree-dds'", specifier = ">=1.0.2" }, { name = "unitree-webrtc-connect-leshy", marker = "extra == 'unitree'", specifier = ">=2.0.7" }, { name = "uvicorn", marker = "extra == 'web'", specifier = ">=0.34.0" }, { name = "watchdog", marker = "extra == 'dev'", specifier = ">=3.0.0" }, @@ -2109,7 +2145,7 @@ requires-dist = [ { name = "xformers", marker = "platform_machine == 'x86_64' and extra == 'cuda'", specifier = ">=0.0.20" }, { name = "yapf", marker = "extra == 'misc'", specifier = "==0.40.2" }, ] -provides-extras = ["misc", "visualization", "agents", "web", "perception", "unitree", "manipulation", "cpu", "cuda", "dev", "psql", "sim", "drone", "dds", "docker", "base"] +provides-extras = ["misc", "visualization", "agents", "web", "perception", "unitree", "unitree-dds", "manipulation", "cpu", "cuda", "dev", "psql", "sim", "drone", "dds", "docker", "base"] [[package]] name = "dimos-lcm" @@ -10292,6 +10328,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7f/c7/fb42228bb05473d248c110218ffb8b1ad2f76728ed8699856e5af21112ad/ultralytics_thop-2.0.18-py3-none-any.whl", hash = "sha256:2bb44851ad224b116c3995b02dd5e474a5ccf00acf237fe0edb9e1506ede04ec", size = 28941, upload-time = "2025-10-29T16:58:12.093Z" }, ] +[[package]] +name = "unitree-sdk2py-dimos" +version = "1.0.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cyclonedds" }, + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "opencv-python" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/93/d2/8c927709a432e6003a7ffdb434c2a3570c1b4ed97c9a0b7b85313e32f6bb/unitree_sdk2py_dimos-1.0.3.tar.gz", hash = "sha256:d0076b9501849a8f144dd076ffb3894c5c804c87cdad7521095c2bc893049438", size = 48758, upload-time = "2026-03-03T21:19:32.8Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/39/69/76b879edbf5eab1cb200bf818b87a5943effe441722429ab940ea38a6887/unitree_sdk2py_dimos-1.0.3-py3-none-any.whl", hash = "sha256:8057cad5de5877757bc586b2ba5ddbe84522ce4c8c3d624464ef975cafa5daec", size = 110245, upload-time = "2026-03-03T21:19:31.692Z" }, +] + [[package]] name = "unitree-webrtc-connect-leshy" version = "2.0.7" From 6599a97909340fcff912712c8a4801038b4e7f5b Mon Sep 17 00:00:00 2001 From: leshy Date: Tue, 28 Apr 2026 20:27:08 +0300 Subject: [PATCH 20/30] Revert "Jeff/fix/rconnect2" (#1924) --- .gitignore | 3 - dimos/core/coordination/python_worker.py | 16 +- dimos/core/docker_module.py | 2 +- dimos/core/global_config.py | 11 +- dimos/hardware/sensors/camera/module.py | 5 +- .../lidar/fastlio2/fastlio_blueprints.py | 35 +-- .../sensors/lidar/livox/livox_blueprints.py | 4 +- dimos/manipulation/blueprints.py | 10 +- dimos/manipulation/grasping/demo_grasping.py | 4 +- .../wavefront_frontier_goal_selector.py | 11 - dimos/navigation/replanning_a_star/module.py | 18 +- .../movement_manager/movement_manager.py | 133 --------- .../movement_manager/test_movement_manager.py | 117 -------- .../demo_object_scene_registration.py | 4 +- dimos/robot/all_blueprints.py | 2 - dimos/robot/cli/dimos.py | 48 +--- .../drone/blueprints/basic/drone_basic.py | 17 +- .../blueprints/perceptive/unitree_g1_shm.py | 10 +- .../primitive/uintree_g1_primitive_no_nav.py | 19 +- .../agentic/unitree_go2_security.py | 4 +- .../go2/blueprints/basic/unitree_go2_basic.py | 34 ++- .../go2/blueprints/basic/unitree_go2_fleet.py | 6 +- .../unitree_go2_webrtc_keyboard_teleop.py | 4 - .../go2/blueprints/smart/unitree_go2.py | 6 +- dimos/robot/unitree/keyboard_teleop.py | 10 +- dimos/robot/unitree/mujoco_connection.py | 16 +- dimos/simulation/unity/blueprint.py | 4 +- dimos/teleop/quest/blueprints.py | 4 +- dimos/test_no_sections.py | 2 - dimos/utils/generic.py | 17 -- dimos/visualization/rerun/bridge.py | 253 +++++++++--------- dimos/visualization/rerun/conftest.py | 45 ---- dimos/visualization/rerun/constants.py | 31 --- .../visualization/rerun/test_viewer_ws_e2e.py | 201 -------------- .../rerun/test_websocket_server.py | 210 --------------- dimos/visualization/rerun/websocket_server.py | 244 ----------------- dimos/visualization/vis_module.py | 87 ------ .../web/websocket_vis/websocket_vis_module.py | 24 +- docs/development/conventions.md | 12 - docs/usage/cli.md | 4 +- docs/usage/visualization.md | 42 ++- pyproject.toml | 2 +- uv.lock | 26 +- 43 files changed, 292 insertions(+), 1465 deletions(-) delete mode 100644 dimos/navigation/smart_nav/modules/movement_manager/movement_manager.py delete mode 100644 dimos/navigation/smart_nav/modules/movement_manager/test_movement_manager.py delete mode 100644 dimos/visualization/rerun/conftest.py delete mode 100644 dimos/visualization/rerun/constants.py delete mode 100644 dimos/visualization/rerun/test_viewer_ws_e2e.py delete mode 100644 dimos/visualization/rerun/test_websocket_server.py delete mode 100644 dimos/visualization/rerun/websocket_server.py delete mode 100644 dimos/visualization/vis_module.py delete mode 100644 docs/development/conventions.md diff --git a/.gitignore b/.gitignore index ea68926e96..1816510c08 100644 --- a/.gitignore +++ b/.gitignore @@ -74,9 +74,6 @@ CLAUDE.MD /.mcp.json *.speedscope.json -# Hidden/personal directories -.hidden/ - # Coverage htmlcov/ .coverage diff --git a/dimos/core/coordination/python_worker.py b/dimos/core/coordination/python_worker.py index 6c3aab3a2d..3c434a982e 100644 --- a/dimos/core/coordination/python_worker.py +++ b/dimos/core/coordination/python_worker.py @@ -18,7 +18,6 @@ import multiprocessing from multiprocessing.connection import Connection import os -import signal import sys import threading import traceback @@ -338,15 +337,12 @@ class _WorkerState: def _worker_entrypoint(conn: Connection, worker_id: int) -> None: apply_library_config() - # Ignore SIGINT so the coordinator can orchestrate shutdown via the pipe. - # Without this, workers race with the coordinator: they start tearing down - # modules locally while the coordinator tries to send stop() RPCs, causing - # BrokenPipeErrors. - signal.signal(signal.SIGINT, signal.SIG_IGN) state = _WorkerState(instances={}, worker_id=worker_id) try: _worker_loop(conn, state) + except KeyboardInterrupt: + logger.info("Worker got KeyboardInterrupt.", worker_id=worker_id) except Exception as e: logger.error(f"Worker process error: {e}", exc_info=True) finally: @@ -365,6 +361,12 @@ def _worker_entrypoint(conn: Connection, worker_id: int) -> None: worker_id=worker_id, module_id=module_id, ) + except KeyboardInterrupt: + logger.warning( + "KeyboardInterrupt during worker stop", + module=type(instance).__name__, + worker_id=worker_id, + ) except Exception: logger.error("Error during worker shutdown", exc_info=True) @@ -431,7 +433,7 @@ def _worker_loop(conn: Connection, state: _WorkerState) -> None: if not conn.poll(timeout=0.1): continue request = conn.recv() - except EOFError: + except (EOFError, KeyboardInterrupt): break try: diff --git a/dimos/core/docker_module.py b/dimos/core/docker_module.py index f82a1b56db..3ad9620556 100644 --- a/dimos/core/docker_module.py +++ b/dimos/core/docker_module.py @@ -30,7 +30,7 @@ from dimos.core.rpc_client import ModuleProxyProtocol, RpcCall from dimos.protocol.rpc.pubsubrpc import LCMRPC from dimos.utils.logging_config import setup_logger -from dimos.visualization.rerun.constants import RERUN_GRPC_PORT, RERUN_WEB_PORT +from dimos.visualization.rerun.bridge import RERUN_GRPC_PORT, RERUN_WEB_PORT if TYPE_CHECKING: from collections.abc import Callable diff --git a/dimos/core/global_config.py b/dimos/core/global_config.py index 435f421dd1..214401959e 100644 --- a/dimos/core/global_config.py +++ b/dimos/core/global_config.py @@ -13,16 +13,13 @@ # limitations under the License. import re +from typing import Literal, TypeAlias from pydantic_settings import BaseSettings, SettingsConfigDict from dimos.models.vl.types import VlModelName -from dimos.visualization.rerun.constants import ( - RERUN_ENABLE_WEB, - RERUN_OPEN_DEFAULT, - RerunOpenOption, - ViewerBackend, -) + +ViewerBackend: TypeAlias = Literal["rerun", "rerun-web", "rerun-connect", "foxglove", "none"] def _get_all_numbers(s: str) -> list[float]: @@ -40,8 +37,6 @@ class GlobalConfig(BaseSettings): replay_db: str = "go2_bigoffice" new_memory: bool = False viewer: ViewerBackend = "rerun" - rerun_open: RerunOpenOption = RERUN_OPEN_DEFAULT - rerun_web: bool = RERUN_ENABLE_WEB n_workers: int = 2 memory_limit: str = "auto" mujoco_camera_position: str | None = None diff --git a/dimos/hardware/sensors/camera/module.py b/dimos/hardware/sensors/camera/module.py index 0fe0d8f030..9b4f50920c 100644 --- a/dimos/hardware/sensors/camera/module.py +++ b/dimos/hardware/sensors/camera/module.py @@ -21,7 +21,6 @@ from dimos.agents.annotation import skill from dimos.core.coordination.blueprints import autoconnect from dimos.core.core import rpc -from dimos.core.global_config import global_config from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out from dimos.hardware.sensors.camera.spec import CameraHardware @@ -32,7 +31,7 @@ from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo from dimos.msgs.sensor_msgs.Image import Image, sharpness_barrier from dimos.spec import perception -from dimos.visualization.vis_module import vis_module +from dimos.visualization.rerun.bridge import RerunBridgeModule def default_transform() -> Transform: @@ -121,5 +120,5 @@ def stop(self) -> None: demo_camera = autoconnect( CameraModule.blueprint(), - vis_module(viewer_backend=global_config.viewer), + RerunBridgeModule.blueprint(), ) diff --git a/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py b/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py index 2c2a64d61e..2946f1d247 100644 --- a/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py +++ b/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py @@ -15,45 +15,30 @@ from dimos.core.coordination.blueprints import autoconnect from dimos.hardware.sensors.lidar.fastlio2.module import FastLio2 from dimos.mapping.voxels import VoxelGridMapper -from dimos.visualization.vis_module import vis_module +from dimos.visualization.rerun.bridge import RerunBridgeModule voxel_size = 0.05 mid360_fastlio = autoconnect( FastLio2.blueprint(voxel_size=voxel_size, map_voxel_size=voxel_size, map_freq=-1), - vis_module( - "rerun", - rerun_config={ - "visual_override": { - "world/lidar": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), - }, - }, - ), + RerunBridgeModule.blueprint(), ).global_config(n_workers=2, robot_model="mid360_fastlio2") mid360_fastlio_voxels = autoconnect( FastLio2.blueprint(), VoxelGridMapper.blueprint(voxel_size=voxel_size, carve_columns=False), - vis_module( - "rerun", - rerun_config={ - "visual_override": { - "world/global_map": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), - "world/lidar": None, - }, - }, + RerunBridgeModule.blueprint( + visual_override={ + "world/lidar": None, + } ), ).global_config(n_workers=3, robot_model="mid360_fastlio2_voxels") mid360_fastlio_voxels_native = autoconnect( FastLio2.blueprint(voxel_size=voxel_size, map_voxel_size=voxel_size, map_freq=3.0), - vis_module( - "rerun", - rerun_config={ - "visual_override": { - "world/lidar": None, - "world/global_map": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), - }, - }, + RerunBridgeModule.blueprint( + visual_override={ + "world/lidar": None, + } ), ).global_config(n_workers=2, robot_model="mid360_fastlio2") diff --git a/dimos/hardware/sensors/lidar/livox/livox_blueprints.py b/dimos/hardware/sensors/lidar/livox/livox_blueprints.py index e437d73994..34ebc33c2a 100644 --- a/dimos/hardware/sensors/lidar/livox/livox_blueprints.py +++ b/dimos/hardware/sensors/lidar/livox/livox_blueprints.py @@ -14,9 +14,9 @@ from dimos.core.coordination.blueprints import autoconnect from dimos.hardware.sensors.lidar.livox.module import Mid360 -from dimos.visualization.vis_module import vis_module +from dimos.visualization.rerun.bridge import RerunBridgeModule mid360 = autoconnect( Mid360.blueprint(), - vis_module("rerun"), + RerunBridgeModule.blueprint(), ).global_config(n_workers=2, robot_model="mid360") diff --git a/dimos/manipulation/blueprints.py b/dimos/manipulation/blueprints.py index 1c006c1d04..f950ea8efa 100644 --- a/dimos/manipulation/blueprints.py +++ b/dimos/manipulation/blueprints.py @@ -44,7 +44,7 @@ from dimos.msgs.sensor_msgs.JointState import JointState from dimos.perception.object_scene_registration import ObjectSceneRegistrationModule from dimos.robot.catalog.ufactory import xarm6 as _catalog_xarm6, xarm7 as _catalog_xarm7 -from dimos.visualization.vis_module import vis_module +from dimos.robot.foxglove_bridge import FoxgloveBridge # TODO: migrate to rerun # Single XArm6 planner (standalone, no coordinator) _xarm6_planner_cfg = _catalog_xarm6( @@ -196,14 +196,14 @@ use_aabb=True, max_obstacle_width=0.06, ), - vis_module("foxglove"), + FoxgloveBridge.blueprint(), # TODO: migrate to rerun ) .transports( { ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), } ) - .global_config(n_workers=4) + .global_config(viewer="foxglove", n_workers=4) ) @@ -289,7 +289,7 @@ from dimos.robot.catalog.ufactory import XARM7_SIM_PATH from dimos.simulation.engines.mujoco_sim_module import MujocoSimModule -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode _xarm7_sim_cfg = _catalog_xarm7( name="arm", @@ -323,7 +323,7 @@ hardware=[_xarm7_sim_cfg.to_hardware_component()], tasks=[_xarm7_sim_cfg.to_task_config()], ), - RerunBridgeModule.blueprint(), + RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode()), ).transports( { ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), diff --git a/dimos/manipulation/grasping/demo_grasping.py b/dimos/manipulation/grasping/demo_grasping.py index 4a1d4b2cf6..37e1d38f1e 100644 --- a/dimos/manipulation/grasping/demo_grasping.py +++ b/dimos/manipulation/grasping/demo_grasping.py @@ -22,7 +22,7 @@ from dimos.manipulation.grasping.grasping import GraspingModule from dimos.perception.detection.detectors.yoloe import YoloePromptMode from dimos.perception.object_scene_registration import ObjectSceneRegistrationModule -from dimos.visualization.vis_module import vis_module +from dimos.robot.foxglove_bridge import FoxgloveBridge camera_module = RealSenseCamera.blueprint(enable_pointcloud=False) @@ -44,7 +44,7 @@ ("/tmp", "/tmp", "rw") ], # Grasp visualization debug standalone: python -m dimos.manipulation.grasping.visualize_grasps ), - vis_module("foxglove"), + FoxgloveBridge.blueprint(), McpServer.blueprint(), McpClient.blueprint(), ).global_config(viewer="foxglove") diff --git a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py index 338d10d9b0..b8dbe0dfc8 100644 --- a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py @@ -115,7 +115,6 @@ class WavefrontFrontierExplorer(Module): goal_reached: In[Bool] explore_cmd: In[Bool] stop_explore_cmd: In[Bool] - stop_movement: In[Bool] # LCM outputs goal_request: Out[PoseStamped] @@ -172,10 +171,6 @@ def start(self) -> None: unsub = self.stop_explore_cmd.subscribe(self._on_stop_explore_cmd) self.register_disposable(Disposable(unsub)) - if self.stop_movement.transport is not None: - unsub = self.stop_movement.subscribe(self._on_stop_movement) - self.register_disposable(Disposable(unsub)) - @rpc def stop(self) -> None: self.stop_exploration() @@ -206,12 +201,6 @@ def _on_stop_explore_cmd(self, msg: Bool) -> None: logger.info("Received exploration stop command via LCM") self.stop_exploration() - def _on_stop_movement(self, msg: Bool) -> None: - """Handle stop movement from teleop — cancel active exploration.""" - if msg.data and self.exploration_active: - logger.info("WavefrontFrontierExplorer: stop_movement received, stopping exploration") - self.stop_exploration() - def _count_costmap_information(self, costmap: OccupancyGrid) -> int: """ Count the amount of information in a costmap (free space + obstacles). diff --git a/dimos/navigation/replanning_a_star/module.py b/dimos/navigation/replanning_a_star/module.py index efc16b52d6..2375af20ce 100644 --- a/dimos/navigation/replanning_a_star/module.py +++ b/dimos/navigation/replanning_a_star/module.py @@ -28,9 +28,6 @@ from dimos.msgs.nav_msgs.Path import Path from dimos.navigation.base import NavigationInterface, NavigationState from dimos.navigation.replanning_a_star.global_planner import GlobalPlanner -from dimos.utils.logging_config import setup_logger - -logger = setup_logger() class ReplanningAStarPlanner(Module, NavigationInterface): @@ -39,11 +36,10 @@ class ReplanningAStarPlanner(Module, NavigationInterface): goal_request: In[PoseStamped] clicked_point: In[PointStamped] target: In[PoseStamped] - stop_movement: In[Bool] goal_reached: Out[Bool] navigation_state: Out[String] # TODO: set it - nav_cmd_vel: Out[Twist] + cmd_vel: Out[Twist] path: Out[Path] navigation_costmap: Out[OccupancyGrid] @@ -76,14 +72,9 @@ def start(self) -> None: ) ) - if self.stop_movement.transport is not None: - self.register_disposable( - Disposable(self.stop_movement.subscribe(self._on_stop_movement)) - ) - self.register_disposable(self._planner.path.subscribe(self.path.publish)) - self.register_disposable(self._planner.cmd_vel.subscribe(self.nav_cmd_vel.publish)) + self.register_disposable(self._planner.cmd_vel.subscribe(self.cmd_vel.publish)) self.register_disposable(self._planner.goal_reached.subscribe(self.goal_reached.publish)) @@ -101,11 +92,6 @@ def stop(self) -> None: super().stop() - def _on_stop_movement(self, msg: Bool) -> None: - if msg.data: - logger.info("ReplanningAStarPlanner: stop_movement received, cancelling goal") - self.cancel_goal() - @rpc def set_goal(self, goal: PoseStamped) -> bool: self._planner.handle_goal_request(goal) diff --git a/dimos/navigation/smart_nav/modules/movement_manager/movement_manager.py b/dimos/navigation/smart_nav/modules/movement_manager/movement_manager.py deleted file mode 100644 index 5a2dd195c0..0000000000 --- a/dimos/navigation/smart_nav/modules/movement_manager/movement_manager.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""MovementManager: click-to-goal relay + teleop/nav velocity mux.""" - -from __future__ import annotations - -import math -import threading -import time -from typing import Any - -from dimos_lcm.std_msgs import Bool # type: ignore[import-untyped] -from reactivex.disposable import Disposable - -from dimos.core.core import rpc -from dimos.core.module import Module, ModuleConfig -from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs.PointStamped import PointStamped -from dimos.msgs.geometry_msgs.Twist import Twist -from dimos.msgs.geometry_msgs.Vector3 import Vector3 -from dimos.utils.logging_config import setup_logger - -logger = setup_logger() - - -class MovementManagerConfig(ModuleConfig): - tele_cooldown_sec: float = 1.0 - tele_cmd_vel_scaling: Twist = Twist(Vector3(1, 1, 1), Vector3(1, 1, 1)) - - -class MovementManager(Module): - """Combine tele_cmd_vel (keyboard controls) and nav_cmd_vel in a sane way, output cmd_vel""" - - config: MovementManagerConfig - - clicked_point: In[PointStamped] - nav_cmd_vel: In[Twist] - tele_cmd_vel: In[Twist] - - goal: Out[PointStamped] - way_point: Out[PointStamped] - cmd_vel: Out[Twist] - stop_movement: Out[Bool] - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._lock = threading.Lock() - self._teleop_active = False - self._last_teleop_time = 0.0 - - @rpc - def start(self) -> None: - super().start() - self.register_disposable(Disposable(self.clicked_point.subscribe(self._on_click))) - self.register_disposable(Disposable(self.nav_cmd_vel.subscribe(self._on_nav))) - self.register_disposable(Disposable(self.tele_cmd_vel.subscribe(self._on_teleop))) - - @rpc - def stop(self) -> None: - with self._lock: - self._teleop_active = False - super().stop() - - def _on_click(self, msg: PointStamped) -> None: - if not all(math.isfinite(v) for v in (msg.x, msg.y, msg.z)): - logger.warning("Ignored invalid click", x=msg.x, y=msg.y, z=msg.z) - return - if abs(msg.x) > 500 or abs(msg.y) > 500 or abs(msg.z) > 50: - logger.warning("Ignored out-of-range click", x=msg.x, y=msg.y, z=msg.z) - return - - logger.debug("Goal", x=round(msg.x, 1), y=round(msg.y, 1), z=round(msg.z, 1)) - self.way_point.publish(msg) - self.goal.publish(msg) - - def _cancel_goal(self) -> None: - self.stop_movement.publish(Bool(data=True)) - # NOTE: this NaN goal is more of a safety fallback. - # It can be REALLY bad if a robot is supposed to stop moving but wont - # we should probably think a more robust/strict requirement on planners - cancel = PointStamped( - ts=time.time(), frame_id="map", x=float("nan"), y=float("nan"), z=float("nan") - ) - self.way_point.publish(cancel) - self.goal.publish(cancel) - logger.debug("Navigation cancelled — waiting for new goal") - - def _on_nav(self, msg: Twist) -> None: - with self._lock: - if self._teleop_active: - # check if cooldown has expired - elapsed = time.monotonic() - self._last_teleop_time - if elapsed < self.config.tele_cooldown_sec: - return - self._teleop_active = False - self.cmd_vel.publish(msg) - - def _on_teleop(self, msg: Twist) -> None: - with self._lock: - was_active = self._teleop_active - self._teleop_active = True - self._last_teleop_time = time.monotonic() - - if not was_active: - self._cancel_goal() - logger.info("Teleop active") - - scale = self.config.tele_cmd_vel_scaling - scaled = Twist( - linear=Vector3( - msg.linear.x * scale.linear.x, - msg.linear.y * scale.linear.y, - msg.linear.z * scale.linear.z, - ), - angular=Vector3( - msg.angular.x * scale.angular.x, - msg.angular.y * scale.angular.y, - msg.angular.z * scale.angular.z, - ), - ) - self.cmd_vel.publish(scaled) diff --git a/dimos/navigation/smart_nav/modules/movement_manager/test_movement_manager.py b/dimos/navigation/smart_nav/modules/movement_manager/test_movement_manager.py deleted file mode 100644 index 6858055605..0000000000 --- a/dimos/navigation/smart_nav/modules/movement_manager/test_movement_manager.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for MovementManager: click-to-goal + teleop/nav velocity mux.""" - -from __future__ import annotations - -import math -import time -from unittest.mock import MagicMock - -import pytest - -from dimos.msgs.geometry_msgs.PointStamped import PointStamped -from dimos.msgs.geometry_msgs.Twist import Twist -from dimos.msgs.geometry_msgs.Vector3 import Vector3 -from dimos.navigation.smart_nav.modules.movement_manager.movement_manager import ( - MovementManager, -) - - -@pytest.fixture() -def manager() -> MovementManager: - """Create a real MovementManager and mock the publish methods on its output streams.""" - module = MovementManager(tele_cooldown_sec=0.1) - module.cmd_vel.publish = MagicMock() - module.stop_movement.publish = MagicMock() - module.goal.publish = MagicMock() - module.way_point.publish = MagicMock() - yield module - module._close_module() - - -def _twist(lx: float = 0.0) -> Twist: - return Twist(linear=Vector3(lx, 0, 0), angular=Vector3(0, 0, 0)) - - -def _click(x: float = 1.0, y: float = 2.0, z: float = 0.0) -> PointStamped: - return PointStamped(ts=time.time(), frame_id="map", x=x, y=y, z=z) - - -def test_teleop_suppresses_nav_and_cancels_goal(manager: MovementManager) -> None: - """Teleop arriving should suppress nav, publish stop_movement, and cancel the goal with NaN.""" - manager.config.tele_cooldown_sec = 10.0 - manager._on_teleop(_twist(lx=0.3)) - - # Nav is suppressed - manager.cmd_vel.publish.reset_mock() # type: ignore[union-attr] - manager._on_nav(_twist(lx=0.9)) - manager.cmd_vel.publish.assert_not_called() # type: ignore[union-attr] - - # stop_movement fired - manager.stop_movement.publish.assert_called_once() # type: ignore[union-attr] - - # Goal cancelled with NaN - cancel_msg = manager.goal.publish.call_args[0][0] # type: ignore[union-attr] - assert math.isnan(cancel_msg.x) - - -def test_nav_resumes_after_cooldown(manager: MovementManager) -> None: - """After the cooldown expires, nav commands pass through again.""" - manager.config.tele_cooldown_sec = 0.05 - manager._on_teleop(_twist(lx=0.3)) - time.sleep(0.1) - manager.cmd_vel.publish.reset_mock() # type: ignore[union-attr] - - manager._on_nav(_twist(lx=0.9)) - manager.cmd_vel.publish.assert_called_once() # type: ignore[union-attr] - - -def test_valid_click_publishes_goal(manager: MovementManager) -> None: - """A valid click should publish to both goal and way_point.""" - click = _click(x=5.0, y=3.0, z=0.1) - manager._on_click(click) - manager.goal.publish.assert_called_once_with(click) # type: ignore[union-attr] - manager.way_point.publish.assert_called_once_with(click) # type: ignore[union-attr] - - -def test_invalid_clicks_rejected(manager: MovementManager) -> None: - """NaN, Inf, and out-of-range clicks should not publish.""" - for bad_click in [ - _click(x=float("nan")), - _click(x=float("inf")), - _click(x=600.0), - ]: - manager._on_click(bad_click) - manager.goal.publish.assert_not_called() # type: ignore[union-attr] - - -def test_tele_cmd_vel_scaling() -> None: - """tele_cmd_vel_scaling multiplies each teleop twist component independently.""" - scaling = Twist(Vector3(0.5, 2.0, 0.0), Vector3(1.0, 1.0, 0.25)) - module = MovementManager(tele_cooldown_sec=10.0, tele_cmd_vel_scaling=scaling) - module.cmd_vel.publish = MagicMock() - module.stop_movement.publish = MagicMock() - module.goal.publish = MagicMock() - module.way_point.publish = MagicMock() - - module._on_teleop(Twist(Vector3(1, 1, 1), Vector3(1, 1, 1))) - - published = module.cmd_vel.publish.call_args[0][0] # type: ignore[union-attr] - assert published.linear.x == pytest.approx(0.5) - assert published.linear.y == pytest.approx(2.0) - assert published.linear.z == pytest.approx(0.0) - assert published.angular.z == pytest.approx(0.25) - module._close_module() diff --git a/dimos/perception/demo_object_scene_registration.py b/dimos/perception/demo_object_scene_registration.py index 28044dec13..c9b489f54b 100644 --- a/dimos/perception/demo_object_scene_registration.py +++ b/dimos/perception/demo_object_scene_registration.py @@ -20,7 +20,7 @@ from dimos.hardware.sensors.camera.zed.compat import ZEDCamera from dimos.perception.detection.detectors.yoloe import YoloePromptMode from dimos.perception.object_scene_registration import ObjectSceneRegistrationModule -from dimos.visualization.vis_module import vis_module +from dimos.robot.foxglove_bridge import FoxgloveBridge camera_choice = "zed" @@ -34,7 +34,7 @@ demo_object_scene_registration = autoconnect( camera_module, ObjectSceneRegistrationModule.blueprint(target_frame="world", prompt_mode=YoloePromptMode.LRPC), - vis_module("foxglove"), + FoxgloveBridge.blueprint(), McpServer.blueprint(), McpClient.blueprint(), ).global_config(viewer="foxglove") diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 11b2ceb731..c794a67124 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -153,7 +153,6 @@ "mock-b1-connection-module": "dimos.robot.unitree.b1.connection.MockB1ConnectionModule", "module-a": "dimos.robot.unitree.demo_error_on_name_conflicts.ModuleA", "module-b": "dimos.robot.unitree.demo_error_on_name_conflicts.ModuleB", - "movement-manager": "dimos.navigation.smart_nav.modules.movement_manager.movement_manager.MovementManager", "mujoco-sim-module": "dimos.simulation.engines.mujoco_sim_module.MujocoSimModule", "navigation-module": "dimos.robot.unitree.rosnav.NavigationModule", "navigation-skill-container": "dimos.agents.skills.navigation.NavigationSkillContainer", @@ -176,7 +175,6 @@ "reid-module": "dimos.perception.detection.reid.module.ReidModule", "replanning-a-star-planner": "dimos.navigation.replanning_a_star.module.ReplanningAStarPlanner", "rerun-bridge-module": "dimos.visualization.rerun.bridge.RerunBridgeModule", - "rerun-web-socket-server": "dimos.visualization.rerun.websocket_server.RerunWebSocketServer", "ros-nav": "dimos.navigation.rosnav.ROSNav", "security-module": "dimos.experimental.security_demo.security_module.SecurityModule", "semantic-search": "dimos.memory2.module.SemanticSearch", diff --git a/dimos/robot/cli/dimos.py b/dimos/robot/cli/dimos.py index e99553c2b3..37d1bd2be0 100644 --- a/dimos/robot/cli/dimos.py +++ b/dimos/robot/cli/dimos.py @@ -21,11 +21,10 @@ import json import os from pathlib import Path -import signal import sys import time import types -from typing import TYPE_CHECKING, Any, Union, cast, get_args, get_origin +from typing import TYPE_CHECKING, Any, Union, get_args, get_origin import click from dotenv import load_dotenv @@ -39,10 +38,7 @@ from dimos.core.daemon import daemonize, install_signal_handlers from dimos.core.global_config import GlobalConfig, global_config from dimos.core.run_registry import get_most_recent, is_pid_alive, stop_entry -from dimos.protocol.pubsub.impl.lcmpubsub import LCM -from dimos.protocol.service.lcmservice import autoconf from dimos.utils.logging_config import setup_logger -from dimos.visualization.rerun.constants import RerunOpenOption if TYPE_CHECKING: from dimos.core.coordination.blueprints import Blueprint, BlueprintAtom @@ -226,10 +222,6 @@ def run( cli_config_overrides: dict[str, Any] = ctx.obj - # this is a workaround until we have a proper way to have delayed-module-choice in blueprints - # ex: vis_module(viewer=global_config.viewer) is WRONG (viewer will always be default value) without this patch - global_config.update(**cli_config_overrides) - # Clean stale registry entries stale = cleanup_stale() if stale: @@ -668,43 +660,17 @@ def send( @main.command(name="rerun-bridge") def rerun_bridge_cmd( + viewer_mode: str = typer.Option( + "native", help="Viewer mode: native (desktop), web (browser), none (headless)" + ), memory_limit: str = typer.Option( "25%", help="Memory limit for Rerun viewer (e.g., '4GB', '16GB', '25%')" ), - rerun_open: str = typer.Option( - "native", help="How to open Rerun: one of native, web, both, none" - ), - rerun_web: bool = typer.Option( - True, "--rerun-web/--no-rerun-web", help="Enable/Disable Rerun web server" - ), ) -> None: - """Launch the Rerun visualization bridge. - - Standalone utility: runs the bridge directly in the main process (no - blueprint / worker pool) so users can attach a viewer to existing LCM - traffic without building a full module graph. - """ - # Deferred: RerunBridgeModule pulls in the rerun package (~1s), keep it - # out of the CLI's hot path so `dimos --help` stays fast. - from dimos.visualization.rerun.bridge import RerunBridgeModule - - valid = get_args(RerunOpenOption) - if rerun_open not in valid: - raise typer.BadParameter( - f"rerun_open must be one of {valid}, got {rerun_open!r}", param_hint="--rerun-open" - ) - autoconf(check_only=True) - - bridge = RerunBridgeModule( - memory_limit=memory_limit, - rerun_open=cast("RerunOpenOption", rerun_open), - rerun_web=rerun_web, - pubsubs=[LCM()], - ) - bridge.start() + """Launch the Rerun visualization bridge.""" + from dimos.visualization.rerun.bridge import run_bridge - signal.signal(signal.SIGINT, lambda *_: bridge.stop()) - signal.pause() + run_bridge(viewer_mode=viewer_mode, memory_limit=memory_limit) if __name__ == "__main__": diff --git a/dimos/robot/drone/blueprints/basic/drone_basic.py b/dimos/robot/drone/blueprints/basic/drone_basic.py index aaf82f6355..c1838d6ac7 100644 --- a/dimos/robot/drone/blueprints/basic/drone_basic.py +++ b/dimos/robot/drone/blueprints/basic/drone_basic.py @@ -20,9 +20,10 @@ from dimos.core.coordination.blueprints import autoconnect from dimos.core.global_config import global_config +from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.robot.drone.camera_module import DroneCameraModule from dimos.robot.drone.connection_module import DroneConnectionModule -from dimos.visualization.vis_module import vis_module +from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule def _static_drone_body(rr: Any) -> list[Any]: @@ -59,12 +60,23 @@ def _drone_rerun_blueprint() -> Any: _rerun_config = { "blueprint": _drone_rerun_blueprint, + "pubsubs": [LCM()], "static": { "world/tf/base_link": _static_drone_body, }, } -_vis = vis_module(global_config.viewer, rerun_config=_rerun_config) +# Conditional visualization +if global_config.viewer == "foxglove": + from dimos.robot.foxglove_bridge import FoxgloveBridge + + _vis = FoxgloveBridge.blueprint() +elif global_config.viewer.startswith("rerun"): + from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode + + _vis = RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **_rerun_config) +else: + _vis = autoconnect() # Determine connection string based on replay flag connection_string = "udp:0.0.0.0:14550" @@ -80,6 +92,7 @@ def _drone_rerun_blueprint() -> Any: outdoor=False, ), DroneCameraModule.blueprint(camera_intrinsics=[1000.0, 1000.0, 960.0, 540.0]), + WebsocketVisModule.blueprint(), ) __all__ = [ diff --git a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py index 4941abad38..dd135a60a1 100644 --- a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py +++ b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py @@ -17,11 +17,10 @@ from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE from dimos.core.coordination.blueprints import autoconnect -from dimos.core.global_config import global_config from dimos.core.transport import pSHMTransport from dimos.msgs.sensor_msgs.Image import Image +from dimos.robot.foxglove_bridge import FoxgloveBridge from dimos.robot.unitree.g1.blueprints.perceptive.unitree_g1 import unitree_g1 -from dimos.visualization.vis_module import vis_module unitree_g1_shm = autoconnect( unitree_g1.transports( @@ -31,9 +30,10 @@ ), } ), - vis_module( - viewer_backend=global_config.viewer, - foxglove_config={"shm_channels": ["/color_image#sensor_msgs.Image"]}, + FoxgloveBridge.blueprint( + shm_channels=[ + "/color_image#sensor_msgs.Image", + ] ), ) diff --git a/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py b/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py index eeabea7909..b04443732f 100644 --- a/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py +++ b/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py @@ -40,7 +40,8 @@ from dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector import ( WavefrontFrontierExplorer, ) -from dimos.visualization.vis_module import vis_module +from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule def _convert_camera_info(camera_info: Any) -> Any: @@ -93,6 +94,7 @@ def _g1_rerun_blueprint() -> Any: rerun_config = { "blueprint": _g1_rerun_blueprint, + "pubsubs": [LCM()], "visual_override": { "world/camera_info": _convert_camera_info, "world/navigation_costmap": _convert_navigation_costmap, @@ -102,7 +104,18 @@ def _g1_rerun_blueprint() -> Any: }, } -_with_vis = vis_module(viewer_backend=global_config.viewer, rerun_config=rerun_config) +if global_config.viewer == "foxglove": + from dimos.robot.foxglove_bridge import FoxgloveBridge + + _with_vis = autoconnect(FoxgloveBridge.blueprint()) +elif global_config.viewer.startswith("rerun"): + from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode + + _with_vis = autoconnect( + RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config) + ) +else: + _with_vis = autoconnect() def _create_webcam() -> Webcam: @@ -137,6 +150,8 @@ def _create_webcam() -> Webcam: VoxelGridMapper.blueprint(), CostMapper.blueprint(), WavefrontFrontierExplorer.blueprint(), + # Visualization + WebsocketVisModule.blueprint(), ) .global_config(n_workers=4, robot_model="unitree_g1") .transports( diff --git a/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_security.py b/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_security.py index 4b39a106b8..be9e04a7fd 100644 --- a/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_security.py +++ b/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_security.py @@ -18,7 +18,7 @@ from dimos.core.coordination.blueprints import autoconnect from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.robot.unitree.go2.blueprints.agentic.unitree_go2_agentic import unitree_go2_agentic -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode def _convert_camera_info(camera_info: Any) -> Any: @@ -85,7 +85,7 @@ def _go2_rerun_blueprint() -> Any: unitree_go2_security = autoconnect( unitree_go2_agentic, - RerunBridgeModule.blueprint(**rerun_config), + RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config), ) __all__ = ["unitree_go2_security"] diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py index 4f86ccb0a3..54a2c0f7c6 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py @@ -22,9 +22,10 @@ from dimos.core.global_config import global_config from dimos.core.transport import pSHMTransport from dimos.msgs.sensor_msgs.Image import Image +from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator from dimos.robot.unitree.go2.connection import GO2Connection -from dimos.visualization.vis_module import vis_module +from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule # Mac has some issue with high bandwidth UDP, so we use pSHMTransport for color_image # actually we can use pSHMTransport for all platforms, and for all streams @@ -98,6 +99,9 @@ def _go2_rerun_blueprint() -> Any: rerun_config = { "blueprint": _go2_rerun_blueprint, + # any pubsub that supports subscribe_all and topic that supports str(topic) + # is acceptable here + "pubsubs": [LCM()], # Custom converters for specific rerun entity paths # Normally all these would be specified in their respectative modules # Until this is implemented we have central overrides here @@ -119,20 +123,30 @@ def _go2_rerun_blueprint() -> Any: }, } -_with_vis = autoconnect( - _transports_base, - vis_module( - viewer_backend=global_config.viewer, - rerun_config=rerun_config, - foxglove_config={"shm_channels": ["/color_image#sensor_msgs.Image"]}, - ), -) + +if global_config.viewer == "foxglove": + from dimos.robot.foxglove_bridge import FoxgloveBridge + + with_vis = autoconnect( + _transports_base, + FoxgloveBridge.blueprint(shm_channels=["/color_image#sensor_msgs.Image"]), + ) +elif global_config.viewer.startswith("rerun"): + from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode + + with_vis = autoconnect( + _transports_base, + RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config), + ) +else: + with_vis = _transports_base unitree_go2_basic = ( autoconnect( - _with_vis, + with_vis, GO2Connection.blueprint(), + WebsocketVisModule.blueprint(), ) .global_config(n_workers=4, robot_model="unitree_go2") .configurators(ClockSyncConfigurator()) diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py index bda362eeca..a7a10767bf 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py @@ -22,13 +22,15 @@ from dimos.core.coordination.blueprints import autoconnect from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator -from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import _with_vis +from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import with_vis from dimos.robot.unitree.go2.fleet_connection import Go2FleetConnection +from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule unitree_go2_fleet = ( autoconnect( - _with_vis, + with_vis, Go2FleetConnection.blueprint(), + WebsocketVisModule.blueprint(), ) .global_config(n_workers=4, robot_model="unitree_go2") .configurators(ClockSyncConfigurator()) diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_keyboard_teleop.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_keyboard_teleop.py index 3be0c62379..01117ec3b5 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_keyboard_teleop.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_keyboard_teleop.py @@ -31,10 +31,6 @@ unitree_go2_webrtc_keyboard_teleop = autoconnect( unitree_go2_coordinator, KeyboardTeleop.blueprint(), -).remappings( - [ - (KeyboardTeleop, "tele_cmd_vel", "cmd_vel"), - ] ) __all__ = ["unitree_go2_webrtc_keyboard_teleop"] diff --git a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py index 16711115ab..f353d995af 100644 --- a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py +++ b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py @@ -27,7 +27,6 @@ ) from dimos.navigation.patrolling.module import PatrollingModule from dimos.navigation.replanning_a_star.module import ReplanningAStarPlanner -from dimos.navigation.smart_nav.modules.movement_manager.movement_manager import MovementManager from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import unitree_go2_basic unitree_go2 = autoconnect( @@ -37,8 +36,7 @@ ReplanningAStarPlanner.blueprint(), WavefrontFrontierExplorer.blueprint(), PatrollingModule.blueprint(), - MovementManager.blueprint(), -).global_config(n_workers=10, robot_model="unitree_go2") +).global_config(n_workers=9, robot_model="unitree_go2") class Go2MemoryConfig(RecorderConfig): @@ -54,6 +52,6 @@ class Go2Memory(Recorder): unitree_go2_memory = autoconnect( unitree_go2, Go2Memory.blueprint(), -).global_config(n_workers=11) +).global_config(n_workers=10) __all__ = ["unitree_go2", "unitree_go2_memory"] diff --git a/dimos/robot/unitree/keyboard_teleop.py b/dimos/robot/unitree/keyboard_teleop.py index 3e8f76a1cc..e3c78ecc52 100644 --- a/dimos/robot/unitree/keyboard_teleop.py +++ b/dimos/robot/unitree/keyboard_teleop.py @@ -38,14 +38,14 @@ class KeyboardTeleop(Module): """Pygame-based keyboard control module. - Outputs standard Twist messages on /tele_cmd_vel for velocity control. + Outputs standard Twist messages on /cmd_vel for velocity control. Speed constants can be tuned at the top of this file, or overridden per-instance by passing linear_speed / angular_speed / boost_multiplier / slow_multiplier to the constructor. """ - tele_cmd_vel: Out[Twist] # Standard velocity commands + cmd_vel: Out[Twist] # Standard velocity commands _stop_event: threading.Event _keys_held: set[int] | None = None @@ -86,7 +86,7 @@ def stop(self) -> None: stop_twist = Twist() stop_twist.linear = Vector3(0, 0, 0) stop_twist.angular = Vector3(0, 0, 0) - self.tele_cmd_vel.publish(stop_twist) + self.cmd_vel.publish(stop_twist) self._stop_event.set() @@ -119,7 +119,7 @@ def _pygame_loop(self) -> None: stop_twist = Twist() stop_twist.linear = Vector3(0, 0, 0) stop_twist.angular = Vector3(0, 0, 0) - self.tele_cmd_vel.publish(stop_twist) + self.cmd_vel.publish(stop_twist) print("EMERGENCY STOP!") elif event.key == pygame.K_ESCAPE: # ESC quits @@ -163,7 +163,7 @@ def _pygame_loop(self) -> None: twist.angular.z *= speed_multiplier # Always publish twist at 50Hz - self.tele_cmd_vel.publish(twist) + self.cmd_vel.publish(twist) self._update_display(twist) diff --git a/dimos/robot/unitree/mujoco_connection.py b/dimos/robot/unitree/mujoco_connection.py index 43ddeb6530..39c0904684 100644 --- a/dimos/robot/unitree/mujoco_connection.py +++ b/dimos/robot/unitree/mujoco_connection.py @@ -20,12 +20,9 @@ from collections.abc import Callable import functools import json -import os -from pathlib import Path import pickle import subprocess import sys -import sysconfig import threading import time from typing import Any, TypeVar @@ -129,23 +126,12 @@ def start(self) -> None: # Launch the subprocess try: - # mjpython must be used on macOS (because of launch_passive inside mujoco_process.py). - # It needs libpython on the dylib search path; uv-installed Pythons - # use @rpath which doesn't always resolve inside venvs, so we - # point DYLD_LIBRARY_PATH at the real libpython directory. + # mjpython must be used macOS (because of launch_passive inside mujoco_process.py) executable = sys.executable if sys.platform != "darwin" else "mjpython" - env = os.environ.copy() - if sys.platform == "darwin": - # on some systems mujoco looks in the wrong place for shared libraries. So we force it look in the right place - libdir = Path(sysconfig.get_config_var("LIBDIR") or "") - if libdir.is_dir(): - existing = env.get("DYLD_LIBRARY_PATH", "") - env["DYLD_LIBRARY_PATH"] = f"{libdir}:{existing}" if existing else str(libdir) self.process = subprocess.Popen( [executable, str(LAUNCHER_PATH), config_pickle, shm_names_json], stderr=subprocess.PIPE, - env=env, ) except Exception as e: diff --git a/dimos/simulation/unity/blueprint.py b/dimos/simulation/unity/blueprint.py index d9b29ee610..f7e2d34ccb 100644 --- a/dimos/simulation/unity/blueprint.py +++ b/dimos/simulation/unity/blueprint.py @@ -28,7 +28,7 @@ from dimos.core.coordination.blueprints import autoconnect from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.simulation.unity.module import UnityBridgeModule -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode def _rerun_blueprint() -> Any: @@ -57,5 +57,5 @@ def _rerun_blueprint() -> Any: unity_sim = autoconnect( UnityBridgeModule.blueprint(), - RerunBridgeModule.blueprint(**rerun_config), + RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config), ) diff --git a/dimos/teleop/quest/blueprints.py b/dimos/teleop/quest/blueprints.py index b825f29a17..57c925c3f0 100644 --- a/dimos/teleop/quest/blueprints.py +++ b/dimos/teleop/quest/blueprints.py @@ -26,12 +26,12 @@ from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.teleop.quest.quest_extensions import ArmTeleopModule from dimos.teleop.quest.quest_types import Buttons -from dimos.visualization.vis_module import vis_module +from dimos.visualization.rerun.bridge import RerunBridgeModule # Arm teleop with press-and-hold engage (has rerun viz) teleop_quest_rerun = autoconnect( ArmTeleopModule.blueprint(), - vis_module("rerun"), + RerunBridgeModule.blueprint(), ).transports( { ("left_controller_output", PoseStamped): LCMTransport("/teleop/left_delta", PoseStamped), diff --git a/dimos/test_no_sections.py b/dimos/test_no_sections.py index 79f2d61b8f..902288b2e6 100644 --- a/dimos/test_no_sections.py +++ b/dimos/test_no_sections.py @@ -52,8 +52,6 @@ ".tox", # third-party vendored code "gtsam", - # hidden/personal directories - ".hidden", } # Lines that match section patterns but are actually programmatic / intentional. diff --git a/dimos/utils/generic.py b/dimos/utils/generic.py index 200c7c6d86..84168ce057 100644 --- a/dimos/utils/generic.py +++ b/dimos/utils/generic.py @@ -16,27 +16,10 @@ import hashlib import json import os -import socket import string from typing import Any, Generic, TypeVar, overload import uuid -import psutil - - -def get_local_ips() -> list[tuple[str, str]]: - """Return ``(ip, interface_name)`` for every non-loopback IPv4 address. - - Picks up physical, virtual, and VPN interfaces (including Tailscale). - """ - results: list[tuple[str, str]] = [] - for iface, addrs in psutil.net_if_addrs().items(): - for addr in addrs: - if addr.family == socket.AF_INET and not addr.address.startswith("127."): - results.append((addr.address, iface)) - return results - - _T = TypeVar("_T") diff --git a/dimos/visualization/rerun/bridge.py b/dimos/visualization/rerun/bridge.py index f6744e74fb..f2e3e51d08 100644 --- a/dimos/visualization/rerun/bridge.py +++ b/dimos/visualization/rerun/bridge.py @@ -18,19 +18,18 @@ from collections.abc import Callable from dataclasses import field -import socket +from functools import lru_cache import subprocess import time from typing import ( Any, + Literal, Protocol, TypeAlias, TypeGuard, cast, - get_args, runtime_checkable, ) -from urllib.parse import urlparse from reactivex.disposable import Disposable import rerun as rr @@ -38,23 +37,19 @@ import rerun.blueprint as rrb from rerun.blueprint import Blueprint from toolz import pipe # type: ignore[import-untyped] +import typer from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.protocol.pubsub.patterns import Glob, pattern_matches from dimos.protocol.pubsub.spec import SubscribeAllCapable -from dimos.utils.generic import get_local_ips from dimos.utils.logging_config import setup_logger -from dimos.visualization.rerun.constants import ( - RERUN_ENABLE_WEB, - RERUN_GRPC_PORT, - RERUN_OPEN_DEFAULT, - RERUN_WEB_PORT, - RerunOpenOption, -) from dimos.visualization.rerun.init import rerun_init +RERUN_GRPC_PORT = 9877 +RERUN_WEB_PORT = 9090 + # TODO OUT visual annotations # # In the future it would be nice if modules can annotate their individual OUTs with (general or rerun specific) @@ -100,6 +95,7 @@ BlueprintFactory: TypeAlias = Callable[[], "Blueprint"] +# to_rerun() can return a single archetype or a list of (entity_path, archetype) tuples RerunMulti: TypeAlias = "list[tuple[str, Archetype]]" RerunData: TypeAlias = "Archetype | RerunMulti" @@ -123,16 +119,18 @@ class RerunConvertible(Protocol): def to_rerun(self) -> RerunData: ... +ViewerMode = Literal["native", "web", "connect", "none"] + + def _hex_to_rgba(hex_color: str) -> int: """Convert '#RRGGBB' to a 0xRRGGBBAA int (fully opaque).""" h = hex_color.lstrip("#") - if len(h) == 6: - return int(h + "ff", 16) - return int(h[:8], 16) + return (int(h, 16) << 8) | 0xFF def _with_graph_tab(bp: Blueprint) -> Blueprint: """Add a Graph tab alongside the existing viewer layout without changing it.""" + root = bp.root_container return rrb.Blueprint( rrb.Tabs( @@ -158,24 +156,48 @@ def _default_blueprint() -> Blueprint: ) +# Maps global_config.viewer -> bridge viewer_mode. +# Evaluated at blueprint construction time (main process), not in start() (worker process). +_BACKEND_TO_MODE: dict[str, ViewerMode] = { + "rerun": "native", + "rerun-web": "web", + "rerun-connect": "connect", + "none": "none", +} + + +def _resolve_viewer_mode() -> ViewerMode: + from dimos.core.global_config import global_config + + return _BACKEND_TO_MODE.get(global_config.viewer, "native") + + class Config(ModuleConfig): + """Configuration for RerunBridgeModule.""" + pubsubs: list[SubscribeAllCapable[Any, Any]] = field(default_factory=lambda: [LCM()]) visual_override: dict[Glob | str, Callable[[Any], Archetype]] = field(default_factory=dict) + + # Static items logged once after start. Maps entity_path -> callable(rr) returning Archetype static: dict[str, Callable[[Any], Archetype]] = field(default_factory=dict) + + grpc_port: int = RERUN_GRPC_PORT + web_port: int = RERUN_WEB_PORT + + # Per-entity max update rate (Hz). Entities not listed are unthrottled. + # Use for heavy entities to prevent viewer backpressure. max_hz: dict[str, float] = field(default_factory=dict) entity_prefix: str = "world" topic_to_entity: Callable[[Any], str] | None = None + viewer_mode: ViewerMode = field(default_factory=_resolve_viewer_mode) connect_url: str = "rerun+http://127.0.0.1:9877/proxy" memory_limit: str = "25%" - rerun_open: RerunOpenOption = RERUN_OPEN_DEFAULT - rerun_web: bool = RERUN_ENABLE_WEB - web_port: int = RERUN_WEB_PORT - blueprint: BlueprintFactory | None = _default_blueprint - -Config.model_rebuild(_types_namespace={"Archetype": Archetype, "Blueprint": Blueprint}) + # Blueprint factory: callable(rrb) -> Blueprint for viewer layout configuration + # Set to None to disable default blueprint + blueprint: BlueprintFactory | None = _default_blueprint class RerunBridgeModule(Module): @@ -195,31 +217,22 @@ class RerunBridgeModule(Module): """ config: Config - _last_log: dict[str, float] # TODO this doesn't belong here, either hardcode it or put it to rerun bridge config - GRAPH_VIZ_SCALE = 100.0 - MODULE_RADIUS = 20.0 - CHANNEL_RADIUS = 12.0 - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._last_log = {} - self._override_cache: dict[str, Callable[[Any], RerunData | None]] = {} + GV_SCALE = 100.0 # graphviz inches to rerun screen units + MODULE_RADIUS = 30.0 + CHANNEL_RADIUS = 20.0 + @lru_cache(maxsize=256) def _visual_override_for_entity_path( self, entity_path: str ) -> Callable[[Any], RerunData | None]: """Return a composed visual override for the entity path. Chains matching overrides from config, ending with final_convert - which handles .to_rerun() or passes through Archetypes. Cached per - instance (not via ``lru_cache`` on a method, which would leak ``self``). + which handles .to_rerun() or passes through Archetypes. """ - cached = self._override_cache.get(entity_path) - if cached is not None: - return cached - + # find all matching converters for this entity path matches = [ fn for pattern, fn in self.config.visual_override.items() @@ -228,13 +241,9 @@ def _visual_override_for_entity_path( # None means "suppress this topic entirely" if any(fn is None for fn in matches): + return lambda msg: None - def suppressed(msg: Any) -> RerunData | None: - return None - - self._override_cache[entity_path] = suppressed - return suppressed - + # final step (ensures we return Archetype or None) def final_convert(msg: Any) -> RerunData | None: if isinstance(msg, Archetype): return msg @@ -244,21 +253,23 @@ def final_convert(msg: Any) -> RerunData | None: return msg.to_rerun() return None - def composed(msg: Any) -> RerunData | None: - return cast("RerunData | None", pipe(msg, *matches, final_convert)) - - self._override_cache[entity_path] = composed - return composed + # compose all converters + return lambda msg: pipe(msg, *matches, final_convert) def _get_entity_path(self, topic: Any) -> str: + """Convert a topic to a Rerun entity path.""" if self.config.topic_to_entity: return self.config.topic_to_entity(topic) + # Default: use topic.name if available (LCM Topic), else str topic_str = getattr(topic, "name", None) or str(topic) - topic_str = topic_str.split("#")[0] # strip LCM topic suffix + # Strip everything after # (LCM topic suffix) + topic_str = topic_str.split("#")[0] return f"{self.config.entity_prefix}{topic_str}" def _on_message(self, msg: Any, topic: Any) -> None: + """Handle incoming message - log to rerun.""" + entity_path: str = self._get_entity_path(topic) # Throttle entities with a max_hz limit @@ -268,6 +279,7 @@ def _on_message(self, msg: Any, topic: Any) -> None: return self._last_log[entity_path] = now + # apply visual overrides (including final_convert which handles .to_rerun()) rerun_data: RerunData | None = self._visual_override_for_entity_path(entity_path)(msg) if not rerun_data: @@ -284,87 +296,47 @@ def _on_message(self, msg: Any, topic: Any) -> None: def start(self) -> None: super().start() - logger.info("Rerun bridge starting") + logger.info("Rerun bridge starting", viewer_mode=self.config.viewer_mode) - self._last_log = {} + # Build throttle lookup: entity_path → min interval in seconds + self._last_log: dict[str, float] = {} self._min_intervals: dict[str, float] = { entity: 1.0 / hz for entity, hz in self.config.max_hz.items() if hz > 0 } + # Initialize and spawn Rerun viewer rerun_init("dimos") - parsed = urlparse(self.config.connect_url.replace("rerun+", "", 1)) - grpc_port = parsed.port or RERUN_GRPC_PORT - - port_in_use = False - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - port_in_use = sock.connect_ex(("127.0.0.1", grpc_port)) == 0 - - if port_in_use: - logger.info(f"gRPC port {grpc_port} already in use, connecting to existing server") - rr.connect_grpc(url=self.config.connect_url) - server_uri = self.config.connect_url - else: - server_uri = rr.serve_grpc( - grpc_port=grpc_port, - server_memory_limit=self.config.memory_limit, - ) - logger.info(f"Rerun gRPC server ready at {server_uri}") - - if self.config.rerun_open not in get_args(RerunOpenOption): - logger.warning( - f"rerun_open was {self.config.rerun_open} which is not one of " - f"{get_args(RerunOpenOption)}" - ) - - spawned = False - if self.config.rerun_open in ("native", "both"): + if self.config.viewer_mode == "native": try: import rerun_bindings - # Use --connect so the viewer connects to the bridge's gRPC - # server rather than starting its own (which would conflict). rerun_bindings.spawn( + port=self.config.grpc_port, executable_name="dimos-viewer", memory_limit=self.config.memory_limit, - extra_args=["--connect", server_uri], ) - spawned = True + rr.connect_grpc(f"rerun+http://127.0.0.1:{self.config.grpc_port}/proxy") except ImportError: - pass # dimos-viewer not installed + rr.spawn(connect=True, memory_limit=self.config.memory_limit) except Exception: logger.warning( "dimos-viewer found but failed to spawn, falling back to stock rerun", exc_info=True, ) + rr.spawn(connect=True, memory_limit=self.config.memory_limit) + elif self.config.viewer_mode == "web": + server_uri = rr.serve_grpc() + rr.serve_web_viewer(connect_to=server_uri, open_browser=False) - # fallback on normal (non-dimos-viewer) rerun - if not spawned: - try: - rr.spawn(connect=True, memory_limit=self.config.memory_limit) - spawned = True - except (RuntimeError, FileNotFoundError): - logger.warning( - "Rerun native viewer not available (headless?). " - "Bridge will continue without a viewer — data is still " - "accessible via --rerun-open web or by connecting a viewer to the gRPC server.", - exc_info=True, - ) - - open_web = self.config.rerun_open == "web" or self.config.rerun_open == "both" - if open_web or self.config.rerun_web: - rr.serve_web_viewer( - connect_to=server_uri, - open_browser=open_web, - web_port=self.config.web_port, - ) - - if self.config.rerun_open == "none" or (self.config.rerun_open == "native" and not spawned): - self._log_connect_hints(grpc_port) + elif self.config.viewer_mode == "connect": + rr.connect_grpc(self.config.connect_url) + # "none" - just init, no viewer (connect externally) if self.config.blueprint: rr.send_blueprint(_with_graph_tab(self.config.blueprint())) + # Start pubsubs and subscribe to all messages for pubsub in self.config.pubsubs: logger.info(f"bridge listening on {pubsub.__class__.__name__}") if hasattr(pubsub, "start"): @@ -372,35 +344,13 @@ def start(self) -> None: unsub = pubsub.subscribe_all(self._on_message) self.register_disposable(Disposable(unsub)) + # Add pubsub stop as disposable for pubsub in self.config.pubsubs: if hasattr(pubsub, "stop"): self.register_disposable(Disposable(pubsub.stop)) # type: ignore[union-attr] self._log_static() - def _log_connect_hints(self, grpc_port: int) -> None: - """Log CLI commands for connecting a viewer to this bridge.""" - local_ips = get_local_ips() - hostname = socket.gethostname() - connect_url = f"rerun+http://127.0.0.1:{grpc_port}/proxy" - - lines = [ - "", - "=" * 60, - "Rerun gRPC server running (no viewer opened)", - "", - "Connect a viewer:", - f" dimos-viewer --connect {connect_url}", - ] - for ip, iface in local_ips: - lines.append(f" dimos-viewer --connect rerun+http://{ip}:{grpc_port}/proxy # {iface}") - lines.append("") - lines.append(f" hostname: {hostname}") - lines.append("=" * 60) - lines.append("") - - logger.info("\n".join(lines)) - def _log_static(self) -> None: for entity_path, factory in self.config.static.items(): data = factory(rr) @@ -421,6 +371,7 @@ def log_blueprint_graph(self, dot_code: str, module_names: list[str]) -> None: dot_code: The DOT-format graph (from ``introspection.blueprint.dot.render``). module_names: List of module class names (to distinguish modules from channels). """ + try: result = subprocess.run( ["dot", "-Tplain"], input=dot_code, text=True, capture_output=True, timeout=30 @@ -442,8 +393,8 @@ def log_blueprint_graph(self, dot_code: str, module_names: list[str]) -> None: if line.startswith("node "): parts = line.split() node_id = parts[1].strip('"') - x = float(parts[2]) * self.GRAPH_VIZ_SCALE - y = -float(parts[3]) * self.GRAPH_VIZ_SCALE + x = float(parts[2]) * self.GV_SCALE + y = -float(parts[3]) * self.GV_SCALE label = parts[6].strip('"') color = parts[9].strip('"') @@ -476,5 +427,49 @@ def log_blueprint_graph(self, dot_code: str, module_names: list[str]) -> None: @rpc def stop(self) -> None: - self._override_cache.clear() super().stop() + + +def run_bridge( + viewer_mode: str = "native", + memory_limit: str = "25%", +) -> None: + """Start a RerunBridgeModule with default LCM config and block until interrupted.""" + import signal + + from dimos.protocol.service.lcmservice import autoconf + + autoconf(check_only=True) + + bridge = RerunBridgeModule( + viewer_mode=viewer_mode, + memory_limit=memory_limit, + # any pubsub that supports subscribe_all and topic that supports str(topic) + # is acceptable here + pubsubs=[LCM()], + ) + + bridge.start() + + signal.signal(signal.SIGINT, lambda *_: bridge.stop()) + signal.pause() + + +app = typer.Typer() + + +@app.command() +def cli( + viewer_mode: str = typer.Option( + "native", help="Viewer mode: native (desktop), web (browser), none (headless)" + ), + memory_limit: str = typer.Option( + "25%", help="Memory limit for Rerun viewer (e.g., '4GB', '16GB', '25%')" + ), +) -> None: + """Rerun bridge for LCM messages.""" + run_bridge(viewer_mode=viewer_mode, memory_limit=memory_limit) + + +if __name__ == "__main__": + app() diff --git a/dimos/visualization/rerun/conftest.py b/dimos/visualization/rerun/conftest.py deleted file mode 100644 index f269bb8015..0000000000 --- a/dimos/visualization/rerun/conftest.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import asyncio -from collections.abc import Callable -import time - -import pytest -import websockets.asyncio.client as ws_client - - -def _wait_for_server(port: int, timeout: float = 5.0) -> None: - """Block until the WebSocket server on *port* accepts a connection.""" - - async def _probe() -> None: - async with ws_client.connect(f"ws://127.0.0.1:{port}/ws"): - pass - - deadline = time.monotonic() + timeout - while time.monotonic() < deadline: - try: - asyncio.run(_probe()) - return - except Exception: - time.sleep(0.05) - raise TimeoutError(f"Server on port {port} did not become ready within {timeout}s") - - -@pytest.fixture() -def wait_for_server() -> Callable[[int, float], None]: - """Fixture that returns a callable to wait for a WebSocket server.""" - return _wait_for_server diff --git a/dimos/visualization/rerun/constants.py b/dimos/visualization/rerun/constants.py deleted file mode 100644 index 860c691cef..0000000000 --- a/dimos/visualization/rerun/constants.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Rerun visualization defaults and type aliases. - -This module is intentionally free of heavy imports so it can be -loaded from lightweight entry-points like ``global_config`` and -``dimos --help`` without pulling in the Rerun SDK or the module -framework. -""" - -from typing import Literal, TypeAlias - -ViewerBackend: TypeAlias = Literal["rerun", "foxglove", "none"] -RerunOpenOption: TypeAlias = Literal["none", "web", "native", "both"] - -RERUN_OPEN_DEFAULT: RerunOpenOption = "native" -RERUN_ENABLE_WEB = False -RERUN_GRPC_PORT = 9876 -RERUN_WEB_PORT = 9877 diff --git a/dimos/visualization/rerun/test_viewer_ws_e2e.py b/dimos/visualization/rerun/test_viewer_ws_e2e.py deleted file mode 100644 index 260699a3e8..0000000000 --- a/dimos/visualization/rerun/test_viewer_ws_e2e.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""End-to-end tests for dimos-viewer ↔ RerunWebSocketServer protocol.""" - -from __future__ import annotations - -import asyncio -import json -import os -import subprocess -import threading -import time -from typing import Any - -import pytest -import websockets.asyncio.client as ws_client - -from dimos.visualization.rerun.websocket_server import RerunWebSocketServer - -_E2E_PORT = 13032 - - -@pytest.fixture() -def server(wait_for_server: Any) -> RerunWebSocketServer: - module = RerunWebSocketServer(port=_E2E_PORT) - module.start() - wait_for_server(_E2E_PORT) - yield module # type: ignore[misc] - module.stop() - - -def _send_messages(port: int, messages: list[dict[str, Any]], *, delay: float = 0.05) -> None: - async def _run() -> None: - async with ws_client.connect(f"ws://127.0.0.1:{port}/ws") as ws: - for msg in messages: - await ws.send(json.dumps(msg)) - await asyncio.sleep(delay) - - asyncio.run(_run()) - - -class TestViewerProtocolE2E: - """Verify the Python-server side of the viewer ↔ DimOS protocol.""" - - def test_viewer_click_reaches_stream(self, server: RerunWebSocketServer) -> None: - """A viewer click over WebSocket publishes PointStamped.""" - received: list[Any] = [] - done = threading.Event() - unsub = server.clicked_point.subscribe(lambda pt: (received.append(pt), done.set())) - - _send_messages( - _E2E_PORT, - [ - { - "type": "click", - "x": 10.0, - "y": 20.0, - "z": 0.5, - "entity_path": "/world/robot", - "timestamp_ms": 42000, - } - ], - ) - - done.wait(timeout=3.0) - unsub() - - assert len(received) == 1 - pt = received[0] - assert pt.x == pytest.approx(10.0) - assert pt.y == pytest.approx(20.0) - assert pt.z == pytest.approx(0.5) - assert pt.frame_id == "/world/robot" - assert pt.ts == pytest.approx(42.0) - - def test_full_viewer_session_sequence(self, server: RerunWebSocketServer) -> None: - """Realistic session: heartbeats, click, twist, stop — only the click produces a point.""" - received: list[Any] = [] - done = threading.Event() - unsub = server.clicked_point.subscribe(lambda pt: (received.append(pt), done.set())) - - _send_messages( - _E2E_PORT, - [ - {"type": "heartbeat", "timestamp_ms": 1000}, - {"type": "heartbeat", "timestamp_ms": 2000}, - { - "type": "click", - "x": 3.14, - "y": 2.71, - "z": 1.41, - "entity_path": "/world", - "timestamp_ms": 3000, - }, - { - "type": "twist", - "linear_x": 0.5, - "linear_y": 0.0, - "linear_z": 0.0, - "angular_x": 0.0, - "angular_y": 0.0, - "angular_z": 0.0, - }, - {"type": "stop"}, - {"type": "heartbeat", "timestamp_ms": 4000}, - ], - delay=0.2, - ) - - done.wait(timeout=3.0) - unsub() - - assert len(received) == 1, f"Expected exactly 1 click, got {len(received)}" - assert received[0].x == pytest.approx(3.14) - assert received[0].y == pytest.approx(2.71) - assert received[0].z == pytest.approx(1.41) - - def test_reconnect_after_disconnect(self, server: RerunWebSocketServer) -> None: - """Server keeps accepting new connections after a client disconnects.""" - received: list[Any] = [] - all_done = threading.Event() - - def _on_pt(pt: Any) -> None: - received.append(pt) - if len(received) >= 2: - all_done.set() - - unsub = server.clicked_point.subscribe(_on_pt) - - _send_messages( - _E2E_PORT, - [{"type": "click", "x": 1.0, "y": 0.0, "z": 0.0, "entity_path": "", "timestamp_ms": 0}], - ) - _send_messages( - _E2E_PORT, - [{"type": "click", "x": 2.0, "y": 0.0, "z": 0.0, "entity_path": "", "timestamp_ms": 0}], - ) - - all_done.wait(timeout=5.0) - unsub() - - xs = sorted(pt.x for pt in received) - assert xs == [1.0, 2.0], f"Unexpected xs: {xs}" - - -class TestViewerBinaryConnectMode: - """Smoke test: dimos-viewer binary starts in --connect mode.""" - - @pytest.fixture() - def viewer_process(self, server: RerunWebSocketServer) -> subprocess.Popen[bytes]: - proc = subprocess.Popen( - [ - "dimos-viewer", - "--connect", - f"--ws-url=ws://127.0.0.1:{_E2E_PORT}/ws", - ], - env={**os.environ, "DISPLAY": ""}, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - yield proc # type: ignore[misc] - proc.terminate() - try: - proc.wait(timeout=3) - except subprocess.TimeoutExpired: - proc.kill() - - @pytest.mark.skip( - reason="Incompatible with current winit: fails without DISPLAY (headless CI exits before WS connect) and hangs with DISPLAY (GUI event loop blocks before printing URL).", - ) - def test_viewer_ws_client_connects(self, viewer_process: subprocess.Popen[bytes]) -> None: - """dimos-viewer --connect starts and its WS client connects to our server.""" - deadline = time.monotonic() + 5.0 - while time.monotonic() < deadline: - if viewer_process.poll() is not None: - break - time.sleep(0.1) - - stdout = ( - viewer_process.stdout.read().decode(errors="replace") if viewer_process.stdout else "" - ) - stderr = ( - viewer_process.stderr.read().decode(errors="replace") if viewer_process.stderr else "" - ) - - combined = stdout + stderr - assert f"ws://127.0.0.1:{_E2E_PORT}" in combined, ( - f"Viewer did not attempt WS connection.\nstdout:\n{stdout}\nstderr:\n{stderr}" - ) diff --git a/dimos/visualization/rerun/test_websocket_server.py b/dimos/visualization/rerun/test_websocket_server.py deleted file mode 100644 index b4304cf7b4..0000000000 --- a/dimos/visualization/rerun/test_websocket_server.py +++ /dev/null @@ -1,210 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for RerunWebSocketServer.""" - -from __future__ import annotations - -import asyncio -import json -import threading -import time -from typing import Any - -import pytest -import websockets.asyncio.client as ws_client - -from dimos.visualization.rerun.websocket_server import RerunWebSocketServer - -_TEST_PORT = 13031 - - -class MockViewerPublisher: - """Simulates dimos-viewer sending JSON events over WebSocket.""" - - def __init__(self, url: str) -> None: - self._url = url - self._ws: Any = None - self._loop: asyncio.AbstractEventLoop | None = None - - def __enter__(self) -> MockViewerPublisher: - self._loop = asyncio.new_event_loop() - self._ws = self._loop.run_until_complete(self._connect()) - return self - - def __exit__(self, *_: Any) -> None: - if self._ws is not None and self._loop is not None: - self._loop.run_until_complete(self._ws.close()) - if self._loop is not None: - self._loop.close() - - async def _connect(self) -> Any: - return await ws_client.connect(self._url) - - def send_click( - self, x: float, y: float, z: float, entity_path: str = "", timestamp_ms: int = 0 - ) -> None: - self._send( - { - "type": "click", - "x": x, - "y": y, - "z": z, - "entity_path": entity_path, - "timestamp_ms": timestamp_ms, - } - ) - - def send_twist( - self, - linear_x: float, - linear_y: float, - linear_z: float, - angular_x: float, - angular_y: float, - angular_z: float, - ) -> None: - self._send( - { - "type": "twist", - "linear_x": linear_x, - "linear_y": linear_y, - "linear_z": linear_z, - "angular_x": angular_x, - "angular_y": angular_y, - "angular_z": angular_z, - } - ) - - def send_stop(self) -> None: - self._send({"type": "stop"}) - - def flush(self, delay: float = 0.1) -> None: - time.sleep(delay) - - def _send(self, msg: dict[str, Any]) -> None: - assert self._loop is not None and self._ws is not None - self._loop.run_until_complete(self._ws.send(json.dumps(msg))) - - -@pytest.fixture() -def server(wait_for_server: Any) -> RerunWebSocketServer: - module = RerunWebSocketServer(port=_TEST_PORT) - module.start() - wait_for_server(_TEST_PORT) - yield module # type: ignore[misc] - module.stop() - - -@pytest.fixture() -def publisher(server: RerunWebSocketServer) -> MockViewerPublisher: - with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as publisher: - yield publisher # type: ignore[misc] - - -# ── Tests ──────────────────────────────────────────────────────────────── - - -def test_click_publishes_point_stamped( - server: RerunWebSocketServer, publisher: MockViewerPublisher -) -> None: - """Click event arrives as PointStamped with correct coords, frame_id, and timestamp.""" - received: list[Any] = [] - done = threading.Event() - - unsub = server.clicked_point.subscribe(lambda point: (received.append(point), done.set())) - - publisher.send_click(1.5, 2.5, 0.0, "/robot/base", timestamp_ms=5000) - publisher.flush() - done.wait(timeout=2.0) - unsub() - - assert len(received) == 1 - point = received[0] - assert point.x == pytest.approx(1.5) - assert point.y == pytest.approx(2.5) - assert point.z == pytest.approx(0.0) - assert point.frame_id == "/robot/base" - assert point.ts == pytest.approx(5.0) - - -def test_twist_publishes_on_tele_cmd_vel( - server: RerunWebSocketServer, publisher: MockViewerPublisher -) -> None: - """Twist event arrives as Twist on tele_cmd_vel.""" - received: list[Any] = [] - done = threading.Event() - - unsub = server.tele_cmd_vel.subscribe(lambda twist: (received.append(twist), done.set())) - - publisher.send_twist(0.5, 0.0, 0.0, 0.0, 0.0, 0.8) - publisher.flush() - done.wait(timeout=2.0) - unsub() - - assert len(received) == 1 - assert received[0].linear.x == pytest.approx(0.5) - assert received[0].angular.z == pytest.approx(0.8) - - -def test_stop_publishes_zero_twist( - server: RerunWebSocketServer, publisher: MockViewerPublisher -) -> None: - """Stop event publishes a zero Twist on tele_cmd_vel.""" - received: list[Any] = [] - done = threading.Event() - - unsub = server.tele_cmd_vel.subscribe(lambda twist: (received.append(twist), done.set())) - - publisher.send_stop() - publisher.flush() - done.wait(timeout=2.0) - unsub() - - assert len(received) == 1 - assert received[0].is_zero() - - -def test_invalid_json_does_not_crash(server: RerunWebSocketServer) -> None: - """Malformed JSON is silently dropped; server stays alive for the next message.""" - - async def _send_bad() -> None: - async with ws_client.connect(f"ws://127.0.0.1:{_TEST_PORT}/ws") as ws: - await ws.send("this is not json {{") - await asyncio.sleep(0.1) - await ws.send(json.dumps({"type": "heartbeat", "timestamp_ms": 0})) - await asyncio.sleep(0.1) - - asyncio.run(_send_bad()) - - -def test_mixed_message_sequence( - server: RerunWebSocketServer, publisher: MockViewerPublisher -) -> None: - """Realistic session: heartbeat, click, twist, stop — only the click produces a point.""" - received: list[Any] = [] - done = threading.Event() - unsub = server.clicked_point.subscribe(lambda point: (received.append(point), done.set())) - - publisher.send_click(7.0, 8.0, 9.0, "/map", timestamp_ms=1100) - publisher.send_twist(0.3, 0.0, 0.0, 0.0, 0.0, 0.2) - publisher.send_stop() - publisher.flush() - done.wait(timeout=2.0) - unsub() - - assert len(received) == 1 - assert received[0].x == pytest.approx(7.0) - assert received[0].y == pytest.approx(8.0) - assert received[0].z == pytest.approx(9.0) diff --git a/dimos/visualization/rerun/websocket_server.py b/dimos/visualization/rerun/websocket_server.py deleted file mode 100644 index 0c0ac2acf2..0000000000 --- a/dimos/visualization/rerun/websocket_server.py +++ /dev/null @@ -1,244 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""WebSocket server module that receives events from dimos-viewer. - -When dimos-viewer is started with ``--connect``, LCM multicast is unavailable -across machines. The viewer falls back to sending click, twist, and stop events -as JSON over a WebSocket connection. This module acts as the server-side -counterpart: it listens for those connections and translates incoming messages -into DimOS stream publishes. - -Message format (newline-delimited JSON, ``"type"`` discriminant): - - {"type":"heartbeat","timestamp_ms":1234567890} - {"type":"click","x":1.0,"y":2.0,"z":3.0,"entity_path":"/world","timestamp_ms":1234567890} - {"type":"twist","linear_x":0.5,"linear_y":0.0,"linear_z":0.0, - "angular_x":0.0,"angular_y":0.0,"angular_z":0.8} - {"type":"stop"} -""" - -from __future__ import annotations - -import asyncio -import json -import logging -import socket -import threading -from typing import Any, Literal, TypedDict, Union - -import websockets -import websockets.asyncio.server as ws_server - -from dimos.core.core import rpc -from dimos.core.global_config import global_config -from dimos.core.module import Module, ModuleConfig -from dimos.core.stream import Out -from dimos.msgs.geometry_msgs.PointStamped import PointStamped -from dimos.msgs.geometry_msgs.Twist import Twist -from dimos.msgs.geometry_msgs.Vector3 import Vector3 -from dimos.utils.generic import get_local_ips -from dimos.utils.logging_config import setup_logger -from dimos.visualization.rerun.constants import RERUN_GRPC_PORT - -logger = setup_logger() - - -class ClickMsg(TypedDict): - type: Literal["click"] - x: float - y: float - z: float - entity_path: str - timestamp_ms: int - - -class TwistMsg(TypedDict): - type: Literal["twist"] - linear_x: float - linear_y: float - linear_z: float - angular_x: float - angular_y: float - angular_z: float - - -class StopMsg(TypedDict): - type: Literal["stop"] - - -class HeartbeatMsg(TypedDict): - type: Literal["heartbeat"] - timestamp_ms: int - - -ViewerMsg = Union[ClickMsg, TwistMsg, StopMsg, HeartbeatMsg] - - -def _handshake_noise_filter(record: logging.LogRecord) -> bool: - """Drop noisy "opening handshake failed" records from port scanners etc.""" - msg = record.getMessage() - return not ("opening handshake failed" in msg or "did not receive a valid HTTP request" in msg) - - -class Config(ModuleConfig): - host: str | None = None - port: int = 3030 - start_timeout: float = 10.0 - - -class RerunWebSocketServer(Module): - """Receives dimos-viewer WebSocket events and publishes them as DimOS streams. - - The viewer connects to this module (not the other way around) when running - in ``--connect`` mode. Each click event is converted to a ``PointStamped`` - and published on the ``clicked_point`` stream so downstream modules (e.g. - ``ReplanningAStarPlanner``) can consume it without modification. - - Outputs: - clicked_point: 3-D world-space point from the most recent viewer click. - tele_cmd_vel: Twist velocity commands from keyboard teleop, including stop events. - - Note: ``stop_movement`` is owned by ``MovementManager`` — it will fire - that signal when it sees the first teleop twist arrive here. - """ - - config: Config - - clicked_point: Out[PointStamped] - tele_cmd_vel: Out[Twist] - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._stop_event: asyncio.Event | None = None - self._server_ready = threading.Event() - self.host = self.config.host if self.config.host is not None else global_config.listen_host - - @rpc - def start(self) -> None: - super().start() - assert self._loop is not None - asyncio.run_coroutine_threadsafe(self._serve(), self._loop) - self._server_ready.wait(timeout=self.config.start_timeout) - self._log_connect_hints() - - @rpc - def stop(self) -> None: - self._server_ready.wait(timeout=self.config.start_timeout) - if self._loop is not None and not self._loop.is_closed() and self._stop_event is not None: - self._loop.call_soon_threadsafe(self._stop_event.set) - super().stop() - - def _log_connect_hints(self) -> None: - """Log full dimos-viewer commands that viewers can use to connect.""" - local_ips = get_local_ips() - hostname = socket.gethostname() - host = self.host - ws_url = f"ws://{host}:{self.config.port}/ws" - grpc_url = f"rerun+http://{host}:{RERUN_GRPC_PORT}/proxy" - - lines = [ - "", - "=" * 60, - f"RerunWebSocketServer listening on {ws_url}", - "", - "Connect a viewer:", - f" dimos-viewer --connect {grpc_url} --ws-url {ws_url}", - ] - if local_ips: - lines.append("") - lines.append("From another machine on the network:") - for ip, iface in local_ips: - remote_grpc = f"rerun+http://{ip}:{RERUN_GRPC_PORT}/proxy" - remote_ws = f"ws://{ip}:{self.config.port}/ws" - lines.append( - f" dimos-viewer --connect {remote_grpc} --ws-url {remote_ws} # {iface}" - ) - lines.append("") - lines.append(f" hostname: {hostname}") - lines.append("=" * 60) - lines.append("") - - logger.info("\n".join(lines)) - - async def _serve(self) -> None: - self._stop_event = asyncio.Event() - - ws_logger = logging.getLogger("websockets.server") - ws_logger.addFilter(_handshake_noise_filter) - - async with ws_server.serve( - self._handle_client, - host=self.host, - port=self.config.port, - ping_interval=30, - ping_timeout=30, - logger=ws_logger, - ): - self._server_ready.set() - await self._stop_event.wait() - - async def _handle_client(self, websocket: Any) -> None: - if hasattr(websocket, "request") and websocket.request.path != "/ws": - await websocket.close(1008, "Not Found") - return - addr = websocket.remote_address - logger.info(f"RerunWebSocketServer: viewer connected from {addr}") - try: - async for raw in websocket: - self._dispatch(raw) - except websockets.ConnectionClosed: - pass - - def _dispatch(self, raw: str | bytes) -> None: - try: - msg: dict[str, Any] = json.loads(raw) - except json.JSONDecodeError: - logger.warning(f"RerunWebSocketServer: ignoring non-JSON message: {raw!r}") - return - - if not isinstance(msg, dict): - return - - msg_type = msg.get("type") - - if msg_type == "click": - self.clicked_point.publish( - PointStamped( - x=float(msg.get("x", 0)), - y=float(msg.get("y", 0)), - z=float(msg.get("z", 0)), - ts=float(msg.get("timestamp_ms", 0)) / 1000.0, - frame_id=str(msg.get("entity_path", "")), - ) - ) - - elif msg_type == "twist": - self.tele_cmd_vel.publish( - Twist( - linear=Vector3( - float(msg.get("linear_x", 0)), - float(msg.get("linear_y", 0)), - float(msg.get("linear_z", 0)), - ), - angular=Vector3( - float(msg.get("angular_x", 0)), - float(msg.get("angular_y", 0)), - float(msg.get("angular_z", 0)), - ), - ) - ) - - elif msg_type == "stop": - self.tele_cmd_vel.publish(Twist.zero()) diff --git a/dimos/visualization/vis_module.py b/dimos/visualization/vis_module.py deleted file mode 100644 index badcba34db..0000000000 --- a/dimos/visualization/vis_module.py +++ /dev/null @@ -1,87 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Shared visualization module factory for all robot blueprints.""" - -from typing import Any, get_args - -from dimos.core.coordination.blueprints import Blueprint, autoconnect -from dimos.visualization.rerun.constants import ViewerBackend - - -def vis_module( - viewer_backend: ViewerBackend, - rerun_config: dict[str, Any] | None = None, - foxglove_config: dict[str, Any] | None = None, -) -> Blueprint: - """Create a visualization blueprint based on the selected viewer backend. - - Bundles the appropriate viewer module (Rerun or Foxglove) together with - the ``WebsocketVisModule`` and ``RerunWebSocketServer`` so that the web - dashboard and remote viewer connections work out of the box. - - Example usage:: - - from dimos.core.global_config import global_config - viz = vis_module( - global_config.viewer, - rerun_config={ - "visual_override": { - "world/camera_info": lambda ci: ci.to_rerun(...), - }, - "static": { - "world/tf/base_link": lambda rr: [rr.Boxes3D(...)], - }, - }, - ) - """ - from dimos.visualization.rerun.websocket_server import RerunWebSocketServer - from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule - - if foxglove_config is None: - foxglove_config = {} - if rerun_config is None: - rerun_config = {} - - match viewer_backend: - case "foxglove": - from dimos.robot.foxglove_bridge import FoxgloveBridge - - return autoconnect( - FoxgloveBridge.blueprint(**foxglove_config), - RerunWebSocketServer.blueprint(), - WebsocketVisModule.blueprint(), - ) - case "rerun": - from dimos.core.global_config import global_config - from dimos.protocol.pubsub.impl.lcmpubsub import LCM - from dimos.visualization.rerun.bridge import RerunBridgeModule - - rerun_config = {**rerun_config} # copy (avoid mutation) - rerun_config.setdefault("pubsubs", [LCM()]) - rerun_config.setdefault("rerun_open", global_config.rerun_open) - rerun_config.setdefault("rerun_web", global_config.rerun_web) - return autoconnect( - RerunBridgeModule.blueprint( - **rerun_config, - ), - RerunWebSocketServer.blueprint(), - WebsocketVisModule.blueprint(), - ) - case "none": - return autoconnect(WebsocketVisModule.blueprint()) - case _: - valid = ", ".join(get_args(ViewerBackend)) - raise ValueError(f"Unknown viewer_backend {viewer_backend!r}. Expected one of: {valid}") diff --git a/dimos/web/websocket_vis/websocket_vis_module.py b/dimos/web/websocket_vis/websocket_vis_module.py index 1ce7e74502..3d6b3df11c 100644 --- a/dimos/web/websocket_vis/websocket_vis_module.py +++ b/dimos/web/websocket_vis/websocket_vis_module.py @@ -105,7 +105,7 @@ class WebsocketVisModule(Module): gps_goal: Out[LatLon] explore_cmd: Out[Bool] stop_explore_cmd: Out[Bool] - tele_cmd_vel: Out[Twist] + cmd_vel: Out[Twist] movecmd_stamped: Out[TwistStamped] def __init__(self, **kwargs: Any) -> None: @@ -158,11 +158,9 @@ def start(self) -> None: self._uvicorn_server_thread = threading.Thread(target=self._run_uvicorn_server, daemon=True) self._uvicorn_server_thread.start() - # Auto-open the dashboard tab only when the user explicitly asked for a - # web-based viewer (rerun_open == "web" or "both"). `rerun_web` alone - # only means "serve the viewer"; it should not trigger a browser popup - # when the user chose the native viewer. - if self.config.g.viewer == "rerun" and self.config.g.rerun_open in ("web", "both"): + # Auto-open browser only for rerun-web (dashboard with Rerun iframe + command center) + # For rerun and foxglove, users access the command center manually if needed + if self.config.g.viewer == "rerun-web": url = f"http://localhost:{self.config.port}/" logger.info(f"Dimensional Command Center: {url}") @@ -238,13 +236,11 @@ def _create_server(self) -> None: async def serve_index(request): # type: ignore[no-untyped-def] """Serve appropriate HTML based on viewer mode.""" - # Serve the full dashboard (with Rerun iframe) only when the rerun - # web server is enabled; otherwise redirect to the standalone - # command center. - if not ( - self.config.g.viewer == "rerun" and self.config.g.rerun_open in ("web", "both") - ): + # If running native Rerun, redirect to standalone command center + if self.config.g.viewer != "rerun-web": return RedirectResponse(url="/command-center") + + # Otherwise serve full dashboard with Rerun iframe return FileResponse(_DASHBOARD_HTML, media_type="text/html") async def serve_command_center(request): # type: ignore[no-untyped-def] @@ -337,14 +333,14 @@ async def clear_gps_goals(sid: str) -> None: @self.sio.event # type: ignore[untyped-decorator] async def move_command(sid: str, data: dict[str, Any]) -> None: # Publish Twist if transport is configured - if self.tele_cmd_vel and self.tele_cmd_vel.transport: + if self.cmd_vel and self.cmd_vel.transport: twist = Twist( linear=Vector3(data["linear"]["x"], data["linear"]["y"], data["linear"]["z"]), angular=Vector3( data["angular"]["x"], data["angular"]["y"], data["angular"]["z"] ), ) - self.tele_cmd_vel.publish(twist) + self.cmd_vel.publish(twist) # Publish TwistStamped if transport is configured if self.movecmd_stamped and self.movecmd_stamped.transport: diff --git a/docs/development/conventions.md b/docs/development/conventions.md deleted file mode 100644 index 2b25a7c3c6..0000000000 --- a/docs/development/conventions.md +++ /dev/null @@ -1,12 +0,0 @@ -This mostly to track when conventions change (with regard to codebase updates) because this codebase is under heavy development. Note: this is a non-exhaustive list of conventions. - -- Instead of using `RerunBridge` in blueprints we always use `vis_module` which allows the CLI to control if its foxglove, rerun, or no-vis at all -- When global_config.py shouldn't accidentally/indirectly import heavy libraries like rerun. But sometimes global_config needs the type definition or default value from a module. Preferably we import from the module file directly, however when thats not possible, we create a config.py for just that module's config and import that into global_config.py. -- When adding visualization tools to a blueprint/autoconnect, instead of using RerunBridge or WebsocketVisModule directly we should always use `vis_module`, which right now should look something like `vis_module(viewer_backend=global_config.viewer, rerun_config={}),` -- `DEFAULT_THREAD_JOIN_TIMEOUT` is used for all thread.join timeouts -- Don't use print inside of tests -- Module configs should be specified as `config: ModuleSpecificConfigClass` -- To customize the way rerun renders something, right now we use a `rerun_config` dict. This will (hopefully) change very soon to be a per-module config instead of a per-blueprint config -- Similar to the `rerun_config` the `rrb` (rerun blueprint) is defined at a blueprint level right now, but ideally would be a per-module contribution with only a per-blueprint override of the layout. -- No `__init__.py` files -- Helper blueprints (like `_with_vis`) that should not be used on their own need to start with an underscore to avoid being picked up by the all_blueprints.py code generation step diff --git a/docs/usage/cli.md b/docs/usage/cli.md index bba73368b2..017b441c7e 100644 --- a/docs/usage/cli.md +++ b/docs/usage/cli.md @@ -18,9 +18,7 @@ dimos [GLOBAL OPTIONS] COMMAND [ARGS] | `--replay` / `--no-replay` | bool | `False` | Use recorded replay data | | `--replay-db` | TEXT | `go2_bigoffice` | Replay memory2 SQLite database name | | `--new-memory` / `--no-new-memory` | bool | `False` | Clear persistent memory on start | -| `--viewer` | `rerun\|foxglove\|none` | `rerun` | Visualization backend | -| `--rerun-open` | `native\|web\|both\|none` | `native` | How to open the Rerun viewer | -| `--rerun-web` / `--no-rerun-web` | bool | `False` | Serve the Rerun web viewer | +| `--viewer` | `rerun\|rerun-web\|rerun-connect\|foxglove\|none` | `rerun` | Visualization backend | | `--n-workers` | INT | `2` | Number of forkserver workers | | `--memory-limit` | TEXT | `auto` | Rerun viewer memory limit | | `--mcp-port` | INT | `9990` | MCP server port | diff --git a/docs/usage/visualization.md b/docs/usage/visualization.md index 9ece977a68..57ad460354 100644 --- a/docs/usage/visualization.md +++ b/docs/usage/visualization.md @@ -1,43 +1,37 @@ # Viewer Backends -Dimos supports three visualization backends: `rerun` (default), `foxglove`, and `none`. +Dimos supports three visualization backends: Rerun (web or native) and Foxglove. ## Quick Start -Choose your viewer via the CLI: +Choose your viewer via the CLI (preferred): ```bash # Rerun native viewer (default) - dimos-viewer with built-in teleop + click-to-navigate dimos run unitree-go2 -# Explicitly select the viewer backend: +# Explicitly select the viewer mode: dimos --viewer rerun run unitree-go2 +dimos --viewer rerun-web run unitree-go2 dimos --viewer foxglove run unitree-go2 -dimos --viewer none run unitree-go2 ``` -Control how the Rerun viewer opens with `--rerun-open` and `--rerun-web`: +Alternative (environment variable): ```bash -# Open native desktop viewer (default) -dimos --rerun-open native run unitree-go2 - -# Open web viewer in browser -dimos --rerun-open web run unitree-go2 - -# Open both native and web -dimos --rerun-open both run unitree-go2 +# Rerun native viewer (default) - dimos-viewer with built-in teleop + click-to-navigate +VIEWER=rerun dimos run unitree-go2 -# No viewer (headless) — data still accessible via gRPC -dimos --rerun-open none run unitree-go2 +# Rerun web viewer - browser dashboard + teleop at http://localhost:7779 +VIEWER=rerun-web dimos run unitree-go2 -# Serve the web viewer without auto-opening a browser -dimos --rerun-web --rerun-open native run unitree-go2 +# Foxglove - Use Foxglove Studio instead of Rerun +VIEWER=foxglove dimos run unitree-go2 ``` ## Viewer Modes Explained -### Rerun Native (`rerun`, `--rerun-open native`) — Default +### Rerun Native (`rerun`) — Default **What you get:** - [dimos-viewer](https://github.com/dimensionalOS/dimos-viewer), a custom Dimensional fork of Rerun with built-in keyboard teleop and click-to-navigate @@ -47,7 +41,7 @@ dimos --rerun-web --rerun-open native run unitree-go2 --- -### Rerun Web (`rerun`, `--rerun-open web`) +### Rerun Web (`rerun-web`) **What you get:** - Browser-based dashboard at http://localhost:7779 @@ -69,16 +63,18 @@ dimos --rerun-web --rerun-open native run unitree-go2 ## Rendering with Custom Blueprints -To enable visualization in your own blueprint, use `vis_module`: +To enable rerun within your own blueprint simply include `RerunBridgeModule`: ```python -from dimos.core.global_config import global_config -from dimos.visualization.vis_module import vis_module +from dimos.visualization.rerun.bridge import RerunBridgeModule from dimos.hardware.sensors.camera.module import CameraModule +from dimos.protocol.pubsub.impl.lcmpubsub import LCM camera_demo = autoconnect( CameraModule.blueprint(), - vis_module(viewer_backend=global_config.viewer), + RerunBridgeModule.blueprint( + viewer_mode="native", # native (desktop), web (browser), none (headless) + ), ) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index d5d9b2fa22..074670cb60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,7 +82,7 @@ dependencies = [ # TODO: rerun shouldn't be required but rn its in core (there is NO WAY to use dimos without rerun rn) # remove this once rerun is optional in core "rerun-sdk>=0.20.0", - "dimos-viewer==0.30.0a6.dev99", + "dimos-viewer>=0.30.0a2", "toolz>=1.1.0", "protobuf>=6.33.5,<7", "psutil>=7.0.0", diff --git a/uv.lock b/uv.lock index dfbc569f8a..6e94931740 100644 --- a/uv.lock +++ b/uv.lock @@ -1993,7 +1993,7 @@ requires-dist = [ { name = "dimos", extras = ["unitree"], marker = "extra == 'unitree-dds'" }, { name = "dimos-lcm" }, { name = "dimos-lcm", marker = "extra == 'docker'" }, - { name = "dimos-viewer", specifier = "==0.30.0a6.dev99" }, + { name = "dimos-viewer", specifier = ">=0.30.0a2" }, { name = "dimos-viewer", marker = "extra == 'visualization'", specifier = ">=0.30.0a4" }, { name = "drake", marker = "platform_machine != 'aarch64' and sys_platform == 'darwin' and extra == 'manipulation'", specifier = "==1.45.0" }, { name = "drake", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and extra == 'manipulation'", specifier = ">=1.40.0" }, @@ -2164,18 +2164,18 @@ wheels = [ [[package]] name = "dimos-viewer" -version = "0.30.0a6.dev99" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3d/0e/d363be05f172bafe5f41a95db318891637e902c50edfdc642edec6bb5111/dimos_viewer-0.30.0a6.dev99-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cfa57e68e8f4094d4a38d202414046fd2419ff2875ace3f16b8581c3106feca4", size = 35405401, upload-time = "2026-04-17T04:19:10.126Z" }, - { url = "https://files.pythonhosted.org/packages/e7/ab/0730fed402b3b92e35194f11b76119754d619fa6bab00a1932b5c78f87b3/dimos_viewer-0.30.0a6.dev99-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:f3bc243342131c8c2b653cc6b76f04d65aad525f5560829b78aa1a7d31a9d375", size = 39167146, upload-time = "2026-04-17T04:19:14.177Z" }, - { url = "https://files.pythonhosted.org/packages/bb/d9/1415d5d7e609d69b05e8e1167a66dd7cb78f3933205f9b321ae18233384c/dimos_viewer-0.30.0a6.dev99-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b954083fcb8951641554fdea95425b3b5ac9415cd1b65410a137d38d3dd57b8a", size = 41536165, upload-time = "2026-04-17T04:19:17.379Z" }, - { url = "https://files.pythonhosted.org/packages/93/7c/7ee6049a753c01ccbe8357f9c5f789378103b87331e5ca7977f05adf5c42/dimos_viewer-0.30.0a6.dev99-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0387201efd1260f968853f0d7863876b6db375b2af15b22f221a893fcce6549c", size = 35405408, upload-time = "2026-04-17T04:19:20.08Z" }, - { url = "https://files.pythonhosted.org/packages/de/2e/9b4252a12c4b641ab1479a6a4d3d576e75fc42ca2a797d88e2e0626abda0/dimos_viewer-0.30.0a6.dev99-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:a0fae6f2077fc6ceb25e1ed33fb7ccf183ef3e2a30456aa5462b953c1419e547", size = 39167138, upload-time = "2026-04-17T04:19:23.292Z" }, - { url = "https://files.pythonhosted.org/packages/46/2a/4bd02c3d79df2aefc5be47afda6b95121937cef0a3f6b15d071691ec3ca7/dimos_viewer-0.30.0a6.dev99-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e844015f3ad193d50201c39abd3e3f34abbf03adbfb1075468696c1236df1409", size = 41536172, upload-time = "2026-04-17T04:19:26.421Z" }, - { url = "https://files.pythonhosted.org/packages/1b/b1/efcea9b9e21c4ab75e2df016a27e5045e30d91a494465ab0cc627d8d8bc3/dimos_viewer-0.30.0a6.dev99-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:dc82061c2c025684c0fbed5392f793d137b1b0fc3aa1b601988bf4d2ee88aa27", size = 35405409, upload-time = "2026-04-17T04:19:29.574Z" }, - { url = "https://files.pythonhosted.org/packages/2d/8e/d482b0b9379c40ddd7547600543ce726fc3b5d10e396a876f22b2d76d0e6/dimos_viewer-0.30.0a6.dev99-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:0f6acfa0de3083e746ac43fe0d0a328d624bcb859dc698b1bbc592f444f52f15", size = 39167144, upload-time = "2026-04-17T04:19:32.301Z" }, - { url = "https://files.pythonhosted.org/packages/6d/eb/08922721c74ceaa99a824258db02c438d50f77c22ff80332cbc4b1a8db7b/dimos_viewer-0.30.0a6.dev99-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:56fa9139c49ec4bf96b12d6e98d3de3319a66876374ae57bda4534ab7a347765", size = 41536171, upload-time = "2026-04-17T04:19:35.29Z" }, +version = "0.30.0a6" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0b/90/ad6d0e1e177a10a0b4f7e736436b6d2741acaeb402ab59504347236744f4/dimos_viewer-0.30.0a6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e623a21e6992e263513847e12809a0d234d73fc7af42a6428e84ca165ba682d0", size = 35309553, upload-time = "2026-03-18T15:22:26.874Z" }, + { url = "https://files.pythonhosted.org/packages/a1/84/1c8f41ff2bd5b6ee143eb6119107397dac284fa4f1f8335623c498bd1d9c/dimos_viewer-0.30.0a6-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:36068a3293cb1c7f4db9f4e6c9fea2d7dd2a2527025f803585f4d3aaad9aedbd", size = 39072034, upload-time = "2026-03-18T15:22:29.592Z" }, + { url = "https://files.pythonhosted.org/packages/58/e6/d6214245e5b99e1da262d037f52d3d39c6b87c65acb516fb08f11378e932/dimos_viewer-0.30.0a6-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:2bf36e8c8bd9dd822bedd1cb2d80ee2bf74b58184ba33872494baed0395fa7ff", size = 41447599, upload-time = "2026-03-18T15:22:32.699Z" }, + { url = "https://files.pythonhosted.org/packages/48/04/80f566400776cab9af68b4a3c0132f55786acd1641ea39d8b75e797a2e22/dimos_viewer-0.30.0a6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:947cfa10c583b357d589c10cb466c63b3651a83d1013a254c0ba03fc2959bef7", size = 35309552, upload-time = "2026-03-18T15:22:35.395Z" }, + { url = "https://files.pythonhosted.org/packages/4c/c3/72157e0806951c2c71c70dcd783e27be8d694344d7ecdb94eaef1066cf99/dimos_viewer-0.30.0a6-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:53ca4ac1f0778f1d9afb317b6268c941c02b20af86dd2aaaf1ea79f2c1d1eeb8", size = 39072018, upload-time = "2026-03-18T15:22:38.043Z" }, + { url = "https://files.pythonhosted.org/packages/2f/92/959fc1e9cdcb5fd8d793b2c8515a6086c9f913ba470baad1f3182ae4c242/dimos_viewer-0.30.0a6-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:27e108060a942c92f7869a0e45693dfe1798896bd90cbac6d1ce019a682f8ba7", size = 41447647, upload-time = "2026-03-18T15:22:41.003Z" }, + { url = "https://files.pythonhosted.org/packages/ab/d6/d76763b60d82539e92777500551116306cfea462f6976ad814a3bdf57e1d/dimos_viewer-0.30.0a6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f4f49f973c51055cfd594b68a8e9d183c706f94b1513b6b69db900d05850f741", size = 35309553, upload-time = "2026-03-18T15:22:43.681Z" }, + { url = "https://files.pythonhosted.org/packages/26/ab/6ea7686c467caecdc74dd8d3a0267053ac74229b3afebc64cff180d5074c/dimos_viewer-0.30.0a6-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:791ef1c1d8d41db69a7d2b701ed3f0b6bc39cb3264aaef7300eddb576c8df7ed", size = 39072062, upload-time = "2026-03-18T15:22:46.264Z" }, + { url = "https://files.pythonhosted.org/packages/3c/87/fce7aac56d8a234d3d7c0911928bb3471d7852e35263b966d2aac5be42cd/dimos_viewer-0.30.0a6-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:dd976c39c38718b8373e1894d55b78c10bcb8c5716c8dbd5fba59141bc08ab3c", size = 41447667, upload-time = "2026-03-18T15:22:49.214Z" }, ] [[package]] From 852fdac3c4bebf97e9cabd3356bb5ce66493eba6 Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Tue, 28 Apr 2026 18:54:50 -0700 Subject: [PATCH 21/30] added trimesh and pycollada dependency to toml --- pyproject.toml | 4 ++++ uv.lock | 29 +++++++++++++++++++---------- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 074670cb60..398903f457 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -214,6 +214,10 @@ manipulation = [ "pyrealsense2; sys_platform != 'darwin'", "xarm-python-sdk>=1.17.0", + # Mesh conversion (STL/DAE → OBJ for Drake collision geometry) + "trimesh", + "pycollada", + # Visualization (Optional) "kaleido>=0.2.1", "plotly>=5.9.0", diff --git a/uv.lock b/uv.lock index 6e94931740..9052f42804 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'darwin'", @@ -1848,8 +1848,10 @@ manipulation = [ { name = "matplotlib" }, { name = "piper-sdk" }, { name = "plotly" }, + { name = "pycollada" }, { name = "pyrealsense2", marker = "sys_platform != 'darwin'" }, { name = "pyyaml" }, + { name = "trimesh" }, { name = "xacro" }, { name = "xarm-python-sdk" }, ] @@ -2061,6 +2063,7 @@ requires-dist = [ { name = "psutil", specifier = ">=7.0.0" }, { name = "psycopg2-binary", marker = "extra == 'psql'", specifier = ">=2.9.11" }, { name = "py-spy", marker = "extra == 'dev'" }, + { name = "pycollada", marker = "extra == 'manipulation'" }, { name = "pydantic" }, { name = "pydantic", marker = "extra == 'docker'" }, { name = "pydantic-settings", specifier = ">=2.11.0,<3" }, @@ -2112,6 +2115,7 @@ requires-dist = [ { name = "toolz", specifier = ">=1.1.0" }, { name = "torchreid", marker = "extra == 'misc'", specifier = "==0.2.5" }, { name = "transformers", extras = ["torch"], marker = "extra == 'perception'", specifier = "==4.49.0" }, + { name = "trimesh", marker = "extra == 'manipulation'" }, { name = "typeguard", marker = "extra == 'misc'" }, { name = "typer", specifier = ">=0.19.2,<1" }, { name = "typer", marker = "extra == 'docker'", specifier = ">=0.19.2,<1" }, @@ -5613,7 +5617,6 @@ resolution-markers = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/82/6c/90d3f532f608a03a13c1d6c16c266ffa3828e8011b1549d3b61db2ad59f5/nvidia_cublas_cu12-12.9.1.4-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:7a950dae01add3b415a5a5cdc4ec818fb5858263e9cca59004bb99fdbbd3a5d6", size = 575006342, upload-time = "2025-06-05T20:04:16.902Z" }, - { url = "https://files.pythonhosted.org/packages/77/3c/aa88abe01f3be3d1f8f787d1d33dc83e76fec05945f9a28fbb41cfb99cd5/nvidia_cublas_cu12-12.9.1.4-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:453611eb21a7c1f2c2156ed9f3a45b691deda0440ec550860290dc901af5b4c2", size = 581242350, upload-time = "2025-06-05T20:04:51.979Z" }, { url = "https://files.pythonhosted.org/packages/45/a1/a17fade6567c57452cfc8f967a40d1035bb9301db52f27808167fbb2be2f/nvidia_cublas_cu12-12.9.1.4-py3-none-win_amd64.whl", hash = "sha256:1e5fee10662e6e52bd71dec533fbbd4971bb70a5f24f3bc3793e5c2e9dc640bf", size = 553153899, upload-time = "2025-06-05T20:13:35.556Z" }, ] @@ -5672,7 +5675,6 @@ resolution-markers = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/e0/0279bd94539fda525e0c8538db29b72a5a8495b0c12173113471d28bce78/nvidia_cuda_runtime_cu12-12.9.79-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:83469a846206f2a733db0c42e223589ab62fd2fabac4432d2f8802de4bded0a4", size = 3515012, upload-time = "2025-06-05T20:00:35.519Z" }, - { url = "https://files.pythonhosted.org/packages/bc/46/a92db19b8309581092a3add7e6fceb4c301a3fd233969856a8cbf042cd3c/nvidia_cuda_runtime_cu12-12.9.79-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:25bba2dfb01d48a9b59ca474a1ac43c6ebf7011f1b0b8cc44f54eb6ac48a96c3", size = 3493179, upload-time = "2025-06-05T20:00:53.735Z" }, { url = "https://files.pythonhosted.org/packages/59/df/e7c3a360be4f7b93cee39271b792669baeb3846c58a4df6dfcf187a7ffab/nvidia_cuda_runtime_cu12-12.9.79-py3-none-win_amd64.whl", hash = "sha256:8e018af8fa02363876860388bd10ccb89eb9ab8fb0aa749aaf58430a9f7c4891", size = 3591604, upload-time = "2025-06-05T20:11:17.036Z" }, ] @@ -7381,6 +7383,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3a/d8/a211b3f85e99a0daa2ddec96c949cac6824bd305b040571b82a03dd62636/pycodestyle-2.12.1-py2.py3-none-any.whl", hash = "sha256:46f0fb92069a7c28ab7bb558f05bfc0110dac69a0cd23c61ea0040283a9d78b3", size = 31284, upload-time = "2024-08-04T20:26:53.173Z" }, ] +[[package]] +name = "pycollada" +version = "0.9.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "python-dateutil" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5a/8d/52a5364a17eb96129962cae8d3ee7658775e085ad0ba38388684ad5944e9/pycollada-0.9.3.tar.gz", hash = "sha256:c34d6dcf0fe2eba5896f71c96d37a1c0fe1a61f08440fa0cfcec3dc2895d3302", size = 110826, upload-time = "2026-01-24T15:45:23.625Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/86/f1f61b7a0701f9d1299e5293d083318019f91021a4d449f94d59dbe024e4/pycollada-0.9.3-py3-none-any.whl", hash = "sha256:636e6496f60987586db82455ea7bbd9ade775e8181c6590c83b698b6cd53a9f5", size = 129206, upload-time = "2026-01-24T15:45:22.182Z" }, +] + [[package]] name = "pycparser" version = "3.0" @@ -9929,19 +9945,12 @@ name = "triton" version = "3.6.0" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/44/ba/b1b04f4b291a3205d95ebd24465de0e5bf010a2df27a4e58a9b5f039d8f2/triton-3.6.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6c723cfb12f6842a0ae94ac307dba7e7a44741d720a40cf0e270ed4a4e3be781", size = 175972180, upload-time = "2026-01-20T16:15:53.664Z" }, { url = "https://files.pythonhosted.org/packages/8c/f7/f1c9d3424ab199ac53c2da567b859bcddbb9c9e7154805119f8bd95ec36f/triton-3.6.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a6550fae429e0667e397e5de64b332d1e5695b73650ee75a6146e2e902770bea", size = 188105201, upload-time = "2026-01-20T16:00:29.272Z" }, - { url = "https://files.pythonhosted.org/packages/0f/2c/96f92f3c60387e14cc45aed49487f3486f89ea27106c1b1376913c62abe4/triton-3.6.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:49df5ef37379c0c2b5c0012286f80174fcf0e073e5ade1ca9a86c36814553651", size = 176081190, upload-time = "2026-01-20T16:16:00.523Z" }, { url = "https://files.pythonhosted.org/packages/e0/12/b05ba554d2c623bffa59922b94b0775673de251f468a9609bc9e45de95e9/triton-3.6.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e8e323d608e3a9bfcc2d9efcc90ceefb764a82b99dea12a86d643c72539ad5d3", size = 188214640, upload-time = "2026-01-20T16:00:35.869Z" }, - { url = "https://files.pythonhosted.org/packages/17/5d/08201db32823bdf77a0e2b9039540080b2e5c23a20706ddba942924ebcd6/triton-3.6.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:374f52c11a711fd062b4bfbb201fd9ac0a5febd28a96fb41b4a0f51dde3157f4", size = 176128243, upload-time = "2026-01-20T16:16:07.857Z" }, { url = "https://files.pythonhosted.org/packages/ab/a8/cdf8b3e4c98132f965f88c2313a4b493266832ad47fb52f23d14d4f86bb5/triton-3.6.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:74caf5e34b66d9f3a429af689c1c7128daba1d8208df60e81106b115c00d6fca", size = 188266850, upload-time = "2026-01-20T16:00:43.041Z" }, - { url = "https://files.pythonhosted.org/packages/3c/12/34d71b350e89a204c2c7777a9bba0dcf2f19a5bfdd70b57c4dbc5ffd7154/triton-3.6.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:448e02fe6dc898e9e5aa89cf0ee5c371e99df5aa5e8ad976a80b93334f3494fd", size = 176133521, upload-time = "2026-01-20T16:16:13.321Z" }, { url = "https://files.pythonhosted.org/packages/f9/0b/37d991d8c130ce81a8728ae3c25b6e60935838e9be1b58791f5997b24a54/triton-3.6.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10c7f76c6e72d2ef08df639e3d0d30729112f47a56b0c81672edc05ee5116ac9", size = 188289450, upload-time = "2026-01-20T16:00:49.136Z" }, - { url = "https://files.pythonhosted.org/packages/ce/4e/41b0c8033b503fd3cfcd12392cdd256945026a91ff02452bef40ec34bee7/triton-3.6.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1722e172d34e32abc3eb7711d0025bb69d7959ebea84e3b7f7a341cd7ed694d6", size = 176276087, upload-time = "2026-01-20T16:16:18.989Z" }, { url = "https://files.pythonhosted.org/packages/35/f8/9c66bfc55361ec6d0e4040a0337fb5924ceb23de4648b8a81ae9d33b2b38/triton-3.6.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d002e07d7180fd65e622134fbd980c9a3d4211fb85224b56a0a0efbd422ab72f", size = 188400296, upload-time = "2026-01-20T16:00:56.042Z" }, - { url = "https://files.pythonhosted.org/packages/49/55/5ecf0dcaa0f2fbbd4420f7ef227ee3cb172e91e5fede9d0ecaddc43363b4/triton-3.6.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ef5523241e7d1abca00f1d240949eebdd7c673b005edbbce0aca95b8191f1d43", size = 176138577, upload-time = "2026-01-20T16:16:25.426Z" }, { url = "https://files.pythonhosted.org/packages/df/3d/9e7eee57b37c80cec63322c0231bb6da3cfe535a91d7a4d64896fcb89357/triton-3.6.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a17a5d5985f0ac494ed8a8e54568f092f7057ef60e1b0fa09d3fd1512064e803", size = 188273063, upload-time = "2026-01-20T16:01:07.278Z" }, - { url = "https://files.pythonhosted.org/packages/48/db/56ee649cab5eaff4757541325aca81f52d02d4a7cd3506776cad2451e060/triton-3.6.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0b3a97e8ed304dfa9bd23bb41ca04cdf6b2e617d5e782a8653d616037a5d537d", size = 176274804, upload-time = "2026-01-20T16:16:31.528Z" }, { url = "https://files.pythonhosted.org/packages/f6/56/6113c23ff46c00aae423333eb58b3e60bdfe9179d542781955a5e1514cb3/triton-3.6.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:46bd1c1af4b6704e554cad2eeb3b0a6513a980d470ccfa63189737340c7746a7", size = 188397994, upload-time = "2026-01-20T16:01:14.236Z" }, ] From 7d613a7c2107ba9fb21ae88e4a9c493a2a5cc0be Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Tue, 28 Apr 2026 19:05:18 -0700 Subject: [PATCH 22/30] openarm_description added to lfs --- data/.lfs/openarm_description.tar.gz | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 data/.lfs/openarm_description.tar.gz diff --git a/data/.lfs/openarm_description.tar.gz b/data/.lfs/openarm_description.tar.gz new file mode 100644 index 0000000000..b4cf5e04e9 --- /dev/null +++ b/data/.lfs/openarm_description.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e5dcbc94024986a46414b81486c3e88883847db75dd6ee90dc5fa6b88536b20f +size 70064686 From e88990f83a32c2610614089fe575b3985fd7718f Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Tue, 28 Apr 2026 19:05:39 -0700 Subject: [PATCH 23/30] openarm adapter updated to use LFS path --- dimos/hardware/manipulators/openarm/adapter.py | 13 +++++++------ dimos/robot/catalog/openarm.py | 18 +++++++----------- dimos/robot/manipulators/openarm/blueprints.py | 4 ++-- 3 files changed, 16 insertions(+), 19 deletions(-) diff --git a/dimos/hardware/manipulators/openarm/adapter.py b/dimos/hardware/manipulators/openarm/adapter.py index d80a8c1bf2..9ec7f60d30 100644 --- a/dimos/hardware/manipulators/openarm/adapter.py +++ b/dimos/hardware/manipulators/openarm/adapter.py @@ -18,6 +18,8 @@ import time from pathlib import Path + +from dimos.utils.data import LfsPath from typing import TYPE_CHECKING, Any import numpy as np @@ -60,7 +62,8 @@ def _socketcan_iface_up(name: str) -> bool: (0x06, MotorType.DM4310), # joint6 (0x07, MotorType.DM4310), # joint7 ] -_OPENARM_V10_GRIPPER_MOTOR: tuple[int, MotorType] = (0x08, MotorType.DM4310) +# Gripper (motor id 0x08, DM4310) is on the bus but not currently wired up +# through the adapter — see the gripper-write methods which return None/False. # Physical joint limits (measured). Joints 1 & 2 are mirrored between sides. _V10_POS_LOWER_LEFT = [-3.45, -3.30, -1.50, -0.01, -1.50, -0.75, -1.50] @@ -91,11 +94,9 @@ class OpenArmAdapter: interface: python-can backend; "virtual" for unit tests """ - # Per-side URDFs for Pinocchio gravity model - _REPO_ROOT = Path(__file__).resolve().parents[4] - _OPENARM_PKG = _REPO_ROOT / "data" / "openarm_description" - _URDF_LEFT = _OPENARM_PKG / "urdf" / "robot" / "openarm_v10_left.urdf" - _URDF_RIGHT = _OPENARM_PKG / "urdf" / "robot" / "openarm_v10_right.urdf" + # Per-side URDFs for Pinocchio gravity model (LFS-backed) + _URDF_LEFT = LfsPath("openarm_description/urdf/robot/openarm_v10_left.urdf") + _URDF_RIGHT = LfsPath("openarm_description/urdf/robot/openarm_v10_right.urdf") def __init__( self, diff --git a/dimos/robot/catalog/openarm.py b/dimos/robot/catalog/openarm.py index afc25db838..b6271ea836 100644 --- a/dimos/robot/catalog/openarm.py +++ b/dimos/robot/catalog/openarm.py @@ -16,10 +16,10 @@ from __future__ import annotations -from pathlib import Path from typing import Any from dimos.robot.config import RobotConfig +from dimos.utils.data import LfsPath # Collision exclusion pairs — structural mesh overlaps in the OpenArm URDF. # link5 and link7 collision meshes overlap by ~3mm at zero pose (and every @@ -29,20 +29,16 @@ ("openarm_right_link5", "openarm_right_link7"), ] -# Local path during bring-up. Swap to ``LfsPath("openarm_description/...")`` -# once the URDF is migrated to LFS. -_REPO_ROOT = Path(__file__).resolve().parents[3] -_OPENARM_PKG = _REPO_ROOT / "data" / "openarm_description" -_OPENARM_MODEL_PATH = _OPENARM_PKG / "urdf" / "robot" / "openarm_v10_bimanual.urdf" +# LFS-backed: data/.lfs/openarm_description.tar.gz extracts to data/openarm_description/ +_OPENARM_PKG = LfsPath("openarm_description") +_OPENARM_MODEL_PATH = _OPENARM_PKG / "urdf/robot/openarm_v10_bimanual.urdf" # Per-side URDFs: extracted from bimanual expansion, only one arm + torso each. # Avoids phantom-arm collisions when Drake loads both sides into one world. -_OPENARM_LEFT_MODEL = _OPENARM_PKG / "urdf" / "robot" / "openarm_v10_left.urdf" -_OPENARM_RIGHT_MODEL = _OPENARM_PKG / "urdf" / "robot" / "openarm_v10_right.urdf" +_OPENARM_LEFT_MODEL = _OPENARM_PKG / "urdf/robot/openarm_v10_left.urdf" +_OPENARM_RIGHT_MODEL = _OPENARM_PKG / "urdf/robot/openarm_v10_right.urdf" # Pre-expanded single-arm URDF for Pinocchio FK (keyboard teleop, IK, etc.) -# Pinocchio doesn't handle xacro — this file is the expansion of the bimanual -# xacro with only one side's links kept. -OPENARM_V10_FK_MODEL = _OPENARM_PKG / "urdf" / "robot" / "openarm_v10_single.urdf" +OPENARM_V10_FK_MODEL = _OPENARM_PKG / "urdf/robot/openarm_v10_single.urdf" def openarm_arm( diff --git a/dimos/robot/manipulators/openarm/blueprints.py b/dimos/robot/manipulators/openarm/blueprints.py index 13070d9fcb..d0075677d9 100644 --- a/dimos/robot/manipulators/openarm/blueprints.py +++ b/dimos/robot/manipulators/openarm/blueprints.py @@ -159,7 +159,7 @@ _teleop_cfg = _openarm_single(name="arm") keyboard_teleop_openarm_mock = autoconnect( - KeyboardTeleopModule.blueprint(model_path=str(OPENARM_V10_FK_MODEL), ee_joint_id=_teleop_cfg.dof), + KeyboardTeleopModule.blueprint(model_path=OPENARM_V10_FK_MODEL, ee_joint_id=_teleop_cfg.dof), ControlCoordinator.blueprint( tick_rate=100.0, publish_joint_state=True, @@ -191,7 +191,7 @@ _teleop_hw_cfg = _openarm_single(name="arm", adapter_type="openarm", address=LEFT_CAN) keyboard_teleop_openarm = autoconnect( - KeyboardTeleopModule.blueprint(model_path=str(OPENARM_V10_FK_MODEL), ee_joint_id=_teleop_hw_cfg.dof), + KeyboardTeleopModule.blueprint(model_path=OPENARM_V10_FK_MODEL, ee_joint_id=_teleop_hw_cfg.dof), ControlCoordinator.blueprint( tick_rate=100.0, publish_joint_state=True, From f98fb2235bea4c78c045b820c22b3488d83f2624 Mon Sep 17 00:00:00 2001 From: leshy Date: Fri, 24 Apr 2026 21:06:34 +0300 Subject: [PATCH 24/30] Feat/memory2 - plotting, examples, recorder module, semantic search (#1769) Co-authored-by: RD <63036454+ruthwikdasyam@users.noreply.github.com> --- dimos/memory2/vis/plot/elements.py | 4 ++-- dimos/memory2/vis/plot/plot.py | 4 ++-- dimos/perception/detection/module2D.py | 3 +-- dimos/perception/detection/type/imageDetections.py | 3 +-- uv.lock | 11 ++++++++++- 5 files changed, 16 insertions(+), 9 deletions(-) diff --git a/dimos/memory2/vis/plot/elements.py b/dimos/memory2/vis/plot/elements.py index 7f83de2b94..8b5932da53 100644 --- a/dimos/memory2/vis/plot/elements.py +++ b/dimos/memory2/vis/plot/elements.py @@ -17,11 +17,11 @@ from __future__ import annotations from dataclasses import dataclass -from enum import Enum +from enum import StrEnum from typing import Union -class Style(str, Enum): +class Style(StrEnum): """Line style for Series and HLine elements. Values match matplotlib's `linestyle` names so they pass through directly diff --git a/dimos/memory2/vis/plot/plot.py b/dimos/memory2/vis/plot/plot.py index 082b147125..6235e44bda 100644 --- a/dimos/memory2/vis/plot/plot.py +++ b/dimos/memory2/vis/plot/plot.py @@ -16,13 +16,13 @@ from __future__ import annotations -from enum import Enum +from enum import StrEnum from typing import Any from dimos.memory2.vis.plot.elements import HLine, Markers, PlotElement, Series, VLine -class TimeAxis(str, Enum): +class TimeAxis(StrEnum): """How the x-axis is formatted. - ``raw``: unix timestamps as-is (matplotlib's default numeric formatter). diff --git a/dimos/perception/detection/module2D.py b/dimos/perception/detection/module2D.py index 3f9aee84e4..fb07e02d3c 100644 --- a/dimos/perception/detection/module2D.py +++ b/dimos/perception/detection/module2D.py @@ -78,8 +78,7 @@ def process_image_frame(self, image: Image) -> ImageDetections2D: imageDetections = self.detector.process_image(image) if not self.config.filter: return imageDetections - filtered: ImageDetections2D = imageDetections.filter(*self.config.filter) - return filtered + return imageDetections.filter(*self.config.filter) @simple_mcache def sharp_image_stream(self) -> Observable[Image]: diff --git a/dimos/perception/detection/type/imageDetections.py b/dimos/perception/detection/type/imageDetections.py index 1eea0c9c3c..6820b6210d 100644 --- a/dimos/perception/detection/type/imageDetections.py +++ b/dimos/perception/detection/type/imageDetections.py @@ -16,8 +16,7 @@ from functools import reduce from operator import add -import sys -from typing import TYPE_CHECKING, Generic, TypeVar +from typing import TYPE_CHECKING, Generic, Self, TypeVar if sys.version_info >= (3, 11): from typing import Self diff --git a/uv.lock b/uv.lock index 9052f42804..f429ae7b71 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'darwin'", @@ -5617,6 +5617,7 @@ resolution-markers = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/82/6c/90d3f532f608a03a13c1d6c16c266ffa3828e8011b1549d3b61db2ad59f5/nvidia_cublas_cu12-12.9.1.4-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:7a950dae01add3b415a5a5cdc4ec818fb5858263e9cca59004bb99fdbbd3a5d6", size = 575006342, upload-time = "2025-06-05T20:04:16.902Z" }, + { url = "https://files.pythonhosted.org/packages/77/3c/aa88abe01f3be3d1f8f787d1d33dc83e76fec05945f9a28fbb41cfb99cd5/nvidia_cublas_cu12-12.9.1.4-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:453611eb21a7c1f2c2156ed9f3a45b691deda0440ec550860290dc901af5b4c2", size = 581242350, upload-time = "2025-06-05T20:04:51.979Z" }, { url = "https://files.pythonhosted.org/packages/45/a1/a17fade6567c57452cfc8f967a40d1035bb9301db52f27808167fbb2be2f/nvidia_cublas_cu12-12.9.1.4-py3-none-win_amd64.whl", hash = "sha256:1e5fee10662e6e52bd71dec533fbbd4971bb70a5f24f3bc3793e5c2e9dc640bf", size = 553153899, upload-time = "2025-06-05T20:13:35.556Z" }, ] @@ -5675,6 +5676,7 @@ resolution-markers = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/e0/0279bd94539fda525e0c8538db29b72a5a8495b0c12173113471d28bce78/nvidia_cuda_runtime_cu12-12.9.79-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:83469a846206f2a733db0c42e223589ab62fd2fabac4432d2f8802de4bded0a4", size = 3515012, upload-time = "2025-06-05T20:00:35.519Z" }, + { url = "https://files.pythonhosted.org/packages/bc/46/a92db19b8309581092a3add7e6fceb4c301a3fd233969856a8cbf042cd3c/nvidia_cuda_runtime_cu12-12.9.79-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:25bba2dfb01d48a9b59ca474a1ac43c6ebf7011f1b0b8cc44f54eb6ac48a96c3", size = 3493179, upload-time = "2025-06-05T20:00:53.735Z" }, { url = "https://files.pythonhosted.org/packages/59/df/e7c3a360be4f7b93cee39271b792669baeb3846c58a4df6dfcf187a7ffab/nvidia_cuda_runtime_cu12-12.9.79-py3-none-win_amd64.whl", hash = "sha256:8e018af8fa02363876860388bd10ccb89eb9ab8fb0aa749aaf58430a9f7c4891", size = 3591604, upload-time = "2025-06-05T20:11:17.036Z" }, ] @@ -9945,12 +9947,19 @@ name = "triton" version = "3.6.0" source = { registry = "https://pypi.org/simple" } wheels = [ + { url = "https://files.pythonhosted.org/packages/44/ba/b1b04f4b291a3205d95ebd24465de0e5bf010a2df27a4e58a9b5f039d8f2/triton-3.6.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6c723cfb12f6842a0ae94ac307dba7e7a44741d720a40cf0e270ed4a4e3be781", size = 175972180, upload-time = "2026-01-20T16:15:53.664Z" }, { url = "https://files.pythonhosted.org/packages/8c/f7/f1c9d3424ab199ac53c2da567b859bcddbb9c9e7154805119f8bd95ec36f/triton-3.6.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a6550fae429e0667e397e5de64b332d1e5695b73650ee75a6146e2e902770bea", size = 188105201, upload-time = "2026-01-20T16:00:29.272Z" }, + { url = "https://files.pythonhosted.org/packages/0f/2c/96f92f3c60387e14cc45aed49487f3486f89ea27106c1b1376913c62abe4/triton-3.6.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:49df5ef37379c0c2b5c0012286f80174fcf0e073e5ade1ca9a86c36814553651", size = 176081190, upload-time = "2026-01-20T16:16:00.523Z" }, { url = "https://files.pythonhosted.org/packages/e0/12/b05ba554d2c623bffa59922b94b0775673de251f468a9609bc9e45de95e9/triton-3.6.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e8e323d608e3a9bfcc2d9efcc90ceefb764a82b99dea12a86d643c72539ad5d3", size = 188214640, upload-time = "2026-01-20T16:00:35.869Z" }, + { url = "https://files.pythonhosted.org/packages/17/5d/08201db32823bdf77a0e2b9039540080b2e5c23a20706ddba942924ebcd6/triton-3.6.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:374f52c11a711fd062b4bfbb201fd9ac0a5febd28a96fb41b4a0f51dde3157f4", size = 176128243, upload-time = "2026-01-20T16:16:07.857Z" }, { url = "https://files.pythonhosted.org/packages/ab/a8/cdf8b3e4c98132f965f88c2313a4b493266832ad47fb52f23d14d4f86bb5/triton-3.6.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:74caf5e34b66d9f3a429af689c1c7128daba1d8208df60e81106b115c00d6fca", size = 188266850, upload-time = "2026-01-20T16:00:43.041Z" }, + { url = "https://files.pythonhosted.org/packages/3c/12/34d71b350e89a204c2c7777a9bba0dcf2f19a5bfdd70b57c4dbc5ffd7154/triton-3.6.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:448e02fe6dc898e9e5aa89cf0ee5c371e99df5aa5e8ad976a80b93334f3494fd", size = 176133521, upload-time = "2026-01-20T16:16:13.321Z" }, { url = "https://files.pythonhosted.org/packages/f9/0b/37d991d8c130ce81a8728ae3c25b6e60935838e9be1b58791f5997b24a54/triton-3.6.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10c7f76c6e72d2ef08df639e3d0d30729112f47a56b0c81672edc05ee5116ac9", size = 188289450, upload-time = "2026-01-20T16:00:49.136Z" }, + { url = "https://files.pythonhosted.org/packages/ce/4e/41b0c8033b503fd3cfcd12392cdd256945026a91ff02452bef40ec34bee7/triton-3.6.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1722e172d34e32abc3eb7711d0025bb69d7959ebea84e3b7f7a341cd7ed694d6", size = 176276087, upload-time = "2026-01-20T16:16:18.989Z" }, { url = "https://files.pythonhosted.org/packages/35/f8/9c66bfc55361ec6d0e4040a0337fb5924ceb23de4648b8a81ae9d33b2b38/triton-3.6.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d002e07d7180fd65e622134fbd980c9a3d4211fb85224b56a0a0efbd422ab72f", size = 188400296, upload-time = "2026-01-20T16:00:56.042Z" }, + { url = "https://files.pythonhosted.org/packages/49/55/5ecf0dcaa0f2fbbd4420f7ef227ee3cb172e91e5fede9d0ecaddc43363b4/triton-3.6.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ef5523241e7d1abca00f1d240949eebdd7c673b005edbbce0aca95b8191f1d43", size = 176138577, upload-time = "2026-01-20T16:16:25.426Z" }, { url = "https://files.pythonhosted.org/packages/df/3d/9e7eee57b37c80cec63322c0231bb6da3cfe535a91d7a4d64896fcb89357/triton-3.6.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a17a5d5985f0ac494ed8a8e54568f092f7057ef60e1b0fa09d3fd1512064e803", size = 188273063, upload-time = "2026-01-20T16:01:07.278Z" }, + { url = "https://files.pythonhosted.org/packages/48/db/56ee649cab5eaff4757541325aca81f52d02d4a7cd3506776cad2451e060/triton-3.6.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0b3a97e8ed304dfa9bd23bb41ca04cdf6b2e617d5e782a8653d616037a5d537d", size = 176274804, upload-time = "2026-01-20T16:16:31.528Z" }, { url = "https://files.pythonhosted.org/packages/f6/56/6113c23ff46c00aae423333eb58b3e60bdfe9179d542781955a5e1514cb3/triton-3.6.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:46bd1c1af4b6704e554cad2eeb3b0a6513a980d470ccfa63189737340c7746a7", size = 188397994, upload-time = "2026-01-20T16:01:14.236Z" }, ] From 9a8d00ae780cfd3230194bc380931e9ab61aabe9 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Mon, 27 Apr 2026 13:33:09 -0500 Subject: [PATCH 25/30] Jeff/fix/rconnect2 (#1784) --- .gitignore | 3 + dimos/core/coordination/python_worker.py | 16 +- dimos/core/docker_module.py | 2 +- dimos/core/global_config.py | 11 +- dimos/hardware/sensors/camera/module.py | 5 +- .../lidar/fastlio2/fastlio_blueprints.py | 35 ++- .../sensors/lidar/livox/livox_blueprints.py | 4 +- dimos/manipulation/blueprints.py | 10 +- dimos/manipulation/grasping/demo_grasping.py | 4 +- .../wavefront_frontier_goal_selector.py | 11 + dimos/navigation/replanning_a_star/module.py | 18 +- .../movement_manager/movement_manager.py | 133 +++++++++ .../movement_manager/test_movement_manager.py | 117 ++++++++ .../demo_object_scene_registration.py | 4 +- dimos/robot/all_blueprints.py | 2 + dimos/robot/cli/dimos.py | 48 +++- .../drone/blueprints/basic/drone_basic.py | 17 +- .../blueprints/perceptive/unitree_g1_shm.py | 10 +- .../primitive/uintree_g1_primitive_no_nav.py | 19 +- .../agentic/unitree_go2_security.py | 4 +- .../go2/blueprints/basic/unitree_go2_basic.py | 34 +-- .../go2/blueprints/basic/unitree_go2_fleet.py | 6 +- .../unitree_go2_webrtc_keyboard_teleop.py | 4 + .../go2/blueprints/smart/unitree_go2.py | 6 +- dimos/robot/unitree/keyboard_teleop.py | 10 +- dimos/robot/unitree/mujoco_connection.py | 16 +- dimos/simulation/unity/blueprint.py | 4 +- dimos/teleop/quest/blueprints.py | 4 +- dimos/test_no_sections.py | 2 + dimos/utils/generic.py | 17 ++ dimos/visualization/rerun/bridge.py | 253 +++++++++--------- dimos/visualization/rerun/conftest.py | 45 ++++ dimos/visualization/rerun/constants.py | 31 +++ .../visualization/rerun/test_viewer_ws_e2e.py | 201 ++++++++++++++ .../rerun/test_websocket_server.py | 210 +++++++++++++++ dimos/visualization/rerun/websocket_server.py | 244 +++++++++++++++++ dimos/visualization/vis_module.py | 87 ++++++ .../web/websocket_vis/websocket_vis_module.py | 24 +- docs/development/conventions.md | 12 + docs/usage/cli.md | 4 +- docs/usage/visualization.md | 42 +-- pyproject.toml | 2 +- uv.lock | 26 +- 43 files changed, 1465 insertions(+), 292 deletions(-) create mode 100644 dimos/navigation/smart_nav/modules/movement_manager/movement_manager.py create mode 100644 dimos/navigation/smart_nav/modules/movement_manager/test_movement_manager.py create mode 100644 dimos/visualization/rerun/conftest.py create mode 100644 dimos/visualization/rerun/constants.py create mode 100644 dimos/visualization/rerun/test_viewer_ws_e2e.py create mode 100644 dimos/visualization/rerun/test_websocket_server.py create mode 100644 dimos/visualization/rerun/websocket_server.py create mode 100644 dimos/visualization/vis_module.py create mode 100644 docs/development/conventions.md diff --git a/.gitignore b/.gitignore index 1816510c08..ea68926e96 100644 --- a/.gitignore +++ b/.gitignore @@ -74,6 +74,9 @@ CLAUDE.MD /.mcp.json *.speedscope.json +# Hidden/personal directories +.hidden/ + # Coverage htmlcov/ .coverage diff --git a/dimos/core/coordination/python_worker.py b/dimos/core/coordination/python_worker.py index 3c434a982e..6c3aab3a2d 100644 --- a/dimos/core/coordination/python_worker.py +++ b/dimos/core/coordination/python_worker.py @@ -18,6 +18,7 @@ import multiprocessing from multiprocessing.connection import Connection import os +import signal import sys import threading import traceback @@ -337,12 +338,15 @@ class _WorkerState: def _worker_entrypoint(conn: Connection, worker_id: int) -> None: apply_library_config() + # Ignore SIGINT so the coordinator can orchestrate shutdown via the pipe. + # Without this, workers race with the coordinator: they start tearing down + # modules locally while the coordinator tries to send stop() RPCs, causing + # BrokenPipeErrors. + signal.signal(signal.SIGINT, signal.SIG_IGN) state = _WorkerState(instances={}, worker_id=worker_id) try: _worker_loop(conn, state) - except KeyboardInterrupt: - logger.info("Worker got KeyboardInterrupt.", worker_id=worker_id) except Exception as e: logger.error(f"Worker process error: {e}", exc_info=True) finally: @@ -361,12 +365,6 @@ def _worker_entrypoint(conn: Connection, worker_id: int) -> None: worker_id=worker_id, module_id=module_id, ) - except KeyboardInterrupt: - logger.warning( - "KeyboardInterrupt during worker stop", - module=type(instance).__name__, - worker_id=worker_id, - ) except Exception: logger.error("Error during worker shutdown", exc_info=True) @@ -433,7 +431,7 @@ def _worker_loop(conn: Connection, state: _WorkerState) -> None: if not conn.poll(timeout=0.1): continue request = conn.recv() - except (EOFError, KeyboardInterrupt): + except EOFError: break try: diff --git a/dimos/core/docker_module.py b/dimos/core/docker_module.py index 3ad9620556..f82a1b56db 100644 --- a/dimos/core/docker_module.py +++ b/dimos/core/docker_module.py @@ -30,7 +30,7 @@ from dimos.core.rpc_client import ModuleProxyProtocol, RpcCall from dimos.protocol.rpc.pubsubrpc import LCMRPC from dimos.utils.logging_config import setup_logger -from dimos.visualization.rerun.bridge import RERUN_GRPC_PORT, RERUN_WEB_PORT +from dimos.visualization.rerun.constants import RERUN_GRPC_PORT, RERUN_WEB_PORT if TYPE_CHECKING: from collections.abc import Callable diff --git a/dimos/core/global_config.py b/dimos/core/global_config.py index 214401959e..435f421dd1 100644 --- a/dimos/core/global_config.py +++ b/dimos/core/global_config.py @@ -13,13 +13,16 @@ # limitations under the License. import re -from typing import Literal, TypeAlias from pydantic_settings import BaseSettings, SettingsConfigDict from dimos.models.vl.types import VlModelName - -ViewerBackend: TypeAlias = Literal["rerun", "rerun-web", "rerun-connect", "foxglove", "none"] +from dimos.visualization.rerun.constants import ( + RERUN_ENABLE_WEB, + RERUN_OPEN_DEFAULT, + RerunOpenOption, + ViewerBackend, +) def _get_all_numbers(s: str) -> list[float]: @@ -37,6 +40,8 @@ class GlobalConfig(BaseSettings): replay_db: str = "go2_bigoffice" new_memory: bool = False viewer: ViewerBackend = "rerun" + rerun_open: RerunOpenOption = RERUN_OPEN_DEFAULT + rerun_web: bool = RERUN_ENABLE_WEB n_workers: int = 2 memory_limit: str = "auto" mujoco_camera_position: str | None = None diff --git a/dimos/hardware/sensors/camera/module.py b/dimos/hardware/sensors/camera/module.py index 9b4f50920c..0fe0d8f030 100644 --- a/dimos/hardware/sensors/camera/module.py +++ b/dimos/hardware/sensors/camera/module.py @@ -21,6 +21,7 @@ from dimos.agents.annotation import skill from dimos.core.coordination.blueprints import autoconnect from dimos.core.core import rpc +from dimos.core.global_config import global_config from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out from dimos.hardware.sensors.camera.spec import CameraHardware @@ -31,7 +32,7 @@ from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo from dimos.msgs.sensor_msgs.Image import Image, sharpness_barrier from dimos.spec import perception -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.vis_module import vis_module def default_transform() -> Transform: @@ -120,5 +121,5 @@ def stop(self) -> None: demo_camera = autoconnect( CameraModule.blueprint(), - RerunBridgeModule.blueprint(), + vis_module(viewer_backend=global_config.viewer), ) diff --git a/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py b/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py index 2946f1d247..2c2a64d61e 100644 --- a/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py +++ b/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py @@ -15,30 +15,45 @@ from dimos.core.coordination.blueprints import autoconnect from dimos.hardware.sensors.lidar.fastlio2.module import FastLio2 from dimos.mapping.voxels import VoxelGridMapper -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.vis_module import vis_module voxel_size = 0.05 mid360_fastlio = autoconnect( FastLio2.blueprint(voxel_size=voxel_size, map_voxel_size=voxel_size, map_freq=-1), - RerunBridgeModule.blueprint(), + vis_module( + "rerun", + rerun_config={ + "visual_override": { + "world/lidar": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), + }, + }, + ), ).global_config(n_workers=2, robot_model="mid360_fastlio2") mid360_fastlio_voxels = autoconnect( FastLio2.blueprint(), VoxelGridMapper.blueprint(voxel_size=voxel_size, carve_columns=False), - RerunBridgeModule.blueprint( - visual_override={ - "world/lidar": None, - } + vis_module( + "rerun", + rerun_config={ + "visual_override": { + "world/global_map": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), + "world/lidar": None, + }, + }, ), ).global_config(n_workers=3, robot_model="mid360_fastlio2_voxels") mid360_fastlio_voxels_native = autoconnect( FastLio2.blueprint(voxel_size=voxel_size, map_voxel_size=voxel_size, map_freq=3.0), - RerunBridgeModule.blueprint( - visual_override={ - "world/lidar": None, - } + vis_module( + "rerun", + rerun_config={ + "visual_override": { + "world/lidar": None, + "world/global_map": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), + }, + }, ), ).global_config(n_workers=2, robot_model="mid360_fastlio2") diff --git a/dimos/hardware/sensors/lidar/livox/livox_blueprints.py b/dimos/hardware/sensors/lidar/livox/livox_blueprints.py index 34ebc33c2a..e437d73994 100644 --- a/dimos/hardware/sensors/lidar/livox/livox_blueprints.py +++ b/dimos/hardware/sensors/lidar/livox/livox_blueprints.py @@ -14,9 +14,9 @@ from dimos.core.coordination.blueprints import autoconnect from dimos.hardware.sensors.lidar.livox.module import Mid360 -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.vis_module import vis_module mid360 = autoconnect( Mid360.blueprint(), - RerunBridgeModule.blueprint(), + vis_module("rerun"), ).global_config(n_workers=2, robot_model="mid360") diff --git a/dimos/manipulation/blueprints.py b/dimos/manipulation/blueprints.py index f950ea8efa..1c006c1d04 100644 --- a/dimos/manipulation/blueprints.py +++ b/dimos/manipulation/blueprints.py @@ -44,7 +44,7 @@ from dimos.msgs.sensor_msgs.JointState import JointState from dimos.perception.object_scene_registration import ObjectSceneRegistrationModule from dimos.robot.catalog.ufactory import xarm6 as _catalog_xarm6, xarm7 as _catalog_xarm7 -from dimos.robot.foxglove_bridge import FoxgloveBridge # TODO: migrate to rerun +from dimos.visualization.vis_module import vis_module # Single XArm6 planner (standalone, no coordinator) _xarm6_planner_cfg = _catalog_xarm6( @@ -196,14 +196,14 @@ use_aabb=True, max_obstacle_width=0.06, ), - FoxgloveBridge.blueprint(), # TODO: migrate to rerun + vis_module("foxglove"), ) .transports( { ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), } ) - .global_config(viewer="foxglove", n_workers=4) + .global_config(n_workers=4) ) @@ -289,7 +289,7 @@ from dimos.robot.catalog.ufactory import XARM7_SIM_PATH from dimos.simulation.engines.mujoco_sim_module import MujocoSimModule -from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode +from dimos.visualization.rerun.bridge import RerunBridgeModule _xarm7_sim_cfg = _catalog_xarm7( name="arm", @@ -323,7 +323,7 @@ hardware=[_xarm7_sim_cfg.to_hardware_component()], tasks=[_xarm7_sim_cfg.to_task_config()], ), - RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode()), + RerunBridgeModule.blueprint(), ).transports( { ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), diff --git a/dimos/manipulation/grasping/demo_grasping.py b/dimos/manipulation/grasping/demo_grasping.py index 37e1d38f1e..4a1d4b2cf6 100644 --- a/dimos/manipulation/grasping/demo_grasping.py +++ b/dimos/manipulation/grasping/demo_grasping.py @@ -22,7 +22,7 @@ from dimos.manipulation.grasping.grasping import GraspingModule from dimos.perception.detection.detectors.yoloe import YoloePromptMode from dimos.perception.object_scene_registration import ObjectSceneRegistrationModule -from dimos.robot.foxglove_bridge import FoxgloveBridge +from dimos.visualization.vis_module import vis_module camera_module = RealSenseCamera.blueprint(enable_pointcloud=False) @@ -44,7 +44,7 @@ ("/tmp", "/tmp", "rw") ], # Grasp visualization debug standalone: python -m dimos.manipulation.grasping.visualize_grasps ), - FoxgloveBridge.blueprint(), + vis_module("foxglove"), McpServer.blueprint(), McpClient.blueprint(), ).global_config(viewer="foxglove") diff --git a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py index b8dbe0dfc8..338d10d9b0 100644 --- a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py @@ -115,6 +115,7 @@ class WavefrontFrontierExplorer(Module): goal_reached: In[Bool] explore_cmd: In[Bool] stop_explore_cmd: In[Bool] + stop_movement: In[Bool] # LCM outputs goal_request: Out[PoseStamped] @@ -171,6 +172,10 @@ def start(self) -> None: unsub = self.stop_explore_cmd.subscribe(self._on_stop_explore_cmd) self.register_disposable(Disposable(unsub)) + if self.stop_movement.transport is not None: + unsub = self.stop_movement.subscribe(self._on_stop_movement) + self.register_disposable(Disposable(unsub)) + @rpc def stop(self) -> None: self.stop_exploration() @@ -201,6 +206,12 @@ def _on_stop_explore_cmd(self, msg: Bool) -> None: logger.info("Received exploration stop command via LCM") self.stop_exploration() + def _on_stop_movement(self, msg: Bool) -> None: + """Handle stop movement from teleop — cancel active exploration.""" + if msg.data and self.exploration_active: + logger.info("WavefrontFrontierExplorer: stop_movement received, stopping exploration") + self.stop_exploration() + def _count_costmap_information(self, costmap: OccupancyGrid) -> int: """ Count the amount of information in a costmap (free space + obstacles). diff --git a/dimos/navigation/replanning_a_star/module.py b/dimos/navigation/replanning_a_star/module.py index 2375af20ce..efc16b52d6 100644 --- a/dimos/navigation/replanning_a_star/module.py +++ b/dimos/navigation/replanning_a_star/module.py @@ -28,6 +28,9 @@ from dimos.msgs.nav_msgs.Path import Path from dimos.navigation.base import NavigationInterface, NavigationState from dimos.navigation.replanning_a_star.global_planner import GlobalPlanner +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() class ReplanningAStarPlanner(Module, NavigationInterface): @@ -36,10 +39,11 @@ class ReplanningAStarPlanner(Module, NavigationInterface): goal_request: In[PoseStamped] clicked_point: In[PointStamped] target: In[PoseStamped] + stop_movement: In[Bool] goal_reached: Out[Bool] navigation_state: Out[String] # TODO: set it - cmd_vel: Out[Twist] + nav_cmd_vel: Out[Twist] path: Out[Path] navigation_costmap: Out[OccupancyGrid] @@ -72,9 +76,14 @@ def start(self) -> None: ) ) + if self.stop_movement.transport is not None: + self.register_disposable( + Disposable(self.stop_movement.subscribe(self._on_stop_movement)) + ) + self.register_disposable(self._planner.path.subscribe(self.path.publish)) - self.register_disposable(self._planner.cmd_vel.subscribe(self.cmd_vel.publish)) + self.register_disposable(self._planner.cmd_vel.subscribe(self.nav_cmd_vel.publish)) self.register_disposable(self._planner.goal_reached.subscribe(self.goal_reached.publish)) @@ -92,6 +101,11 @@ def stop(self) -> None: super().stop() + def _on_stop_movement(self, msg: Bool) -> None: + if msg.data: + logger.info("ReplanningAStarPlanner: stop_movement received, cancelling goal") + self.cancel_goal() + @rpc def set_goal(self, goal: PoseStamped) -> bool: self._planner.handle_goal_request(goal) diff --git a/dimos/navigation/smart_nav/modules/movement_manager/movement_manager.py b/dimos/navigation/smart_nav/modules/movement_manager/movement_manager.py new file mode 100644 index 0000000000..5a2dd195c0 --- /dev/null +++ b/dimos/navigation/smart_nav/modules/movement_manager/movement_manager.py @@ -0,0 +1,133 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MovementManager: click-to-goal relay + teleop/nav velocity mux.""" + +from __future__ import annotations + +import math +import threading +import time +from typing import Any + +from dimos_lcm.std_msgs import Bool # type: ignore[import-untyped] +from reactivex.disposable import Disposable + +from dimos.core.core import rpc +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import In, Out +from dimos.msgs.geometry_msgs.PointStamped import PointStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class MovementManagerConfig(ModuleConfig): + tele_cooldown_sec: float = 1.0 + tele_cmd_vel_scaling: Twist = Twist(Vector3(1, 1, 1), Vector3(1, 1, 1)) + + +class MovementManager(Module): + """Combine tele_cmd_vel (keyboard controls) and nav_cmd_vel in a sane way, output cmd_vel""" + + config: MovementManagerConfig + + clicked_point: In[PointStamped] + nav_cmd_vel: In[Twist] + tele_cmd_vel: In[Twist] + + goal: Out[PointStamped] + way_point: Out[PointStamped] + cmd_vel: Out[Twist] + stop_movement: Out[Bool] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._lock = threading.Lock() + self._teleop_active = False + self._last_teleop_time = 0.0 + + @rpc + def start(self) -> None: + super().start() + self.register_disposable(Disposable(self.clicked_point.subscribe(self._on_click))) + self.register_disposable(Disposable(self.nav_cmd_vel.subscribe(self._on_nav))) + self.register_disposable(Disposable(self.tele_cmd_vel.subscribe(self._on_teleop))) + + @rpc + def stop(self) -> None: + with self._lock: + self._teleop_active = False + super().stop() + + def _on_click(self, msg: PointStamped) -> None: + if not all(math.isfinite(v) for v in (msg.x, msg.y, msg.z)): + logger.warning("Ignored invalid click", x=msg.x, y=msg.y, z=msg.z) + return + if abs(msg.x) > 500 or abs(msg.y) > 500 or abs(msg.z) > 50: + logger.warning("Ignored out-of-range click", x=msg.x, y=msg.y, z=msg.z) + return + + logger.debug("Goal", x=round(msg.x, 1), y=round(msg.y, 1), z=round(msg.z, 1)) + self.way_point.publish(msg) + self.goal.publish(msg) + + def _cancel_goal(self) -> None: + self.stop_movement.publish(Bool(data=True)) + # NOTE: this NaN goal is more of a safety fallback. + # It can be REALLY bad if a robot is supposed to stop moving but wont + # we should probably think a more robust/strict requirement on planners + cancel = PointStamped( + ts=time.time(), frame_id="map", x=float("nan"), y=float("nan"), z=float("nan") + ) + self.way_point.publish(cancel) + self.goal.publish(cancel) + logger.debug("Navigation cancelled — waiting for new goal") + + def _on_nav(self, msg: Twist) -> None: + with self._lock: + if self._teleop_active: + # check if cooldown has expired + elapsed = time.monotonic() - self._last_teleop_time + if elapsed < self.config.tele_cooldown_sec: + return + self._teleop_active = False + self.cmd_vel.publish(msg) + + def _on_teleop(self, msg: Twist) -> None: + with self._lock: + was_active = self._teleop_active + self._teleop_active = True + self._last_teleop_time = time.monotonic() + + if not was_active: + self._cancel_goal() + logger.info("Teleop active") + + scale = self.config.tele_cmd_vel_scaling + scaled = Twist( + linear=Vector3( + msg.linear.x * scale.linear.x, + msg.linear.y * scale.linear.y, + msg.linear.z * scale.linear.z, + ), + angular=Vector3( + msg.angular.x * scale.angular.x, + msg.angular.y * scale.angular.y, + msg.angular.z * scale.angular.z, + ), + ) + self.cmd_vel.publish(scaled) diff --git a/dimos/navigation/smart_nav/modules/movement_manager/test_movement_manager.py b/dimos/navigation/smart_nav/modules/movement_manager/test_movement_manager.py new file mode 100644 index 0000000000..6858055605 --- /dev/null +++ b/dimos/navigation/smart_nav/modules/movement_manager/test_movement_manager.py @@ -0,0 +1,117 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for MovementManager: click-to-goal + teleop/nav velocity mux.""" + +from __future__ import annotations + +import math +import time +from unittest.mock import MagicMock + +import pytest + +from dimos.msgs.geometry_msgs.PointStamped import PointStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.navigation.smart_nav.modules.movement_manager.movement_manager import ( + MovementManager, +) + + +@pytest.fixture() +def manager() -> MovementManager: + """Create a real MovementManager and mock the publish methods on its output streams.""" + module = MovementManager(tele_cooldown_sec=0.1) + module.cmd_vel.publish = MagicMock() + module.stop_movement.publish = MagicMock() + module.goal.publish = MagicMock() + module.way_point.publish = MagicMock() + yield module + module._close_module() + + +def _twist(lx: float = 0.0) -> Twist: + return Twist(linear=Vector3(lx, 0, 0), angular=Vector3(0, 0, 0)) + + +def _click(x: float = 1.0, y: float = 2.0, z: float = 0.0) -> PointStamped: + return PointStamped(ts=time.time(), frame_id="map", x=x, y=y, z=z) + + +def test_teleop_suppresses_nav_and_cancels_goal(manager: MovementManager) -> None: + """Teleop arriving should suppress nav, publish stop_movement, and cancel the goal with NaN.""" + manager.config.tele_cooldown_sec = 10.0 + manager._on_teleop(_twist(lx=0.3)) + + # Nav is suppressed + manager.cmd_vel.publish.reset_mock() # type: ignore[union-attr] + manager._on_nav(_twist(lx=0.9)) + manager.cmd_vel.publish.assert_not_called() # type: ignore[union-attr] + + # stop_movement fired + manager.stop_movement.publish.assert_called_once() # type: ignore[union-attr] + + # Goal cancelled with NaN + cancel_msg = manager.goal.publish.call_args[0][0] # type: ignore[union-attr] + assert math.isnan(cancel_msg.x) + + +def test_nav_resumes_after_cooldown(manager: MovementManager) -> None: + """After the cooldown expires, nav commands pass through again.""" + manager.config.tele_cooldown_sec = 0.05 + manager._on_teleop(_twist(lx=0.3)) + time.sleep(0.1) + manager.cmd_vel.publish.reset_mock() # type: ignore[union-attr] + + manager._on_nav(_twist(lx=0.9)) + manager.cmd_vel.publish.assert_called_once() # type: ignore[union-attr] + + +def test_valid_click_publishes_goal(manager: MovementManager) -> None: + """A valid click should publish to both goal and way_point.""" + click = _click(x=5.0, y=3.0, z=0.1) + manager._on_click(click) + manager.goal.publish.assert_called_once_with(click) # type: ignore[union-attr] + manager.way_point.publish.assert_called_once_with(click) # type: ignore[union-attr] + + +def test_invalid_clicks_rejected(manager: MovementManager) -> None: + """NaN, Inf, and out-of-range clicks should not publish.""" + for bad_click in [ + _click(x=float("nan")), + _click(x=float("inf")), + _click(x=600.0), + ]: + manager._on_click(bad_click) + manager.goal.publish.assert_not_called() # type: ignore[union-attr] + + +def test_tele_cmd_vel_scaling() -> None: + """tele_cmd_vel_scaling multiplies each teleop twist component independently.""" + scaling = Twist(Vector3(0.5, 2.0, 0.0), Vector3(1.0, 1.0, 0.25)) + module = MovementManager(tele_cooldown_sec=10.0, tele_cmd_vel_scaling=scaling) + module.cmd_vel.publish = MagicMock() + module.stop_movement.publish = MagicMock() + module.goal.publish = MagicMock() + module.way_point.publish = MagicMock() + + module._on_teleop(Twist(Vector3(1, 1, 1), Vector3(1, 1, 1))) + + published = module.cmd_vel.publish.call_args[0][0] # type: ignore[union-attr] + assert published.linear.x == pytest.approx(0.5) + assert published.linear.y == pytest.approx(2.0) + assert published.linear.z == pytest.approx(0.0) + assert published.angular.z == pytest.approx(0.25) + module._close_module() diff --git a/dimos/perception/demo_object_scene_registration.py b/dimos/perception/demo_object_scene_registration.py index c9b489f54b..28044dec13 100644 --- a/dimos/perception/demo_object_scene_registration.py +++ b/dimos/perception/demo_object_scene_registration.py @@ -20,7 +20,7 @@ from dimos.hardware.sensors.camera.zed.compat import ZEDCamera from dimos.perception.detection.detectors.yoloe import YoloePromptMode from dimos.perception.object_scene_registration import ObjectSceneRegistrationModule -from dimos.robot.foxglove_bridge import FoxgloveBridge +from dimos.visualization.vis_module import vis_module camera_choice = "zed" @@ -34,7 +34,7 @@ demo_object_scene_registration = autoconnect( camera_module, ObjectSceneRegistrationModule.blueprint(target_frame="world", prompt_mode=YoloePromptMode.LRPC), - FoxgloveBridge.blueprint(), + vis_module("foxglove"), McpServer.blueprint(), McpClient.blueprint(), ).global_config(viewer="foxglove") diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index c794a67124..11b2ceb731 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -153,6 +153,7 @@ "mock-b1-connection-module": "dimos.robot.unitree.b1.connection.MockB1ConnectionModule", "module-a": "dimos.robot.unitree.demo_error_on_name_conflicts.ModuleA", "module-b": "dimos.robot.unitree.demo_error_on_name_conflicts.ModuleB", + "movement-manager": "dimos.navigation.smart_nav.modules.movement_manager.movement_manager.MovementManager", "mujoco-sim-module": "dimos.simulation.engines.mujoco_sim_module.MujocoSimModule", "navigation-module": "dimos.robot.unitree.rosnav.NavigationModule", "navigation-skill-container": "dimos.agents.skills.navigation.NavigationSkillContainer", @@ -175,6 +176,7 @@ "reid-module": "dimos.perception.detection.reid.module.ReidModule", "replanning-a-star-planner": "dimos.navigation.replanning_a_star.module.ReplanningAStarPlanner", "rerun-bridge-module": "dimos.visualization.rerun.bridge.RerunBridgeModule", + "rerun-web-socket-server": "dimos.visualization.rerun.websocket_server.RerunWebSocketServer", "ros-nav": "dimos.navigation.rosnav.ROSNav", "security-module": "dimos.experimental.security_demo.security_module.SecurityModule", "semantic-search": "dimos.memory2.module.SemanticSearch", diff --git a/dimos/robot/cli/dimos.py b/dimos/robot/cli/dimos.py index 37d1bd2be0..e99553c2b3 100644 --- a/dimos/robot/cli/dimos.py +++ b/dimos/robot/cli/dimos.py @@ -21,10 +21,11 @@ import json import os from pathlib import Path +import signal import sys import time import types -from typing import TYPE_CHECKING, Any, Union, get_args, get_origin +from typing import TYPE_CHECKING, Any, Union, cast, get_args, get_origin import click from dotenv import load_dotenv @@ -38,7 +39,10 @@ from dimos.core.daemon import daemonize, install_signal_handlers from dimos.core.global_config import GlobalConfig, global_config from dimos.core.run_registry import get_most_recent, is_pid_alive, stop_entry +from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.protocol.service.lcmservice import autoconf from dimos.utils.logging_config import setup_logger +from dimos.visualization.rerun.constants import RerunOpenOption if TYPE_CHECKING: from dimos.core.coordination.blueprints import Blueprint, BlueprintAtom @@ -222,6 +226,10 @@ def run( cli_config_overrides: dict[str, Any] = ctx.obj + # this is a workaround until we have a proper way to have delayed-module-choice in blueprints + # ex: vis_module(viewer=global_config.viewer) is WRONG (viewer will always be default value) without this patch + global_config.update(**cli_config_overrides) + # Clean stale registry entries stale = cleanup_stale() if stale: @@ -660,17 +668,43 @@ def send( @main.command(name="rerun-bridge") def rerun_bridge_cmd( - viewer_mode: str = typer.Option( - "native", help="Viewer mode: native (desktop), web (browser), none (headless)" - ), memory_limit: str = typer.Option( "25%", help="Memory limit for Rerun viewer (e.g., '4GB', '16GB', '25%')" ), + rerun_open: str = typer.Option( + "native", help="How to open Rerun: one of native, web, both, none" + ), + rerun_web: bool = typer.Option( + True, "--rerun-web/--no-rerun-web", help="Enable/Disable Rerun web server" + ), ) -> None: - """Launch the Rerun visualization bridge.""" - from dimos.visualization.rerun.bridge import run_bridge + """Launch the Rerun visualization bridge. + + Standalone utility: runs the bridge directly in the main process (no + blueprint / worker pool) so users can attach a viewer to existing LCM + traffic without building a full module graph. + """ + # Deferred: RerunBridgeModule pulls in the rerun package (~1s), keep it + # out of the CLI's hot path so `dimos --help` stays fast. + from dimos.visualization.rerun.bridge import RerunBridgeModule + + valid = get_args(RerunOpenOption) + if rerun_open not in valid: + raise typer.BadParameter( + f"rerun_open must be one of {valid}, got {rerun_open!r}", param_hint="--rerun-open" + ) + autoconf(check_only=True) + + bridge = RerunBridgeModule( + memory_limit=memory_limit, + rerun_open=cast("RerunOpenOption", rerun_open), + rerun_web=rerun_web, + pubsubs=[LCM()], + ) + bridge.start() - run_bridge(viewer_mode=viewer_mode, memory_limit=memory_limit) + signal.signal(signal.SIGINT, lambda *_: bridge.stop()) + signal.pause() if __name__ == "__main__": diff --git a/dimos/robot/drone/blueprints/basic/drone_basic.py b/dimos/robot/drone/blueprints/basic/drone_basic.py index c1838d6ac7..aaf82f6355 100644 --- a/dimos/robot/drone/blueprints/basic/drone_basic.py +++ b/dimos/robot/drone/blueprints/basic/drone_basic.py @@ -20,10 +20,9 @@ from dimos.core.coordination.blueprints import autoconnect from dimos.core.global_config import global_config -from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.robot.drone.camera_module import DroneCameraModule from dimos.robot.drone.connection_module import DroneConnectionModule -from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule +from dimos.visualization.vis_module import vis_module def _static_drone_body(rr: Any) -> list[Any]: @@ -60,23 +59,12 @@ def _drone_rerun_blueprint() -> Any: _rerun_config = { "blueprint": _drone_rerun_blueprint, - "pubsubs": [LCM()], "static": { "world/tf/base_link": _static_drone_body, }, } -# Conditional visualization -if global_config.viewer == "foxglove": - from dimos.robot.foxglove_bridge import FoxgloveBridge - - _vis = FoxgloveBridge.blueprint() -elif global_config.viewer.startswith("rerun"): - from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode - - _vis = RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **_rerun_config) -else: - _vis = autoconnect() +_vis = vis_module(global_config.viewer, rerun_config=_rerun_config) # Determine connection string based on replay flag connection_string = "udp:0.0.0.0:14550" @@ -92,7 +80,6 @@ def _drone_rerun_blueprint() -> Any: outdoor=False, ), DroneCameraModule.blueprint(camera_intrinsics=[1000.0, 1000.0, 960.0, 540.0]), - WebsocketVisModule.blueprint(), ) __all__ = [ diff --git a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py index dd135a60a1..4941abad38 100644 --- a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py +++ b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py @@ -17,10 +17,11 @@ from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE from dimos.core.coordination.blueprints import autoconnect +from dimos.core.global_config import global_config from dimos.core.transport import pSHMTransport from dimos.msgs.sensor_msgs.Image import Image -from dimos.robot.foxglove_bridge import FoxgloveBridge from dimos.robot.unitree.g1.blueprints.perceptive.unitree_g1 import unitree_g1 +from dimos.visualization.vis_module import vis_module unitree_g1_shm = autoconnect( unitree_g1.transports( @@ -30,10 +31,9 @@ ), } ), - FoxgloveBridge.blueprint( - shm_channels=[ - "/color_image#sensor_msgs.Image", - ] + vis_module( + viewer_backend=global_config.viewer, + foxglove_config={"shm_channels": ["/color_image#sensor_msgs.Image"]}, ), ) diff --git a/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py b/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py index b04443732f..eeabea7909 100644 --- a/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py +++ b/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py @@ -40,8 +40,7 @@ from dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector import ( WavefrontFrontierExplorer, ) -from dimos.protocol.pubsub.impl.lcmpubsub import LCM -from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule +from dimos.visualization.vis_module import vis_module def _convert_camera_info(camera_info: Any) -> Any: @@ -94,7 +93,6 @@ def _g1_rerun_blueprint() -> Any: rerun_config = { "blueprint": _g1_rerun_blueprint, - "pubsubs": [LCM()], "visual_override": { "world/camera_info": _convert_camera_info, "world/navigation_costmap": _convert_navigation_costmap, @@ -104,18 +102,7 @@ def _g1_rerun_blueprint() -> Any: }, } -if global_config.viewer == "foxglove": - from dimos.robot.foxglove_bridge import FoxgloveBridge - - _with_vis = autoconnect(FoxgloveBridge.blueprint()) -elif global_config.viewer.startswith("rerun"): - from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode - - _with_vis = autoconnect( - RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config) - ) -else: - _with_vis = autoconnect() +_with_vis = vis_module(viewer_backend=global_config.viewer, rerun_config=rerun_config) def _create_webcam() -> Webcam: @@ -150,8 +137,6 @@ def _create_webcam() -> Webcam: VoxelGridMapper.blueprint(), CostMapper.blueprint(), WavefrontFrontierExplorer.blueprint(), - # Visualization - WebsocketVisModule.blueprint(), ) .global_config(n_workers=4, robot_model="unitree_g1") .transports( diff --git a/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_security.py b/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_security.py index be9e04a7fd..4b39a106b8 100644 --- a/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_security.py +++ b/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_security.py @@ -18,7 +18,7 @@ from dimos.core.coordination.blueprints import autoconnect from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.robot.unitree.go2.blueprints.agentic.unitree_go2_agentic import unitree_go2_agentic -from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode +from dimos.visualization.rerun.bridge import RerunBridgeModule def _convert_camera_info(camera_info: Any) -> Any: @@ -85,7 +85,7 @@ def _go2_rerun_blueprint() -> Any: unitree_go2_security = autoconnect( unitree_go2_agentic, - RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config), + RerunBridgeModule.blueprint(**rerun_config), ) __all__ = ["unitree_go2_security"] diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py index 54a2c0f7c6..4f86ccb0a3 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py @@ -22,10 +22,9 @@ from dimos.core.global_config import global_config from dimos.core.transport import pSHMTransport from dimos.msgs.sensor_msgs.Image import Image -from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator from dimos.robot.unitree.go2.connection import GO2Connection -from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule +from dimos.visualization.vis_module import vis_module # Mac has some issue with high bandwidth UDP, so we use pSHMTransport for color_image # actually we can use pSHMTransport for all platforms, and for all streams @@ -99,9 +98,6 @@ def _go2_rerun_blueprint() -> Any: rerun_config = { "blueprint": _go2_rerun_blueprint, - # any pubsub that supports subscribe_all and topic that supports str(topic) - # is acceptable here - "pubsubs": [LCM()], # Custom converters for specific rerun entity paths # Normally all these would be specified in their respectative modules # Until this is implemented we have central overrides here @@ -123,30 +119,20 @@ def _go2_rerun_blueprint() -> Any: }, } - -if global_config.viewer == "foxglove": - from dimos.robot.foxglove_bridge import FoxgloveBridge - - with_vis = autoconnect( - _transports_base, - FoxgloveBridge.blueprint(shm_channels=["/color_image#sensor_msgs.Image"]), - ) -elif global_config.viewer.startswith("rerun"): - from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode - - with_vis = autoconnect( - _transports_base, - RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config), - ) -else: - with_vis = _transports_base +_with_vis = autoconnect( + _transports_base, + vis_module( + viewer_backend=global_config.viewer, + rerun_config=rerun_config, + foxglove_config={"shm_channels": ["/color_image#sensor_msgs.Image"]}, + ), +) unitree_go2_basic = ( autoconnect( - with_vis, + _with_vis, GO2Connection.blueprint(), - WebsocketVisModule.blueprint(), ) .global_config(n_workers=4, robot_model="unitree_go2") .configurators(ClockSyncConfigurator()) diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py index a7a10767bf..bda362eeca 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py @@ -22,15 +22,13 @@ from dimos.core.coordination.blueprints import autoconnect from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator -from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import with_vis +from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import _with_vis from dimos.robot.unitree.go2.fleet_connection import Go2FleetConnection -from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule unitree_go2_fleet = ( autoconnect( - with_vis, + _with_vis, Go2FleetConnection.blueprint(), - WebsocketVisModule.blueprint(), ) .global_config(n_workers=4, robot_model="unitree_go2") .configurators(ClockSyncConfigurator()) diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_keyboard_teleop.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_keyboard_teleop.py index 01117ec3b5..3be0c62379 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_keyboard_teleop.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_keyboard_teleop.py @@ -31,6 +31,10 @@ unitree_go2_webrtc_keyboard_teleop = autoconnect( unitree_go2_coordinator, KeyboardTeleop.blueprint(), +).remappings( + [ + (KeyboardTeleop, "tele_cmd_vel", "cmd_vel"), + ] ) __all__ = ["unitree_go2_webrtc_keyboard_teleop"] diff --git a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py index f353d995af..16711115ab 100644 --- a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py +++ b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py @@ -27,6 +27,7 @@ ) from dimos.navigation.patrolling.module import PatrollingModule from dimos.navigation.replanning_a_star.module import ReplanningAStarPlanner +from dimos.navigation.smart_nav.modules.movement_manager.movement_manager import MovementManager from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import unitree_go2_basic unitree_go2 = autoconnect( @@ -36,7 +37,8 @@ ReplanningAStarPlanner.blueprint(), WavefrontFrontierExplorer.blueprint(), PatrollingModule.blueprint(), -).global_config(n_workers=9, robot_model="unitree_go2") + MovementManager.blueprint(), +).global_config(n_workers=10, robot_model="unitree_go2") class Go2MemoryConfig(RecorderConfig): @@ -52,6 +54,6 @@ class Go2Memory(Recorder): unitree_go2_memory = autoconnect( unitree_go2, Go2Memory.blueprint(), -).global_config(n_workers=10) +).global_config(n_workers=11) __all__ = ["unitree_go2", "unitree_go2_memory"] diff --git a/dimos/robot/unitree/keyboard_teleop.py b/dimos/robot/unitree/keyboard_teleop.py index e3c78ecc52..3e8f76a1cc 100644 --- a/dimos/robot/unitree/keyboard_teleop.py +++ b/dimos/robot/unitree/keyboard_teleop.py @@ -38,14 +38,14 @@ class KeyboardTeleop(Module): """Pygame-based keyboard control module. - Outputs standard Twist messages on /cmd_vel for velocity control. + Outputs standard Twist messages on /tele_cmd_vel for velocity control. Speed constants can be tuned at the top of this file, or overridden per-instance by passing linear_speed / angular_speed / boost_multiplier / slow_multiplier to the constructor. """ - cmd_vel: Out[Twist] # Standard velocity commands + tele_cmd_vel: Out[Twist] # Standard velocity commands _stop_event: threading.Event _keys_held: set[int] | None = None @@ -86,7 +86,7 @@ def stop(self) -> None: stop_twist = Twist() stop_twist.linear = Vector3(0, 0, 0) stop_twist.angular = Vector3(0, 0, 0) - self.cmd_vel.publish(stop_twist) + self.tele_cmd_vel.publish(stop_twist) self._stop_event.set() @@ -119,7 +119,7 @@ def _pygame_loop(self) -> None: stop_twist = Twist() stop_twist.linear = Vector3(0, 0, 0) stop_twist.angular = Vector3(0, 0, 0) - self.cmd_vel.publish(stop_twist) + self.tele_cmd_vel.publish(stop_twist) print("EMERGENCY STOP!") elif event.key == pygame.K_ESCAPE: # ESC quits @@ -163,7 +163,7 @@ def _pygame_loop(self) -> None: twist.angular.z *= speed_multiplier # Always publish twist at 50Hz - self.cmd_vel.publish(twist) + self.tele_cmd_vel.publish(twist) self._update_display(twist) diff --git a/dimos/robot/unitree/mujoco_connection.py b/dimos/robot/unitree/mujoco_connection.py index 39c0904684..43ddeb6530 100644 --- a/dimos/robot/unitree/mujoco_connection.py +++ b/dimos/robot/unitree/mujoco_connection.py @@ -20,9 +20,12 @@ from collections.abc import Callable import functools import json +import os +from pathlib import Path import pickle import subprocess import sys +import sysconfig import threading import time from typing import Any, TypeVar @@ -126,12 +129,23 @@ def start(self) -> None: # Launch the subprocess try: - # mjpython must be used macOS (because of launch_passive inside mujoco_process.py) + # mjpython must be used on macOS (because of launch_passive inside mujoco_process.py). + # It needs libpython on the dylib search path; uv-installed Pythons + # use @rpath which doesn't always resolve inside venvs, so we + # point DYLD_LIBRARY_PATH at the real libpython directory. executable = sys.executable if sys.platform != "darwin" else "mjpython" + env = os.environ.copy() + if sys.platform == "darwin": + # on some systems mujoco looks in the wrong place for shared libraries. So we force it look in the right place + libdir = Path(sysconfig.get_config_var("LIBDIR") or "") + if libdir.is_dir(): + existing = env.get("DYLD_LIBRARY_PATH", "") + env["DYLD_LIBRARY_PATH"] = f"{libdir}:{existing}" if existing else str(libdir) self.process = subprocess.Popen( [executable, str(LAUNCHER_PATH), config_pickle, shm_names_json], stderr=subprocess.PIPE, + env=env, ) except Exception as e: diff --git a/dimos/simulation/unity/blueprint.py b/dimos/simulation/unity/blueprint.py index f7e2d34ccb..d9b29ee610 100644 --- a/dimos/simulation/unity/blueprint.py +++ b/dimos/simulation/unity/blueprint.py @@ -28,7 +28,7 @@ from dimos.core.coordination.blueprints import autoconnect from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.simulation.unity.module import UnityBridgeModule -from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode +from dimos.visualization.rerun.bridge import RerunBridgeModule def _rerun_blueprint() -> Any: @@ -57,5 +57,5 @@ def _rerun_blueprint() -> Any: unity_sim = autoconnect( UnityBridgeModule.blueprint(), - RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config), + RerunBridgeModule.blueprint(**rerun_config), ) diff --git a/dimos/teleop/quest/blueprints.py b/dimos/teleop/quest/blueprints.py index 57c925c3f0..b825f29a17 100644 --- a/dimos/teleop/quest/blueprints.py +++ b/dimos/teleop/quest/blueprints.py @@ -26,12 +26,12 @@ from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.teleop.quest.quest_extensions import ArmTeleopModule from dimos.teleop.quest.quest_types import Buttons -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.vis_module import vis_module # Arm teleop with press-and-hold engage (has rerun viz) teleop_quest_rerun = autoconnect( ArmTeleopModule.blueprint(), - RerunBridgeModule.blueprint(), + vis_module("rerun"), ).transports( { ("left_controller_output", PoseStamped): LCMTransport("/teleop/left_delta", PoseStamped), diff --git a/dimos/test_no_sections.py b/dimos/test_no_sections.py index 902288b2e6..79f2d61b8f 100644 --- a/dimos/test_no_sections.py +++ b/dimos/test_no_sections.py @@ -52,6 +52,8 @@ ".tox", # third-party vendored code "gtsam", + # hidden/personal directories + ".hidden", } # Lines that match section patterns but are actually programmatic / intentional. diff --git a/dimos/utils/generic.py b/dimos/utils/generic.py index 84168ce057..200c7c6d86 100644 --- a/dimos/utils/generic.py +++ b/dimos/utils/generic.py @@ -16,10 +16,27 @@ import hashlib import json import os +import socket import string from typing import Any, Generic, TypeVar, overload import uuid +import psutil + + +def get_local_ips() -> list[tuple[str, str]]: + """Return ``(ip, interface_name)`` for every non-loopback IPv4 address. + + Picks up physical, virtual, and VPN interfaces (including Tailscale). + """ + results: list[tuple[str, str]] = [] + for iface, addrs in psutil.net_if_addrs().items(): + for addr in addrs: + if addr.family == socket.AF_INET and not addr.address.startswith("127."): + results.append((addr.address, iface)) + return results + + _T = TypeVar("_T") diff --git a/dimos/visualization/rerun/bridge.py b/dimos/visualization/rerun/bridge.py index f2e3e51d08..f6744e74fb 100644 --- a/dimos/visualization/rerun/bridge.py +++ b/dimos/visualization/rerun/bridge.py @@ -18,18 +18,19 @@ from collections.abc import Callable from dataclasses import field -from functools import lru_cache +import socket import subprocess import time from typing import ( Any, - Literal, Protocol, TypeAlias, TypeGuard, cast, + get_args, runtime_checkable, ) +from urllib.parse import urlparse from reactivex.disposable import Disposable import rerun as rr @@ -37,19 +38,23 @@ import rerun.blueprint as rrb from rerun.blueprint import Blueprint from toolz import pipe # type: ignore[import-untyped] -import typer from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.protocol.pubsub.patterns import Glob, pattern_matches from dimos.protocol.pubsub.spec import SubscribeAllCapable +from dimos.utils.generic import get_local_ips from dimos.utils.logging_config import setup_logger +from dimos.visualization.rerun.constants import ( + RERUN_ENABLE_WEB, + RERUN_GRPC_PORT, + RERUN_OPEN_DEFAULT, + RERUN_WEB_PORT, + RerunOpenOption, +) from dimos.visualization.rerun.init import rerun_init -RERUN_GRPC_PORT = 9877 -RERUN_WEB_PORT = 9090 - # TODO OUT visual annotations # # In the future it would be nice if modules can annotate their individual OUTs with (general or rerun specific) @@ -95,7 +100,6 @@ BlueprintFactory: TypeAlias = Callable[[], "Blueprint"] -# to_rerun() can return a single archetype or a list of (entity_path, archetype) tuples RerunMulti: TypeAlias = "list[tuple[str, Archetype]]" RerunData: TypeAlias = "Archetype | RerunMulti" @@ -119,18 +123,16 @@ class RerunConvertible(Protocol): def to_rerun(self) -> RerunData: ... -ViewerMode = Literal["native", "web", "connect", "none"] - - def _hex_to_rgba(hex_color: str) -> int: """Convert '#RRGGBB' to a 0xRRGGBBAA int (fully opaque).""" h = hex_color.lstrip("#") - return (int(h, 16) << 8) | 0xFF + if len(h) == 6: + return int(h + "ff", 16) + return int(h[:8], 16) def _with_graph_tab(bp: Blueprint) -> Blueprint: """Add a Graph tab alongside the existing viewer layout without changing it.""" - root = bp.root_container return rrb.Blueprint( rrb.Tabs( @@ -156,50 +158,26 @@ def _default_blueprint() -> Blueprint: ) -# Maps global_config.viewer -> bridge viewer_mode. -# Evaluated at blueprint construction time (main process), not in start() (worker process). -_BACKEND_TO_MODE: dict[str, ViewerMode] = { - "rerun": "native", - "rerun-web": "web", - "rerun-connect": "connect", - "none": "none", -} - - -def _resolve_viewer_mode() -> ViewerMode: - from dimos.core.global_config import global_config - - return _BACKEND_TO_MODE.get(global_config.viewer, "native") - - class Config(ModuleConfig): - """Configuration for RerunBridgeModule.""" - pubsubs: list[SubscribeAllCapable[Any, Any]] = field(default_factory=lambda: [LCM()]) visual_override: dict[Glob | str, Callable[[Any], Archetype]] = field(default_factory=dict) - - # Static items logged once after start. Maps entity_path -> callable(rr) returning Archetype static: dict[str, Callable[[Any], Archetype]] = field(default_factory=dict) - - grpc_port: int = RERUN_GRPC_PORT - web_port: int = RERUN_WEB_PORT - - # Per-entity max update rate (Hz). Entities not listed are unthrottled. - # Use for heavy entities to prevent viewer backpressure. max_hz: dict[str, float] = field(default_factory=dict) entity_prefix: str = "world" topic_to_entity: Callable[[Any], str] | None = None - viewer_mode: ViewerMode = field(default_factory=_resolve_viewer_mode) connect_url: str = "rerun+http://127.0.0.1:9877/proxy" memory_limit: str = "25%" - - # Blueprint factory: callable(rrb) -> Blueprint for viewer layout configuration - # Set to None to disable default blueprint + rerun_open: RerunOpenOption = RERUN_OPEN_DEFAULT + rerun_web: bool = RERUN_ENABLE_WEB + web_port: int = RERUN_WEB_PORT blueprint: BlueprintFactory | None = _default_blueprint +Config.model_rebuild(_types_namespace={"Archetype": Archetype, "Blueprint": Blueprint}) + + class RerunBridgeModule(Module): """Bridge that logs messages from pubsubs to Rerun. @@ -217,22 +195,31 @@ class RerunBridgeModule(Module): """ config: Config + _last_log: dict[str, float] # TODO this doesn't belong here, either hardcode it or put it to rerun bridge config - GV_SCALE = 100.0 # graphviz inches to rerun screen units - MODULE_RADIUS = 30.0 - CHANNEL_RADIUS = 20.0 + GRAPH_VIZ_SCALE = 100.0 + MODULE_RADIUS = 20.0 + CHANNEL_RADIUS = 12.0 + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._last_log = {} + self._override_cache: dict[str, Callable[[Any], RerunData | None]] = {} - @lru_cache(maxsize=256) def _visual_override_for_entity_path( self, entity_path: str ) -> Callable[[Any], RerunData | None]: """Return a composed visual override for the entity path. Chains matching overrides from config, ending with final_convert - which handles .to_rerun() or passes through Archetypes. + which handles .to_rerun() or passes through Archetypes. Cached per + instance (not via ``lru_cache`` on a method, which would leak ``self``). """ - # find all matching converters for this entity path + cached = self._override_cache.get(entity_path) + if cached is not None: + return cached + matches = [ fn for pattern, fn in self.config.visual_override.items() @@ -241,9 +228,13 @@ def _visual_override_for_entity_path( # None means "suppress this topic entirely" if any(fn is None for fn in matches): - return lambda msg: None - # final step (ensures we return Archetype or None) + def suppressed(msg: Any) -> RerunData | None: + return None + + self._override_cache[entity_path] = suppressed + return suppressed + def final_convert(msg: Any) -> RerunData | None: if isinstance(msg, Archetype): return msg @@ -253,23 +244,21 @@ def final_convert(msg: Any) -> RerunData | None: return msg.to_rerun() return None - # compose all converters - return lambda msg: pipe(msg, *matches, final_convert) + def composed(msg: Any) -> RerunData | None: + return cast("RerunData | None", pipe(msg, *matches, final_convert)) + + self._override_cache[entity_path] = composed + return composed def _get_entity_path(self, topic: Any) -> str: - """Convert a topic to a Rerun entity path.""" if self.config.topic_to_entity: return self.config.topic_to_entity(topic) - # Default: use topic.name if available (LCM Topic), else str topic_str = getattr(topic, "name", None) or str(topic) - # Strip everything after # (LCM topic suffix) - topic_str = topic_str.split("#")[0] + topic_str = topic_str.split("#")[0] # strip LCM topic suffix return f"{self.config.entity_prefix}{topic_str}" def _on_message(self, msg: Any, topic: Any) -> None: - """Handle incoming message - log to rerun.""" - entity_path: str = self._get_entity_path(topic) # Throttle entities with a max_hz limit @@ -279,7 +268,6 @@ def _on_message(self, msg: Any, topic: Any) -> None: return self._last_log[entity_path] = now - # apply visual overrides (including final_convert which handles .to_rerun()) rerun_data: RerunData | None = self._visual_override_for_entity_path(entity_path)(msg) if not rerun_data: @@ -296,47 +284,87 @@ def _on_message(self, msg: Any, topic: Any) -> None: def start(self) -> None: super().start() - logger.info("Rerun bridge starting", viewer_mode=self.config.viewer_mode) + logger.info("Rerun bridge starting") - # Build throttle lookup: entity_path → min interval in seconds - self._last_log: dict[str, float] = {} + self._last_log = {} self._min_intervals: dict[str, float] = { entity: 1.0 / hz for entity, hz in self.config.max_hz.items() if hz > 0 } - # Initialize and spawn Rerun viewer rerun_init("dimos") - if self.config.viewer_mode == "native": + parsed = urlparse(self.config.connect_url.replace("rerun+", "", 1)) + grpc_port = parsed.port or RERUN_GRPC_PORT + + port_in_use = False + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + port_in_use = sock.connect_ex(("127.0.0.1", grpc_port)) == 0 + + if port_in_use: + logger.info(f"gRPC port {grpc_port} already in use, connecting to existing server") + rr.connect_grpc(url=self.config.connect_url) + server_uri = self.config.connect_url + else: + server_uri = rr.serve_grpc( + grpc_port=grpc_port, + server_memory_limit=self.config.memory_limit, + ) + logger.info(f"Rerun gRPC server ready at {server_uri}") + + if self.config.rerun_open not in get_args(RerunOpenOption): + logger.warning( + f"rerun_open was {self.config.rerun_open} which is not one of " + f"{get_args(RerunOpenOption)}" + ) + + spawned = False + if self.config.rerun_open in ("native", "both"): try: import rerun_bindings + # Use --connect so the viewer connects to the bridge's gRPC + # server rather than starting its own (which would conflict). rerun_bindings.spawn( - port=self.config.grpc_port, executable_name="dimos-viewer", memory_limit=self.config.memory_limit, + extra_args=["--connect", server_uri], ) - rr.connect_grpc(f"rerun+http://127.0.0.1:{self.config.grpc_port}/proxy") + spawned = True except ImportError: - rr.spawn(connect=True, memory_limit=self.config.memory_limit) + pass # dimos-viewer not installed except Exception: logger.warning( "dimos-viewer found but failed to spawn, falling back to stock rerun", exc_info=True, ) - rr.spawn(connect=True, memory_limit=self.config.memory_limit) - elif self.config.viewer_mode == "web": - server_uri = rr.serve_grpc() - rr.serve_web_viewer(connect_to=server_uri, open_browser=False) - elif self.config.viewer_mode == "connect": - rr.connect_grpc(self.config.connect_url) - # "none" - just init, no viewer (connect externally) + # fallback on normal (non-dimos-viewer) rerun + if not spawned: + try: + rr.spawn(connect=True, memory_limit=self.config.memory_limit) + spawned = True + except (RuntimeError, FileNotFoundError): + logger.warning( + "Rerun native viewer not available (headless?). " + "Bridge will continue without a viewer — data is still " + "accessible via --rerun-open web or by connecting a viewer to the gRPC server.", + exc_info=True, + ) + + open_web = self.config.rerun_open == "web" or self.config.rerun_open == "both" + if open_web or self.config.rerun_web: + rr.serve_web_viewer( + connect_to=server_uri, + open_browser=open_web, + web_port=self.config.web_port, + ) + + if self.config.rerun_open == "none" or (self.config.rerun_open == "native" and not spawned): + self._log_connect_hints(grpc_port) if self.config.blueprint: rr.send_blueprint(_with_graph_tab(self.config.blueprint())) - # Start pubsubs and subscribe to all messages for pubsub in self.config.pubsubs: logger.info(f"bridge listening on {pubsub.__class__.__name__}") if hasattr(pubsub, "start"): @@ -344,13 +372,35 @@ def start(self) -> None: unsub = pubsub.subscribe_all(self._on_message) self.register_disposable(Disposable(unsub)) - # Add pubsub stop as disposable for pubsub in self.config.pubsubs: if hasattr(pubsub, "stop"): self.register_disposable(Disposable(pubsub.stop)) # type: ignore[union-attr] self._log_static() + def _log_connect_hints(self, grpc_port: int) -> None: + """Log CLI commands for connecting a viewer to this bridge.""" + local_ips = get_local_ips() + hostname = socket.gethostname() + connect_url = f"rerun+http://127.0.0.1:{grpc_port}/proxy" + + lines = [ + "", + "=" * 60, + "Rerun gRPC server running (no viewer opened)", + "", + "Connect a viewer:", + f" dimos-viewer --connect {connect_url}", + ] + for ip, iface in local_ips: + lines.append(f" dimos-viewer --connect rerun+http://{ip}:{grpc_port}/proxy # {iface}") + lines.append("") + lines.append(f" hostname: {hostname}") + lines.append("=" * 60) + lines.append("") + + logger.info("\n".join(lines)) + def _log_static(self) -> None: for entity_path, factory in self.config.static.items(): data = factory(rr) @@ -371,7 +421,6 @@ def log_blueprint_graph(self, dot_code: str, module_names: list[str]) -> None: dot_code: The DOT-format graph (from ``introspection.blueprint.dot.render``). module_names: List of module class names (to distinguish modules from channels). """ - try: result = subprocess.run( ["dot", "-Tplain"], input=dot_code, text=True, capture_output=True, timeout=30 @@ -393,8 +442,8 @@ def log_blueprint_graph(self, dot_code: str, module_names: list[str]) -> None: if line.startswith("node "): parts = line.split() node_id = parts[1].strip('"') - x = float(parts[2]) * self.GV_SCALE - y = -float(parts[3]) * self.GV_SCALE + x = float(parts[2]) * self.GRAPH_VIZ_SCALE + y = -float(parts[3]) * self.GRAPH_VIZ_SCALE label = parts[6].strip('"') color = parts[9].strip('"') @@ -427,49 +476,5 @@ def log_blueprint_graph(self, dot_code: str, module_names: list[str]) -> None: @rpc def stop(self) -> None: + self._override_cache.clear() super().stop() - - -def run_bridge( - viewer_mode: str = "native", - memory_limit: str = "25%", -) -> None: - """Start a RerunBridgeModule with default LCM config and block until interrupted.""" - import signal - - from dimos.protocol.service.lcmservice import autoconf - - autoconf(check_only=True) - - bridge = RerunBridgeModule( - viewer_mode=viewer_mode, - memory_limit=memory_limit, - # any pubsub that supports subscribe_all and topic that supports str(topic) - # is acceptable here - pubsubs=[LCM()], - ) - - bridge.start() - - signal.signal(signal.SIGINT, lambda *_: bridge.stop()) - signal.pause() - - -app = typer.Typer() - - -@app.command() -def cli( - viewer_mode: str = typer.Option( - "native", help="Viewer mode: native (desktop), web (browser), none (headless)" - ), - memory_limit: str = typer.Option( - "25%", help="Memory limit for Rerun viewer (e.g., '4GB', '16GB', '25%')" - ), -) -> None: - """Rerun bridge for LCM messages.""" - run_bridge(viewer_mode=viewer_mode, memory_limit=memory_limit) - - -if __name__ == "__main__": - app() diff --git a/dimos/visualization/rerun/conftest.py b/dimos/visualization/rerun/conftest.py new file mode 100644 index 0000000000..f269bb8015 --- /dev/null +++ b/dimos/visualization/rerun/conftest.py @@ -0,0 +1,45 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from collections.abc import Callable +import time + +import pytest +import websockets.asyncio.client as ws_client + + +def _wait_for_server(port: int, timeout: float = 5.0) -> None: + """Block until the WebSocket server on *port* accepts a connection.""" + + async def _probe() -> None: + async with ws_client.connect(f"ws://127.0.0.1:{port}/ws"): + pass + + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + asyncio.run(_probe()) + return + except Exception: + time.sleep(0.05) + raise TimeoutError(f"Server on port {port} did not become ready within {timeout}s") + + +@pytest.fixture() +def wait_for_server() -> Callable[[int, float], None]: + """Fixture that returns a callable to wait for a WebSocket server.""" + return _wait_for_server diff --git a/dimos/visualization/rerun/constants.py b/dimos/visualization/rerun/constants.py new file mode 100644 index 0000000000..860c691cef --- /dev/null +++ b/dimos/visualization/rerun/constants.py @@ -0,0 +1,31 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Rerun visualization defaults and type aliases. + +This module is intentionally free of heavy imports so it can be +loaded from lightweight entry-points like ``global_config`` and +``dimos --help`` without pulling in the Rerun SDK or the module +framework. +""" + +from typing import Literal, TypeAlias + +ViewerBackend: TypeAlias = Literal["rerun", "foxglove", "none"] +RerunOpenOption: TypeAlias = Literal["none", "web", "native", "both"] + +RERUN_OPEN_DEFAULT: RerunOpenOption = "native" +RERUN_ENABLE_WEB = False +RERUN_GRPC_PORT = 9876 +RERUN_WEB_PORT = 9877 diff --git a/dimos/visualization/rerun/test_viewer_ws_e2e.py b/dimos/visualization/rerun/test_viewer_ws_e2e.py new file mode 100644 index 0000000000..260699a3e8 --- /dev/null +++ b/dimos/visualization/rerun/test_viewer_ws_e2e.py @@ -0,0 +1,201 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end tests for dimos-viewer ↔ RerunWebSocketServer protocol.""" + +from __future__ import annotations + +import asyncio +import json +import os +import subprocess +import threading +import time +from typing import Any + +import pytest +import websockets.asyncio.client as ws_client + +from dimos.visualization.rerun.websocket_server import RerunWebSocketServer + +_E2E_PORT = 13032 + + +@pytest.fixture() +def server(wait_for_server: Any) -> RerunWebSocketServer: + module = RerunWebSocketServer(port=_E2E_PORT) + module.start() + wait_for_server(_E2E_PORT) + yield module # type: ignore[misc] + module.stop() + + +def _send_messages(port: int, messages: list[dict[str, Any]], *, delay: float = 0.05) -> None: + async def _run() -> None: + async with ws_client.connect(f"ws://127.0.0.1:{port}/ws") as ws: + for msg in messages: + await ws.send(json.dumps(msg)) + await asyncio.sleep(delay) + + asyncio.run(_run()) + + +class TestViewerProtocolE2E: + """Verify the Python-server side of the viewer ↔ DimOS protocol.""" + + def test_viewer_click_reaches_stream(self, server: RerunWebSocketServer) -> None: + """A viewer click over WebSocket publishes PointStamped.""" + received: list[Any] = [] + done = threading.Event() + unsub = server.clicked_point.subscribe(lambda pt: (received.append(pt), done.set())) + + _send_messages( + _E2E_PORT, + [ + { + "type": "click", + "x": 10.0, + "y": 20.0, + "z": 0.5, + "entity_path": "/world/robot", + "timestamp_ms": 42000, + } + ], + ) + + done.wait(timeout=3.0) + unsub() + + assert len(received) == 1 + pt = received[0] + assert pt.x == pytest.approx(10.0) + assert pt.y == pytest.approx(20.0) + assert pt.z == pytest.approx(0.5) + assert pt.frame_id == "/world/robot" + assert pt.ts == pytest.approx(42.0) + + def test_full_viewer_session_sequence(self, server: RerunWebSocketServer) -> None: + """Realistic session: heartbeats, click, twist, stop — only the click produces a point.""" + received: list[Any] = [] + done = threading.Event() + unsub = server.clicked_point.subscribe(lambda pt: (received.append(pt), done.set())) + + _send_messages( + _E2E_PORT, + [ + {"type": "heartbeat", "timestamp_ms": 1000}, + {"type": "heartbeat", "timestamp_ms": 2000}, + { + "type": "click", + "x": 3.14, + "y": 2.71, + "z": 1.41, + "entity_path": "/world", + "timestamp_ms": 3000, + }, + { + "type": "twist", + "linear_x": 0.5, + "linear_y": 0.0, + "linear_z": 0.0, + "angular_x": 0.0, + "angular_y": 0.0, + "angular_z": 0.0, + }, + {"type": "stop"}, + {"type": "heartbeat", "timestamp_ms": 4000}, + ], + delay=0.2, + ) + + done.wait(timeout=3.0) + unsub() + + assert len(received) == 1, f"Expected exactly 1 click, got {len(received)}" + assert received[0].x == pytest.approx(3.14) + assert received[0].y == pytest.approx(2.71) + assert received[0].z == pytest.approx(1.41) + + def test_reconnect_after_disconnect(self, server: RerunWebSocketServer) -> None: + """Server keeps accepting new connections after a client disconnects.""" + received: list[Any] = [] + all_done = threading.Event() + + def _on_pt(pt: Any) -> None: + received.append(pt) + if len(received) >= 2: + all_done.set() + + unsub = server.clicked_point.subscribe(_on_pt) + + _send_messages( + _E2E_PORT, + [{"type": "click", "x": 1.0, "y": 0.0, "z": 0.0, "entity_path": "", "timestamp_ms": 0}], + ) + _send_messages( + _E2E_PORT, + [{"type": "click", "x": 2.0, "y": 0.0, "z": 0.0, "entity_path": "", "timestamp_ms": 0}], + ) + + all_done.wait(timeout=5.0) + unsub() + + xs = sorted(pt.x for pt in received) + assert xs == [1.0, 2.0], f"Unexpected xs: {xs}" + + +class TestViewerBinaryConnectMode: + """Smoke test: dimos-viewer binary starts in --connect mode.""" + + @pytest.fixture() + def viewer_process(self, server: RerunWebSocketServer) -> subprocess.Popen[bytes]: + proc = subprocess.Popen( + [ + "dimos-viewer", + "--connect", + f"--ws-url=ws://127.0.0.1:{_E2E_PORT}/ws", + ], + env={**os.environ, "DISPLAY": ""}, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + yield proc # type: ignore[misc] + proc.terminate() + try: + proc.wait(timeout=3) + except subprocess.TimeoutExpired: + proc.kill() + + @pytest.mark.skip( + reason="Incompatible with current winit: fails without DISPLAY (headless CI exits before WS connect) and hangs with DISPLAY (GUI event loop blocks before printing URL).", + ) + def test_viewer_ws_client_connects(self, viewer_process: subprocess.Popen[bytes]) -> None: + """dimos-viewer --connect starts and its WS client connects to our server.""" + deadline = time.monotonic() + 5.0 + while time.monotonic() < deadline: + if viewer_process.poll() is not None: + break + time.sleep(0.1) + + stdout = ( + viewer_process.stdout.read().decode(errors="replace") if viewer_process.stdout else "" + ) + stderr = ( + viewer_process.stderr.read().decode(errors="replace") if viewer_process.stderr else "" + ) + + combined = stdout + stderr + assert f"ws://127.0.0.1:{_E2E_PORT}" in combined, ( + f"Viewer did not attempt WS connection.\nstdout:\n{stdout}\nstderr:\n{stderr}" + ) diff --git a/dimos/visualization/rerun/test_websocket_server.py b/dimos/visualization/rerun/test_websocket_server.py new file mode 100644 index 0000000000..b4304cf7b4 --- /dev/null +++ b/dimos/visualization/rerun/test_websocket_server.py @@ -0,0 +1,210 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RerunWebSocketServer.""" + +from __future__ import annotations + +import asyncio +import json +import threading +import time +from typing import Any + +import pytest +import websockets.asyncio.client as ws_client + +from dimos.visualization.rerun.websocket_server import RerunWebSocketServer + +_TEST_PORT = 13031 + + +class MockViewerPublisher: + """Simulates dimos-viewer sending JSON events over WebSocket.""" + + def __init__(self, url: str) -> None: + self._url = url + self._ws: Any = None + self._loop: asyncio.AbstractEventLoop | None = None + + def __enter__(self) -> MockViewerPublisher: + self._loop = asyncio.new_event_loop() + self._ws = self._loop.run_until_complete(self._connect()) + return self + + def __exit__(self, *_: Any) -> None: + if self._ws is not None and self._loop is not None: + self._loop.run_until_complete(self._ws.close()) + if self._loop is not None: + self._loop.close() + + async def _connect(self) -> Any: + return await ws_client.connect(self._url) + + def send_click( + self, x: float, y: float, z: float, entity_path: str = "", timestamp_ms: int = 0 + ) -> None: + self._send( + { + "type": "click", + "x": x, + "y": y, + "z": z, + "entity_path": entity_path, + "timestamp_ms": timestamp_ms, + } + ) + + def send_twist( + self, + linear_x: float, + linear_y: float, + linear_z: float, + angular_x: float, + angular_y: float, + angular_z: float, + ) -> None: + self._send( + { + "type": "twist", + "linear_x": linear_x, + "linear_y": linear_y, + "linear_z": linear_z, + "angular_x": angular_x, + "angular_y": angular_y, + "angular_z": angular_z, + } + ) + + def send_stop(self) -> None: + self._send({"type": "stop"}) + + def flush(self, delay: float = 0.1) -> None: + time.sleep(delay) + + def _send(self, msg: dict[str, Any]) -> None: + assert self._loop is not None and self._ws is not None + self._loop.run_until_complete(self._ws.send(json.dumps(msg))) + + +@pytest.fixture() +def server(wait_for_server: Any) -> RerunWebSocketServer: + module = RerunWebSocketServer(port=_TEST_PORT) + module.start() + wait_for_server(_TEST_PORT) + yield module # type: ignore[misc] + module.stop() + + +@pytest.fixture() +def publisher(server: RerunWebSocketServer) -> MockViewerPublisher: + with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as publisher: + yield publisher # type: ignore[misc] + + +# ── Tests ──────────────────────────────────────────────────────────────── + + +def test_click_publishes_point_stamped( + server: RerunWebSocketServer, publisher: MockViewerPublisher +) -> None: + """Click event arrives as PointStamped with correct coords, frame_id, and timestamp.""" + received: list[Any] = [] + done = threading.Event() + + unsub = server.clicked_point.subscribe(lambda point: (received.append(point), done.set())) + + publisher.send_click(1.5, 2.5, 0.0, "/robot/base", timestamp_ms=5000) + publisher.flush() + done.wait(timeout=2.0) + unsub() + + assert len(received) == 1 + point = received[0] + assert point.x == pytest.approx(1.5) + assert point.y == pytest.approx(2.5) + assert point.z == pytest.approx(0.0) + assert point.frame_id == "/robot/base" + assert point.ts == pytest.approx(5.0) + + +def test_twist_publishes_on_tele_cmd_vel( + server: RerunWebSocketServer, publisher: MockViewerPublisher +) -> None: + """Twist event arrives as Twist on tele_cmd_vel.""" + received: list[Any] = [] + done = threading.Event() + + unsub = server.tele_cmd_vel.subscribe(lambda twist: (received.append(twist), done.set())) + + publisher.send_twist(0.5, 0.0, 0.0, 0.0, 0.0, 0.8) + publisher.flush() + done.wait(timeout=2.0) + unsub() + + assert len(received) == 1 + assert received[0].linear.x == pytest.approx(0.5) + assert received[0].angular.z == pytest.approx(0.8) + + +def test_stop_publishes_zero_twist( + server: RerunWebSocketServer, publisher: MockViewerPublisher +) -> None: + """Stop event publishes a zero Twist on tele_cmd_vel.""" + received: list[Any] = [] + done = threading.Event() + + unsub = server.tele_cmd_vel.subscribe(lambda twist: (received.append(twist), done.set())) + + publisher.send_stop() + publisher.flush() + done.wait(timeout=2.0) + unsub() + + assert len(received) == 1 + assert received[0].is_zero() + + +def test_invalid_json_does_not_crash(server: RerunWebSocketServer) -> None: + """Malformed JSON is silently dropped; server stays alive for the next message.""" + + async def _send_bad() -> None: + async with ws_client.connect(f"ws://127.0.0.1:{_TEST_PORT}/ws") as ws: + await ws.send("this is not json {{") + await asyncio.sleep(0.1) + await ws.send(json.dumps({"type": "heartbeat", "timestamp_ms": 0})) + await asyncio.sleep(0.1) + + asyncio.run(_send_bad()) + + +def test_mixed_message_sequence( + server: RerunWebSocketServer, publisher: MockViewerPublisher +) -> None: + """Realistic session: heartbeat, click, twist, stop — only the click produces a point.""" + received: list[Any] = [] + done = threading.Event() + unsub = server.clicked_point.subscribe(lambda point: (received.append(point), done.set())) + + publisher.send_click(7.0, 8.0, 9.0, "/map", timestamp_ms=1100) + publisher.send_twist(0.3, 0.0, 0.0, 0.0, 0.0, 0.2) + publisher.send_stop() + publisher.flush() + done.wait(timeout=2.0) + unsub() + + assert len(received) == 1 + assert received[0].x == pytest.approx(7.0) + assert received[0].y == pytest.approx(8.0) + assert received[0].z == pytest.approx(9.0) diff --git a/dimos/visualization/rerun/websocket_server.py b/dimos/visualization/rerun/websocket_server.py new file mode 100644 index 0000000000..0c0ac2acf2 --- /dev/null +++ b/dimos/visualization/rerun/websocket_server.py @@ -0,0 +1,244 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""WebSocket server module that receives events from dimos-viewer. + +When dimos-viewer is started with ``--connect``, LCM multicast is unavailable +across machines. The viewer falls back to sending click, twist, and stop events +as JSON over a WebSocket connection. This module acts as the server-side +counterpart: it listens for those connections and translates incoming messages +into DimOS stream publishes. + +Message format (newline-delimited JSON, ``"type"`` discriminant): + + {"type":"heartbeat","timestamp_ms":1234567890} + {"type":"click","x":1.0,"y":2.0,"z":3.0,"entity_path":"/world","timestamp_ms":1234567890} + {"type":"twist","linear_x":0.5,"linear_y":0.0,"linear_z":0.0, + "angular_x":0.0,"angular_y":0.0,"angular_z":0.8} + {"type":"stop"} +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import socket +import threading +from typing import Any, Literal, TypedDict, Union + +import websockets +import websockets.asyncio.server as ws_server + +from dimos.core.core import rpc +from dimos.core.global_config import global_config +from dimos.core.module import Module, ModuleConfig +from dimos.core.stream import Out +from dimos.msgs.geometry_msgs.PointStamped import PointStamped +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.utils.generic import get_local_ips +from dimos.utils.logging_config import setup_logger +from dimos.visualization.rerun.constants import RERUN_GRPC_PORT + +logger = setup_logger() + + +class ClickMsg(TypedDict): + type: Literal["click"] + x: float + y: float + z: float + entity_path: str + timestamp_ms: int + + +class TwistMsg(TypedDict): + type: Literal["twist"] + linear_x: float + linear_y: float + linear_z: float + angular_x: float + angular_y: float + angular_z: float + + +class StopMsg(TypedDict): + type: Literal["stop"] + + +class HeartbeatMsg(TypedDict): + type: Literal["heartbeat"] + timestamp_ms: int + + +ViewerMsg = Union[ClickMsg, TwistMsg, StopMsg, HeartbeatMsg] + + +def _handshake_noise_filter(record: logging.LogRecord) -> bool: + """Drop noisy "opening handshake failed" records from port scanners etc.""" + msg = record.getMessage() + return not ("opening handshake failed" in msg or "did not receive a valid HTTP request" in msg) + + +class Config(ModuleConfig): + host: str | None = None + port: int = 3030 + start_timeout: float = 10.0 + + +class RerunWebSocketServer(Module): + """Receives dimos-viewer WebSocket events and publishes them as DimOS streams. + + The viewer connects to this module (not the other way around) when running + in ``--connect`` mode. Each click event is converted to a ``PointStamped`` + and published on the ``clicked_point`` stream so downstream modules (e.g. + ``ReplanningAStarPlanner``) can consume it without modification. + + Outputs: + clicked_point: 3-D world-space point from the most recent viewer click. + tele_cmd_vel: Twist velocity commands from keyboard teleop, including stop events. + + Note: ``stop_movement`` is owned by ``MovementManager`` — it will fire + that signal when it sees the first teleop twist arrive here. + """ + + config: Config + + clicked_point: Out[PointStamped] + tele_cmd_vel: Out[Twist] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._stop_event: asyncio.Event | None = None + self._server_ready = threading.Event() + self.host = self.config.host if self.config.host is not None else global_config.listen_host + + @rpc + def start(self) -> None: + super().start() + assert self._loop is not None + asyncio.run_coroutine_threadsafe(self._serve(), self._loop) + self._server_ready.wait(timeout=self.config.start_timeout) + self._log_connect_hints() + + @rpc + def stop(self) -> None: + self._server_ready.wait(timeout=self.config.start_timeout) + if self._loop is not None and not self._loop.is_closed() and self._stop_event is not None: + self._loop.call_soon_threadsafe(self._stop_event.set) + super().stop() + + def _log_connect_hints(self) -> None: + """Log full dimos-viewer commands that viewers can use to connect.""" + local_ips = get_local_ips() + hostname = socket.gethostname() + host = self.host + ws_url = f"ws://{host}:{self.config.port}/ws" + grpc_url = f"rerun+http://{host}:{RERUN_GRPC_PORT}/proxy" + + lines = [ + "", + "=" * 60, + f"RerunWebSocketServer listening on {ws_url}", + "", + "Connect a viewer:", + f" dimos-viewer --connect {grpc_url} --ws-url {ws_url}", + ] + if local_ips: + lines.append("") + lines.append("From another machine on the network:") + for ip, iface in local_ips: + remote_grpc = f"rerun+http://{ip}:{RERUN_GRPC_PORT}/proxy" + remote_ws = f"ws://{ip}:{self.config.port}/ws" + lines.append( + f" dimos-viewer --connect {remote_grpc} --ws-url {remote_ws} # {iface}" + ) + lines.append("") + lines.append(f" hostname: {hostname}") + lines.append("=" * 60) + lines.append("") + + logger.info("\n".join(lines)) + + async def _serve(self) -> None: + self._stop_event = asyncio.Event() + + ws_logger = logging.getLogger("websockets.server") + ws_logger.addFilter(_handshake_noise_filter) + + async with ws_server.serve( + self._handle_client, + host=self.host, + port=self.config.port, + ping_interval=30, + ping_timeout=30, + logger=ws_logger, + ): + self._server_ready.set() + await self._stop_event.wait() + + async def _handle_client(self, websocket: Any) -> None: + if hasattr(websocket, "request") and websocket.request.path != "/ws": + await websocket.close(1008, "Not Found") + return + addr = websocket.remote_address + logger.info(f"RerunWebSocketServer: viewer connected from {addr}") + try: + async for raw in websocket: + self._dispatch(raw) + except websockets.ConnectionClosed: + pass + + def _dispatch(self, raw: str | bytes) -> None: + try: + msg: dict[str, Any] = json.loads(raw) + except json.JSONDecodeError: + logger.warning(f"RerunWebSocketServer: ignoring non-JSON message: {raw!r}") + return + + if not isinstance(msg, dict): + return + + msg_type = msg.get("type") + + if msg_type == "click": + self.clicked_point.publish( + PointStamped( + x=float(msg.get("x", 0)), + y=float(msg.get("y", 0)), + z=float(msg.get("z", 0)), + ts=float(msg.get("timestamp_ms", 0)) / 1000.0, + frame_id=str(msg.get("entity_path", "")), + ) + ) + + elif msg_type == "twist": + self.tele_cmd_vel.publish( + Twist( + linear=Vector3( + float(msg.get("linear_x", 0)), + float(msg.get("linear_y", 0)), + float(msg.get("linear_z", 0)), + ), + angular=Vector3( + float(msg.get("angular_x", 0)), + float(msg.get("angular_y", 0)), + float(msg.get("angular_z", 0)), + ), + ) + ) + + elif msg_type == "stop": + self.tele_cmd_vel.publish(Twist.zero()) diff --git a/dimos/visualization/vis_module.py b/dimos/visualization/vis_module.py new file mode 100644 index 0000000000..badcba34db --- /dev/null +++ b/dimos/visualization/vis_module.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared visualization module factory for all robot blueprints.""" + +from typing import Any, get_args + +from dimos.core.coordination.blueprints import Blueprint, autoconnect +from dimos.visualization.rerun.constants import ViewerBackend + + +def vis_module( + viewer_backend: ViewerBackend, + rerun_config: dict[str, Any] | None = None, + foxglove_config: dict[str, Any] | None = None, +) -> Blueprint: + """Create a visualization blueprint based on the selected viewer backend. + + Bundles the appropriate viewer module (Rerun or Foxglove) together with + the ``WebsocketVisModule`` and ``RerunWebSocketServer`` so that the web + dashboard and remote viewer connections work out of the box. + + Example usage:: + + from dimos.core.global_config import global_config + viz = vis_module( + global_config.viewer, + rerun_config={ + "visual_override": { + "world/camera_info": lambda ci: ci.to_rerun(...), + }, + "static": { + "world/tf/base_link": lambda rr: [rr.Boxes3D(...)], + }, + }, + ) + """ + from dimos.visualization.rerun.websocket_server import RerunWebSocketServer + from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule + + if foxglove_config is None: + foxglove_config = {} + if rerun_config is None: + rerun_config = {} + + match viewer_backend: + case "foxglove": + from dimos.robot.foxglove_bridge import FoxgloveBridge + + return autoconnect( + FoxgloveBridge.blueprint(**foxglove_config), + RerunWebSocketServer.blueprint(), + WebsocketVisModule.blueprint(), + ) + case "rerun": + from dimos.core.global_config import global_config + from dimos.protocol.pubsub.impl.lcmpubsub import LCM + from dimos.visualization.rerun.bridge import RerunBridgeModule + + rerun_config = {**rerun_config} # copy (avoid mutation) + rerun_config.setdefault("pubsubs", [LCM()]) + rerun_config.setdefault("rerun_open", global_config.rerun_open) + rerun_config.setdefault("rerun_web", global_config.rerun_web) + return autoconnect( + RerunBridgeModule.blueprint( + **rerun_config, + ), + RerunWebSocketServer.blueprint(), + WebsocketVisModule.blueprint(), + ) + case "none": + return autoconnect(WebsocketVisModule.blueprint()) + case _: + valid = ", ".join(get_args(ViewerBackend)) + raise ValueError(f"Unknown viewer_backend {viewer_backend!r}. Expected one of: {valid}") diff --git a/dimos/web/websocket_vis/websocket_vis_module.py b/dimos/web/websocket_vis/websocket_vis_module.py index 3d6b3df11c..1ce7e74502 100644 --- a/dimos/web/websocket_vis/websocket_vis_module.py +++ b/dimos/web/websocket_vis/websocket_vis_module.py @@ -105,7 +105,7 @@ class WebsocketVisModule(Module): gps_goal: Out[LatLon] explore_cmd: Out[Bool] stop_explore_cmd: Out[Bool] - cmd_vel: Out[Twist] + tele_cmd_vel: Out[Twist] movecmd_stamped: Out[TwistStamped] def __init__(self, **kwargs: Any) -> None: @@ -158,9 +158,11 @@ def start(self) -> None: self._uvicorn_server_thread = threading.Thread(target=self._run_uvicorn_server, daemon=True) self._uvicorn_server_thread.start() - # Auto-open browser only for rerun-web (dashboard with Rerun iframe + command center) - # For rerun and foxglove, users access the command center manually if needed - if self.config.g.viewer == "rerun-web": + # Auto-open the dashboard tab only when the user explicitly asked for a + # web-based viewer (rerun_open == "web" or "both"). `rerun_web` alone + # only means "serve the viewer"; it should not trigger a browser popup + # when the user chose the native viewer. + if self.config.g.viewer == "rerun" and self.config.g.rerun_open in ("web", "both"): url = f"http://localhost:{self.config.port}/" logger.info(f"Dimensional Command Center: {url}") @@ -236,11 +238,13 @@ def _create_server(self) -> None: async def serve_index(request): # type: ignore[no-untyped-def] """Serve appropriate HTML based on viewer mode.""" - # If running native Rerun, redirect to standalone command center - if self.config.g.viewer != "rerun-web": + # Serve the full dashboard (with Rerun iframe) only when the rerun + # web server is enabled; otherwise redirect to the standalone + # command center. + if not ( + self.config.g.viewer == "rerun" and self.config.g.rerun_open in ("web", "both") + ): return RedirectResponse(url="/command-center") - - # Otherwise serve full dashboard with Rerun iframe return FileResponse(_DASHBOARD_HTML, media_type="text/html") async def serve_command_center(request): # type: ignore[no-untyped-def] @@ -333,14 +337,14 @@ async def clear_gps_goals(sid: str) -> None: @self.sio.event # type: ignore[untyped-decorator] async def move_command(sid: str, data: dict[str, Any]) -> None: # Publish Twist if transport is configured - if self.cmd_vel and self.cmd_vel.transport: + if self.tele_cmd_vel and self.tele_cmd_vel.transport: twist = Twist( linear=Vector3(data["linear"]["x"], data["linear"]["y"], data["linear"]["z"]), angular=Vector3( data["angular"]["x"], data["angular"]["y"], data["angular"]["z"] ), ) - self.cmd_vel.publish(twist) + self.tele_cmd_vel.publish(twist) # Publish TwistStamped if transport is configured if self.movecmd_stamped and self.movecmd_stamped.transport: diff --git a/docs/development/conventions.md b/docs/development/conventions.md new file mode 100644 index 0000000000..2b25a7c3c6 --- /dev/null +++ b/docs/development/conventions.md @@ -0,0 +1,12 @@ +This mostly to track when conventions change (with regard to codebase updates) because this codebase is under heavy development. Note: this is a non-exhaustive list of conventions. + +- Instead of using `RerunBridge` in blueprints we always use `vis_module` which allows the CLI to control if its foxglove, rerun, or no-vis at all +- When global_config.py shouldn't accidentally/indirectly import heavy libraries like rerun. But sometimes global_config needs the type definition or default value from a module. Preferably we import from the module file directly, however when thats not possible, we create a config.py for just that module's config and import that into global_config.py. +- When adding visualization tools to a blueprint/autoconnect, instead of using RerunBridge or WebsocketVisModule directly we should always use `vis_module`, which right now should look something like `vis_module(viewer_backend=global_config.viewer, rerun_config={}),` +- `DEFAULT_THREAD_JOIN_TIMEOUT` is used for all thread.join timeouts +- Don't use print inside of tests +- Module configs should be specified as `config: ModuleSpecificConfigClass` +- To customize the way rerun renders something, right now we use a `rerun_config` dict. This will (hopefully) change very soon to be a per-module config instead of a per-blueprint config +- Similar to the `rerun_config` the `rrb` (rerun blueprint) is defined at a blueprint level right now, but ideally would be a per-module contribution with only a per-blueprint override of the layout. +- No `__init__.py` files +- Helper blueprints (like `_with_vis`) that should not be used on their own need to start with an underscore to avoid being picked up by the all_blueprints.py code generation step diff --git a/docs/usage/cli.md b/docs/usage/cli.md index 017b441c7e..bba73368b2 100644 --- a/docs/usage/cli.md +++ b/docs/usage/cli.md @@ -18,7 +18,9 @@ dimos [GLOBAL OPTIONS] COMMAND [ARGS] | `--replay` / `--no-replay` | bool | `False` | Use recorded replay data | | `--replay-db` | TEXT | `go2_bigoffice` | Replay memory2 SQLite database name | | `--new-memory` / `--no-new-memory` | bool | `False` | Clear persistent memory on start | -| `--viewer` | `rerun\|rerun-web\|rerun-connect\|foxglove\|none` | `rerun` | Visualization backend | +| `--viewer` | `rerun\|foxglove\|none` | `rerun` | Visualization backend | +| `--rerun-open` | `native\|web\|both\|none` | `native` | How to open the Rerun viewer | +| `--rerun-web` / `--no-rerun-web` | bool | `False` | Serve the Rerun web viewer | | `--n-workers` | INT | `2` | Number of forkserver workers | | `--memory-limit` | TEXT | `auto` | Rerun viewer memory limit | | `--mcp-port` | INT | `9990` | MCP server port | diff --git a/docs/usage/visualization.md b/docs/usage/visualization.md index 57ad460354..9ece977a68 100644 --- a/docs/usage/visualization.md +++ b/docs/usage/visualization.md @@ -1,37 +1,43 @@ # Viewer Backends -Dimos supports three visualization backends: Rerun (web or native) and Foxglove. +Dimos supports three visualization backends: `rerun` (default), `foxglove`, and `none`. ## Quick Start -Choose your viewer via the CLI (preferred): +Choose your viewer via the CLI: ```bash # Rerun native viewer (default) - dimos-viewer with built-in teleop + click-to-navigate dimos run unitree-go2 -# Explicitly select the viewer mode: +# Explicitly select the viewer backend: dimos --viewer rerun run unitree-go2 -dimos --viewer rerun-web run unitree-go2 dimos --viewer foxglove run unitree-go2 +dimos --viewer none run unitree-go2 ``` -Alternative (environment variable): +Control how the Rerun viewer opens with `--rerun-open` and `--rerun-web`: ```bash -# Rerun native viewer (default) - dimos-viewer with built-in teleop + click-to-navigate -VIEWER=rerun dimos run unitree-go2 +# Open native desktop viewer (default) +dimos --rerun-open native run unitree-go2 + +# Open web viewer in browser +dimos --rerun-open web run unitree-go2 + +# Open both native and web +dimos --rerun-open both run unitree-go2 -# Rerun web viewer - browser dashboard + teleop at http://localhost:7779 -VIEWER=rerun-web dimos run unitree-go2 +# No viewer (headless) — data still accessible via gRPC +dimos --rerun-open none run unitree-go2 -# Foxglove - Use Foxglove Studio instead of Rerun -VIEWER=foxglove dimos run unitree-go2 +# Serve the web viewer without auto-opening a browser +dimos --rerun-web --rerun-open native run unitree-go2 ``` ## Viewer Modes Explained -### Rerun Native (`rerun`) — Default +### Rerun Native (`rerun`, `--rerun-open native`) — Default **What you get:** - [dimos-viewer](https://github.com/dimensionalOS/dimos-viewer), a custom Dimensional fork of Rerun with built-in keyboard teleop and click-to-navigate @@ -41,7 +47,7 @@ VIEWER=foxglove dimos run unitree-go2 --- -### Rerun Web (`rerun-web`) +### Rerun Web (`rerun`, `--rerun-open web`) **What you get:** - Browser-based dashboard at http://localhost:7779 @@ -63,18 +69,16 @@ VIEWER=foxglove dimos run unitree-go2 ## Rendering with Custom Blueprints -To enable rerun within your own blueprint simply include `RerunBridgeModule`: +To enable visualization in your own blueprint, use `vis_module`: ```python -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.core.global_config import global_config +from dimos.visualization.vis_module import vis_module from dimos.hardware.sensors.camera.module import CameraModule -from dimos.protocol.pubsub.impl.lcmpubsub import LCM camera_demo = autoconnect( CameraModule.blueprint(), - RerunBridgeModule.blueprint( - viewer_mode="native", # native (desktop), web (browser), none (headless) - ), + vis_module(viewer_backend=global_config.viewer), ) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index 398903f457..1ca77c7c7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,7 +82,7 @@ dependencies = [ # TODO: rerun shouldn't be required but rn its in core (there is NO WAY to use dimos without rerun rn) # remove this once rerun is optional in core "rerun-sdk>=0.20.0", - "dimos-viewer>=0.30.0a2", + "dimos-viewer==0.30.0a6.dev99", "toolz>=1.1.0", "protobuf>=6.33.5,<7", "psutil>=7.0.0", diff --git a/uv.lock b/uv.lock index f429ae7b71..5cb8d2ef0c 100644 --- a/uv.lock +++ b/uv.lock @@ -1995,7 +1995,7 @@ requires-dist = [ { name = "dimos", extras = ["unitree"], marker = "extra == 'unitree-dds'" }, { name = "dimos-lcm" }, { name = "dimos-lcm", marker = "extra == 'docker'" }, - { name = "dimos-viewer", specifier = ">=0.30.0a2" }, + { name = "dimos-viewer", specifier = "==0.30.0a6.dev99" }, { name = "dimos-viewer", marker = "extra == 'visualization'", specifier = ">=0.30.0a4" }, { name = "drake", marker = "platform_machine != 'aarch64' and sys_platform == 'darwin' and extra == 'manipulation'", specifier = "==1.45.0" }, { name = "drake", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and extra == 'manipulation'", specifier = ">=1.40.0" }, @@ -2168,18 +2168,18 @@ wheels = [ [[package]] name = "dimos-viewer" -version = "0.30.0a6" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/0b/90/ad6d0e1e177a10a0b4f7e736436b6d2741acaeb402ab59504347236744f4/dimos_viewer-0.30.0a6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e623a21e6992e263513847e12809a0d234d73fc7af42a6428e84ca165ba682d0", size = 35309553, upload-time = "2026-03-18T15:22:26.874Z" }, - { url = "https://files.pythonhosted.org/packages/a1/84/1c8f41ff2bd5b6ee143eb6119107397dac284fa4f1f8335623c498bd1d9c/dimos_viewer-0.30.0a6-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:36068a3293cb1c7f4db9f4e6c9fea2d7dd2a2527025f803585f4d3aaad9aedbd", size = 39072034, upload-time = "2026-03-18T15:22:29.592Z" }, - { url = "https://files.pythonhosted.org/packages/58/e6/d6214245e5b99e1da262d037f52d3d39c6b87c65acb516fb08f11378e932/dimos_viewer-0.30.0a6-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:2bf36e8c8bd9dd822bedd1cb2d80ee2bf74b58184ba33872494baed0395fa7ff", size = 41447599, upload-time = "2026-03-18T15:22:32.699Z" }, - { url = "https://files.pythonhosted.org/packages/48/04/80f566400776cab9af68b4a3c0132f55786acd1641ea39d8b75e797a2e22/dimos_viewer-0.30.0a6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:947cfa10c583b357d589c10cb466c63b3651a83d1013a254c0ba03fc2959bef7", size = 35309552, upload-time = "2026-03-18T15:22:35.395Z" }, - { url = "https://files.pythonhosted.org/packages/4c/c3/72157e0806951c2c71c70dcd783e27be8d694344d7ecdb94eaef1066cf99/dimos_viewer-0.30.0a6-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:53ca4ac1f0778f1d9afb317b6268c941c02b20af86dd2aaaf1ea79f2c1d1eeb8", size = 39072018, upload-time = "2026-03-18T15:22:38.043Z" }, - { url = "https://files.pythonhosted.org/packages/2f/92/959fc1e9cdcb5fd8d793b2c8515a6086c9f913ba470baad1f3182ae4c242/dimos_viewer-0.30.0a6-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:27e108060a942c92f7869a0e45693dfe1798896bd90cbac6d1ce019a682f8ba7", size = 41447647, upload-time = "2026-03-18T15:22:41.003Z" }, - { url = "https://files.pythonhosted.org/packages/ab/d6/d76763b60d82539e92777500551116306cfea462f6976ad814a3bdf57e1d/dimos_viewer-0.30.0a6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f4f49f973c51055cfd594b68a8e9d183c706f94b1513b6b69db900d05850f741", size = 35309553, upload-time = "2026-03-18T15:22:43.681Z" }, - { url = "https://files.pythonhosted.org/packages/26/ab/6ea7686c467caecdc74dd8d3a0267053ac74229b3afebc64cff180d5074c/dimos_viewer-0.30.0a6-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:791ef1c1d8d41db69a7d2b701ed3f0b6bc39cb3264aaef7300eddb576c8df7ed", size = 39072062, upload-time = "2026-03-18T15:22:46.264Z" }, - { url = "https://files.pythonhosted.org/packages/3c/87/fce7aac56d8a234d3d7c0911928bb3471d7852e35263b966d2aac5be42cd/dimos_viewer-0.30.0a6-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:dd976c39c38718b8373e1894d55b78c10bcb8c5716c8dbd5fba59141bc08ab3c", size = 41447667, upload-time = "2026-03-18T15:22:49.214Z" }, +version = "0.30.0a6.dev99" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3d/0e/d363be05f172bafe5f41a95db318891637e902c50edfdc642edec6bb5111/dimos_viewer-0.30.0a6.dev99-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cfa57e68e8f4094d4a38d202414046fd2419ff2875ace3f16b8581c3106feca4", size = 35405401, upload-time = "2026-04-17T04:19:10.126Z" }, + { url = "https://files.pythonhosted.org/packages/e7/ab/0730fed402b3b92e35194f11b76119754d619fa6bab00a1932b5c78f87b3/dimos_viewer-0.30.0a6.dev99-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:f3bc243342131c8c2b653cc6b76f04d65aad525f5560829b78aa1a7d31a9d375", size = 39167146, upload-time = "2026-04-17T04:19:14.177Z" }, + { url = "https://files.pythonhosted.org/packages/bb/d9/1415d5d7e609d69b05e8e1167a66dd7cb78f3933205f9b321ae18233384c/dimos_viewer-0.30.0a6.dev99-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b954083fcb8951641554fdea95425b3b5ac9415cd1b65410a137d38d3dd57b8a", size = 41536165, upload-time = "2026-04-17T04:19:17.379Z" }, + { url = "https://files.pythonhosted.org/packages/93/7c/7ee6049a753c01ccbe8357f9c5f789378103b87331e5ca7977f05adf5c42/dimos_viewer-0.30.0a6.dev99-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0387201efd1260f968853f0d7863876b6db375b2af15b22f221a893fcce6549c", size = 35405408, upload-time = "2026-04-17T04:19:20.08Z" }, + { url = "https://files.pythonhosted.org/packages/de/2e/9b4252a12c4b641ab1479a6a4d3d576e75fc42ca2a797d88e2e0626abda0/dimos_viewer-0.30.0a6.dev99-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:a0fae6f2077fc6ceb25e1ed33fb7ccf183ef3e2a30456aa5462b953c1419e547", size = 39167138, upload-time = "2026-04-17T04:19:23.292Z" }, + { url = "https://files.pythonhosted.org/packages/46/2a/4bd02c3d79df2aefc5be47afda6b95121937cef0a3f6b15d071691ec3ca7/dimos_viewer-0.30.0a6.dev99-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e844015f3ad193d50201c39abd3e3f34abbf03adbfb1075468696c1236df1409", size = 41536172, upload-time = "2026-04-17T04:19:26.421Z" }, + { url = "https://files.pythonhosted.org/packages/1b/b1/efcea9b9e21c4ab75e2df016a27e5045e30d91a494465ab0cc627d8d8bc3/dimos_viewer-0.30.0a6.dev99-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:dc82061c2c025684c0fbed5392f793d137b1b0fc3aa1b601988bf4d2ee88aa27", size = 35405409, upload-time = "2026-04-17T04:19:29.574Z" }, + { url = "https://files.pythonhosted.org/packages/2d/8e/d482b0b9379c40ddd7547600543ce726fc3b5d10e396a876f22b2d76d0e6/dimos_viewer-0.30.0a6.dev99-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:0f6acfa0de3083e746ac43fe0d0a328d624bcb859dc698b1bbc592f444f52f15", size = 39167144, upload-time = "2026-04-17T04:19:32.301Z" }, + { url = "https://files.pythonhosted.org/packages/6d/eb/08922721c74ceaa99a824258db02c438d50f77c22ff80332cbc4b1a8db7b/dimos_viewer-0.30.0a6.dev99-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:56fa9139c49ec4bf96b12d6e98d3de3319a66876374ae57bda4534ab7a347765", size = 41536171, upload-time = "2026-04-17T04:19:35.29Z" }, ] [[package]] From 5b98212decf1b4a51fc300370e96c866a5a8da11 Mon Sep 17 00:00:00 2001 From: Jeff Hykin Date: Mon, 27 Apr 2026 15:08:40 -0500 Subject: [PATCH 26/30] fix(types): resolve mypy 3.10 errors (#1921) --- dimos/memory2/vis/plot/elements.py | 4 ++-- dimos/memory2/vis/plot/plot.py | 4 ++-- dimos/perception/detection/module2D.py | 3 ++- dimos/perception/detection/type/imageDetections.py | 8 +++++++- 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/dimos/memory2/vis/plot/elements.py b/dimos/memory2/vis/plot/elements.py index 8b5932da53..7f83de2b94 100644 --- a/dimos/memory2/vis/plot/elements.py +++ b/dimos/memory2/vis/plot/elements.py @@ -17,11 +17,11 @@ from __future__ import annotations from dataclasses import dataclass -from enum import StrEnum +from enum import Enum from typing import Union -class Style(StrEnum): +class Style(str, Enum): """Line style for Series and HLine elements. Values match matplotlib's `linestyle` names so they pass through directly diff --git a/dimos/memory2/vis/plot/plot.py b/dimos/memory2/vis/plot/plot.py index 6235e44bda..082b147125 100644 --- a/dimos/memory2/vis/plot/plot.py +++ b/dimos/memory2/vis/plot/plot.py @@ -16,13 +16,13 @@ from __future__ import annotations -from enum import StrEnum +from enum import Enum from typing import Any from dimos.memory2.vis.plot.elements import HLine, Markers, PlotElement, Series, VLine -class TimeAxis(StrEnum): +class TimeAxis(str, Enum): """How the x-axis is formatted. - ``raw``: unix timestamps as-is (matplotlib's default numeric formatter). diff --git a/dimos/perception/detection/module2D.py b/dimos/perception/detection/module2D.py index fb07e02d3c..3f9aee84e4 100644 --- a/dimos/perception/detection/module2D.py +++ b/dimos/perception/detection/module2D.py @@ -78,7 +78,8 @@ def process_image_frame(self, image: Image) -> ImageDetections2D: imageDetections = self.detector.process_image(image) if not self.config.filter: return imageDetections - return imageDetections.filter(*self.config.filter) + filtered: ImageDetections2D = imageDetections.filter(*self.config.filter) + return filtered @simple_mcache def sharp_image_stream(self) -> Observable[Image]: diff --git a/dimos/perception/detection/type/imageDetections.py b/dimos/perception/detection/type/imageDetections.py index 6820b6210d..e3212a7804 100644 --- a/dimos/perception/detection/type/imageDetections.py +++ b/dimos/perception/detection/type/imageDetections.py @@ -16,7 +16,13 @@ from functools import reduce from operator import add -from typing import TYPE_CHECKING, Generic, Self, TypeVar +import sys +from typing import TYPE_CHECKING, Generic, TypeVar + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self if sys.version_info >= (3, 11): from typing import Self From 060ff686170ff5ce4f2b923da125af6ec6517904 Mon Sep 17 00:00:00 2001 From: leshy Date: Tue, 28 Apr 2026 20:27:08 +0300 Subject: [PATCH 27/30] Revert "Jeff/fix/rconnect2" (#1924) --- .gitignore | 3 - dimos/core/coordination/python_worker.py | 16 +- dimos/core/docker_module.py | 2 +- dimos/core/global_config.py | 11 +- dimos/hardware/sensors/camera/module.py | 5 +- .../lidar/fastlio2/fastlio_blueprints.py | 35 +-- .../sensors/lidar/livox/livox_blueprints.py | 4 +- dimos/manipulation/blueprints.py | 10 +- dimos/manipulation/grasping/demo_grasping.py | 4 +- .../wavefront_frontier_goal_selector.py | 11 - dimos/navigation/replanning_a_star/module.py | 18 +- .../movement_manager/movement_manager.py | 133 --------- .../movement_manager/test_movement_manager.py | 117 -------- .../demo_object_scene_registration.py | 4 +- dimos/robot/all_blueprints.py | 2 - dimos/robot/cli/dimos.py | 48 +--- .../drone/blueprints/basic/drone_basic.py | 17 +- .../blueprints/perceptive/unitree_g1_shm.py | 10 +- .../primitive/uintree_g1_primitive_no_nav.py | 19 +- .../agentic/unitree_go2_security.py | 4 +- .../go2/blueprints/basic/unitree_go2_basic.py | 34 ++- .../go2/blueprints/basic/unitree_go2_fleet.py | 6 +- .../unitree_go2_webrtc_keyboard_teleop.py | 4 - .../go2/blueprints/smart/unitree_go2.py | 6 +- dimos/robot/unitree/keyboard_teleop.py | 10 +- dimos/robot/unitree/mujoco_connection.py | 16 +- dimos/simulation/unity/blueprint.py | 4 +- dimos/teleop/quest/blueprints.py | 4 +- dimos/test_no_sections.py | 2 - dimos/utils/generic.py | 17 -- dimos/visualization/rerun/bridge.py | 253 +++++++++--------- dimos/visualization/rerun/conftest.py | 45 ---- dimos/visualization/rerun/constants.py | 31 --- .../visualization/rerun/test_viewer_ws_e2e.py | 201 -------------- .../rerun/test_websocket_server.py | 210 --------------- dimos/visualization/rerun/websocket_server.py | 244 ----------------- dimos/visualization/vis_module.py | 87 ------ .../web/websocket_vis/websocket_vis_module.py | 24 +- docs/development/conventions.md | 12 - docs/usage/cli.md | 4 +- docs/usage/visualization.md | 42 ++- pyproject.toml | 2 +- uv.lock | 26 +- 43 files changed, 292 insertions(+), 1465 deletions(-) delete mode 100644 dimos/navigation/smart_nav/modules/movement_manager/movement_manager.py delete mode 100644 dimos/navigation/smart_nav/modules/movement_manager/test_movement_manager.py delete mode 100644 dimos/visualization/rerun/conftest.py delete mode 100644 dimos/visualization/rerun/constants.py delete mode 100644 dimos/visualization/rerun/test_viewer_ws_e2e.py delete mode 100644 dimos/visualization/rerun/test_websocket_server.py delete mode 100644 dimos/visualization/rerun/websocket_server.py delete mode 100644 dimos/visualization/vis_module.py delete mode 100644 docs/development/conventions.md diff --git a/.gitignore b/.gitignore index ea68926e96..1816510c08 100644 --- a/.gitignore +++ b/.gitignore @@ -74,9 +74,6 @@ CLAUDE.MD /.mcp.json *.speedscope.json -# Hidden/personal directories -.hidden/ - # Coverage htmlcov/ .coverage diff --git a/dimos/core/coordination/python_worker.py b/dimos/core/coordination/python_worker.py index 6c3aab3a2d..3c434a982e 100644 --- a/dimos/core/coordination/python_worker.py +++ b/dimos/core/coordination/python_worker.py @@ -18,7 +18,6 @@ import multiprocessing from multiprocessing.connection import Connection import os -import signal import sys import threading import traceback @@ -338,15 +337,12 @@ class _WorkerState: def _worker_entrypoint(conn: Connection, worker_id: int) -> None: apply_library_config() - # Ignore SIGINT so the coordinator can orchestrate shutdown via the pipe. - # Without this, workers race with the coordinator: they start tearing down - # modules locally while the coordinator tries to send stop() RPCs, causing - # BrokenPipeErrors. - signal.signal(signal.SIGINT, signal.SIG_IGN) state = _WorkerState(instances={}, worker_id=worker_id) try: _worker_loop(conn, state) + except KeyboardInterrupt: + logger.info("Worker got KeyboardInterrupt.", worker_id=worker_id) except Exception as e: logger.error(f"Worker process error: {e}", exc_info=True) finally: @@ -365,6 +361,12 @@ def _worker_entrypoint(conn: Connection, worker_id: int) -> None: worker_id=worker_id, module_id=module_id, ) + except KeyboardInterrupt: + logger.warning( + "KeyboardInterrupt during worker stop", + module=type(instance).__name__, + worker_id=worker_id, + ) except Exception: logger.error("Error during worker shutdown", exc_info=True) @@ -431,7 +433,7 @@ def _worker_loop(conn: Connection, state: _WorkerState) -> None: if not conn.poll(timeout=0.1): continue request = conn.recv() - except EOFError: + except (EOFError, KeyboardInterrupt): break try: diff --git a/dimos/core/docker_module.py b/dimos/core/docker_module.py index f82a1b56db..3ad9620556 100644 --- a/dimos/core/docker_module.py +++ b/dimos/core/docker_module.py @@ -30,7 +30,7 @@ from dimos.core.rpc_client import ModuleProxyProtocol, RpcCall from dimos.protocol.rpc.pubsubrpc import LCMRPC from dimos.utils.logging_config import setup_logger -from dimos.visualization.rerun.constants import RERUN_GRPC_PORT, RERUN_WEB_PORT +from dimos.visualization.rerun.bridge import RERUN_GRPC_PORT, RERUN_WEB_PORT if TYPE_CHECKING: from collections.abc import Callable diff --git a/dimos/core/global_config.py b/dimos/core/global_config.py index 435f421dd1..214401959e 100644 --- a/dimos/core/global_config.py +++ b/dimos/core/global_config.py @@ -13,16 +13,13 @@ # limitations under the License. import re +from typing import Literal, TypeAlias from pydantic_settings import BaseSettings, SettingsConfigDict from dimos.models.vl.types import VlModelName -from dimos.visualization.rerun.constants import ( - RERUN_ENABLE_WEB, - RERUN_OPEN_DEFAULT, - RerunOpenOption, - ViewerBackend, -) + +ViewerBackend: TypeAlias = Literal["rerun", "rerun-web", "rerun-connect", "foxglove", "none"] def _get_all_numbers(s: str) -> list[float]: @@ -40,8 +37,6 @@ class GlobalConfig(BaseSettings): replay_db: str = "go2_bigoffice" new_memory: bool = False viewer: ViewerBackend = "rerun" - rerun_open: RerunOpenOption = RERUN_OPEN_DEFAULT - rerun_web: bool = RERUN_ENABLE_WEB n_workers: int = 2 memory_limit: str = "auto" mujoco_camera_position: str | None = None diff --git a/dimos/hardware/sensors/camera/module.py b/dimos/hardware/sensors/camera/module.py index 0fe0d8f030..9b4f50920c 100644 --- a/dimos/hardware/sensors/camera/module.py +++ b/dimos/hardware/sensors/camera/module.py @@ -21,7 +21,6 @@ from dimos.agents.annotation import skill from dimos.core.coordination.blueprints import autoconnect from dimos.core.core import rpc -from dimos.core.global_config import global_config from dimos.core.module import Module, ModuleConfig from dimos.core.stream import Out from dimos.hardware.sensors.camera.spec import CameraHardware @@ -32,7 +31,7 @@ from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo from dimos.msgs.sensor_msgs.Image import Image, sharpness_barrier from dimos.spec import perception -from dimos.visualization.vis_module import vis_module +from dimos.visualization.rerun.bridge import RerunBridgeModule def default_transform() -> Transform: @@ -121,5 +120,5 @@ def stop(self) -> None: demo_camera = autoconnect( CameraModule.blueprint(), - vis_module(viewer_backend=global_config.viewer), + RerunBridgeModule.blueprint(), ) diff --git a/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py b/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py index 2c2a64d61e..2946f1d247 100644 --- a/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py +++ b/dimos/hardware/sensors/lidar/fastlio2/fastlio_blueprints.py @@ -15,45 +15,30 @@ from dimos.core.coordination.blueprints import autoconnect from dimos.hardware.sensors.lidar.fastlio2.module import FastLio2 from dimos.mapping.voxels import VoxelGridMapper -from dimos.visualization.vis_module import vis_module +from dimos.visualization.rerun.bridge import RerunBridgeModule voxel_size = 0.05 mid360_fastlio = autoconnect( FastLio2.blueprint(voxel_size=voxel_size, map_voxel_size=voxel_size, map_freq=-1), - vis_module( - "rerun", - rerun_config={ - "visual_override": { - "world/lidar": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), - }, - }, - ), + RerunBridgeModule.blueprint(), ).global_config(n_workers=2, robot_model="mid360_fastlio2") mid360_fastlio_voxels = autoconnect( FastLio2.blueprint(), VoxelGridMapper.blueprint(voxel_size=voxel_size, carve_columns=False), - vis_module( - "rerun", - rerun_config={ - "visual_override": { - "world/global_map": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), - "world/lidar": None, - }, - }, + RerunBridgeModule.blueprint( + visual_override={ + "world/lidar": None, + } ), ).global_config(n_workers=3, robot_model="mid360_fastlio2_voxels") mid360_fastlio_voxels_native = autoconnect( FastLio2.blueprint(voxel_size=voxel_size, map_voxel_size=voxel_size, map_freq=3.0), - vis_module( - "rerun", - rerun_config={ - "visual_override": { - "world/lidar": None, - "world/global_map": lambda grid: grid.to_rerun(voxel_size=voxel_size, mode="boxes"), - }, - }, + RerunBridgeModule.blueprint( + visual_override={ + "world/lidar": None, + } ), ).global_config(n_workers=2, robot_model="mid360_fastlio2") diff --git a/dimos/hardware/sensors/lidar/livox/livox_blueprints.py b/dimos/hardware/sensors/lidar/livox/livox_blueprints.py index e437d73994..34ebc33c2a 100644 --- a/dimos/hardware/sensors/lidar/livox/livox_blueprints.py +++ b/dimos/hardware/sensors/lidar/livox/livox_blueprints.py @@ -14,9 +14,9 @@ from dimos.core.coordination.blueprints import autoconnect from dimos.hardware.sensors.lidar.livox.module import Mid360 -from dimos.visualization.vis_module import vis_module +from dimos.visualization.rerun.bridge import RerunBridgeModule mid360 = autoconnect( Mid360.blueprint(), - vis_module("rerun"), + RerunBridgeModule.blueprint(), ).global_config(n_workers=2, robot_model="mid360") diff --git a/dimos/manipulation/blueprints.py b/dimos/manipulation/blueprints.py index 1c006c1d04..f950ea8efa 100644 --- a/dimos/manipulation/blueprints.py +++ b/dimos/manipulation/blueprints.py @@ -44,7 +44,7 @@ from dimos.msgs.sensor_msgs.JointState import JointState from dimos.perception.object_scene_registration import ObjectSceneRegistrationModule from dimos.robot.catalog.ufactory import xarm6 as _catalog_xarm6, xarm7 as _catalog_xarm7 -from dimos.visualization.vis_module import vis_module +from dimos.robot.foxglove_bridge import FoxgloveBridge # TODO: migrate to rerun # Single XArm6 planner (standalone, no coordinator) _xarm6_planner_cfg = _catalog_xarm6( @@ -196,14 +196,14 @@ use_aabb=True, max_obstacle_width=0.06, ), - vis_module("foxglove"), + FoxgloveBridge.blueprint(), # TODO: migrate to rerun ) .transports( { ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), } ) - .global_config(n_workers=4) + .global_config(viewer="foxglove", n_workers=4) ) @@ -289,7 +289,7 @@ from dimos.robot.catalog.ufactory import XARM7_SIM_PATH from dimos.simulation.engines.mujoco_sim_module import MujocoSimModule -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode _xarm7_sim_cfg = _catalog_xarm7( name="arm", @@ -323,7 +323,7 @@ hardware=[_xarm7_sim_cfg.to_hardware_component()], tasks=[_xarm7_sim_cfg.to_task_config()], ), - RerunBridgeModule.blueprint(), + RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode()), ).transports( { ("joint_state", JointState): LCMTransport("/coordinator/joint_state", JointState), diff --git a/dimos/manipulation/grasping/demo_grasping.py b/dimos/manipulation/grasping/demo_grasping.py index 4a1d4b2cf6..37e1d38f1e 100644 --- a/dimos/manipulation/grasping/demo_grasping.py +++ b/dimos/manipulation/grasping/demo_grasping.py @@ -22,7 +22,7 @@ from dimos.manipulation.grasping.grasping import GraspingModule from dimos.perception.detection.detectors.yoloe import YoloePromptMode from dimos.perception.object_scene_registration import ObjectSceneRegistrationModule -from dimos.visualization.vis_module import vis_module +from dimos.robot.foxglove_bridge import FoxgloveBridge camera_module = RealSenseCamera.blueprint(enable_pointcloud=False) @@ -44,7 +44,7 @@ ("/tmp", "/tmp", "rw") ], # Grasp visualization debug standalone: python -m dimos.manipulation.grasping.visualize_grasps ), - vis_module("foxglove"), + FoxgloveBridge.blueprint(), McpServer.blueprint(), McpClient.blueprint(), ).global_config(viewer="foxglove") diff --git a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py index 338d10d9b0..b8dbe0dfc8 100644 --- a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py @@ -115,7 +115,6 @@ class WavefrontFrontierExplorer(Module): goal_reached: In[Bool] explore_cmd: In[Bool] stop_explore_cmd: In[Bool] - stop_movement: In[Bool] # LCM outputs goal_request: Out[PoseStamped] @@ -172,10 +171,6 @@ def start(self) -> None: unsub = self.stop_explore_cmd.subscribe(self._on_stop_explore_cmd) self.register_disposable(Disposable(unsub)) - if self.stop_movement.transport is not None: - unsub = self.stop_movement.subscribe(self._on_stop_movement) - self.register_disposable(Disposable(unsub)) - @rpc def stop(self) -> None: self.stop_exploration() @@ -206,12 +201,6 @@ def _on_stop_explore_cmd(self, msg: Bool) -> None: logger.info("Received exploration stop command via LCM") self.stop_exploration() - def _on_stop_movement(self, msg: Bool) -> None: - """Handle stop movement from teleop — cancel active exploration.""" - if msg.data and self.exploration_active: - logger.info("WavefrontFrontierExplorer: stop_movement received, stopping exploration") - self.stop_exploration() - def _count_costmap_information(self, costmap: OccupancyGrid) -> int: """ Count the amount of information in a costmap (free space + obstacles). diff --git a/dimos/navigation/replanning_a_star/module.py b/dimos/navigation/replanning_a_star/module.py index efc16b52d6..2375af20ce 100644 --- a/dimos/navigation/replanning_a_star/module.py +++ b/dimos/navigation/replanning_a_star/module.py @@ -28,9 +28,6 @@ from dimos.msgs.nav_msgs.Path import Path from dimos.navigation.base import NavigationInterface, NavigationState from dimos.navigation.replanning_a_star.global_planner import GlobalPlanner -from dimos.utils.logging_config import setup_logger - -logger = setup_logger() class ReplanningAStarPlanner(Module, NavigationInterface): @@ -39,11 +36,10 @@ class ReplanningAStarPlanner(Module, NavigationInterface): goal_request: In[PoseStamped] clicked_point: In[PointStamped] target: In[PoseStamped] - stop_movement: In[Bool] goal_reached: Out[Bool] navigation_state: Out[String] # TODO: set it - nav_cmd_vel: Out[Twist] + cmd_vel: Out[Twist] path: Out[Path] navigation_costmap: Out[OccupancyGrid] @@ -76,14 +72,9 @@ def start(self) -> None: ) ) - if self.stop_movement.transport is not None: - self.register_disposable( - Disposable(self.stop_movement.subscribe(self._on_stop_movement)) - ) - self.register_disposable(self._planner.path.subscribe(self.path.publish)) - self.register_disposable(self._planner.cmd_vel.subscribe(self.nav_cmd_vel.publish)) + self.register_disposable(self._planner.cmd_vel.subscribe(self.cmd_vel.publish)) self.register_disposable(self._planner.goal_reached.subscribe(self.goal_reached.publish)) @@ -101,11 +92,6 @@ def stop(self) -> None: super().stop() - def _on_stop_movement(self, msg: Bool) -> None: - if msg.data: - logger.info("ReplanningAStarPlanner: stop_movement received, cancelling goal") - self.cancel_goal() - @rpc def set_goal(self, goal: PoseStamped) -> bool: self._planner.handle_goal_request(goal) diff --git a/dimos/navigation/smart_nav/modules/movement_manager/movement_manager.py b/dimos/navigation/smart_nav/modules/movement_manager/movement_manager.py deleted file mode 100644 index 5a2dd195c0..0000000000 --- a/dimos/navigation/smart_nav/modules/movement_manager/movement_manager.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""MovementManager: click-to-goal relay + teleop/nav velocity mux.""" - -from __future__ import annotations - -import math -import threading -import time -from typing import Any - -from dimos_lcm.std_msgs import Bool # type: ignore[import-untyped] -from reactivex.disposable import Disposable - -from dimos.core.core import rpc -from dimos.core.module import Module, ModuleConfig -from dimos.core.stream import In, Out -from dimos.msgs.geometry_msgs.PointStamped import PointStamped -from dimos.msgs.geometry_msgs.Twist import Twist -from dimos.msgs.geometry_msgs.Vector3 import Vector3 -from dimos.utils.logging_config import setup_logger - -logger = setup_logger() - - -class MovementManagerConfig(ModuleConfig): - tele_cooldown_sec: float = 1.0 - tele_cmd_vel_scaling: Twist = Twist(Vector3(1, 1, 1), Vector3(1, 1, 1)) - - -class MovementManager(Module): - """Combine tele_cmd_vel (keyboard controls) and nav_cmd_vel in a sane way, output cmd_vel""" - - config: MovementManagerConfig - - clicked_point: In[PointStamped] - nav_cmd_vel: In[Twist] - tele_cmd_vel: In[Twist] - - goal: Out[PointStamped] - way_point: Out[PointStamped] - cmd_vel: Out[Twist] - stop_movement: Out[Bool] - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._lock = threading.Lock() - self._teleop_active = False - self._last_teleop_time = 0.0 - - @rpc - def start(self) -> None: - super().start() - self.register_disposable(Disposable(self.clicked_point.subscribe(self._on_click))) - self.register_disposable(Disposable(self.nav_cmd_vel.subscribe(self._on_nav))) - self.register_disposable(Disposable(self.tele_cmd_vel.subscribe(self._on_teleop))) - - @rpc - def stop(self) -> None: - with self._lock: - self._teleop_active = False - super().stop() - - def _on_click(self, msg: PointStamped) -> None: - if not all(math.isfinite(v) for v in (msg.x, msg.y, msg.z)): - logger.warning("Ignored invalid click", x=msg.x, y=msg.y, z=msg.z) - return - if abs(msg.x) > 500 or abs(msg.y) > 500 or abs(msg.z) > 50: - logger.warning("Ignored out-of-range click", x=msg.x, y=msg.y, z=msg.z) - return - - logger.debug("Goal", x=round(msg.x, 1), y=round(msg.y, 1), z=round(msg.z, 1)) - self.way_point.publish(msg) - self.goal.publish(msg) - - def _cancel_goal(self) -> None: - self.stop_movement.publish(Bool(data=True)) - # NOTE: this NaN goal is more of a safety fallback. - # It can be REALLY bad if a robot is supposed to stop moving but wont - # we should probably think a more robust/strict requirement on planners - cancel = PointStamped( - ts=time.time(), frame_id="map", x=float("nan"), y=float("nan"), z=float("nan") - ) - self.way_point.publish(cancel) - self.goal.publish(cancel) - logger.debug("Navigation cancelled — waiting for new goal") - - def _on_nav(self, msg: Twist) -> None: - with self._lock: - if self._teleop_active: - # check if cooldown has expired - elapsed = time.monotonic() - self._last_teleop_time - if elapsed < self.config.tele_cooldown_sec: - return - self._teleop_active = False - self.cmd_vel.publish(msg) - - def _on_teleop(self, msg: Twist) -> None: - with self._lock: - was_active = self._teleop_active - self._teleop_active = True - self._last_teleop_time = time.monotonic() - - if not was_active: - self._cancel_goal() - logger.info("Teleop active") - - scale = self.config.tele_cmd_vel_scaling - scaled = Twist( - linear=Vector3( - msg.linear.x * scale.linear.x, - msg.linear.y * scale.linear.y, - msg.linear.z * scale.linear.z, - ), - angular=Vector3( - msg.angular.x * scale.angular.x, - msg.angular.y * scale.angular.y, - msg.angular.z * scale.angular.z, - ), - ) - self.cmd_vel.publish(scaled) diff --git a/dimos/navigation/smart_nav/modules/movement_manager/test_movement_manager.py b/dimos/navigation/smart_nav/modules/movement_manager/test_movement_manager.py deleted file mode 100644 index 6858055605..0000000000 --- a/dimos/navigation/smart_nav/modules/movement_manager/test_movement_manager.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for MovementManager: click-to-goal + teleop/nav velocity mux.""" - -from __future__ import annotations - -import math -import time -from unittest.mock import MagicMock - -import pytest - -from dimos.msgs.geometry_msgs.PointStamped import PointStamped -from dimos.msgs.geometry_msgs.Twist import Twist -from dimos.msgs.geometry_msgs.Vector3 import Vector3 -from dimos.navigation.smart_nav.modules.movement_manager.movement_manager import ( - MovementManager, -) - - -@pytest.fixture() -def manager() -> MovementManager: - """Create a real MovementManager and mock the publish methods on its output streams.""" - module = MovementManager(tele_cooldown_sec=0.1) - module.cmd_vel.publish = MagicMock() - module.stop_movement.publish = MagicMock() - module.goal.publish = MagicMock() - module.way_point.publish = MagicMock() - yield module - module._close_module() - - -def _twist(lx: float = 0.0) -> Twist: - return Twist(linear=Vector3(lx, 0, 0), angular=Vector3(0, 0, 0)) - - -def _click(x: float = 1.0, y: float = 2.0, z: float = 0.0) -> PointStamped: - return PointStamped(ts=time.time(), frame_id="map", x=x, y=y, z=z) - - -def test_teleop_suppresses_nav_and_cancels_goal(manager: MovementManager) -> None: - """Teleop arriving should suppress nav, publish stop_movement, and cancel the goal with NaN.""" - manager.config.tele_cooldown_sec = 10.0 - manager._on_teleop(_twist(lx=0.3)) - - # Nav is suppressed - manager.cmd_vel.publish.reset_mock() # type: ignore[union-attr] - manager._on_nav(_twist(lx=0.9)) - manager.cmd_vel.publish.assert_not_called() # type: ignore[union-attr] - - # stop_movement fired - manager.stop_movement.publish.assert_called_once() # type: ignore[union-attr] - - # Goal cancelled with NaN - cancel_msg = manager.goal.publish.call_args[0][0] # type: ignore[union-attr] - assert math.isnan(cancel_msg.x) - - -def test_nav_resumes_after_cooldown(manager: MovementManager) -> None: - """After the cooldown expires, nav commands pass through again.""" - manager.config.tele_cooldown_sec = 0.05 - manager._on_teleop(_twist(lx=0.3)) - time.sleep(0.1) - manager.cmd_vel.publish.reset_mock() # type: ignore[union-attr] - - manager._on_nav(_twist(lx=0.9)) - manager.cmd_vel.publish.assert_called_once() # type: ignore[union-attr] - - -def test_valid_click_publishes_goal(manager: MovementManager) -> None: - """A valid click should publish to both goal and way_point.""" - click = _click(x=5.0, y=3.0, z=0.1) - manager._on_click(click) - manager.goal.publish.assert_called_once_with(click) # type: ignore[union-attr] - manager.way_point.publish.assert_called_once_with(click) # type: ignore[union-attr] - - -def test_invalid_clicks_rejected(manager: MovementManager) -> None: - """NaN, Inf, and out-of-range clicks should not publish.""" - for bad_click in [ - _click(x=float("nan")), - _click(x=float("inf")), - _click(x=600.0), - ]: - manager._on_click(bad_click) - manager.goal.publish.assert_not_called() # type: ignore[union-attr] - - -def test_tele_cmd_vel_scaling() -> None: - """tele_cmd_vel_scaling multiplies each teleop twist component independently.""" - scaling = Twist(Vector3(0.5, 2.0, 0.0), Vector3(1.0, 1.0, 0.25)) - module = MovementManager(tele_cooldown_sec=10.0, tele_cmd_vel_scaling=scaling) - module.cmd_vel.publish = MagicMock() - module.stop_movement.publish = MagicMock() - module.goal.publish = MagicMock() - module.way_point.publish = MagicMock() - - module._on_teleop(Twist(Vector3(1, 1, 1), Vector3(1, 1, 1))) - - published = module.cmd_vel.publish.call_args[0][0] # type: ignore[union-attr] - assert published.linear.x == pytest.approx(0.5) - assert published.linear.y == pytest.approx(2.0) - assert published.linear.z == pytest.approx(0.0) - assert published.angular.z == pytest.approx(0.25) - module._close_module() diff --git a/dimos/perception/demo_object_scene_registration.py b/dimos/perception/demo_object_scene_registration.py index 28044dec13..c9b489f54b 100644 --- a/dimos/perception/demo_object_scene_registration.py +++ b/dimos/perception/demo_object_scene_registration.py @@ -20,7 +20,7 @@ from dimos.hardware.sensors.camera.zed.compat import ZEDCamera from dimos.perception.detection.detectors.yoloe import YoloePromptMode from dimos.perception.object_scene_registration import ObjectSceneRegistrationModule -from dimos.visualization.vis_module import vis_module +from dimos.robot.foxglove_bridge import FoxgloveBridge camera_choice = "zed" @@ -34,7 +34,7 @@ demo_object_scene_registration = autoconnect( camera_module, ObjectSceneRegistrationModule.blueprint(target_frame="world", prompt_mode=YoloePromptMode.LRPC), - vis_module("foxglove"), + FoxgloveBridge.blueprint(), McpServer.blueprint(), McpClient.blueprint(), ).global_config(viewer="foxglove") diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 11b2ceb731..c794a67124 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -153,7 +153,6 @@ "mock-b1-connection-module": "dimos.robot.unitree.b1.connection.MockB1ConnectionModule", "module-a": "dimos.robot.unitree.demo_error_on_name_conflicts.ModuleA", "module-b": "dimos.robot.unitree.demo_error_on_name_conflicts.ModuleB", - "movement-manager": "dimos.navigation.smart_nav.modules.movement_manager.movement_manager.MovementManager", "mujoco-sim-module": "dimos.simulation.engines.mujoco_sim_module.MujocoSimModule", "navigation-module": "dimos.robot.unitree.rosnav.NavigationModule", "navigation-skill-container": "dimos.agents.skills.navigation.NavigationSkillContainer", @@ -176,7 +175,6 @@ "reid-module": "dimos.perception.detection.reid.module.ReidModule", "replanning-a-star-planner": "dimos.navigation.replanning_a_star.module.ReplanningAStarPlanner", "rerun-bridge-module": "dimos.visualization.rerun.bridge.RerunBridgeModule", - "rerun-web-socket-server": "dimos.visualization.rerun.websocket_server.RerunWebSocketServer", "ros-nav": "dimos.navigation.rosnav.ROSNav", "security-module": "dimos.experimental.security_demo.security_module.SecurityModule", "semantic-search": "dimos.memory2.module.SemanticSearch", diff --git a/dimos/robot/cli/dimos.py b/dimos/robot/cli/dimos.py index e99553c2b3..37d1bd2be0 100644 --- a/dimos/robot/cli/dimos.py +++ b/dimos/robot/cli/dimos.py @@ -21,11 +21,10 @@ import json import os from pathlib import Path -import signal import sys import time import types -from typing import TYPE_CHECKING, Any, Union, cast, get_args, get_origin +from typing import TYPE_CHECKING, Any, Union, get_args, get_origin import click from dotenv import load_dotenv @@ -39,10 +38,7 @@ from dimos.core.daemon import daemonize, install_signal_handlers from dimos.core.global_config import GlobalConfig, global_config from dimos.core.run_registry import get_most_recent, is_pid_alive, stop_entry -from dimos.protocol.pubsub.impl.lcmpubsub import LCM -from dimos.protocol.service.lcmservice import autoconf from dimos.utils.logging_config import setup_logger -from dimos.visualization.rerun.constants import RerunOpenOption if TYPE_CHECKING: from dimos.core.coordination.blueprints import Blueprint, BlueprintAtom @@ -226,10 +222,6 @@ def run( cli_config_overrides: dict[str, Any] = ctx.obj - # this is a workaround until we have a proper way to have delayed-module-choice in blueprints - # ex: vis_module(viewer=global_config.viewer) is WRONG (viewer will always be default value) without this patch - global_config.update(**cli_config_overrides) - # Clean stale registry entries stale = cleanup_stale() if stale: @@ -668,43 +660,17 @@ def send( @main.command(name="rerun-bridge") def rerun_bridge_cmd( + viewer_mode: str = typer.Option( + "native", help="Viewer mode: native (desktop), web (browser), none (headless)" + ), memory_limit: str = typer.Option( "25%", help="Memory limit for Rerun viewer (e.g., '4GB', '16GB', '25%')" ), - rerun_open: str = typer.Option( - "native", help="How to open Rerun: one of native, web, both, none" - ), - rerun_web: bool = typer.Option( - True, "--rerun-web/--no-rerun-web", help="Enable/Disable Rerun web server" - ), ) -> None: - """Launch the Rerun visualization bridge. - - Standalone utility: runs the bridge directly in the main process (no - blueprint / worker pool) so users can attach a viewer to existing LCM - traffic without building a full module graph. - """ - # Deferred: RerunBridgeModule pulls in the rerun package (~1s), keep it - # out of the CLI's hot path so `dimos --help` stays fast. - from dimos.visualization.rerun.bridge import RerunBridgeModule - - valid = get_args(RerunOpenOption) - if rerun_open not in valid: - raise typer.BadParameter( - f"rerun_open must be one of {valid}, got {rerun_open!r}", param_hint="--rerun-open" - ) - autoconf(check_only=True) - - bridge = RerunBridgeModule( - memory_limit=memory_limit, - rerun_open=cast("RerunOpenOption", rerun_open), - rerun_web=rerun_web, - pubsubs=[LCM()], - ) - bridge.start() + """Launch the Rerun visualization bridge.""" + from dimos.visualization.rerun.bridge import run_bridge - signal.signal(signal.SIGINT, lambda *_: bridge.stop()) - signal.pause() + run_bridge(viewer_mode=viewer_mode, memory_limit=memory_limit) if __name__ == "__main__": diff --git a/dimos/robot/drone/blueprints/basic/drone_basic.py b/dimos/robot/drone/blueprints/basic/drone_basic.py index aaf82f6355..c1838d6ac7 100644 --- a/dimos/robot/drone/blueprints/basic/drone_basic.py +++ b/dimos/robot/drone/blueprints/basic/drone_basic.py @@ -20,9 +20,10 @@ from dimos.core.coordination.blueprints import autoconnect from dimos.core.global_config import global_config +from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.robot.drone.camera_module import DroneCameraModule from dimos.robot.drone.connection_module import DroneConnectionModule -from dimos.visualization.vis_module import vis_module +from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule def _static_drone_body(rr: Any) -> list[Any]: @@ -59,12 +60,23 @@ def _drone_rerun_blueprint() -> Any: _rerun_config = { "blueprint": _drone_rerun_blueprint, + "pubsubs": [LCM()], "static": { "world/tf/base_link": _static_drone_body, }, } -_vis = vis_module(global_config.viewer, rerun_config=_rerun_config) +# Conditional visualization +if global_config.viewer == "foxglove": + from dimos.robot.foxglove_bridge import FoxgloveBridge + + _vis = FoxgloveBridge.blueprint() +elif global_config.viewer.startswith("rerun"): + from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode + + _vis = RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **_rerun_config) +else: + _vis = autoconnect() # Determine connection string based on replay flag connection_string = "udp:0.0.0.0:14550" @@ -80,6 +92,7 @@ def _drone_rerun_blueprint() -> Any: outdoor=False, ), DroneCameraModule.blueprint(camera_intrinsics=[1000.0, 1000.0, 960.0, 540.0]), + WebsocketVisModule.blueprint(), ) __all__ = [ diff --git a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py index 4941abad38..dd135a60a1 100644 --- a/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py +++ b/dimos/robot/unitree/g1/blueprints/perceptive/unitree_g1_shm.py @@ -17,11 +17,10 @@ from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE from dimos.core.coordination.blueprints import autoconnect -from dimos.core.global_config import global_config from dimos.core.transport import pSHMTransport from dimos.msgs.sensor_msgs.Image import Image +from dimos.robot.foxglove_bridge import FoxgloveBridge from dimos.robot.unitree.g1.blueprints.perceptive.unitree_g1 import unitree_g1 -from dimos.visualization.vis_module import vis_module unitree_g1_shm = autoconnect( unitree_g1.transports( @@ -31,9 +30,10 @@ ), } ), - vis_module( - viewer_backend=global_config.viewer, - foxglove_config={"shm_channels": ["/color_image#sensor_msgs.Image"]}, + FoxgloveBridge.blueprint( + shm_channels=[ + "/color_image#sensor_msgs.Image", + ] ), ) diff --git a/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py b/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py index eeabea7909..b04443732f 100644 --- a/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py +++ b/dimos/robot/unitree/g1/blueprints/primitive/uintree_g1_primitive_no_nav.py @@ -40,7 +40,8 @@ from dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector import ( WavefrontFrontierExplorer, ) -from dimos.visualization.vis_module import vis_module +from dimos.protocol.pubsub.impl.lcmpubsub import LCM +from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule def _convert_camera_info(camera_info: Any) -> Any: @@ -93,6 +94,7 @@ def _g1_rerun_blueprint() -> Any: rerun_config = { "blueprint": _g1_rerun_blueprint, + "pubsubs": [LCM()], "visual_override": { "world/camera_info": _convert_camera_info, "world/navigation_costmap": _convert_navigation_costmap, @@ -102,7 +104,18 @@ def _g1_rerun_blueprint() -> Any: }, } -_with_vis = vis_module(viewer_backend=global_config.viewer, rerun_config=rerun_config) +if global_config.viewer == "foxglove": + from dimos.robot.foxglove_bridge import FoxgloveBridge + + _with_vis = autoconnect(FoxgloveBridge.blueprint()) +elif global_config.viewer.startswith("rerun"): + from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode + + _with_vis = autoconnect( + RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config) + ) +else: + _with_vis = autoconnect() def _create_webcam() -> Webcam: @@ -137,6 +150,8 @@ def _create_webcam() -> Webcam: VoxelGridMapper.blueprint(), CostMapper.blueprint(), WavefrontFrontierExplorer.blueprint(), + # Visualization + WebsocketVisModule.blueprint(), ) .global_config(n_workers=4, robot_model="unitree_g1") .transports( diff --git a/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_security.py b/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_security.py index 4b39a106b8..be9e04a7fd 100644 --- a/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_security.py +++ b/dimos/robot/unitree/go2/blueprints/agentic/unitree_go2_security.py @@ -18,7 +18,7 @@ from dimos.core.coordination.blueprints import autoconnect from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.robot.unitree.go2.blueprints.agentic.unitree_go2_agentic import unitree_go2_agentic -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode def _convert_camera_info(camera_info: Any) -> Any: @@ -85,7 +85,7 @@ def _go2_rerun_blueprint() -> Any: unitree_go2_security = autoconnect( unitree_go2_agentic, - RerunBridgeModule.blueprint(**rerun_config), + RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config), ) __all__ = ["unitree_go2_security"] diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py index 4f86ccb0a3..54a2c0f7c6 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_basic.py @@ -22,9 +22,10 @@ from dimos.core.global_config import global_config from dimos.core.transport import pSHMTransport from dimos.msgs.sensor_msgs.Image import Image +from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator from dimos.robot.unitree.go2.connection import GO2Connection -from dimos.visualization.vis_module import vis_module +from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule # Mac has some issue with high bandwidth UDP, so we use pSHMTransport for color_image # actually we can use pSHMTransport for all platforms, and for all streams @@ -98,6 +99,9 @@ def _go2_rerun_blueprint() -> Any: rerun_config = { "blueprint": _go2_rerun_blueprint, + # any pubsub that supports subscribe_all and topic that supports str(topic) + # is acceptable here + "pubsubs": [LCM()], # Custom converters for specific rerun entity paths # Normally all these would be specified in their respectative modules # Until this is implemented we have central overrides here @@ -119,20 +123,30 @@ def _go2_rerun_blueprint() -> Any: }, } -_with_vis = autoconnect( - _transports_base, - vis_module( - viewer_backend=global_config.viewer, - rerun_config=rerun_config, - foxglove_config={"shm_channels": ["/color_image#sensor_msgs.Image"]}, - ), -) + +if global_config.viewer == "foxglove": + from dimos.robot.foxglove_bridge import FoxgloveBridge + + with_vis = autoconnect( + _transports_base, + FoxgloveBridge.blueprint(shm_channels=["/color_image#sensor_msgs.Image"]), + ) +elif global_config.viewer.startswith("rerun"): + from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode + + with_vis = autoconnect( + _transports_base, + RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config), + ) +else: + with_vis = _transports_base unitree_go2_basic = ( autoconnect( - _with_vis, + with_vis, GO2Connection.blueprint(), + WebsocketVisModule.blueprint(), ) .global_config(n_workers=4, robot_model="unitree_go2") .configurators(ClockSyncConfigurator()) diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py index bda362eeca..a7a10767bf 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_fleet.py @@ -22,13 +22,15 @@ from dimos.core.coordination.blueprints import autoconnect from dimos.protocol.service.system_configurator.clock_sync import ClockSyncConfigurator -from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import _with_vis +from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import with_vis from dimos.robot.unitree.go2.fleet_connection import Go2FleetConnection +from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule unitree_go2_fleet = ( autoconnect( - _with_vis, + with_vis, Go2FleetConnection.blueprint(), + WebsocketVisModule.blueprint(), ) .global_config(n_workers=4, robot_model="unitree_go2") .configurators(ClockSyncConfigurator()) diff --git a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_keyboard_teleop.py b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_keyboard_teleop.py index 3be0c62379..01117ec3b5 100644 --- a/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_keyboard_teleop.py +++ b/dimos/robot/unitree/go2/blueprints/basic/unitree_go2_webrtc_keyboard_teleop.py @@ -31,10 +31,6 @@ unitree_go2_webrtc_keyboard_teleop = autoconnect( unitree_go2_coordinator, KeyboardTeleop.blueprint(), -).remappings( - [ - (KeyboardTeleop, "tele_cmd_vel", "cmd_vel"), - ] ) __all__ = ["unitree_go2_webrtc_keyboard_teleop"] diff --git a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py index 16711115ab..f353d995af 100644 --- a/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py +++ b/dimos/robot/unitree/go2/blueprints/smart/unitree_go2.py @@ -27,7 +27,6 @@ ) from dimos.navigation.patrolling.module import PatrollingModule from dimos.navigation.replanning_a_star.module import ReplanningAStarPlanner -from dimos.navigation.smart_nav.modules.movement_manager.movement_manager import MovementManager from dimos.robot.unitree.go2.blueprints.basic.unitree_go2_basic import unitree_go2_basic unitree_go2 = autoconnect( @@ -37,8 +36,7 @@ ReplanningAStarPlanner.blueprint(), WavefrontFrontierExplorer.blueprint(), PatrollingModule.blueprint(), - MovementManager.blueprint(), -).global_config(n_workers=10, robot_model="unitree_go2") +).global_config(n_workers=9, robot_model="unitree_go2") class Go2MemoryConfig(RecorderConfig): @@ -54,6 +52,6 @@ class Go2Memory(Recorder): unitree_go2_memory = autoconnect( unitree_go2, Go2Memory.blueprint(), -).global_config(n_workers=11) +).global_config(n_workers=10) __all__ = ["unitree_go2", "unitree_go2_memory"] diff --git a/dimos/robot/unitree/keyboard_teleop.py b/dimos/robot/unitree/keyboard_teleop.py index 3e8f76a1cc..e3c78ecc52 100644 --- a/dimos/robot/unitree/keyboard_teleop.py +++ b/dimos/robot/unitree/keyboard_teleop.py @@ -38,14 +38,14 @@ class KeyboardTeleop(Module): """Pygame-based keyboard control module. - Outputs standard Twist messages on /tele_cmd_vel for velocity control. + Outputs standard Twist messages on /cmd_vel for velocity control. Speed constants can be tuned at the top of this file, or overridden per-instance by passing linear_speed / angular_speed / boost_multiplier / slow_multiplier to the constructor. """ - tele_cmd_vel: Out[Twist] # Standard velocity commands + cmd_vel: Out[Twist] # Standard velocity commands _stop_event: threading.Event _keys_held: set[int] | None = None @@ -86,7 +86,7 @@ def stop(self) -> None: stop_twist = Twist() stop_twist.linear = Vector3(0, 0, 0) stop_twist.angular = Vector3(0, 0, 0) - self.tele_cmd_vel.publish(stop_twist) + self.cmd_vel.publish(stop_twist) self._stop_event.set() @@ -119,7 +119,7 @@ def _pygame_loop(self) -> None: stop_twist = Twist() stop_twist.linear = Vector3(0, 0, 0) stop_twist.angular = Vector3(0, 0, 0) - self.tele_cmd_vel.publish(stop_twist) + self.cmd_vel.publish(stop_twist) print("EMERGENCY STOP!") elif event.key == pygame.K_ESCAPE: # ESC quits @@ -163,7 +163,7 @@ def _pygame_loop(self) -> None: twist.angular.z *= speed_multiplier # Always publish twist at 50Hz - self.tele_cmd_vel.publish(twist) + self.cmd_vel.publish(twist) self._update_display(twist) diff --git a/dimos/robot/unitree/mujoco_connection.py b/dimos/robot/unitree/mujoco_connection.py index 43ddeb6530..39c0904684 100644 --- a/dimos/robot/unitree/mujoco_connection.py +++ b/dimos/robot/unitree/mujoco_connection.py @@ -20,12 +20,9 @@ from collections.abc import Callable import functools import json -import os -from pathlib import Path import pickle import subprocess import sys -import sysconfig import threading import time from typing import Any, TypeVar @@ -129,23 +126,12 @@ def start(self) -> None: # Launch the subprocess try: - # mjpython must be used on macOS (because of launch_passive inside mujoco_process.py). - # It needs libpython on the dylib search path; uv-installed Pythons - # use @rpath which doesn't always resolve inside venvs, so we - # point DYLD_LIBRARY_PATH at the real libpython directory. + # mjpython must be used macOS (because of launch_passive inside mujoco_process.py) executable = sys.executable if sys.platform != "darwin" else "mjpython" - env = os.environ.copy() - if sys.platform == "darwin": - # on some systems mujoco looks in the wrong place for shared libraries. So we force it look in the right place - libdir = Path(sysconfig.get_config_var("LIBDIR") or "") - if libdir.is_dir(): - existing = env.get("DYLD_LIBRARY_PATH", "") - env["DYLD_LIBRARY_PATH"] = f"{libdir}:{existing}" if existing else str(libdir) self.process = subprocess.Popen( [executable, str(LAUNCHER_PATH), config_pickle, shm_names_json], stderr=subprocess.PIPE, - env=env, ) except Exception as e: diff --git a/dimos/simulation/unity/blueprint.py b/dimos/simulation/unity/blueprint.py index d9b29ee610..f7e2d34ccb 100644 --- a/dimos/simulation/unity/blueprint.py +++ b/dimos/simulation/unity/blueprint.py @@ -28,7 +28,7 @@ from dimos.core.coordination.blueprints import autoconnect from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.simulation.unity.module import UnityBridgeModule -from dimos.visualization.rerun.bridge import RerunBridgeModule +from dimos.visualization.rerun.bridge import RerunBridgeModule, _resolve_viewer_mode def _rerun_blueprint() -> Any: @@ -57,5 +57,5 @@ def _rerun_blueprint() -> Any: unity_sim = autoconnect( UnityBridgeModule.blueprint(), - RerunBridgeModule.blueprint(**rerun_config), + RerunBridgeModule.blueprint(viewer_mode=_resolve_viewer_mode(), **rerun_config), ) diff --git a/dimos/teleop/quest/blueprints.py b/dimos/teleop/quest/blueprints.py index b825f29a17..57c925c3f0 100644 --- a/dimos/teleop/quest/blueprints.py +++ b/dimos/teleop/quest/blueprints.py @@ -26,12 +26,12 @@ from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.teleop.quest.quest_extensions import ArmTeleopModule from dimos.teleop.quest.quest_types import Buttons -from dimos.visualization.vis_module import vis_module +from dimos.visualization.rerun.bridge import RerunBridgeModule # Arm teleop with press-and-hold engage (has rerun viz) teleop_quest_rerun = autoconnect( ArmTeleopModule.blueprint(), - vis_module("rerun"), + RerunBridgeModule.blueprint(), ).transports( { ("left_controller_output", PoseStamped): LCMTransport("/teleop/left_delta", PoseStamped), diff --git a/dimos/test_no_sections.py b/dimos/test_no_sections.py index 79f2d61b8f..902288b2e6 100644 --- a/dimos/test_no_sections.py +++ b/dimos/test_no_sections.py @@ -52,8 +52,6 @@ ".tox", # third-party vendored code "gtsam", - # hidden/personal directories - ".hidden", } # Lines that match section patterns but are actually programmatic / intentional. diff --git a/dimos/utils/generic.py b/dimos/utils/generic.py index 200c7c6d86..84168ce057 100644 --- a/dimos/utils/generic.py +++ b/dimos/utils/generic.py @@ -16,27 +16,10 @@ import hashlib import json import os -import socket import string from typing import Any, Generic, TypeVar, overload import uuid -import psutil - - -def get_local_ips() -> list[tuple[str, str]]: - """Return ``(ip, interface_name)`` for every non-loopback IPv4 address. - - Picks up physical, virtual, and VPN interfaces (including Tailscale). - """ - results: list[tuple[str, str]] = [] - for iface, addrs in psutil.net_if_addrs().items(): - for addr in addrs: - if addr.family == socket.AF_INET and not addr.address.startswith("127."): - results.append((addr.address, iface)) - return results - - _T = TypeVar("_T") diff --git a/dimos/visualization/rerun/bridge.py b/dimos/visualization/rerun/bridge.py index f6744e74fb..f2e3e51d08 100644 --- a/dimos/visualization/rerun/bridge.py +++ b/dimos/visualization/rerun/bridge.py @@ -18,19 +18,18 @@ from collections.abc import Callable from dataclasses import field -import socket +from functools import lru_cache import subprocess import time from typing import ( Any, + Literal, Protocol, TypeAlias, TypeGuard, cast, - get_args, runtime_checkable, ) -from urllib.parse import urlparse from reactivex.disposable import Disposable import rerun as rr @@ -38,23 +37,19 @@ import rerun.blueprint as rrb from rerun.blueprint import Blueprint from toolz import pipe # type: ignore[import-untyped] +import typer from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.protocol.pubsub.impl.lcmpubsub import LCM from dimos.protocol.pubsub.patterns import Glob, pattern_matches from dimos.protocol.pubsub.spec import SubscribeAllCapable -from dimos.utils.generic import get_local_ips from dimos.utils.logging_config import setup_logger -from dimos.visualization.rerun.constants import ( - RERUN_ENABLE_WEB, - RERUN_GRPC_PORT, - RERUN_OPEN_DEFAULT, - RERUN_WEB_PORT, - RerunOpenOption, -) from dimos.visualization.rerun.init import rerun_init +RERUN_GRPC_PORT = 9877 +RERUN_WEB_PORT = 9090 + # TODO OUT visual annotations # # In the future it would be nice if modules can annotate their individual OUTs with (general or rerun specific) @@ -100,6 +95,7 @@ BlueprintFactory: TypeAlias = Callable[[], "Blueprint"] +# to_rerun() can return a single archetype or a list of (entity_path, archetype) tuples RerunMulti: TypeAlias = "list[tuple[str, Archetype]]" RerunData: TypeAlias = "Archetype | RerunMulti" @@ -123,16 +119,18 @@ class RerunConvertible(Protocol): def to_rerun(self) -> RerunData: ... +ViewerMode = Literal["native", "web", "connect", "none"] + + def _hex_to_rgba(hex_color: str) -> int: """Convert '#RRGGBB' to a 0xRRGGBBAA int (fully opaque).""" h = hex_color.lstrip("#") - if len(h) == 6: - return int(h + "ff", 16) - return int(h[:8], 16) + return (int(h, 16) << 8) | 0xFF def _with_graph_tab(bp: Blueprint) -> Blueprint: """Add a Graph tab alongside the existing viewer layout without changing it.""" + root = bp.root_container return rrb.Blueprint( rrb.Tabs( @@ -158,24 +156,48 @@ def _default_blueprint() -> Blueprint: ) +# Maps global_config.viewer -> bridge viewer_mode. +# Evaluated at blueprint construction time (main process), not in start() (worker process). +_BACKEND_TO_MODE: dict[str, ViewerMode] = { + "rerun": "native", + "rerun-web": "web", + "rerun-connect": "connect", + "none": "none", +} + + +def _resolve_viewer_mode() -> ViewerMode: + from dimos.core.global_config import global_config + + return _BACKEND_TO_MODE.get(global_config.viewer, "native") + + class Config(ModuleConfig): + """Configuration for RerunBridgeModule.""" + pubsubs: list[SubscribeAllCapable[Any, Any]] = field(default_factory=lambda: [LCM()]) visual_override: dict[Glob | str, Callable[[Any], Archetype]] = field(default_factory=dict) + + # Static items logged once after start. Maps entity_path -> callable(rr) returning Archetype static: dict[str, Callable[[Any], Archetype]] = field(default_factory=dict) + + grpc_port: int = RERUN_GRPC_PORT + web_port: int = RERUN_WEB_PORT + + # Per-entity max update rate (Hz). Entities not listed are unthrottled. + # Use for heavy entities to prevent viewer backpressure. max_hz: dict[str, float] = field(default_factory=dict) entity_prefix: str = "world" topic_to_entity: Callable[[Any], str] | None = None + viewer_mode: ViewerMode = field(default_factory=_resolve_viewer_mode) connect_url: str = "rerun+http://127.0.0.1:9877/proxy" memory_limit: str = "25%" - rerun_open: RerunOpenOption = RERUN_OPEN_DEFAULT - rerun_web: bool = RERUN_ENABLE_WEB - web_port: int = RERUN_WEB_PORT - blueprint: BlueprintFactory | None = _default_blueprint - -Config.model_rebuild(_types_namespace={"Archetype": Archetype, "Blueprint": Blueprint}) + # Blueprint factory: callable(rrb) -> Blueprint for viewer layout configuration + # Set to None to disable default blueprint + blueprint: BlueprintFactory | None = _default_blueprint class RerunBridgeModule(Module): @@ -195,31 +217,22 @@ class RerunBridgeModule(Module): """ config: Config - _last_log: dict[str, float] # TODO this doesn't belong here, either hardcode it or put it to rerun bridge config - GRAPH_VIZ_SCALE = 100.0 - MODULE_RADIUS = 20.0 - CHANNEL_RADIUS = 12.0 - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._last_log = {} - self._override_cache: dict[str, Callable[[Any], RerunData | None]] = {} + GV_SCALE = 100.0 # graphviz inches to rerun screen units + MODULE_RADIUS = 30.0 + CHANNEL_RADIUS = 20.0 + @lru_cache(maxsize=256) def _visual_override_for_entity_path( self, entity_path: str ) -> Callable[[Any], RerunData | None]: """Return a composed visual override for the entity path. Chains matching overrides from config, ending with final_convert - which handles .to_rerun() or passes through Archetypes. Cached per - instance (not via ``lru_cache`` on a method, which would leak ``self``). + which handles .to_rerun() or passes through Archetypes. """ - cached = self._override_cache.get(entity_path) - if cached is not None: - return cached - + # find all matching converters for this entity path matches = [ fn for pattern, fn in self.config.visual_override.items() @@ -228,13 +241,9 @@ def _visual_override_for_entity_path( # None means "suppress this topic entirely" if any(fn is None for fn in matches): + return lambda msg: None - def suppressed(msg: Any) -> RerunData | None: - return None - - self._override_cache[entity_path] = suppressed - return suppressed - + # final step (ensures we return Archetype or None) def final_convert(msg: Any) -> RerunData | None: if isinstance(msg, Archetype): return msg @@ -244,21 +253,23 @@ def final_convert(msg: Any) -> RerunData | None: return msg.to_rerun() return None - def composed(msg: Any) -> RerunData | None: - return cast("RerunData | None", pipe(msg, *matches, final_convert)) - - self._override_cache[entity_path] = composed - return composed + # compose all converters + return lambda msg: pipe(msg, *matches, final_convert) def _get_entity_path(self, topic: Any) -> str: + """Convert a topic to a Rerun entity path.""" if self.config.topic_to_entity: return self.config.topic_to_entity(topic) + # Default: use topic.name if available (LCM Topic), else str topic_str = getattr(topic, "name", None) or str(topic) - topic_str = topic_str.split("#")[0] # strip LCM topic suffix + # Strip everything after # (LCM topic suffix) + topic_str = topic_str.split("#")[0] return f"{self.config.entity_prefix}{topic_str}" def _on_message(self, msg: Any, topic: Any) -> None: + """Handle incoming message - log to rerun.""" + entity_path: str = self._get_entity_path(topic) # Throttle entities with a max_hz limit @@ -268,6 +279,7 @@ def _on_message(self, msg: Any, topic: Any) -> None: return self._last_log[entity_path] = now + # apply visual overrides (including final_convert which handles .to_rerun()) rerun_data: RerunData | None = self._visual_override_for_entity_path(entity_path)(msg) if not rerun_data: @@ -284,87 +296,47 @@ def _on_message(self, msg: Any, topic: Any) -> None: def start(self) -> None: super().start() - logger.info("Rerun bridge starting") + logger.info("Rerun bridge starting", viewer_mode=self.config.viewer_mode) - self._last_log = {} + # Build throttle lookup: entity_path → min interval in seconds + self._last_log: dict[str, float] = {} self._min_intervals: dict[str, float] = { entity: 1.0 / hz for entity, hz in self.config.max_hz.items() if hz > 0 } + # Initialize and spawn Rerun viewer rerun_init("dimos") - parsed = urlparse(self.config.connect_url.replace("rerun+", "", 1)) - grpc_port = parsed.port or RERUN_GRPC_PORT - - port_in_use = False - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - port_in_use = sock.connect_ex(("127.0.0.1", grpc_port)) == 0 - - if port_in_use: - logger.info(f"gRPC port {grpc_port} already in use, connecting to existing server") - rr.connect_grpc(url=self.config.connect_url) - server_uri = self.config.connect_url - else: - server_uri = rr.serve_grpc( - grpc_port=grpc_port, - server_memory_limit=self.config.memory_limit, - ) - logger.info(f"Rerun gRPC server ready at {server_uri}") - - if self.config.rerun_open not in get_args(RerunOpenOption): - logger.warning( - f"rerun_open was {self.config.rerun_open} which is not one of " - f"{get_args(RerunOpenOption)}" - ) - - spawned = False - if self.config.rerun_open in ("native", "both"): + if self.config.viewer_mode == "native": try: import rerun_bindings - # Use --connect so the viewer connects to the bridge's gRPC - # server rather than starting its own (which would conflict). rerun_bindings.spawn( + port=self.config.grpc_port, executable_name="dimos-viewer", memory_limit=self.config.memory_limit, - extra_args=["--connect", server_uri], ) - spawned = True + rr.connect_grpc(f"rerun+http://127.0.0.1:{self.config.grpc_port}/proxy") except ImportError: - pass # dimos-viewer not installed + rr.spawn(connect=True, memory_limit=self.config.memory_limit) except Exception: logger.warning( "dimos-viewer found but failed to spawn, falling back to stock rerun", exc_info=True, ) + rr.spawn(connect=True, memory_limit=self.config.memory_limit) + elif self.config.viewer_mode == "web": + server_uri = rr.serve_grpc() + rr.serve_web_viewer(connect_to=server_uri, open_browser=False) - # fallback on normal (non-dimos-viewer) rerun - if not spawned: - try: - rr.spawn(connect=True, memory_limit=self.config.memory_limit) - spawned = True - except (RuntimeError, FileNotFoundError): - logger.warning( - "Rerun native viewer not available (headless?). " - "Bridge will continue without a viewer — data is still " - "accessible via --rerun-open web or by connecting a viewer to the gRPC server.", - exc_info=True, - ) - - open_web = self.config.rerun_open == "web" or self.config.rerun_open == "both" - if open_web or self.config.rerun_web: - rr.serve_web_viewer( - connect_to=server_uri, - open_browser=open_web, - web_port=self.config.web_port, - ) - - if self.config.rerun_open == "none" or (self.config.rerun_open == "native" and not spawned): - self._log_connect_hints(grpc_port) + elif self.config.viewer_mode == "connect": + rr.connect_grpc(self.config.connect_url) + # "none" - just init, no viewer (connect externally) if self.config.blueprint: rr.send_blueprint(_with_graph_tab(self.config.blueprint())) + # Start pubsubs and subscribe to all messages for pubsub in self.config.pubsubs: logger.info(f"bridge listening on {pubsub.__class__.__name__}") if hasattr(pubsub, "start"): @@ -372,35 +344,13 @@ def start(self) -> None: unsub = pubsub.subscribe_all(self._on_message) self.register_disposable(Disposable(unsub)) + # Add pubsub stop as disposable for pubsub in self.config.pubsubs: if hasattr(pubsub, "stop"): self.register_disposable(Disposable(pubsub.stop)) # type: ignore[union-attr] self._log_static() - def _log_connect_hints(self, grpc_port: int) -> None: - """Log CLI commands for connecting a viewer to this bridge.""" - local_ips = get_local_ips() - hostname = socket.gethostname() - connect_url = f"rerun+http://127.0.0.1:{grpc_port}/proxy" - - lines = [ - "", - "=" * 60, - "Rerun gRPC server running (no viewer opened)", - "", - "Connect a viewer:", - f" dimos-viewer --connect {connect_url}", - ] - for ip, iface in local_ips: - lines.append(f" dimos-viewer --connect rerun+http://{ip}:{grpc_port}/proxy # {iface}") - lines.append("") - lines.append(f" hostname: {hostname}") - lines.append("=" * 60) - lines.append("") - - logger.info("\n".join(lines)) - def _log_static(self) -> None: for entity_path, factory in self.config.static.items(): data = factory(rr) @@ -421,6 +371,7 @@ def log_blueprint_graph(self, dot_code: str, module_names: list[str]) -> None: dot_code: The DOT-format graph (from ``introspection.blueprint.dot.render``). module_names: List of module class names (to distinguish modules from channels). """ + try: result = subprocess.run( ["dot", "-Tplain"], input=dot_code, text=True, capture_output=True, timeout=30 @@ -442,8 +393,8 @@ def log_blueprint_graph(self, dot_code: str, module_names: list[str]) -> None: if line.startswith("node "): parts = line.split() node_id = parts[1].strip('"') - x = float(parts[2]) * self.GRAPH_VIZ_SCALE - y = -float(parts[3]) * self.GRAPH_VIZ_SCALE + x = float(parts[2]) * self.GV_SCALE + y = -float(parts[3]) * self.GV_SCALE label = parts[6].strip('"') color = parts[9].strip('"') @@ -476,5 +427,49 @@ def log_blueprint_graph(self, dot_code: str, module_names: list[str]) -> None: @rpc def stop(self) -> None: - self._override_cache.clear() super().stop() + + +def run_bridge( + viewer_mode: str = "native", + memory_limit: str = "25%", +) -> None: + """Start a RerunBridgeModule with default LCM config and block until interrupted.""" + import signal + + from dimos.protocol.service.lcmservice import autoconf + + autoconf(check_only=True) + + bridge = RerunBridgeModule( + viewer_mode=viewer_mode, + memory_limit=memory_limit, + # any pubsub that supports subscribe_all and topic that supports str(topic) + # is acceptable here + pubsubs=[LCM()], + ) + + bridge.start() + + signal.signal(signal.SIGINT, lambda *_: bridge.stop()) + signal.pause() + + +app = typer.Typer() + + +@app.command() +def cli( + viewer_mode: str = typer.Option( + "native", help="Viewer mode: native (desktop), web (browser), none (headless)" + ), + memory_limit: str = typer.Option( + "25%", help="Memory limit for Rerun viewer (e.g., '4GB', '16GB', '25%')" + ), +) -> None: + """Rerun bridge for LCM messages.""" + run_bridge(viewer_mode=viewer_mode, memory_limit=memory_limit) + + +if __name__ == "__main__": + app() diff --git a/dimos/visualization/rerun/conftest.py b/dimos/visualization/rerun/conftest.py deleted file mode 100644 index f269bb8015..0000000000 --- a/dimos/visualization/rerun/conftest.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import asyncio -from collections.abc import Callable -import time - -import pytest -import websockets.asyncio.client as ws_client - - -def _wait_for_server(port: int, timeout: float = 5.0) -> None: - """Block until the WebSocket server on *port* accepts a connection.""" - - async def _probe() -> None: - async with ws_client.connect(f"ws://127.0.0.1:{port}/ws"): - pass - - deadline = time.monotonic() + timeout - while time.monotonic() < deadline: - try: - asyncio.run(_probe()) - return - except Exception: - time.sleep(0.05) - raise TimeoutError(f"Server on port {port} did not become ready within {timeout}s") - - -@pytest.fixture() -def wait_for_server() -> Callable[[int, float], None]: - """Fixture that returns a callable to wait for a WebSocket server.""" - return _wait_for_server diff --git a/dimos/visualization/rerun/constants.py b/dimos/visualization/rerun/constants.py deleted file mode 100644 index 860c691cef..0000000000 --- a/dimos/visualization/rerun/constants.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Rerun visualization defaults and type aliases. - -This module is intentionally free of heavy imports so it can be -loaded from lightweight entry-points like ``global_config`` and -``dimos --help`` without pulling in the Rerun SDK or the module -framework. -""" - -from typing import Literal, TypeAlias - -ViewerBackend: TypeAlias = Literal["rerun", "foxglove", "none"] -RerunOpenOption: TypeAlias = Literal["none", "web", "native", "both"] - -RERUN_OPEN_DEFAULT: RerunOpenOption = "native" -RERUN_ENABLE_WEB = False -RERUN_GRPC_PORT = 9876 -RERUN_WEB_PORT = 9877 diff --git a/dimos/visualization/rerun/test_viewer_ws_e2e.py b/dimos/visualization/rerun/test_viewer_ws_e2e.py deleted file mode 100644 index 260699a3e8..0000000000 --- a/dimos/visualization/rerun/test_viewer_ws_e2e.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""End-to-end tests for dimos-viewer ↔ RerunWebSocketServer protocol.""" - -from __future__ import annotations - -import asyncio -import json -import os -import subprocess -import threading -import time -from typing import Any - -import pytest -import websockets.asyncio.client as ws_client - -from dimos.visualization.rerun.websocket_server import RerunWebSocketServer - -_E2E_PORT = 13032 - - -@pytest.fixture() -def server(wait_for_server: Any) -> RerunWebSocketServer: - module = RerunWebSocketServer(port=_E2E_PORT) - module.start() - wait_for_server(_E2E_PORT) - yield module # type: ignore[misc] - module.stop() - - -def _send_messages(port: int, messages: list[dict[str, Any]], *, delay: float = 0.05) -> None: - async def _run() -> None: - async with ws_client.connect(f"ws://127.0.0.1:{port}/ws") as ws: - for msg in messages: - await ws.send(json.dumps(msg)) - await asyncio.sleep(delay) - - asyncio.run(_run()) - - -class TestViewerProtocolE2E: - """Verify the Python-server side of the viewer ↔ DimOS protocol.""" - - def test_viewer_click_reaches_stream(self, server: RerunWebSocketServer) -> None: - """A viewer click over WebSocket publishes PointStamped.""" - received: list[Any] = [] - done = threading.Event() - unsub = server.clicked_point.subscribe(lambda pt: (received.append(pt), done.set())) - - _send_messages( - _E2E_PORT, - [ - { - "type": "click", - "x": 10.0, - "y": 20.0, - "z": 0.5, - "entity_path": "/world/robot", - "timestamp_ms": 42000, - } - ], - ) - - done.wait(timeout=3.0) - unsub() - - assert len(received) == 1 - pt = received[0] - assert pt.x == pytest.approx(10.0) - assert pt.y == pytest.approx(20.0) - assert pt.z == pytest.approx(0.5) - assert pt.frame_id == "/world/robot" - assert pt.ts == pytest.approx(42.0) - - def test_full_viewer_session_sequence(self, server: RerunWebSocketServer) -> None: - """Realistic session: heartbeats, click, twist, stop — only the click produces a point.""" - received: list[Any] = [] - done = threading.Event() - unsub = server.clicked_point.subscribe(lambda pt: (received.append(pt), done.set())) - - _send_messages( - _E2E_PORT, - [ - {"type": "heartbeat", "timestamp_ms": 1000}, - {"type": "heartbeat", "timestamp_ms": 2000}, - { - "type": "click", - "x": 3.14, - "y": 2.71, - "z": 1.41, - "entity_path": "/world", - "timestamp_ms": 3000, - }, - { - "type": "twist", - "linear_x": 0.5, - "linear_y": 0.0, - "linear_z": 0.0, - "angular_x": 0.0, - "angular_y": 0.0, - "angular_z": 0.0, - }, - {"type": "stop"}, - {"type": "heartbeat", "timestamp_ms": 4000}, - ], - delay=0.2, - ) - - done.wait(timeout=3.0) - unsub() - - assert len(received) == 1, f"Expected exactly 1 click, got {len(received)}" - assert received[0].x == pytest.approx(3.14) - assert received[0].y == pytest.approx(2.71) - assert received[0].z == pytest.approx(1.41) - - def test_reconnect_after_disconnect(self, server: RerunWebSocketServer) -> None: - """Server keeps accepting new connections after a client disconnects.""" - received: list[Any] = [] - all_done = threading.Event() - - def _on_pt(pt: Any) -> None: - received.append(pt) - if len(received) >= 2: - all_done.set() - - unsub = server.clicked_point.subscribe(_on_pt) - - _send_messages( - _E2E_PORT, - [{"type": "click", "x": 1.0, "y": 0.0, "z": 0.0, "entity_path": "", "timestamp_ms": 0}], - ) - _send_messages( - _E2E_PORT, - [{"type": "click", "x": 2.0, "y": 0.0, "z": 0.0, "entity_path": "", "timestamp_ms": 0}], - ) - - all_done.wait(timeout=5.0) - unsub() - - xs = sorted(pt.x for pt in received) - assert xs == [1.0, 2.0], f"Unexpected xs: {xs}" - - -class TestViewerBinaryConnectMode: - """Smoke test: dimos-viewer binary starts in --connect mode.""" - - @pytest.fixture() - def viewer_process(self, server: RerunWebSocketServer) -> subprocess.Popen[bytes]: - proc = subprocess.Popen( - [ - "dimos-viewer", - "--connect", - f"--ws-url=ws://127.0.0.1:{_E2E_PORT}/ws", - ], - env={**os.environ, "DISPLAY": ""}, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - yield proc # type: ignore[misc] - proc.terminate() - try: - proc.wait(timeout=3) - except subprocess.TimeoutExpired: - proc.kill() - - @pytest.mark.skip( - reason="Incompatible with current winit: fails without DISPLAY (headless CI exits before WS connect) and hangs with DISPLAY (GUI event loop blocks before printing URL).", - ) - def test_viewer_ws_client_connects(self, viewer_process: subprocess.Popen[bytes]) -> None: - """dimos-viewer --connect starts and its WS client connects to our server.""" - deadline = time.monotonic() + 5.0 - while time.monotonic() < deadline: - if viewer_process.poll() is not None: - break - time.sleep(0.1) - - stdout = ( - viewer_process.stdout.read().decode(errors="replace") if viewer_process.stdout else "" - ) - stderr = ( - viewer_process.stderr.read().decode(errors="replace") if viewer_process.stderr else "" - ) - - combined = stdout + stderr - assert f"ws://127.0.0.1:{_E2E_PORT}" in combined, ( - f"Viewer did not attempt WS connection.\nstdout:\n{stdout}\nstderr:\n{stderr}" - ) diff --git a/dimos/visualization/rerun/test_websocket_server.py b/dimos/visualization/rerun/test_websocket_server.py deleted file mode 100644 index b4304cf7b4..0000000000 --- a/dimos/visualization/rerun/test_websocket_server.py +++ /dev/null @@ -1,210 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for RerunWebSocketServer.""" - -from __future__ import annotations - -import asyncio -import json -import threading -import time -from typing import Any - -import pytest -import websockets.asyncio.client as ws_client - -from dimos.visualization.rerun.websocket_server import RerunWebSocketServer - -_TEST_PORT = 13031 - - -class MockViewerPublisher: - """Simulates dimos-viewer sending JSON events over WebSocket.""" - - def __init__(self, url: str) -> None: - self._url = url - self._ws: Any = None - self._loop: asyncio.AbstractEventLoop | None = None - - def __enter__(self) -> MockViewerPublisher: - self._loop = asyncio.new_event_loop() - self._ws = self._loop.run_until_complete(self._connect()) - return self - - def __exit__(self, *_: Any) -> None: - if self._ws is not None and self._loop is not None: - self._loop.run_until_complete(self._ws.close()) - if self._loop is not None: - self._loop.close() - - async def _connect(self) -> Any: - return await ws_client.connect(self._url) - - def send_click( - self, x: float, y: float, z: float, entity_path: str = "", timestamp_ms: int = 0 - ) -> None: - self._send( - { - "type": "click", - "x": x, - "y": y, - "z": z, - "entity_path": entity_path, - "timestamp_ms": timestamp_ms, - } - ) - - def send_twist( - self, - linear_x: float, - linear_y: float, - linear_z: float, - angular_x: float, - angular_y: float, - angular_z: float, - ) -> None: - self._send( - { - "type": "twist", - "linear_x": linear_x, - "linear_y": linear_y, - "linear_z": linear_z, - "angular_x": angular_x, - "angular_y": angular_y, - "angular_z": angular_z, - } - ) - - def send_stop(self) -> None: - self._send({"type": "stop"}) - - def flush(self, delay: float = 0.1) -> None: - time.sleep(delay) - - def _send(self, msg: dict[str, Any]) -> None: - assert self._loop is not None and self._ws is not None - self._loop.run_until_complete(self._ws.send(json.dumps(msg))) - - -@pytest.fixture() -def server(wait_for_server: Any) -> RerunWebSocketServer: - module = RerunWebSocketServer(port=_TEST_PORT) - module.start() - wait_for_server(_TEST_PORT) - yield module # type: ignore[misc] - module.stop() - - -@pytest.fixture() -def publisher(server: RerunWebSocketServer) -> MockViewerPublisher: - with MockViewerPublisher(f"ws://127.0.0.1:{_TEST_PORT}/ws") as publisher: - yield publisher # type: ignore[misc] - - -# ── Tests ──────────────────────────────────────────────────────────────── - - -def test_click_publishes_point_stamped( - server: RerunWebSocketServer, publisher: MockViewerPublisher -) -> None: - """Click event arrives as PointStamped with correct coords, frame_id, and timestamp.""" - received: list[Any] = [] - done = threading.Event() - - unsub = server.clicked_point.subscribe(lambda point: (received.append(point), done.set())) - - publisher.send_click(1.5, 2.5, 0.0, "/robot/base", timestamp_ms=5000) - publisher.flush() - done.wait(timeout=2.0) - unsub() - - assert len(received) == 1 - point = received[0] - assert point.x == pytest.approx(1.5) - assert point.y == pytest.approx(2.5) - assert point.z == pytest.approx(0.0) - assert point.frame_id == "/robot/base" - assert point.ts == pytest.approx(5.0) - - -def test_twist_publishes_on_tele_cmd_vel( - server: RerunWebSocketServer, publisher: MockViewerPublisher -) -> None: - """Twist event arrives as Twist on tele_cmd_vel.""" - received: list[Any] = [] - done = threading.Event() - - unsub = server.tele_cmd_vel.subscribe(lambda twist: (received.append(twist), done.set())) - - publisher.send_twist(0.5, 0.0, 0.0, 0.0, 0.0, 0.8) - publisher.flush() - done.wait(timeout=2.0) - unsub() - - assert len(received) == 1 - assert received[0].linear.x == pytest.approx(0.5) - assert received[0].angular.z == pytest.approx(0.8) - - -def test_stop_publishes_zero_twist( - server: RerunWebSocketServer, publisher: MockViewerPublisher -) -> None: - """Stop event publishes a zero Twist on tele_cmd_vel.""" - received: list[Any] = [] - done = threading.Event() - - unsub = server.tele_cmd_vel.subscribe(lambda twist: (received.append(twist), done.set())) - - publisher.send_stop() - publisher.flush() - done.wait(timeout=2.0) - unsub() - - assert len(received) == 1 - assert received[0].is_zero() - - -def test_invalid_json_does_not_crash(server: RerunWebSocketServer) -> None: - """Malformed JSON is silently dropped; server stays alive for the next message.""" - - async def _send_bad() -> None: - async with ws_client.connect(f"ws://127.0.0.1:{_TEST_PORT}/ws") as ws: - await ws.send("this is not json {{") - await asyncio.sleep(0.1) - await ws.send(json.dumps({"type": "heartbeat", "timestamp_ms": 0})) - await asyncio.sleep(0.1) - - asyncio.run(_send_bad()) - - -def test_mixed_message_sequence( - server: RerunWebSocketServer, publisher: MockViewerPublisher -) -> None: - """Realistic session: heartbeat, click, twist, stop — only the click produces a point.""" - received: list[Any] = [] - done = threading.Event() - unsub = server.clicked_point.subscribe(lambda point: (received.append(point), done.set())) - - publisher.send_click(7.0, 8.0, 9.0, "/map", timestamp_ms=1100) - publisher.send_twist(0.3, 0.0, 0.0, 0.0, 0.0, 0.2) - publisher.send_stop() - publisher.flush() - done.wait(timeout=2.0) - unsub() - - assert len(received) == 1 - assert received[0].x == pytest.approx(7.0) - assert received[0].y == pytest.approx(8.0) - assert received[0].z == pytest.approx(9.0) diff --git a/dimos/visualization/rerun/websocket_server.py b/dimos/visualization/rerun/websocket_server.py deleted file mode 100644 index 0c0ac2acf2..0000000000 --- a/dimos/visualization/rerun/websocket_server.py +++ /dev/null @@ -1,244 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""WebSocket server module that receives events from dimos-viewer. - -When dimos-viewer is started with ``--connect``, LCM multicast is unavailable -across machines. The viewer falls back to sending click, twist, and stop events -as JSON over a WebSocket connection. This module acts as the server-side -counterpart: it listens for those connections and translates incoming messages -into DimOS stream publishes. - -Message format (newline-delimited JSON, ``"type"`` discriminant): - - {"type":"heartbeat","timestamp_ms":1234567890} - {"type":"click","x":1.0,"y":2.0,"z":3.0,"entity_path":"/world","timestamp_ms":1234567890} - {"type":"twist","linear_x":0.5,"linear_y":0.0,"linear_z":0.0, - "angular_x":0.0,"angular_y":0.0,"angular_z":0.8} - {"type":"stop"} -""" - -from __future__ import annotations - -import asyncio -import json -import logging -import socket -import threading -from typing import Any, Literal, TypedDict, Union - -import websockets -import websockets.asyncio.server as ws_server - -from dimos.core.core import rpc -from dimos.core.global_config import global_config -from dimos.core.module import Module, ModuleConfig -from dimos.core.stream import Out -from dimos.msgs.geometry_msgs.PointStamped import PointStamped -from dimos.msgs.geometry_msgs.Twist import Twist -from dimos.msgs.geometry_msgs.Vector3 import Vector3 -from dimos.utils.generic import get_local_ips -from dimos.utils.logging_config import setup_logger -from dimos.visualization.rerun.constants import RERUN_GRPC_PORT - -logger = setup_logger() - - -class ClickMsg(TypedDict): - type: Literal["click"] - x: float - y: float - z: float - entity_path: str - timestamp_ms: int - - -class TwistMsg(TypedDict): - type: Literal["twist"] - linear_x: float - linear_y: float - linear_z: float - angular_x: float - angular_y: float - angular_z: float - - -class StopMsg(TypedDict): - type: Literal["stop"] - - -class HeartbeatMsg(TypedDict): - type: Literal["heartbeat"] - timestamp_ms: int - - -ViewerMsg = Union[ClickMsg, TwistMsg, StopMsg, HeartbeatMsg] - - -def _handshake_noise_filter(record: logging.LogRecord) -> bool: - """Drop noisy "opening handshake failed" records from port scanners etc.""" - msg = record.getMessage() - return not ("opening handshake failed" in msg or "did not receive a valid HTTP request" in msg) - - -class Config(ModuleConfig): - host: str | None = None - port: int = 3030 - start_timeout: float = 10.0 - - -class RerunWebSocketServer(Module): - """Receives dimos-viewer WebSocket events and publishes them as DimOS streams. - - The viewer connects to this module (not the other way around) when running - in ``--connect`` mode. Each click event is converted to a ``PointStamped`` - and published on the ``clicked_point`` stream so downstream modules (e.g. - ``ReplanningAStarPlanner``) can consume it without modification. - - Outputs: - clicked_point: 3-D world-space point from the most recent viewer click. - tele_cmd_vel: Twist velocity commands from keyboard teleop, including stop events. - - Note: ``stop_movement`` is owned by ``MovementManager`` — it will fire - that signal when it sees the first teleop twist arrive here. - """ - - config: Config - - clicked_point: Out[PointStamped] - tele_cmd_vel: Out[Twist] - - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._stop_event: asyncio.Event | None = None - self._server_ready = threading.Event() - self.host = self.config.host if self.config.host is not None else global_config.listen_host - - @rpc - def start(self) -> None: - super().start() - assert self._loop is not None - asyncio.run_coroutine_threadsafe(self._serve(), self._loop) - self._server_ready.wait(timeout=self.config.start_timeout) - self._log_connect_hints() - - @rpc - def stop(self) -> None: - self._server_ready.wait(timeout=self.config.start_timeout) - if self._loop is not None and not self._loop.is_closed() and self._stop_event is not None: - self._loop.call_soon_threadsafe(self._stop_event.set) - super().stop() - - def _log_connect_hints(self) -> None: - """Log full dimos-viewer commands that viewers can use to connect.""" - local_ips = get_local_ips() - hostname = socket.gethostname() - host = self.host - ws_url = f"ws://{host}:{self.config.port}/ws" - grpc_url = f"rerun+http://{host}:{RERUN_GRPC_PORT}/proxy" - - lines = [ - "", - "=" * 60, - f"RerunWebSocketServer listening on {ws_url}", - "", - "Connect a viewer:", - f" dimos-viewer --connect {grpc_url} --ws-url {ws_url}", - ] - if local_ips: - lines.append("") - lines.append("From another machine on the network:") - for ip, iface in local_ips: - remote_grpc = f"rerun+http://{ip}:{RERUN_GRPC_PORT}/proxy" - remote_ws = f"ws://{ip}:{self.config.port}/ws" - lines.append( - f" dimos-viewer --connect {remote_grpc} --ws-url {remote_ws} # {iface}" - ) - lines.append("") - lines.append(f" hostname: {hostname}") - lines.append("=" * 60) - lines.append("") - - logger.info("\n".join(lines)) - - async def _serve(self) -> None: - self._stop_event = asyncio.Event() - - ws_logger = logging.getLogger("websockets.server") - ws_logger.addFilter(_handshake_noise_filter) - - async with ws_server.serve( - self._handle_client, - host=self.host, - port=self.config.port, - ping_interval=30, - ping_timeout=30, - logger=ws_logger, - ): - self._server_ready.set() - await self._stop_event.wait() - - async def _handle_client(self, websocket: Any) -> None: - if hasattr(websocket, "request") and websocket.request.path != "/ws": - await websocket.close(1008, "Not Found") - return - addr = websocket.remote_address - logger.info(f"RerunWebSocketServer: viewer connected from {addr}") - try: - async for raw in websocket: - self._dispatch(raw) - except websockets.ConnectionClosed: - pass - - def _dispatch(self, raw: str | bytes) -> None: - try: - msg: dict[str, Any] = json.loads(raw) - except json.JSONDecodeError: - logger.warning(f"RerunWebSocketServer: ignoring non-JSON message: {raw!r}") - return - - if not isinstance(msg, dict): - return - - msg_type = msg.get("type") - - if msg_type == "click": - self.clicked_point.publish( - PointStamped( - x=float(msg.get("x", 0)), - y=float(msg.get("y", 0)), - z=float(msg.get("z", 0)), - ts=float(msg.get("timestamp_ms", 0)) / 1000.0, - frame_id=str(msg.get("entity_path", "")), - ) - ) - - elif msg_type == "twist": - self.tele_cmd_vel.publish( - Twist( - linear=Vector3( - float(msg.get("linear_x", 0)), - float(msg.get("linear_y", 0)), - float(msg.get("linear_z", 0)), - ), - angular=Vector3( - float(msg.get("angular_x", 0)), - float(msg.get("angular_y", 0)), - float(msg.get("angular_z", 0)), - ), - ) - ) - - elif msg_type == "stop": - self.tele_cmd_vel.publish(Twist.zero()) diff --git a/dimos/visualization/vis_module.py b/dimos/visualization/vis_module.py deleted file mode 100644 index badcba34db..0000000000 --- a/dimos/visualization/vis_module.py +++ /dev/null @@ -1,87 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Shared visualization module factory for all robot blueprints.""" - -from typing import Any, get_args - -from dimos.core.coordination.blueprints import Blueprint, autoconnect -from dimos.visualization.rerun.constants import ViewerBackend - - -def vis_module( - viewer_backend: ViewerBackend, - rerun_config: dict[str, Any] | None = None, - foxglove_config: dict[str, Any] | None = None, -) -> Blueprint: - """Create a visualization blueprint based on the selected viewer backend. - - Bundles the appropriate viewer module (Rerun or Foxglove) together with - the ``WebsocketVisModule`` and ``RerunWebSocketServer`` so that the web - dashboard and remote viewer connections work out of the box. - - Example usage:: - - from dimos.core.global_config import global_config - viz = vis_module( - global_config.viewer, - rerun_config={ - "visual_override": { - "world/camera_info": lambda ci: ci.to_rerun(...), - }, - "static": { - "world/tf/base_link": lambda rr: [rr.Boxes3D(...)], - }, - }, - ) - """ - from dimos.visualization.rerun.websocket_server import RerunWebSocketServer - from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule - - if foxglove_config is None: - foxglove_config = {} - if rerun_config is None: - rerun_config = {} - - match viewer_backend: - case "foxglove": - from dimos.robot.foxglove_bridge import FoxgloveBridge - - return autoconnect( - FoxgloveBridge.blueprint(**foxglove_config), - RerunWebSocketServer.blueprint(), - WebsocketVisModule.blueprint(), - ) - case "rerun": - from dimos.core.global_config import global_config - from dimos.protocol.pubsub.impl.lcmpubsub import LCM - from dimos.visualization.rerun.bridge import RerunBridgeModule - - rerun_config = {**rerun_config} # copy (avoid mutation) - rerun_config.setdefault("pubsubs", [LCM()]) - rerun_config.setdefault("rerun_open", global_config.rerun_open) - rerun_config.setdefault("rerun_web", global_config.rerun_web) - return autoconnect( - RerunBridgeModule.blueprint( - **rerun_config, - ), - RerunWebSocketServer.blueprint(), - WebsocketVisModule.blueprint(), - ) - case "none": - return autoconnect(WebsocketVisModule.blueprint()) - case _: - valid = ", ".join(get_args(ViewerBackend)) - raise ValueError(f"Unknown viewer_backend {viewer_backend!r}. Expected one of: {valid}") diff --git a/dimos/web/websocket_vis/websocket_vis_module.py b/dimos/web/websocket_vis/websocket_vis_module.py index 1ce7e74502..3d6b3df11c 100644 --- a/dimos/web/websocket_vis/websocket_vis_module.py +++ b/dimos/web/websocket_vis/websocket_vis_module.py @@ -105,7 +105,7 @@ class WebsocketVisModule(Module): gps_goal: Out[LatLon] explore_cmd: Out[Bool] stop_explore_cmd: Out[Bool] - tele_cmd_vel: Out[Twist] + cmd_vel: Out[Twist] movecmd_stamped: Out[TwistStamped] def __init__(self, **kwargs: Any) -> None: @@ -158,11 +158,9 @@ def start(self) -> None: self._uvicorn_server_thread = threading.Thread(target=self._run_uvicorn_server, daemon=True) self._uvicorn_server_thread.start() - # Auto-open the dashboard tab only when the user explicitly asked for a - # web-based viewer (rerun_open == "web" or "both"). `rerun_web` alone - # only means "serve the viewer"; it should not trigger a browser popup - # when the user chose the native viewer. - if self.config.g.viewer == "rerun" and self.config.g.rerun_open in ("web", "both"): + # Auto-open browser only for rerun-web (dashboard with Rerun iframe + command center) + # For rerun and foxglove, users access the command center manually if needed + if self.config.g.viewer == "rerun-web": url = f"http://localhost:{self.config.port}/" logger.info(f"Dimensional Command Center: {url}") @@ -238,13 +236,11 @@ def _create_server(self) -> None: async def serve_index(request): # type: ignore[no-untyped-def] """Serve appropriate HTML based on viewer mode.""" - # Serve the full dashboard (with Rerun iframe) only when the rerun - # web server is enabled; otherwise redirect to the standalone - # command center. - if not ( - self.config.g.viewer == "rerun" and self.config.g.rerun_open in ("web", "both") - ): + # If running native Rerun, redirect to standalone command center + if self.config.g.viewer != "rerun-web": return RedirectResponse(url="/command-center") + + # Otherwise serve full dashboard with Rerun iframe return FileResponse(_DASHBOARD_HTML, media_type="text/html") async def serve_command_center(request): # type: ignore[no-untyped-def] @@ -337,14 +333,14 @@ async def clear_gps_goals(sid: str) -> None: @self.sio.event # type: ignore[untyped-decorator] async def move_command(sid: str, data: dict[str, Any]) -> None: # Publish Twist if transport is configured - if self.tele_cmd_vel and self.tele_cmd_vel.transport: + if self.cmd_vel and self.cmd_vel.transport: twist = Twist( linear=Vector3(data["linear"]["x"], data["linear"]["y"], data["linear"]["z"]), angular=Vector3( data["angular"]["x"], data["angular"]["y"], data["angular"]["z"] ), ) - self.tele_cmd_vel.publish(twist) + self.cmd_vel.publish(twist) # Publish TwistStamped if transport is configured if self.movecmd_stamped and self.movecmd_stamped.transport: diff --git a/docs/development/conventions.md b/docs/development/conventions.md deleted file mode 100644 index 2b25a7c3c6..0000000000 --- a/docs/development/conventions.md +++ /dev/null @@ -1,12 +0,0 @@ -This mostly to track when conventions change (with regard to codebase updates) because this codebase is under heavy development. Note: this is a non-exhaustive list of conventions. - -- Instead of using `RerunBridge` in blueprints we always use `vis_module` which allows the CLI to control if its foxglove, rerun, or no-vis at all -- When global_config.py shouldn't accidentally/indirectly import heavy libraries like rerun. But sometimes global_config needs the type definition or default value from a module. Preferably we import from the module file directly, however when thats not possible, we create a config.py for just that module's config and import that into global_config.py. -- When adding visualization tools to a blueprint/autoconnect, instead of using RerunBridge or WebsocketVisModule directly we should always use `vis_module`, which right now should look something like `vis_module(viewer_backend=global_config.viewer, rerun_config={}),` -- `DEFAULT_THREAD_JOIN_TIMEOUT` is used for all thread.join timeouts -- Don't use print inside of tests -- Module configs should be specified as `config: ModuleSpecificConfigClass` -- To customize the way rerun renders something, right now we use a `rerun_config` dict. This will (hopefully) change very soon to be a per-module config instead of a per-blueprint config -- Similar to the `rerun_config` the `rrb` (rerun blueprint) is defined at a blueprint level right now, but ideally would be a per-module contribution with only a per-blueprint override of the layout. -- No `__init__.py` files -- Helper blueprints (like `_with_vis`) that should not be used on their own need to start with an underscore to avoid being picked up by the all_blueprints.py code generation step diff --git a/docs/usage/cli.md b/docs/usage/cli.md index bba73368b2..017b441c7e 100644 --- a/docs/usage/cli.md +++ b/docs/usage/cli.md @@ -18,9 +18,7 @@ dimos [GLOBAL OPTIONS] COMMAND [ARGS] | `--replay` / `--no-replay` | bool | `False` | Use recorded replay data | | `--replay-db` | TEXT | `go2_bigoffice` | Replay memory2 SQLite database name | | `--new-memory` / `--no-new-memory` | bool | `False` | Clear persistent memory on start | -| `--viewer` | `rerun\|foxglove\|none` | `rerun` | Visualization backend | -| `--rerun-open` | `native\|web\|both\|none` | `native` | How to open the Rerun viewer | -| `--rerun-web` / `--no-rerun-web` | bool | `False` | Serve the Rerun web viewer | +| `--viewer` | `rerun\|rerun-web\|rerun-connect\|foxglove\|none` | `rerun` | Visualization backend | | `--n-workers` | INT | `2` | Number of forkserver workers | | `--memory-limit` | TEXT | `auto` | Rerun viewer memory limit | | `--mcp-port` | INT | `9990` | MCP server port | diff --git a/docs/usage/visualization.md b/docs/usage/visualization.md index 9ece977a68..57ad460354 100644 --- a/docs/usage/visualization.md +++ b/docs/usage/visualization.md @@ -1,43 +1,37 @@ # Viewer Backends -Dimos supports three visualization backends: `rerun` (default), `foxglove`, and `none`. +Dimos supports three visualization backends: Rerun (web or native) and Foxglove. ## Quick Start -Choose your viewer via the CLI: +Choose your viewer via the CLI (preferred): ```bash # Rerun native viewer (default) - dimos-viewer with built-in teleop + click-to-navigate dimos run unitree-go2 -# Explicitly select the viewer backend: +# Explicitly select the viewer mode: dimos --viewer rerun run unitree-go2 +dimos --viewer rerun-web run unitree-go2 dimos --viewer foxglove run unitree-go2 -dimos --viewer none run unitree-go2 ``` -Control how the Rerun viewer opens with `--rerun-open` and `--rerun-web`: +Alternative (environment variable): ```bash -# Open native desktop viewer (default) -dimos --rerun-open native run unitree-go2 - -# Open web viewer in browser -dimos --rerun-open web run unitree-go2 - -# Open both native and web -dimos --rerun-open both run unitree-go2 +# Rerun native viewer (default) - dimos-viewer with built-in teleop + click-to-navigate +VIEWER=rerun dimos run unitree-go2 -# No viewer (headless) — data still accessible via gRPC -dimos --rerun-open none run unitree-go2 +# Rerun web viewer - browser dashboard + teleop at http://localhost:7779 +VIEWER=rerun-web dimos run unitree-go2 -# Serve the web viewer without auto-opening a browser -dimos --rerun-web --rerun-open native run unitree-go2 +# Foxglove - Use Foxglove Studio instead of Rerun +VIEWER=foxglove dimos run unitree-go2 ``` ## Viewer Modes Explained -### Rerun Native (`rerun`, `--rerun-open native`) — Default +### Rerun Native (`rerun`) — Default **What you get:** - [dimos-viewer](https://github.com/dimensionalOS/dimos-viewer), a custom Dimensional fork of Rerun with built-in keyboard teleop and click-to-navigate @@ -47,7 +41,7 @@ dimos --rerun-web --rerun-open native run unitree-go2 --- -### Rerun Web (`rerun`, `--rerun-open web`) +### Rerun Web (`rerun-web`) **What you get:** - Browser-based dashboard at http://localhost:7779 @@ -69,16 +63,18 @@ dimos --rerun-web --rerun-open native run unitree-go2 ## Rendering with Custom Blueprints -To enable visualization in your own blueprint, use `vis_module`: +To enable rerun within your own blueprint simply include `RerunBridgeModule`: ```python -from dimos.core.global_config import global_config -from dimos.visualization.vis_module import vis_module +from dimos.visualization.rerun.bridge import RerunBridgeModule from dimos.hardware.sensors.camera.module import CameraModule +from dimos.protocol.pubsub.impl.lcmpubsub import LCM camera_demo = autoconnect( CameraModule.blueprint(), - vis_module(viewer_backend=global_config.viewer), + RerunBridgeModule.blueprint( + viewer_mode="native", # native (desktop), web (browser), none (headless) + ), ) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index 1ca77c7c7f..398903f457 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,7 +82,7 @@ dependencies = [ # TODO: rerun shouldn't be required but rn its in core (there is NO WAY to use dimos without rerun rn) # remove this once rerun is optional in core "rerun-sdk>=0.20.0", - "dimos-viewer==0.30.0a6.dev99", + "dimos-viewer>=0.30.0a2", "toolz>=1.1.0", "protobuf>=6.33.5,<7", "psutil>=7.0.0", diff --git a/uv.lock b/uv.lock index 5cb8d2ef0c..f429ae7b71 100644 --- a/uv.lock +++ b/uv.lock @@ -1995,7 +1995,7 @@ requires-dist = [ { name = "dimos", extras = ["unitree"], marker = "extra == 'unitree-dds'" }, { name = "dimos-lcm" }, { name = "dimos-lcm", marker = "extra == 'docker'" }, - { name = "dimos-viewer", specifier = "==0.30.0a6.dev99" }, + { name = "dimos-viewer", specifier = ">=0.30.0a2" }, { name = "dimos-viewer", marker = "extra == 'visualization'", specifier = ">=0.30.0a4" }, { name = "drake", marker = "platform_machine != 'aarch64' and sys_platform == 'darwin' and extra == 'manipulation'", specifier = "==1.45.0" }, { name = "drake", marker = "platform_machine != 'aarch64' and sys_platform != 'darwin' and extra == 'manipulation'", specifier = ">=1.40.0" }, @@ -2168,18 +2168,18 @@ wheels = [ [[package]] name = "dimos-viewer" -version = "0.30.0a6.dev99" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3d/0e/d363be05f172bafe5f41a95db318891637e902c50edfdc642edec6bb5111/dimos_viewer-0.30.0a6.dev99-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cfa57e68e8f4094d4a38d202414046fd2419ff2875ace3f16b8581c3106feca4", size = 35405401, upload-time = "2026-04-17T04:19:10.126Z" }, - { url = "https://files.pythonhosted.org/packages/e7/ab/0730fed402b3b92e35194f11b76119754d619fa6bab00a1932b5c78f87b3/dimos_viewer-0.30.0a6.dev99-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:f3bc243342131c8c2b653cc6b76f04d65aad525f5560829b78aa1a7d31a9d375", size = 39167146, upload-time = "2026-04-17T04:19:14.177Z" }, - { url = "https://files.pythonhosted.org/packages/bb/d9/1415d5d7e609d69b05e8e1167a66dd7cb78f3933205f9b321ae18233384c/dimos_viewer-0.30.0a6.dev99-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:b954083fcb8951641554fdea95425b3b5ac9415cd1b65410a137d38d3dd57b8a", size = 41536165, upload-time = "2026-04-17T04:19:17.379Z" }, - { url = "https://files.pythonhosted.org/packages/93/7c/7ee6049a753c01ccbe8357f9c5f789378103b87331e5ca7977f05adf5c42/dimos_viewer-0.30.0a6.dev99-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0387201efd1260f968853f0d7863876b6db375b2af15b22f221a893fcce6549c", size = 35405408, upload-time = "2026-04-17T04:19:20.08Z" }, - { url = "https://files.pythonhosted.org/packages/de/2e/9b4252a12c4b641ab1479a6a4d3d576e75fc42ca2a797d88e2e0626abda0/dimos_viewer-0.30.0a6.dev99-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:a0fae6f2077fc6ceb25e1ed33fb7ccf183ef3e2a30456aa5462b953c1419e547", size = 39167138, upload-time = "2026-04-17T04:19:23.292Z" }, - { url = "https://files.pythonhosted.org/packages/46/2a/4bd02c3d79df2aefc5be47afda6b95121937cef0a3f6b15d071691ec3ca7/dimos_viewer-0.30.0a6.dev99-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:e844015f3ad193d50201c39abd3e3f34abbf03adbfb1075468696c1236df1409", size = 41536172, upload-time = "2026-04-17T04:19:26.421Z" }, - { url = "https://files.pythonhosted.org/packages/1b/b1/efcea9b9e21c4ab75e2df016a27e5045e30d91a494465ab0cc627d8d8bc3/dimos_viewer-0.30.0a6.dev99-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:dc82061c2c025684c0fbed5392f793d137b1b0fc3aa1b601988bf4d2ee88aa27", size = 35405409, upload-time = "2026-04-17T04:19:29.574Z" }, - { url = "https://files.pythonhosted.org/packages/2d/8e/d482b0b9379c40ddd7547600543ce726fc3b5d10e396a876f22b2d76d0e6/dimos_viewer-0.30.0a6.dev99-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:0f6acfa0de3083e746ac43fe0d0a328d624bcb859dc698b1bbc592f444f52f15", size = 39167144, upload-time = "2026-04-17T04:19:32.301Z" }, - { url = "https://files.pythonhosted.org/packages/6d/eb/08922721c74ceaa99a824258db02c438d50f77c22ff80332cbc4b1a8db7b/dimos_viewer-0.30.0a6.dev99-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:56fa9139c49ec4bf96b12d6e98d3de3319a66876374ae57bda4534ab7a347765", size = 41536171, upload-time = "2026-04-17T04:19:35.29Z" }, +version = "0.30.0a6" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0b/90/ad6d0e1e177a10a0b4f7e736436b6d2741acaeb402ab59504347236744f4/dimos_viewer-0.30.0a6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e623a21e6992e263513847e12809a0d234d73fc7af42a6428e84ca165ba682d0", size = 35309553, upload-time = "2026-03-18T15:22:26.874Z" }, + { url = "https://files.pythonhosted.org/packages/a1/84/1c8f41ff2bd5b6ee143eb6119107397dac284fa4f1f8335623c498bd1d9c/dimos_viewer-0.30.0a6-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:36068a3293cb1c7f4db9f4e6c9fea2d7dd2a2527025f803585f4d3aaad9aedbd", size = 39072034, upload-time = "2026-03-18T15:22:29.592Z" }, + { url = "https://files.pythonhosted.org/packages/58/e6/d6214245e5b99e1da262d037f52d3d39c6b87c65acb516fb08f11378e932/dimos_viewer-0.30.0a6-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:2bf36e8c8bd9dd822bedd1cb2d80ee2bf74b58184ba33872494baed0395fa7ff", size = 41447599, upload-time = "2026-03-18T15:22:32.699Z" }, + { url = "https://files.pythonhosted.org/packages/48/04/80f566400776cab9af68b4a3c0132f55786acd1641ea39d8b75e797a2e22/dimos_viewer-0.30.0a6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:947cfa10c583b357d589c10cb466c63b3651a83d1013a254c0ba03fc2959bef7", size = 35309552, upload-time = "2026-03-18T15:22:35.395Z" }, + { url = "https://files.pythonhosted.org/packages/4c/c3/72157e0806951c2c71c70dcd783e27be8d694344d7ecdb94eaef1066cf99/dimos_viewer-0.30.0a6-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:53ca4ac1f0778f1d9afb317b6268c941c02b20af86dd2aaaf1ea79f2c1d1eeb8", size = 39072018, upload-time = "2026-03-18T15:22:38.043Z" }, + { url = "https://files.pythonhosted.org/packages/2f/92/959fc1e9cdcb5fd8d793b2c8515a6086c9f913ba470baad1f3182ae4c242/dimos_viewer-0.30.0a6-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:27e108060a942c92f7869a0e45693dfe1798896bd90cbac6d1ce019a682f8ba7", size = 41447647, upload-time = "2026-03-18T15:22:41.003Z" }, + { url = "https://files.pythonhosted.org/packages/ab/d6/d76763b60d82539e92777500551116306cfea462f6976ad814a3bdf57e1d/dimos_viewer-0.30.0a6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f4f49f973c51055cfd594b68a8e9d183c706f94b1513b6b69db900d05850f741", size = 35309553, upload-time = "2026-03-18T15:22:43.681Z" }, + { url = "https://files.pythonhosted.org/packages/26/ab/6ea7686c467caecdc74dd8d3a0267053ac74229b3afebc64cff180d5074c/dimos_viewer-0.30.0a6-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:791ef1c1d8d41db69a7d2b701ed3f0b6bc39cb3264aaef7300eddb576c8df7ed", size = 39072062, upload-time = "2026-03-18T15:22:46.264Z" }, + { url = "https://files.pythonhosted.org/packages/3c/87/fce7aac56d8a234d3d7c0911928bb3471d7852e35263b966d2aac5be42cd/dimos_viewer-0.30.0a6-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:dd976c39c38718b8373e1894d55b78c10bcb8c5716c8dbd5fba59141bc08ab3c", size = 41447667, upload-time = "2026-03-18T15:22:49.214Z" }, ] [[package]] From 20ba0f31be56e039a80d4bfecb95c70243f7a286 Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Tue, 28 Apr 2026 19:33:16 -0700 Subject: [PATCH 28/30] ruff fixes + unused noqa cleanup --- .../hardware/manipulators/openarm/adapter.py | 41 ++++----- dimos/hardware/manipulators/openarm/driver.py | 54 ++++++----- .../manipulators/openarm/test_driver.py | 7 +- dimos/robot/catalog/openarm.py | 5 +- .../openarm/scripts/openarm_can_probe.py | 89 +++++++++++++------ .../openarm/scripts/openarm_set_mit_mode.py | 33 ++++--- dimos/utils/workspace.py | 24 +++-- 7 files changed, 150 insertions(+), 103 deletions(-) diff --git a/dimos/hardware/manipulators/openarm/adapter.py b/dimos/hardware/manipulators/openarm/adapter.py index 9ec7f60d30..851185a26e 100644 --- a/dimos/hardware/manipulators/openarm/adapter.py +++ b/dimos/hardware/manipulators/openarm/adapter.py @@ -16,10 +16,8 @@ from __future__ import annotations -import time from pathlib import Path - -from dimos.utils.data import LfsPath +import time from typing import TYPE_CHECKING, Any import numpy as np @@ -35,6 +33,7 @@ JointLimits, ManipulatorInfo, ) +from dimos.utils.data import LfsPath if TYPE_CHECKING: from dimos.hardware.manipulators.registry import AdapterRegistry @@ -163,7 +162,7 @@ def connect(self) -> bool: interface=self._interface, ) self._bus.open() - except Exception as e: # noqa: BLE001 + except Exception as e: print(f"ERROR: OpenArm {self._side}@{self._address} connect failed: {e}") self._bus = None return False @@ -200,8 +199,10 @@ def connect(self) -> bool: urdf = str(self._URDF_LEFT if self._side == "left" else self._URDF_RIGHT) self._pin_model = pinocchio.buildModelFromUrdf(urdf) self._pin_data = self._pin_model.createData() - print(f"OpenArm {self._side}: gravity compensation enabled (nq={self._pin_model.nq})") - except Exception as e: # noqa: BLE001 + print( + f"OpenArm {self._side}: gravity compensation enabled (nq={self._pin_model.nq})" + ) + except Exception as e: print(f"WARNING: gravity comp disabled — {e}") self._pin_model = None self._pin_data = None @@ -213,7 +214,7 @@ def disconnect(self) -> None: return try: self._bus.disable_all() - except Exception: # noqa: BLE001 + except Exception: pass self._enabled = False self._bus.close() @@ -324,15 +325,10 @@ def _compute_gravity_torques(self, q: list[float]) -> list[float]: import pinocchio q_arr = np.array(q, dtype=np.float64) - tau_g = pinocchio.computeGeneralizedGravity( - self._pin_model, self._pin_data, q_arr - ) + tau_g = pinocchio.computeGeneralizedGravity(self._pin_model, self._pin_data, q_arr) # Clamp to motor torque limits for safety limits = [m.limits for m in self._motors] # (p_max, v_max, t_max) - return [ - float(np.clip(tau_g[i], -lim[2], lim[2])) - for i, lim in enumerate(limits) - ] + return [float(np.clip(tau_g[i], -lim[2], lim[2])) for i, lim in enumerate(limits)] # ------------------------------------------------------------------ # Commands @@ -354,7 +350,7 @@ def write_joint_positions( tau_ff = self._compute_gravity_torques(q_current) commands = [ (q, 0.0, kp * velocity, kd, tau) - for q, kp, kd, tau in zip(positions, self._kp, self._kd, tau_ff) + for q, kp, kd, tau in zip(positions, self._kp, self._kd, tau_ff, strict=False) ] self._bus.send_mit_many(commands) self._last_cmd_q = list(positions) @@ -376,7 +372,7 @@ def write_joint_velocities(self, velocities: list[float]) -> bool: tau_ff = self._compute_gravity_torques(q_current) commands = [ (q_anchor, dq, 0.0, kd, tau) - for q_anchor, dq, kd, tau in zip(anchor, velocities, self._kd, tau_ff) + for q_anchor, dq, kd, tau in zip(anchor, velocities, self._kd, tau_ff, strict=False) ] self._bus.send_mit_many(commands) return True @@ -386,11 +382,12 @@ def write_stop(self) -> bool: return False try: q_now = self.read_joint_positions() - except Exception: # noqa: BLE001 + except Exception: q_now = [0.0] * self._dof tau_ff = self._compute_gravity_torques(q_now) - commands = [(q, 0.0, kp, kd, tau) - for q, kp, kd, tau in zip(q_now, self._kp, self._kd, tau_ff)] + commands = [ + (q, 0.0, kp, kd, tau) for q, kp, kd, tau in zip(q_now, self._kp, self._kd, tau_ff, strict=False) + ] self._bus.send_mit_many(commands) self._last_cmd_q = q_now return True @@ -425,9 +422,7 @@ def write_clear_errors(self) -> bool: def read_cartesian_position(self) -> dict[str, float] | None: return None - def write_cartesian_position( - self, pose: dict[str, float], velocity: float = 1.0 - ) -> bool: + def write_cartesian_position(self, pose: dict[str, float], velocity: float = 1.0) -> bool: return False def read_gripper_position(self) -> float | None: @@ -441,7 +436,7 @@ def read_force_torque(self) -> list[float] | None: # ── Registry hook (required for auto-discovery) ─────────────────── -def register(registry: "AdapterRegistry") -> None: +def register(registry: AdapterRegistry) -> None: registry.register("openarm", OpenArmAdapter) diff --git a/dimos/hardware/manipulators/openarm/driver.py b/dimos/hardware/manipulators/openarm/driver.py index 843e7dc64b..0a964d92db 100644 --- a/dimos/hardware/manipulators/openarm/driver.py +++ b/dimos/hardware/manipulators/openarm/driver.py @@ -20,12 +20,12 @@ from __future__ import annotations +from dataclasses import dataclass import enum import errno import struct import threading import time -from dataclasses import dataclass from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -57,19 +57,19 @@ class MotorType(str, enum.Enum): # (p_max [rad], v_max [rad/s], t_max [Nm]) _MOTOR_LIMITS: dict[MotorType, tuple[float, float, float]] = { - MotorType.DM3507: (12.5, 50.0, 5.0), - MotorType.DM4310: (12.5, 30.0, 10.0), - MotorType.DM4310_48V: (12.5, 50.0, 10.0), - MotorType.DM4340: (12.5, 8.0, 28.0), - MotorType.DM4340_48V: (12.5, 10.0, 28.0), - MotorType.DM6006: (12.5, 45.0, 20.0), - MotorType.DM8006: (12.5, 45.0, 40.0), - MotorType.DM8009: (12.5, 45.0, 54.0), - MotorType.DM10010L: (12.5, 25.0, 200.0), - MotorType.DM10010: (12.5, 20.0, 200.0), - MotorType.DMH3510: (12.5, 280.0, 1.0), - MotorType.DMH6215: (12.5, 45.0, 10.0), - MotorType.DMG6220: (12.5, 45.0, 10.0), + MotorType.DM3507: (12.5, 50.0, 5.0), + MotorType.DM4310: (12.5, 30.0, 10.0), + MotorType.DM4310_48V: (12.5, 50.0, 10.0), + MotorType.DM4340: (12.5, 8.0, 28.0), + MotorType.DM4340_48V: (12.5, 10.0, 28.0), + MotorType.DM6006: (12.5, 45.0, 20.0), + MotorType.DM8006: (12.5, 45.0, 40.0), + MotorType.DM8009: (12.5, 45.0, 54.0), + MotorType.DM10010L: (12.5, 25.0, 200.0), + MotorType.DM10010: (12.5, 20.0, 200.0), + MotorType.DMH3510: (12.5, 280.0, 1.0), + MotorType.DMH6215: (12.5, 45.0, 10.0), + MotorType.DMG6220: (12.5, 45.0, 10.0), } # MIT gain ranges (protocol-fixed, same for every motor type) @@ -139,11 +139,11 @@ def pack_mit_frame( class MotorState: """Decoded state from a Damiao reply frame.""" - q: float # rad - dq: float # rad/s - tau: float # Nm - t_mos: int # °C - t_rotor: int # °C + q: float # rad + dq: float # rad/s + tau: float # Nm + t_mos: int # °C + t_rotor: int # °C timestamp: float # monotonic seconds when received @@ -240,7 +240,7 @@ def __init__( self._interface = interface self._by_recv: dict[int, DamiaoMotor] = {m.effective_recv_id: m for m in motors} - self._bus: "can.BusABC | None" = None + self._bus: can.BusABC | None = None self._rx_thread: threading.Thread | None = None self._rx_stop = threading.Event() self._state_lock = threading.Lock() @@ -275,7 +275,7 @@ def close(self) -> None: finally: self._bus = None - def __enter__(self) -> "OpenArmBus": + def __enter__(self) -> OpenArmBus: self.open() return self @@ -327,10 +327,8 @@ def send_mit_many( ) -> None: """One MIT frame per motor; commands[i] → self.motors[i] = (q, dq, kp, kd, tau).""" if len(commands) != len(self._motors): - raise ValueError( - f"expected {len(self._motors)} commands, got {len(commands)}" - ) - for i, (motor, cmd) in enumerate(zip(self._motors, commands)): + raise ValueError(f"expected {len(self._motors)} commands, got {len(commands)}") + for i, (motor, cmd) in enumerate(zip(self._motors, commands, strict=False)): q, dq, kp, kd, tau = cmd data = pack_mit_frame(motor.motor_type, q, dq, kp, kd, tau) self._send_raw(motor.send_id, data) @@ -362,9 +360,7 @@ def wait_all_states(self, timeout: float = 0.5) -> bool: deadline = time.monotonic() + timeout while time.monotonic() < deadline: with self._state_lock: - if all( - m.effective_recv_id in self._states for m in self._motors - ): + if all(m.effective_recv_id in self._states for m in self._motors): return True time.sleep(0.005) return False @@ -422,11 +418,11 @@ def _rx_loop(self) -> None: __all__ = [ "CTRL_MODE_MIT", - "DamiaoMotor", "KD_MAX", "KD_MIN", "KP_MAX", "KP_MIN", + "DamiaoMotor", "MotorState", "MotorType", "OpenArmBus", diff --git a/dimos/hardware/manipulators/openarm/test_driver.py b/dimos/hardware/manipulators/openarm/test_driver.py index 829f7970ef..4f5a1a16a9 100644 --- a/dimos/hardware/manipulators/openarm/test_driver.py +++ b/dimos/hardware/manipulators/openarm/test_driver.py @@ -21,9 +21,9 @@ from dimos.hardware.manipulators.openarm.driver import ( CTRL_MODE_MIT, - DamiaoMotor, KD_MAX, KP_MAX, + DamiaoMotor, MotorType, OpenArmBus, float_to_uint, @@ -33,7 +33,6 @@ uint_to_float, ) - # --------------------------------------------------------------------------- # Pack / unpack primitives # --------------------------------------------------------------------------- @@ -192,9 +191,7 @@ def test_rx_thread_populates_state_cache() -> None: 28, ] ) - sender.send( - can.Message(arbitration_id=0x11, data=payload, is_extended_id=False) - ) + sender.send(can.Message(arbitration_id=0x11, data=payload, is_extended_id=False)) # Poll briefly for the RX thread to consume it deadline = time.monotonic() + 0.5 s = None diff --git a/dimos/robot/catalog/openarm.py b/dimos/robot/catalog/openarm.py index b6271ea836..b6e1238cf2 100644 --- a/dimos/robot/catalog/openarm.py +++ b/dimos/robot/catalog/openarm.py @@ -81,7 +81,10 @@ def openarm_arm( # Merge adapter_kwargs rather than replace, so callers can add keys # (e.g. auto_set_mit_mode) without clobbering the catalog's "side". if "adapter_kwargs" in overrides: - defaults["adapter_kwargs"] = {**defaults["adapter_kwargs"], **overrides.pop("adapter_kwargs")} + defaults["adapter_kwargs"] = { + **defaults["adapter_kwargs"], + **overrides.pop("adapter_kwargs"), + } defaults.update(overrides) return RobotConfig(**defaults) diff --git a/dimos/robot/manipulators/openarm/scripts/openarm_can_probe.py b/dimos/robot/manipulators/openarm/scripts/openarm_can_probe.py index ed57f16691..8c9f401382 100755 --- a/dimos/robot/manipulators/openarm/scripts/openarm_can_probe.py +++ b/dimos/robot/manipulators/openarm/scripts/openarm_can_probe.py @@ -11,6 +11,7 @@ python dimos/robot/manipulators/openarm/scripts/openarm_can_probe.py --channel can0 python dimos/robot/manipulators/openarm/scripts/openarm_can_probe.py --channel can1 --ids 1,2,3,4,5,6,7 """ + from __future__ import annotations import argparse @@ -26,7 +27,7 @@ # [p_max rad, v_max rad/s, t_max Nm] LIMITS: dict[str, tuple[float, float, float]] = { "DM4310": (12.5, 30.0, 10.0), - "DM4340": (12.5, 8.0, 28.0), + "DM4340": (12.5, 8.0, 28.0), "DM8006": (12.5, 45.0, 40.0), } @@ -42,7 +43,7 @@ (0x08, "DM4310"), # gripper ] -ENABLE = bytes([0xFF] * 7 + [0xFC]) +ENABLE = bytes([0xFF] * 7 + [0xFC]) DISABLE = bytes([0xFF] * 7 + [0xFD]) FD = False # set by --fd at runtime; defaults to classical CAN @ 1 Mbit @@ -57,24 +58,28 @@ def parse_state(motor_type: str, data: bytes) -> tuple[float, float, float, int, if len(data) < 8: return None p_max, v_max, t_max = LIMITS[motor_type] - q_u = (data[1] << 8) | data[2] - dq_u = (data[3] << 4) | (data[4] >> 4) + q_u = (data[1] << 8) | data[2] + dq_u = (data[3] << 4) | (data[4] >> 4) tau_u = ((data[4] & 0x0F) << 8) | data[5] - q = uint_to_float(q_u, -p_max, p_max, 16) - dq = uint_to_float(dq_u, -v_max, v_max, 12) + q = uint_to_float(q_u, -p_max, p_max, 16) + dq = uint_to_float(dq_u, -v_max, v_max, 12) tau = uint_to_float(tau_u, -t_max, t_max, 12) return q, dq, tau, data[6], data[7] -def probe_motor(bus: can.BusABC, send_id: int, recv_id: int, - motor_type: str, timeout: float = 0.2) -> bool: +def probe_motor( + bus: can.BusABC, send_id: int, recv_id: int, motor_type: str, timeout: float = 0.2 +) -> bool: """Enable motor, wait for state reply on recv_id, print result, disable.""" # Flush any stale frames while bus.recv(0.0) is not None: pass - bus.send(can.Message(arbitration_id=send_id, data=ENABLE, - is_extended_id=False, is_fd=FD, bitrate_switch=FD)) + bus.send( + can.Message( + arbitration_id=send_id, data=ENABLE, is_extended_id=False, is_fd=FD, bitrate_switch=FD + ) + ) t0 = time.monotonic() while time.monotonic() - t0 < timeout: msg = bus.recv(timeout - (time.monotonic() - t0)) @@ -85,29 +90,55 @@ def probe_motor(bus: can.BusABC, send_id: int, recv_id: int, parsed = parse_state(motor_type, bytes(msg.data)) if parsed is None: print(f" 0x{send_id:02X} ({motor_type}): short reply {list(msg.data)}") - bus.send(can.Message(arbitration_id=send_id, data=DISABLE, - is_extended_id=False, is_fd=FD, bitrate_switch=FD)) + bus.send( + can.Message( + arbitration_id=send_id, + data=DISABLE, + is_extended_id=False, + is_fd=FD, + bitrate_switch=FD, + ) + ) return False q, dq, tau, t_mos, t_rot = parsed - print(f" 0x{send_id:02X} ({motor_type:>6}): " - f"q={q:+.3f} rad dq={dq:+.3f} rad/s tau={tau:+.3f} Nm " - f"T_mos={t_mos}C T_rotor={t_rot}C") - bus.send(can.Message(arbitration_id=send_id, data=DISABLE, - is_extended_id=False, is_fd=FD, bitrate_switch=FD)) + print( + f" 0x{send_id:02X} ({motor_type:>6}): " + f"q={q:+.3f} rad dq={dq:+.3f} rad/s tau={tau:+.3f} Nm " + f"T_mos={t_mos}C T_rotor={t_rot}C" + ) + bus.send( + can.Message( + arbitration_id=send_id, + data=DISABLE, + is_extended_id=False, + is_fd=FD, + bitrate_switch=FD, + ) + ) return True - print(f" 0x{send_id:02X} ({motor_type:>6}): NO REPLY on 0x{recv_id:02X} within {timeout*1e3:.0f}ms") - bus.send(can.Message(arbitration_id=send_id, data=DISABLE, - is_extended_id=False, is_fd=FD, bitrate_switch=FD)) + print( + f" 0x{send_id:02X} ({motor_type:>6}): NO REPLY on 0x{recv_id:02X} within {timeout * 1e3:.0f}ms" + ) + bus.send( + can.Message( + arbitration_id=send_id, data=DISABLE, is_extended_id=False, is_fd=FD, bitrate_switch=FD + ) + ) return False def main() -> int: - ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + ap = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) ap.add_argument("--channel", default="can0", help="SocketCAN interface (default: can0)") - ap.add_argument("--fd", action="store_true", help="Use CAN-FD (requires FD-capable adapter). Default is classical CAN @ 1 Mbit, which is what most gs_usb adapters support.") - ap.add_argument("--ids", default=None, - help="Comma-separated send IDs to probe (default: 1..8)") + ap.add_argument( + "--fd", + action="store_true", + help="Use CAN-FD (requires FD-capable adapter). Default is classical CAN @ 1 Mbit, which is what most gs_usb adapters support.", + ) + ap.add_argument("--ids", default=None, help="Comma-separated send IDs to probe (default: 1..8)") ap.add_argument("--timeout", type=float, default=0.2, help="Reply timeout per motor (s)") args = ap.parse_args() @@ -127,7 +158,10 @@ def main() -> int: return 1 if not iface_up: print(f"ERROR: SocketCAN interface '{args.channel}' is DOWN.", file=sys.stderr) - print(f" Run: sudo ./dimos/robot/manipulators/openarm/scripts/openarm_can_up.sh {args.channel}", file=sys.stderr) + print( + f" Run: sudo ./dimos/robot/manipulators/openarm/scripts/openarm_can_up.sh {args.channel}", + file=sys.stderr, + ) return 1 print(f"Opening {args.channel} ({'CAN-FD' if FD else 'classical CAN'})...") @@ -135,7 +169,10 @@ def main() -> int: bus = can.Bus(interface="socketcan", channel=args.channel, fd=FD) except Exception as e: print(f"ERROR opening {args.channel}: {e}", file=sys.stderr) - print(" Did you run 'sudo ./dimos/robot/manipulators/openarm/scripts/openarm_can_up.sh' first?", file=sys.stderr) + print( + " Did you run 'sudo ./dimos/robot/manipulators/openarm/scripts/openarm_can_up.sh' first?", + file=sys.stderr, + ) return 1 try: diff --git a/dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py b/dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py index 0b7f5b1cb1..5db1e7f89f 100755 --- a/dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py +++ b/dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py @@ -24,6 +24,7 @@ # CAN-FD (only if your adapter supports it) python dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py --channel can0 --fd """ + from __future__ import annotations import argparse @@ -43,13 +44,17 @@ def write_ctrl_mode(bus: can.BusABC, send_id: int, fd: bool) -> bool: val = struct.pack("> 8) & 0xFF, 0x55, RID_CTRL_MODE, - val[0], val[1], val[2], val[3]]) + data = bytes( + [send_id & 0xFF, (send_id >> 8) & 0xFF, 0x55, RID_CTRL_MODE, val[0], val[1], val[2], val[3]] + ) # Flush while bus.recv(0.0) is not None: pass - bus.send(can.Message(arbitration_id=0x7FF, data=data, - is_extended_id=False, is_fd=fd, bitrate_switch=fd)) + bus.send( + can.Message( + arbitration_id=0x7FF, data=data, is_extended_id=False, is_fd=fd, bitrate_switch=fd + ) + ) # Wait for ack on 0x7FF (per openarm_can param response) t0 = time.monotonic() while time.monotonic() - t0 < 0.2: @@ -61,19 +66,24 @@ def write_ctrl_mode(bus: can.BusABC, send_id: int, fd: bool) -> bool: rid = msg.data[3] if rid == RID_CTRL_MODE: echoed = int(struct.unpack(" int: - ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + ap = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) ap.add_argument("--channel", default="can0") ap.add_argument("--fd", action="store_true", help="Use CAN-FD (default: classical CAN)") - ap.add_argument("--id", type=lambda s: int(s, 0), default=None, - help="Single send ID (default: all 8)") + ap.add_argument( + "--id", type=lambda s: int(s, 0), default=None, help="Single send ID (default: all 8)" + ) args = ap.parse_args() fd = args.fd @@ -87,7 +97,10 @@ def main() -> int: return 1 if not (flags & 0x1): print(f"ERROR: SocketCAN interface '{args.channel}' is DOWN.", file=sys.stderr) - print(f" Run: sudo ./dimos/robot/manipulators/openarm/scripts/openarm_can_up.sh {args.channel}", file=sys.stderr) + print( + f" Run: sudo ./dimos/robot/manipulators/openarm/scripts/openarm_can_up.sh {args.channel}", + file=sys.stderr, + ) return 1 print(f"Opening {args.channel} ({'CAN-FD' if fd else 'classical'})") diff --git a/dimos/utils/workspace.py b/dimos/utils/workspace.py index 8f0e496143..70d21f5a7e 100644 --- a/dimos/utils/workspace.py +++ b/dimos/utils/workspace.py @@ -20,9 +20,9 @@ from __future__ import annotations import argparse +from pathlib import Path import sys import time -from pathlib import Path from typing import Any import numpy as np @@ -64,7 +64,9 @@ def _sample(self, n: int, seed: int) -> None: self.configs[i] = q J = self._pin.getJointJacobian( - self.model, self.data, self.ee_id, + self.model, + self.data, + self.ee_id, self._pin.ReferenceFrame.LOCAL_WORLD_ALIGNED, ) JJt = J[:3, :] @ J[:3, :].T @@ -114,9 +116,9 @@ def stats(self) -> str: lines = [ "Workspace stats:", f" Samples: {len(p):,}", - f" X range: [{p[:,0].min():.3f}, {p[:,0].max():.3f}] m", - f" Y range: [{p[:,1].min():.3f}, {p[:,1].max():.3f}] m", - f" Z range: [{p[:,2].min():.3f}, {p[:,2].max():.3f}] m", + f" X range: [{p[:, 0].min():.3f}, {p[:, 0].max():.3f}] m", + f" Y range: [{p[:, 1].min():.3f}, {p[:, 1].max():.3f}] m", + f" Z range: [{p[:, 2].min():.3f}, {p[:, 2].max():.3f}] m", f" Max reach from origin: {np.linalg.norm(p, axis=1).max():.3f} m", f" Manipulability: [{self.manipulability.min():.4f}, {self.manipulability.max():.4f}]", ] @@ -125,7 +127,7 @@ def stats(self) -> str: hull = ConvexHull(p) lines.append(f" Convex hull volume: {hull.volume:.4f} m³") - except Exception: # noqa: BLE001 + except Exception: pass return "\n".join(lines) @@ -222,7 +224,9 @@ def _cmd_suggest(args: argparse.Namespace) -> int: p, q = ws.positions[idx], ws.configs[idx] pos_str = f"({p[0]:+.3f}, {p[1]:+.3f}, {p[2]:+.3f})" q_str = "[" + ", ".join(f"{v:.2f}" for v in q) + "]" - print(f"{rank:>3} {dists[idx]:>6.3f} {ws.manipulability[idx]:>7.4f} {pos_str:>30} {q_str}") + print( + f"{rank:>3} {dists[idx]:>6.3f} {ws.manipulability[idx]:>7.4f} {pos_str:>30} {q_str}" + ) return 0 @@ -257,8 +261,10 @@ def _cmd_interactive(args: argparse.Namespace) -> int: if result["reachable"]: render_target(meshcat, (x, y, z), "target", (0.0, 1.0, 0.0)) render_target(meshcat, result["best_position"], "best_ee", (0.0, 0.5, 1.0)) - print(f" REACHABLE — {result['n_configs']} configs, " - f"manip={result['mean_manipulability']:.4f}") + print( + f" REACHABLE — {result['n_configs']} configs, " + f"manip={result['mean_manipulability']:.4f}" + ) print(f" Joint config: {[round(q, 3) for q in result['best_config']]}") else: render_target(meshcat, (x, y, z), "target", (1.0, 0.0, 0.0)) From 21dd274883e182583f4b7266e24f257f30a38b41 Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Tue, 28 Apr 2026 20:00:30 -0700 Subject: [PATCH 29/30] converted relateive paths to absolute paths --- .../hardware/manipulators/openarm/adapter.py | 31 ++-------------- dimos/hardware/manipulators/openarm/driver.py | 36 ------------------- .../manipulators/openarm/test_driver.py | 16 ++++----- .../openarm/scripts/openarm_can_probe.py | 14 ++++++++ .../openarm/scripts/openarm_set_mit_mode.py | 14 ++++++++ .../manipulation/openarm_integration.md | 12 +++---- 6 files changed, 43 insertions(+), 80 deletions(-) diff --git a/dimos/hardware/manipulators/openarm/adapter.py b/dimos/hardware/manipulators/openarm/adapter.py index 851185a26e..f68aec32ce 100644 --- a/dimos/hardware/manipulators/openarm/adapter.py +++ b/dimos/hardware/manipulators/openarm/adapter.py @@ -138,10 +138,6 @@ def __init__( self._pin_model: Any = None self._pin_data: Any = None - # ------------------------------------------------------------------ - # Lifecycle - # ------------------------------------------------------------------ - def connect(self) -> bool: # Preflight: verify the SocketCAN interface is up before opening the bus. # Bringing the interface up requires root privileges, so we don't do it @@ -223,10 +219,6 @@ def disconnect(self) -> None: def is_connected(self) -> bool: return self._bus is not None - # ------------------------------------------------------------------ - # Info - # ------------------------------------------------------------------ - def get_info(self) -> ManipulatorInfo: return ManipulatorInfo( vendor="Enactic", @@ -250,10 +242,6 @@ def get_limits(self) -> JointLimits: velocity_max=list(_V10_VEL_MAX), ) - # ------------------------------------------------------------------ - # Mode - # ------------------------------------------------------------------ - def set_control_mode(self, mode: ControlMode) -> bool: # OpenArm runs exclusively in Damiao MIT register mode; we emulate # dimos ControlModes by tuning kp/kd/q/dq/tau on each MIT frame. @@ -271,10 +259,6 @@ def set_control_mode(self, mode: ControlMode) -> bool: def get_control_mode(self) -> ControlMode: return self._control_mode - # ------------------------------------------------------------------ - # State reads - # ------------------------------------------------------------------ - def _states_or_raise(self) -> list[Any]: if self._bus is None: raise RuntimeError("OpenArmAdapter not connected") @@ -314,10 +298,6 @@ def read_error(self) -> tuple[int, str]: return 1, f"rotor over-temperature ({t_rotor}°C)" return 0, "" - # ------------------------------------------------------------------ - # Gravity compensation - # ------------------------------------------------------------------ - def _compute_gravity_torques(self, q: list[float]) -> list[float]: """Pinocchio G(q), clamped to motor torque limits. Zero if model not loaded.""" if self._pin_model is None or self._pin_data is None: @@ -330,10 +310,6 @@ def _compute_gravity_torques(self, q: list[float]) -> list[float]: limits = [m.limits for m in self._motors] # (p_max, v_max, t_max) return [float(np.clip(tau_g[i], -lim[2], lim[2])) for i, lim in enumerate(limits)] - # ------------------------------------------------------------------ - # Commands - # ------------------------------------------------------------------ - def write_joint_positions( self, positions: list[float], @@ -386,7 +362,8 @@ def write_stop(self) -> bool: q_now = [0.0] * self._dof tau_ff = self._compute_gravity_torques(q_now) commands = [ - (q, 0.0, kp, kd, tau) for q, kp, kd, tau in zip(q_now, self._kp, self._kd, tau_ff, strict=False) + (q, 0.0, kp, kd, tau) + for q, kp, kd, tau in zip(q_now, self._kp, self._kd, tau_ff, strict=False) ] self._bus.send_mit_many(commands) self._last_cmd_q = q_now @@ -415,10 +392,6 @@ def write_clear_errors(self) -> bool: self._enabled = True return True - # ------------------------------------------------------------------ - # Cartesian / gripper / F/T — not supported at this layer - # ------------------------------------------------------------------ - def read_cartesian_position(self) -> dict[str, float] | None: return None diff --git a/dimos/hardware/manipulators/openarm/driver.py b/dimos/hardware/manipulators/openarm/driver.py index 0a964d92db..5408187647 100644 --- a/dimos/hardware/manipulators/openarm/driver.py +++ b/dimos/hardware/manipulators/openarm/driver.py @@ -32,11 +32,6 @@ import can -# --------------------------------------------------------------------------- -# Motor tables (from enactic/openarm_can dm_motor_constants.hpp) -# --------------------------------------------------------------------------- - - class MotorType(str, enum.Enum): """Damiao motor types used on OpenArm. Values match the reference library.""" @@ -85,11 +80,6 @@ class MotorType(str, enum.Enum): CTRL_MODE_MIT = 1 -# --------------------------------------------------------------------------- -# Pack / unpack helpers (pure — safe to unit test in isolation) -# --------------------------------------------------------------------------- - - def _clamp(x: float, lo: float, hi: float) -> float: if x < lo: return lo @@ -186,11 +176,6 @@ def pack_write_param_frame(send_id: int, rid: int, value_u32: int) -> bytes: ) -# --------------------------------------------------------------------------- -# Motor descriptor -# --------------------------------------------------------------------------- - - @dataclass(frozen=True) class DamiaoMotor: """One Damiao motor on a CAN bus. recv_id defaults to send_id | 0x10.""" @@ -208,11 +193,6 @@ def limits(self) -> tuple[float, float, float]: return _MOTOR_LIMITS[self.motor_type] -# --------------------------------------------------------------------------- -# Bus wrapper with background receive thread -# --------------------------------------------------------------------------- - - class OpenArmBus: """One SocketCAN bus with a background RX thread caching latest state.""" @@ -246,10 +226,6 @@ def __init__( self._state_lock = threading.Lock() self._states: dict[int, MotorState] = {} - # ------------------------------------------------------------------ - # Lifecycle - # ------------------------------------------------------------------ - def open(self) -> None: """Open the CAN bus and start the background RX thread.""" if self._bus is not None: @@ -282,10 +258,6 @@ def __enter__(self) -> OpenArmBus: def __exit__(self, *_exc: object) -> None: self.close() - # ------------------------------------------------------------------ - # Control commands - # ------------------------------------------------------------------ - def enable_all(self) -> None: for m in self._motors: self._send_raw(m.send_id, _pack_control_command(_CMD_ENABLE)) @@ -337,10 +309,6 @@ def send_mit_many( if i < len(self._motors) - 1: time.sleep(0.0005) - # ------------------------------------------------------------------ - # State access - # ------------------------------------------------------------------ - @property def motors(self) -> tuple[DamiaoMotor, ...]: return tuple(self._motors) @@ -365,10 +333,6 @@ def wait_all_states(self, timeout: float = 0.5) -> bool: time.sleep(0.005) return False - # ------------------------------------------------------------------ - # Internals - # ------------------------------------------------------------------ - def _send_raw(self, arbitration_id: int, data: bytes) -> None: if self._bus is None: raise RuntimeError("bus not open — call .open() first") diff --git a/dimos/hardware/manipulators/openarm/test_driver.py b/dimos/hardware/manipulators/openarm/test_driver.py index 4f5a1a16a9..c65a972bd6 100644 --- a/dimos/hardware/manipulators/openarm/test_driver.py +++ b/dimos/hardware/manipulators/openarm/test_driver.py @@ -5,6 +5,13 @@ # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Unit tests for the Damiao MIT-mode driver — no hardware required. Uses ``can.Bus(interface="virtual")`` for loopback. @@ -33,10 +40,6 @@ uint_to_float, ) -# --------------------------------------------------------------------------- -# Pack / unpack primitives -# --------------------------------------------------------------------------- - def test_float_to_uint_endpoints_and_roundtrip() -> None: # Endpoints @@ -137,11 +140,6 @@ def test_pack_write_param_ctrl_mode_mit() -> None: assert struct.unpack(" OpenArmBus: return OpenArmBus(channel=channel, motors=motors, fd=False, interface="virtual") diff --git a/dimos/robot/manipulators/openarm/scripts/openarm_can_probe.py b/dimos/robot/manipulators/openarm/scripts/openarm_can_probe.py index 8c9f401382..9c740ef485 100755 --- a/dimos/robot/manipulators/openarm/scripts/openarm_can_probe.py +++ b/dimos/robot/manipulators/openarm/scripts/openarm_can_probe.py @@ -1,4 +1,18 @@ #!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Probe an OpenArm on a SocketCAN interface. Enumerates all 8 expected Damiao motors (7 arm joints + gripper) on one CAN bus diff --git a/dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py b/dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py index 5db1e7f89f..01d9bbdb17 100755 --- a/dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py +++ b/dimos/robot/manipulators/openarm/scripts/openarm_set_mit_mode.py @@ -1,4 +1,18 @@ #!/usr/bin/env python3 +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Write CTRL_MODE = MIT (1) to one or all OpenArm motors. Damiao motors have a persistent CTRL_MODE register (RID=10). If a motor was diff --git a/docs/capabilities/manipulation/openarm_integration.md b/docs/capabilities/manipulation/openarm_integration.md index 306a44fd16..49a78ccde7 100644 --- a/docs/capabilities/manipulation/openarm_integration.md +++ b/docs/capabilities/manipulation/openarm_integration.md @@ -6,7 +6,7 @@ Guide for running the **OpenArm** — an open-source bimanual 7-DOF research arm Related: - Upstream hardware + C++ reference: [enactic/openarm_can](https://github.com/enactic/openarm_can) -- How to integrate any new arm: [adding_a_custom_arm.md](adding_a_custom_arm.md) +- How to integrate any new arm: [adding_a_custom_arm.md](/docs/capabilities/manipulation/adding_a_custom_arm.md) --- @@ -57,7 +57,7 @@ data/openarm_description/ # URDF + meshes (in-tree; may migrate to LFS) └── openarm_v10_single.urdf # standalone arm (Pinocchio FK for teleop) ``` -Workspace analysis is generic and lives in [dimos/utils/workspace.py](../../../dimos/utils/workspace.py) — works for any URDF, not just OpenArm. +Workspace analysis is generic and lives in [dimos/utils/workspace.py](/dimos/utils/workspace.py) — works for any URDF, not just OpenArm. --- @@ -219,7 +219,7 @@ If you don't know which Cartesian targets are reachable, check first with the wo ### Which CAN bus is which arm -Linux assigns `can0`/`can1` in USB-enumeration order, which isn't guaranteed stable across reboots or cable swaps. If the arms come up "swapped" (commanding `left_arm` moves the physical right arm), flip these two constants at the top of [blueprints.py](../../../dimos/robot/manipulators/openarm/blueprints.py): +Linux assigns `can0`/`can1` in USB-enumeration order, which isn't guaranteed stable across reboots or cable swaps. If the arms come up "swapped" (commanding `left_arm` moves the physical right arm), flip these two constants at the top of [blueprints.py](/dimos/robot/manipulators/openarm/blueprints.py): ```python LEFT_CAN = "can0" @@ -230,7 +230,7 @@ No other code changes are needed. ### Gain tuning (MIT kp/kd) -Defaults live in [adapter.py](../../../dimos/hardware/manipulators/openarm/adapter.py). Gains are per-joint because the shoulder motors (DM8006, 40 Nm) tolerate higher kp than the wrist motors (DM4310, 10 Nm): +Defaults live in [adapter.py](/dimos/hardware/manipulators/openarm/adapter.py). Gains are per-joint because the shoulder motors (DM8006, 40 Nm) tolerate higher kp than the wrist motors (DM4310, 10 Nm): ```python _DEFAULT_KP = [100.0, 100.0, 80.0, 80.0, 60.0, 60.0, 60.0] @@ -248,7 +248,7 @@ The URDFs use the xacro-generated limits (which include per-side offsets for mir ### Disabling auto MIT-mode write -The adapter writes `CTRL_MODE=MIT` to every motor at `connect()`. It's idempotent (writing the same value is a no-op), so this is safe to leave on. To verify that a previous write persisted across a power cycle, flip `AUTO_SET_MIT_MODE = False` in [blueprints.py](../../../dimos/robot/manipulators/openarm/blueprints.py) and restart — the arms should still respond. +The adapter writes `CTRL_MODE=MIT` to every motor at `connect()`. It's idempotent (writing the same value is a no-op), so this is safe to leave on. To verify that a previous write persisted across a power cycle, flip `AUTO_SET_MIT_MODE = False` in [blueprints.py](/dimos/robot/manipulators/openarm/blueprints.py) and restart — the arms should still respond. --- @@ -339,7 +339,7 @@ Persistent across power cycles. - **`COLLISION_AT_START` during planning.** `link5` and `link7` collision meshes overlap by 3 mm at every configuration. Handled by `OPENARM_COLLISION_EXCLUSIONS` in the catalog. If you see it anyway, the exclusion pairs may not be getting applied — check that the collision filter log line appears during world build. - **`INVALID_START` during planning.** Hardware encoder noise pushed a joint 1 mrad past a URDF limit. Joint4 used to be exactly `lower=0.0` which tripped this — it's now `-0.01` to give breathing room. If you see it on a different joint, widen that limit by ~10 mrad. - **"Transmit buffer full" (ENOBUFS) at 100 Hz.** Kernel TX queue too small. The bringup script sets `txqueuelen 1000`; the driver also retries on ENOBUFS. If you still see the error, check `ip -details link show canX | grep qlen`. -- **Arms swap sides.** USB enumeration order flipped. Swap `LEFT_CAN` / `RIGHT_CAN` in [blueprints.py](../../../dimos/robot/manipulators/openarm/blueprints.py). +- **Arms swap sides.** USB enumeration order flipped. Swap `LEFT_CAN` / `RIGHT_CAN` in [blueprints.py](/dimos/robot/manipulators/openarm/blueprints.py). --- From 6a44087de215df573074e5f70f8bfac4b30be0cd Mon Sep 17 00:00:00 2001 From: mustafab0 Date: Fri, 1 May 2026 14:25:32 -0700 Subject: [PATCH 30/30] updated urdf to increse min limit for joint4 as real hw can go lower than the limit --- data/.lfs/openarm_description.tar.gz | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data/.lfs/openarm_description.tar.gz b/data/.lfs/openarm_description.tar.gz index b4cf5e04e9..54aa76da41 100644 --- a/data/.lfs/openarm_description.tar.gz +++ b/data/.lfs/openarm_description.tar.gz @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:e5dcbc94024986a46414b81486c3e88883847db75dd6ee90dc5fa6b88536b20f -size 70064686 +oid sha256:4da176b6c210b9796bb2ee1a29c15ee9a67578b9ae906eb89a6ec8a44b7f303a +size 70064687