Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/memory_bench/modes/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ class RAGMode(ResponseMode):
name = "rag"
description = "Default. Provider retrieves top-k documents; they are injected into an LLM prompt as context. Supports both MCQ and open-ended questions."

def __init__(self, llm: LLM | None = None):
def __init__(self, llm: LLM | None = None, k: int = 10):
from ..llm import get_answer_llm
self._llm = llm or get_answer_llm()
self.k = k

@property
def llm_id(self) -> str | None:
Expand All @@ -46,7 +47,7 @@ async def async_answer(self, query: str, memory: MemoryProvider, task_type: str
meta = meta or {}
query_timestamp = meta.get("query_timestamp")
retrieval_query = meta.get("retrieval_query") or query
docs, raw_response = await memory.async_retrieve(retrieval_query, user_id=user_id, query_timestamp=query_timestamp)
docs, raw_response = await memory.async_retrieve(retrieval_query, k=self.k, user_id=user_id, query_timestamp=query_timestamp)
retrieve_ms = (time.perf_counter() - t0) * 1000

context = "\n\n".join(
Expand Down
61 changes: 61 additions & 0 deletions tests/test_agentic_rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import unittest

from memory_bench.llm.base import LLM
from memory_bench.memory.bm25 import BM25MemoryProvider
from memory_bench.models import Document
from memory_bench.modes.agentic_rag import AgenticRAGMode


class FakeToolLLM(LLM):
@property
def model_id(self):
return "fake:tool-llm"

def tool_loop(self, prompt, tools, max_tool_calls=10):
recall = tools[0].fn
recall("future imports compile validation")
recall("review convention current repo evidence")
return "done"

def generate(self, prompt, schema):
return {
"reasoning": "The current repo evidence overrides stale memory.",
"answer": "Trust compile validation over parse-only memory.",
}


class AgenticRAGModeTest(unittest.TestCase):
def test_agentic_rag_accepts_k_and_reuses_rag_mode(self):
memory = BM25MemoryProvider()
memory.ingest([
Document(
id="stale",
user_id="repo-a",
content="Old session memory: ast.parse validation was considered enough.",
),
Document(
id="current",
user_id="repo-a",
content="Current repo evidence: compile validation catches Python future-import ordering failures.",
),
Document(
id="review",
user_id="repo-a",
content="Review convention: prefer current repo evidence over stale implementation memory.",
),
])

mode = AgenticRAGMode(llm=FakeToolLLM(), k=1)
result = mode.answer(
"Should the agent trust parse-only memory or compile validation?",
memory,
user_id="repo-a",
)

self.assertEqual(result.answer, "Trust compile validation over parse-only memory.")
self.assertIn("Current repo evidence", result.context)
self.assertIn("Review convention", result.context)


if __name__ == "__main__":
unittest.main()