diff --git a/src/utils/context_manager.py b/src/utils/context_manager.py index b56d6a651..2c8891f6b 100644 --- a/src/utils/context_manager.py +++ b/src/utils/context_manager.py @@ -13,7 +13,7 @@ from src.config import load_yaml_config -logger = logging.getLogger(__name__) +logger = logging.getLogger("src.utils.token_manager") def get_search_config(): @@ -46,10 +46,9 @@ def count_tokens(self, messages: List[BaseMessage]) -> int: Returns: Number of tokens """ - total_tokens = 0 - for message in messages: - total_tokens += self._count_message_tokens(message) - return total_tokens + # Micro-optimization: use localvar lookup, avoid attribute access in hot loop + _count_message_tokens = self._count_message_tokens + return sum(_count_message_tokens(m) for m in messages) def _count_message_tokens(self, message: BaseMessage) -> int: """ @@ -154,7 +153,6 @@ def compress_messages(self, state: dict) -> List[BaseMessage]: Returns: Compressed state with compressed messages """ - # If not set token_limit, return original state if self.token_limit is None: logger.info("No token_limit set, the context management doesn't work.") return state @@ -165,14 +163,16 @@ def compress_messages(self, state: dict) -> List[BaseMessage]: messages = state["messages"] - if not self.is_over_limit(messages): + # Avoid recomputation: always do the count once, as this is expensive + token_count = self.count_tokens(messages) + if token_count <= self.token_limit: return state - # 2. Compress messages + # Compress messages compressed_messages = self._compress_messages(messages) logger.info( - f"Message compression completed: {self.count_tokens(messages)} -> {self.count_tokens(compressed_messages)} tokens" + f"Message compression completed: {token_count} -> {self.count_tokens(compressed_messages)} tokens" ) state["messages"] = compressed_messages @@ -191,16 +191,18 @@ def _compress_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]: available_token = self.token_limit prefix_messages = [] + _count_message_tokens = self._count_message_tokens + _truncate_message_content = self._truncate_message_content # 1. Preserve head messages of specified length to retain system prompts and user input - for i in range(min(self.preserve_prefix_message_count, len(messages))): - cur_token_cnt = self._count_message_tokens(messages[i]) + prefix_count = min(self.preserve_prefix_message_count, len(messages)) + for i in range(prefix_count): + cur_token_cnt = _count_message_tokens(messages[i]) if available_token > 0 and available_token >= cur_token_cnt: prefix_messages.append(messages[i]) available_token -= cur_token_cnt elif available_token > 0: - # Truncate content to fit available tokens - truncated_message = self._truncate_message_content( + truncated_message = _truncate_message_content( messages[i], available_token ) prefix_messages.append(truncated_message) @@ -209,24 +211,26 @@ def _compress_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]: break # 2. Compress subsequent messages from the tail, some messages may be discarded - messages = messages[len(prefix_messages) :] + # Faster: don't slice in loop, just reuse views + tail_messages = messages[prefix_count:] suffix_messages = [] - for i in range(len(messages) - 1, -1, -1): - cur_token_cnt = self._count_message_tokens(messages[i]) + for i in range(len(tail_messages) - 1, -1, -1): + msg = tail_messages[i] + cur_token_cnt = _count_message_tokens(msg) if cur_token_cnt > 0 and available_token >= cur_token_cnt: - suffix_messages = [messages[i]] + suffix_messages + suffix_messages.append(msg) available_token -= cur_token_cnt elif available_token > 0: - # Truncate content to fit available tokens - truncated_message = self._truncate_message_content( - messages[i], available_token - ) - suffix_messages = [truncated_message] + suffix_messages + truncated_message = _truncate_message_content(msg, available_token) + suffix_messages.append(truncated_message) + # Reverse once to avoid multiple list concatenations + suffix_messages.reverse() return prefix_messages + suffix_messages else: break + suffix_messages.reverse() return prefix_messages + suffix_messages def _truncate_message_content(