-
Notifications
You must be signed in to change notification settings - Fork 281
refactor: Support dynamic pool sizing in WebSocket ConnectionManager #2797
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,142 @@ | ||
| """ | ||
| WebSocket Connection Manager — Manages active WebSocket connections with dynamic pool sizing. | ||
|
|
||
| Reads pool bounds from environment variables: | ||
| WS_POOL_MIN_SIZE — minimum pool size (default: 10) | ||
| WS_POOL_MAX_SIZE — maximum pool size (default: 100) | ||
| WS_POOL_CLEANUP_INTERVAL — cleanup interval in seconds (default: 300) | ||
| """ | ||
|
|
||
| import os | ||
| import logging | ||
| import asyncio | ||
| from typing import Dict, Set, Optional | ||
| from dataclasses import dataclass, field | ||
| from datetime import datetime, timezone | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
| logger.setLevel(logging.INFO) | ||
|
|
||
| handler = logging.StreamHandler() | ||
| formatter = logging.Formatter("[WebSocketManager] %(asctime)s - %(levelname)s - %(message)s") | ||
| handler.setFormatter(formatter) | ||
| logger.addHandler(handler) | ||
|
|
||
|
|
||
| @dataclass | ||
| class PoolConfig: | ||
| min_size: int = 10 | ||
| max_size: int = 100 | ||
| cleanup_interval: int = 300 | ||
|
|
||
| @classmethod | ||
| def from_env(cls) -> "PoolConfig": | ||
| return cls( | ||
| min_size=int(os.getenv("WS_POOL_MIN_SIZE", "10")), | ||
| max_size=int(os.getenv("WS_POOL_MAX_SIZE", "100")), | ||
| cleanup_interval=int(os.getenv("WS_POOL_CLEANUP_INTERVAL", "300")), | ||
| ) | ||
|
|
||
|
|
||
| class ConnectionManager: | ||
| def __init__(self, config: Optional[PoolConfig] = None): | ||
| self.config = config or PoolConfig.from_env() | ||
| self._connections: Dict[str, Set[object]] = {} | ||
| self._active_count: int = 0 | ||
| self._cleanup_task: Optional[asyncio.Task] = None | ||
|
|
||
| @property | ||
| def active_connections(self) -> int: | ||
| return self._active_count | ||
|
|
||
| @property | ||
| def pool_size(self) -> int: | ||
| return len(self._connections) | ||
|
|
||
| def register(self, company_id: str, connection: object) -> bool: | ||
| if self._active_count >= self.config.max_size: | ||
| logger.warning( | ||
| f"Pool at max capacity ({self.config.max_size}), " | ||
| f"rejecting connection for company {company_id}" | ||
| ) | ||
| return False | ||
|
|
||
| if company_id not in self._connections: | ||
| self._connections[company_id] = set() | ||
| self._connections[company_id].add(connection) | ||
| self._active_count += 1 | ||
| return True | ||
|
|
||
| def unregister(self, company_id: str, connection: object) -> None: | ||
| if company_id in self._connections: | ||
| self._connections[company_id].discard(connection) | ||
| if not self._connections[company_id]: | ||
| del self._connections[company_id] | ||
| self._active_count = max(0, self._active_count - 1) | ||
|
|
||
|
Comment on lines
+64
to
+76
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix
💡 Suggested fix@@
def register(self, company_id: str, connection: object) -> bool:
@@
- if company_id not in self._connections:
- self._connections[company_id] = set()
- self._connections[company_id].add(connection)
+ company_pool = self._connections.setdefault(company_id, set())
+ if connection in company_pool:
+ return True
+ company_pool.add(connection)
self._active_count += 1
return True
@@
def unregister(self, company_id: str, connection: object) -> None:
- if company_id in self._connections:
- self._connections[company_id].discard(connection)
- if not self._connections[company_id]:
- del self._connections[company_id]
- self._active_count = max(0, self._active_count - 1)
+ company_pool = self._connections.get(company_id)
+ if company_pool is None:
+ return
+ if connection in company_pool:
+ company_pool.remove(connection)
+ self._active_count = max(0, self._active_count - 1)
+ if not company_pool:
+ del self._connections[company_id]🤖 Prompt for AI Agents |
||
| def get_connections(self, company_id: str) -> Set[object]: | ||
| return self._connections.get(company_id, set()) | ||
|
Comment on lines
+77
to
+78
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do not expose the internal mutable connection set. Returning the live set lets callers mutate manager state directly, bypassing accounting and cleanup invariants. Return a snapshot instead. 💡 Suggested fix def get_connections(self, company_id: str) -> Set[object]:
- return self._connections.get(company_id, set())
+ return set(self._connections.get(company_id, set()))🤖 Prompt for AI Agents |
||
|
|
||
| def broadcast(self, company_id: str, message: str) -> int: | ||
| sent = 0 | ||
| for conn in self.get_connections(company_id): | ||
| try: | ||
| if hasattr(conn, "send") and callable(conn.send): | ||
| conn.send(message) | ||
| sent += 1 | ||
| except Exception as e: | ||
| logger.error(f"Broadcast error for company {company_id}: {e}") | ||
| return sent | ||
|
|
||
| async def cleanup_stale(self) -> int: | ||
| removed = 0 | ||
| now = datetime.now(timezone.utc) | ||
| for company_id in list(self._connections.keys()): | ||
| stale = set() | ||
| for conn in self._connections[company_id]: | ||
| last_active = getattr(conn, "last_active", now) | ||
| if hasattr(conn, "closed") and conn.closed: | ||
| stale.add(conn) | ||
| for conn in stale: | ||
| self._connections[company_id].discard(conn) | ||
| self._active_count = max(0, self._active_count - 1) | ||
| removed += 1 | ||
| if not self._connections[company_id]: | ||
| del self._connections[company_id] | ||
| if removed: | ||
| logger.info(f"Cleanup removed {removed} stale connection(s)") | ||
| return removed | ||
|
|
||
| async def start_cleanup_loop(self): | ||
| async def _loop(): | ||
| while True: | ||
| await asyncio.sleep(self.config.cleanup_interval) | ||
| await self.cleanup_stale() | ||
|
|
||
| self._cleanup_task = asyncio.create_task(_loop()) | ||
| logger.info(f"Cleanup loop started (interval={self.config.cleanup_interval}s)") | ||
|
|
||
| async def stop_cleanup_loop(self): | ||
| if self._cleanup_task: | ||
| self._cleanup_task.cancel() | ||
| self._cleanup_task = None | ||
| logger.info("Cleanup loop stopped") | ||
|
Comment on lines
+110
to
+123
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make cleanup-loop lifecycle idempotent and cancellation-safe.
💡 Suggested fix@@
import asyncio
+from contextlib import suppress
@@
async def start_cleanup_loop(self):
+ if self._cleanup_task and not self._cleanup_task.done():
+ return
async def _loop():
while True:
await asyncio.sleep(self.config.cleanup_interval)
await self.cleanup_stale()
@@
async def stop_cleanup_loop(self):
- if self._cleanup_task:
- self._cleanup_task.cancel()
- self._cleanup_task = None
- logger.info("Cleanup loop stopped")
+ task = self._cleanup_task
+ if not task:
+ return
+ self._cleanup_task = None
+ task.cancel()
+ with suppress(asyncio.CancelledError):
+ await task
+ logger.info("Cleanup loop stopped")🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| _instance: Optional[ConnectionManager] = None | ||
|
|
||
|
|
||
| def load(): | ||
| global _instance | ||
| if _instance is None: | ||
| config = PoolConfig.from_env() | ||
| _instance = ConnectionManager(config) | ||
| logger.info( | ||
| f"ConnectionManager loaded (min={config.min_size}, " | ||
| f"max={config.max_size}, cleanup={config.cleanup_interval}s)" | ||
| ) | ||
| return _instance | ||
|
|
||
|
|
||
| def get_instance() -> Optional[ConnectionManager]: | ||
| return _instance | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validate and normalize environment pool settings before using them.
from_env()currently trusts raw env values. Non-integer values can crash initialization, and non-positive bounds (especially cleanup interval) can lead to invalid runtime behavior.💡 Suggested fix
@@ `@classmethod` def from_env(cls) -> "PoolConfig": - return cls( - min_size=int(os.getenv("WS_POOL_MIN_SIZE", "10")), - max_size=int(os.getenv("WS_POOL_MAX_SIZE", "100")), - cleanup_interval=int(os.getenv("WS_POOL_CLEANUP_INTERVAL", "300")), - ) + def _parse_int(name: str, default: int) -> int: + raw = os.getenv(name, str(default)) + try: + return int(raw) + except ValueError: + logger.warning("Invalid %s=%r; using default=%d", name, raw, default) + return default + + min_size = max(1, _parse_int("WS_POOL_MIN_SIZE", 10)) + max_size = max(1, _parse_int("WS_POOL_MAX_SIZE", 100)) + cleanup_interval = max(1, _parse_int("WS_POOL_CLEANUP_INTERVAL", 300)) + + if min_size > max_size: + logger.warning( + "WS_POOL_MIN_SIZE (%d) > WS_POOL_MAX_SIZE (%d); clamping min_size", + min_size, max_size + ) + min_size = max_size + + return cls(min_size=min_size, max_size=max_size, cleanup_interval=cleanup_interval)Also applies to: 113-113
🤖 Prompt for AI Agents