From f15553fc82a195f313d276499ec245b33f74657b Mon Sep 17 00:00:00 2001 From: Shahmir Varqha Date: Fri, 26 Jun 2026 11:44:24 +0800 Subject: [PATCH] add open chats atom, bust cache --- frontend/src/__tests__/chat-utils.test.ts | 143 +++++++++++++++++- .../components/chat/chat-history-popover.tsx | 2 +- frontend/src/components/chat/chat-panel.tsx | 18 +-- frontend/src/components/chat/chat-tabs.tsx | 95 ++++++++++++ frontend/src/core/ai/state.ts | 125 +++++++++++++-- 5 files changed, 356 insertions(+), 27 deletions(-) create mode 100644 frontend/src/components/chat/chat-tabs.tsx diff --git a/frontend/src/__tests__/chat-utils.test.ts b/frontend/src/__tests__/chat-utils.test.ts index ab880417469..8a8a179e08b 100644 --- a/frontend/src/__tests__/chat-utils.test.ts +++ b/frontend/src/__tests__/chat-utils.test.ts @@ -4,9 +4,28 @@ import type { UIMessage } from "ai"; import { describe, expect, it } from "vitest"; import { Maps } from "@/utils/maps"; import { replaceMessagesInChat } from "../core/ai/chat-utils"; -import type { Chat, ChatId, ChatState } from "../core/ai/state"; +import { + closeChatTab, + type Chat, + type ChatId, + type ChatState, + MAX_STORED_CHATS, + openChatTab, + pruneChats, +} from "../core/ai/state"; const CHAT_1 = "chat-1" as ChatId; +const CHAT_2 = "chat-2" as ChatId; + +function makeChat(id: number): Chat { + return { + id: `chat-${id}` as ChatId, + title: `Chat ${id}`, + messages: [{ id: `m-${id}`, role: "user", parts: [] }], + createdAt: id, + updatedAt: id, + }; +} function asMap(list: Iterable) { return Maps.keyBy(list, (c) => c.id); @@ -30,6 +49,7 @@ describe("replaceMessagesInChat", () => { }, ]), activeChatId: CHAT_1, + openChatIds: [CHAT_1], }; it("replaces messages in a chat", () => { @@ -61,3 +81,124 @@ describe("replaceMessagesInChat", () => { expect(result).toEqual(mockChatState); }); }); + +describe("openChatTab", () => { + const baseState: ChatState = { + chats: asMap([ + { + id: CHAT_1, + title: "Chat 1", + messages: [], + createdAt: 1000, + updatedAt: 1000, + }, + { + id: CHAT_2, + title: "Chat 2", + messages: [], + createdAt: 2000, + updatedAt: 2000, + }, + ]), + activeChatId: null, + openChatIds: [], + }; + + it("opens a chat tab and sets it active", () => { + const result = openChatTab(baseState, CHAT_1); + expect(result.activeChatId).toBe(CHAT_1); + expect(result.openChatIds).toEqual([CHAT_1]); + }); + + it("does not duplicate open tabs", () => { + const state = { ...baseState, openChatIds: [CHAT_1], activeChatId: CHAT_2 }; + const result = openChatTab(state, CHAT_1); + expect(result.openChatIds).toEqual([CHAT_1]); + expect(result.activeChatId).toBe(CHAT_1); + }); +}); + +describe("closeChatTab", () => { + const baseState: ChatState = { + chats: asMap([ + { + id: CHAT_1, + title: "Chat 1", + messages: [{ id: "m1", role: "user", parts: [] }], + createdAt: 1000, + updatedAt: 1000, + }, + { + id: CHAT_2, + title: "Chat 2", + messages: [{ id: "m2", role: "user", parts: [] }], + createdAt: 2000, + updatedAt: 2000, + }, + ]), + activeChatId: CHAT_2, + openChatIds: [CHAT_1, CHAT_2], + }; + + it("hides a tab without deleting the chat", () => { + const result = closeChatTab(baseState, CHAT_1); + expect(result.openChatIds).toEqual([CHAT_2]); + expect(result.activeChatId).toBe(CHAT_2); + expect(result.chats.has(CHAT_1)).toBe(true); + }); + + it("activates a neighbor when closing the active tab", () => { + const result = closeChatTab(baseState, CHAT_2); + expect(result.openChatIds).toEqual([CHAT_1]); + expect(result.activeChatId).toBe(CHAT_1); + expect(result.chats.has(CHAT_2)).toBe(true); + }); + + it("clears active chat when closing the last tab", () => { + const state = { ...baseState, openChatIds: [CHAT_2], activeChatId: CHAT_2 }; + const result = closeChatTab(state, CHAT_2); + expect(result.openChatIds).toEqual([]); + expect(result.activeChatId).toBeNull(); + expect(result.chats.has(CHAT_2)).toBe(true); + }); + + it("returns unchanged state when the tab is not open", () => { + const state = { ...baseState, openChatIds: [CHAT_2], activeChatId: CHAT_2 }; + const result = closeChatTab(state, CHAT_1); + expect(result).toBe(state); + }); +}); + +describe("pruneChats", () => { + it("returns the same map when under the cap", () => { + const chats = Maps.keyBy( + Array.from({ length: 5 }, (_, i) => makeChat(i)), + (c) => c.id, + ); + expect(pruneChats(chats, [])).toBe(chats); + }); + + it("keeps the most recently updated chats up to the cap", () => { + const chats = Maps.keyBy( + Array.from({ length: MAX_STORED_CHATS + 5 }, (_, i) => makeChat(i)), + (c) => c.id, + ); + const result = pruneChats(chats, []); + expect(result.size).toBe(MAX_STORED_CHATS); + // The 5 oldest (updatedAt 0..4) should be evicted. + expect(result.has("chat-0" as ChatId)).toBe(false); + expect(result.has("chat-4" as ChatId)).toBe(false); + expect(result.has("chat-5" as ChatId)).toBe(true); + }); + + it("never evicts protected (open) chats, even when old", () => { + const chats = Maps.keyBy( + Array.from({ length: MAX_STORED_CHATS + 5 }, (_, i) => makeChat(i)), + (c) => c.id, + ); + const oldId = "chat-0" as ChatId; + const result = pruneChats(chats, [oldId]); + expect(result.has(oldId)).toBe(true); + expect(result.size).toBe(MAX_STORED_CHATS + 1); + }); +}); diff --git a/frontend/src/components/chat/chat-history-popover.tsx b/frontend/src/components/chat/chat-history-popover.tsx index ab3ffa54966..0a38ddab6fb 100644 --- a/frontend/src/components/chat/chat-history-popover.tsx +++ b/frontend/src/components/chat/chat-history-popover.tsx @@ -21,7 +21,7 @@ import { groupChatsByDate } from "./chat-history-utils"; interface ChatHistoryPopoverProps { activeChatId: ChatId | undefined; - setActiveChat: (id: ChatId | null) => void; + setActiveChat: (chatId: ChatId | null) => void; } export const ChatHistoryPopover: React.FC = ({ diff --git a/frontend/src/components/chat/chat-panel.tsx b/frontend/src/components/chat/chat-panel.tsx index 6167b2465f4..49c1e097dee 100644 --- a/frontend/src/components/chat/chat-panel.tsx +++ b/frontend/src/components/chat/chat-panel.tsx @@ -42,6 +42,7 @@ import { AiModelId } from "@/core/ai/ids/ids"; import { useStagedAICellsActions } from "@/core/ai/staged-cells"; import { activeChatAtom, + addChatAndOpenTab, type Chat, type ChatId, chatStateAtom, @@ -84,6 +85,7 @@ import { } from "./chat-components"; import { renderUIMessage } from "./chat-display"; import { ChatHistoryPopover } from "./chat-history-popover"; +import { ChatTabs } from "./chat-tabs"; import { convertToFileUIPart, generateChatTitle, @@ -101,7 +103,7 @@ const DEFAULT_MODE = "manual"; interface ChatHeaderProps { onNewChat: () => void; activeChatId: ChatId | undefined; - setActiveChat: (id: ChatId | null) => void; + setActiveChat: (chatId: ChatId | null) => void; } const ChatHeader: React.FC = ({ @@ -613,17 +615,8 @@ const ChatPanelBody = () => { updatedAt: now, }; - // Create new chat and set as active - setChatState((prev) => { - const newChats = new Map(prev.chats); - newChats.set(newChat.id, newChat); - const newState = { - ...prev, - chats: newChats, - activeChatId: newChat.id, - }; - return newState; - }); + // Create new chat, open it as a tab, and set as active + setChatState((prev) => addChatAndOpenTab(prev, newChat)); const fileParts = initialAttachments && initialAttachments.length > 0 @@ -783,6 +776,7 @@ const ChatPanelBody = () => { activeChatId={activeChat?.id} setActiveChat={setActiveChat} /> +
void; + onClose: (chatId: ChatId) => void; +} + +const ChatTab = memo(({ chat, isActive, onSelect, onClose }) => { + return ( +
onSelect(chat.id)} + > + + {chat.title} + + +
+ ); +}); +ChatTab.displayName = "ChatTab"; + +export const ChatTabs = memo(() => { + const [chatState, setChatState] = useAtom(chatStateAtom); + const setActiveChat = useSetAtom(activeChatAtom); + + const openChats = useMemo(() => { + return chatState.openChatIds + .map((id) => chatState.chats.get(id)) + .filter((chat): chat is Chat => chat !== undefined); + }, [chatState.chats, chatState.openChatIds]); + + const handleSelectChat = useEvent((chatId: ChatId) => { + setActiveChat(chatId); + }); + + const handleCloseChat = useEvent((chatId: ChatId) => { + setChatState((prev) => closeChatTab(prev, chatId)); + }); + + if (openChats.length === 0) { + return null; + } + + return ( +
+
+ {openChats.map((chat) => ( + + ))} +
+
+ ); +}); +ChatTabs.displayName = "ChatTabs"; diff --git a/frontend/src/core/ai/state.ts b/frontend/src/core/ai/state.ts index 7aef97b0d39..0ad09142e43 100644 --- a/frontend/src/core/ai/state.ts +++ b/frontend/src/core/ai/state.ts @@ -9,7 +9,8 @@ import { adaptForLocalStorage, jotaiJsonStorage } from "@/utils/storage/jotai"; import type { TypedString } from "@/utils/typed"; import type { CellId } from "../cells/ids"; -const KEY = "marimo:ai:chatState:v5"; +const KEY = "marimo:ai:chatState:v6"; +export const MAX_STORED_CHATS = 25; export type ChatId = TypedString<"ChatId">; @@ -48,6 +49,8 @@ export interface Chat { export interface ChatState { chats: Map; activeChatId: ChatId | null; + /** Chat ids with an open tab, in left-to-right order. */ + openChatIds: ChatId[]; } function removeEmptyChats(chatState: Map): Map { @@ -64,21 +67,115 @@ function removeEmptyChats(chatState: Map): Map { return result; } +interface SerializableChatState { + chats: [ChatId, Chat][]; + activeChatId: ChatId | null; + openChatIds?: ChatId[]; +} + +function sanitizeOpenChatIds( + chats: Map, + openChatIds: ChatId[], +): ChatId[] { + return openChatIds.filter((id) => chats.has(id)); +} + +/** + * Keep the most recently updated chats, plus any chat in + * `protectedIds` (i.e. with an open tab) regardless of age. + */ +export function pruneChats( + chats: Map, + protectedIds: Iterable, +): Map { + if (chats.size <= MAX_STORED_CHATS) { + return chats; + } + const protectedSet = new Set(protectedIds); + const byRecency = [...chats.values()].toSorted( + (a, b) => b.updatedAt - a.updatedAt, + ); + const kept = new Map(); + for (const chat of byRecency) { + if (kept.size < MAX_STORED_CHATS || protectedSet.has(chat.id)) { + kept.set(chat.id, chat); + } + } + return kept; +} + +export function openChatTab(chatState: ChatState, chatId: ChatId): ChatState { + if (!chatState.chats.has(chatId)) { + return chatState; + } + const openChatIds = chatState.openChatIds.includes(chatId) + ? chatState.openChatIds + : [...chatState.openChatIds, chatId]; + return { + ...chatState, + openChatIds, + activeChatId: chatId, + }; +} + +export function closeChatTab(chatState: ChatState, chatId: ChatId): ChatState { + const closedIndex = chatState.openChatIds.indexOf(chatId); + if (closedIndex === -1) { + return chatState; + } + + const openChatIds = chatState.openChatIds.filter((id) => id !== chatId); + + // When closing the active tab, fall back to the neighbor at the same index. + const activeChatId = + chatState.activeChatId === chatId + ? (openChatIds[Math.min(closedIndex, openChatIds.length - 1)] ?? null) + : chatState.activeChatId; + + return { + ...chatState, + openChatIds, + activeChatId, + }; +} + +export function addChatAndOpenTab(chatState: ChatState, chat: Chat): ChatState { + const chats = new Map(chatState.chats); + chats.set(chat.id, chat); + return openChatTab({ ...chatState, chats }, chat.id); +} + export const chatStateAtom = atomWithStorage( KEY, { chats: new Map(), activeChatId: null, + openChatIds: [], }, adaptForLocalStorage({ - toSerializable: (value: ChatState) => ({ - chats: [...removeEmptyChats(value.chats).entries()], - activeChatId: value.activeChatId, - }), - fromSerializable: (value) => ({ - chats: new Map(value.chats), - activeChatId: value.activeChatId, - }), + toSerializable: (value: ChatState) => { + const chats = pruneChats( + removeEmptyChats(value.chats), + value.openChatIds, + ); + return { + chats: [...chats.entries()], + activeChatId: value.activeChatId, + openChatIds: sanitizeOpenChatIds(chats, value.openChatIds), + }; + }, + fromSerializable: (value: SerializableChatState) => { + const chats = new Map(value.chats); + const openChatIds = sanitizeOpenChatIds( + chats, + value.openChatIds ?? (value.activeChatId ? [value.activeChatId] : []), + ); + return { + chats, + activeChatId: value.activeChatId, + openChatIds, + }; + }, }), ); @@ -90,10 +187,12 @@ export const activeChatAtom = atom( } return state.chats.get(state.activeChatId); }, + // oxlint-disable-next-line marimo/prefer-object-params (_get, set, chatId: ChatId | null) => { - set(chatStateAtom, (prev) => ({ - ...prev, - activeChatId: chatId, - })); + set(chatStateAtom, (prev) => + chatId === null + ? { ...prev, activeChatId: null } + : openChatTab(prev, chatId), + ); }, );