Skip to content

Fix memory leak in worker batch processing#119

Open
mihow wants to merge 2 commits intomainfrom
fix/worker-memory-leak
Open

Fix memory leak in worker batch processing#119
mihow wants to merge 2 commits intomainfrom
fix/worker-memory-leak

Conversation

@mihow
Copy link
Collaborator

@mihow mihow commented Mar 3, 2026

Summary

  • Extract batch processing into _process_batch() so large intermediates (image tensors, crops, batched crops) go out of scope after each batch
  • Replace all_detections accumulator list with a total_detections counter — the list grew unboundedly but was only used for len() in a log message
  • Add torch.cuda.empty_cache() between batches to free CUDA allocator caches
  • Add memory regression test that processes 50 tasks and asserts RSS growth stays bounded

Measurements (500 tasks / 250 batches, separate processes, RTX 4090)

Old code Fixed code
Total RSS growth 1749 MB 859 MB
Per-batch rate (OLS) 7.15 MB/batch 3.43 MB/batch
Improvement 52% less growth
Projected at 31K images 108 GB 52 GB

The remaining ~3.4 MB/batch growth appears to be from PyTorch/CUDA allocator internals and ML model class state — not from _process_job() scope. Addressing that would require changes deeper in the ML model classes.

Test plan

  • Memory regression test passes (25 batches, <150 MB growth)
  • All 12 existing test_worker.py tests pass (no behavior change)
  • Deploy to a GPU server and monitor RSS during a real job

Closes #118

🤖 Generated with Claude Code

Summary by CodeRabbit

Release Notes

  • New Features

    • Added optional batch completion callbacks for monitoring processing progress during job execution.
    • Enhanced per-batch metrics collection including detection and classification timing statistics.
  • Tests

    • Introduced memory regression test validating batch processing stability without resource leaks across iterations.
  • Chores

    • Improved internal logging for per-batch flow with detailed timing metrics.

@coderabbitai
Copy link

coderabbitai bot commented Mar 3, 2026

📝 Walkthrough

Walkthrough

The changes refactor batch processing in the worker to eliminate a memory leak by introducing a modular _process_batch helper function, replacing an unbounded detections list with a counter, implementing explicit per-batch cleanup, and adding an optional callback mechanism for monitoring batch completion. A regression test validates memory stability across batches.

Changes

Cohort / File(s) Summary
Test Coverage
trapdata/antenna/tests/test_memory_leak.py
New regression test that simulates the complete batch processing pipeline, monitors resident memory (RSS) after each batch via an on_batch_complete callback, and validates that memory growth remains below 150 MB threshold across batches.
Worker Batch Processing Refactor
trapdata/antenna/worker.py
Extracted batch processing logic into new _process_batch() helper function; replaced unbounded all_detections list with total_detections counter; added optional on_batch_complete callback to _process_job() for per-batch monitoring; updated logging to reflect per-batch metrics and cumulative totals; removed unused post_batch_results import; added future annotations and type hints.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

🐰 A leak's been plugged, the memory's free,
Batches now clean as they well should be,
No more tensors in limbo so long,
_process_batch keeps things moving along,
With callbacks to watch, we'll know all is well! 🎉

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Fix memory leak in worker batch processing' directly describes the main change, which involves extracting batch processing logic and adding memory cleanup to resolve the identified memory leak issue.
Linked Issues check ✅ Passed All coding objectives from issue #118 are addressed: all_detections replaced with total_detections counter, batch processing extracted into _process_batch() to scope large intermediates, torch.cuda.empty_cache() added between batches, and a regression test validating memory growth is included.
Out of Scope Changes check ✅ Passed All changes are directly related to fixing the memory leak issue #118: the new test validates the fix, worker.py refactoring implements cleanup, and modifications to the _process_job signature support memory profiling callbacks.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch fix/worker-memory-leak

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

mihow and others added 2 commits March 3, 2026 00:30
Extract batch processing body into _process_batch() so all large
intermediates (image_tensors, crops, batched_crops, image_detections,
detector_results) go out of scope after each batch and are freed by
reference counting.

