diff --git a/faiss/IndexIVFRaBitQ.cpp b/faiss/IndexIVFRaBitQ.cpp index f514633e30..58df501d85 100644 --- a/faiss/IndexIVFRaBitQ.cpp +++ b/faiss/IndexIVFRaBitQ.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include @@ -321,6 +322,44 @@ void IndexIVFRaBitQ::sa_decode(idx_t n, const uint8_t* codes, float* x) const { } } +void IndexIVFRaBitQ::compute_distance_to_codes_for_list( + const idx_t list_no, + const float* x, + idx_t n, + const uint8_t* codes, + float* dists, + float* /*dist_table*/) const { + FAISS_THROW_IF_NOT(n >= 0); + FAISS_THROW_IF_NOT(list_no >= 0 && (size_t)list_no < nlist); + FAISS_THROW_IF_NOT(x != nullptr); + FAISS_THROW_IF_NOT(codes != nullptr); + FAISS_THROW_IF_NOT(dists != nullptr); + FAISS_THROW_IF_NOT(code_size > 0); + if (n == 0) { + return; + } + FAISS_THROW_IF_NOT( + (size_t)n <= (std::numeric_limits::max)() / code_size); + + // RaBitQ uses per-vector correction factors stored in the codes, so we + // must use the RaBitQuantizer distance computer. + std::vector centroid(d); + quantizer->reconstruct(list_no, centroid.data()); + + // Note: "centered" query quantization is a per-search parameter in Faiss. + // compute_distance_to_codes_for_list does not take IVFSearchParameters, so + // we use centered=false here (consistent with get_distance_computer()). In + // future, we can look into setting centered and qb per call if needed. + std::unique_ptr dc( + rabitq.get_distance_computer(qb, centroid.data(), /*centered=*/false)); + dc->set_query(x); + + const uint8_t* code = codes; + for (idx_t i = 0; i < n; i++, code += code_size) { + dists[i] = dc->distance_to_code(code); + } +} + struct IVFRaBitDistanceComputer : DistanceComputer { const float* q = nullptr; const IndexIVFRaBitQ* parent = nullptr; diff --git a/faiss/IndexIVFRaBitQ.h b/faiss/IndexIVFRaBitQ.h index 7d59df7d33..74666259b1 100644 --- a/faiss/IndexIVFRaBitQ.h +++ b/faiss/IndexIVFRaBitQ.h @@ -72,6 +72,14 @@ struct IndexIVFRaBitQ : IndexIVF { void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override; + void compute_distance_to_codes_for_list( + const idx_t list_no, + const float* x, + idx_t n, + const uint8_t* codes, + float* dists, + float* dist_table) const override; + // unfortunately DistanceComputer* get_distance_computer() const override; };