diff --git a/python/tests/detail/test_collection_dql.py b/python/tests/detail/test_collection_dql.py index 8eb04e316..a49a446b2 100644 --- a/python/tests/detail/test_collection_dql.py +++ b/python/tests/detail/test_collection_dql.py @@ -827,6 +827,7 @@ def test_query_multivector_rrf(self, full_collection: Collection, doc_num): rrf_reranker = RrfReRanker() multi_query_result = full_collection.query( multi_query_vectors, + topk=3, reranker=rrf_reranker, ) assert len(multi_query_result) > 0, ( @@ -876,8 +877,11 @@ def test_query_multivector_weighted( batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") doc_fields, doc_vectors = generate_vectordict_random(full_collection.schema) - weight_list = [weights[v] for v in DEFAULT_VECTOR_FIELD_NAME.values()] - weighted_reranker = WeightedReRanker(weights=weight_list) + # Weights are positional, aligned with the multi_query_vectors order + # (DEFAULT_VECTOR_FIELD_NAME insertion order). Metric normalization is + # automatic from each field's schema. + weights_list = [weights[v] for v in DEFAULT_VECTOR_FIELD_NAME.values()] + weighted_reranker = WeightedReRanker(weights_list) single_query_results = {} for k, v in DEFAULT_VECTOR_FIELD_NAME.items(): @@ -894,6 +898,7 @@ def test_query_multivector_weighted( multi_query_result = full_collection.query( multi_query_vectors, + topk=3, reranker=weighted_reranker, ) assert len(multi_query_result) > 0, ( diff --git a/python/tests/test_collection.py b/python/tests/test_collection.py index b16e2eea6..7eba2e22d 100644 --- a/python/tests/test_collection.py +++ b/python/tests/test_collection.py @@ -27,7 +27,6 @@ InvertIndexParam, LogLevel, LogType, - MetricType, OptimizeOption, StatusCode, Query, @@ -1105,7 +1104,8 @@ def test_collection_query_with_weighted_reranker_by_multi_dense_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): """Test multi-vector query with Weighted reranker on multiple dense vectors.""" - reranker = WeightedReRanker(weights=[0.6, 0.4]) + weights = [0.6, 0.4] + reranker = WeightedReRanker(weights=weights) result = collection_with_multiple_docs.query( [ Query(field_name="dense", vector=multiple_docs[0].vector("dense")), @@ -1121,7 +1121,8 @@ def test_collection_query_with_weighted_reranker_by_multi_sparse_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): """Test multi-vector query with Weighted reranker on multiple sparse vectors.""" - reranker = WeightedReRanker(weights=[0.6, 0.4]) + weights = [0.6, 0.4] + reranker = WeightedReRanker(weights=weights) result = collection_with_multiple_docs.query( [ Query(field_name="sparse", vector=multiple_docs[0].vector("sparse")), @@ -1140,7 +1141,8 @@ def test_collection_query_with_weighted_reranker_by_hybrid_vector( self, collection_with_multiple_docs: Collection, multiple_docs ): """Test multi-vector query with Weighted reranker combining dense + sparse.""" - reranker = WeightedReRanker(weights=[0.7, 0.3]) + weights = [0.7, 0.3] + reranker = WeightedReRanker(weights=weights) result = collection_with_multiple_docs.query( [ Query(field_name="dense", vector=multiple_docs[0].vector("dense")), @@ -1158,7 +1160,7 @@ def test_collection_query_with_callback_reranker_by_multi_dense_vector( """Test multi-vector query with CallbackReRanker (Python callback via C++).""" callback_invoked = [] - def my_rerank_callback(query_results, topn): + def my_rerank_callback(query_results, fields, topn): callback_invoked.append(True) all_docs = [] for docs in query_results: @@ -1190,7 +1192,7 @@ def test_collection_query_with_callback_reranker_by_hybrid_vector( ): """Test multi-vector query with CallbackReRanker combining dense + sparse.""" - def my_rerank_callback(query_results, topn): + def my_rerank_callback(query_results, fields, topn): all_docs = [] for docs in query_results: all_docs.extend(docs) diff --git a/python/tests/test_reranker.py b/python/tests/test_reranker.py index 19350e78a..4c0461bc1 100644 --- a/python/tests/test_reranker.py +++ b/python/tests/test_reranker.py @@ -17,7 +17,7 @@ import pytest import os -from zvec import Doc +from zvec import Doc, MetricType, VectorSchema, DataType, FlatIndexParam from zvec.extension.multi_vector_reranker import ( CallbackReRanker, RrfReRanker, @@ -33,16 +33,19 @@ # ---------------------------- -# RrfRanker Test Case +# RrfReRanker Test Case # ---------------------------- class TestRrfReRanker: def test_init(self): reranker = RrfReRanker(rank_constant=100) assert reranker.rank_constant == 100 - def test_rerank_delegates_to_cpp(self): - """RrfReRanker.rerank() delegates to C++ (raises TypeError with Python Docs).""" + def test_default_rank_constant(self): reranker = RrfReRanker() + assert reranker.rank_constant == 60 + + def test_rerank(self): + reranker = RrfReRanker(rank_constant=60) doc1 = Doc(id="1", score=0.8) doc2 = Doc(id="2", score=0.7) @@ -51,66 +54,68 @@ def test_rerank_delegates_to_cpp(self): query_results = [[doc1, doc2, doc3], [doc3, doc1, doc4]] - with pytest.raises((TypeError, RuntimeError)): - reranker.rerank(query_results, topn=3) + results = reranker.rerank(query_results, topn=3) - def test_get_object_returns_cpp_reranker(self): - """_get_object() returns a valid C++ reranker instance.""" - reranker = RrfReRanker() - assert reranker._get_object() is not None + assert len(results) <= 3 + + for doc in results: + assert hasattr(doc, "score") + + scores = [doc.score for doc in results] + assert scores == sorted(scores, reverse=True) # ---------------------------- -# WeightedRanker Test Case +# WeightedReRanker Test Case # ---------------------------- class TestWeightedReRanker: + @staticmethod + def _make_fields(metrics): + return [ + VectorSchema( + name=f"vector{i}", + data_type=DataType.VECTOR_FP32, + dimension=4, + index_param=FlatIndexParam(metric_type=metric), + ) + for i, metric in enumerate(metrics) + ] + def test_init(self): - weights = [0.7, 0.3] - reranker = WeightedReRanker( - weights=weights, - ) - assert list(reranker.weights) == weights + reranker = WeightedReRanker([0.7, 0.3]) + assert reranker.weights == [0.7, 0.3] - def test_rerank_delegates_to_cpp(self): - """WeightedReRanker.rerank() delegates to C++ (raises TypeError with Python Docs).""" - weights = [0.7, 0.3] - reranker = WeightedReRanker(weights=weights) + def test_rerank(self): + reranker = WeightedReRanker([0.7, 0.3]) doc1 = Doc(id="1", score=0.8) doc2 = Doc(id="2", score=0.7) doc3 = Doc(id="3", score=0.9) query_results = [[doc1, doc2], [doc2, doc3]] + fields = self._make_fields([MetricType.L2, MetricType.L2]) + + results = reranker.rerank(query_results, topn=3, fields=fields) - with pytest.raises((TypeError, RuntimeError)): - reranker.rerank(query_results, topn=3) + assert len(results) <= 3 - def test_get_object_returns_cpp_reranker(self): - """_get_object() returns a valid C++ reranker instance.""" - reranker = WeightedReRanker(weights=[0.5, 0.5]) - assert reranker._get_object() is not None + for doc in results: + assert hasattr(doc, "score") # ---------------------------- # CallbackReRanker Test Case # ---------------------------- class TestCallbackReRanker: - def test_init(self): - def my_callback(query_results, topn): - return [] - - reranker = CallbackReRanker(callback=my_callback) - assert reranker._get_object() is not None - def test_rerank(self): - def my_callback(query_results, topn): + def my_callback(query_results, fields, topn): all_docs = [] for docs in query_results: all_docs.extend(docs) all_docs.sort(key=lambda d: d.score, reverse=True) return all_docs[:topn] - reranker = CallbackReRanker(callback=my_callback) + reranker = CallbackReRanker(my_callback) doc1 = Doc(id="1", score=0.8) doc2 = Doc(id="2", score=0.9) @@ -128,22 +133,15 @@ def my_callback(query_results, topn): def test_callback_with_topn(self): received_topn = [] - def my_callback(query_results, topn): + def my_callback(query_results, fields, topn): received_topn.append(topn) return [] - reranker = CallbackReRanker(callback=my_callback) + reranker = CallbackReRanker(my_callback) reranker.rerank([[Doc(id="1", score=0.5)]], topn=7) assert received_topn == [7] - def test_get_object_returns_cpp_reranker(self): - def my_callback(query_results, topn): - return [] - - reranker = CallbackReRanker(callback=my_callback) - assert reranker._get_object() is not None - # ---------------------------- # QwenReRanker Test Case @@ -200,7 +198,7 @@ def test_rerank_empty_results(self): reranker = QwenReRanker( query="test", api_key="test_key", rerank_field="content" ) - results = reranker.rerank([], topn=10) + results = reranker.rerank({}) assert results == [] def test_rerank_no_valid_documents(self): @@ -208,22 +206,22 @@ def test_rerank_no_valid_documents(self): query="test", api_key="test_key", rerank_field="content" ) # Document without the rerank_field - query_results = [[Doc(id="1")]] + query_results = {"vector1": [Doc(id="1")]} with pytest.raises(ValueError, match="No documents to rerank"): - reranker.rerank(query_results, topn=10) + reranker.rerank(query_results) def test_rerank_skip_empty_content(self): reranker = QwenReRanker( query="test", api_key="test_key", rerank_field="content" ) - query_results = [ - [ + query_results = { + "vector1": [ Doc(id="1", fields={"content": ""}), Doc(id="2", fields={"content": " "}), ] - ] + } with pytest.raises(ValueError, match="No documents to rerank"): - reranker.rerank(query_results, topn=10) + reranker.rerank(query_results) @patch("zvec.extension.qwen_function.require_module") def test_rerank_success(self, mock_require_module): @@ -246,12 +244,12 @@ def test_rerank_success(self, mock_require_module): query="test query", api_key="test_key", rerank_field="content" ) - query_results = [ - [ + query_results = { + "vector1": [ Doc(id="1", fields={"content": "Document 1"}), Doc(id="2", fields={"content": "Document 2"}), ] - ] + } results = reranker.rerank(query_results, topn=2) @@ -292,7 +290,7 @@ def test_rerank_deduplicate_documents(self, mock_require_module): # Same document in multiple vector results doc1 = Doc(id="1", fields={"content": "Document 1"}) - query_results = [[doc1], [doc1]] + query_results = {"vector1": [doc1], "vector2": [doc1]} results = reranker.rerank(query_results, topn=5) @@ -317,10 +315,10 @@ def test_rerank_api_error(self, mock_require_module): query="test", api_key="test_key", rerank_field="content" ) - query_results = [[Doc(id="1", fields={"content": "Document 1"})]] + query_results = {"vector1": [Doc(id="1", fields={"content": "Document 1"})]} with pytest.raises(ValueError, match="DashScope API error"): - reranker.rerank(query_results, topn=10) + reranker.rerank(query_results) @patch("zvec.extension.qwen_function.require_module") def test_rerank_runtime_error(self, mock_require_module): @@ -333,10 +331,10 @@ def test_rerank_runtime_error(self, mock_require_module): query="test", api_key="test_key", rerank_field="content" ) - query_results = [[Doc(id="1", fields={"content": "Document 1"})]] + query_results = {"vector1": [Doc(id="1", fields={"content": "Document 1"})]} with pytest.raises(RuntimeError, match="Failed to call DashScope API"): - reranker.rerank(query_results, topn=10) + reranker.rerank(query_results) @pytest.mark.skipif( not RUN_INTEGRATION_TESTS, @@ -357,8 +355,8 @@ def test_real_qwen_rerank(self): ) # Prepare test documents - query_results = [ - [ + query_results = { + "vector1": [ Doc( id="1", score=0.8, @@ -381,7 +379,7 @@ def test_real_qwen_rerank(self): }, ), ], - [ + "vector2": [ Doc( id="4", score=0.6, @@ -397,7 +395,7 @@ def test_real_qwen_rerank(self): }, ), ], - ] + } # Call real API results = reranker.rerank(query_results, topn=3) @@ -582,7 +580,7 @@ def test_rerank_empty_results(self): return_value=mock_st, ): reranker = DefaultLocalReRanker(query="test", rerank_field="content") - results = reranker.rerank([], topn=10) + results = reranker.rerank({}) assert results == [] def test_rerank_no_valid_documents(self): @@ -600,9 +598,9 @@ def test_rerank_no_valid_documents(self): reranker = DefaultLocalReRanker(query="test", rerank_field="content") # Document without the rerank_field - query_results = [[Doc(id="1")]] + query_results = {"vector1": [Doc(id="1")]} with pytest.raises(ValueError, match="No documents to rerank"): - reranker.rerank(query_results, topn=10) + reranker.rerank(query_results) def test_rerank_skip_empty_content(self): """Test rerank skips documents with empty content.""" @@ -618,14 +616,14 @@ def test_rerank_skip_empty_content(self): ): reranker = DefaultLocalReRanker(query="test", rerank_field="content") - query_results = [ - [ + query_results = { + "vector1": [ Doc(id="1", fields={"content": ""}), Doc(id="2", fields={"content": " "}), ] - ] + } with pytest.raises(ValueError, match="No documents to rerank"): - reranker.rerank(query_results, topn=10) + reranker.rerank(query_results) def test_rerank_success(self): """Test successful rerank with mocked model.""" @@ -649,13 +647,13 @@ def test_rerank_success(self): ): reranker = DefaultLocalReRanker(query="test query", rerank_field="content") - query_results = [ - [ + query_results = { + "vector1": [ Doc(id="1", score=0.8, fields={"content": "Document 1"}), Doc(id="2", score=0.7, fields={"content": "Document 2"}), Doc(id="3", score=0.6, fields={"content": "Document 3"}), ] - ] + } results = reranker.rerank(query_results, topn=3) @@ -698,15 +696,15 @@ def test_rerank_with_topn_limit(self): ): reranker = DefaultLocalReRanker(query="test", rerank_field="content") - query_results = [ - [ + query_results = { + "vector1": [ Doc(id="1", fields={"content": "Doc 1"}), Doc(id="2", fields={"content": "Doc 2"}), Doc(id="3", fields={"content": "Doc 3"}), Doc(id="4", fields={"content": "Doc 4"}), Doc(id="5", fields={"content": "Doc 5"}), ] - ] + } results = reranker.rerank(query_results, topn=2) @@ -740,10 +738,10 @@ def test_rerank_deduplicate_documents(self): doc1 = Doc(id="1", fields={"content": "Document 1"}) doc2 = Doc(id="2", fields={"content": "Document 2"}) - query_results = [ - [doc1, doc2], - [doc1], # doc1 appears in both - ] + query_results = { + "vector1": [doc1, doc2], + "vector2": [doc1], # doc1 appears in both + } results = reranker.rerank(query_results, topn=5) @@ -775,13 +773,13 @@ def test_rerank_sorting(self): ): reranker = DefaultLocalReRanker(query="test", rerank_field="content") - query_results = [ - [ + query_results = { + "vector1": [ Doc(id="1", fields={"content": "Doc 1"}), Doc(id="2", fields={"content": "Doc 2"}), Doc(id="3", fields={"content": "Doc 3"}), ] - ] + } results = reranker.rerank(query_results, topn=3) @@ -811,10 +809,10 @@ def test_rerank_model_error(self): ): reranker = DefaultLocalReRanker(query="test", rerank_field="content") - query_results = [[Doc(id="1", fields={"content": "Document 1"})]] + query_results = {"vector1": [Doc(id="1", fields={"content": "Document 1"})]} with pytest.raises(RuntimeError, match="Failed to compute rerank scores"): - reranker.rerank(query_results, topn=10) + reranker.rerank(query_results) def test_rerank_with_custom_batch_size(self): """Test rerank uses custom batch_size.""" @@ -837,14 +835,14 @@ def test_rerank_with_custom_batch_size(self): query="test", rerank_field="content", batch_size=64 ) - query_results = [ - [ + query_results = { + "vector1": [ Doc(id="1", fields={"content": "Doc 1"}), Doc(id="2", fields={"content": "Doc 2"}), ] - ] + } - reranker.rerank(query_results, topn=10) + reranker.rerank(query_results) # Verify batch_size is passed to predict call_args = mock_model.predict.call_args @@ -870,8 +868,8 @@ def test_real_sentence_transformer_rerank(self): ) # Prepare test documents - query_results = [ - [ + query_results = { + "vector1": [ Doc( id="1", score=0.8, @@ -894,7 +892,7 @@ def test_real_sentence_transformer_rerank(self): }, ), ], - [ + "vector2": [ Doc( id="4", score=0.6, @@ -910,7 +908,7 @@ def test_real_sentence_transformer_rerank(self): }, ), ], - ] + } # Call real model results = reranker.rerank(query_results, topn=3) @@ -948,49 +946,3 @@ def test_real_sentence_transformer_rerank(self): content = doc.field("content") if content: print(f" Content: {content[:80]}...") - - -# ---------------------------- -# DocList Type and Delegation Tests -# ---------------------------- -class TestDocList: - def test_type_alias(self): - """DocList is list[Doc].""" - from zvec.model.doc import DocList - from zvec import Doc, DocList as QR - - assert DocList == list[Doc] - assert QR == list[Doc] - - def test_rrf_reranker_delegates_to_cpp(self): - """RrfReRanker.rerank() delegates to C++ (raises TypeError with Python Docs).""" - reranker = RrfReRanker() - with pytest.raises(TypeError): - reranker.rerank([[Doc(id="1", score=0.5)]], topn=5) - - def test_weighted_reranker_delegates_to_cpp(self): - """WeightedReRanker.rerank() delegates to C++ (raises TypeError with Python Docs).""" - reranker = WeightedReRanker(weights=[0.7, 0.3]) - with pytest.raises(TypeError): - reranker.rerank( - [[Doc(id="1", score=0.5)], [Doc(id="2", score=0.3)]], topn=5 - ) - - def test_single_route_query_results(self): - """CallbackReRanker works with single-route (one element list).""" - - def cb(query_results, topn): - return query_results[0][:topn] - - reranker = CallbackReRanker(callback=cb) - results = reranker.rerank( - [ - [ - Doc(id="1", score=0.9), - Doc(id="2", score=0.8), - Doc(id="3", score=0.7), - ] - ], - topn=2, - ) - assert len(results) == 2 diff --git a/python/zvec/executor/query_executor.py b/python/zvec/executor/query_executor.py index 08c98ca31..62bc9e332 100644 --- a/python/zvec/executor/query_executor.py +++ b/python/zvec/executor/query_executor.py @@ -19,7 +19,7 @@ from _zvec import _Collection, _MultiQuery from _zvec.param import _Fts, _SearchQuery, _SubQuery -from ..extension import ReRanker +from ..extension import CallbackReRanker, ReRanker, RrfReRanker, WeightedReRanker from ..model.convert import convert_to_py_doc from ..model.doc import DocList from ..model.param.query import Query @@ -143,12 +143,18 @@ def _execute_multi_query( ) -> DocList: """Multiple queries: send a ``_MultiQuery`` to C++. - A Python-only reranker (``_get_object()`` returns None) cannot run - inside the C++ MultiQuery, so each route is executed individually and - merged by the reranker in Python. + A Python-only reranker (e.g. a model/API-based one) cannot run inside + the C++ MultiQuery, so each route is executed individually and merged by + the reranker in Python. The built-in RRF/Weighted/Callback rerankers use + the C++ variant-based fast path. """ reranker = ctx.reranker - if reranker is not None and reranker._get_object() is None: + if reranker is None: + raise ValueError( + "A reranker is required to merge results from multiple queries; " + "specify the 'reranker' argument." + ) + if not isinstance(reranker, (RrfReRanker, WeightedReRanker, CallbackReRanker)): docs_list = self._execute_python_pipeline(queries, collection) return self._merge_and_rerank(ctx, docs_list) @@ -162,14 +168,26 @@ def _build_multi_query( """Assemble a C++ ``_MultiQuery`` from per-route ``_SearchQuery`` objects.""" multi_query = _MultiQuery() multi_query.queries = [_SubQuery.from_search_query(query) for query in queries] + # num_candidates controls per-sub-query candidate count for reranking pool. + # It must NOT be limited to the final output topk; use at least the C++ + # SubQuery default of 10 to ensure sufficient candidates for reranking. + _DEFAULT_NUM_CANDIDATES = 10 + for sub in multi_query.queries: + sub.num_candidates = max(ctx.topk, _DEFAULT_NUM_CANDIDATES) multi_query.topk = ctx.topk if ctx.filter: multi_query.filter = ctx.filter multi_query.include_vector = ctx.include_vector if ctx.output_fields is not None: multi_query.output_fields = ctx.output_fields - if ctx.reranker is not None: - multi_query.reranker = ctx.reranker._get_object() + # Set rerank strategy via the C++ variant-based API. + reranker = ctx.reranker + if isinstance(reranker, RrfReRanker): + multi_query.set_rerank_rrf(reranker.rank_constant) + elif isinstance(reranker, WeightedReRanker): + multi_query.set_rerank_weighted(reranker.weights) + elif isinstance(reranker, CallbackReRanker): + multi_query.set_rerank_callback(reranker._callback) return multi_query def _execute_python_pipeline( diff --git a/python/zvec/extension/__init__.py b/python/zvec/extension/__init__.py index f738c6c90..a1f3a8cde 100644 --- a/python/zvec/extension/__init__.py +++ b/python/zvec/extension/__init__.py @@ -24,6 +24,7 @@ from .qwen_embedding_function import QwenDenseEmbedding, QwenSparseEmbedding from .qwen_function import QwenFunctionBase from .qwen_rerank_function import QwenReRanker +from .rerank_function import RerankFunction from .rerank_function import RerankFunction as ReRanker from .sentence_transformer_embedding_function import ( DefaultLocalDenseEmbedding, @@ -49,6 +50,7 @@ "QwenReRanker", "QwenSparseEmbedding", "ReRanker", + "RerankFunction", "RrfReRanker", "SentenceTransformerFunctionBase", "SparseEmbeddingFunction", diff --git a/python/zvec/extension/multi_vector_reranker.py b/python/zvec/extension/multi_vector_reranker.py index e96de178c..acee984b1 100644 --- a/python/zvec/extension/multi_vector_reranker.py +++ b/python/zvec/extension/multi_vector_reranker.py @@ -14,134 +14,178 @@ from __future__ import annotations from collections.abc import Callable -from typing import Optional +from typing import TYPE_CHECKING -from _zvec import _CallbackReranker, _RrfReranker, _WeightedReranker +from _zvec import _CallbackParams, _Doc, _reranker_rerank, _RrfParams, _WeightedParams -from ..model.doc import DocList +from ..model.doc import Doc, DocList from .rerank_function import RerankFunction +if TYPE_CHECKING: + from ..model.schema import FieldSchema, VectorSchema + + +def _to_cpp_doc_lists( + query_results: list[list[Doc]], +) -> tuple[list[list], dict[str, Doc]]: + """Convert Python Doc lists to C++ _Doc lists for reranker input.""" + id_to_doc: dict[str, Doc] = {} + cpp_results: list[list] = [] + for query_result in query_results: + cpp_list: list = [] + for doc in query_result: + _doc = _Doc() + _doc.set_pk(doc.id) + _doc.set_score(doc.score if doc.score is not None else 0.0) + cpp_list.append(_doc) + if doc.id not in id_to_doc: + id_to_doc[doc.id] = doc + cpp_results.append(cpp_list) + return cpp_results, id_to_doc + + +def _from_cpp_docs(cpp_docs: list, id_to_doc: dict[str, Doc]) -> DocList: + """Convert C++ rerank result _Doc list back to Python DocList.""" + results: DocList = [] + for _doc in cpp_docs: + doc_id = _doc.pk() + new_score = _doc.score() + original = id_to_doc.get(doc_id) + if original is not None: + results.append(original._replace(score=new_score)) + else: + results.append(Doc(id=doc_id, score=new_score)) + return results + class RrfReRanker(RerankFunction): """Re-ranker using Reciprocal Rank Fusion (RRF) for multi-vector search. - RRF combines results from multiple vector queries without requiring relevance scores. - It assigns higher weight to documents that appear early in multiple result lists. - - The RRF score for a document at rank ``r`` is: ``1 / (k + r + 1)``, - where ``k`` is the rank constant. + RRF combines results from multiple vector queries without requiring + relevance scores. The RRF score for a document at rank r is: + score = 1 / (k + r + 1) + where k is the rank constant. Args: - rank_constant (int, optional): Smoothing constant ``k`` in RRF formula. - Larger values reduce the impact of early ranks. Defaults to 60. + rank_constant: RRF smoothing constant (default: 60). + Higher values reduce the influence of rank position. + + Example: + >>> reranker = RrfReRanker(rank_constant=60) + >>> merged = reranker.rerank([results_a, results_b], topn=10) """ - def __init__( - self, - rank_constant: int = 60, - ): + def __init__(self, rank_constant: int = 60): self._rank_constant = rank_constant - # Use C++ implementation for performance - self._cpp_reranker = _RrfReranker(rank_constant) @property def rank_constant(self) -> int: + """int: RRF rank constant.""" return self._rank_constant - def _get_object(self): - """Return the underlying C++ RrfReranker instance.""" - return self._cpp_reranker - - def rerank(self, query_results: list[DocList], topn: int) -> DocList: - """Re-rank using C++ RRF implementation. + def _to_cpp_params(self): + return _RrfParams(self._rank_constant) - Args: - query_results (list[DocList]): Multi-route recall results, - positionally aligned with queries. - topn (int): Number of top documents to return. - - Returns: - DocList: Re-ranked documents. - """ - return self._cpp_reranker.rerank(query_results, topn) + def rerank( + self, + query_results: list[list[Doc]], + topn: int = 10, + *, + fields: list[FieldSchema | VectorSchema] | None = None, # noqa: ARG002 + ) -> DocList: + """Apply RRF to combine multiple query results via C++ reranker.""" + cpp_results, id_to_doc = _to_cpp_doc_lists(query_results) + cpp_docs = _reranker_rerank(self._to_cpp_params(), cpp_results, [], topn) + return _from_cpp_docs(cpp_docs, id_to_doc) class WeightedReRanker(RerankFunction): - """Re-ranker that combines scores from multiple vector fields using weights. + """Re-ranker that combines scores using per-sub-query weights. - Each vector field's relevance score is normalized based on its own metric - type, then scaled by a user-provided weight. Final scores are summed across - fields. The actual re-ranking logic lives in the C++ implementation. + Each sub-query's score is normalized by metric type (automatic when used + via collection.multi_query), then multiplied by the corresponding weight. Args: - weights (Optional[list[float]], optional): Weight per vector field, - aligned by position with the queries supplied to ``collection.query()``. - Defaults to None (treated as an empty list). + weights: Per-sub-query weights. Length must match the number of + sub-queries. + + Example: + >>> reranker = WeightedReRanker([0.7, 0.3]) + >>> merged = reranker.rerank([results_a, results_b], topn=10, + ... fields=field_schemas) """ - def __init__( - self, - weights: Optional[list[float]] = None, - ): - self._cpp_reranker = _WeightedReranker(weights or []) + def __init__(self, weights: list[float]): + self._weights = list(weights) @property def weights(self) -> list[float]: - """list[float]: Weight list for vector fields, aligned with queries.""" - return self._cpp_reranker.weights + """list[float]: Per-sub-query weights.""" + return self._weights - def _get_object(self): - """Return the underlying C++ WeightedReranker instance.""" - return self._cpp_reranker + def _to_cpp_params(self): + return _WeightedParams(self._weights) - def rerank(self, query_results: list[DocList], topn: int) -> DocList: - """Re-rank using C++ Weighted implementation. + def rerank( + self, + query_results: list[list[Doc]], + topn: int = 10, + *, + fields: list[FieldSchema | VectorSchema] | None = None, + ) -> DocList: + """Combine scores from multiple sub-queries using weighted sum via C++ reranker. Args: - query_results (list[DocList]): Multi-route recall results, - positionally aligned with queries. - topn (int): Number of top documents to return. + query_results: Per-sub-query document lists. + topn: Maximum results to return. + fields: Per-sub-query Python FieldSchema/VectorSchema objects + (required for score normalization by metric type). - Returns: - DocList: Re-ranked documents. + Raises: + ValueError: If fields is None (required for normalization). """ - return self._cpp_reranker.rerank(query_results, topn) + if not fields: + raise ValueError( + "WeightedReRanker.rerank() requires 'fields' for score normalization. " + "Pass field schemas via fields= parameter." + ) + cpp_fields = [f._get_object() for f in fields] + cpp_results, id_to_doc = _to_cpp_doc_lists(query_results) + cpp_docs = _reranker_rerank( + self._to_cpp_params(), cpp_results, cpp_fields, topn + ) + return _from_cpp_docs(cpp_docs, id_to_doc) class CallbackReRanker(RerankFunction): - """Re-ranker that delegates to a user-provided Python callback. - - This bridges a Python callable into the C++ reranker interface, enabling - custom re-ranking logic to be executed within the C++ MultiQuery path. + """Re-ranker that delegates to a user-provided callback. - The callback receives raw C++ ``_Doc`` objects grouped per query (as a - ``list[list[_Doc]]``) and must return a ``list[_Doc]``. + The callback receives sub-query results, field schemas, and topn. Args: callback: A callable with signature - ``(query_results: list[list[_Doc]], topn: int) -> list[_Doc]``. + (results: list[list[Doc]], fields: list, topn: int) -> list[Doc] + + Example: + >>> def my_rerank(results, fields, topn): + ... # custom logic + ... return merged[:topn] + >>> reranker = CallbackReRanker(my_rerank) + >>> merged = reranker.rerank([results_a, results_b], topn=10) """ - def __init__( - self, - callback: Callable, - ): + def __init__(self, callback: Callable): self._callback = callback - self._cpp_reranker = _CallbackReranker(callback) - - def _get_object(self): - """Return the underlying C++ CallbackReranker instance.""" - return self._cpp_reranker - def rerank(self, query_results: list[DocList], topn: int) -> DocList: - """Invoke the callback to re-rank documents. + def _to_cpp_params(self): + return _CallbackParams(self._callback) - Args: - query_results (list[DocList]): Multi-route recall results, - positionally aligned with queries. - topn (int): Number of top documents to return. - - Returns: - DocList: Re-ranked documents. - """ - return self._callback(query_results, topn) + def rerank( + self, + query_results: list[list[Doc]], + topn: int = 10, + *, + fields: list[FieldSchema | VectorSchema] | None = None, + ) -> DocList: + """Invoke the callback to re-rank documents.""" + return self._callback(query_results, fields, topn) diff --git a/python/zvec/extension/qwen_rerank_function.py b/python/zvec/extension/qwen_rerank_function.py index ead1f9e04..0a06eb203 100644 --- a/python/zvec/extension/qwen_rerank_function.py +++ b/python/zvec/extension/qwen_rerank_function.py @@ -13,12 +13,15 @@ # limitations under the License. from __future__ import annotations -from typing import Optional +from typing import TYPE_CHECKING, Optional from ..model.doc import Doc, DocList from .qwen_function import QwenFunctionBase from .rerank_function import RerankFunction +if TYPE_CHECKING: + from ..model.schema import FieldSchema, VectorSchema + class QwenReRanker(QwenFunctionBase, RerankFunction): """Re-ranker using Qwen (DashScope) cross-encoder API for semantic re-ranking. @@ -77,51 +80,60 @@ def __init__( ValueError: If query is empty or API key is unavailable. """ QwenFunctionBase.__init__(self, model=model, api_key=api_key) - RerankFunction.__init__(self) + self._rerank_field = rerank_field if not query: raise ValueError("Query is required for QwenReRanker") self._query = query - self._rerank_field = rerank_field - - @property - def query(self) -> str: - """str: Query text used for semantic re-ranking.""" - return self._query @property def rerank_field(self) -> Optional[str]: """Optional[str]: Field name used as re-ranking input.""" return self._rerank_field - def rerank(self, query_results: list[DocList], topn: int) -> DocList: + @property + def query(self) -> str: + """str: Query text used for semantic re-ranking.""" + return self._query + + def rerank( + self, + query_results: list[list[Doc]], + topn: int = 10, + *, + fields: list[FieldSchema | VectorSchema] | None = None, # noqa: ARG002 + ) -> DocList: """Re-rank documents using Qwen's TextReRank API. Sends document texts to DashScope TextReRank service along with the query. Returns documents sorted by relevance scores from the cross-encoder model. Args: - query_results (list[DocList]): Multi-route recall results, - positionally aligned with the queries supplied to - ``collection.query()``. Documents from all routes are - deduplicated and re-ranked together. - topn (int): Maximum number of documents to return after re-ranking. + query_results (list[list[Doc]]): Per-sub-query lists of retrieved + documents. Documents from all lists are deduplicated and + re-ranked together. + topn (int): Maximum number of documents to return. + fields: Unused; present for interface compatibility. Returns: - DocList: Re-ranked documents (up to ``topn``) with updated - ``score`` fields containing relevance scores from the API. + list[Doc]: Re-ranked documents (up to ``topn``) with updated ``score`` + fields containing relevance scores from the API. Raises: ValueError: If no valid documents are found or API call fails. Note: - - Duplicate documents (same ID) across routes are processed once + - Duplicate documents (same ID) across lists are processed once - Documents with empty/missing ``rerank_field`` content are skipped - Returned scores are relevance scores from the cross-encoder model """ if not query_results: return [] + # Accept both dict (legacy) and list formats + if isinstance(query_results, dict): + query_results = list(query_results.values()) + # Collect and deduplicate documents id_to_doc: dict[str, Doc] = {} doc_ids: list[str] = [] diff --git a/python/zvec/extension/rerank_function.py b/python/zvec/extension/rerank_function.py index 54fe16124..09a26c41e 100644 --- a/python/zvec/extension/rerank_function.py +++ b/python/zvec/extension/rerank_function.py @@ -14,45 +14,43 @@ from __future__ import annotations from abc import ABC, abstractmethod +from typing import TYPE_CHECKING -from ..model.doc import DocList +from ..model.doc import Doc, DocList +if TYPE_CHECKING: + from ..model.schema import FieldSchema, VectorSchema -class RerankFunction(ABC): - """Abstract base class for re-ranking search results. - Re-rankers refine the output of one or more vector queries by applying - a secondary scoring strategy. They are used in the ``query()`` method of - ``Collection`` via the ``reranker`` parameter. +class RerankFunction(ABC): + """Abstract base class for reranker parameter containers. - Note: - Subclasses must implement the ``rerank()`` method. + Subclasses define rerank parameters and implement _to_cpp_params() + for conversion to C++ parameter structs (used by collection fast path). + Each subclass also provides a standalone rerank() implementation. """ + def _to_cpp_params(self): + """Return C++ reranker params. Override in subclasses that use C++ path.""" + raise NotImplementedError + @abstractmethod - def rerank(self, query_results: list[DocList], topn: int) -> DocList: - """Re-rank documents from multi-route recall results. + def rerank( + self, + query_results: list[list[Doc]], + topn: int = 10, + *, + fields: list[FieldSchema | VectorSchema] | None = None, + ) -> DocList: + """Execute rerank on sub-query results. Args: - query_results (list[DocList]): List of query results from - multi-route recall. Each element corresponds to a Query in the - collection.query(queries=List[Query]) call, aligned by position. - topn (int): Number of top documents to return after re-ranking. + query_results: List of per-sub-query document lists. + topn: Maximum number of results to return. + fields: Per-sub-query Python FieldSchema/VectorSchema objects + (required for WeightedReRanker score normalization). Returns: - DocList: Re-ranked list of documents (length ≤ ``topn``), - with updated ``score`` fields. + Re-ranked document list. """ ... - - def _get_object(self): - """Return the underlying C++ Reranker instance, if available. - - This is used internally by the query executor to pass the reranker - to the C++ MultiQuery method. Subclasses that wrap a C++ reranker - should override this method. - - Returns: - The C++ Reranker shared pointer, or None if not available. - """ - return None # noqa: RET501 diff --git a/python/zvec/extension/sentence_transformer_rerank_function.py b/python/zvec/extension/sentence_transformer_rerank_function.py index 2e22d7c0d..bc84242af 100644 --- a/python/zvec/extension/sentence_transformer_rerank_function.py +++ b/python/zvec/extension/sentence_transformer_rerank_function.py @@ -13,13 +13,16 @@ # limitations under the License. from __future__ import annotations -from typing import Literal, Optional +from typing import TYPE_CHECKING, Literal, Optional from ..model.doc import Doc, DocList from ..tool import require_module from .rerank_function import RerankFunction from .sentence_transformer_function import SentenceTransformerFunctionBase +if TYPE_CHECKING: + from ..model.schema import FieldSchema, VectorSchema + class DefaultLocalReRanker(SentenceTransformerFunctionBase, RerankFunction): """Re-ranker using Sentence Transformer cross-encoder models for semantic re-ranking. @@ -137,13 +140,13 @@ class DefaultLocalReRanker(SentenceTransformerFunctionBase, RerankFunction): ... ) >>> # Direct rerank call (for testing) - >>> query_results = [ - ... [ + >>> query_results = { + ... "vector1": [ ... Doc(id="1", score=0.9, fields={"content": "Machine learning is..."}), ... Doc(id="2", score=0.8, fields={"content": "Deep learning is..."}), ... ] - ... ] - >>> reranked = reranker.rerank(query_results, topn=5) + ... } + >>> reranked = reranker.rerank(query_results) >>> for doc in reranked: ... print(f"ID: {doc.id}, Score: {doc.score:.4f}") ID: 2, Score: 0.9234 @@ -188,14 +191,13 @@ def __init__( self, model_name=model_name, model_source=model_source, device=device ) - # Initialize rerank function - RerankFunction.__init__(self) + # Initialize rerank parameters + self._rerank_field = rerank_field # Validate query if not query: raise ValueError("Query is required for DefaultLocalReRanker") self._query = query - self._rerank_field = rerank_field self._batch_size = batch_size # Load and validate cross-encoder model @@ -261,22 +263,28 @@ def _get_model(self): f"from {self._model_source}: {e!s}" ) from e - @property - def query(self) -> str: - """str: Query text used for semantic re-ranking.""" - return self._query - @property def rerank_field(self) -> Optional[str]: """Optional[str]: Field name used as re-ranking input.""" return self._rerank_field + @property + def query(self) -> str: + """str: Query text used for semantic re-ranking.""" + return self._query + @property def batch_size(self) -> int: """int: Batch size for processing query-document pairs.""" return self._batch_size - def rerank(self, query_results: list[DocList], topn: int) -> DocList: + def rerank( + self, + query_results: list[list[Doc]], + topn: int = 10, + *, + fields: list[FieldSchema | VectorSchema] | None = None, # noqa: ARG002 + ) -> DocList: """Re-rank documents using Sentence Transformer cross-encoder model. Evaluates each query-document pair using the cross-encoder model to compute @@ -284,22 +292,21 @@ def rerank(self, query_results: list[DocList], topn: int) -> DocList: results are returned. Args: - query_results (list[DocList]): Multi-route recall results, - positionally aligned with the queries supplied to - ``collection.query()``. Documents from all routes are - deduplicated and re-ranked together. - topn (int): Maximum number of documents to return after re-ranking. + query_results (list[list[Doc]]): Per-sub-query lists of retrieved + documents. Documents from all lists are deduplicated and + re-ranked together. + topn (int): Maximum number of documents to return. + fields: Unused; present for interface compatibility. Returns: - DocList: Re-ranked documents (up to ``topn``) with updated - ``score`` fields containing relevance scores from the - cross-encoder model. + list[Doc]: Re-ranked documents (up to ``topn``) with updated ``score`` + fields containing relevance scores from the cross-encoder model. Raises: ValueError: If no valid documents are found or model inference fails. Note: - - Duplicate documents (same ID) across routes are processed once + - Duplicate documents (same ID) across fields are processed once - Documents with empty/missing ``rerank_field`` content are skipped - Returned scores are logits from the cross-encoder model - Higher scores indicate higher relevance @@ -311,19 +318,23 @@ def rerank(self, query_results: list[DocList], topn: int) -> DocList: ... topn=3, ... rerank_field="content" ... ) - >>> query_results = [ - ... [ + >>> query_results = { + ... "vector1": [ ... Doc(id="1", score=0.9, fields={"content": "ML basics"}), ... Doc(id="2", score=0.8, fields={"content": "DL tutorial"}), ... ] - ... ] - >>> reranked = reranker.rerank(query_results, topn=3) + ... } + >>> reranked = reranker.rerank(query_results) >>> len(reranked) <= 3 True """ if not query_results: return [] + # Accept both dict (legacy) and list formats + if isinstance(query_results, dict): + query_results = list(query_results.values()) + # Collect and deduplicate documents id_to_doc: dict[str, Doc] = {} doc_ids: list[str] = [] diff --git a/src/binding/c/c_api.cc b/src/binding/c/c_api.cc index df26bf939..f2372abd2 100644 --- a/src/binding/c/c_api.cc +++ b/src/binding/c/c_api.cc @@ -5622,43 +5622,31 @@ zvec_error_code_t zvec_group_by_vector_query_set_flat_params( // Reranker Implementation // ============================================================================= -zvec_reranker_t *zvec_create_rrf_reranker(int rank_constant) { - ZVEC_TRY_RETURN_NULL("Failed to create RRF Reranker", - auto *reranker = - new zvec::Reranker::Ptr( - std::make_shared( - rank_constant)); - return reinterpret_cast(reranker);) - return nullptr; -} - -zvec_reranker_t *zvec_create_weighted_reranker(const double *weights, - size_t weight_count) { - if (!weights && weight_count > 0) { - set_last_error("Weights pointer cannot be null when weight_count > 0"); - return nullptr; +zvec_error_code_t zvec_multi_query_set_rerank_rrf( + zvec_multi_query_t *query, int rank_constant) { + if (!query) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, "Query pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; } - - ZVEC_TRY_RETURN_NULL( - "Failed to create Weighted Reranker", - auto *reranker = new zvec::Reranker::Ptr( - std::make_shared( - std::vector(weights, weights + weight_count))); - return reinterpret_cast(reranker);) - return nullptr; + auto *mq = reinterpret_cast(query); + mq->rerank = zvec::reranker::RrfParams{rank_constant}; + return ZVEC_OK; } -void zvec_destroy_reranker(zvec_reranker_t *reranker) { - if (reranker) { - delete reinterpret_cast(reranker); +zvec_error_code_t zvec_multi_query_set_rerank_weighted( + zvec_multi_query_t *query, const double *weights, size_t weight_count) { + if (!query) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, "Query pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; } -} - -int zvec_get_reranker_rank_constant(const zvec_reranker_t *reranker) { - if (!reranker) return -1; - auto *ptr = reinterpret_cast(reranker); - auto *rrf = dynamic_cast(ptr->get()); - return rrf ? rrf->rank_constant() : -1; + if (!weights && weight_count > 0) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, "Weights pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *mq = reinterpret_cast(query); + mq->rerank = zvec::reranker::WeightedParams{ + std::vector(weights, weights + weight_count)}; + return ZVEC_OK; } // ============================================================================= @@ -5812,22 +5800,6 @@ zvec_error_code_t zvec_multi_query_get_output_fields( return ZVEC_OK; } -zvec_error_code_t zvec_multi_query_set_reranker( - zvec_multi_query_t *query, zvec_reranker_t *reranker) { - if (!query || !reranker) { - SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, - "Query or reranker pointer is null"); - return ZVEC_ERROR_INVALID_ARGUMENT; - } - - auto *mvq = reinterpret_cast(query); - auto *reranker_ptr = - reinterpret_cast(reranker); - mvq->reranker = *reranker_ptr; - - return ZVEC_OK; -} - // ============================================================================= // SubVectorQuery Implementation // ============================================================================= diff --git a/src/binding/python/model/python_reranker.cc b/src/binding/python/model/python_reranker.cc index 6543ecf63..ff5753f4f 100644 --- a/src/binding/python/model/python_reranker.cc +++ b/src/binding/python/model/python_reranker.cc @@ -13,69 +13,52 @@ // limitations under the License. #include "python_reranker.h" -#include #include #include #include -#include +#include namespace zvec { -namespace { - -inline void reranker_throw_if_error(const Status &status) { - switch (status.code()) { - case StatusCode::OK: - return; - case StatusCode::NOT_FOUND: - throw py::key_error(status.message()); - case StatusCode::INVALID_ARGUMENT: - throw py::value_error(status.message()); - default: - throw std::runtime_error(status.message()); - } -} - -inline DocPtrList unwrap_rerank_result(Result result) { - if (!result.has_value()) { - reranker_throw_if_error(result.error()); - } - return std::move(result).value(); -} - -} // namespace - void ZVecPyReranker::Initialize(py::module_ &m) { - // Bind Reranker base class (abstract, cannot be instantiated directly) - py::class_(m, "_Reranker") - .def( - "rerank", - [](const Reranker &self, const std::vector &query_results, - int topn) { - return unwrap_rerank_result(self.rerank(query_results, topn)); - }, - py::arg("query_results"), py::arg("topn") = 10); - - // Bind ScoreBasedReranker intermediate class - py::class_>( - m, "_ScoreBasedReranker"); - - // Bind RrfReranker - py::class_>( - m, "_RrfReranker") + // Bind RrfParams + py::class_(m, "_RrfParams") .def(py::init(), py::arg("rank_constant") = 60) - .def_property_readonly("rank_constant", &RrfReranker::rank_constant); + .def_readwrite("rank_constant", &reranker::RrfParams::rank_constant); - // Bind WeightedReranker - py::class_>(m, "_WeightedReranker") + // Bind WeightedParams + py::class_(m, "_WeightedParams") .def(py::init>(), py::arg("weights")) - .def_property_readonly("weights", &WeightedReranker::weights); + .def_readwrite("weights", &reranker::WeightedParams::weights); + + // Bind CallbackParams + py::class_(m, "_CallbackParams") + .def(py::init(), py::arg("callback")); - // Bind CallbackReranker - py::class_>( - m, "_CallbackReranker") - .def(py::init(), py::arg("callback")); + // Standalone rerank execution function + m.def( + "_reranker_rerank", + [](py::object params, const std::vector &results, + const std::vector &fields, int topn) -> DocPtrList { + reranker::RerankParams strategy; + if (py::isinstance(params)) { + strategy = params.cast(); + } else if (py::isinstance(params)) { + strategy = params.cast(); + } else if (py::isinstance(params)) { + strategy = params.cast(); + } else { + throw py::type_error( + "params must be _RrfParams, _WeightedParams, or _CallbackParams"); + } + auto result = reranker::rerank(strategy, results, fields, topn); + if (!result.has_value()) { + throw std::runtime_error(result.error().message()); + } + return std::move(result).value(); + }, + py::arg("params"), py::arg("results"), py::arg("fields"), + py::arg("topn")); // Bind MultiQuery struct py::class_(m, "_MultiQuery") @@ -85,7 +68,24 @@ void ZVecPyReranker::Initialize(py::module_ &m) { .def_readwrite("filter", &MultiQuery::filter) .def_readwrite("include_vector", &MultiQuery::include_vector) .def_readwrite("output_fields", &MultiQuery::output_fields) - .def_readwrite("reranker", &MultiQuery::reranker); + .def( + "set_rerank_rrf", + [](MultiQuery &q, int rank_constant) { + q.rerank = reranker::RrfParams{rank_constant}; + }, + py::arg("rank_constant") = 60) + .def( + "set_rerank_weighted", + [](MultiQuery &q, std::vector weights) { + q.rerank = reranker::WeightedParams{std::move(weights)}; + }, + py::arg("weights")) + .def( + "set_rerank_callback", + [](MultiQuery &q, reranker::CallbackParams::Callback callback) { + q.rerank = reranker::CallbackParams{std::move(callback)}; + }, + py::arg("callback")); } } // namespace zvec diff --git a/src/db/collection.cc b/src/db/collection.cc index e45b00144..bab103e5b 100644 --- a/src/db/collection.cc +++ b/src/db/collection.cc @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -1695,26 +1696,25 @@ Result CollectionImpl::Query(const MultiQuery &query) const { query.queries.size())); } - if (!query.reranker) { - return tl::make_unexpected(Status::InvalidArgument( - "Invalid query: MultiQuery requires a reranker")); - } - auto segments = get_all_segments(); if (segments.empty()) { return DocPtrList(); } // Convert each SubQuery to a SearchQuery and validate. - std::vector search_queries; - std::vector field_names; - search_queries.reserve(query.queries.size()); - field_names.reserve(query.queries.size()); + std::vector pending_queries; + std::vector field_schemas; + pending_queries.reserve(query.queries.size()); + field_schemas.reserve(query.queries.size()); for (const auto &sub : query.queries) { const auto &target = sub.target_; - - auto *field_schema = schema_->get_field(target.field_name_); + auto field_ptr = schema_->get_field_ptr(target.field_name_); + if (!field_ptr) { + return tl::make_unexpected(Status::InvalidArgument( + "Invalid query: field ", target.field_name_, " not found")); + } + auto *field_schema = field_ptr.get(); SearchQuery sq; sq.target_ = target; @@ -1726,45 +1726,44 @@ Result CollectionImpl::Query(const MultiQuery &query) const { auto s = sq.validate_and_sanitize(field_schema); CHECK_RETURN_STATUS_EXPECTED(s); - field_names.push_back(target.field_name_); - search_queries.push_back(std::move(sq)); + pending_queries.push_back(std::move(sq)); + field_schemas.push_back(std::move(field_ptr)); } - // Execute sub-queries. - auto execute_query = [&](SearchQuery &sq) -> Result { + auto execute_query = [&](SearchQuery &pending) -> Result { auto engine = sqlengine::SQLEngine::create(std::make_shared()); - return engine->execute(schema_, std::move(sq), segments); + return engine->execute(schema_, std::move(pending), segments); }; - std::vector> results(search_queries.size()); + std::vector> results(pending_queries.size()); // Single-segment queries have no segment-level fanout; multi-segment queries // already use the query pool per sub-query. if (segments.size() == 1) { auto group = GlobalResource::Instance().query_thread_pool()->make_group(); - for (size_t i = 0; i < search_queries.size(); ++i) { + for (size_t i = 0; i < pending_queries.size(); ++i) { group->execute( - [&, i]() { results[i] = execute_query(search_queries[i]); }); + [&, i]() { results[i] = execute_query(pending_queries[i]); }); } group->wait_finish(); } else { - for (size_t i = 0; i < search_queries.size(); ++i) { - results[i] = execute_query(search_queries[i]); + for (size_t i = 0; i < pending_queries.size(); ++i) { + results[i] = execute_query(pending_queries[i]); } } - // Collect results and rerank. std::vector query_results; - query_results.reserve(results.size()); - for (auto &result : results) { - if (!result) { - return tl::make_unexpected(result.error()); + query_results.reserve(pending_queries.size()); + for (size_t i = 0; i < pending_queries.size(); ++i) { + if (!results[i]) { + return tl::make_unexpected(results[i].error()); } - query_results.push_back(std::move(result.value())); + query_results.push_back(std::move(results[i].value())); } - query.reranker->bind_schema(schema_, field_names); - return query.reranker->rerank(query_results, query.topk); + // Dispatch rerank — schema info injected via field_schemas + return reranker::rerank(query.rerank, query_results, field_schemas, + query.topk); } Result CollectionImpl::GroupByQuery( diff --git a/src/db/reranker/reranker.cc b/src/db/reranker/reranker.cc index 9fb49be7d..684130123 100644 --- a/src/db/reranker/reranker.cc +++ b/src/db/reranker/reranker.cc @@ -19,16 +19,21 @@ #include #include #include +#include #include #include #include namespace zvec { +namespace { -// ==================== ScoreBasedReranker ==================== +// Shared score-based rerank logic used by RRF and Weighted. +// score_fn(doc_score, rank, field_index) -> contribution score +using ScoreFn = std::function(double, int, size_t)>; -Result ScoreBasedReranker::rerank( - const std::vector &query_results, int topn) const { +Result score_based_rerank(const ScoreFn &score_fn, + const std::vector &results, + int topn) { if (topn <= 0) { return DocPtrList(); } @@ -36,14 +41,13 @@ Result ScoreBasedReranker::rerank( std::unordered_map scores; std::unordered_map id_to_doc; - for (size_t query_index = 0; query_index < query_results.size(); - ++query_index) { - const auto &docs = query_results[query_index]; + for (size_t field_idx = 0; field_idx < results.size(); ++field_idx) { + const auto &docs = results[field_idx]; for (size_t rank = 0; rank < docs.size(); ++rank) { const auto &doc = docs[rank]; const std::string &doc_id = doc->pk(); - auto rs = rescore(static_cast(doc->score()), - static_cast(rank), static_cast(query_index)); + auto rs = score_fn(static_cast(doc->score()), + static_cast(rank), field_idx); if (!rs.has_value()) { return tl::make_unexpected(rs.error()); } @@ -69,52 +73,26 @@ Result ScoreBasedReranker::rerank( } } - DocPtrList results; - results.reserve(pq.size()); + DocPtrList result; + result.reserve(pq.size()); while (!pq.empty()) { const auto &[doc_id, score] = pq.top(); auto doc = std::move(id_to_doc[doc_id]); doc->set_score(static_cast(score)); - results.push_back(std::move(doc)); + result.push_back(std::move(doc)); pq.pop(); } - std::reverse(results.begin(), results.end()); - return results; + std::reverse(result.begin(), result.end()); + return result; } -// ==================== RrfReranker ==================== - -Result RrfReranker::rescore(double /*score*/, int rank, - int /*query_index*/) const { - return 1.0 / (static_cast(rank_constant_) + - static_cast(rank) + 1.0); -} - -// ==================== WeightedReranker ==================== - -WeightedReranker::WeightedReranker(const std::vector &weights) - : weights_(weights) {} - -void WeightedReranker::bind_schema( - CollectionSchema::Ptr schema, const std::vector &field_names) { - schema_ = std::move(schema); - field_names_ = field_names; -} - -Result WeightedReranker::normalize_score(double score, - const FieldSchema &field) { - // FTS field: BM25 scores are non-negative; normalize via arctan to [0, 1). +Result normalize_score(double score, const FieldSchema &field) { if (field.index_type() == IndexType::FTS) { + // Non-vector FTS/BM25 fields: map positive scores to [0, 1). return 2.0 * std::atan(score) / M_PI; } - auto *vip = dynamic_cast(field.index_params().get()); - if (!vip) { - return tl::make_unexpected( - Status::InvalidArgument("WeightedReranker: field '", field.name(), - "' has no vector index params")); - } switch (vip->metric_type()) { case MetricType::L2: return 1.0 - 2.0 * std::atan(score) / M_PI; @@ -129,33 +107,61 @@ Result WeightedReranker::normalize_score(double score, } } -Result WeightedReranker::rescore(double score, int /*rank*/, - int query_index) const { - if (!schema_) { - return tl::make_unexpected( - Status::InvalidArgument("WeightedReranker: schema is null")); - } - if (query_index < 0 || - static_cast(query_index) >= field_names_.size()) { - return tl::make_unexpected( - Status::InvalidArgument("WeightedReranker: query_index out of range: ", - std::to_string(query_index))); - } - const auto &field_name = field_names_[query_index]; - const auto *field = schema_->get_field(field_name); - if (!field) { - return tl::make_unexpected(Status::InvalidArgument( - "WeightedReranker: field not found: '", field_name + "'")); - } - auto normalized = normalize_score(score, *field); - if (!normalized.has_value()) { - return tl::make_unexpected(normalized.error()); - } - double weight = 1.0; - if (static_cast(query_index) < weights_.size()) { - weight = weights_[query_index]; - } - return normalized.value() * weight; +} // anonymous namespace + +namespace reranker { + +Result rerank(const RerankParams ¶ms, + const std::vector &results, + const std::vector &fields, + int topn) { + return std::visit( + [&](const auto &p) -> Result { + using T = std::decay_t; + + if constexpr (std::is_same_v) { + auto score_fn = [&p](double /*score*/, int rank, + size_t /*field_idx*/) -> Result { + return 1.0 / (static_cast(p.rank_constant) + + static_cast(rank) + 1.0); + }; + return score_based_rerank(score_fn, results, topn); + + } else if constexpr (std::is_same_v) { + if (p.weights.size() != results.size()) { + return tl::make_unexpected(Status::InvalidArgument( + "WeightedParams: weights count (", p.weights.size(), + ") != results count (", results.size(), ")")); + } + if (fields.size() != results.size()) { + return tl::make_unexpected(Status::InvalidArgument( + "WeightedParams: fields count (", fields.size(), + ") != results count (", results.size(), ")")); + } + auto score_fn = [&p, &fields](double score, int /*rank*/, + size_t field_idx) -> Result { + if (!fields[field_idx]) { + return tl::make_unexpected(Status::InvalidArgument( + "WeightedParams: null field schema at index ", field_idx)); + } + auto normalized = normalize_score(score, *fields[field_idx]); + if (!normalized.has_value()) { + return tl::make_unexpected(normalized.error()); + } + return normalized.value() * p.weights[field_idx]; + }; + return score_based_rerank(score_fn, results, topn); + + } else if constexpr (std::is_same_v) { + if (!p.callback) { + return tl::make_unexpected( + Status::InvalidArgument("CallbackParams: callback is empty")); + } + return p.callback(results, fields, topn); + } + }, + params); } +} // namespace reranker } // namespace zvec diff --git a/src/include/zvec/c_api.h b/src/include/zvec/c_api.h index 4adfaf9ea..500265fa2 100644 --- a/src/include/zvec/c_api.h +++ b/src/include/zvec/c_api.h @@ -1137,13 +1137,6 @@ typedef struct zvec_fts_t zvec_fts_t; */ typedef struct zvec_doc_t zvec_doc_t; -/** - * @brief Reranker structure (opaque pointer) - * Aligned with zvec::Reranker - * Use zvec_create_rrf_reranker() or zvec_create_weighted_reranker() to create - * and zvec_destroy_reranker() to destroy - */ -typedef struct zvec_reranker_t zvec_reranker_t; typedef struct zvec_collection_schema_t zvec_collection_schema_t; /** @@ -1952,39 +1945,27 @@ zvec_group_by_vector_query_set_flat_params( zvec_group_by_vector_query_t *query, zvec_flat_query_params_t *flat_params); // ----------------------------------------------------------------------------- -// zvec_reranker_t (Reranker) +// Rerank Strategy (set on MultiQuery) // ----------------------------------------------------------------------------- /** - * @brief Create an RRF (Reciprocal Rank Fusion) reranker + * @brief Set RRF rerank strategy on a multi-query. + * @param query Multi-query pointer * @param rank_constant RRF rank constant (default: 60) - * @return zvec_reranker_t* Pointer to the newly created reranker - */ -ZVEC_EXPORT zvec_reranker_t *ZVEC_CALL -zvec_create_rrf_reranker(int rank_constant); - -/** - * @brief Create a Weighted reranker - * @param weights Array of weights for each query - * @param weight_count Number of weight entries - * @return zvec_reranker_t* Pointer to the newly created reranker - */ -ZVEC_EXPORT zvec_reranker_t *ZVEC_CALL -zvec_create_weighted_reranker(const double *weights, size_t weight_count); - -/** - * @brief Destroy reranker - * @param reranker Reranker pointer + * @return Error code */ -ZVEC_EXPORT void ZVEC_CALL zvec_destroy_reranker(zvec_reranker_t *reranker); +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL +zvec_multi_query_set_rerank_rrf(zvec_multi_query_t *query, int rank_constant); /** - * @brief Get RRF rank constant (only valid for RRF reranker) - * @param reranker Reranker pointer - * @return int Rank constant, or -1 if not an RRF reranker + * @brief Set Weighted rerank strategy on a multi-query. + * @param query Multi-query pointer + * @param weights Array of per-sub-query weights + * @param weight_count Number of weights + * @return Error code */ -ZVEC_EXPORT int ZVEC_CALL -zvec_get_reranker_rank_constant(const zvec_reranker_t *reranker); +ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_multi_query_set_rerank_weighted( + zvec_multi_query_t *query, const double *weights, size_t weight_count); // ----------------------------------------------------------------------------- // zvec_multi_query_t (Multi Query) @@ -2094,17 +2075,6 @@ ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_multi_query_set_output_fields( ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_multi_query_get_output_fields( zvec_multi_query_t *query, const char ***fields, size_t *count); -/** - * @brief Set reranker (copies shared pointer, caller must still destroy - * reranker) - * @param query Multi-vector query pointer - * @param reranker Reranker pointer (remains valid, caller must call - * zvec_destroy_reranker after use) - * @return zvec_error_code_t Error code - */ -ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_multi_query_set_reranker( - zvec_multi_query_t *query, zvec_reranker_t *reranker); - // ----------------------------------------------------------------------------- // zvec_sub_query_t (Sub-Query for Multi Query) // ----------------------------------------------------------------------------- diff --git a/src/include/zvec/db/query.h b/src/include/zvec/db/query.h index 2dd42f0c0..abfa66f22 100644 --- a/src/include/zvec/db/query.h +++ b/src/include/zvec/db/query.h @@ -137,7 +137,7 @@ struct MultiQuery { // empty -> select no field // non-empty -> select only the listed fields std::optional> output_fields; - std::shared_ptr reranker{nullptr}; + reranker::RerankParams rerank; // Value semantics, defaults to RRF k=60 }; diff --git a/src/include/zvec/db/reranker.h b/src/include/zvec/db/reranker.h index c138f42ed..e6e9256be 100644 --- a/src/include/zvec/db/reranker.h +++ b/src/include/zvec/db/reranker.h @@ -14,132 +14,62 @@ #pragma once #include -#include -#include +#include #include #include #include -#include -#include "zvec/db/status.h" +#include namespace zvec { +namespace reranker { -//! Reranker abstract base class for re-ranking search results -class Reranker { - public: - using Ptr = std::shared_ptr; +// =========================================================================== +// Rerank parameter types (stateless, value semantics) +// =========================================================================== - Reranker() = default; - virtual ~Reranker() = default; - - virtual void bind_schema(CollectionSchema::Ptr /*schema*/, - const std::vector & /*field_names*/) {} - - //! Re-rank documents from one or more vector queries. - //! \param query_results Per-query lists of retrieved documents (sorted by - //! relevance), in the same order as the sub-queries supplied by the caller. - //! \param topn Maximum number of documents to return. - //! \return Re-ranked list of documents (length <= topn), with updated scores. - virtual Result rerank( - const std::vector &query_results, int topn = 10) const = 0; -}; - -//! Intermediate base for rerankers that compute per-document scores. -//! -//! Implements the common rerank() logic: iterate docs, call rescore() for each, -//! accumulate scores by doc_id, and return topn results in descending order. -//! Subclasses only need to implement rescore(). -class ScoreBasedReranker : public Reranker { - public: - Result rerank(const std::vector &query_results, - int topn = 10) const override; - - private: - //! Compute the contribution score for a single document. - //! \param score The document's raw relevance score from the vector query. - //! \param rank The document's position (0-based) in the per-query result - //! list. \param query_index The index (0-based) of the sub-query this result - //! came from. \return The score contribution to be accumulated for this - //! document. - virtual Result rescore(double score, int rank, - int query_index) const = 0; -}; - -//! Re-ranker using Reciprocal Rank Fusion (RRF) for multi-vector search. -//! -//! RRF combines results from multiple vector queries without requiring -//! relevance scores. The RRF score for a document at rank r is: -//! score = 1 / (k + r + 1) -//! where k is the rank constant. -class RrfReranker : public ScoreBasedReranker { - public: - explicit RrfReranker(int rank_constant = 60) - : rank_constant_(rank_constant) {} - - int rank_constant() const { - return rank_constant_; - } - - private: - Result rescore(double score, int rank, - int query_index) const override; - - int rank_constant_; +/// RRF (Reciprocal Rank Fusion) parameters. +/// Score formula: 1 / (rank_constant + rank + 1) +struct RrfParams { + int rank_constant = 60; }; -//! Re-ranker that combines scores from multiple vector fields using weights. -//! -//! Each vector field's relevance score is normalized based on its own metric -//! type, then scaled by a user-provided weight. Final scores are summed across -//! fields. Supported metrics: L2, IP, COSINE. -//! -//! @note NOT thread-safe. The bind_schema() and rerank() calls share mutable -//! state. Each concurrent query must use its own WeightedReranker instance or -//! serialize access externally. -class WeightedReranker : public ScoreBasedReranker { - public: - explicit WeightedReranker(const std::vector &weights = {}); - - void bind_schema(CollectionSchema::Ptr schema, - const std::vector &field_names) override; - - const std::vector &weights() const { - return weights_; - } - - private: - Result rescore(double score, int rank, - int query_index) const override; - - static Result normalize_score(double score, const FieldSchema &field); - - CollectionSchema::Ptr schema_; - std::vector field_names_; - std::vector weights_; +/// Weighted score fusion parameters. +/// Each sub-query's score is normalized by metric_type (handled internally), +/// then multiplied by the corresponding weight. +struct WeightedParams { + std::vector weights; }; -//! Callback-based re-ranker for cross-language bridging. -//! -//! Wraps a user-provided callback (e.g., a Python callable) as a Reranker. -//! When the callback is a Python function, GIL must be managed by the caller. -class CallbackReranker : public Reranker { - public: +/// Custom callback reranker parameters. +/// The callback receives all sub-query results, field schemas, and topn. +struct CallbackParams { using Callback = - std::function &, int)>; - - explicit CallbackReranker(Callback fn) : callback_(std::move(fn)) {} - - Result rerank(const std::vector &query_results, - int topn = 10) const override { - if (!callback_) { - return tl::make_unexpected( - Status::InvalidArgument("CallbackReranker: callback is empty")); - } - return callback_(query_results, topn); - } - - private: - Callback callback_; + std::function &, + const std::vector &, int)>; + Callback callback; }; +/// Type-safe rerank strategy — a tagged union of parameter types. +/// Defaults to RrfParams (first variant type) — works out of the box. +using RerankParams = std::variant; + +// =========================================================================== +// Public: Rerank execution API (stateless free function) +// =========================================================================== + +/// Unified rerank entry point. +/// Dispatches to the appropriate algorithm based on the variant type. +/// +/// @param params User-specified rerank params (variant value) +/// @param results Per-sub-query document lists (parallel to fields) +/// @param fields Per-sub-query FieldSchema::Ptr (for metric_type +/// normalization) +/// @param topn Maximum number of results to return +/// @return Re-ranked document list (length <= topn) +Result rerank(const RerankParams ¶ms, + const std::vector &results, + const std::vector &fields, + int topn); + +} // namespace reranker } // namespace zvec diff --git a/src/include/zvec/db/schema.h b/src/include/zvec/db/schema.h index 56ad4a064..c899f47fe 100644 --- a/src/include/zvec/db/schema.h +++ b/src/include/zvec/db/schema.h @@ -341,6 +341,11 @@ class CollectionSchema { const FieldSchema *get_field(const std::string &column) const; FieldSchema *get_field(const std::string &column); + + FieldSchema::Ptr get_field_ptr(const std::string &column) const { + auto it = fields_map_.find(column); + return it != fields_map_.end() ? it->second : nullptr; + } const FieldSchema *get_forward_field(const std::string &column) const; FieldSchema *get_forward_field(const std::string &column); const FieldSchema *get_vector_field(const std::string &column) const; diff --git a/tests/c/c_api_test.c b/tests/c/c_api_test.c index 80013da66..5a0c3a144 100644 --- a/tests/c/c_api_test.c +++ b/tests/c/c_api_test.c @@ -4366,48 +4366,6 @@ void test_fts_end_to_end(void) { TEST_END(); } -void test_reranker_functions(void) { - TEST_START(); - - // Test 1: Create RRF reranker - zvec_reranker_t *rrf = zvec_create_rrf_reranker(60); - TEST_ASSERT(rrf != NULL); - if (rrf) { - TEST_ASSERT(zvec_get_reranker_rank_constant(rrf) == 60); - zvec_destroy_reranker(rrf); - } - - // Test 2: Create RRF reranker with different rank constant - zvec_reranker_t *rrf2 = zvec_create_rrf_reranker(100); - TEST_ASSERT(rrf2 != NULL); - if (rrf2) { - TEST_ASSERT(zvec_get_reranker_rank_constant(rrf2) == 100); - zvec_destroy_reranker(rrf2); - } - - // Test 3: Create Weighted reranker - double weights[] = {0.7, 0.3}; - zvec_reranker_t *weighted = zvec_create_weighted_reranker(weights, 2); - TEST_ASSERT(weighted != NULL); - if (weighted) { - TEST_ASSERT(zvec_get_reranker_rank_constant(weighted) == -1); - zvec_destroy_reranker(weighted); - } - - // Test 4: Create Weighted reranker with no weights - zvec_reranker_t *weighted2 = zvec_create_weighted_reranker(NULL, 0); - TEST_ASSERT(weighted2 != NULL); - if (weighted2) { - zvec_destroy_reranker(weighted2); - } - - // Test 5: NULL reranker operations - TEST_ASSERT(zvec_get_reranker_rank_constant(NULL) == -1); - zvec_destroy_reranker(NULL); // Should not crash - - TEST_END(); -} - // ==================== Multi-query reranker test helpers ==================== typedef struct { @@ -4480,9 +4438,14 @@ static void teardown_multi_query_fixture(multi_query_fixture_t *f) { cleanup_temp_directory(f->temp_dir); } -static int execute_multi_query_with_reranker(const multi_query_fixture_t *f, - zvec_reranker_t *reranker, - int topk, int num_candidates) { +typedef enum { + MQ_RERANK_RRF, + MQ_RERANK_WEIGHTED, +} mq_rerank_kind_t; + +static int execute_multi_query_with_rerank( + const multi_query_fixture_t *f, mq_rerank_kind_t kind, int rank_constant, + const double *weights, size_t weight_count, int topk, int num_candidates) { zvec_multi_query_t *mvq = zvec_multi_query_create(); if (!mvq) return -1; zvec_multi_query_set_topk(mvq, topk); @@ -4500,7 +4463,11 @@ static int execute_multi_query_with_reranker(const multi_query_fixture_t *f, zvec_sub_query_set_num_candidates(vq2, num_candidates); zvec_multi_query_add_sub_query(mvq, vq2); - zvec_multi_query_set_reranker(mvq, reranker); + if (kind == MQ_RERANK_WEIGHTED) { + zvec_multi_query_set_rerank_weighted(mvq, weights, weight_count); + } else { + zvec_multi_query_set_rerank_rrf(mvq, rank_constant); + } zvec_doc_t **results = NULL; size_t result_count = 0; @@ -4527,15 +4494,11 @@ void test_multi_vector_query_with_rrf_reranker(void) { multi_query_fixture_t f; TEST_ASSERT(setup_multi_query_fixture(&f, "zvec_test_mq_rrf", "mq_rrf")); - zvec_reranker_t *rrf = zvec_create_rrf_reranker(60); - TEST_ASSERT(rrf != NULL); - - int count = execute_multi_query_with_reranker(&f, rrf, 3, 3); + int count = + execute_multi_query_with_rerank(&f, MQ_RERANK_RRF, 60, NULL, 0, 3, 3); TEST_ASSERT(count > 0); TEST_ASSERT(count <= 3); - zvec_destroy_reranker(rrf); - // MultiQuery property setters/getters zvec_multi_query_t *mvq2 = zvec_multi_query_create(); TEST_ASSERT(mvq2 != NULL); @@ -4589,14 +4552,12 @@ void test_multi_vector_query_with_weighted_reranker(void) { setup_multi_query_fixture(&f, "zvec_test_mq_weighted", "mq_weighted")); double weights[] = {0.7, 0.3}; - zvec_reranker_t *weighted = zvec_create_weighted_reranker(weights, 2); - TEST_ASSERT(weighted != NULL); - int count = execute_multi_query_with_reranker(&f, weighted, 3, 3); + int count = execute_multi_query_with_rerank(&f, MQ_RERANK_WEIGHTED, 0, + weights, 2, 3, 3); TEST_ASSERT(count > 0); TEST_ASSERT(count <= 3); - zvec_destroy_reranker(weighted); teardown_multi_query_fixture(&f); TEST_END(); @@ -5930,7 +5891,6 @@ int main(void) { test_fts_wiring_on_vector_query(); test_fts_end_to_end(); - test_reranker_functions(); test_multi_vector_query_with_rrf_reranker(); test_multi_vector_query_with_weighted_reranker(); // Performance tests diff --git a/tests/db/collection_test.cc b/tests/db/collection_test.cc index fd9e62711..4805df32f 100644 --- a/tests/db/collection_test.cc +++ b/tests/db/collection_test.cc @@ -3747,7 +3747,7 @@ TEST_F(CollectionTest, Feature_MultiQuery_Validate) { { MultiQuery mvq; mvq.topk = 10; - mvq.reranker = std::make_shared(60); + mvq.rerank = reranker::RrfParams{60}; auto result = collection->Query(mvq); ASSERT_FALSE(result.has_value()); EXPECT_EQ(result.error().code(), StatusCode::INVALID_ARGUMENT); @@ -3788,7 +3788,7 @@ TEST_F(CollectionTest, Feature_MultiQuery_Validate) { { MultiQuery mvq; mvq.topk = 10; - mvq.reranker = std::make_shared(60); + mvq.rerank = reranker::RrfParams{60}; SubQuery vq1; vq1.num_candidates_ = 10; @@ -3809,6 +3809,30 @@ TEST_F(CollectionTest, Feature_MultiQuery_Validate) { EXPECT_EQ(result.error().code(), StatusCode::INVALID_ARGUMENT); } + // Test 4: Duplicate field names should succeed (same field, different + // vectors) + { + MultiQuery mvq; + mvq.topk = 10; + mvq.rerank = reranker::RrfParams{60}; + + SubQuery vq1; + vq1.num_candidates_ = 10; + vq1.target_.field_name_ = "dense_fp32"; + std::get(vq1.target_.clause_) + .query_vector_.assign(128 * sizeof(float), '\0'); + mvq.queries.push_back(vq1); + + SubQuery vq2; + vq2.num_candidates_ = 10; + vq2.target_.field_name_ = "dense_fp32"; + std::get(vq2.target_.clause_) + .query_vector_.assign(128 * sizeof(float), '\0'); + mvq.queries.push_back(vq2); + + auto result = collection->Query(mvq); + ASSERT_TRUE(result.has_value()); + } } TEST_F(CollectionTest, Feature_MultiQuery_SingleFieldWithReranker) { @@ -3826,7 +3850,7 @@ TEST_F(CollectionTest, Feature_MultiQuery_SingleFieldWithReranker) { MultiQuery mvq; mvq.topk = 10; - mvq.reranker = std::make_shared(60); + mvq.rerank = reranker::RrfParams{60}; SubQuery vq; vq.num_candidates_ = 10; @@ -3857,7 +3881,7 @@ TEST_F(CollectionTest, Feature_MultiQuery_MultiFieldRRF) { MultiQuery mvq; mvq.topk = 10; - mvq.reranker = std::make_shared(60); + mvq.rerank = reranker::RrfParams{60}; // Query dense_fp32 and dense_fp16 fields with different vectors auto vector1 = query_doc.get>("dense_fp32"); @@ -3917,8 +3941,9 @@ TEST_F(CollectionTest, Feature_MultiQuery_MultiFieldWeighted) { MultiQuery mvq; mvq.topk = 10; - mvq.reranker = - std::make_shared(std::vector{0.7, 0.3}); + // Weights are positional, parallel to the sub-query order below + // (dense_fp32 first, sparse_fp32 second). + mvq.rerank = reranker::WeightedParams{{0.7, 0.3}}; // Query dense_fp32 field { @@ -3972,7 +3997,7 @@ TEST_F(CollectionTest, Feature_MultiQuery_WithFilter) { MultiQuery mvq; mvq.topk = 10; mvq.filter = "int32 > 50"; - mvq.reranker = std::make_shared(60); + mvq.rerank = reranker::RrfParams{60}; SubQuery vq1; vq1.num_candidates_ = 10; @@ -4022,7 +4047,7 @@ TEST_F(CollectionTest, Feature_MultiQuery_WithOutputFields) { mvq.include_vector = false; mvq.output_fields = std::make_optional>( std::vector{"int32", "string"}); - mvq.reranker = std::make_shared(60); + mvq.rerank = reranker::RrfParams{60}; SubQuery vq1; vq1.num_candidates_ = 10; @@ -4067,10 +4092,12 @@ TEST_F(CollectionTest, Feature_MultiQuery_CallbackReranker) { auto query_doc = TestHelper::CreateDoc(1, *schema); - // Use CallbackReranker with a lambda that merges and sorts by score + // Use a callback rerank strategy with a lambda that merges and sorts by + // score. bool callback_invoked = false; auto callback_fn = [&callback_invoked]( const std::vector &query_results, + const std::vector & /*fields*/, int topn) -> DocPtrList { callback_invoked = true; DocPtrList all_docs; @@ -4091,7 +4118,7 @@ TEST_F(CollectionTest, Feature_MultiQuery_CallbackReranker) { MultiQuery mvq; mvq.topk = 10; - mvq.reranker = std::make_shared(callback_fn); + mvq.rerank = reranker::CallbackParams{callback_fn}; // Query dense_fp32 field { diff --git a/tests/db/reranker_test.cc b/tests/db/reranker_test.cc index b41c123aa..9c0c4a01b 100644 --- a/tests/db/reranker_test.cc +++ b/tests/db/reranker_test.cc @@ -13,6 +13,7 @@ // limitations under the License. #define _USE_MATH_DEFINES +#include #include #include #include @@ -35,222 +36,245 @@ Doc::Ptr MakeDoc(const std::string &id, float score) { return doc; } -CollectionSchema::Ptr MakeSchema( - const std::vector> &fields) { - auto schema = std::make_shared("test"); - for (const auto &[name, metric] : fields) { - auto field = std::make_shared( - name, DataType::VECTOR_FP16, /*dimension=*/4, /*nullable=*/false, - std::make_shared(metric)); - schema->add_field(field); - } - return schema; +FieldSchema::Ptr MakeField(const std::string &name, MetricType metric) { + return std::make_shared( + name, DataType::VECTOR_FP16, /*dimension=*/4, /*nullable=*/false, + std::make_shared(metric)); } } // namespace -// ==================== RrfReranker Tests ==================== +// ==================== RRF Tests ==================== -TEST(RrfRerankerTest, BasicRRF) { - RrfReranker reranker(/*rank_constant=*/60); - - // Two vector fields, each returning 3 documents with some overlap - std::vector query_results; - query_results.push_back( +TEST(RerankRrfTest, BasicRRF) { + // Two sub-queries, each returning 3 documents with some overlap. + std::vector results; + results.push_back( {MakeDoc("a", 0.9f), MakeDoc("b", 0.8f), MakeDoc("c", 0.7f)}); - query_results.push_back( + results.push_back( {MakeDoc("b", 0.95f), MakeDoc("a", 0.85f), MakeDoc("d", 0.75f)}); - auto result = reranker.rerank(query_results, /*topn=*/10); + auto result = + reranker::rerank(reranker::RrfParams{/*rank_constant=*/60}, results, + /*fields=*/{}, /*topn=*/10); ASSERT_TRUE(result.has_value()); - auto &results = result.value(); + auto &out = result.value(); - // "a" appears at rank 0 in vec1 and rank 1 in vec2: + // "a" appears at rank 0 in sub-query 0 and rank 1 in sub-query 1: // rrf_score = 1/(60+0+1) + 1/(60+1+1) = 1/61 + 1/62 - // "b" appears at rank 1 in vec1 and rank 0 in vec2: + // "b" appears at rank 1 in sub-query 0 and rank 0 in sub-query 1: // rrf_score = 1/(60+1+1) + 1/(60+0+1) = 1/62 + 1/61 - // So a and b should have equal scores and be at the top - ASSERT_GE(results.size(), 3u); - - // "a" and "b" should have the highest RRF scores (equal, order unspecified) - std::set top2{results[0]->pk(), results[1]->pk()}; + // So a and b should have equal scores and occupy the top two slots. + ASSERT_GE(out.size(), 3u); + std::set top2 = {out[0]->pk(), out[1]->pk()}; EXPECT_EQ(top2, (std::set{"a", "b"})); - // Verify scores are close (a and b have same RRF score) - EXPECT_NEAR(results[0]->score(), results[1]->score(), 1e-10); + EXPECT_NEAR(out[0]->score(), out[1]->score(), 1e-10); } -TEST(RrfRerankerTest, Topn) { - RrfReranker reranker(/*rank_constant=*/60); - - std::vector query_results; - query_results.push_back( +TEST(RerankRrfTest, Topn) { + std::vector results; + results.push_back( {MakeDoc("a", 0.9f), MakeDoc("b", 0.8f), MakeDoc("c", 0.7f)}); - auto result = reranker.rerank(query_results, /*topn=*/2); + auto result = + reranker::rerank(reranker::RrfParams{/*rank_constant=*/60}, results, + /*fields=*/{}, /*topn=*/2); ASSERT_TRUE(result.has_value()); ASSERT_EQ(result.value().size(), 2u); } -TEST(RrfRerankerTest, SingleField) { - RrfReranker reranker(/*rank_constant=*/60); - - std::vector query_results; - query_results.push_back({MakeDoc("a", 0.9f), MakeDoc("b", 0.8f)}); +TEST(RerankRrfTest, SingleField) { + std::vector results; + results.push_back({MakeDoc("a", 0.9f), MakeDoc("b", 0.8f)}); - auto result = reranker.rerank(query_results); + auto result = + reranker::rerank(reranker::RrfParams{/*rank_constant=*/60}, results, + /*fields=*/{}, /*topn=*/10); ASSERT_TRUE(result.has_value()); - auto &results = result.value(); - ASSERT_EQ(results.size(), 2u); - // With single field, RRF score for rank 0 > rank 1 - EXPECT_GT(results[0]->score(), results[1]->score()); + auto &out = result.value(); + ASSERT_EQ(out.size(), 2u); + // With single sub-query, RRF score for rank 0 > rank 1. + EXPECT_GT(out[0]->score(), out[1]->score()); } -TEST(RrfRerankerTest, EmptyResults) { - RrfReranker reranker(/*rank_constant=*/60); - - std::vector query_results; - auto result = reranker.rerank(query_results); +TEST(RerankRrfTest, EmptyResults) { + std::vector results; + auto result = + reranker::rerank(reranker::RrfParams{/*rank_constant=*/60}, results, + /*fields=*/{}, /*topn=*/10); ASSERT_TRUE(result.has_value()); EXPECT_TRUE(result.value().empty()); } -// ==================== WeightedReranker Tests ==================== +TEST(RerankRrfTest, DefaultParams) { + // RrfParams (and therefore RerankParams) defaults to rank_constant = 60. + std::vector results; + results.push_back({MakeDoc("a", 0.9f), MakeDoc("b", 0.8f)}); -TEST(WeightedRerankerTest, BasicWeighted) { - auto schema = - MakeSchema({{"vec1", MetricType::L2}, {"vec2", MetricType::L2}}); - WeightedReranker reranker({0.7, 0.3}); - reranker.bind_schema(schema, {"vec1", "vec2"}); + auto result = + reranker::rerank(reranker::RerankParams{}, results, /*fields=*/{}, + /*topn=*/10); + ASSERT_TRUE(result.has_value()); + ASSERT_EQ(result.value().size(), 2u); +} + +// ==================== Weighted Tests ==================== - std::vector query_results; - query_results.push_back({MakeDoc("a", 0.5f), MakeDoc("b", 0.3f)}); - query_results.push_back({MakeDoc("a", 0.8f), MakeDoc("c", 0.6f)}); +TEST(RerankWeightedTest, BasicWeighted) { + std::vector results; + results.push_back({MakeDoc("a", 0.5f), MakeDoc("b", 0.3f)}); + results.push_back({MakeDoc("a", 0.8f), MakeDoc("c", 0.6f)}); + std::vector fields = {MakeField("vec1", MetricType::L2), + MakeField("vec2", MetricType::L2)}; - auto result = reranker.rerank(query_results); + auto result = + reranker::rerank(reranker::WeightedParams{{0.7, 0.3}}, results, fields, + /*topn=*/10); ASSERT_TRUE(result.has_value()); - auto &results = result.value(); - ASSERT_GE(results.size(), 2u); - // "a" appears in both fields, should have highest combined score - EXPECT_EQ(results[0]->pk(), "a"); + auto &out = result.value(); + ASSERT_GE(out.size(), 2u); + // "a" appears in both sub-queries, should have highest combined score. + EXPECT_EQ(out[0]->pk(), "a"); } -TEST(WeightedRerankerTest, MixedMetrics) { - auto schema = - MakeSchema({{"vec1", MetricType::L2}, {"vec2", MetricType::COSINE}}); - WeightedReranker reranker({0.5, 0.5}); - reranker.bind_schema(schema, {"vec1", "vec2"}); - - std::vector query_results; - query_results.push_back({MakeDoc("a", 0.5f)}); - query_results.push_back({MakeDoc("a", 0.4f)}); +TEST(RerankWeightedTest, MixedMetrics) { + std::vector results; + results.push_back({MakeDoc("a", 0.5f)}); + results.push_back({MakeDoc("a", 0.4f)}); + std::vector fields = { + MakeField("vec1", MetricType::L2), MakeField("vec2", MetricType::COSINE)}; - auto result = reranker.rerank(query_results); + auto result = + reranker::rerank(reranker::WeightedParams{{0.5, 0.5}}, results, fields, + /*topn=*/10); ASSERT_TRUE(result.has_value()); - auto &results = result.value(); - ASSERT_EQ(results.size(), 1u); - EXPECT_EQ(results[0]->pk(), "a"); - // L2 normalize(0.5) = 1 - 2*atan(0.5)/pi ≈ 0.7048 + auto &out = result.value(); + ASSERT_EQ(out.size(), 1u); + EXPECT_EQ(out[0]->pk(), "a"); + // L2 normalize(0.5) = 1 - 2*atan(0.5)/pi // COSINE normalize(0.4) = 1 - 0.4/2 = 0.8 - // weighted = 0.7048 * 0.5 + 0.8 * 0.5 ≈ 0.7524 + // weighted = l2_norm * 0.5 + cos_norm * 0.5 double l2_norm = 1.0 - 2.0 * std::atan(0.5) / M_PI; double cos_norm = 1.0 - 0.4 / 2.0; double expected = l2_norm * 0.5 + cos_norm * 0.5; - EXPECT_NEAR(results[0]->score(), expected, 1e-5); + EXPECT_NEAR(out[0]->score(), expected, 1e-5); +} + +TEST(RerankWeightedTest, WeightsCountMismatch) { + std::vector results; + results.push_back({MakeDoc("a", 0.5f)}); + results.push_back({MakeDoc("b", 0.3f)}); + std::vector fields = {MakeField("vec1", MetricType::L2), + MakeField("vec2", MetricType::L2)}; + + // Only one weight provided for two sub-queries. + auto result = + reranker::rerank(reranker::WeightedParams{{1.0}}, results, fields, + /*topn=*/10); + ASSERT_FALSE(result.has_value()); } -TEST(WeightedRerankerTest, MissingMetricError) { - auto schema = MakeSchema({{"vec1", MetricType::L2}}); - WeightedReranker reranker; - // Binding a field that is absent from the schema should fail at rerank time. - reranker.bind_schema(schema, {"vec1", "vec2"}); +TEST(RerankWeightedTest, FieldsCountMismatch) { + std::vector results; + results.push_back({MakeDoc("a", 0.5f)}); + results.push_back({MakeDoc("b", 0.3f)}); + std::vector fields = {MakeField("vec1", MetricType::L2)}; - std::vector query_results; - query_results.push_back({MakeDoc("a", 0.5f)}); - query_results.push_back({MakeDoc("b", 0.3f)}); - auto result = reranker.rerank(query_results); + auto result = + reranker::rerank(reranker::WeightedParams{{0.5, 0.5}}, results, fields, + /*topn=*/10); ASSERT_FALSE(result.has_value()); } -TEST(WeightedRerankerTest, NormalizeL2) { - auto schema = MakeSchema({{"vec1", MetricType::L2}}); - WeightedReranker reranker; - reranker.bind_schema(schema, {"vec1"}); +TEST(RerankWeightedTest, NullFieldError) { + std::vector results; + results.push_back({MakeDoc("a", 0.5f)}); + std::vector fields = {nullptr}; - std::vector query_results; - query_results.push_back({MakeDoc("a", 0.0f), MakeDoc("b", 1.0f)}); + auto result = + reranker::rerank(reranker::WeightedParams{{1.0}}, results, fields, + /*topn=*/10); + ASSERT_FALSE(result.has_value()); +} + +TEST(RerankWeightedTest, NormalizeL2) { + std::vector results; + results.push_back({MakeDoc("a", 0.0f), MakeDoc("b", 1.0f)}); + std::vector fields = {MakeField("vec1", MetricType::L2)}; - auto result = reranker.rerank(query_results); + auto result = + reranker::rerank(reranker::WeightedParams{{1.0}}, results, fields, + /*topn=*/10); ASSERT_TRUE(result.has_value()); - auto &results = result.value(); - ASSERT_EQ(results.size(), 2u); - // L2 normalize(0.0) = 1.0, normalize(1.0) ∈ (0, 1) - EXPECT_NEAR(results[0]->score(), 1.0, 1e-10); - EXPECT_EQ(results[0]->pk(), "a"); - EXPECT_GT(results[1]->score(), 0.0); - EXPECT_LT(results[1]->score(), 1.0); + auto &out = result.value(); + ASSERT_EQ(out.size(), 2u); + // L2 normalize(0.0) = 1.0, normalize(1.0) in (0, 1) + EXPECT_NEAR(out[0]->score(), 1.0, 1e-10); + EXPECT_EQ(out[0]->pk(), "a"); + EXPECT_GT(out[1]->score(), 0.0); + EXPECT_LT(out[1]->score(), 1.0); } -TEST(WeightedRerankerTest, NormalizeIP) { - auto schema = MakeSchema({{"vec1", MetricType::IP}}); - WeightedReranker reranker; - reranker.bind_schema(schema, {"vec1"}); +TEST(RerankWeightedTest, NormalizeIP) { + std::vector results; + results.push_back({MakeDoc("a", 0.0f), MakeDoc("b", 1.0f)}); + std::vector fields = {MakeField("vec1", MetricType::IP)}; - std::vector query_results; - query_results.push_back({MakeDoc("a", 0.0f), MakeDoc("b", 1.0f)}); - - auto result = reranker.rerank(query_results); + auto result = + reranker::rerank(reranker::WeightedParams{{1.0}}, results, fields, + /*topn=*/10); ASSERT_TRUE(result.has_value()); - auto &results = result.value(); - ASSERT_EQ(results.size(), 2u); - // IP normalize(1.0) > 0.5 > normalize(0.0) = 0.5... but b scores higher - EXPECT_EQ(results[0]->pk(), "b"); - EXPECT_GT(results[0]->score(), 0.5); - EXPECT_NEAR(results[1]->score(), 0.5, 1e-10); + auto &out = result.value(); + ASSERT_EQ(out.size(), 2u); + // IP normalize(1.0) > 0.5 > normalize(0.0) = 0.5 + EXPECT_EQ(out[0]->pk(), "b"); + EXPECT_GT(out[0]->score(), 0.5); + EXPECT_NEAR(out[1]->score(), 0.5, 1e-10); } -TEST(WeightedRerankerTest, NormalizeCosine) { - auto schema = MakeSchema({{"vec1", MetricType::COSINE}}); - WeightedReranker reranker; - reranker.bind_schema(schema, {"vec1"}); - - std::vector query_results; - query_results.push_back( +TEST(RerankWeightedTest, NormalizeCosine) { + std::vector results; + results.push_back( {MakeDoc("a", 0.0f), MakeDoc("b", 1.0f), MakeDoc("c", 2.0f)}); + std::vector fields = { + MakeField("vec1", MetricType::COSINE)}; - auto result = reranker.rerank(query_results); + auto result = + reranker::rerank(reranker::WeightedParams{{1.0}}, results, fields, + /*topn=*/10); ASSERT_TRUE(result.has_value()); - auto &results = result.value(); - ASSERT_EQ(results.size(), 3u); + auto &out = result.value(); + ASSERT_EQ(out.size(), 3u); // COSINE normalize(0.0) = 1.0, normalize(1.0) = 0.5, normalize(2.0) = 0.0 - EXPECT_NEAR(results[0]->score(), 1.0, 1e-10); - EXPECT_NEAR(results[1]->score(), 0.5, 1e-10); - EXPECT_NEAR(results[2]->score(), 0.0, 1e-10); + EXPECT_NEAR(out[0]->score(), 1.0, 1e-10); + EXPECT_NEAR(out[1]->score(), 0.5, 1e-10); + EXPECT_NEAR(out[2]->score(), 0.0, 1e-10); } -TEST(WeightedRerankerTest, Topn) { - auto schema = MakeSchema({{"vec1", MetricType::L2}}); - WeightedReranker reranker; - reranker.bind_schema(schema, {"vec1"}); - - std::vector query_results; - query_results.push_back( +TEST(RerankWeightedTest, Topn) { + std::vector results; + results.push_back( {MakeDoc("a", 0.1f), MakeDoc("b", 0.2f), MakeDoc("c", 0.3f)}); + std::vector fields = {MakeField("vec1", MetricType::L2)}; - auto result = reranker.rerank(query_results, /*topn=*/2); + auto result = + reranker::rerank(reranker::WeightedParams{{1.0}}, results, fields, + /*topn=*/2); ASSERT_TRUE(result.has_value()); ASSERT_EQ(result.value().size(), 2u); } +// ==================== Callback Tests ==================== -// ==================== CallbackReranker Tests ==================== - -TEST(CallbackRerankerTest, BasicCallback) { +TEST(RerankCallbackTest, BasicCallback) { // Simple callback that returns docs sorted by score descending, limited to - // topn - CallbackReranker::Callback cb = - [](const std::vector &query_results, int topn) -> DocPtrList { + // topn. + reranker::CallbackParams::Callback cb = + [](const std::vector &results, + const std::vector & /*fields*/, + int topn) -> DocPtrList { DocPtrList all_docs; - for (const auto &docs : query_results) { + for (const auto &docs : results) { for (const auto &doc : docs) { all_docs.push_back(doc); } @@ -265,18 +289,27 @@ TEST(CallbackRerankerTest, BasicCallback) { return all_docs; }; - CallbackReranker reranker(cb); - - std::vector query_results; - query_results.push_back({MakeDoc("a", 0.5f), MakeDoc("b", 0.9f)}); - query_results.push_back({MakeDoc("c", 0.7f)}); + std::vector results; + results.push_back({MakeDoc("a", 0.5f), MakeDoc("b", 0.9f)}); + results.push_back({MakeDoc("c", 0.7f)}); - auto result = reranker.rerank(query_results, /*topn=*/10); + auto result = + reranker::rerank(reranker::CallbackParams{cb}, results, /*fields=*/{}, + /*topn=*/10); ASSERT_TRUE(result.has_value()); - auto &results = result.value(); - ASSERT_EQ(results.size(), 3u); - // Should be sorted by score descending - EXPECT_EQ(results[0]->pk(), "b"); - EXPECT_EQ(results[1]->pk(), "c"); - EXPECT_EQ(results[2]->pk(), "a"); + auto &out = result.value(); + ASSERT_EQ(out.size(), 3u); + // Should be sorted by score descending. + EXPECT_EQ(out[0]->pk(), "b"); + EXPECT_EQ(out[1]->pk(), "c"); + EXPECT_EQ(out[2]->pk(), "a"); +} + +TEST(RerankCallbackTest, EmptyCallbackError) { + reranker::CallbackParams params; // callback is empty + std::vector results; + results.push_back({MakeDoc("a", 0.5f)}); + + auto result = reranker::rerank(params, results, /*fields=*/{}, /*topn=*/10); + ASSERT_FALSE(result.has_value()); }