Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions astrbot/core/db/vec_db/faiss_impl/embedding_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import numpy as np

from astrbot.core.exceptions import KnowledgeBaseUploadError


class EmbeddingStorage:
def __init__(self, dimension: int, path: str | None = None) -> None:
Expand All @@ -16,6 +18,20 @@ def __init__(self, dimension: int, path: str | None = None) -> None:
self.index = None
if path and os.path.exists(path):
self.index = faiss.read_index(path)
actual_dimension = self.index.d
if actual_dimension != dimension:
raise KnowledgeBaseUploadError(
stage="embedding",
user_message=(
"向量化失败:知识库索引维度与当前嵌入模型维度不一致"
f"(索引维度 {actual_dimension},当前模型配置维度 {dimension})。"
"请使用原嵌入模型,或删除并重建知识库索引。"
),
details={
"index_dimension": actual_dimension,
"provider_dimension": dimension,
},
)
else:
base_index = faiss.IndexFlatL2(dimension)
self.index = faiss.IndexIDMap(base_index)
Expand Down
23 changes: 16 additions & 7 deletions astrbot/core/db/vec_db/faiss_impl/vec_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,22 @@ async def insert_batch(

start = time.time()
logger.debug(f"Generating embeddings for {len(contents)} contents...")
vectors = await self.embedding_provider.get_embeddings_batch(
contents,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=progress_callback,
)
try:
vectors = await self.embedding_provider.get_embeddings_batch(
contents,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=progress_callback,
)
except KnowledgeBaseUploadError:
raise
except Exception as exc:
raise KnowledgeBaseUploadError(
stage="embedding",
user_message=f"向量化失败:批量生成嵌入向量时出错。{exc}",
details={"content_count": content_count},
) from exc
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
end = time.time()
logger.debug(
f"Generated embeddings for {len(contents)} contents in {end - start:.2f} seconds.",
Expand Down
35 changes: 23 additions & 12 deletions astrbot/core/knowledge_base/kb_mgr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path

from astrbot.core import logger
from astrbot.core.exceptions import KnowledgeBaseUploadError
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.utils.astrbot_path import get_astrbot_knowledge_base_path

Expand Down Expand Up @@ -198,6 +199,19 @@ async def update_kb(
}
previous_init_error = kb_helper.init_error

def rollback_state() -> None:
kb.kb_name = previous_state["kb_name"]
kb.description = previous_state["description"]
kb.emoji = previous_state["emoji"]
kb.embedding_provider_id = previous_state["embedding_provider_id"]
kb.rerank_provider_id = previous_state["rerank_provider_id"]
kb.chunk_size = previous_state["chunk_size"]
kb.chunk_overlap = previous_state["chunk_overlap"]
kb.top_k_dense = previous_state["top_k_dense"]
kb.top_k_sparse = previous_state["top_k_sparse"]
kb.top_m_final = previous_state["top_m_final"]
kb_helper.init_error = previous_init_error

if kb_name is not None:
kb.kb_name = kb_name
if description is not None:
Expand Down Expand Up @@ -229,24 +243,21 @@ async def update_kb(

try:
await new_helper.initialize()
except KnowledgeBaseUploadError as e:
rollback_state()
logger.error(
f"知识库 {kb.kb_name}({kb.kb_id}) 重新初始化失败,继续使用旧实例: {e}",
exc_info=True,
)
raise
except Exception as e:
# Roll back in-memory settings and keep current helper available.
kb.kb_name = previous_state["kb_name"]
kb.description = previous_state["description"]
kb.emoji = previous_state["emoji"]
kb.embedding_provider_id = previous_state["embedding_provider_id"]
kb.rerank_provider_id = previous_state["rerank_provider_id"]
kb.chunk_size = previous_state["chunk_size"]
kb.chunk_overlap = previous_state["chunk_overlap"]
kb.top_k_dense = previous_state["top_k_dense"]
kb.top_k_sparse = previous_state["top_k_sparse"]
kb.top_m_final = previous_state["top_m_final"]
kb_helper.init_error = previous_init_error
rollback_state()
logger.error(
f"知识库 {kb.kb_name}({kb.kb_id}) 重新初始化失败,继续使用旧实例: {e}",
exc_info=True,
)
return kb_helper
raise ValueError(f"知识库重新初始化失败:{e}") from e
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

同上,将通用异常包装为 ValueError 可能会掩盖错误的本质。虽然这里是为了向 API 调用方暴露错误,但建议考虑使用更具体的异常类型(例如将其包装为 KnowledgeBaseUploadError),或者至少确保错误消息足够清晰且不丢失原始异常的上下文。另外,由于此处处理逻辑与前文高度相似,建议将其重构为共享的助手函数以避免代码重复。

References
  1. When implementing similar functionality for different cases (e.g., direct vs. quoted attachments), refactor the logic into a shared helper function to avoid code duplication.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里关于重复逻辑抽取成共享 helper 的方向我认同,不过这次我没有继续做这一步重构。当前这次 PR 的目标是尽量以最小 diff 修复 #7794 对应的两个核心问题:索引维度校验,以及初始化失败时不要伪成功。目前已经先修复了更关键的一点:KnowledgeBaseUploadError 不再在 update_kb() 中被降级为 ValueError,从而保留了结构化错误信息。至于进一步抽取共享 helper,我倾向于放在后续单独的整理里处理,避免这次 PR scope 继续扩大。


async with self.kb_db.get_db() as session:
session.add(kb)
Expand Down
80 changes: 75 additions & 5 deletions tests/unit/test_faiss_vec_db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from astrbot.core.db.vec_db.faiss_impl.embedding_storage import EmbeddingStorage
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
from astrbot.core.exceptions import KnowledgeBaseUploadError

Expand All @@ -22,9 +23,7 @@ async def test_insert_batch_skips_empty_contents() -> None:


@pytest.mark.asyncio
async def test_insert_batch_raises_friendly_error_for_embedding_count_mismatch() -> (
None
):
async def test_insert_batch_raises_friendly_error_for_embedding_count_mismatch() -> None:
vec_db = FaissVecDB.__new__(FaissVecDB)
vec_db.embedding_provider = AsyncMock()
vec_db.embedding_provider.get_embeddings_batch.return_value = [[0.1, 0.2]]
Expand All @@ -40,7 +39,78 @@ async def test_insert_batch_raises_friendly_error_for_embedding_count_mismatch()
ids=["doc-1", "doc-2"],
)

assert "向量化失败" in str(exc_info.value)
assert exc_info.value.stage == "embedding"
assert "期望 2,实际 1" in str(exc_info.value)
assert exc_info.value.details["expected_contents"] == 2
assert exc_info.value.details["actual_vectors"] == 1
vec_db.document_storage.insert_documents_batch.assert_not_awaited()
vec_db.embedding_storage.insert_batch.assert_not_awaited()


@pytest.mark.asyncio
async def test_insert_batch_wraps_embedding_batch_failures_as_embedding_error() -> None:
vec_db = FaissVecDB.__new__(FaissVecDB)
vec_db.embedding_provider = AsyncMock()
vec_db.embedding_provider.get_embeddings_batch.side_effect = Exception("rate limit")
vec_db.document_storage = AsyncMock()
vec_db.embedding_storage = AsyncMock()
vec_db.embedding_storage.dimension = 2

with pytest.raises(KnowledgeBaseUploadError) as exc_info:
await FaissVecDB.insert_batch(
vec_db,
contents=["chunk-1"],
metadatas=[{}],
ids=["doc-1"],
)

assert exc_info.value.stage == "embedding"
assert "批量生成嵌入向量时出错" in str(exc_info.value)
assert "rate limit" in str(exc_info.value)
vec_db.document_storage.insert_documents_batch.assert_not_awaited()
vec_db.embedding_storage.insert_batch.assert_not_awaited()


def test_embedding_storage_rejects_existing_index_dimension_mismatch() -> None:
mock_index = MagicMock()
mock_index.d = 768

with (
patch(
"astrbot.core.db.vec_db.faiss_impl.embedding_storage.os.path.exists",
return_value=True,
),
patch(
"astrbot.core.db.vec_db.faiss_impl.embedding_storage.faiss.read_index",
return_value=mock_index,
),
):
with pytest.raises(KnowledgeBaseUploadError) as exc_info:
EmbeddingStorage(1536, "existing-index.faiss")

assert exc_info.value.stage == "embedding"
assert "知识库索引维度与当前嵌入模型维度不一致" in str(exc_info.value)
assert exc_info.value.details == {
"index_dimension": 768,
"provider_dimension": 1536,
}


def test_embedding_storage_accepts_existing_index_dimension_match() -> None:
mock_index = MagicMock()
mock_index.d = 768

with (
patch(
"astrbot.core.db.vec_db.faiss_impl.embedding_storage.os.path.exists",
return_value=True,
),
patch(
"astrbot.core.db.vec_db.faiss_impl.embedding_storage.faiss.read_index",
return_value=mock_index,
),
):
storage = EmbeddingStorage(768, "existing-index.faiss")

assert storage.index is mock_index
assert storage.dimension == 768
Loading
Loading