Replace all_detections accumulator list with a total_detections counter
— the list grew unboundedly across all batches but was only used for
len() in a final log message.

Add torch.cuda.empty_cache() at end of each batch to free CUDA allocator
caches.

Add optional on_batch_complete callback to _process_job() for memory
profiling.

Add memory leak regression test: RSS growth across 25 batches must be
< 150 MB. Test showed 88.0 MB growth (well under threshold) with the fix.

Re-applied on top of main's refactored worker.py, which added:
- _apply_binary_classification() helper (called from _process_batch)
- ResultPoster for async result posting
- processing_service_name parameter to _process_job

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
DataLoader batches may yield image_id elements as Tensor scalars.
Cast to str before passing to AntennaTaskResultError(image_id=...)
which expects str | None.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@mihow mihow force-pushed the fix/worker-memory-leak branch from 84b20fa to 841fe87 Compare March 3, 2026 08:43
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@trapdata/antenna/tests/test_memory_leak.py`:
- Around line 27-31: The helper _get_rss_mb currently assumes Linux and opens
"/proc/self/statm"; make it robust by detecting availability and raising a clear
SkipTest (or returning None) when "/proc/self/statm" or
os.sysconf("SC_PAGE_SIZE") is not available so tests don't fail on non-Linux
platforms: update the _get_rss_mb function to check
os.path.exists("/proc/self/statm") and handle exceptions around os.sysconf, and
modify any callers in the same test file (the checks around lines where
_get_rss_mb is used, e.g., the blocks referenced at 61-63) to skip the test when
_get_rss_mb indicates unavailability instead of proceeding and failing.
- Around line 51-58: The helper function _make_settings lacks a return type
annotation; add one to its signature (e.g., def _make_settings() -> MagicMock:)
and ensure MagicMock is imported (from unittest.mock import MagicMock) or use
the concrete Settings type if one exists; update the signature in the
_make_settings definition to include the chosen return type so the function is
properly type-hinted.

In `@trapdata/antenna/worker.py`:
- Around line 247-251: The length-check raises a ValueError but subsequent code
still uses zip(..., strict=True) which can re-raise and prevent emitting
AntennaTaskResultError for the batch; update the handler so when the lengths
mismatch (image_ids, reply_subjects, image_urls) you do not rely on zip(...,
strict=True) — instead detect the mismatch and immediately construct and emit an
AntennaTaskResultError for each affected image (or the whole batch) without
using strict zip, or iterate by index using a safe min/len loop or
itertools.zip_longest with explicit None checks; make the same change for the
second occurrence that references zip(..., strict=True) (the block around the
other reported lines) so mismatched payloads always result in
AntennaTaskResultError emissions rather than a secondary exception.

ℹ️ Review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a0cf1c6 and 841fe87.

📒 Files selected for processing (2)
  • trapdata/antenna/tests/test_memory_leak.py
  • trapdata/antenna/worker.py

Comment on lines +27 to +31
def _get_rss_mb() -> float:
"""Current RSS in MB, read from /proc/self/statm (Linux-only)."""
with open("/proc/self/statm") as f:
pages = int(f.read().split()[1]) # resident pages
return pages * os.sysconf("SC_PAGE_SIZE") / (1024 * 1024)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Guard the Linux-specific RSS probe to avoid cross-platform test failures.

/proc/self/statm is Linux-only; skip this test when unavailable.

💡 Proposed fix
 `@pytest.mark.slow`
 def test_rss_stable_across_batches(self):
+    if not os.path.exists("/proc/self/statm"):
+        pytest.skip("RSS regression test requires Linux /proc/self/statm")
     """RSS should not grow more than 150 MB across 25+ batches.

Also applies to: 61-63

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@trapdata/antenna/tests/test_memory_leak.py` around lines 27 - 31, The helper
_get_rss_mb currently assumes Linux and opens "/proc/self/statm"; make it robust
by detecting availability and raising a clear SkipTest (or returning None) when
"/proc/self/statm" or os.sysconf("SC_PAGE_SIZE") is not available so tests don't
fail on non-Linux platforms: update the _get_rss_mb function to check
os.path.exists("/proc/self/statm") and handle exceptions around os.sysconf, and
modify any callers in the same test file (the checks around lines where
_get_rss_mb is used, e.g., the blocks referenced at 61-63) to skip the test when
_get_rss_mb indicates unavailability instead of proceeding and failing.

