Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions backend/services/websocket_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""
WebSocket Connection Manager — Manages active WebSocket connections with dynamic pool sizing.

Reads pool bounds from environment variables:
WS_POOL_MIN_SIZE — minimum pool size (default: 10)
WS_POOL_MAX_SIZE — maximum pool size (default: 100)
WS_POOL_CLEANUP_INTERVAL — cleanup interval in seconds (default: 300)
"""

import os
import logging
import asyncio
from typing import Dict, Set, Optional
from dataclasses import dataclass, field
from datetime import datetime, timezone

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

handler = logging.StreamHandler()
formatter = logging.Formatter("[WebSocketManager] %(asctime)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)


@dataclass
class PoolConfig:
min_size: int = 10
max_size: int = 100
cleanup_interval: int = 300

@classmethod
def from_env(cls) -> "PoolConfig":
return cls(
min_size=int(os.getenv("WS_POOL_MIN_SIZE", "10")),
max_size=int(os.getenv("WS_POOL_MAX_SIZE", "100")),
cleanup_interval=int(os.getenv("WS_POOL_CLEANUP_INTERVAL", "300")),
)
Comment on lines +33 to +38

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Validate and normalize environment pool settings before using them.

from_env() currently trusts raw env values. Non-integer values can crash initialization, and non-positive bounds (especially cleanup interval) can lead to invalid runtime behavior.

💡 Suggested fix
@@
     `@classmethod`
     def from_env(cls) -> "PoolConfig":
-        return cls(
-            min_size=int(os.getenv("WS_POOL_MIN_SIZE", "10")),
-            max_size=int(os.getenv("WS_POOL_MAX_SIZE", "100")),
-            cleanup_interval=int(os.getenv("WS_POOL_CLEANUP_INTERVAL", "300")),
-        )
+        def _parse_int(name: str, default: int) -> int:
+            raw = os.getenv(name, str(default))
+            try:
+                return int(raw)
+            except ValueError:
+                logger.warning("Invalid %s=%r; using default=%d", name, raw, default)
+                return default
+
+        min_size = max(1, _parse_int("WS_POOL_MIN_SIZE", 10))
+        max_size = max(1, _parse_int("WS_POOL_MAX_SIZE", 100))
+        cleanup_interval = max(1, _parse_int("WS_POOL_CLEANUP_INTERVAL", 300))
+
+        if min_size > max_size:
+            logger.warning(
+                "WS_POOL_MIN_SIZE (%d) > WS_POOL_MAX_SIZE (%d); clamping min_size",
+                min_size, max_size
+            )
+            min_size = max_size
+
+        return cls(min_size=min_size, max_size=max_size, cleanup_interval=cleanup_interval)

Also applies to: 113-113

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@backend/services/websocket_manager.py` around lines 33 - 38, The from_env
classmethod in PoolConfig currently trusts raw env values and can crash or
produce invalid configs; update PoolConfig.from_env to validate and normalize
the three env-derived fields (min_size, max_size, cleanup_interval): parse
integers safely with a fallback/default when parsing fails, enforce min_size >=
0, cleanup_interval > 0 (replace non-positive with a sane default like 300), and
ensure min_size <= max_size (swap or clamp values if necessary). Use the
existing symbol names (from_env, PoolConfig, WS_POOL_MIN_SIZE, WS_POOL_MAX_SIZE,
WS_POOL_CLEANUP_INTERVAL) so callers remain unchanged and add minimal logging or
warnings when fallbacks are used.



class ConnectionManager:
def __init__(self, config: Optional[PoolConfig] = None):
self.config = config or PoolConfig.from_env()
self._connections: Dict[str, Set[object]] = {}
self._active_count: int = 0
self._cleanup_task: Optional[asyncio.Task] = None

@property
def active_connections(self) -> int:
return self._active_count

@property
def pool_size(self) -> int:
return len(self._connections)

def register(self, company_id: str, connection: object) -> bool:
if self._active_count >= self.config.max_size:
logger.warning(
f"Pool at max capacity ({self.config.max_size}), "
f"rejecting connection for company {company_id}"
)
return False

if company_id not in self._connections:
self._connections[company_id] = set()
self._connections[company_id].add(connection)
self._active_count += 1
return True

def unregister(self, company_id: str, connection: object) -> None:
if company_id in self._connections:
self._connections[company_id].discard(connection)
if not self._connections[company_id]:
del self._connections[company_id]
self._active_count = max(0, self._active_count - 1)

Comment on lines +64 to +76

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Fix _active_count drift in register/unregister paths.

_active_count is incremented/decremented without checking whether set membership actually changed. This can desync capacity tracking and reject valid connections.

💡 Suggested fix
@@
     def register(self, company_id: str, connection: object) -> bool:
@@
-        if company_id not in self._connections:
-            self._connections[company_id] = set()
-        self._connections[company_id].add(connection)
+        company_pool = self._connections.setdefault(company_id, set())
+        if connection in company_pool:
+            return True
+        company_pool.add(connection)
         self._active_count += 1
         return True
@@
     def unregister(self, company_id: str, connection: object) -> None:
-        if company_id in self._connections:
-            self._connections[company_id].discard(connection)
-            if not self._connections[company_id]:
-                del self._connections[company_id]
-            self._active_count = max(0, self._active_count - 1)
+        company_pool = self._connections.get(company_id)
+        if company_pool is None:
+            return
+        if connection in company_pool:
+            company_pool.remove(connection)
+            self._active_count = max(0, self._active_count - 1)
+        if not company_pool:
+            del self._connections[company_id]
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@backend/services/websocket_manager.py` around lines 64 - 76, The
_active_count is changed unconditionally in register/unregister causing drift;
update register(self, company_id, connection) to only increment _active_count
when the connection was not already present in self._connections[company_id]
(i.e., check membership before add), and update unregister(self, company_id,
connection) to only decrement _active_count when discard actually removed an
existing connection (i.e., check membership before discard or detect removal),
keeping _connections and _active_count consistent; refer to the methods register
and unregister and the attributes _connections and _active_count when making the
change.

def get_connections(self, company_id: str) -> Set[object]:
return self._connections.get(company_id, set())
Comment on lines +77 to +78

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Do not expose the internal mutable connection set.

Returning the live set lets callers mutate manager state directly, bypassing accounting and cleanup invariants. Return a snapshot instead.

💡 Suggested fix
     def get_connections(self, company_id: str) -> Set[object]:
-        return self._connections.get(company_id, set())
+        return set(self._connections.get(company_id, set()))
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@backend/services/websocket_manager.py` around lines 77 - 78, get_connections
currently returns the live mutable set from self._connections, letting callers
mutate manager state; change it to return an immutable snapshot (e.g., a
frozenset) or a shallow copy of the set so callers cannot alter internal state.
Update the get_connections method to fetch the set via
self._connections.get(company_id, set()) and return a snapshot (frozenset(...)
or set(...)) instead of the original object, preserving the original key and
value types but preventing external mutation of the manager's _connections.


def broadcast(self, company_id: str, message: str) -> int:
sent = 0
for conn in self.get_connections(company_id):
try:
if hasattr(conn, "send") and callable(conn.send):
conn.send(message)
sent += 1
except Exception as e:
logger.error(f"Broadcast error for company {company_id}: {e}")
return sent

async def cleanup_stale(self) -> int:
removed = 0
now = datetime.now(timezone.utc)
for company_id in list(self._connections.keys()):
stale = set()
for conn in self._connections[company_id]:
last_active = getattr(conn, "last_active", now)
if hasattr(conn, "closed") and conn.closed:
stale.add(conn)
for conn in stale:
self._connections[company_id].discard(conn)
self._active_count = max(0, self._active_count - 1)
removed += 1
if not self._connections[company_id]:
del self._connections[company_id]
if removed:
logger.info(f"Cleanup removed {removed} stale connection(s)")
return removed

async def start_cleanup_loop(self):
async def _loop():
while True:
await asyncio.sleep(self.config.cleanup_interval)
await self.cleanup_stale()

self._cleanup_task = asyncio.create_task(_loop())
logger.info(f"Cleanup loop started (interval={self.config.cleanup_interval}s)")

async def stop_cleanup_loop(self):
if self._cleanup_task:
self._cleanup_task.cancel()
self._cleanup_task = None
logger.info("Cleanup loop stopped")
Comment on lines +110 to +123

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Make cleanup-loop lifecycle idempotent and cancellation-safe.

start_cleanup_loop() can spawn multiple background tasks on repeated calls. stop_cleanup_loop() cancels without awaiting completion.

💡 Suggested fix
@@
 import asyncio
+from contextlib import suppress
@@
     async def start_cleanup_loop(self):
+        if self._cleanup_task and not self._cleanup_task.done():
+            return
         async def _loop():
             while True:
                 await asyncio.sleep(self.config.cleanup_interval)
                 await self.cleanup_stale()
@@
     async def stop_cleanup_loop(self):
-        if self._cleanup_task:
-            self._cleanup_task.cancel()
-            self._cleanup_task = None
-            logger.info("Cleanup loop stopped")
+        task = self._cleanup_task
+        if not task:
+            return
+        self._cleanup_task = None
+        task.cancel()
+        with suppress(asyncio.CancelledError):
+            await task
+        logger.info("Cleanup loop stopped")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@backend/services/websocket_manager.py` around lines 110 - 123, The cleanup
loop lifecycle is not idempotent and cancellation-unsafe: modify
start_cleanup_loop to only create the background task if self._cleanup_task is
None or finished (check task.done()), and in stop_cleanup_loop cancel and then
await the task completion (wrap await in try/except asyncio.CancelledError and
optionally asyncio.TimeoutError) before setting self._cleanup_task = None;
ensure you handle Race conditions by checking task existence again after
awaiting and log appropriately — update the methods start_cleanup_loop and
stop_cleanup_loop accordingly (use self._cleanup_task, _loop coroutine, and
asyncio.create_task references).



_instance: Optional[ConnectionManager] = None


def load():
global _instance
if _instance is None:
config = PoolConfig.from_env()
_instance = ConnectionManager(config)
logger.info(
f"ConnectionManager loaded (min={config.min_size}, "
f"max={config.max_size}, cleanup={config.cleanup_interval}s)"
)
return _instance


def get_instance() -> Optional[ConnectionManager]:
return _instance
Loading