diff --git a/.dockerignore b/.dockerignore index 00ed5497..635cecca 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,54 +1,7 @@ -# Git -.git -.gitignore +# Exclude everything by default +* -# Python -__pycache__/ -*.py[cod] -*$py.class -*.so -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -*.egg-info/ -.installed.cfg -*.egg - -# Virtual Environment -venv/ -env/ -ENV/ - -# IDE -.idea/ -.vscode/ -*.swp -*.swo - -# Docker -Dockerfile -.dockerignore - -# Documentation -docs/ -*.md -*.rst - -# Tests -tests/ -.pytest_cache/ -.coverage -htmlcov/ - -# Misc -.DS_Store \ No newline at end of file +# Allow only what the Dockerfile needs +!collab_env/ +!scripts/ +!pyproject.toml diff --git a/.gcloudignore b/.gcloudignore new file mode 100644 index 00000000..06b38a7c --- /dev/null +++ b/.gcloudignore @@ -0,0 +1,13 @@ +# Upload only what the Docker build needs +# .gcloudignore controls what gcloud builds submit uploads + +# Start by ignoring everything +* + +# Allow only what Dockerfile.tracking-studio needs +!collab_env/ +!scripts/ +!pyproject.toml +!Dockerfile.tracking-studio +!cloudbuild.yaml +!.dockerignore diff --git a/Dockerfile.tracking-studio b/Dockerfile.tracking-studio new file mode 100644 index 00000000..7916ee7f --- /dev/null +++ b/Dockerfile.tracking-studio @@ -0,0 +1,48 @@ +FROM python:3.10-slim + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + ffmpeg \ + libgl1 \ + libglib2.0-0 \ + libsm6 \ + libxext6 \ + libxrender-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install uv for fast dependency resolution +RUN pip install uv + +WORKDIR /workspace + +# Install CPU-only PyTorch first (prevents ultralytics pulling 900MB CUDA version) +RUN uv pip install --system --no-cache \ + torch torchvision --index-url https://download.pytorch.org/whl/cpu + +# Install Python dependencies (Roboflow models load via .pt download, no inference SDK needed) +COPY pyproject.toml . +RUN uv pip install --system --no-cache \ + nicegui \ + ultralytics \ + supervision \ + google-cloud-storage \ + gcsfs \ + opencv-python-headless \ + pandas \ + numpy \ + loguru + +# Copy application code +COPY collab_env/ collab_env/ +COPY scripts/ scripts/ + +# Create model cache and tmp directories +RUN mkdir -p /workspace/models +RUN mkdir -p /tmp/videos /tmp/outputs /tmp/uploads + +ENV PORT=8080 +ENV PYTHONUNBUFFERED=1 +ENV PYTHONPATH=/workspace +ENV NICEGUI_RELOAD=false + +CMD ["python", "scripts/tracking/run_tracking_studio.py"] diff --git a/README.rst b/README.rst index 7ee0cd2f..bb7b2145 100644 --- a/README.rst +++ b/README.rst @@ -117,5 +117,6 @@ Detailed documentation for specific modules: * `GNN Training `_ - Graph Neural Network training and rollouts * `Simulation `_ - Boids simulation and output format * `Tracking `_ - Animal tracking and thermal video processing +* `Tracking Studio `_ - Interactive web GUI for video object detection and tracking For contributing guidelines, see `CONTRIBUTING.md `_. diff --git a/cloudbuild.yaml b/cloudbuild.yaml new file mode 100644 index 00000000..48cb536b --- /dev/null +++ b/cloudbuild.yaml @@ -0,0 +1,35 @@ +steps: + - name: 'gcr.io/cloud-builders/docker' + args: + - 'build' + - '-t' + - 'gcr.io/$PROJECT_ID/tracking-studio:latest' + - '-f' + - 'Dockerfile.tracking-studio' + - '.' + + - name: 'gcr.io/cloud-builders/docker' + args: ['push', 'gcr.io/$PROJECT_ID/tracking-studio:latest'] + + - name: 'gcr.io/google.com/cloudsdktool/cloud-sdk' + entrypoint: gcloud + args: + - 'run' + - 'deploy' + - 'tracking-studio' + - '--image=gcr.io/$PROJECT_ID/tracking-studio:latest' + - '--platform=managed' + - '--region=us-central1' + - '--memory=4Gi' + - '--cpu=2' + - '--timeout=900' + - '--concurrency=1' + - '--max-instances=5' + - '--set-secrets=ROBOFLOW_API_KEY=roboflow-api-key:latest' + - '--no-allow-unauthenticated' + +images: + - 'gcr.io/$PROJECT_ID/tracking-studio:latest' + +options: + machineType: 'N1_HIGHCPU_8' diff --git a/collab_env/tracking/visualization.py b/collab_env/tracking/visualization.py index 5e76edb2..7076a63f 100644 --- a/collab_env/tracking/visualization.py +++ b/collab_env/tracking/visualization.py @@ -135,6 +135,7 @@ def overlay_tracks_on_video( # Get frame size sample_frame = cv2.imread(str(frame_paths[0])) + assert sample_frame is not None, f"Failed to read frame: {frame_paths[0]}" h, w = sample_frame.shape[:2] writer = cv2.VideoWriter( str(output_video), @@ -146,6 +147,8 @@ def overlay_tracks_on_video( for frame_path in frame_paths: frame_idx = int(frame_path.stem.split("_")[-1]) frame = cv2.imread(str(frame_path)) + if frame is None: + continue frame_tracks = df[df["frame"] == frame_idx] for _, row in frame_tracks.iterrows(): diff --git a/collab_env/tracking_studio/__init__.py b/collab_env/tracking_studio/__init__.py new file mode 100644 index 00000000..38508125 --- /dev/null +++ b/collab_env/tracking_studio/__init__.py @@ -0,0 +1,12 @@ +""" +NiceGUI-based Video Tracking Studio + +A web application for interactive video tracking with support for: +- GCS bucket video browsing and upload +- YOLO and Roboflow model selection +- Real-time tracking visualization +- ByteTrack with Re-ID support +- CSV output export +""" + +__version__ = "0.1.0" diff --git a/collab_env/tracking_studio/app.py b/collab_env/tracking_studio/app.py new file mode 100644 index 00000000..3a698b00 --- /dev/null +++ b/collab_env/tracking_studio/app.py @@ -0,0 +1,1224 @@ +""" +Main NiceGUI Tracking Studio Application + +Single-page interactive app for video tracking with: +- GCS bucket browsing and video upload +- Model selection (YOLO and Roboflow) +- Real-time tracking visualization +- CSV output download +""" + +from nicegui import ui +import asyncio +from pathlib import Path +from typing import Optional +import uuid +import os +import base64 +from loguru import logger + +import json as _json + +from .gcs_browser import GCSVideoBrowser +from .model_manager import ModelManager +from .video_processor import VideoTracker +from .video_converter import convert_to_h264, needs_conversion + +# Persistent preferences file (last used model/video settings) +_PREFS_FILE = Path.home() / ".tracking_studio.json" + + +def load_preferences() -> dict: + """Load saved preferences from dot file.""" + try: + if _PREFS_FILE.exists(): + return _json.loads(_PREFS_FILE.read_text()) + except Exception as e: + logger.warning(f"Failed to load preferences: {e}") + return {} + + +def save_preferences(prefs: dict): + """Save preferences to dot file (merges with existing).""" + try: + existing = load_preferences() + existing.update(prefs) + _PREFS_FILE.write_text(_json.dumps(existing, indent=2)) + except Exception as e: + logger.warning(f"Failed to save preferences: {e}") + + +# Load ByteTrack parameter definitions +def load_bytetrack_params(): + """Load ByteTrack parameter schema from JSON""" + import json + + params_file = Path(__file__).parent / "bytetrack_params.json" + with open(params_file, "r") as f: + return json.load(f) + + +bytetrack_params_schema = load_bytetrack_params() + + +# Initialize services +def get_credentials_path(): + """Get GCS credentials path from environment or default. Returns None for ADC.""" + env_path = os.getenv("GCS_CREDENTIALS") + if env_path: + return env_path + default = "/workspace/config/collab-data-463313-c340ad86b28e.json" + if os.path.exists(default): + return default + return None # GCSClient will use Application Default Credentials + + +gcs_browser: "Optional[GCSVideoBrowser]" +try: + gcs_browser = GCSVideoBrowser(credentials_path=get_credentials_path()) +except Exception as e: + logger.error(f"Failed to initialize GCS browser: {e}") + gcs_browser = None + +model_manager = ModelManager() + + +@ui.page("/") +async def index(): + """Main tracking studio page""" + session_id = str(uuid.uuid4())[:8] + prefs = load_preferences() + + # State variables (stored in page context) + import threading + + state = { + "selected_bucket": None, + "selected_video_path": None, + "selected_model": None, + "processing": False, + "results": None, + "uploaded_model": None, # Uploaded model .pt file + "stop_event": None, # Hard stop + "pause_event": None, # Pause/resume + "skip_frames_event": None, # Skip forward signal + "video_path": None, # Path to current video being processed + "video_loaded": False, # Video is loaded and ready for playback + "model_loaded": False, # Model is loaded and ready for tracking + "loaded_model": None, # Reference to loaded model object + } + + # UI Layout + with ui.column().classes("w-full p-4 gap-3"): + # Header + with ui.row().classes("w-full items-center mb-2"): + ui.label("🎯 Video Tracking Studio").classes("text-2xl font-bold") + ui.space() + ui.label("Real-time object detection and tracking").classes( + "text-sm text-gray-600" + ) + + # Row 1: Video source selection (Bucket | Folder | Video | Upload) + with ui.card().classes("w-full shadow-md p-2"): + with ui.row().classes("w-full gap-2 items-end"): + # GCS selection + if gcs_browser: + bucket_select = ui.select( + label="Bucket", + options=[], + ).style("width: 250px") + try: + buckets = gcs_browser.list_buckets() + bucket_select.options = buckets + if buckets: + bucket_select.value = buckets[0] + except Exception as e: + logger.error(f"Failed to list buckets: {e}") + + folder_select = ui.select( + label="Folder", options=[""], value="", clearable=True + ).style("width: 350px") + + video_select = ui.select(label="Video", options=[]).classes( + "flex-grow" + ) + + async def update_folders(e): + """Update folder list when bucket changes""" + try: + bucket = bucket_select.value + if bucket: + folders = gcs_browser.list_folders(bucket, "") + folder_select.options = [""] + folders + folder_select.value = "" + folder_select.update() + await update_video_list(None) + except Exception as error: + logger.error(f"Failed to list folders: {error}") + + async def update_video_list(e): + """Update video list when bucket or folder changes""" + try: + bucket = bucket_select.value + folder = folder_select.value or "" + if bucket: + videos = gcs_browser.list_videos(bucket, folder) + video_select.options = [v["rel_path"] for v in videos] + video_select.update() + except Exception as error: + logger.error(f"Failed to list videos: {error}") + + def enable_load_video_btn(e=None): + """Enable Load Video button when a GCS video is selected""" + if video_select.value: + load_video_btn.enable() + + bucket_select.on("update:model-value", update_folders) + folder_select.on("update:model-value", update_video_list) + video_select.on("update:model-value", enable_load_video_btn) + + # Restore last used bucket from preferences + saved_bucket = prefs.get("video_bucket") + if saved_bucket and saved_bucket in buckets: + bucket_select.value = saved_bucket + + if bucket_select.value: + + async def _restore_gcs_selection(): + await update_folders(None) + saved_folder = prefs.get("video_folder", "") + if saved_folder and saved_folder in folder_select.options: + folder_select.value = saved_folder + await update_video_list(None) + saved_video = prefs.get("video_name") + if saved_video and saved_video in video_select.options: + video_select.value = saved_video + load_video_btn.enable() + + ui.timer(0.1, _restore_gcs_selection, once=True) + + # Upload widget + async def handle_upload(e): + """Handle user video upload and auto-load it""" + try: + upload_path = Path(f"/tmp/uploads/{session_id}") + upload_path.mkdir(parents=True, exist_ok=True) + uploaded_file = upload_path / e.name + uploaded_file.write_bytes(e.content.read()) + ui.notify(f"Uploaded: {e.name}") + await load_video(local_video=uploaded_file) + except Exception as error: + logger.error(f"Upload failed: {error}") + ui.notify(f"Upload failed: {error}", type="negative") + + ui.upload( + on_upload=handle_upload, + auto_upload=True, + ).props( + "accept=video/mp4,video/quicktime,video/x-msvideo dense flat" + ).props("label=Upload").style("width: 120px; height: 40px") + + # Row 2: Model/Params (left) | Controls + Preview (right) + with ui.row().classes("w-full gap-3"): + # LEFT: Model + Parameters (stacked) + with ( + ui.column().classes("gap-3").style("flex: 0 0 280px; min-width: 280px") + ): + # Model card + with ui.card().classes("w-full shadow-md p-3"): + ui.label("Model").classes("text-sm font-semibold mb-2") + + model_source = ui.select( + label="Source", + options=["YOLO", "Roboflow", "Custom"], + value=prefs.get("model_source", "Roboflow"), + ).classes("w-full") + + # YOLO model selection + yolo_container = ui.column().classes("w-full mt-2") + yolo_container.visible = ( + prefs.get("model_source", "Roboflow") == "YOLO" + ) + with yolo_container: + yolo_model_input = ( + ui.input( + label="Model Name", + placeholder="e.g., yolo11n.pt", + value=prefs.get("yolo_model_name", "yolo11n.pt"), + ) + .classes("w-full") + .tooltip("Enter any YOLO model name (will auto-download)") + ) + + # Roboflow model selection + rf_container = ui.column().classes("w-full mt-2 gap-2") + rf_container.visible = ( + prefs.get("model_source", "Roboflow") == "Roboflow" + ) + with rf_container: + # Populate project dropdown from Roboflow workspace + try: + _rf_project_options = model_manager.list_roboflow_projects() + except Exception as _err: + logger.warning(f"Could not list Roboflow projects: {_err}") + _rf_project_options = [] + + _saved_rf_project = prefs.get("rf_project_id", "") + if ( + _saved_rf_project + and _saved_rf_project not in _rf_project_options + ): + _rf_project_options = [ + _saved_rf_project, + *_rf_project_options, + ] + + rf_project_input = ( + ui.select( + label="Project ID", + options=_rf_project_options, + value=_saved_rf_project or None, + ) + .classes("w-full") + .tooltip("Pick a project from your Roboflow workspace") + ) + + # Store raw version data for detail dialog + _rf_versions_raw = {} + + async def list_rf_models(): + """Query Roboflow for available model versions""" + project_id = rf_project_input.value + if not project_id: + ui.notify("Please enter project ID", type="warning") + return + try: + rf_version_select.options = {} + rf_version_select.value = None + rf_version_select.disable() + rf_detail_btn.visible = False + _rf_versions_raw.clear() + versions = model_manager.list_roboflow_project_models( + project_id + ) + if versions: + options = {} + _rf_versions_raw.clear() + for v in versions: + parts = [f"v{v['version']}"] + if v["name"]: + parts.append(v["name"]) + parts.append(f"{v['images']} imgs") + options[v["version"]] = " | ".join(parts) + _rf_versions_raw[v["version"]] = v.get( + "raw", {} + ) + rf_version_select.options = options + rf_version_select.value = versions[0]["version"] + rf_version_select.enable() + rf_detail_btn.visible = True + enable_load_model_btn() + ui.notify( + f"Found {len(versions)} versions", + type="positive", + ) + else: + ui.notify("No versions found", type="warning") + except Exception as error: + logger.error(f"Failed to list models: {error}") + ui.notify(f"Error: {error}", type="negative") + + def show_version_detail(): + """Show full JSON for the selected version in a dialog""" + import json + + ver = rf_version_select.value + raw = _rf_versions_raw.get(ver, {}) + if not raw: + ui.notify("No version data available", type="warning") + return + with ( + ui.dialog() as dlg, + ui.card().style("min-width: 500px; max-height: 80vh;"), + ): + ui.label(f"Version {ver} Details").classes( + "text-sm font-semibold" + ) + ui.code(json.dumps(raw, indent=2, default=str)).classes( + "w-full text-xs" + ).style("max-height: 60vh; overflow: auto;") + ui.button("Close", on_click=dlg.close).props( + "size=sm flat" + ) + dlg.open() + + with ui.row().classes("w-full gap-2 items-center no-wrap"): + rf_version_select = ui.select( + label="Version", + options=[], + ).classes("flex-grow") + rf_version_select.disable() + rf_detail_btn = ui.button( + "Details", on_click=show_version_detail + ).props("size=sm flat") + rf_detail_btn.visible = False + + rf_project_input.on( + "update:model-value", lambda _e: list_rf_models() + ) + + # Custom model upload + custom_container = ui.column().classes("w-full mt-2") + custom_container.visible = ( + prefs.get("model_source", "Roboflow") == "Custom" + ) + with custom_container: + + async def handle_model_upload(e): + """Handle model .pt file upload""" + try: + model_upload_path = Path(f"/tmp/models/{session_id}") + model_upload_path.mkdir(parents=True, exist_ok=True) + uploaded_model_file = model_upload_path / e.name + uploaded_model_file.write_bytes(e.content.read()) + state["uploaded_model"] = uploaded_model_file + ui.notify(f"Model uploaded: {e.name}", type="positive") + logger.info(f"Model uploaded to: {uploaded_model_file}") + load_model_btn.enable() + except Exception as error: + logger.error(f"Model upload failed: {error}") + ui.notify( + f"Model upload failed: {error}", type="negative" + ) + + ui.upload( + on_upload=handle_model_upload, + auto_upload=True, + ).props("accept=.pt dense flat").props( + "label=Upload Model (.pt)" + ).classes("w-full") + + # Toggle visibility based on model source + def toggle_model_ui(e=None): + value = model_source.value + yolo_container.visible = value == "YOLO" + rf_container.visible = value == "Roboflow" + custom_container.visible = value == "Custom" + enable_load_model_btn() + + def enable_load_model_btn(e=None): + """Enable Load Model button when model is selected""" + if model_source.value == "YOLO" and yolo_model_input.value: + load_model_btn.enable() + elif ( + model_source.value == "Roboflow" and rf_version_select.value + ): + load_model_btn.enable() + elif model_source.value == "Custom" and state.get( + "uploaded_model" + ): + load_model_btn.enable() + + model_source.on("update:model-value", toggle_model_ui) + yolo_model_input.on("update:model-value", enable_load_model_btn) + rf_version_select.on("update:model-value", enable_load_model_btn) + + # Auto-fetch Roboflow versions if saved project exists + if prefs.get("model_source") == "Roboflow" and prefs.get( + "rf_project_id" + ): + + async def _restore_rf_version(): + await list_rf_models() + saved_ver = prefs.get("rf_version") + if saved_ver and saved_ver in ( + rf_version_select.options or {} + ): + rf_version_select.value = saved_ver + enable_load_model_btn() + + ui.timer(0.1, _restore_rf_version, once=True) + elif prefs.get("model_source") == "YOLO": + ui.timer(0.1, enable_load_model_btn, once=True) + + # Parameters card + params_card = ui.card().classes("w-full shadow-md p-3") + with params_card: + ui.label("⚙️ Parameters").classes("text-sm font-semibold mb-2") + + # Detection confidence (not in ByteTrack params) + with ui.column().classes("w-full gap-1"): + conf_label = ui.label("Confidence: 0.50").classes("text-xs") + conf_slider = ( + ui.slider(min=0.1, max=0.9, step=0.05, value=0.5) + .classes("w-full") + .tooltip("Detection confidence threshold") + ) + conf_slider.on( + "update:model-value", + lambda e: conf_label.set_text(f"Confidence: {e.args:.2f}"), + ) + + # Dynamic ByteTrack parameters from JSON + param_widgets = {} # Store references to UI elements + with ui.column().classes("w-full gap-1 mt-2"): + for param_name, param_config in bytetrack_params_schema.items(): + if param_config["type"] == "float": + # Float slider + default_val = param_config["default"] + min_val, max_val = param_config["range"] + + # Create label with tooltip + param_label = ui.label( + f"{param_name.replace('_', ' ').title()}: {default_val:.2f}" + ).classes("text-xs") + param_label.tooltip(param_config["description"]) + + # Create slider + step = 0.05 if max_val <= 1.0 else 0.1 + param_slider = ui.slider( + min=min_val, + max=max_val, + step=step, + value=default_val, + ).classes("w-full") + + # Update label on change + param_slider.on( + "update:model-value", + lambda e, lbl=param_label, name=param_name: ( + lbl.set_text( + f"{name.replace('_', ' ').title()}: {e.args:.2f}" + ) + ), + ) + param_widgets[param_name] = param_slider + + elif param_config["type"] == "int": + # Int slider + default_val = param_config["default"] + min_val = param_config["range"][0] + max_val = ( + param_config["range"][1] + if param_config["range"][1] + else 300 + ) + + param_label = ui.label( + f"{param_name.replace('_', ' ').title()}: {default_val}" + ).classes("text-xs") + param_label.tooltip(param_config["description"]) + + param_slider = ui.slider( + min=min_val, max=max_val, step=1, value=default_val + ).classes("w-full") + + param_slider.on( + "update:model-value", + lambda e, lbl=param_label, name=param_name: ( + lbl.set_text( + f"{name.replace('_', ' ').title()}: {int(e.args)}" + ) + ), + ) + param_widgets[param_name] = param_slider + + elif param_config["type"] == "bool": + # Checkbox + param_checkbox = ui.checkbox( + param_name.replace("_", " ").title(), + value=param_config["default"], + ).classes("text-xs") + param_checkbox.tooltip(param_config["description"]) + param_widgets[param_name] = param_checkbox + + elif param_config["type"] == "string": + # Dropdown for options + param_select = ui.select( + label=param_name.replace("_", " ").title(), + options=param_config["options"], + value=param_config["default"], + ).classes("w-full text-xs") + param_select.tooltip(param_config["description"]) + param_widgets[param_name] = param_select + + # Detection-only mode toggle + detection_only_checkbox = ui.checkbox( + "Detection only (no tracking)", + value=False, + ).classes("text-xs mt-2") + detection_only_checkbox.tooltip( + "Run detection without ByteTrack — shows raw detections per frame" + ) + + # Skip frames (for fast-forward, not a ByteTrack param) + with ui.row().classes("w-full items-center gap-2 mt-2"): + skip_frames_label = ui.label("Skip: every frame").classes( + "text-xs" + ) + skip_frames_slider = ui.slider( + min=1, max=30, step=1, value=1 + ).style("width: 100px") + skip_frames_slider.tooltip( + "Process every Nth frame (1 = all frames)" + ) + skip_frames_slider.on( + "update:model-value", + lambda e: skip_frames_label.set_text( + "Skip: every frame" + if int(e.args) == 1 + else f"Skip: every {int(e.args)} frames" + ), + ) + + # GUI refresh rate (display updates, not a ByteTrack param) + with ui.row().classes("w-full items-center gap-2 mt-1"): + display_update_label = ui.label("Display: every frame").classes( + "text-xs" + ) + display_update_slider = ui.slider( + min=1, max=30, step=1, value=1 + ).style("width: 100px") + display_update_slider.tooltip( + "Update display every Nth frame (1 = every frame, higher = skip display frames)" + ) + display_update_slider.on( + "update:model-value", + lambda e: display_update_label.set_text( + "Display: every frame" + if int(e.args) == 1 + else f"Display: every {int(e.args)} frames" + ), + ) + + # RIGHT: Controls + Preview (stacked vertically) + with ui.column().classes("flex-grow gap-3"): + # Controls card + with ui.card().classes("w-full shadow-md p-3"): + # Row 1: Load buttons + with ui.row().classes("w-full items-center gap-2 mb-2"): + load_video_btn = ui.button("Load Video").props( + "color=primary icon=video_file" + ) + load_video_btn.disable() # Enabled when video selected + + load_model_btn = ui.button("Load Model").props( + "color=primary icon=model_training" + ) + load_model_btn.disable() # Enabled when model selected + + ui.separator().props("vertical") + + with ui.column().classes("flex-grow gap-1"): + status_label = ui.label("Select video and model").classes( + "text-xs" + ) + + # Row 2: Playback controls + with ui.row().classes("w-full items-center gap-2"): + start_btn = ui.button("Start Tracking").props( + "color=positive icon=play_arrow" + ) + start_btn.disable() # Enabled when both video and model loaded + + pause_btn = ui.button("Pause").props("color=warning icon=pause") + pause_btn.disable() + + stop_btn = ui.button("Stop").props("color=negative icon=stop") + stop_btn.disable() + + ui.separator().props("vertical") + + status_indicator = ui.label("Ready").classes( + "text-xs flex-grow" + ) + + # Time slider for seeking + with ui.column().classes("w-full gap-1 mt-2"): + time_label = ui.label("Frame: 0 / 0").classes( + "text-xs text-gray-600" + ) + time_slider = ui.slider(min=0, max=100, value=0).classes( + "w-full" + ) + time_slider.disable() + + _preview_pending = [False] + + async def preview_frame_on_drag(e): + """Preview frame during slider drag — reads frame via cv2""" + if _preview_pending[0] or not state.get("video_path"): + return + _preview_pending[0] = True + try: + import cv2 + + target_frame = int(e.args) + video_display.content = "" # Clear SVG overlay + if not state.get("processing"): + reset_tracker_state() + + def read_frame(path, idx): + cap = cv2.VideoCapture(str(path)) + cap.set(cv2.CAP_PROP_POS_FRAMES, idx) + ret, f = cap.read() + cap.release() + return f if ret else None + + frame = await asyncio.to_thread( + read_frame, state["video_path"], target_frame + ) + if frame is not None: + _, buf = cv2.imencode( + ".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 50] + ) + b64 = base64.b64encode(buf).decode() + video_display.set_source( + f"data:image/jpeg;base64,{b64}" + ) + + total_frames = state.get("total_frames", target_frame) + time_label.text = ( + f"Frame: {target_frame} / {total_frames}" + ) + finally: + _preview_pending[0] = False + + async def seek_to_frame(e): + """Seek to specific frame when slider is released""" + target_frame = int(e.args) + if state.get("processing") and state.get( + "skip_frames_event" + ): + current_frame = state.get("current_frame", 0) + if target_frame != current_frame: + state["skip_frames_event"]["skip_amount"] = ( + target_frame - current_frame + ) + ui.notify( + f"Seeking to frame {target_frame}...", + type="info", + ) + elif state.get("video_path"): + # Not processing: show the frame at release position + def read_frame(path, idx): + import cv2 + + cap = cv2.VideoCapture(str(path)) + cap.set(cv2.CAP_PROP_POS_FRAMES, idx) + ret, f = cap.read() + cap.release() + return f if ret else None + + frame = await asyncio.to_thread( + read_frame, state["video_path"], target_frame + ) + if frame is not None: + import cv2 + + _, buf = cv2.imencode( + ".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 70] + ) + b64 = base64.b64encode(buf).decode() + video_display.set_source( + f"data:image/jpeg;base64,{b64}" + ) + + # Live preview during drag + time_slider.on("update:model-value", preview_frame_on_drag) + # Actual seek on release + time_slider.on("change", seek_to_frame) + + # Preview card + with ui.card().classes("w-full shadow-md p-3"): + ui.label("Live Preview").classes("text-sm font-semibold mb-2") + video_container = ( + ui.element("div") + .classes("border-2 border-gray-200 rounded bg-gray-50") + .style("max-width: 100%; resize: horizontal; overflow: hidden;") + ) + with video_container: + video_display = ui.interactive_image("").style("width: 100%;") + + # Debug: actual parameters passed to detector/tracker + debug_params_card = ui.card().classes("w-full shadow-md p-3 hidden") + with debug_params_card: + ui.label("Active Parameters").classes("text-xs font-semibold mb-1") + debug_params_label = ( + ui.label("") + .classes("text-xs font-mono text-gray-600") + .style("white-space: pre-wrap;") + ) + + # Results (initially hidden, separate row) + results_container = ui.card().classes("w-full shadow-md p-3 hidden") + with results_container: + with ui.row().classes("w-full items-center gap-3"): + ui.label("✅ Results").classes("text-sm font-semibold") + stats_label = ui.label().classes("text-sm flex-grow") + download_track_btn = ui.button("Download CSV").props( + "color=primary icon=download size=sm" + ) + + # Event handlers + def reset_tracker_state(): + """Reset YOLO model's internal tracker so track IDs start fresh.""" + model = state.get("loaded_model") + if model and hasattr(model, "predictor") and model.predictor is not None: + # Full predictor reset — Ultralytics will create a fresh one on next call + model.predictor = None + + async def load_video(local_video=None): + """Load and prepare video for viewing/tracking. + + Args: + local_video: Path to a local video file (e.g. from upload). + If None, downloads the video selected in the GCS dropdowns. + """ + from nicegui import context + + try: + status_label.text = "Loading video..." + load_video_btn.disable() + + # Capture client context before threading + client = context.client + + if local_video is None: + # Download from GCS dropdowns + if gcs_browser and bucket_select.value and video_select.value: + status_label.text = "Downloading video..." + bucket = bucket_select.value + video_name = video_select.value + gcs_path = f"{bucket}/{video_name}" + + local_video_dir = Path(f"/tmp/videos/{session_id}") + local_video_dir.mkdir(parents=True, exist_ok=True) + local_video = local_video_dir / Path(video_name).name + + await asyncio.to_thread( + gcs_browser.download_video, gcs_path, str(local_video) + ) + else: + raise ValueError( + "No video selected. Please select a video from the dropdowns." + ) + + # Ensure browser-compatible H.264 MP4 + if await asyncio.to_thread(needs_conversion, local_video): + converted_video = local_video.parent / f"{local_video.stem}_h264.mp4" + # Check if codec is already h264 (just needs container remux) + import subprocess + + try: + probe = subprocess.run( + [ + "ffprobe", + "-v", + "error", + "-select_streams", + "v:0", + "-show_entries", + "stream=codec_name", + "-of", + "default=noprint_wrappers=1:nokey=1", + str(local_video), + ], + capture_output=True, + text=True, + check=True, + ) + is_h264 = probe.stdout.strip() == "h264" + except Exception: + is_h264 = False + + if is_h264: + status_label.text = "Remuxing to MP4..." + await asyncio.to_thread( + convert_to_h264, local_video, converted_video, remux_only=True + ) + else: + status_label.text = "Converting to H.264..." + await asyncio.to_thread( + convert_to_h264, local_video, converted_video + ) + local_video = converted_video + with client: + ui.notify("Video converted to H.264") + + # Restore context for UI updates after threading + with client: + # Store video path in state + state["video_path"] = local_video + state["video_loaded"] = True + + # Read video metadata and first frame + import cv2 + + cap = cv2.VideoCapture(str(local_video)) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = cap.get(cv2.CAP_PROP_FPS) or 30 + w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + ret, first_frame = cap.read() + cap.release() + + state["total_frames"] = total_frames + state["video_fps"] = fps + state["video_width"] = w + state["video_height"] = h + + # Size container to video dimensions + video_container.style( + f"width: {w}px; max-width: 100%; aspect-ratio: {w}/{h};" + f" resize: horizontal; overflow: hidden;" + ) + + # Show first frame + video_display.content = "" + reset_tracker_state() + if ret: + _, buf = cv2.imencode( + ".jpg", first_frame, [cv2.IMWRITE_JPEG_QUALITY, 80] + ) + b64 = base64.b64encode(buf).decode() + video_display.set_source(f"data:image/jpeg;base64,{b64}") + logger.info(f"Displayed first frame ({w}x{h})") + + # Enable time slider for playback + time_slider.enable() + time_slider.set_value(0) + time_slider.props(f"max={total_frames}") + time_label.text = f"Frame: 0 / {total_frames}" + + status_label.text = "Video loaded ✓" + ui.notify("Video loaded successfully", type="positive") + + # Save video selection to preferences + video_prefs = {} + if gcs_browser and bucket_select.value: + video_prefs["video_bucket"] = bucket_select.value + video_prefs["video_folder"] = folder_select.value or "" + video_prefs["video_name"] = video_select.value or "" + save_preferences(video_prefs) + + # Enable Start button if model is also loaded + if state["model_loaded"]: + start_btn.enable() + + except Exception as e: + logger.error(f"Failed to load video: {e}", exc_info=True) + with client: + status_label.text = "Error loading video" + ui.notify(f"Error: {str(e)}", type="negative") + finally: + with client: + load_video_btn.enable() + + async def load_model(): + """Load selected model""" + from nicegui import context + + try: + status_label.text = "Loading model..." + load_model_btn.disable() + + # Capture client context before threading + client = context.client + + if model_source.value == "YOLO": + model = await asyncio.to_thread( + model_manager.load_yolo_model, yolo_model_input.value + ) + elif model_source.value == "Roboflow": + # Roboflow mode: download from API + project_id = rf_project_input.value + version = rf_version_select.value + + if not project_id or not version: + raise ValueError( + "Please select a Roboflow model (project ID and version)" + ) + + # Validate project ID format (should be workspace/project) + project_parts = project_id.split("/") + if len(project_parts) != 2: + raise ValueError( + f"Invalid project ID: '{project_id}'\n" + f"Expected format: workspace/project (e.g., 'dima-sdrkv/ratsmerged20260211')" + ) + + # Construct full model ID: workspace/project/version + model_id = f"{project_id}/{version}" + logger.info(f"Loading Roboflow model: {model_id}") + + model = await asyncio.to_thread( + model_manager.load_roboflow_model, model_id + ) + else: # Custom + # Custom mode: use uploaded model file + if state.get("uploaded_model"): + model_path = str(state["uploaded_model"]) + logger.info(f"Loading uploaded model: {model_path}") + model = await asyncio.to_thread( + model_manager.load_roboflow_model, model_path + ) + else: + raise ValueError("Please upload a model .pt file") + + # Restore context for UI updates after threading + with client: + # Store model in state + state["loaded_model"] = model + state["model_loaded"] = True + + # Detect tracker type for display + from ultralytics import YOLO + + tracker_type = ( + "YOLO Native" if isinstance(model, YOLO) else "Supervision" + ) + state["tracker_type"] = tracker_type + + status_label.text = f"Model loaded ✓ ({tracker_type} tracking)" + ui.notify("Model loaded successfully", type="positive") + + # Save model selection to preferences + model_prefs = {"model_source": model_source.value} + if model_source.value == "YOLO": + model_prefs["yolo_model_name"] = yolo_model_input.value + elif model_source.value == "Roboflow": + model_prefs["rf_project_id"] = rf_project_input.value + model_prefs["rf_version"] = rf_version_select.value + save_preferences(model_prefs) + + # Enable Start button if video is also loaded + if state["video_loaded"]: + start_btn.enable() + + except Exception as e: + logger.error(f"Failed to load model: {e}", exc_info=True) + # Restore context for error UI updates + with client: + status_label.text = "Error loading model" + ui.notify(f"Error: {str(e)}", type="negative") + finally: + with client: + load_model_btn.enable() + + def pause_tracking(): + """Pause/resume tracking""" + if state["pause_event"]: + if state["pause_event"].is_set(): + # Currently paused, resume + state["pause_event"].clear() + pause_btn.props("icon=pause") + pause_btn.text = "Pause" + status_indicator.text = "Resuming..." + ui.notify("Resumed", type="info") + else: + # Currently running, pause + state["pause_event"].set() + pause_btn.props("icon=play_arrow") + pause_btn.text = "Resume" + status_indicator.text = "Paused" + ui.notify("Paused", type="warning") + + def stop_tracking(): + """Hard stop - terminates processing""" + if state["stop_event"]: + state["stop_event"].set() + status_indicator.text = "Stopping..." + ui.notify("Stopping tracking...", type="negative") + + async def start_tracking(): + """Start tracking on already-loaded video with already-loaded model""" + if not state.get("video_loaded") or not state.get("model_loaded"): + ui.notify("Please load video and model first", type="warning") + return + + state["processing"] = True + reset_tracker_state() + state["stop_event"] = threading.Event() # Hard stop + state["pause_event"] = threading.Event() # Pause (starts clear = not paused) + state["skip_frames_event"] = {"skip_amount": 0} # Skip forward + # Resume from current slider position (preserved after stop) + start_frame = int(time_slider.value) if time_slider.value else 0 + state["current_frame"] = start_frame + start_btn.disable() + pause_btn.text = "Pause" + pause_btn.props("icon=pause") + pause_btn.enable() + stop_btn.enable() + params_card.style("opacity: 0.5; pointer-events: none;") + results_container.classes(add="hidden") + + try: + # Show mode in progress label + if detection_only_checkbox.value: + status_indicator.text = "Starting detection..." + else: + tracker_type = state.get("tracker_type", "Unknown") + status_indicator.text = f"Starting tracking ({tracker_type})..." + + # Use already-loaded video and model from state + local_video = state["video_path"] + model = state["loaded_model"] + + # Frame callback for real-time UI updates + display_interval = int(display_update_slider.value) + _track_colors = [ + "#00FF00", + "#FF0000", + "#0080FF", + "#FFFF00", + "#FF00FF", + "#00FFFF", + "#FF8000", + "#8000FF", + "#00FF80", + "#FF0080", + "#80FF00", + "#0040FF", + ] + + async def frame_callback(frame, detections, frame_idx, total_frames): + """Update UI: JPEG frame + SVG bbox overlay on every callback.""" + import cv2 + + state["current_frame"] = frame_idx + + # Update base JPEG image (rate controlled by Display Update slider) + _, buf = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 50]) + b64 = base64.b64encode(buf).decode() + video_display.set_source(f"data:image/jpeg;base64,{b64}") + + # Update SVG overlay with detection bboxes + svg_rects = [] + det_only = detection_only_checkbox.value + if len(detections) > 0: + for i, bbox in enumerate(detections.xyxy): + x1, y1, x2, y2 = bbox + bw, bh = x2 - x1, y2 - y1 + conf = detections.confidence[i] + + if det_only: + color = _track_colors[0] + label = f"{conf:.2f}" + else: + tid = ( + int(detections.tracker_id[i]) + if detections.tracker_id is not None + else 0 + ) + color = _track_colors[tid % len(_track_colors)] + label = f"#{tid} {conf:.2f}" + + svg_rects.append( + f'' + f'{label}' + ) + + video_display.content = "\n".join(svg_rects) + + time_slider.set_value(frame_idx) + time_label.text = f"Frame: {frame_idx} / {total_frames}" + + # Build tracker config from dynamic parameter widgets + tracker_config = { + "skip_frames": int(skip_frames_slider.value), # Fast-forward + } + # Add ByteTrack parameters from param_widgets + for param_name, widget in param_widgets.items(): + if hasattr(widget, "value"): + tracker_config[param_name] = widget.value + + # Log and display actual parameters + active_params = { + "start_frame": start_frame, + "confidence": conf_slider.value, + "detection_only": detection_only_checkbox.value, + "display_interval": display_interval, + **tracker_config, + } + logger.info(f"Tracking params: {active_params}") + debug_params_label.text = " ".join( + f"{k}={v}" for k, v in active_params.items() + ) + debug_params_card.classes(remove="hidden") + + # Initialize tracker with dynamic parameters + tracker = VideoTracker( + model=model, + tracker_config=tracker_config, + confidence=conf_slider.value, + detection_only=detection_only_checkbox.value, + display_interval=display_interval, + frame_callback=frame_callback, + stop_event=state["stop_event"], + pause_event=state["pause_event"], + skip_frames_event=state["skip_frames_event"], + ) + + output_dir = f"/tmp/outputs/{session_id}" + results = await tracker.process_video_realtime( + str(local_video), output_dir, start_frame=start_frame + ) + + # Show results + status_indicator.text = "Complete!" + + state["results"] = results + + if detection_only_checkbox.value: + stats_label.text = ( + f"Processed {results['stats']['total_frames']} frames | " + f"{results['stats']['total_detections']} detections" + ) + else: + stats_label.text = ( + f"Processed {results['stats']['total_frames']} frames | " + f"{results['stats']['total_detections']} detections | " + f"{results['stats']['unique_tracks']} unique tracks" + ) + + # Setup download button (only tracking CSV) + download_track_btn.on_click(lambda: ui.download(results["tracking_csv"])) + + results_container.classes(remove="hidden") + + except Exception as e: + logger.error(f"Tracking failed: {e}", exc_info=True) + status_indicator.text = f"Error: {str(e)}" + ui.notify(f"Error: {str(e)}", type="negative") + + finally: + state["processing"] = False + state["stop_event"] = None + state["pause_event"] = None + state["skip_frames_event"] = None + start_btn.enable() + pause_btn.text = "Pause" + pause_btn.props("icon=pause") + pause_btn.disable() + stop_btn.disable() + params_card.style(remove="opacity: 0.5; pointer-events: none;") + + # Wire up buttons to event handlers (after functions are defined) + load_video_btn.on_click(lambda: load_video()) + load_model_btn.on_click(load_model) + start_btn.on_click(start_tracking) + pause_btn.on_click(lambda: pause_tracking()) + stop_btn.on_click(lambda: stop_tracking()) + + +# Run the NiceGUI app directly (no function wrapper for reload compatibility) +ui.run( + host="0.0.0.0", + port=int(os.getenv("PORT", 8080)), + reload=os.getenv("NICEGUI_RELOAD", "true").lower() == "true", + title="Tracking Studio", +) diff --git a/collab_env/tracking_studio/bytetrack_params.json b/collab_env/tracking_studio/bytetrack_params.json new file mode 100644 index 00000000..5700217a --- /dev/null +++ b/collab_env/tracking_studio/bytetrack_params.json @@ -0,0 +1,37 @@ +{ + "track_high_thresh": { + "type": "float", + "default": 0.25, + "range": [0.0, 1.0], + "description": "Detection confidence threshold that separates high-confidence and low-confidence detections. Detections above this score enter Stage 1 (primary IoU matching). Detections between track_low_thresh and this value enter Stage 2 (secondary matching to rescue lost tracks). Raise to restrict primary matching to only the most confident detections; lower to feed more detections into Stage 1." + }, + "track_low_thresh": { + "type": "float", + "default": 0.1, + "range": [0.0, 1.0], + "description": "Absolute minimum detection confidence. Detections below this score are discarded entirely and never participate in any matching stage. Lower to recover very marginal detections at the cost of more noise; raise to filter out weak false positives." + }, + "new_track_thresh": { + "type": "float", + "default": 0.25, + "range": [0.0, 1.0], + "description": "Minimum detection confidence required to initialize a new track. Unmatched detections from Stage 1 must exceed this score to spawn a new track ID. Higher values prevent spurious tracks from false positives; lower values allow tracks to start from weaker detections." + }, + "track_buffer": { + "type": "int", + "default": 30, + "range": [1, null], + "description": "Number of frames a lost track is kept alive before permanent deletion. Internally scaled by frame rate: max_time_lost = int(fps / 30.0 * track_buffer). Higher values let tracks survive longer occlusions but increase ID switch risk when the object reappears far from its last position. Lower values remove lost tracks faster." + }, + "match_thresh": { + "type": "float", + "default": 0.8, + "range": [0.0, 1.0], + "description": "IoU-distance gating threshold for Stage 1 (high-confidence) association. Passed as cost_limit to the linear assignment solver. Since cost = 1 - IoU, a threshold of 0.8 accepts matches with IoU >= 0.2. Higher values are more lenient (easier to match); lower values require stronger spatial overlap. Tune with detector quality." + }, + "fuse_score": { + "type": "bool", + "default": true, + "description": "When enabled, IoU similarity is multiplied by the detection confidence score before matching: fused_cost = 1 - (iou_similarity * detection_score). This biases matching toward high-confidence detections. Disable if your detector's confidence scores are poorly calibrated or inconsistent." + } +} diff --git a/collab_env/tracking_studio/gcs_browser.py b/collab_env/tracking_studio/gcs_browser.py new file mode 100644 index 00000000..76a9f720 --- /dev/null +++ b/collab_env/tracking_studio/gcs_browser.py @@ -0,0 +1,171 @@ +""" +GCS Video Browser Component + +Provides interface for browsing and downloading videos from Google Cloud Storage. +""" + +from typing import List, Dict +from loguru import logger + +from collab_env.data.gcs_utils import GCSClient + + +class GCSVideoBrowser: + """Browser for selecting and downloading videos from GCS buckets""" + + def __init__(self, credentials_path: str): + """ + Initialize GCS browser. + + Args: + credentials_path: Path to GCS service account credentials JSON + """ + self.gcs = GCSClient(credentials_path=credentials_path) + logger.info("GCS Video Browser initialized") + + def list_buckets(self) -> List[str]: + """ + List all available GCS buckets. + + Returns: + List of bucket names + """ + try: + buckets = self.gcs.list_buckets() + logger.info(f"Found {len(buckets)} buckets") + return buckets + except Exception as e: + logger.error(f"Failed to list buckets: {e}") + return [] + + def list_folders(self, bucket: str, prefix: str = "") -> List[str]: + """ + List immediate subfolders in a bucket path. + + Note: GCS doesn't have real folders - they're just prefixes in object names. + This function extracts unique first-level directory prefixes. + + Args: + bucket: GCS bucket name + prefix: Path prefix within bucket (should end with / if not empty) + + Returns: + List of folder names (relative to prefix) + """ + try: + # Ensure prefix ends with / if not empty + if prefix and not prefix.endswith("/"): + prefix = prefix + "/" + + # Get all objects recursively to find folder-like structures + pattern = f"{bucket}/{prefix}**" if prefix else f"{bucket}/**" + all_paths = self.gcs.glob(pattern) + + # Extract unique immediate subdirectories + unique_folders = set() + for path in all_paths: + # Remove bucket prefix + rel_path = path.replace(f"{bucket}/", "") + + # Remove the current prefix if any + if prefix: + if not rel_path.startswith(prefix): + continue + rel_path = rel_path[len(prefix) :] + + # Get first directory component after prefix + if "/" in rel_path: + folder = rel_path.split("/")[0] + if folder: # Skip empty strings + unique_folders.add(folder) + + folder_list = sorted(list(unique_folders)) + logger.info( + f"Found {len(folder_list)} folder prefixes in {bucket}/{prefix}" + ) + return folder_list + + except Exception as e: + logger.error(f"Failed to list folders in {bucket}/{prefix}: {e}") + return [] + + def list_videos(self, bucket: str, prefix: str = "") -> List[Dict[str, str]]: + """ + List video files (.mp4, .mov, .avi) in a bucket path. + + Args: + bucket: GCS bucket name + prefix: Path prefix within bucket + + Returns: + List of dicts with video metadata: {name, path, rel_path} + """ + try: + # Build pattern for video files - ensure prefix ends with / if not empty + if prefix and not prefix.endswith("/"): + prefix = prefix + "/" + + # Search for multiple video formats + video_extensions = ["*.mp4", "*.mov", "*.avi", "*.MP4", "*.MOV", "*.AVI"] + all_files = [] + + for ext in video_extensions: + pattern = ( + f"{bucket}/{prefix}**/{ext}" if prefix else f"{bucket}/**/{ext}" + ) + files = self.gcs.glob(pattern) + all_files.extend(files) + + videos = [] + seen_paths = set() # Avoid duplicates from case-insensitive extensions + + for file_path in all_files: + if file_path in seen_paths: + continue + seen_paths.add(file_path) + + # Extract filename + filename = file_path.split("/")[-1] + + # Get relative path from bucket + rel_path = file_path.replace(f"{bucket}/", "") + + videos.append( + { + "name": filename, + "path": file_path, + "rel_path": rel_path, + } + ) + + logger.info(f"Found {len(videos)} videos in {bucket}/{prefix}") + return sorted(videos, key=lambda x: x["name"]) + + except Exception as e: + logger.error(f"Failed to list videos in {bucket}/{prefix}: {e}") + return [] + + def download_video(self, gcs_path: str, local_path: str) -> str: + """ + Download video from GCS to local path. + + Args: + gcs_path: Full GCS path (e.g., "bucket/path/video.mp4" or "gs://bucket/path/video.mp4") + local_path: Local destination path + + Returns: + Local path to downloaded video + """ + try: + # Remove gs:// prefix if present + if gcs_path.startswith("gs://"): + gcs_path = gcs_path[5:] + + logger.info(f"Downloading {gcs_path} to {local_path}") + self.gcs.download_file(gcs_path, local_path, overwrite=True) + logger.info(f"Successfully downloaded video to {local_path}") + return local_path + + except Exception as e: + logger.error(f"Failed to download video from {gcs_path}: {e}") + raise diff --git a/collab_env/tracking_studio/model_manager.py b/collab_env/tracking_studio/model_manager.py new file mode 100644 index 00000000..d3b1383c --- /dev/null +++ b/collab_env/tracking_studio/model_manager.py @@ -0,0 +1,500 @@ +""" +Model Manager Component + +Handles loading and managing detection models (YOLO and Roboflow). +""" + +import os +from pathlib import Path +from typing import List, Optional +from loguru import logger + +from ultralytics import YOLO + + +class ModelManager: + """Manager for detection models (YOLO and Roboflow)""" + + def __init__(self, roboflow_api_key: Optional[str] = None): + """ + Initialize model manager. + + Args: + roboflow_api_key: Roboflow API key (or read from env) + """ + self.roboflow_api_key = roboflow_api_key or os.getenv("ROBOFLOW_API_KEY") + self.local_models_dir = Path("/workspace/models") + + # Check if running locally (models in ~/.cache/ultralytics) + if not self.local_models_dir.exists(): + # Use default Ultralytics cache directory + self.local_models_dir = Path.home() / ".cache" / "ultralytics" + + logger.info(f"Model directory: {self.local_models_dir}") + + def list_local_yolo_models(self) -> List[str]: + """ + Return available YOLO models (YOLO11 and YOLO26 variants). + + Returns: + List of model filenames + """ + # Auto-downloadable models (Ultralytics will download them) + auto_downloadable = [ + "yolo11n.pt", + "yolo11s.pt", + "yolo11m.pt", + ] + + # Models that must exist locally (not auto-downloadable) + local_only = [ + "yolo26n-fast.pt", + "yolo26s-fast.pt", + "yolo26m-fast.pt", + ] + + available = [] + + # Add auto-downloadable models (always available) + available.extend(auto_downloadable) + + # Add local-only models only if they exist + for model in local_only: + model_path = self.local_models_dir / model + if model_path.exists(): + available.append(model) + logger.debug(f"Found local YOLO26 model: {model}") + + logger.info(f"Available YOLO models: {available}") + return available + + def load_yolo_model(self, model_name: str) -> YOLO: + """ + Load YOLO model - will download automatically if available. + + Args: + model_name: Model filename (e.g., "yolo11n.pt", "yolo26n-fast.pt") + + Returns: + Loaded YOLO model + """ + try: + logger.info(f"Loading YOLO model: {model_name}") + # Pass directly to YOLO - it will handle local files or auto-download + model = YOLO(model_name) + logger.info(f"Successfully loaded YOLO model: {model_name}") + return model + + except Exception as e: + logger.error(f"Failed to load YOLO model {model_name}: {e}") + raise ValueError( + f"Failed to load model '{model_name}'.\n\n" + f"Possible solutions:\n" + f"- Check the model name is correct\n" + f"- Download manually and place in {self.local_models_dir}\n" + f"- Use the 'Custom' upload option to upload your .pt file" + ) from e + + def _validate_roboflow_model_id(self, model_id: str) -> str: + """ + Validate and format Roboflow model ID. + + Accepts: + - project/version (e.g., "ratsmerged20260211/1") + - workspace/project/version (e.g., "myworkspace/ratsmerged20260211/1") + + Returns properly formatted model ID. + """ + parts = model_id.split("/") + + if len(parts) == 2: + # project/version format + logger.info(f"Model ID format: project/version ({model_id})") + return model_id + elif len(parts) == 3: + # workspace/project/version format + logger.info(f"Model ID format: workspace/project/version ({model_id})") + return model_id + else: + raise ValueError( + f"Invalid model ID format: {model_id}\n" + f"Expected: 'project/version' or 'workspace/project/version'" + ) + + def load_roboflow_model(self, model_id: str): + """ + Load Roboflow model using Inference SDK or local file path. + + Args: + model_id: Model ID in format "project/version", "workspace/project/version", + or a local file path to a .pt file + + Returns: + Loaded Roboflow model or YOLO model from local file + """ + # Check if model_id is a local file path + if ( + model_id.startswith("/") + or model_id.startswith("~") + or model_id.endswith(".pt") + ): + logger.info(f"Loading Roboflow model from local file: {model_id}") + model_path = Path(model_id).expanduser() + + if not model_path.exists(): + raise FileNotFoundError(f"Model file not found: {model_path}") + + logger.info(f"Loading YOLO model from: {model_path}") + model = YOLO(str(model_path)) + logger.info( + f"Successfully loaded Roboflow model from local file: {model_id}" + ) + return model + + if not self.roboflow_api_key: + raise ValueError( + "ROBOFLOW_API_KEY not set. Please provide API key in environment or constructor." + ) + + # Validate model ID format + model_id = self._validate_roboflow_model_id(model_id) + + # Try downloading .pt file first (for YOLO native tracking) + # This provides better performance and supports all ByteTrack parameters + try: + logger.info( + f"Downloading Roboflow model weights for YOLO native tracking: {model_id}" + ) + model = self._load_roboflow_with_pipeline(model_id) + logger.info( + f"Successfully loaded Roboflow model with native tracking: {model_id}" + ) + return model + + except Exception as download_error: + # Fallback to get_model() (inference API) if download fails + logger.warning(f"Download failed: {download_error}") + logger.info( + "Attempting fallback: loading with inference API (Supervision tracking)" + ) + + try: + from inference import get_model + + # Extract project/version from workspace/project/version if needed + parts = model_id.split("/") + if len(parts) == 3: + # workspace/project/version -> project/version + project_version = f"{parts[1]}/{parts[2]}" + logger.info( + f"Trying to load Roboflow model with get_model(): {project_version}" + ) + model = get_model( + model_id=project_version, api_key=self.roboflow_api_key + ) + else: + # Already project/version format + logger.info( + f"Trying to load Roboflow model with get_model(): {model_id}" + ) + model = get_model(model_id=model_id, api_key=self.roboflow_api_key) + + logger.info( + f"Successfully loaded Roboflow model via inference API: {model_id}" + ) + return model + + except ImportError: + logger.error( + "inference library not installed. Install with: pip install inference" + ) + raise + except Exception as inference_error: + # Both methods failed + logger.error(f"Inference API also failed: {inference_error}") + error_msg = ( + f"Failed to load Roboflow model '{model_id}'.\n\n" + f"Tried:\n" + f"1. Downloading model weights (.pt file): {str(download_error)}\n" + f"2. Loading via inference API: {str(inference_error)}\n\n" + f"Possible solutions:\n" + f"- Verify model ID format: workspace/project/version (e.g., 'dima-sdrkv/ratsmerged20260211/1')\n" + f"- Check model exists at https://app.roboflow.com/\n" + f"- Ensure ROBOFLOW_API_KEY has access to this model\n" + f"- Try uploading the .pt file directly using 'Custom' option" + ) + raise ValueError(error_msg) from inference_error + + def _load_roboflow_with_pipeline(self, model_id: str): + """ + Fallback: Download Roboflow model weights via /ptFile endpoint. + + This downloads the model weights once via API, then runs inference locally. + Much faster than HTTP inference for every frame. + """ + import requests + + logger.info( + f"Downloading Roboflow model weights for local inference: {model_id}" + ) + + # Parse model ID to get workspace/project/version + parts = model_id.split("/") + if len(parts) == 2: + # project/version format - need workspace + raise ValueError( + f"Model ID '{model_id}' missing workspace.\n" + f"For model download, use format: workspace/project/version" + ) + elif len(parts) == 3: + # workspace/project/version format + workspace, project, version = parts + else: + raise ValueError(f"Invalid model ID format: {model_id}") + + # Create cache directory for downloaded models + cache_dir = self.local_models_dir / "roboflow_cache" + cache_dir.mkdir(parents=True, exist_ok=True) + + # Check if model already downloaded + model_cache_name = f"{workspace}_{project}_v{version}.pt" + cached_model_path = cache_dir / model_cache_name + + if cached_model_path.exists(): + logger.info(f"Using cached Roboflow model: {cached_model_path}") + return YOLO(str(cached_model_path)) + + # Download model weights from Roboflow using /ptFile endpoint + logger.info("Fetching model weights URL from Roboflow API...") + + try: + # Call /ptFile endpoint to get signed download URL + ptfile_url = ( + f"https://api.roboflow.com/{workspace}/{project}/{version}/ptFile" + ) + logger.info(f"Requesting weights URL from: {ptfile_url}") + + response = requests.get( + ptfile_url, params={"api_key": self.roboflow_api_key}, timeout=10 + ) + response.raise_for_status() + + # Parse response to get weightsUrl + data = response.json() + if "weightsUrl" not in data: + raise ValueError(f"No weightsUrl in response: {data}") + + weights_url = data["weightsUrl"] + logger.info("Got weights URL, downloading...") + + # Download the .pt file from signed URL + response = requests.get(weights_url, stream=True, timeout=120) + response.raise_for_status() + + # Save to cache + with open(cached_model_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + logger.info( + f"Downloaded model weights: {cached_model_path} ({cached_model_path.stat().st_size} bytes)" + ) + + # Load with Ultralytics YOLO + model = YOLO(str(cached_model_path)) + logger.info( + f"Successfully loaded Roboflow model for local inference: {model_id}" + ) + return model + + except requests.exceptions.HTTPError as e: + error_msg = ( + f"Failed to download Roboflow model weights for '{model_id}'.\n\n" + f"HTTP Error {e.response.status_code}: {e.response.text[:200]}\n\n" + f"Possible solutions:\n" + f"1. Verify model ID format is workspace/project/version\n" + f"2. Check ROBOFLOW_API_KEY has access to this model\n" + f"3. Ensure model exists at https://app.roboflow.com/\n" + f"4. Upload model weights manually via 'Custom' option" + ) + logger.error(error_msg) + raise ValueError(error_msg) from e + except Exception as e: + error_msg = ( + f"Failed to download Roboflow model weights for '{model_id}'.\n\n" + f"Error: {str(e)}\n\n" + f"Try uploading model weights manually via 'Custom' option." + ) + logger.error(error_msg) + raise ValueError(error_msg) from e + + def list_roboflow_projects(self) -> List[str]: + """ + Query Roboflow API for all projects in the workspace tied to the API key. + + Returns: + List of project IDs in "workspace/project" format, sorted alphabetically. + """ + import requests + + if not self.roboflow_api_key: + raise ValueError("ROBOFLOW_API_KEY not set") + + try: + # Root endpoint with API key returns workspace info (may include + # workspace name and/or a nested workspace object with projects). + root = requests.get( + "https://api.roboflow.com/", + params={"api_key": self.roboflow_api_key}, + timeout=10, + ) + root.raise_for_status() + root_data = root.json() + logger.debug(f"Roboflow root response keys: {list(root_data.keys())}") + + # Collect candidate workspace names from various possible shapes + workspace_names: List[str] = [] + ws_field = root_data.get("workspace") + if isinstance(ws_field, str): + workspace_names.append(ws_field) + elif isinstance(ws_field, dict): + name = ws_field.get("url") or ws_field.get("name") + if name: + workspace_names.append(name) + for w in root_data.get("workspaces", []) or []: + if isinstance(w, str): + workspace_names.append(w) + elif isinstance(w, dict): + name = w.get("url") or w.get("name") + if name: + workspace_names.append(name) + + if not workspace_names: + raise ValueError( + f"Could not resolve any workspace from API key. " + f"Root response: {root_data}" + ) + + project_ids: List[str] = [] + for workspace in workspace_names: + ws = requests.get( + f"https://api.roboflow.com/{workspace}", + params={"api_key": self.roboflow_api_key}, + timeout=10, + ) + ws.raise_for_status() + data = ws.json() + projects = ( + data.get("workspace", {}).get("projects") + or data.get("projects") + or [] + ) + logger.info( + f"Roboflow workspace '{workspace}': {len(projects)} projects" + ) + for p in projects: + if isinstance(p, str): + pid = p + else: + pid = p.get("id") or p.get("url") or p.get("name") or "" + if not pid: + continue + if "/" not in pid: + pid = f"{workspace}/{pid}" + project_ids.append(pid) + + project_ids = sorted(set(project_ids)) + logger.info( + f"Found {len(project_ids)} total Roboflow projects across " + f"{len(workspace_names)} workspace(s)" + ) + return project_ids + except requests.exceptions.HTTPError as e: + error_msg = ( + f"Failed to list Roboflow projects: HTTP {e.response.status_code}" + ) + logger.error(error_msg) + raise ValueError(error_msg) from e + except Exception as e: + error_msg = f"Failed to list Roboflow projects: {str(e)}" + logger.error(error_msg) + raise ValueError(error_msg) from e + + def list_roboflow_project_models(self, project_id: str) -> List[dict]: + """ + Query Roboflow API for available model versions in a project. + + Args: + project_id: Project ID in format "workspace/project" (e.g., "dima-sdrkv/ratsmerged20260211") + + Returns: + List of dicts with keys: version, name, images, map + """ + import requests + + if not self.roboflow_api_key: + raise ValueError("ROBOFLOW_API_KEY not set") + + parts = project_id.split("/") + if len(parts) != 2: + raise ValueError("Project ID must be in format: workspace/project") + + workspace, project = parts + + try: + url = f"https://api.roboflow.com/{workspace}/{project}" + logger.info(f"Querying Roboflow project models: {url}") + + response = requests.get( + url, params={"api_key": self.roboflow_api_key}, timeout=10 + ) + response.raise_for_status() + + data = response.json() + + versions = [] + if "versions" in data: + for vd in data["versions"]: + version_num = vd.get("id", "") + if isinstance(version_num, str) and "/" in version_num: + version_num = version_num.split("/")[-1] + if not version_num: + version_num = vd.get("version") + if not version_num: + continue + + map_val = vd.get("model", {}).get("map", "") + if map_val and str(map_val) != "NaN": + map_str = f"{float(map_val):.1f}%" + else: + map_str = "" + + versions.append( + { + "version": str(version_num), + "name": vd.get("name", ""), + "images": vd.get("images", 0), + "map": map_str, + "raw": vd, + } + ) + + versions.sort( + key=lambda x: int(x["version"]) if x["version"].isdigit() else 0, + reverse=True, + ) + logger.info( + f"Found {len(versions)} versions: {[v['version'] for v in versions]}" + ) + return versions + + except requests.exceptions.HTTPError as e: + error_msg = ( + f"Failed to query Roboflow project: HTTP {e.response.status_code}" + ) + logger.error(error_msg) + raise ValueError(error_msg) from e + except Exception as e: + error_msg = f"Failed to query Roboflow project: {str(e)}" + logger.error(error_msg) + raise ValueError(error_msg) from e diff --git a/collab_env/tracking_studio/video_converter.py b/collab_env/tracking_studio/video_converter.py new file mode 100644 index 00000000..ed1acfb1 --- /dev/null +++ b/collab_env/tracking_studio/video_converter.py @@ -0,0 +1,121 @@ +""" +Video Format Converter Component + +Ensures videos are in H.264 format for browser compatibility. +""" + +import subprocess +from pathlib import Path +from loguru import logger + + +def needs_conversion(video_path: Path) -> bool: + """ + Check if video needs H.264 conversion. + + Args: + video_path: Path to video file + + Returns: + True if conversion needed, False otherwise + """ + try: + cmd = [ + "ffprobe", + "-v", + "error", + "-select_streams", + "v:0", + "-show_entries", + "stream=codec_name", + "-of", + "default=noprint_wrappers=1:nokey=1", + str(video_path), + ] + + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + codec = result.stdout.strip() + + logger.info(f"Video codec: {codec}") + if codec != "h264": + return True + # H.264 in non-MP4 container (e.g. .mov) may not play in all browsers + ext = Path(video_path).suffix.lower() + if ext not in (".mp4", ".m4v"): + logger.info( + f"H.264 in {ext} container — will remux to .mp4 for browser compatibility" + ) + return True + return False + + except subprocess.CalledProcessError as e: + logger.error(f"Failed to check video codec: {e}") + # Assume conversion needed if check fails + return True + except FileNotFoundError: + logger.error("ffprobe not found. Please install ffmpeg.") + raise + + +def convert_to_h264( + input_path: Path, output_path: Path, remux_only: bool = False +) -> Path: + """ + Convert video to H.264 format using ffmpeg. + + Args: + input_path: Original video file + output_path: Output path for converted video + remux_only: If True, copy streams without re-encoding (fast container change) + + Returns: + Path to converted video + """ + try: + if remux_only: + logger.info(f"Remuxing {input_path} to MP4 container (no re-encoding)") + cmd = [ + "ffmpeg", + "-i", + str(input_path), + "-c", + "copy", # Copy all streams without re-encoding + "-movflags", + "+faststart", + "-y", + str(output_path), + ] + else: + logger.info(f"Converting {input_path} to H.264 format") + cmd = [ + "ffmpeg", + "-i", + str(input_path), + "-c:v", + "libx264", + "-preset", + "fast", + "-crf", + "23", + "-c:a", + "aac", + "-b:a", + "128k", + "-movflags", + "+faststart", + "-y", + str(output_path), + ] + + subprocess.run(cmd, check=True, capture_output=True) + + logger.info(f"Successfully converted video to {output_path}") + return output_path + + except subprocess.CalledProcessError as e: + logger.error(f"Failed to convert video: {e}") + logger.error(f"stderr: {e.stderr.decode() if e.stderr else 'N/A'}") + raise + except FileNotFoundError: + logger.error("ffmpeg not found. Please install ffmpeg.") + raise diff --git a/collab_env/tracking_studio/video_processor.py b/collab_env/tracking_studio/video_processor.py new file mode 100644 index 00000000..3df6eef0 --- /dev/null +++ b/collab_env/tracking_studio/video_processor.py @@ -0,0 +1,428 @@ +""" +Video Processor Component + +Core tracking pipeline with ByteTrack. +""" + +import asyncio +import threading +import cv2 +import supervision as sv +from ultralytics import YOLO +import pandas as pd +from pathlib import Path +from concurrent.futures import Future +from typing import Callable, Coroutine, Dict, Optional, Union, Any +import tempfile +import yaml +from loguru import logger + + +class VideoTracker: + """Video tracking processor with detection and tracking""" + + def __init__( + self, + model: Union[YOLO, Any], # YOLO or Roboflow model + tracker_config: Dict, # ByteTrack parameters + confidence: float = 0.5, + detection_only: bool = False, + display_interval: int = 10, + frame_callback: Optional[Callable[..., Coroutine[Any, Any, None]]] = None, + stop_event: Optional[threading.Event] = None, + pause_event: Optional[threading.Event] = None, + skip_frames_event: Optional[Dict] = None, + ): + """ + Initialize video tracker. + + Args: + model: Detection model (YOLO or Roboflow) + tracker_config: Tracker configuration dict + confidence: Detection confidence threshold + detection_only: If True, run detection without tracking (no track IDs) + display_interval: Update display every Nth frame (1 = every frame) + frame_callback: Async callback for frame updates (frame, frame_idx, total_frames) + stop_event: Threading event to signal hard stop + pause_event: Threading event to signal pause/resume + skip_frames_event: Dict with skip_amount for forward seeking + """ + self.model = model + self.confidence = confidence + self.detection_only = detection_only + self.display_interval = max(1, display_interval) + self.frame_callback = frame_callback + self.stop_event = stop_event or threading.Event() + self.pause_event = pause_event or threading.Event() + self.skip_frames_event = skip_frames_event or {"skip_amount": 0} + self._pending_update: Optional[Future[None]] = None + + # Store tracker config for use with model.track() + self.tracker_config = tracker_config + + # Check if model supports native tracking + self.use_native_tracking = isinstance(model, YOLO) + + if detection_only: + logger.info("Detection-only mode (no tracking)") + self.tracker = None + self.tracker_yaml_path = None + elif not self.use_native_tracking: + # For Roboflow inference models (fallback), initialize supervision tracker + logger.info( + "Using supervision ByteTrack (Roboflow inference model fallback)" + ) + self.tracker = sv.ByteTrack( + track_activation_threshold=tracker_config.get( + "track_high_thresh", 0.25 + ), + lost_track_buffer=tracker_config.get("track_buffer", 30), + minimum_matching_threshold=tracker_config.get("match_thresh", 0.8), + minimum_consecutive_frames=1, + frame_rate=30, + ) + self.tracker_yaml_path = None + else: + logger.info("Using Ultralytics native ByteTrack (supports all parameters)") + self.tracker = None + # Create temporary ByteTrack YAML config from parameters + self.tracker_yaml_path = self._create_bytetrack_config(tracker_config) + + # Fast-forward: Skip frames for faster preview + self.skip_frames = tracker_config.get( + "skip_frames", 1 + ) # 1 = process every frame + + logger.info( + f"VideoTracker initialized (confidence: {self.confidence}, native_tracking: {self.use_native_tracking})" + ) + + def _create_bytetrack_config(self, config: Dict) -> str: + """ + Create a temporary ByteTrack YAML config file from parameters. + + Args: + config: Tracker configuration dict + + Returns: + Path to temporary YAML config file + """ + # Map our parameter names to Ultralytics ByteTrack YAML format + bytetrack_yaml = { + "tracker_type": "bytetrack", + "track_high_thresh": config.get("track_high_thresh", 0.25), + "track_low_thresh": config.get("track_low_thresh", 0.1), + "new_track_thresh": config.get("new_track_thresh", 0.25), + "track_buffer": config.get("track_buffer", 30), + "match_thresh": config.get("match_thresh", 0.8), + "fuse_score": config.get("fuse_score", True), + } + + # Create temporary YAML file + temp_file = tempfile.NamedTemporaryFile( + mode="w", suffix=".yaml", delete=False, prefix="bytetrack_" + ) + + with temp_file as f: + yaml.dump(bytetrack_yaml, f, default_flow_style=False) + + logger.info(f"Created ByteTrack config: {temp_file.name}") + logger.debug(f"Config values: {bytetrack_yaml}") + + return temp_file.name + + def _process_video_sync( + self, video_path: str, output_dir: str, event_loop, start_frame: int = 0 + ) -> Dict[str, Any]: + """ + Synchronous video processing function (runs in background thread). + + Args: + video_path: Path to input video + output_dir: Directory for output CSV + event_loop: Main asyncio event loop for scheduling UI updates + start_frame: Frame index to start processing from (0-based) + + Returns: + Dict with tracking_csv path and stats + """ + logger.info(f"Processing video in background thread: {video_path}") + + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise ValueError(f"Failed to open video: {video_path}") + + fps = cap.get(cv2.CAP_PROP_FPS) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + logger.info(f"Video info: {total_frames} frames, {fps} fps, {width}x{height}") + + detections_list = [] + tracking_list = [] + + frame_idx = start_frame + if start_frame > 0: + cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) + logger.info(f"Starting from frame {start_frame}") + while frame_idx < total_frames: + # Check if stop was requested (hard stop) + if self.stop_event.is_set(): + logger.info(f"Stop requested at frame {frame_idx}, stopping processing") + break + + # Check if pause was requested + while self.pause_event.is_set(): + import time + + time.sleep(0.1) # Wait while paused + if self.stop_event.is_set(): + break + + # Check if seek was requested (forward or backward) + if self.skip_frames_event["skip_amount"] != 0: + skip_to = frame_idx + self.skip_frames_event["skip_amount"] + # Clamp to valid range + skip_to = max(0, min(skip_to, total_frames - 1)) + logger.info(f"Seeking from frame {frame_idx} to {skip_to}") + cap.set(cv2.CAP_PROP_POS_FRAMES, skip_to) + self.skip_frames_event["skip_amount"] = 0 # Reset + frame_idx = skip_to + continue + + ret, frame = cap.read() + if not ret: + logger.warning(f"Failed to read frame {frame_idx}, stopping") + break + + # Fast-forward: Skip frames if requested + if self.skip_frames > 1 and frame_idx % self.skip_frames != 0: + frame_idx += 1 + continue + + # 1. Run detection (and optionally tracking) + try: + if self.detection_only: + # Detection only - no tracking + if self.use_native_tracking: + results = self.model( + source=frame, + conf=self.confidence, + verbose=False, + )[0] + detections = sv.Detections.from_ultralytics(results) + else: + results = self.model.infer(frame, confidence=self.confidence)[0] + detections = sv.Detections.from_inference(results) + tracked_detections = detections + + elif self.use_native_tracking: + # Use Ultralytics native tracking (supports all ByteTrack parameters) + results = self.model.track( + source=frame, + conf=self.confidence, + persist=True, # Maintain track IDs across frames + tracker=self.tracker_yaml_path, # Custom ByteTrack config + verbose=False, + )[0] + + # Convert to supervision Detections (with track IDs) + tracked_detections = sv.Detections.from_ultralytics(results) + + # Also get detections without tracking for stats + detections = tracked_detections + else: + # Roboflow inference model (fallback to supervision ByteTrack) + logger.debug(f"Running Roboflow inference on frame {frame_idx}...") + results = self.model.infer(frame, confidence=self.confidence)[0] + detections = sv.Detections.from_inference(results) + logger.debug(f"Frame {frame_idx}: {len(detections)} detections") + + # Update tracker (adds track IDs via supervision ByteTrack) + assert self.tracker is not None + tracked_detections = self.tracker.update_with_detections(detections) + + except Exception as e: + logger.error( + f"Detection/tracking failed on frame {frame_idx}: {e}", + exc_info=True, + ) + detections = sv.Detections.empty() + tracked_detections = sv.Detections.empty() + + # 2. Save detections + if detections.confidence is None or detections.class_id is None: + frame_idx += 1 + continue + for i, (bbox, conf, class_id) in enumerate( + zip(detections.xyxy, detections.confidence, detections.class_id) + ): + detections_list.append( + { + "frame": frame_idx, + "x1": bbox[0], + "y1": bbox[1], + "x2": bbox[2], + "y2": bbox[3], + "confidence": conf, + "class": class_id, + } + ) + + # 3. Save tracking data (with track IDs if tracking is enabled) + if ( + not self.detection_only + and tracked_detections.tracker_id is not None + and len(tracked_detections) > 0 + ): + confidences: Any = ( + tracked_detections.confidence + if tracked_detections.confidence is not None + else [] + ) + class_ids: Any = ( + tracked_detections.class_id + if tracked_detections.class_id is not None + else [] + ) + for bbox, track_id, conf, class_id in zip( + tracked_detections.xyxy, + tracked_detections.tracker_id, + confidences, + class_ids, + ): + tracking_list.append( + { + "track_id": int(track_id), + "frame": frame_idx, + "x1": int(bbox[0]), + "y1": int(bbox[1]), + "x2": int(bbox[2]), + "y2": int(bbox[3]), + "confidence": float(conf), + "class": int(class_id), + } + ) + + # 4. Send frame + detections to UI for display + is_last = frame_idx >= total_frames - 1 + should_display = (frame_idx % self.display_interval == 0) or is_last + + if should_display and self.frame_callback is not None and event_loop: + # Skip if previous UI update is still in-flight (prevents queue buildup) + if self._pending_update is None or self._pending_update.done(): + self._pending_update = asyncio.run_coroutine_threadsafe( + self.frame_callback( + frame, tracked_detections, frame_idx, total_frames + ), + event_loop, + ) + + # Increment frame counter for next iteration + frame_idx += 1 + + cap.release() + + # Cleanup temporary tracker config file if created + if self.tracker_yaml_path: + try: + import os + + os.unlink(self.tracker_yaml_path) + logger.debug( + f"Cleaned up temporary tracker config: {self.tracker_yaml_path}" + ) + except Exception as e: + logger.warning(f"Failed to cleanup tracker config: {e}") + + unique_tracks = ( + len(set(t["track_id"] for t in tracking_list)) if tracking_list else 0 + ) + logger.info( + f"Processing complete: {total_frames} frames, " + f"{len(detections_list)} detections, " + f"{unique_tracks} unique tracks" + ) + + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + if self.detection_only: + # Save detections CSV (no track IDs) + output_csv = output_path / "detections.csv" + if len(detections_list) > 0: + det_df = pd.DataFrame(detections_list) + det_df = det_df[ + ["frame", "x1", "y1", "x2", "y2", "confidence", "class"] + ] + det_df.to_csv(output_csv, index=False) + else: + pd.DataFrame( + columns=["frame", "x1", "y1", "x2", "y2", "confidence", "class"] + ).to_csv(output_csv, index=False) + logger.info(f"Saved detections CSV to {output_csv}") + else: + # Save tracking CSV (with track IDs) + output_csv = output_path / "tracking.csv" + if len(tracking_list) > 0: + tracking_df = pd.DataFrame(tracking_list) + tracking_df = tracking_df[ + ["track_id", "frame", "x1", "y1", "x2", "y2", "confidence", "class"] + ] + tracking_df.to_csv(output_csv, index=False) + else: + pd.DataFrame( + columns=[ + "track_id", + "frame", + "x1", + "y1", + "x2", + "y2", + "confidence", + "class", + ] + ).to_csv(output_csv, index=False) + logger.info(f"Saved tracking CSV to {output_csv}") + + return { + "tracking_csv": str(output_csv), + "stats": { + "total_frames": total_frames, + "total_detections": len(detections_list), + "unique_tracks": unique_tracks, + "fps": fps, + }, + } + + async def process_video_realtime( + self, video_path: str, output_dir: str, start_frame: int = 0 + ) -> Dict[str, Any]: + """ + Process video frame-by-frame with real-time UI updates. + + This runs the heavy processing in a background thread to prevent + blocking the asyncio event loop and WebSocket connections. + + Args: + video_path: Path to input video + output_dir: Directory for output CSV + start_frame: Frame index to start processing from (0-based) + + Returns: + Dict with tracking_csv path and stats + """ + # Get current event loop for scheduling UI updates from background thread + loop = asyncio.get_running_loop() + + # Run processing in background thread + logger.info( + f"Starting video processing in background thread (frame {start_frame})..." + ) + result = await asyncio.to_thread( + self._process_video_sync, video_path, output_dir, loop, start_frame + ) + + logger.info("Video processing complete") + return result diff --git a/docs/tracking/tracking_web_gui.md b/docs/tracking/tracking_web_gui.md new file mode 100644 index 00000000..bede6029 --- /dev/null +++ b/docs/tracking/tracking_web_gui.md @@ -0,0 +1,359 @@ +# Tracking Studio Web GUI + +Interactive web-based application for real-time video object detection and tracking using YOLO and ByteTrack. + +## Overview + +The Tracking Studio provides a user-friendly interface for: +- Loading videos from Google Cloud Storage or local uploads +- Selecting detection models (YOLO, Roboflow, or custom .pt files) +- Tuning ByteTrack parameters in real-time +- Visualizing tracking results with live preview +- Exporting tracking data to CSV + +## Quick Start + +### Prerequisites + +1. **Python 3.10** with the project installed: + + ```bash + pip install -e . + ``` + +2. **FFmpeg** (for video format conversion): + + ```bash + # macOS + brew install ffmpeg + # Ubuntu/Debian + sudo apt install ffmpeg + ``` + +3. **GCS credentials** (for browsing videos in Google Cloud Storage): + - Place your service account JSON at `config-local/collab-data-463313-c340ad86b28e.json` + - Or set the env var: `export GCS_CREDENTIALS=/path/to/credentials.json` + - If neither is set, GCS browsing is disabled (video upload still works) + +4. **Roboflow API key** (only needed for Roboflow models): + + ```bash + export ROBOFLOW_API_KEY=your_api_key_here + ``` + + Get your key from [Roboflow settings](https://app.roboflow.com/settings/api) + +### Running the Application + +```bash +# From the repository root +python scripts/tracking/run_tracking_studio.py +``` + +The application will start on `http://localhost:8080` + +## Workflow + +### 1. Load Video + +**From Google Cloud Storage:** +1. Select **Bucket** (e.g., `collab-data-463313`) +2. Select **Folder** (optional subfolder) +3. Select **Video** from the list +4. Click **Load Video** + +**From Local Upload:** +1. Click **Upload** button +2. Select video file (.mp4, .mov, .avi) +3. Click **Load Video** + +The first frame will display in the preview area. + +### 2. Load Model + +Choose one of three model sources: + +#### YOLO Models +- Enter any YOLO model name (e.g., `yolo11n.pt`, `yolo26n.pt`) +- Models will auto-download if available from Ultralytics +- Click **Load Model** + +#### Roboflow Models +- Enter **Project ID** in format: `workspace/project` +- Click **List Models** to fetch available versions +- Select a **Version** from dropdown +- Click **Load Model** + +**Supported types:** +- Object detection models (standard) +- Instance segmentation models (extracts bounding boxes only) + +#### Custom Models +- Upload your own `.pt` file +- Click **Load Model** + +### 3. Configure Parameters + +**Detection Confidence** +- Threshold for detection scores (0.1 - 0.9) +- Higher = fewer false positives, may miss detections +- Lower = more detections, may include noise + +**ByteTrack Parameters** (see [ByteTrack Parameters](#bytetrack-parameters) below) + +**Skip Frames** +- Process every Nth frame (1 = all frames, 30 = every 30th frame) +- Use higher values for faster preview on long videos +- Final tracking still captures data for skipped frames + +**Display Update** +- Update preview every Nth frame (1-30, default: 10) +- Lower values = smoother preview (more network traffic) +- Higher values = less frequent updates (lower bandwidth) +- At 30fps video: 10 frames = ~3 updates/second, 5 frames = ~6 updates/second + +### 4. Start Tracking + +1. Click **Start Tracking** +2. Watch live preview with bounding boxes and track IDs +3. Use **Pause** to temporarily halt processing +4. Use **Stop** to terminate early +5. Drag the time slider to jump to specific frames + +### 5. Export Results + +When complete, click **Download CSV** to save tracking data. + +## ByteTrack Parameters + +ByteTrack uses a two-stage association algorithm to track objects across frames: + +### `track_high_thresh` (default: 0.25) +Detection confidence threshold separating high-confidence and low-confidence detections. +- Detections **above** this → Stage 1 (primary IoU matching) +- Detections **between** `track_low_thresh` and this → Stage 2 (secondary matching) +- **Raise** to restrict primary matching to most confident detections +- **Lower** to feed more detections into Stage 1 + +### `track_low_thresh` (default: 0.1) +Absolute minimum detection confidence. +- Detections **below** this are discarded entirely +- **Lower** to recover marginal detections (more noise) +- **Raise** to filter weak false positives + +### `new_track_thresh` (default: 0.25) +Minimum confidence required to initialize a new track. +- Unmatched detections from Stage 1 must exceed this to spawn new track IDs +- **Higher** prevents spurious tracks from false positives +- **Lower** allows tracks to start from weaker detections + +### `track_buffer` (default: 30) +Number of frames a lost track is kept alive before deletion. +- Internally scaled by frame rate: `max_time_lost = int(fps / 30.0 * track_buffer)` +- **Higher** values let tracks survive longer occlusions (more ID switches risk) +- **Lower** values remove lost tracks faster + +### `match_thresh` (default: 0.8) +IoU-distance gating threshold for Stage 1 association. +- Cost = 1 - IoU, so threshold of 0.8 accepts matches with IoU ≥ 0.2 +- **Higher** values are more lenient (easier to match) +- **Lower** values require stronger spatial overlap +- Tune based on detector quality + +### `fuse_score` (default: true) +Multiply IoU similarity by detection confidence before matching. +- Formula: `fused_cost = 1 - (iou_similarity * detection_score)` +- **Enable** to bias matching toward high-confidence detections +- **Disable** if detector's confidence scores are poorly calibrated + +## Output Format + +### Tracking CSV + +Format: `tracking.csv` + +| Column | Type | Description | +|--------|------|-------------| +| `track_id` | int | Unique object track identifier | +| `frame` | int | Frame number (0-indexed) | +| `x1` | int | Bounding box top-left X | +| `y1` | int | Bounding box top-left Y | +| `x2` | int | Bounding box bottom-right X | +| `y2` | int | Bounding box bottom-right Y | +| `confidence` | float | Detection confidence score | +| `class` | int | Object class ID | + +Example: +```csv +track_id,frame,x1,y1,x2,y2,confidence,class +1,0,245,150,345,280,0.87,0 +1,1,247,152,346,281,0.85,0 +2,1,450,200,550,320,0.92,0 +``` + +## Model Support + +### YOLO Models (Ultralytics) +- ✅ YOLO11 series (`yolo11n.pt`, `yolo11s.pt`, `yolo11m.pt`, etc.) +- ✅ YOLO26 series (`yolo26n.pt`, `yolo26s.pt`, etc.) +- ✅ Custom trained YOLO models (.pt files) +- Uses **native Ultralytics tracking** (supports all 6 ByteTrack parameters) + +### Roboflow Models +- ✅ Object detection models +- ✅ Instance segmentation models (bounding boxes only, masks ignored) +- Uses **model download + local inference** (YOLO-compatible weights) +- Fallback to **supervision ByteTrack** if native tracking unavailable + +### Custom Models +- ✅ Upload any YOLO-compatible `.pt` file +- Must be trainable with Ultralytics YOLO framework + +## Architecture + +### Components + +**Frontend: NiceGUI** +- Reactive web interface +- Real-time frame updates via WebSocket +- Slider-based parameter tuning + +**Backend: FastAPI (via NiceGUI)** +- Async video processing with `asyncio.to_thread()` +- Background thread handles heavy CV operations +- Event loop scheduling for UI updates + +**Video Processing: [video_processor.py](../../collab_env/tracking_studio/video_processor.py)** +- OpenCV video capture +- Frame-by-frame detection + tracking +- Supervision library for annotation +- Temporary YAML config for ByteTrack parameters + +**Model Management: [model_manager.py](../../collab_env/tracking_studio/model_manager.py)** +- YOLO model loading (Ultralytics) +- Roboflow model loading (inference SDK + fallback download) +- Model caching for faster reloads + +## Video Format Support + +### Supported Formats +- MP4 (H.264 codec) - **recommended** +- MOV (QuickTime) +- AVI (uncompressed) + +### Automatic Conversion +Videos not in H.264 format are automatically converted on load: +- Source: mjpeg, raw, etc. +- Target: H.264 MP4 (1080p max, 30fps) +- Uses FFmpeg via `video_converter.py` + +## Playback Controls + +### Real-time Controls +- **Start Tracking**: Begin processing video +- **Pause**: Temporarily halt (can resume) +- **Stop**: Hard stop (terminates processing) + +### Seeking +- Drag **time slider** during processing to jump to specific frame +- Shows raw frame preview during drag (no tracking) +- Releases to seek tracker forward/backward + +### Preview Updates +- Updates every 10 frames for performance (~3x per second at 30fps) +- Shows annotated frames with bounding boxes and track IDs + +## Performance Tips + +### For Long Videos +1. Use **Skip Frames** to process every Nth frame for faster preview +2. Increase **track_buffer** to maintain tracks across skipped frames +3. Lower **detection confidence** if missing objects + +### For Crowded Scenes +1. Raise **track_high_thresh** to focus on confident detections +2. Lower **match_thresh** to require tighter spatial overlap +3. Increase **new_track_thresh** to reduce spurious tracks + +### For Fast Motion +1. Lower **match_thresh** to accept looser spatial matches +2. Increase **track_buffer** to keep tracks alive longer +3. Process all frames (Skip = 1) for smoother tracking + +## Troubleshooting + +### "No detections found" +- Lower **detection confidence** slider +- Check video quality and lighting +- Try different model (e.g., `yolo11m.pt` instead of `yolo11n.pt`) + +### "Too many false positives" +- Raise **detection confidence** +- Increase **new_track_thresh** +- Raise **track_high_thresh** + +### "Track IDs jumping/switching" +- Lower **track_high_thresh** to feed more detections into Stage 1 +- Raise **match_thresh** for more lenient matching +- Increase **track_buffer** to keep lost tracks alive longer +- Enable **fuse_score** if disabled + +### "Video conversion failed" +- Check FFmpeg installation: `ffmpeg -version` +- Ensure video file is not corrupted +- Try converting manually: `ffmpeg -i input.mp4 -c:v libx264 output.mp4` + +### "Model loading failed" +- YOLO: Check model name spelling and internet connection +- Roboflow: Verify `ROBOFLOW_API_KEY` is set and has access +- Custom: Ensure `.pt` file is YOLO-compatible format + +### Batch Processing + +For offline batch processing without the GUI, use the [full tracking pipeline notebook](full_pipeline.ipynb) or direct API: + +```python +from collab_env.tracking_studio.video_processor import VideoTracker +from collab_env.tracking_studio.model_manager import ModelManager + +# Load model +manager = ModelManager() +model = manager.load_yolo_model("yolo11n.pt") + +# Configure tracker +tracker_config = { + "track_high_thresh": 0.25, + "track_low_thresh": 0.1, + "new_track_thresh": 0.25, + "track_buffer": 30, + "match_thresh": 0.8, + "fuse_score": True, +} + +tracker = VideoTracker(model=model, tracker_config=tracker_config, confidence=0.5) + +# Process video +results = await tracker.process_video_realtime("input.mp4", "/tmp/output") +print(f"Saved to: {results['tracking_csv']}") +``` + +## Cloud Deployment + +The Tracking Studio can be deployed to Google Cloud Run using the provided `Dockerfile.tracking-studio` and `cloudbuild.yaml`: + +```bash +gcloud builds submit --config=cloudbuild.yaml +``` + +This builds a CPU-only Docker image and deploys to Cloud Run with 4GB RAM and 2 vCPUs. + +**Not recommended for interactive use.** The real-time frame preview relies on WebSocket streaming between the browser and server. Cloud Run's request-based scaling, cold starts, and network latency make the interactive experience significantly worse than running locally. For best results, run the studio on a local machine or a persistent VM with a GPU. + +Cloud Run deployment is better suited for batch processing or short demo sessions where latency is acceptable. + +## References + +- [ByteTrack Paper](https://arxiv.org/abs/2110.06864) +- [Ultralytics YOLO](https://docs.ultralytics.com/) +- [Supervision Library](https://supervision.roboflow.com/) +- [Roboflow Inference](https://inference.roboflow.com/) diff --git a/pyproject.toml b/pyproject.toml index f2b507d1..04713936 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,10 @@ dependencies = [ "rich", "pyarrow", "starbars", - "importlib-metadata" + "importlib-metadata", + "nicegui>=1.4.0", + "supervision>=0.18.0", + "inference>=0.28.0" # Upgraded for yolo26n-seg support ] [tool.setuptools] # NEW @@ -68,6 +71,7 @@ packages = [ "collab_env.sim.util", "collab_env.tracking", "collab_env.tracking.model", + "collab_env.tracking_studio", "collab_env.utils" ] @@ -93,7 +97,7 @@ dev = [ "pytest-benchmark", "mypy==1.18.2", "mypy-extensions==1.1.0", - "ruff", + "ruff==0.15.2", "nbval", "nbqa==1.9.1", "types-requests", @@ -104,6 +108,9 @@ dev = [ requires = ["setuptools<70.0"] build-backend = "setuptools.build_meta" +[tool.ruff] +target-version = "py39" + [tool.mypy] ignore_missing_imports = true warn_unused_ignores = true diff --git a/scripts/test_notebooks.sh b/scripts/test_notebooks.sh index dcca0946..4965fac0 100755 --- a/scripts/test_notebooks.sh +++ b/scripts/test_notebooks.sh @@ -11,6 +11,7 @@ EXCLUDED_NOTEBOOKS=( "docs/alignment/align.ipynb" "docs/alignment/reprojection.ipynb" "docs/tracking/full_pipeline.ipynb" + "docs/gnn/gnn3D/sample.ipynb" ) # Notebooks requiring GCS credentials - excluded when SKIP_GCS_TESTS is set diff --git a/scripts/tracking/bytetrack_video_inference.py b/scripts/tracking/bytetrack_video_inference.py new file mode 100644 index 00000000..f0bf1e3f --- /dev/null +++ b/scripts/tracking/bytetrack_video_inference.py @@ -0,0 +1,170 @@ +"""Real-time ByteTracker inference on video with live visualization""" + +import cv2 +import numpy as np +import supervision as sv +from ultralytics import YOLO +from pathlib import Path +import argparse + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, required=True, help="Path to YOLO model") + parser.add_argument("path_to_video", type=str, help="Path to video file") + parser.add_argument("--confidence", type=float, default=0.5, help="Confidence threshold") + parser.add_argument("--track_activation", type=float, default=0.2, help="Track activation threshold") + parser.add_argument("--lost_buffer", type=int, default=90, help="Lost track buffer frames") + parser.add_argument("--match_threshold", type=float, default=0.8, help="Minimum matching threshold") + parser.add_argument("--min_frames", type=int, default=5, help="Minimum consecutive frames") + + args = parser.parse_args() + + # Load model + model = YOLO(args.model_path) + + # Initialize tracker + tracker = sv.ByteTrack( + track_activation_threshold=args.track_activation, + lost_track_buffer=args.lost_buffer, + minimum_matching_threshold=args.match_threshold, + minimum_consecutive_frames=args.min_frames + ) + + # Initialize annotators + box_annotator = sv.BoxAnnotator(thickness=1) + mask_annotator = sv.MaskAnnotator(opacity=0.4) + label_annotator = sv.LabelAnnotator( + text_scale=0.3, + text_thickness=1, + text_padding=3, + text_position=sv.Position.TOP_LEFT, + color=sv.Color.BLACK, + text_color=sv.Color.WHITE, + border_radius=2, + smart_position=True + ) + + # Open video + cap = cv2.VideoCapture(args.path_to_video) + + if not cap.isOpened(): + print(f"Error: Could not open video {args.path_to_video}") + return + + # Get video properties + fps = cap.get(cv2.CAP_PROP_FPS) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + print(f"Video: {width}x{height} @ {fps:.2f} fps, {total_frames} frames") + print(f"Confidence: {args.confidence}") + print(f"ByteTracker params: activation={args.track_activation}, buffer={args.lost_buffer}, match={args.match_threshold}, min_frames={args.min_frames}") + print("\nPress 'q' to quit, 'p' to pause/unpause, SPACE to step frame when paused") + + frame_idx = 0 + paused = False + + # Calculate padding for YOLO + target_height = ((height + 31) // 32) * 32 + target_width = ((width + 31) // 32) * 32 + + while True: + if not paused: + ret, frame = cap.read() + if not ret: + print("\nEnd of video") + break + + # Run YOLO detection + results = model.predict( + source=frame, + conf=args.confidence, + verbose=False, + imgsz=(target_width, target_height) + ) + + # Process detections + if results and results[0].boxes: + boxes = results[0].boxes + masks = results[0].masks + + # Resize masks to original frame size + if masks is not None: + mask_array = masks.data.cpu().numpy() + resized_masks = np.zeros((mask_array.shape[0], height, width)) + for i in range(mask_array.shape[0]): + resized_masks[i] = cv2.resize( + mask_array[i], + (width, height), + interpolation=cv2.INTER_LINEAR + ) + resized_masks = resized_masks > 0.5 + + # Create detections with masks + detections = sv.Detections( + xyxy=boxes.xyxy.cpu().numpy(), + mask=resized_masks, + confidence=boxes.conf.cpu().numpy(), + class_id=boxes.cls.cpu().numpy().astype(np.int32), + ) + else: + # No masks, just boxes + detections = sv.Detections( + xyxy=boxes.xyxy.cpu().numpy(), + confidence=boxes.conf.cpu().numpy(), + class_id=boxes.cls.cpu().numpy().astype(np.int32), + ) + + # Update tracker + detections = tracker.update_with_detections(detections) + + # Create labels + labels = [ + f"#{int(tid)} ({conf:.2f})" + for tid, conf in zip(detections.tracker_id, detections.confidence) + ] + + # Annotate frame + annotated_frame = frame.copy() + if detections.mask is not None: + annotated_frame = mask_annotator.annotate(annotated_frame, detections=detections) + else: + annotated_frame = box_annotator.annotate(annotated_frame, detections=detections) + annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=labels) + + # Add info overlay + info_text = f"Frame: {frame_idx}/{total_frames} | Detections: {len(detections)}" + cv2.putText(annotated_frame, info_text, (10, 20), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) + else: + annotated_frame = frame.copy() + info_text = f"Frame: {frame_idx}/{total_frames} | Detections: 0" + cv2.putText(annotated_frame, info_text, (10, 20), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) + + frame_idx += 1 + + # Display frame + cv2.imshow('ByteTracker Inference', annotated_frame) + + # Handle key presses + key = cv2.waitKey(1 if not paused else 0) & 0xFF + + if key == ord('q'): + print("\nQuitting...") + break + elif key == ord('p'): + paused = not paused + print(f"\n{'Paused' if paused else 'Resumed'}") + elif key == ord(' ') and paused: + # Step one frame forward + ret, frame = cap.read() + if ret: + frame_idx += 1 + + cap.release() + cv2.destroyAllWindows() + +if __name__ == "__main__": + main() diff --git a/scripts/tracking/configs/botsort_thermal_rats.yaml b/scripts/tracking/configs/botsort_thermal_rats.yaml new file mode 100644 index 00000000..ec33ec3f --- /dev/null +++ b/scripts/tracking/configs/botsort_thermal_rats.yaml @@ -0,0 +1,21 @@ +# Custom BoT-SORT config optimized for thermal rat tracking +# BoT-SORT handles camera motion better than ByteTrack + +tracker_type: botsort + +# Detection thresholds +track_high_thresh: 0.3 # High confidence threshold +track_low_thresh: 0.1 # Low confidence threshold +new_track_thresh: 0.2 # Threshold for creating new tracks + +# Track management +track_buffer: 120 # Frames to keep lost tracks alive +match_thresh: 0.7 # IoU threshold for matching + +# BoT-SORT specific +cmc_method: sparseOptFlow # Camera motion compensation method +with_reid: False # ReID features + +# Kalman filter settings (for motion prediction) +std_weight_position: 0.05 +std_weight_velocity: 0.00625 diff --git a/scripts/tracking/configs/bytetrack_thermal_rats.yaml b/scripts/tracking/configs/bytetrack_thermal_rats.yaml new file mode 100644 index 00000000..7c144a15 --- /dev/null +++ b/scripts/tracking/configs/bytetrack_thermal_rats.yaml @@ -0,0 +1,16 @@ +# Custom ByteTrack config optimized for thermal rat tracking +# Lower thresholds to reduce dropped detections + +tracker_type: bytetrack + +# Detection thresholds +track_high_thresh: 0.3 # High confidence threshold (lowered from default 0.5) +track_low_thresh: 0.1 # Low confidence threshold for re-identification (lowered from 0.1) +new_track_thresh: 0.2 # Threshold for creating new tracks (lowered from 0.4) + +# Track management +track_buffer: 120 # Frames to keep lost tracks alive (increased from 30) +match_thresh: 0.7 # IoU threshold for matching (lowered from 0.8 for more lenient matching) + +# Optional features +with_reid: False # ReID features (not needed for rats) diff --git a/scripts/tracking/roboflow_video_inference.py b/scripts/tracking/roboflow_video_inference.py new file mode 100644 index 00000000..3f56c643 --- /dev/null +++ b/scripts/tracking/roboflow_video_inference.py @@ -0,0 +1,23 @@ +# Import the InferencePipeline object +from inference import InferencePipeline +# Import the built in render_boxes sink for visualizing results +from inference.core.interfaces.stream.sinks import render_boxes + +if __name__ == "__main__": + + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--model_id", type=str, required=True) + parser.add_argument("path_to_video", type=str) + + args = parser.parse_args() + + + # initialize a pipeline object + pipeline = InferencePipeline.init( + model_id=args.model_id, # Roboflow model to use + video_reference=args.path_to_video, # Path to video, device id (int, usually 0 for built in webcams), or RTSP stream url + on_prediction=render_boxes, # Function to run after each prediction + ) + pipeline.start() + pipeline.join() diff --git a/scripts/tracking/run_tracking_studio.py b/scripts/tracking/run_tracking_studio.py new file mode 100755 index 00000000..e26a88d3 --- /dev/null +++ b/scripts/tracking/run_tracking_studio.py @@ -0,0 +1,9 @@ +#!/usr/bin/env python3 +""" +Entry point for the Tracking Studio NiceGUI application. + +Run with: python scripts/tracking/run_tracking_studio.py +""" + +# Simply import the app module - ui.run() is called at module level +import collab_env.tracking_studio.app diff --git a/scripts/tracking/yolo_native_tracking.py b/scripts/tracking/yolo_native_tracking.py new file mode 100644 index 00000000..155c1bcd --- /dev/null +++ b/scripts/tracking/yolo_native_tracking.py @@ -0,0 +1,109 @@ +"""Real-time tracking using YOLO's native track() method""" + +import cv2 +import numpy as np +from ultralytics import YOLO +import argparse + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, required=True, help="Path to YOLO model") + parser.add_argument("path_to_video", type=str, help="Path to video file") + parser.add_argument("--confidence", type=float, default=0.2, help="Confidence threshold") + parser.add_argument("--tracker", type=str, default="bytetrack.yaml", + help="Tracker config: bytetrack.yaml, botsort.yaml, or path to custom .yaml") + parser.add_argument("--iou", type=float, default=0.5, help="IOU threshold for NMS") + parser.add_argument("--no-persist", action="store_true", help="Don't persist tracks between frames (default: persist=True)") + + args = parser.parse_args() + + # Load model + model = YOLO(args.model_path) + + # Open video + cap = cv2.VideoCapture(args.path_to_video) + + if not cap.isOpened(): + print(f"Error: Could not open video {args.path_to_video}") + return + + # Get video properties + fps = cap.get(cv2.CAP_PROP_FPS) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + print(f"Video: {width}x{height} @ {fps:.2f} fps, {total_frames} frames") + print(f"Tracker: {args.tracker}") + print(f"Confidence: {args.confidence}, IOU: {args.iou}") + print("\nPress 'q' to quit, 'p' to pause/unpause, SPACE to step frame when paused") + + frame_idx = 0 + paused = False + current_frame = None + + # Calculate padding for YOLO + target_height = ((height + 31) // 32) * 32 + target_width = ((width + 31) // 32) * 32 + + while True: + if not paused: + ret, frame = cap.read() + if not ret: + print("\nEnd of video") + break + + # Run YOLO tracking (track() method does detection + tracking in one step!) + results = model.track( + source=frame, + conf=args.confidence, + iou=args.iou, + tracker=args.tracker, + persist=not args.no_persist, # Persist tracks between frames (True by default) + verbose=False, + imgsz=(target_width, target_height), + device='mps' if hasattr(model, 'device') else 'cpu' # Use MPS on Mac if available + ) + + # Get annotated frame with tracking visualization + # YOLO's plot() method draws boxes, masks, and track IDs automatically + annotated_frame = results[0].plot() + + # Add custom info overlay + if results[0].boxes is not None and results[0].boxes.id is not None: + n_detections = len(results[0].boxes.id) + track_ids = results[0].boxes.id.cpu().numpy().astype(int) + unique_tracks = len(np.unique(track_ids)) + info_text = f"Frame: {frame_idx}/{total_frames} | Detections: {n_detections} | Unique IDs: {unique_tracks}" + else: + info_text = f"Frame: {frame_idx}/{total_frames} | Detections: 0" + + cv2.putText(annotated_frame, info_text, (10, 20), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) + + current_frame = annotated_frame + frame_idx += 1 + + # Display frame + if current_frame is not None: + cv2.imshow('YOLO Native Tracking', current_frame) + + # Handle key presses + key = cv2.waitKey(1 if not paused else 0) & 0xFF + + if key == ord('q'): + print("\nQuitting...") + break + elif key == ord('p'): + paused = not paused + print(f"\n{'Paused' if paused else 'Resumed'}") + elif key == ord(' ') and paused: + # Step one frame forward when paused + paused = False + continue + + cap.release() + cv2.destroyAllWindows() + +if __name__ == "__main__": + main()