Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 26 additions & 22 deletions src/utils/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from src.config import load_yaml_config

logger = logging.getLogger(__name__)
logger = logging.getLogger("src.utils.token_manager")


def get_search_config():
Expand Down Expand Up @@ -46,10 +46,9 @@ def count_tokens(self, messages: List[BaseMessage]) -> int:
Returns:
Number of tokens
"""
total_tokens = 0
for message in messages:
total_tokens += self._count_message_tokens(message)
return total_tokens
# Micro-optimization: use localvar lookup, avoid attribute access in hot loop
_count_message_tokens = self._count_message_tokens
return sum(_count_message_tokens(m) for m in messages)

def _count_message_tokens(self, message: BaseMessage) -> int:
"""
Expand Down Expand Up @@ -154,7 +153,6 @@ def compress_messages(self, state: dict) -> List[BaseMessage]:
Returns:
Compressed state with compressed messages
"""
# If not set token_limit, return original state
if self.token_limit is None:
logger.info("No token_limit set, the context management doesn't work.")
return state
Expand All @@ -165,14 +163,16 @@ def compress_messages(self, state: dict) -> List[BaseMessage]:

messages = state["messages"]

if not self.is_over_limit(messages):
# Avoid recomputation: always do the count once, as this is expensive
token_count = self.count_tokens(messages)
if token_count <= self.token_limit:
return state

# 2. Compress messages
# Compress messages
compressed_messages = self._compress_messages(messages)

logger.info(
f"Message compression completed: {self.count_tokens(messages)} -> {self.count_tokens(compressed_messages)} tokens"
f"Message compression completed: {token_count} -> {self.count_tokens(compressed_messages)} tokens"
)

state["messages"] = compressed_messages
Expand All @@ -191,16 +191,18 @@ def _compress_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:

available_token = self.token_limit
prefix_messages = []
_count_message_tokens = self._count_message_tokens
_truncate_message_content = self._truncate_message_content

# 1. Preserve head messages of specified length to retain system prompts and user input
for i in range(min(self.preserve_prefix_message_count, len(messages))):
cur_token_cnt = self._count_message_tokens(messages[i])
prefix_count = min(self.preserve_prefix_message_count, len(messages))
for i in range(prefix_count):
cur_token_cnt = _count_message_tokens(messages[i])
if available_token > 0 and available_token >= cur_token_cnt:
prefix_messages.append(messages[i])
available_token -= cur_token_cnt
elif available_token > 0:
# Truncate content to fit available tokens
truncated_message = self._truncate_message_content(
truncated_message = _truncate_message_content(
messages[i], available_token
)
prefix_messages.append(truncated_message)
Expand All @@ -209,24 +211,26 @@ def _compress_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
break

# 2. Compress subsequent messages from the tail, some messages may be discarded
messages = messages[len(prefix_messages) :]
# Faster: don't slice in loop, just reuse views
tail_messages = messages[prefix_count:]
suffix_messages = []
for i in range(len(messages) - 1, -1, -1):
cur_token_cnt = self._count_message_tokens(messages[i])
for i in range(len(tail_messages) - 1, -1, -1):
msg = tail_messages[i]
cur_token_cnt = _count_message_tokens(msg)

if cur_token_cnt > 0 and available_token >= cur_token_cnt:
suffix_messages = [messages[i]] + suffix_messages
suffix_messages.append(msg)
available_token -= cur_token_cnt
elif available_token > 0:
# Truncate content to fit available tokens
truncated_message = self._truncate_message_content(
messages[i], available_token
)
suffix_messages = [truncated_message] + suffix_messages
truncated_message = _truncate_message_content(msg, available_token)
suffix_messages.append(truncated_message)
# Reverse once to avoid multiple list concatenations
suffix_messages.reverse()
return prefix_messages + suffix_messages
else:
break

suffix_messages.reverse()
return prefix_messages + suffix_messages

def _truncate_message_content(
Expand Down