-
Notifications
You must be signed in to change notification settings - Fork 616
fix(cache-middleware): bound cache size and enforce safe fuzzy threshold #1879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
24de910
9f60ef6
ee05bcd
2036cf2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,6 +30,7 @@ | |
|
|
||
| import json | ||
| import logging | ||
| from collections import OrderedDict | ||
| from collections.abc import AsyncIterator | ||
| from typing import Any | ||
|
|
||
|
|
@@ -44,6 +45,14 @@ | |
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| # Default bound on cache size. The previous implementation used an unbounded | ||
| # dict which, under sustained unique input, grew without limit — a memory- | ||
| # exhaustion DoS and, combined with fuzzy matching, a long-lived surface for | ||
| # cross-request confusion. OrderedDict-backed LRU evicts the oldest entry | ||
| # when the cache exceeds this bound. | ||
| _DEFAULT_MAX_CACHE_ENTRIES = 1024 | ||
|
|
||
|
|
||
| class CacheMiddleware(FunctionMiddleware): | ||
| """Cache middleware that memoizes function outputs based on input similarity. | ||
|
|
||
|
|
@@ -67,19 +76,36 @@ class CacheMiddleware(FunctionMiddleware): | |
| computation. | ||
| """ | ||
|
|
||
| def __init__(self, *, enabled_mode: str, similarity_threshold: float) -> None: | ||
| def __init__( | ||
| self, | ||
| *, | ||
| enabled_mode: str, | ||
| similarity_threshold: float, | ||
| max_entries: int = _DEFAULT_MAX_CACHE_ENTRIES, | ||
| ) -> None: | ||
| """Initialize the cache middleware. | ||
|
|
||
| Args: | ||
| enabled_mode: Either "always" or "eval". If "eval", only caches | ||
| when Context.is_evaluating is True. | ||
| similarity_threshold: Similarity threshold between 0 and 1. | ||
| If 1.0, performs exact matching. Otherwise uses fuzzy matching. | ||
| similarity_threshold: Similarity threshold in [0, 1.0]. If 1.0, | ||
| performs exact matching. Lower values enable difflib-based | ||
| fuzzy matching; note that difflib is quadratic in the worst | ||
| case, so large caches with low thresholds may have a | ||
| performance cost. Values near 0 increase the risk of cache | ||
| collisions where different inputs return the same cached | ||
| response. | ||
| max_entries: Maximum number of cache entries. When exceeded, the | ||
| oldest entry is evicted (LRU). Defaults to | ||
| _DEFAULT_MAX_CACHE_ENTRIES. | ||
| """ | ||
| super().__init__(is_final=True) | ||
| self._enabled_mode = enabled_mode | ||
| self._similarity_threshold = similarity_threshold | ||
| self._cache: dict[str, Any] = {} | ||
| # OrderedDict gives O(1) LRU: move_to_end() on hit, popitem(last=False) | ||
| # to evict the oldest when we exceed max_entries. | ||
| self._cache: OrderedDict[str, Any] = OrderedDict() | ||
| self._max_entries = max_entries | ||
|
|
||
| # ==================== Abstract Method Implementations ==================== | ||
|
|
||
|
|
@@ -142,22 +168,13 @@ def _find_similar_key(self, input_str: str) -> str | None: | |
| # Exact matching - fast path | ||
| return input_str if input_str in self._cache else None | ||
|
|
||
| # Fuzzy matching using difflib | ||
| import difflib | ||
|
|
||
| best_match = None | ||
| best_ratio = 0.0 | ||
|
|
||
| for cached_key in self._cache: | ||
| # Use SequenceMatcher for similarity computation | ||
| matcher = difflib.SequenceMatcher(None, input_str, cached_key) | ||
| ratio = matcher.ratio() | ||
|
|
||
| if ratio >= self._similarity_threshold and ratio > best_ratio: | ||
| best_ratio = ratio | ||
| best_match = cached_key | ||
|
|
||
| return best_match | ||
| best_matches = difflib.get_close_matches( | ||
| input_str, self._cache.keys(), n=1, cutoff=self._similarity_threshold) | ||
| if best_matches: | ||
| return best_matches[0] | ||
| return None | ||
|
|
||
| async def function_middleware_invoke(self, | ||
| *args: Any, | ||
|
|
@@ -199,20 +216,31 @@ async def function_middleware_invoke(self, | |
| # Phase 1: Preprocess - look for a similar cached input | ||
| similar_key = self._find_similar_key(input_str) | ||
| if similar_key is not None: | ||
| # Cache hit - short-circuit and return cached output | ||
| # Cache hit - short-circuit and return cached output. | ||
| # Move the hit entry to the MRU end so LRU eviction prefers truly | ||
| # old entries, not just recently-useful ones. | ||
| logger.debug("Cache hit for function %s with similarity %.2f", | ||
| context.name, | ||
| 1.0 if similar_key == input_str else self._similarity_threshold) | ||
| self._cache.move_to_end(similar_key) | ||
| # Phase 4: Continue - return cached result | ||
| return self._cache[similar_key] | ||
|
|
||
| # Phase 2: Call next - no cache hit, call next middleware/function | ||
| logger.debug("Cache miss for function %s", context.name) | ||
| result = await call_next(*args, **kwargs) | ||
|
|
||
| # Phase 3: Postprocess - cache the result for future use | ||
| # Phase 3: Postprocess - cache the result for future use. Enforce the | ||
| # LRU bound BEFORE insert so the new entry always lands in a cache of | ||
| # size <= max_entries, preventing unbounded memory growth (DoS). | ||
| self._cache[input_str] = result | ||
| logger.debug("Cached result for function %s", context.name) | ||
| self._cache.move_to_end(input_str) | ||
|
Comment on lines
236
to
+237
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this call to |
||
| while len(self._cache) > self._max_entries: | ||
| self._cache.popitem(last=False) | ||
| logger.debug("Cached result for function %s (size=%d/%d)", | ||
| context.name, | ||
| len(self._cache), | ||
| self._max_entries) | ||
|
|
||
| # Phase 4: Continue - return the fresh result | ||
| return result | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -19,6 +19,7 @@ | |||||
| from pydantic import Field | ||||||
|
|
||||||
| from nat.data_models.middleware import FunctionMiddlewareBaseConfig | ||||||
| from nat.middleware.cache.cache_middleware import _DEFAULT_MAX_CACHE_ENTRIES | ||||||
|
|
||||||
|
|
||||||
| class CacheMiddlewareConfig(FunctionMiddlewareBaseConfig, name="cache"): | ||||||
|
|
@@ -31,14 +32,35 @@ class CacheMiddlewareConfig(FunctionMiddlewareBaseConfig, name="cache"): | |||||
| enabled_mode: Controls when caching is active: | ||||||
| - "always": Cache is always enabled | ||||||
| - "eval": Cache only active when Context.is_evaluating is True | ||||||
| similarity_threshold: Float between 0 and 1 for input matching: | ||||||
| - 1.0: Exact string matching (fastest) | ||||||
| - < 1.0: Fuzzy matching using difflib similarity | ||||||
| similarity_threshold: Float in [0, 1.0] for input matching: | ||||||
| - 1.0: Exact string matching (fastest, recommended) | ||||||
| - < 1.0: Fuzzy matching via difflib. Note that difflib is | ||||||
| quadratic in the worst case, so large caches with low | ||||||
| thresholds may have a performance cost. Values near 0 | ||||||
| increase the risk of cache collisions where different | ||||||
| inputs return the same cached response. | ||||||
| max_entries: Upper bound on cached entries. When exceeded, the | ||||||
| least-recently-used entry is evicted. Must be a positive int; | ||||||
| defaults to _DEFAULT_MAX_CACHE_ENTRIES. | ||||||
| """ | ||||||
|
|
||||||
| enabled_mode: Literal["always", "eval"] = Field( | ||||||
| default="eval", description="When caching is enabled: 'always' or 'eval' (only during evaluation)") | ||||||
| similarity_threshold: float = Field(default=1.0, | ||||||
| ge=0.0, | ||||||
| le=1.0, | ||||||
| description="Similarity threshold between 0 and 1. Use 1.0 for exact matching") | ||||||
| similarity_threshold: float = Field( | ||||||
| default=1.0, | ||||||
| ge=0, | ||||||
| le=1.0, | ||||||
| description=( | ||||||
| "Similarity threshold in [0, 1.0]. Use 1.0 for exact matching (recommended). " | ||||||
| "Lower values enable fuzzy matching via difflib; note that difflib is quadratic " | ||||||
| "in the worst case, so large caches with low thresholds may have a performance " | ||||||
| "cost. Values near 0 increase the risk of cache collisions where different " | ||||||
| "inputs return the same cached response."), | ||||||
| ) | ||||||
| max_entries: int = Field( | ||||||
| default=_DEFAULT_MAX_CACHE_ENTRIES, | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Remove the const global, 1024 isn't a number that needs to be a const. |
||||||
| ge=1, | ||||||
| description=("Maximum number of cache entries before LRU eviction. Must be >= 1. " | ||||||
| "Prevents memory-exhaustion DoS from unbounded cache growth under " | ||||||
| "sustained unique inputs."), | ||||||
| ) | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,7 +62,7 @@ def test_default_initialization(self): | |
|
|
||
| def test_custom_initialization(self): | ||
| """Test custom initialization.""" | ||
| middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.8) | ||
| middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.9) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Revert, I don't think this is needed |
||
| # Check attributes are set | ||
| assert hasattr(middleware, '_enabled_mode') | ||
| assert hasattr(middleware, '_similarity_threshold') | ||
|
|
@@ -109,7 +109,7 @@ async def mock_next_call(*args, **kwargs): | |
|
|
||
| async def test_fuzzy_match_caching(self, middleware_context): | ||
| """Test fuzzy matching with similarity_threshold < 1.0.""" | ||
| middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.8) | ||
| middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.9) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same, not needed |
||
|
|
||
| call_count = 0 | ||
|
|
||
|
|
@@ -267,8 +267,7 @@ async def mock_next_call(*args, **kwargs): | |
|
|
||
| def test_similarity_computation_for_different_thresholds(self): | ||
| """Test similarity computation for different thresholds.""" | ||
| # This is more of a unit test for the similarity logic | ||
| middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.5) | ||
| middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.9) | ||
|
Comment on lines
-271
to
+270
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Revert |
||
|
|
||
| # Directly test internal methods | ||
| # Add a cached entry | ||
|
|
@@ -278,14 +277,14 @@ def test_similarity_computation_for_different_thresholds(self): | |
| # Test various similarity levels | ||
| # Exact match | ||
| assert middleware._find_similar_key(test_key) == test_key # noqa | ||
| # Very similar | ||
| # Very similar (one char shorter, ~0.95 ratio) | ||
| assert middleware._find_similar_key("hello worl") == test_key # noqa | ||
| # Too different - use a completely different string | ||
| assert middleware._find_similar_key("xyz123abc") is None # noqa | ||
|
|
||
| async def test_multiple_similar_entries(self, middleware_context): | ||
| """Test behavior with multiple similar cached entries.""" | ||
| middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.7) | ||
| middleware = CacheMiddleware(enabled_mode="always", similarity_threshold=0.85) | ||
|
Comment on lines
-288
to
+287
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Revert |
||
|
|
||
| # Pre-populate cache with similar entries | ||
| key1 = middleware._serialize_input( # noqa | ||
|
|
@@ -306,3 +305,73 @@ async def mock_next_call(*args, **kwargs): | |
| input_str = {"value": "test input X", "number": 42} | ||
| await middleware.function_middleware_invoke(input_str, call_next=mock_next_call, context=middleware_context) | ||
| # The exact behavior depends on which cached key is most similar | ||
|
|
||
|
|
||
| class TestMaxEntriesLruEviction: | ||
| """The cache must bound its size to prevent memory-exhaustion DoS. | ||
|
|
||
| The previous implementation used an unbounded dict; sustained unique | ||
| inputs would grow the cache without limit, eventually crashing the | ||
| process. LRU eviction ensures the cache stays within max_entries. | ||
| """ | ||
|
|
||
| async def test_default_max_entries_is_positive(self): | ||
| mw = CacheMiddleware(enabled_mode="always", similarity_threshold=1.0) | ||
| assert mw._max_entries > 0 # noqa: SLF001 | ||
|
|
||
| async def test_cache_evicts_oldest_when_exceeding_max_entries(self, middleware_context): | ||
| """Insert more unique entries than max_entries; verify size stays bounded.""" | ||
| mw = CacheMiddleware( | ||
| enabled_mode="always", | ||
| similarity_threshold=1.0, # exact match keeps the test deterministic | ||
| max_entries=3, | ||
| ) | ||
|
|
||
| call_count = 0 | ||
|
|
||
| async def mock_next_call(*_args, **_kwargs): | ||
| nonlocal call_count | ||
| call_count += 1 | ||
| return _TestOutput(result=f"result_{call_count}") | ||
|
|
||
| for i in range(10): | ||
| await mw.function_middleware_invoke( | ||
| {"value": f"unique_input_{i}"}, | ||
| call_next=mock_next_call, | ||
| context=middleware_context, | ||
| ) | ||
|
|
||
| assert len(mw._cache) == 3 # noqa: SLF001 | ||
| # The MOST recent three inserts should be what's left. | ||
| latest_keys = list(mw._cache.keys()) # noqa: SLF001 | ||
| for i in range(7, 10): | ||
| assert any(f"unique_input_{i}" in k for k in latest_keys) | ||
|
|
||
| async def test_cache_hit_promotes_entry_to_most_recently_used(self, middleware_context): | ||
| """A cache hit should move the entry to MRU so later evictions spare it.""" | ||
| mw = CacheMiddleware( | ||
| enabled_mode="always", | ||
| similarity_threshold=1.0, | ||
| max_entries=3, | ||
| ) | ||
|
|
||
| async def mock_next_call(*_args, **_kwargs): | ||
| return _TestOutput(result="r") | ||
|
|
||
| # Fill the cache with A, B, C (A is oldest) | ||
| for key in ("A", "B", "C"): | ||
| await mw.function_middleware_invoke( | ||
| {"value": key}, call_next=mock_next_call, context=middleware_context) | ||
|
|
||
| # Hit A again — should promote A to the MRU end | ||
| await mw.function_middleware_invoke( | ||
| {"value": "A"}, call_next=mock_next_call, context=middleware_context) | ||
|
|
||
| # Now insert D — B (now oldest) should be evicted, not A. | ||
| await mw.function_middleware_invoke( | ||
| {"value": "D"}, call_next=mock_next_call, context=middleware_context) | ||
|
|
||
| keys = "".join(list(mw._cache.keys())) # noqa: SLF001 | ||
| assert '"value": "A"' in keys | ||
| assert '"value": "D"' in keys | ||
| assert '"value": "B"' not in keys | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is needed as a global const, the default value should be in the config class, since we only construct
CacheMiddlewarefrom the config object, the constructor similarly doesn't need a default value.