Skip to content
Merged
9 changes: 7 additions & 2 deletions python/tests/detail/test_collection_dql.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,7 @@ def test_query_multivector_rrf(self, full_collection: Collection, doc_num):
rrf_reranker = RrfReRanker()
multi_query_result = full_collection.query(
multi_query_vectors,
topk=3,
reranker=rrf_reranker,
)
assert len(multi_query_result) > 0, (
Expand Down Expand Up @@ -876,8 +877,11 @@ def test_query_multivector_weighted(
batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert")
doc_fields, doc_vectors = generate_vectordict_random(full_collection.schema)

weight_list = [weights[v] for v in DEFAULT_VECTOR_FIELD_NAME.values()]
weighted_reranker = WeightedReRanker(weights=weight_list)
# Weights are positional, aligned with the multi_query_vectors order
# (DEFAULT_VECTOR_FIELD_NAME insertion order). Metric normalization is
# automatic from each field's schema.
weights_list = [weights[v] for v in DEFAULT_VECTOR_FIELD_NAME.values()]
weighted_reranker = WeightedReRanker(weights_list)

single_query_results = {}
for k, v in DEFAULT_VECTOR_FIELD_NAME.items():
Expand All @@ -894,6 +898,7 @@ def test_query_multivector_weighted(

multi_query_result = full_collection.query(
multi_query_vectors,
topk=3,
reranker=weighted_reranker,
)
assert len(multi_query_result) > 0, (
Expand Down
14 changes: 8 additions & 6 deletions python/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
InvertIndexParam,
LogLevel,
LogType,
MetricType,
OptimizeOption,
StatusCode,
Query,
Expand Down Expand Up @@ -1105,7 +1104,8 @@ def test_collection_query_with_weighted_reranker_by_multi_dense_vector(
self, collection_with_multiple_docs: Collection, multiple_docs
):
"""Test multi-vector query with Weighted reranker on multiple dense vectors."""
reranker = WeightedReRanker(weights=[0.6, 0.4])
weights = [0.6, 0.4]
reranker = WeightedReRanker(weights=weights)
result = collection_with_multiple_docs.query(
[
Query(field_name="dense", vector=multiple_docs[0].vector("dense")),
Expand All @@ -1121,7 +1121,8 @@ def test_collection_query_with_weighted_reranker_by_multi_sparse_vector(
self, collection_with_multiple_docs: Collection, multiple_docs
):
"""Test multi-vector query with Weighted reranker on multiple sparse vectors."""
reranker = WeightedReRanker(weights=[0.6, 0.4])
weights = [0.6, 0.4]
reranker = WeightedReRanker(weights=weights)
result = collection_with_multiple_docs.query(
[
Query(field_name="sparse", vector=multiple_docs[0].vector("sparse")),
Expand All @@ -1140,7 +1141,8 @@ def test_collection_query_with_weighted_reranker_by_hybrid_vector(
self, collection_with_multiple_docs: Collection, multiple_docs
):
"""Test multi-vector query with Weighted reranker combining dense + sparse."""
reranker = WeightedReRanker(weights=[0.7, 0.3])
weights = [0.7, 0.3]
reranker = WeightedReRanker(weights=weights)
result = collection_with_multiple_docs.query(
[
Query(field_name="dense", vector=multiple_docs[0].vector("dense")),
Expand All @@ -1158,7 +1160,7 @@ def test_collection_query_with_callback_reranker_by_multi_dense_vector(
"""Test multi-vector query with CallbackReRanker (Python callback via C++)."""
callback_invoked = []

def my_rerank_callback(query_results, topn):
def my_rerank_callback(query_results, fields, topn):
callback_invoked.append(True)
all_docs = []
for docs in query_results:
Expand Down Expand Up @@ -1190,7 +1192,7 @@ def test_collection_query_with_callback_reranker_by_hybrid_vector(
):
"""Test multi-vector query with CallbackReRanker combining dense + sparse."""

def my_rerank_callback(query_results, topn):
def my_rerank_callback(query_results, fields, topn):
all_docs = []
for docs in query_results:
all_docs.extend(docs)
Expand Down
Loading
Loading