diff --git a/packages/nvidia_nat_core/src/nat/middleware/cache/cache_middleware.py b/packages/nvidia_nat_core/src/nat/middleware/cache/cache_middleware.py index 0ed26e22df..70def61313 100644 --- a/packages/nvidia_nat_core/src/nat/middleware/cache/cache_middleware.py +++ b/packages/nvidia_nat_core/src/nat/middleware/cache/cache_middleware.py @@ -30,6 +30,7 @@ import json import logging +from collections import OrderedDict from collections.abc import AsyncIterator from typing import Any @@ -44,6 +45,14 @@ logger = logging.getLogger(__name__) +# Default bound on cache size. The previous implementation used an unbounded +# dict which, under sustained unique input, grew without limit — a memory- +# exhaustion DoS and, combined with fuzzy matching, a long-lived surface for +# cross-request confusion. OrderedDict-backed LRU evicts the oldest entry +# when the cache exceeds this bound. +_DEFAULT_MAX_CACHE_ENTRIES = 1024 + + class CacheMiddleware(FunctionMiddleware): """Cache middleware that memoizes function outputs based on input similarity. @@ -67,19 +76,36 @@ class CacheMiddleware(FunctionMiddleware): computation. """ - def __init__(self, *, enabled_mode: str, similarity_threshold: float) -> None: + def __init__( + self, + *, + enabled_mode: str, + similarity_threshold: float, + max_entries: int = _DEFAULT_MAX_CACHE_ENTRIES, + ) -> None: """Initialize the cache middleware. Args: enabled_mode: Either "always" or "eval". If "eval", only caches when Context.is_evaluating is True. - similarity_threshold: Similarity threshold between 0 and 1. - If 1.0, performs exact matching. Otherwise uses fuzzy matching. + similarity_threshold: Similarity threshold in [0, 1.0]. If 1.0, + performs exact matching. Lower values enable difflib-based + fuzzy matching; note that difflib is quadratic in the worst + case, so large caches with low thresholds may have a + performance cost. Values near 0 increase the risk of cache + collisions where different inputs return the same cached + response. + max_entries: Maximum number of cache entries. When exceeded, the + oldest entry is evicted (LRU). Defaults to + _DEFAULT_MAX_CACHE_ENTRIES. """ super().__init__(is_final=True) self._enabled_mode = enabled_mode self._similarity_threshold = similarity_threshold - self._cache: dict[str, Any] = {} + # OrderedDict gives O(1) LRU: move_to_end() on hit, popitem(last=False) + # to evict the oldest when we exceed max_entries. + self._cache: OrderedDict[str, Any] = OrderedDict() + self._max_entries = max_entries # ==================== Abstract Method Implementations ==================== @@ -142,22 +168,13 @@ def _find_similar_key(self, input_str: str) -> str | None: # Exact matching - fast path return input_str if input_str in self._cache else None - # Fuzzy matching using difflib import difflib - best_match = None - best_ratio = 0.0 - - for cached_key in self._cache: - # Use SequenceMatcher for similarity computation - matcher = difflib.SequenceMatcher(None, input_str, cached_key) - ratio = matcher.ratio() - - if ratio >= self._similarity_threshold and ratio > best_ratio: - best_ratio = ratio - best_match = cached_key - - return best_match + best_matches = difflib.get_close_matches( + input_str, self._cache.keys(), n=1, cutoff=self._similarity_threshold) + if best_matches: + return best_matches[0] + return None async def function_middleware_invoke(self, *args: Any, @@ -199,10 +216,13 @@ async def function_middleware_invoke(self, # Phase 1: Preprocess - look for a similar cached input similar_key = self._find_similar_key(input_str) if similar_key is not None: - # Cache hit - short-circuit and return cached output + # Cache hit - short-circuit and return cached output. + # Move the hit entry to the MRU end so LRU eviction prefers truly + # old entries, not just recently-useful ones. logger.debug("Cache hit for function %s with similarity %.2f", context.name, 1.0 if similar_key == input_str else self._similarity_threshold) + self._cache.move_to_end(similar_key) # Phase 4: Continue - return cached result return self._cache[similar_key] @@ -210,9 +230,17 @@ async def function_middleware_invoke(self, logger.debug("Cache miss for function %s", context.name) result = await call_next(*args, **kwargs) - # Phase 3: Postprocess - cache the result for future use + # Phase 3: Postprocess - cache the result for future use. Enforce the + # LRU bound BEFORE insert so the new entry always lands in a cache of + # size <= max_entries, preventing unbounded memory growth (DoS). self._cache[input_str] = result - logger.debug("Cached result for function %s", context.name) + self._cache.move_to_end(input_str) + while len(self._cache) > self._max_entries: + self._cache.popitem(last=False) + logger.debug("Cached result for function %s (size=%d/%d)", + context.name, + len(self._cache), + self._max_entries) # Phase 4: Continue - return the fresh result return result diff --git a/packages/nvidia_nat_core/src/nat/middleware/cache/cache_middleware_config.py b/packages/nvidia_nat_core/src/nat/middleware/cache/cache_middleware_config.py index 7f24f03485..1ad122c773 100644 --- a/packages/nvidia_nat_core/src/nat/middleware/cache/cache_middleware_config.py +++ b/packages/nvidia_nat_core/src/nat/middleware/cache/cache_middleware_config.py @@ -19,6 +19,7 @@ from pydantic import Field from nat.data_models.middleware import FunctionMiddlewareBaseConfig +from nat.middleware.cache.cache_middleware import _DEFAULT_MAX_CACHE_ENTRIES class CacheMiddlewareConfig(FunctionMiddlewareBaseConfig, name="cache"): @@ -31,14 +32,35 @@ class CacheMiddlewareConfig(FunctionMiddlewareBaseConfig, name="cache"): enabled_mode: Controls when caching is active: - "always": Cache is always enabled - "eval": Cache only active when Context.is_evaluating is True - similarity_threshold: Float between 0 and 1 for input matching: - - 1.0: Exact string matching (fastest) - - < 1.0: Fuzzy matching using difflib similarity + similarity_threshold: Float in [0, 1.0] for input matching: + - 1.0: Exact string matching (fastest, recommended) + - < 1.0: Fuzzy matching via difflib. Note that difflib is + quadratic in the worst case, so large caches with low + thresholds may have a performance cost. Values near 0 + increase the risk of cache collisions where different + inputs return the same cached response. + max_entries: Upper bound on cached entries. When exceeded, the + least-recently-used entry is evicted. Must be a positive int; + defaults to _DEFAULT_MAX_CACHE_ENTRIES. """ enabled_mode: Literal["always", "eval"] = Field( default="eval", description="When caching is enabled: 'always' or 'eval' (only during evaluation)") - similarity_threshold: float = Field(default=1.0, - ge=0.0, - le=1.0, - description="Similarity threshold between 0 and 1. Use 1.0 for exact matching") + similarity_threshold: float = Field( + default=1.0, + ge=0, + le=1.0, + description=( + "Similarity threshold in [0, 1.0]. Use 1.0 for exact matching (recommended). " + "Lower values enable fuzzy matching via difflib; note that difflib is quadratic " + "in the worst case, so large caches with low thresholds may have a performance " + "cost. Values near 0 increase the risk of cache collisions where different " + "inputs return the same cached response."), + ) + max_entries: int = Field( + default=_DEFAULT_MAX_CACHE_ENTRIES, + ge=1, + description=("Maximum number of cache entries before LRU eviction. Must be >= 1. " + "Prevents memory-exhaustion DoS from unbounded cache growth under " + "sustained unique inputs."), + ) diff --git a/packages/nvidia_nat_core/src/nat/middleware/cache/register.py b/packages/nvidia_nat_core/src/nat/middleware/cache/register.py index 48d2702095..e8a77b997e 100644 --- a/packages/nvidia_nat_core/src/nat/middleware/cache/register.py +++ b/packages/nvidia_nat_core/src/nat/middleware/cache/register.py @@ -30,4 +30,8 @@ async def cache_middleware(config: CacheMiddlewareConfig, builder: Builder): Yields: A configured cache middleware instance """ - yield CacheMiddleware(enabled_mode=config.enabled_mode, similarity_threshold=config.similarity_threshold) + yield CacheMiddleware( + enabled_mode=config.enabled_mode, + similarity_threshold=config.similarity_threshold, + max_entries=config.max_entries, + ) diff --git a/packages/nvidia_nat_core/tests/nat/middleware/test_cache_middleware.py b/packages/nvidia_nat_core/tests/nat/middleware/test_cache_middleware.py index c1e14f5de1..aa5cb2a034 100644 --- a/packages/nvidia_nat_core/tests/nat/middleware/test_cache_middleware.py +++ b/packages/nvidia_nat_core/tests/nat/middleware/test_cache_middleware.py @@ -62,7 +62,7 @@ def test_default_initialization(self): def test_custom_initialization(self): """Test custom initialization.""" - middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.8) + middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.9) # Check attributes are set assert hasattr(middleware, '_enabled_mode') assert hasattr(middleware, '_similarity_threshold') @@ -109,7 +109,7 @@ async def mock_next_call(*args, **kwargs): async def test_fuzzy_match_caching(self, middleware_context): """Test fuzzy matching with similarity_threshold < 1.0.""" - middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.8) + middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.9) call_count = 0 @@ -267,8 +267,7 @@ async def mock_next_call(*args, **kwargs): def test_similarity_computation_for_different_thresholds(self): """Test similarity computation for different thresholds.""" - # This is more of a unit test for the similarity logic - middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.5) + middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.9) # Directly test internal methods # Add a cached entry @@ -278,14 +277,14 @@ def test_similarity_computation_for_different_thresholds(self): # Test various similarity levels # Exact match assert middleware._find_similar_key(test_key) == test_key # noqa - # Very similar + # Very similar (one char shorter, ~0.95 ratio) assert middleware._find_similar_key("hello worl") == test_key # noqa # Too different - use a completely different string assert middleware._find_similar_key("xyz123abc") is None # noqa async def test_multiple_similar_entries(self, middleware_context): """Test behavior with multiple similar cached entries.""" - middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.7) + middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.85) # Pre-populate cache with similar entries key1 = middleware._serialize_input( # noqa @@ -306,3 +305,73 @@ async def mock_next_call(*args, **kwargs): input_str = {"value": "test input X", "number": 42} await middleware.function_middleware_invoke(input_str, call_next=mock_next_call, context=middleware_context) # The exact behavior depends on which cached key is most similar + + +class TestMaxEntriesLruEviction: + """The cache must bound its size to prevent memory-exhaustion DoS. + + The previous implementation used an unbounded dict; sustained unique + inputs would grow the cache without limit, eventually crashing the + process. LRU eviction ensures the cache stays within max_entries. + """ + + async def test_default_max_entries_is_positive(self): + mw = CacheMiddleware(enabled_mode="always", similarity_threshold=1.0) + assert mw._max_entries > 0 # noqa: SLF001 + + async def test_cache_evicts_oldest_when_exceeding_max_entries(self, middleware_context): + """Insert more unique entries than max_entries; verify size stays bounded.""" + mw = CacheMiddleware( + enabled_mode="always", + similarity_threshold=1.0, # exact match keeps the test deterministic + max_entries=3, + ) + + call_count = 0 + + async def mock_next_call(*_args, **_kwargs): + nonlocal call_count + call_count += 1 + return _TestOutput(result=f"result_{call_count}") + + for i in range(10): + await mw.function_middleware_invoke( + {"value": f"unique_input_{i}"}, + call_next=mock_next_call, + context=middleware_context, + ) + + assert len(mw._cache) == 3 # noqa: SLF001 + # The MOST recent three inserts should be what's left. + latest_keys = list(mw._cache.keys()) # noqa: SLF001 + for i in range(7, 10): + assert any(f"unique_input_{i}" in k for k in latest_keys) + + async def test_cache_hit_promotes_entry_to_most_recently_used(self, middleware_context): + """A cache hit should move the entry to MRU so later evictions spare it.""" + mw = CacheMiddleware( + enabled_mode="always", + similarity_threshold=1.0, + max_entries=3, + ) + + async def mock_next_call(*_args, **_kwargs): + return _TestOutput(result="r") + + # Fill the cache with A, B, C (A is oldest) + for key in ("A", "B", "C"): + await mw.function_middleware_invoke( + {"value": key}, call_next=mock_next_call, context=middleware_context) + + # Hit A again — should promote A to the MRU end + await mw.function_middleware_invoke( + {"value": "A"}, call_next=mock_next_call, context=middleware_context) + + # Now insert D — B (now oldest) should be evicted, not A. + await mw.function_middleware_invoke( + {"value": "D"}, call_next=mock_next_call, context=middleware_context) + + keys = "".join(list(mw._cache.keys())) # noqa: SLF001 + assert '"value": "A"' in keys + assert '"value": "D"' in keys + assert '"value": "B"' not in keys