Comment on lines +51 to +58
def _make_settings(self):
settings = MagicMock()
settings.antenna_api_base_url = "http://testserver/api/v2"
settings.antenna_api_auth_token = "test-token"
settings.antenna_api_batch_size = 2
settings.num_workers = 0
settings.localization_batch_size = 2
return settings
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add a return type annotation to _make_settings.

This signature is missing a return type hint.

💡 Proposed fix
-    def _make_settings(self):
+    def _make_settings(self) -> MagicMock:
         settings = MagicMock()
         settings.antenna_api_base_url = "http://testserver/api/v2"
         settings.antenna_api_auth_token = "test-token"
         settings.antenna_api_batch_size = 2
         settings.num_workers = 0
         settings.localization_batch_size = 2
         return settings

As per coding guidelines, trapdata/**/*.py: "Use type hints in function signatures to document expected types without requiring extensive documentation."

🧰 Tools
🪛 Ruff (0.15.2)

[error] 54-54: Possible hardcoded password assigned to: "antenna_api_auth_token"

(S105)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@trapdata/antenna/tests/test_memory_leak.py` around lines 51 - 58, The helper
function _make_settings lacks a return type annotation; add one to its signature
(e.g., def _make_settings() -> MagicMock:) and ensure MagicMock is imported
(from unittest.mock import MagicMock) or use the concrete Settings type if one
exists; update the signature in the _make_settings definition to include the
chosen return type so the function is properly type-hinted.

Comment on lines +247 to +251
if len(image_ids) != len(reply_subjects) or len(image_ids) != len(image_urls):
raise ValueError(
f"Length mismatch: image_ids ({len(image_ids)}), "
f"reply_subjects ({len(reply_subjects)}), image_urls ({len(image_urls)})"
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Exception fallback can fail on the same length mismatch it is handling.

If the payload-length checks fail, the handler currently uses zip(..., strict=True) and can raise again, which prevents emitting AntennaTaskResultError results for the batch.

💡 Proposed fix
-        for reply_subject, image_id in zip(reply_subjects, image_ids, strict=True):
+        max_len = max(len(reply_subjects), len(image_ids))
+        for idx in range(max_len):
+            reply_subject = reply_subjects[idx] if idx < len(reply_subjects) else None
+            image_id = image_ids[idx] if idx < len(image_ids) else None
             batch_results.append(
                 AntennaTaskResult(
                     reply_subject=reply_subject,
                     result=AntennaTaskResultError(
                         error=f"Batch processing error: {e}",
                         image_id=str(image_id) if image_id is not None else None,
                     ),
                 )
             )

Also applies to: 373-374

🧰 Tools
🪛 Ruff (0.15.2)

[warning] 248-251: Abstract raise to an inner function

(TRY301)


[warning] 248-251: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@trapdata/antenna/worker.py` around lines 247 - 251, The length-check raises a
ValueError but subsequent code still uses zip(..., strict=True) which can
re-raise and prevent emitting AntennaTaskResultError for the batch; update the
handler so when the lengths mismatch (image_ids, reply_subjects, image_urls) you
do not rely on zip(..., strict=True) — instead detect the mismatch and
immediately construct and emit an AntennaTaskResultError for each affected image
(or the whole batch) without using strict zip, or iterate by index using a safe
min/len loop or itertools.zip_longest with explicit None checks; make the same
change for the second occurrence that references zip(..., strict=True) (the
block around the other reported lines) so mismatched payloads always result in
AntennaTaskResultError emissions rather than a secondary exception.

@mihow
Copy link
Collaborator Author

mihow commented Mar 3, 2026

@carlosgjs you may still be interested in this one! I was getting OOM errors in production on jobs over a few hours. No matter the worker type.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Memory leak in worker batch processing causes OOM

1 participant