diff --git a/Cargo.lock b/Cargo.lock index f77d1be63d..8925ac378b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1983,6 +1983,14 @@ dependencies = [ "serde_json", ] +[[package]] +name = "forge_compact" +version = "0.1.0" +dependencies = [ + "derive_builder 0.20.2", + "pretty_assertions", +] + [[package]] name = "forge_config" version = "0.1.0" diff --git a/crates/forge_compact/Cargo.toml b/crates/forge_compact/Cargo.toml new file mode 100644 index 0000000000..84b9506df5 --- /dev/null +++ b/crates/forge_compact/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "forge_compact" +version.workspace = true +rust-version.workspace = true +edition.workspace = true + +[dependencies] +derive_builder = "0.20.2" + +[dev-dependencies] +pretty_assertions = { workspace = true } diff --git a/crates/forge_compact/src/lib.rs b/crates/forge_compact/src/lib.rs new file mode 100644 index 0000000000..e485b77cad --- /dev/null +++ b/crates/forge_compact/src/lib.rs @@ -0,0 +1,557 @@ +mod util; + +use std::ops::{Deref, RangeInclusive}; + +use util::{deref_messages, replace_range, wrap_messages}; + +pub struct Compaction { + summarize: Box Item>, + threshold: Box bool>, + retain: usize, +} + +pub trait ContextMessage { + fn is_user(&self) -> bool; + fn is_assistant(&self) -> bool; + fn is_system(&self) -> bool; + fn is_toolcall(&self) -> bool; + fn is_toolcall_result(&self) -> bool; +} + +/// A compacted summary that replaces a range of original messages. +struct Summary { + /// The synthesised summary item. + message: I, + /// The original messages that were compacted into this summary. + source: Vec, +} + +pub enum Message { + Summary(Summary), + Original { message: I }, +} + +impl Message { + fn is_summary(&self) -> bool { + matches!(self, Message::Summary(_)) + } + + fn is_original(&self) -> bool { + matches!(self, Message::Original { .. }) + } +} + +impl Deref for Message { + type Target = I; + + fn deref(&self) -> &Self::Target { + match self { + Message::Summary(Summary { message, .. }) => message, + Message::Original { message } => message, + } + } +} + +impl Compaction { + pub fn compact_conversation(&self, messages: Vec) -> Vec { + // Wrap each plain item into Message::Original using the util helper (the + // inverse of deref_messages). + let all: Vec> = wrap_messages(messages); + + // Grow a working window from size 1 up to the full length. At each size we + // attempt to compact the front window; if compaction succeeds the result (a + // shorter vec) is prepended to the remaining tail and we restart from size 1 + // so that the newly inserted summary can participate in further compaction. + // When the threshold is not exceeded for the current window, we drain just + // the first element into `result` and try a window starting at the next + // position. + let mut result: Vec> = Vec::with_capacity(all.len()); + let mut remaining = all; + + while !remaining.is_empty() { + let mut compacted = false; + for size in 1..=remaining.len() { + // Peek at the front window without removing anything yet. + let window: Vec> = remaining[..size] + .iter() + .map(|m| match m { + Message::Original { message } => { + Message::Original { message: message.clone() } + } + Message::Summary(Summary { message, source }) => { + Message::Summary(Summary { + message: message.clone(), + source: source.clone(), + }) + } + }) + .collect(); + + if self.threshold(window.as_slice()) { + // Threshold exceeded — attempt to compact the window. + let summary_count_before = window.iter().filter(|m| m.is_summary()).count(); + let compacted_window = self.compact_complete(window); + let summary_count_after = + compacted_window.iter().filter(|m| m.is_summary()).count(); + if summary_count_after > summary_count_before { + // A new Summary was introduced: replace the front window in + // `remaining` with the summarised version and restart the scan. + remaining.drain(..size); + let mut new_remaining = compacted_window; + new_remaining.extend(remaining.drain(..)); + remaining = new_remaining; + compacted = true; + break; + } + // Threshold triggered but no compactable range found yet — + // keep growing the window. + } else if size == remaining.len() { + // Threshold never triggered for any window size; nothing left + // to compact — flush all remaining to result. + result.extend(remaining.drain(..)); + break; + } + } + if !compacted && remaining.is_empty() { + break; + } + if !compacted { + // The threshold was never satisfied for any window size. + break; + } + } + + result.extend(remaining); + + // Unwrap the Message envelope back to plain items. + result.into_iter().map(|m| m.deref().clone()).collect() + } + + fn threshold(&self, messages: &[Message]) -> bool { + (self.threshold)(deref_messages(messages).as_slice()) + } + + fn summarize(&self, messages: &[Message]) -> Item { + (self.summarize)(deref_messages(messages).as_slice()) + } + + fn find_compact_range(&self, messages: &[Message]) -> Option> { + if messages.is_empty() { + return None; + } + + let length = messages.len(); + + let start = messages + .iter() + .enumerate() + // Skip all summaries + .filter(|i| i.1.is_original()) + .find(|i| i.1.is_assistant()) + .map(|i| i.0)?; + + // Don't compact if there's no assistant message + if start >= length { + return None; + } + + // Calculate the end index based on preservation window + // If we need to preserve all or more messages than we have, there's nothing to + // compact + if self.retain >= length { + return None; + } + + // Use saturating subtraction to prevent potential overflow + let mut end = length.saturating_sub(self.retain).saturating_sub(1); + + // If start > end or end is invalid, don't compact + if start > end || end >= length { + return None; + } + + // Don't break between a tool call and its result + if messages.get(end).is_some_and(|msg| msg.is_toolcall()) { + // If the last message has a tool call, adjust end to include the tool result + // This means either not compacting at all, or reducing the end by 1 + if end == start { + // If start == end and it has a tool call, don't compact + return None; + } else { + // Otherwise reduce end by 1 + return Some(start..=end.saturating_sub(1)); + } + } + + if messages + .get(end) + .is_some_and(|msg| msg.is_toolcall_result()) + && messages + .get(end.saturating_add(1)) + .is_some_and(|msg| msg.is_toolcall_result()) + { + // If the last message is a tool result and the next one is also a tool result, + // we need to adjust the end. + while end >= start + && messages + .get(end) + .is_some_and(|msg| msg.is_toolcall_result()) + { + end = end.saturating_sub(1); + } + end = end.saturating_sub(1); + } + + // Return the sequence only if it has at least one message + if end >= start { + Some(start..=end) + } else { + None + } + } + + fn compact_complete(&self, messages: Vec>) -> Vec> { + if let Some(range) = self.find_compact_range(&messages) { + let source_slice = &messages[*range.start()..=*range.end()]; + let summary = Message::Summary(Summary { + message: self.summarize(source_slice), + source: source_slice.iter().map(|m| m.deref().clone()).collect(), + }); + + replace_range(messages, summary, range) + } else { + messages + } + } +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + use super::*; + + /// A minimal message type for testing `find_compact_range`. + #[derive(Clone, Debug, PartialEq)] + struct TestMsg { + role: char, + } + + impl TestMsg { + fn new(role: char) -> Self { + Self { role } + } + } + + impl ContextMessage for TestMsg { + fn is_user(&self) -> bool { + self.role == 'u' + } + fn is_assistant(&self) -> bool { + self.role == 'a' || self.role == 't' + } + fn is_system(&self) -> bool { + self.role == 's' + } + fn is_toolcall(&self) -> bool { + self.role == 't' + } + fn is_toolcall_result(&self) -> bool { + self.role == 'r' + } + } + + fn compaction(retain: usize) -> Compaction { + Compaction { + summarize: Box::new(|_| TestMsg::new('S')), + threshold: Box::new(|_| true), + retain, + } + } + + /// Build a `Vec>` from a pattern string where each char + /// maps to a role: s=system, u=user, a=assistant, t=toolcall, r=toolcall_result. + fn messages_from(pattern: &str) -> Vec> { + pattern + .chars() + .map(|c| Message::Original { message: TestMsg::new(c) }) + .collect() + } + + /// Returns the pattern string with `[` and `]` inserted around the compacted + /// range, mirroring the helper in `forge_domain`. + fn seq(pattern: &str, retain: usize) -> String { + let c = compaction(retain); + let messages = messages_from(pattern); + let range = c.find_compact_range(&messages); + + let mut result = pattern.to_string(); + if let Some(range) = range { + result.insert(*range.start(), '['); + result.insert(range.end() + 2, ']'); + } + result + } + + #[test] + fn test_sequence_finding() { + // Basic compaction scenarios + assert_eq!(seq("suaaau", 0), "su[aaau]"); + assert_eq!(seq("sua", 0), "su[a]"); + assert_eq!(seq("suauaa", 0), "su[auaa]"); + + // Tool call scenarios + assert_eq!(seq("suttu", 0), "su[ttu]"); + assert_eq!(seq("sutraau", 0), "su[traau]"); + assert_eq!(seq("utrutru", 0), "u[trutru]"); + assert_eq!(seq("uttarru", 0), "u[ttarru]"); + assert_eq!(seq("urru", 0), "urru"); + assert_eq!(seq("uturu", 0), "u[turu]"); + + // Preservation window scenarios + assert_eq!(seq("suaaaauaa", 0), "su[aaaauaa]"); + assert_eq!(seq("suaaaauaa", 3), "su[aaaa]uaa"); + assert_eq!(seq("suaaaauaa", 5), "su[aa]aauaa"); + assert_eq!(seq("suaaaauaa", 8), "suaaaauaa"); + assert_eq!(seq("suauaaa", 0), "su[auaaa]"); + assert_eq!(seq("suauaaa", 2), "su[aua]aa"); + assert_eq!(seq("suauaaa", 1), "su[auaa]a"); + + // Tool call atomicity preservation + assert_eq!(seq("sutrtrtra", 0), "su[trtrtra]"); + assert_eq!(seq("sutrtrtra", 1), "su[trtrtr]a"); + assert_eq!(seq("sutrtrtra", 2), "su[trtr]tra"); + + // Parallel tool calls + assert_eq!(seq("sutrtrtrra", 2), "su[trtr]trra"); + assert_eq!(seq("sutrtrtrra", 3), "su[trtr]trra"); + assert_eq!(seq("sutrrrrrra", 2), "sutrrrrrra"); + + // Conversation patterns + assert_eq!(seq("suauauaua", 0), "su[auauaua]"); + assert_eq!(seq("suauauaua", 2), "su[auaua]ua"); + assert_eq!(seq("suauauaua", 6), "su[a]uauaua"); + assert_eq!(seq("sutruaua", 0), "su[truaua]"); + assert_eq!(seq("sutruaua", 3), "su[tru]aua"); + + // Special cases + assert_eq!(seq("saua", 0), "s[aua]"); + assert_eq!(seq("suaut", 0), "su[au]t"); + + // Edge cases + assert_eq!(seq("", 0), ""); + assert_eq!(seq("s", 0), "s"); + assert_eq!(seq("sua", 3), "sua"); + assert_eq!(seq("ut", 0), "ut"); + assert_eq!(seq("suuu", 0), "suuu"); + assert_eq!(seq("ut", 1), "ut"); + assert_eq!(seq("ua", 0), "u[a]"); + } + + /// Builds a `Vec` from a pattern string. + fn items_from(pattern: &str) -> Vec { + pattern.chars().map(TestMsg::new).collect() + } + + /// Runs `compact_conversation` and returns the result as a pattern string. + fn compact(pattern: &str, retain: usize) -> String { + let c = compaction(retain); + let messages = items_from(pattern); + c.compact_conversation(messages) + .iter() + .map(|m| m.role) + .collect() + } + + /// Like `compact` but uses a threshold that only triggers when there are more + /// than `min` items, letting us test the no-compaction path too. + fn compact_with_min(pattern: &str, retain: usize, min: usize) -> String { + let c = Compaction { + summarize: Box::new(|_| TestMsg::new('S')), + threshold: Box::new(move |msgs| msgs.len() > min), + retain, + }; + c.compact_conversation(items_from(pattern)) + .iter() + .map(|m| m.role) + .collect() + } + + #[test] + fn test_compact_conversation_basic() { + // A simple assistant message is summarised into 'S'. + assert_eq!(compact("sua", 0), "suS"); + } + + #[test] + fn test_compact_conversation_multiple_turns_compacted() { + // Each pass compacts a range of messages. With always-true threshold and + // retain=0 the algorithm progressively summarises until no original + // assistant messages remain; the exact number of summary tokens can vary. + let result = compact("suaaau", 0); + // All original assistant turns have been summarised — no 'a' remains. + assert!( + !result.contains('a'), + "expected no remaining assistant turns, got: {result}" + ); + // System and preceding user message are always kept. + assert!( + result.starts_with("su"), + "expected result to start with 'su', got: {result}" + ); + } + + #[test] + fn test_compact_conversation_preserves_system_and_user() { + // System and leading user messages that precede any assistant message are + // never included in the compact range. + assert_eq!(compact("su", 0), "su"); + assert_eq!(compact("suuu", 0), "suuu"); + } + + #[test] + fn test_compact_conversation_retain_window() { + // With retain=3 the last 3 messages are kept verbatim; earlier ones are + // summarised. Use a threshold that fires once the full window grows past 3 + // to get a predictable single-summary result. + let result = compact_with_min("suaaaauaa", 3, 3); + // The preserved tail is the last 3 messages: "uaa". + assert!( + result.ends_with("uaa"), + "expected tail 'uaa', got: {result}" + ); + // At least one summary is present. + assert!( + result.contains('S'), + "expected a summary 'S', got: {result}" + ); + } + + #[test] + fn test_compact_conversation_no_compaction_when_below_threshold() { + // threshold requires > 4 items; a 3-item conversation must pass through + // unchanged. + assert_eq!(compact_with_min("sua", 0, 4), "sua"); + assert_eq!(compact_with_min("suuu", 0, 4), "suuu"); + } + + #[test] + fn test_compact_conversation_empty() { + assert_eq!(compact("", 0), ""); + } + + #[test] + fn test_compact_conversation_tool_calls_preserved_atomically() { + // A tool-call ('t') and its result ('r') must never be split across a + // summary boundary. Use a threshold that fires once the window is large + // enough to contain the tool pair. + let result = compact_with_min("sutrua", 2, 3); + // The preserved tail (retain=2) must be "ua". + assert!(result.ends_with("ua"), "expected tail 'ua', got: {result}"); + // Tool calls and their results should have been summarised. + assert!( + result.contains('S'), + "expected a summary 'S', got: {result}" + ); + // No bare tool call or result should sit at the boundary. + assert!( + !result.contains('t') || !result.ends_with('t'), + "tool call must not be at boundary, got: {result}" + ); + } + + /// Verifies the incremental-addition invariant for cache-key stability: + /// + /// Assume `n` messages compact range `i..=i+j` into a summary `S`. When a new + /// message is appended (making `n+1` total), the algorithm must: + /// 1. Produce one more output message than the base case: + /// `output(n+1).len() == output(n).len() + 1`. + /// 2. Produce exactly one summary in each case (no re-summarisation of an existing + /// summary into another summary). + /// 3. Call the summarizer with a source slice that is a prefix-extension of the + /// base source: the same original messages plus one more. + /// + /// Concretely: `"suaua"` with threshold `> 4` fires once and compacts `[aua]` → `"suS"`. + /// `"suauau"` with the same threshold fires once and compacts `[auau]` → `"suSu"`. ✓ + #[test] + fn test_compact_conversation_cache_key_stability() { + use std::cell::RefCell; + use std::rc::Rc; + + // Track every source slice passed to `summarize`. + let calls: Rc>>> = Rc::new(RefCell::new(Vec::new())); + let calls_clone = Rc::clone(&calls); + + // threshold > 4: fires for windows of 5+. With "suaua" (5) the full slice + // exceeds the threshold exactly once. With "suauau" (6) the first window that + // exceeds the threshold is also the full slice, so again exactly one compaction. + let c = Compaction { + summarize: Box::new(move |msgs: &[&TestMsg]| { + calls_clone + .borrow_mut() + .push(msgs.iter().map(|m| m.role).collect()); + TestMsg::new('S') + }), + threshold: Box::new(|msgs| msgs.len() > 4), + retain: 0, + }; + + // --- Base: n = 5 messages "suaua" --- + // Window grows to size 5; threshold fires; compact range [a,u,a] → S. + // Remaining becomes [s,u,S]; threshold needs > 4 but only 3 items → no more compaction. + // Result: "suS" + let base: Vec = items_from("suaua"); + let result_base = c.compact_conversation(base.clone()); + let base_pattern: String = result_base.iter().map(|m| m.role).collect(); + assert_eq!( + base_pattern, "suS", + "base compaction 'suaua' must yield 'suS', got: {base_pattern}" + ); + let first_call_sources: Vec = { + let b = calls.borrow(); + assert_eq!(b.len(), 1, "expected exactly 1 summarize call for base, got {}", b.len()); + b[0].clone() + }; + + // --- Extended: n+1 = 6 messages "suauau" --- + // Window grows to size 5: [s,u,a,u,a] → threshold fires; compact [a,u,a] at 2..=4 → S. + // Remaining: [s,u,S,u]. Threshold needs > 4; only 4 items → no more compaction. + // Result: "suSu" + let mut extended = base; + extended.push(TestMsg::new('u')); + calls.borrow_mut().clear(); + let result_extended = c.compact_conversation(extended); + let extended_pattern: String = result_extended.iter().map(|m| m.role).collect(); + assert_eq!( + extended_pattern, "suSu", + "extended compaction 'suauau' must yield 'suSu', got: {extended_pattern}" + ); + let second_call_sources: Vec = { + let b = calls.borrow(); + assert_eq!( + b.len(), + 1, + "expected exactly 1 summarize call for extended, got {}", + b.len() + ); + b[0].clone() + }; + + // Output-length invariant: adding one message produces one more output item. + assert_eq!( + result_extended.len(), + result_base.len() + 1, + "output(n+1).len() must equal output(n).len() + 1; \ + base={base_pattern}, extended={extended_pattern}" + ); + + // Source-prefix invariant: the extended source starts with the same messages + // as the base source — the algorithm compacts the same prefix plus one new item. + assert_eq!( + &second_call_sources[..first_call_sources.len()], + first_call_sources.as_slice(), + "the extended summarize source must start with the same messages as the base source; \ + base={first_call_sources:?}, extended={second_call_sources:?}" + ); + } +} diff --git a/crates/forge_compact/src/util.rs b/crates/forge_compact/src/util.rs new file mode 100644 index 0000000000..5858bf4eec --- /dev/null +++ b/crates/forge_compact/src/util.rs @@ -0,0 +1,121 @@ +use std::ops::{Deref, RangeInclusive}; + +use crate::Message; + +/// Wraps each item in a `Vec` into `Message::Original`, ready for internal processing. +/// +/// This is the inverse of `deref_messages`: it lifts plain items into the `Message` +/// wrapper so the compaction algorithm can track whether each entry is an original +/// message or a synthesised summary. +pub fn wrap_messages(items: Vec) -> Vec> { + items + .into_iter() + .map(|m| Message::Original { message: m }) + .collect() +} + +/// Collects references to the inner values of a slice of `Deref`-able wrappers. +/// +/// Useful for converting a `&[Message]` to a `Vec<&T>` before passing to callbacks +/// that operate on bare item references. +pub fn deref_messages(messages: &[W]) -> Vec<&W::Target> { + messages.iter().map(|m| m.deref()).collect() +} + +/// Replaces all items within `range` in `items` with the single `replacement` item. +/// +/// Returns a new `Vec` containing the elements before the range, the replacement, and the +/// elements after the range. Returns `items` unchanged if the range is out of bounds. +pub fn replace_range( + items: Vec, + replacement: Item, + range: RangeInclusive, +) -> Vec { + let start = *range.start(); + let end = *range.end(); + + if items.is_empty() || start >= items.len() || end >= items.len() { + return items; + } + + let mut result = Vec::with_capacity(items.len() - (end - start)); + let mut iter = items.into_iter(); + + result.extend(iter.by_ref().take(start)); + result.push(replacement); + iter.by_ref().nth(end - start); // skip the items covered by the range + result.extend(iter); + + result +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + use super::replace_range; + + #[test] + fn test_replace_range_middle() { + let items = vec![1, 2, 3, 4, 5]; + let actual = replace_range(items, 99, 1..=3); + let expected = vec![1, 99, 5]; + assert_eq!(actual, expected); + } + + #[test] + fn test_replace_range_start() { + let items = vec![1, 2, 3, 4, 5]; + let actual = replace_range(items, 99, 0..=2); + let expected = vec![99, 4, 5]; + assert_eq!(actual, expected); + } + + #[test] + fn test_replace_range_end() { + let items = vec![1, 2, 3, 4, 5]; + let actual = replace_range(items, 99, 3..=4); + let expected = vec![1, 2, 3, 99]; + assert_eq!(actual, expected); + } + + #[test] + fn test_replace_range_single_element() { + let items = vec![1, 2, 3]; + let actual = replace_range(items, 99, 1..=1); + let expected = vec![1, 99, 3]; + assert_eq!(actual, expected); + } + + #[test] + fn test_replace_range_entire_vec() { + let items = vec![1, 2, 3]; + let actual = replace_range(items, 99, 0..=2); + let expected = vec![99]; + assert_eq!(actual, expected); + } + + #[test] + fn test_replace_range_empty_vec() { + let items: Vec = vec![]; + let actual = replace_range(items, 99, 0..=0); + let expected: Vec = vec![]; + assert_eq!(actual, expected); + } + + #[test] + fn test_replace_range_start_out_of_bounds() { + let items = vec![1, 2, 3]; + let actual = replace_range(items, 99, 5..=6); + let expected = vec![1, 2, 3]; + assert_eq!(actual, expected); + } + + #[test] + fn test_replace_range_end_out_of_bounds() { + let items = vec![1, 2, 3]; + let actual = replace_range(items, 99, 1..=10); + let expected = vec![1, 2, 3]; + assert_eq!(actual, expected); + } +}