From 16c098fb0cfc7960879c4b1c8651a3b08a4ffe1b Mon Sep 17 00:00:00 2001 From: dudegladiator Date: Sun, 7 Jun 2026 14:50:24 +0530 Subject: [PATCH 1/4] feat: add TurboQuant quantizer core Data-oblivious vector quantizer: random rotation + per-coordinate Lloyd-Max codebooks (b=1..4), MSE quantizer, and an unbiased two-stage inner-product estimator (MSE + 1-bit QJL residual). Pure NumPy, no storage or app coupling. Verified against the paper's distortion bounds and estimator unbiasedness. --- src/cocoindex_code/turbo_quant.py | 350 ++++++++++++++++++++++++++++++ tests/test_turbo_quant.py | 200 +++++++++++++++++ 2 files changed, 550 insertions(+) create mode 100644 src/cocoindex_code/turbo_quant.py create mode 100644 tests/test_turbo_quant.py diff --git a/src/cocoindex_code/turbo_quant.py b/src/cocoindex_code/turbo_quant.py new file mode 100644 index 0000000..b71f417 --- /dev/null +++ b/src/cocoindex_code/turbo_quant.py @@ -0,0 +1,350 @@ +"""TurboQuant: data-oblivious vector quantization with near-optimal distortion. + +Implements the algorithm from Zandieh et al., "TurboQuant: Online Vector +Quantization with Near-optimal Distortion Rate" (arXiv:2504.19874). + +Two quantizers are provided: + +* **MSE quantizer** (``quantize_mse`` / ``dequantize_mse``): randomly rotates the + input, then applies an optimal per-coordinate Lloyd-Max scalar quantizer. The + rotation makes every coordinate follow a Beta distribution that converges to + ``N(0, 1/d)`` in high dimensions, so a single precomputed codebook (solved for + the standard normal and scaled by ``1/sqrt(d)``) is near-optimal per coordinate. + Minimizes reconstruction MSE but is *biased* for inner-product estimation. + +* **Inner-product quantizer** (``quantize_prod`` / ``inner_product_prod``): the + two-stage scheme. Applies the MSE quantizer at ``bits - 1`` bits, then a 1-bit + Quantized Johnson-Lindenstrauss (QJL) transform on the residual. The result is + an *unbiased* inner-product estimator (paper Theorem 2). + +The rotation matrix ``Pi`` and the QJL projection ``S`` are derived from a single +integer ``seed`` so an index is fully reproducible and only the seed (not the +matrices) needs to be persisted. + +This module is intentionally free of any cocoindex / SQLite coupling so it can be +unit-tested in isolation. +""" + +from __future__ import annotations + +import numpy as np +import numpy.typing as npt + +__all__ = [ + "SUPPORTED_BITS", + "TurboQuant", + "gaussian_lloyd_max", + "pack_indices", + "unpack_indices", + "pack_signs", + "unpack_signs", +] + +# Bit-widths we precompute codebooks for and support end-to-end. +SUPPORTED_BITS = (1, 2, 3, 4) + +# Offset added to the base seed when deriving the QJL projection matrix, so Pi +# and S come from independent draws of the same seeded generator family. +_QJL_SEED_OFFSET = 0x5F3759DF + + +# --------------------------------------------------------------------------- +# Codebook computation (Lloyd-Max for the standard normal) +# --------------------------------------------------------------------------- + + +def gaussian_lloyd_max( + bits: int, + *, + grid_points: int = 1 << 16, + grid_limit: float = 8.0, + max_iter: int = 200, + tol: float = 1e-9, +) -> npt.NDArray[np.float64]: + """Solve the optimal ``2**bits``-level scalar quantizer for ``N(0, 1)``. + + Uses deterministic grid quadrature (no RNG, no scipy): the real line is + approximated by a fine grid weighted by the normal density, and Lloyd-Max + iteration alternates between Voronoi assignment and conditional-mean centroid + updates until convergence. + + Returns the sorted centroids for the standard normal. Callers scale these by + ``1/sqrt(d)`` to match the ``N(0, 1/d)`` coordinate distribution of a rotated + unit vector. + """ + if bits < 1: + raise ValueError(f"bits must be >= 1, got {bits}") + + levels = 1 << bits + x = np.linspace(-grid_limit, grid_limit, grid_points) + # Normal density (unnormalized is fine — only relative weights matter). + w = np.exp(-0.5 * x * x) + + # Initialize centroids at evenly spaced density quantile-ish positions. + centroids = np.linspace(-grid_limit / 2, grid_limit / 2, levels) + + prev_distortion = np.inf + for _ in range(max_iter): + # Assign each grid point to the nearest centroid. + # boundaries are midpoints between sorted centroids. + boundaries = (centroids[:-1] + centroids[1:]) / 2.0 + assign = np.searchsorted(boundaries, x) + + # Conditional mean per cluster (weighted by density). + new_centroids = centroids.copy() + for k in range(levels): + mask = assign == k + wk = w[mask] + total = wk.sum() + if total > 0: + new_centroids[k] = (x[mask] * wk).sum() / total + # Distortion for convergence check. + distortion = float((w * (x - new_centroids[assign]) ** 2).sum() / w.sum()) + centroids = new_centroids + if abs(prev_distortion - distortion) <= tol: + break + prev_distortion = distortion + + centroids.sort() + return centroids + + +# Precompute standard-normal codebooks once at import for the supported bits. +_NORMAL_CODEBOOKS: dict[int, npt.NDArray[np.float64]] = { + b: gaussian_lloyd_max(b) for b in SUPPORTED_BITS +} + + +# --------------------------------------------------------------------------- +# Bit packing +# --------------------------------------------------------------------------- + + +def pack_indices(indices: npt.NDArray[np.integer], bits: int) -> bytes: + """Pack an array of ``bits``-wide integer indices into a byte string. + + MSB-first within each index. The packed stream is zero-padded to a byte + boundary; ``unpack_indices`` must be given the original length to trim it. + """ + idx = np.asarray(indices, dtype=np.uint64) + if idx.size == 0: + return b"" + if bits < 1 or bits > 8: + raise ValueError(f"bits must be in 1..8, got {bits}") + shifts = np.arange(bits - 1, -1, -1, dtype=np.uint64) + bit_matrix = ((idx[:, None] >> shifts) & np.uint64(1)).astype(np.uint8) + return np.packbits(bit_matrix.reshape(-1)).tobytes() + + +def unpack_indices(packed: bytes, count: int, bits: int) -> npt.NDArray[np.int64]: + """Inverse of :func:`pack_indices`. Returns ``count`` integer indices.""" + if count == 0: + return np.empty(0, dtype=np.int64) + raw = np.frombuffer(packed, dtype=np.uint8) + bit_stream = np.unpackbits(raw)[: count * bits].reshape(count, bits) + weights = (1 << np.arange(bits - 1, -1, -1)).astype(np.int64) + return bit_stream.astype(np.int64) @ weights + + +def pack_signs(signs: npt.NDArray[np.floating | np.integer]) -> bytes: + """Pack a ``+1/-1`` vector into a bit string (``+1`` -> 1, ``-1`` -> 0).""" + s = np.asarray(signs) + if s.size == 0: + return b"" + bit = (s > 0).astype(np.uint8) + return np.packbits(bit).tobytes() + + +def unpack_signs(packed: bytes, count: int) -> npt.NDArray[np.float32]: + """Inverse of :func:`pack_signs`. Returns a ``+1/-1`` float32 vector.""" + if count == 0: + return np.empty(0, dtype=np.float32) + raw = np.frombuffer(packed, dtype=np.uint8) + bit = np.unpackbits(raw)[:count] + return np.where(bit > 0, np.float32(1.0), np.float32(-1.0)).astype(np.float32) + + +# --------------------------------------------------------------------------- +# TurboQuant +# --------------------------------------------------------------------------- + + +class TurboQuant: + """Reproducible TurboQuant quantizer for a fixed ``dim`` / ``bits`` / ``seed``. + + ``bits`` is the *target* bit-width. The MSE stage uses the full ``bits`` for + :meth:`quantize_mse`. For the inner-product (``prod``) scheme the MSE stage + uses ``bits - 1`` and the remaining 1 bit is spent on the QJL residual; a + ``bits == 1`` prod quantizer therefore uses a 0-bit MSE stage (no MSE term, + pure QJL). + """ + + def __init__(self, dim: int, bits: int, seed: int = 0) -> None: + if dim < 1: + raise ValueError(f"dim must be >= 1, got {dim}") + if bits not in SUPPORTED_BITS: + raise ValueError(f"bits must be one of {SUPPORTED_BITS}, got {bits}") + self.dim = dim + self.bits = bits + self.seed = seed + + self._rotation = _random_rotation(dim, seed) + self._qjl = _random_projection(dim, seed + _QJL_SEED_OFFSET) + # MSE-stage bit-width for the two-stage prod scheme. + self._mse_bits = bits - 1 + # Scaled codebooks (centroids for N(0, 1/d)). + self._scale = 1.0 / np.sqrt(dim) + + # -- codebook access ---------------------------------------------------- + + def _codebook(self, mse_bits: int) -> npt.NDArray[np.float32]: + """Scaled centroids for the given MSE bit-width (>=1).""" + scaled: npt.NDArray[np.float32] = (_NORMAL_CODEBOOKS[mse_bits] * self._scale).astype( + np.float32 + ) + return scaled + + # -- MSE quantizer ------------------------------------------------------ + + def quantize_mse(self, vec: npt.NDArray[np.floating]) -> tuple[npt.NDArray[np.int64], float]: + """Quantize ``vec`` with the MSE quantizer at the full target ``bits``. + + Returns ``(indices, norm)`` where ``indices`` are the per-coordinate + codebook indices of the rotated, unit-normalized vector and ``norm`` is + the original L2 norm (used to rescale on dequantization). + """ + return self._quantize_mse_core(vec, self.bits) + + def dequantize_mse( + self, indices: npt.NDArray[np.int64], norm: float + ) -> npt.NDArray[np.float32]: + """Reconstruct a vector from MSE indices produced at the full ``bits``.""" + return self._dequantize_mse_core(indices, norm, self.bits) + + def _quantize_mse_core( + self, vec: npt.NDArray[np.floating], mse_bits: int + ) -> tuple[npt.NDArray[np.int64], float]: + v = np.asarray(vec, dtype=np.float32) + norm = float(np.linalg.norm(v)) + if norm == 0.0 or mse_bits < 1: + return np.zeros(self.dim, dtype=np.int64), norm + u = v / norm + y = self._rotation @ u # rotated unit vector + codebook = self._codebook(mse_bits) + # Nearest centroid per coordinate. searchsorted on midpoints is O(d log L). + boundaries = (codebook[:-1] + codebook[1:]) / 2.0 + indices = np.searchsorted(boundaries, y).astype(np.int64) + return indices, norm + + def _dequantize_mse_core( + self, indices: npt.NDArray[np.int64], norm: float, mse_bits: int + ) -> npt.NDArray[np.float32]: + if norm == 0.0 or mse_bits < 1: + return np.zeros(self.dim, dtype=np.float32) + codebook = self._codebook(mse_bits) + y_hat = codebook[indices] + u_hat = self._rotation.T @ y_hat # rotate back + return (u_hat * norm).astype(np.float32) + + # -- inner-product (two-stage) quantizer -------------------------------- + + def quantize_prod( + self, vec: npt.NDArray[np.floating] + ) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.float32], float, float]: + """Quantize ``vec`` with the unbiased two-stage inner-product scheme. + + Returns ``(mse_indices, qjl_signs, residual_norm, norm)``: + + * ``mse_indices`` — MSE-stage indices at ``bits - 1`` (all zeros when + ``bits == 1``). + * ``qjl_signs`` — ``+1/-1`` vector of length ``dim`` (sign of ``S @ r``). + * ``residual_norm`` — L2 norm of the unit-space residual ``r``. + * ``norm`` — original L2 norm of ``vec``. + """ + v = np.asarray(vec, dtype=np.float32) + norm = float(np.linalg.norm(v)) + if norm == 0.0: + return ( + np.zeros(self.dim, dtype=np.int64), + np.ones(self.dim, dtype=np.float32), + 0.0, + 0.0, + ) + u = v / norm + mse_indices, _ = self._quantize_mse_core(u, self._mse_bits) + u_mse = self._dequantize_mse_core(mse_indices, 1.0, self._mse_bits) + residual = u - u_mse + residual_norm = float(np.linalg.norm(residual)) + qjl_signs = np.sign(self._qjl @ residual).astype(np.float32) + # np.sign(0) == 0; map any zeros to +1 so the sign vector is strictly +-1. + qjl_signs[qjl_signs == 0] = 1.0 + return mse_indices, qjl_signs, residual_norm, norm + + def inner_product_prod( + self, + query: npt.NDArray[np.floating], + mse_indices: npt.NDArray[np.int64], + qjl_signs: npt.NDArray[np.float32], + residual_norm: float, + norm: float, + ) -> float: + """Unbiased estimate of ```` from a prod row. + + ``query`` is a full-precision vector (not quantized). Implements the + estimator ``norm * ( + gamma * sqrt(pi/2)/d * )`` + (paper Theorem 2 / Algorithm 2). + """ + q = np.asarray(query, dtype=np.float32) + if norm == 0.0: + return 0.0 + u_mse = self._dequantize_mse_core(mse_indices, 1.0, self._mse_bits) + mse_term = float(q @ u_mse) + sq = self._qjl @ q # S @ q + qjl_term = float(np.sqrt(np.pi / 2.0) / self.dim * residual_norm * (sq @ qjl_signs)) + return norm * (mse_term + qjl_term) + + def dequantize_prod( + self, + mse_indices: npt.NDArray[np.int64], + qjl_signs: npt.NDArray[np.float32], + residual_norm: float, + norm: float, + ) -> npt.NDArray[np.float32]: + """Reconstruct an (unbiased-in-expectation) vector from a prod row. + + Used for diagnostics; the search path uses :meth:`inner_product_prod` + directly, which avoids materializing the reconstruction. + """ + if norm == 0.0: + return np.zeros(self.dim, dtype=np.float32) + u_mse = self._dequantize_mse_core(mse_indices, 1.0, self._mse_bits) + qjl_recon = np.sqrt(np.pi / 2.0) / self.dim * residual_norm * (self._qjl.T @ qjl_signs) + recon: npt.NDArray[np.float32] = ((u_mse + qjl_recon) * norm).astype(np.float32) + return recon + + +# --------------------------------------------------------------------------- +# Seeded matrix generation +# --------------------------------------------------------------------------- + + +def _random_rotation(dim: int, seed: int) -> npt.NDArray[np.float32]: + """Uniformly random rotation via QR of a seeded Gaussian matrix. + + Sign-corrects the Q factor so the decomposition is a deterministic function + of the seed (NumPy's QR sign convention is otherwise implementation-defined). + """ + rng = np.random.default_rng(seed) + a = rng.standard_normal((dim, dim)) + q, r = np.linalg.qr(a) + # Make Q unique: force positive diagonal of R. + d = np.sign(np.diag(r)) + d[d == 0] = 1.0 + q = q * d + return q.astype(np.float32) + + +def _random_projection(dim: int, seed: int) -> npt.NDArray[np.float32]: + """Seeded ``dim x dim`` Gaussian matrix for the QJL transform.""" + rng = np.random.default_rng(seed) + return rng.standard_normal((dim, dim)).astype(np.float32) diff --git a/tests/test_turbo_quant.py b/tests/test_turbo_quant.py new file mode 100644 index 0000000..ab5ca95 --- /dev/null +++ b/tests/test_turbo_quant.py @@ -0,0 +1,200 @@ +"""Unit tests for the TurboQuant core algorithm. + +Verifies the paper's distortion bounds (Theorem 1) and the unbiasedness of the +two-stage inner-product estimator (Theorem 2), plus packing and determinism. +""" + +from __future__ import annotations + +import math + +import numpy as np +import pytest + +from cocoindex_code.turbo_quant import ( + SUPPORTED_BITS, + TurboQuant, + pack_indices, + pack_signs, + unpack_indices, + unpack_signs, +) + +# Paper Theorem 1 upper bound: D_mse <= sqrt(3*pi/2) * 4^-b +_MSE_UPPER = math.sqrt(3 * math.pi / 2) +# Finer per-b values from Theorem 1. +_MSE_FINE = {1: 0.36, 2: 0.117, 3: 0.03, 4: 0.009} + +_DIM = 384 +_N = 4000 +_SEED = 7 + + +def _random_unit_vectors(n: int, d: int, seed: int) -> np.ndarray: + rng = np.random.default_rng(seed) + v = rng.standard_normal((n, d)).astype(np.float32) + v /= np.linalg.norm(v, axis=1, keepdims=True) + return v + + +# --------------------------------------------------------------------------- +# MSE distortion bounds +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("bits", SUPPORTED_BITS) +def test_mse_distortion_within_paper_bound(bits: int) -> None: + tq = TurboQuant(dim=_DIM, bits=bits, seed=_SEED) + vecs = _random_unit_vectors(_N, _DIM, seed=11) + + sq_errors = [] + for v in vecs: + idx, norm = tq.quantize_mse(v) + v_hat = tq.dequantize_mse(idx, norm) + sq_errors.append(float(np.sum((v - v_hat) ** 2))) + mse = float(np.mean(sq_errors)) + + # Correctness check: measured distortion matches the paper's reported + # empirical fine values (Theorem 1) within tolerance. These are the real + # targets — the sqrt(3*pi/2)*4^-b asymptotic (Panter-Dite high-resolution) + # formula is overshot by the paper's own b=4 optimum (0.009 > 0.0085), so it + # is only a sanity ceiling for the lower bit-widths. + fine = _MSE_FINE[bits] + assert mse <= fine * 1.15, f"b={bits}: MSE {mse:.4f} exceeds paper value {fine:.4f}" + if bits <= 3: + upper = _MSE_UPPER * 4.0 ** (-bits) + assert mse <= upper * 1.10, f"b={bits}: MSE {mse:.4f} exceeds asymptotic {upper:.4f}" + + +# --------------------------------------------------------------------------- +# Inner-product unbiasedness +# --------------------------------------------------------------------------- + + +def test_prod_estimator_is_unbiased() -> None: + bits = 4 + tq = TurboQuant(dim=_DIM, bits=bits, seed=_SEED) + xs = _random_unit_vectors(_N, _DIM, seed=21) + ys = _random_unit_vectors(_N, _DIM, seed=22) + + errors = [] + for x, y in zip(xs, ys): + mse_idx, qjl, rnorm, norm = tq.quantize_prod(x) + est = tq.inner_product_prod(y, mse_idx, qjl, rnorm, norm) + true_ip = float(y @ x) + errors.append(est - true_ip) + + errors = np.array(errors) + mean_err = float(errors.mean()) + stderr = float(errors.std(ddof=1) / math.sqrt(len(errors))) + # Mean signed error within ~3 standard errors of zero -> unbiased. + assert abs(mean_err) <= 3.0 * stderr + 1e-3, f"bias {mean_err:.5f} (SE {stderr:.5f})" + + +def test_mse_quantizer_is_biased_for_inner_product_at_b1() -> None: + """MSE-only estimate (via dequantize) shows the ~2/pi shrinkage at b=1. + + This is the motivation for the two-stage prod scheme. + """ + tq = TurboQuant(dim=_DIM, bits=1, seed=_SEED) + xs = _random_unit_vectors(_N, _DIM, seed=31) + ys = _random_unit_vectors(_N, _DIM, seed=32) + + ratios = [] + for x, y in zip(xs, ys): + idx, norm = tq.quantize_mse(x) + x_hat = tq.dequantize_mse(idx, norm) + true_ip = float(y @ x) + if abs(true_ip) > 1e-3: + ratios.append(float(y @ x_hat) / true_ip) + mean_ratio = float(np.mean(ratios)) + # Biased: estimate is a fraction of the true IP, not ~1.0. + assert mean_ratio < 0.9 + + +# --------------------------------------------------------------------------- +# Determinism + reconstruction sanity +# --------------------------------------------------------------------------- + + +def test_same_seed_is_deterministic() -> None: + a = TurboQuant(dim=64, bits=3, seed=99) + b = TurboQuant(dim=64, bits=3, seed=99) + v = _random_unit_vectors(1, 64, seed=5)[0] + idx_a, n_a = a.quantize_mse(v) + idx_b, n_b = b.quantize_mse(v) + assert np.array_equal(idx_a, idx_b) + assert n_a == n_b + + +def test_cosine_improves_with_bits() -> None: + vecs = _random_unit_vectors(500, 128, seed=8) + prev = -1.0 + for bits in SUPPORTED_BITS: + tq = TurboQuant(dim=128, bits=bits, seed=3) + cos = [] + for v in vecs: + idx, norm = tq.quantize_mse(v) + v_hat = tq.dequantize_mse(idx, norm) + denom = np.linalg.norm(v) * np.linalg.norm(v_hat) + if denom > 0: + cos.append(float(v @ v_hat) / denom) + mean_cos = float(np.mean(cos)) + assert mean_cos >= prev - 0.02, f"cosine regressed at b={bits}" + prev = mean_cos + + +def test_zero_vector_no_nan() -> None: + tq = TurboQuant(dim=32, bits=2, seed=1) + z = np.zeros(32, dtype=np.float32) + idx, norm = tq.quantize_mse(z) + out = tq.dequantize_mse(idx, norm) + assert norm == 0.0 + assert not np.any(np.isnan(out)) + + mse_idx, qjl, rnorm, n = tq.quantize_prod(z) + est = tq.inner_product_prod(np.ones(32, dtype=np.float32), mse_idx, qjl, rnorm, n) + assert est == 0.0 + + +def test_small_dims_do_not_crash() -> None: + for d in (1, 2, 3): + tq = TurboQuant(dim=d, bits=2, seed=2) + v = _random_unit_vectors(1, d, seed=4)[0] + idx, norm = tq.quantize_mse(v) + out = tq.dequantize_mse(idx, norm) + assert out.shape == (d,) + assert not np.any(np.isnan(out)) + + +# --------------------------------------------------------------------------- +# Bit packing +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("bits", SUPPORTED_BITS) +def test_index_packing_roundtrip(bits: int) -> None: + rng = np.random.default_rng(bits) + count = 137 # not a multiple of 8 -> exercises padding + idx = rng.integers(0, 1 << bits, size=count) + packed = pack_indices(idx, bits) + out = unpack_indices(packed, count, bits) + assert np.array_equal(idx, out) + # Packed size is ceil(count*bits/8). + assert len(packed) == math.ceil(count * bits / 8) + + +@pytest.mark.parametrize("count", [1, 7, 8, 9, 64, 65]) +def test_sign_packing_roundtrip(count: int) -> None: + rng = np.random.default_rng(count) + signs = np.where(rng.random(count) > 0.5, 1.0, -1.0).astype(np.float32) + packed = pack_signs(signs) + out = unpack_signs(packed, count) + assert np.array_equal(signs, out) + + +def test_empty_packing() -> None: + assert pack_indices(np.empty(0, dtype=np.int64), 4) == b"" + assert unpack_indices(b"", 0, 4).size == 0 + assert pack_signs(np.empty(0)) == b"" + assert unpack_signs(b"", 0).size == 0 From 85c6cf18f8e0d7a4457ed1c33e0f37c65c4aa78f Mon Sep 17 00:00:00 2001 From: dudegladiator Date: Sun, 7 Jun 2026 14:51:26 +0530 Subject: [PATCH 2/4] feat: add TurboQuant SQLite store with vectorized search SQLite-backed compressed vector store: bit-packed rows, seed-reproducible matrices (only the seed is persisted), and a vectorized NumPy inner-product search with language/path/limit/offset filter parity. Batched bitpack decode keeps load cheap at scale. --- src/cocoindex_code/tq_store.py | 404 +++++++++++++++++++++++++++++++++ tests/test_tq_store.py | 200 ++++++++++++++++ 2 files changed, 604 insertions(+) create mode 100644 src/cocoindex_code/tq_store.py create mode 100644 tests/test_tq_store.py diff --git a/src/cocoindex_code/tq_store.py b/src/cocoindex_code/tq_store.py new file mode 100644 index 0000000..07be856 --- /dev/null +++ b/src/cocoindex_code/tq_store.py @@ -0,0 +1,404 @@ +"""TurboQuant compressed vector store backed by plain SQLite tables. + +Unlike the sqlite-vec path (which uses a ``vec0`` virtual table and C KNN), the +TurboQuant backend stores bit-packed quantized rows in ordinary tables and runs +search as a vectorized inner-product scan in NumPy. + +Two tables: + +* ``code_chunks_tq`` — one row per chunk: id, file_path, language, content, + start_line, end_line, and the quantized payload (packed MSE indices, packed + QJL signs, residual norm, original norm). +* ``tq_metadata`` — a single row describing the index: bit-width, dimension, and + the seed used to derive the rotation / QJL matrices. The matrices themselves + are regenerated from the seed on load, so they never need to be serialized. + +Search honors the same filters as ``query.py``'s sqlite-vec path: ``languages`` +(exact match), ``paths`` (GLOB), ``limit``, and ``offset``. +""" + +from __future__ import annotations + +import fnmatch +import sqlite3 +from dataclasses import dataclass + +import numpy as np +import numpy.typing as npt + +from .schema import QueryResult, TqChunkRow +from .turbo_quant import ( + TurboQuant, + pack_indices, + pack_signs, +) + +TQ_TABLE = "code_chunks_tq" +TQ_METADATA_TABLE = "tq_metadata" + + +# --------------------------------------------------------------------------- +# Schema management +# --------------------------------------------------------------------------- + + +def create_chunk_table(conn: sqlite3.Connection) -> None: + """Create the ``code_chunks_tq`` table if absent. + + Used by standalone callers and tests. In the live indexer the cocoindex + ``mount_table_target`` owns this table's creation, so the indexer only calls + :func:`create_metadata_table` and must NOT call this. + """ + conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {TQ_TABLE} ( + id INTEGER PRIMARY KEY, + file_path TEXT NOT NULL, + language TEXT NOT NULL, + content TEXT NOT NULL, + start_line INTEGER NOT NULL, + end_line INTEGER NOT NULL, + idx_packed BLOB NOT NULL, + qjl_packed BLOB NOT NULL, + residual_norm REAL NOT NULL, + norm REAL NOT NULL + ) + """ + ) + + +def create_metadata_table(conn: sqlite3.Connection) -> None: + """Create the ``tq_metadata`` table if absent.""" + conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {TQ_METADATA_TABLE} ( + id INTEGER PRIMARY KEY CHECK (id = 0), + backend TEXT NOT NULL, + bits INTEGER NOT NULL, + dim INTEGER NOT NULL, + seed INTEGER NOT NULL + ) + """ + ) + + +def create_tables(conn: sqlite3.Connection) -> None: + """Create both TurboQuant tables (standalone / test convenience).""" + create_chunk_table(conn) + create_metadata_table(conn) + + +def write_metadata(conn: sqlite3.Connection, *, bits: int, dim: int, seed: int) -> None: + """Write (or replace) the single metadata row.""" + conn.execute( + f"INSERT OR REPLACE INTO {TQ_METADATA_TABLE} (id, backend, bits, dim, seed) " + f"VALUES (0, ?, ?, ?, ?)", + ("turbo-quant", bits, dim, seed), + ) + + +@dataclass +class TqMetadata: + backend: str + bits: int + dim: int + seed: int + + +def read_metadata(conn: sqlite3.Connection) -> TqMetadata | None: + """Read the metadata row, or ``None`` if the table/row is absent.""" + try: + row = conn.execute( + f"SELECT backend, bits, dim, seed FROM {TQ_METADATA_TABLE} WHERE id = 0" + ).fetchone() + except sqlite3.OperationalError: + return None + if row is None: + return None + return TqMetadata(backend=row[0], bits=int(row[1]), dim=int(row[2]), seed=int(row[3])) + + +def insert_rows(conn: sqlite3.Connection, rows: list[TqChunkRow]) -> None: + """Bulk-insert quantized chunk rows.""" + conn.executemany( + f""" + INSERT OR REPLACE INTO {TQ_TABLE} + (id, file_path, language, content, start_line, end_line, + idx_packed, qjl_packed, residual_norm, norm) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + ( + r.id, + r.file_path, + r.language, + r.content, + r.start_line, + r.end_line, + r.idx_packed, + r.qjl_packed, + r.residual_norm, + r.norm, + ) + for r in rows + ], + ) + + +def quantize_row( + tq: TurboQuant, + *, + chunk_id: int, + file_path: str, + language: str, + content: str, + start_line: int, + end_line: int, + embedding: npt.NDArray[np.floating], +) -> TqChunkRow: + """Quantize one embedding into a storable :class:`TqChunkRow`.""" + mse_idx, qjl_signs, residual_norm, norm = tq.quantize_prod(embedding) + # MSE stage uses bits-1; a 1-bit prod index has a 0-bit MSE stage (no bytes). + mse_bits = tq.bits - 1 + idx_packed = pack_indices(mse_idx, mse_bits) if mse_bits >= 1 else b"" + qjl_packed = pack_signs(qjl_signs) + return TqChunkRow( + id=chunk_id, + file_path=file_path, + language=language, + content=content, + start_line=start_line, + end_line=end_line, + idx_packed=idx_packed, + qjl_packed=qjl_packed, + residual_norm=residual_norm, + norm=norm, + ) + + +def _bulk_unpack_indices(blobs: list[bytes], n: int, dim: int, bits: int) -> npt.NDArray[np.int8]: + """Decode ``n`` equal-length packed-index blobs into an ``(n, dim)`` int8 matrix. + + Vectorized equivalent of calling :func:`unpack_indices` per row: stacks all + blobs into one uint8 matrix and unpacks the whole batch in a single pass. + """ + if n == 0: + return np.empty((0, dim), dtype=np.int8) + raw = np.frombuffer(b"".join(blobs), dtype=np.uint8).reshape(n, -1) + bits_mat = np.unpackbits(raw, axis=1)[:, : dim * bits].reshape(n, dim, bits) + weights = (1 << np.arange(bits - 1, -1, -1)).astype(np.int8) + return (bits_mat.astype(np.int8) @ weights).astype(np.int8) + + +def _bulk_unpack_signs(blobs: list[bytes], n: int, dim: int) -> npt.NDArray[np.int8]: + """Decode ``n`` packed sign blobs into an ``(n, dim)`` int8 ``+-1`` matrix.""" + if n == 0: + return np.empty((0, dim), dtype=np.int8) + raw = np.frombuffer(b"".join(blobs), dtype=np.uint8).reshape(n, -1) + bit = np.unpackbits(raw, axis=1)[:, :dim] + return np.where(bit > 0, np.int8(1), np.int8(-1)).astype(np.int8) + + +# --------------------------------------------------------------------------- +# In-memory store + search +# --------------------------------------------------------------------------- + + +class TqStore: + """Loaded, searchable TurboQuant index held in NumPy arrays.""" + + def __init__(self, tq: TurboQuant, metadata: TqMetadata) -> None: + self.tq = tq + self.metadata = metadata + self._ids: list[int] = [] + self._file_paths: list[str] = [] + self._languages: list[str] = [] + self._contents: list[str] = [] + self._start_lines: list[int] = [] + self._end_lines: list[int] = [] + # Decoded quantized payload as dense arrays for vectorized scoring. + # Indices (0..15 for <=4 bits) and signs (+-1) fit in int8, keeping the + # in-memory footprint ~bits-proportional rather than blowing up to int64. + self._mse_idx: npt.NDArray[np.int8] = np.empty((0, tq.dim), dtype=np.int8) + self._qjl: npt.NDArray[np.int8] = np.empty((0, tq.dim), dtype=np.int8) + self._residual_norms: npt.NDArray[np.float32] = np.empty(0, dtype=np.float32) + self._norms: npt.NDArray[np.float32] = np.empty(0, dtype=np.float32) + + @classmethod + def load(cls, conn: sqlite3.Connection) -> TqStore: + """Load the full index into memory. Raises if metadata is missing.""" + metadata = read_metadata(conn) + if metadata is None: + raise RuntimeError("TurboQuant metadata not found; index not built with turbo-quant") + tq = TurboQuant(dim=metadata.dim, bits=metadata.bits, seed=metadata.seed) + store = cls(tq, metadata) + store._load_rows(conn) + return store + + def _load_rows(self, conn: sqlite3.Connection) -> None: + rows = conn.execute( + f"SELECT id, file_path, language, content, start_line, end_line, " + f"idx_packed, qjl_packed, residual_norm, norm FROM {TQ_TABLE} ORDER BY id" + ).fetchall() + n = len(rows) + dim = self.tq.dim + mse_bits = self.tq.bits - 1 + + # Metadata columns: cheap Python-side gather. + self._ids = [int(r[0]) for r in rows] + self._file_paths = [r[1] for r in rows] + self._languages = [r[2] for r in rows] + self._contents = [r[3] for r in rows] + self._start_lines = [int(r[4]) for r in rows] + self._end_lines = [int(r[5]) for r in rows] + self._residual_norms = np.array([r[8] for r in rows], dtype=np.float32) + self._norms = np.array([r[9] for r in rows], dtype=np.float32) + + # Quantized payload: decode all rows in one vectorized pass instead of a + # per-row unpack loop (the dominant cost at scale). Every row's blob has + # the same byte length, so the blobs stack into a single uint8 matrix and + # np.unpackbits decodes the whole batch at once. + self._mse_idx = ( + _bulk_unpack_indices([r[6] for r in rows], n, dim, mse_bits) + if mse_bits >= 1 + else np.zeros((n, dim), dtype=np.int8) + ) + self._qjl = _bulk_unpack_signs([r[7] for r in rows], n, dim) + + def __len__(self) -> int: + return len(self._ids) + + # -- filtering ---------------------------------------------------------- + + def _candidate_mask( + self, languages: list[str] | None, paths: list[str] | None + ) -> npt.NDArray[np.bool_]: + n = len(self._ids) + mask = np.ones(n, dtype=bool) + if languages: + lang_set = set(languages) + mask &= np.array([lg in lang_set for lg in self._languages], dtype=bool) + if paths: + path_mask = np.array( + [any(fnmatch.fnmatch(fp, pat) for pat in paths) for fp in self._file_paths], + dtype=bool, + ) + mask &= path_mask + return mask + + # -- search ------------------------------------------------------------- + + def search( + self, + query_embedding: npt.NDArray[np.floating], + limit: int = 10, + offset: int = 0, + languages: list[str] | None = None, + paths: list[str] | None = None, + ) -> list[QueryResult]: + """Top-(limit) inner-product search over the (filtered) candidate set. + + Returns results in descending estimated-inner-product order. The score is + the unbiased inner-product estimate (higher = more similar), consistent + with the sqlite-vec path returning a higher-is-better similarity. + """ + n = len(self._ids) + if n == 0: + return [] + mask = self._candidate_mask(languages, paths) + cand = np.nonzero(mask)[0] + if cand.size == 0: + return [] + + scores = self._score(np.asarray(query_embedding, dtype=np.float32), cand) + + # Top (limit+offset) by score, then slice the offset window. + want = min(limit + offset, cand.size) + # argpartition for the top `want`, then sort that slice descending. + part = np.argpartition(-scores, want - 1)[:want] + ordered = part[np.argsort(-scores[part])] + window = ordered[offset : offset + limit] + + results: list[QueryResult] = [] + for local in window: + global_i = int(cand[local]) + results.append( + QueryResult( + file_path=self._file_paths[global_i], + language=self._languages[global_i], + content=self._contents[global_i], + start_line=self._start_lines[global_i], + end_line=self._end_lines[global_i], + score=float(scores[local]), + ) + ) + return results + + def _score( + self, q: npt.NDArray[np.float32], cand: npt.NDArray[np.int64] + ) -> npt.NDArray[np.float32]: + """Vectorized unbiased inner-product estimate for candidate rows. + + Mirrors :meth:`TurboQuant.inner_product_prod` but batched across rows: + + score = norm * ( + sqrt(pi/2)/d * gamma * ) + + where ``u_mse`` is the dequantized MSE term (unit-space) and ``S q`` is + projected once for the whole batch. + """ + tq = self.tq + dim = tq.dim + mse_bits = tq.bits - 1 + + # MSE term: dequantize candidate MSE indices back to unit space, dot q. + if mse_bits >= 1: + codebook = (tq._codebook(mse_bits)).astype(np.float32) # scaled centroids + # y_hat[cand] : (m, d) rotated reconstructions; rotate back via Pi^T. + y_hat = codebook[self._mse_idx[cand]] # (m, d) + u_mse = y_hat @ tq._rotation # (m,d)@(d,d) == (Pi^T y_hat) rows + mse_term = u_mse @ q # (m,) + else: + mse_term = np.zeros(cand.size, dtype=np.float32) + + # QJL term: project q once, then dot with each row's sign vector. + sq = tq._qjl @ q # (d,) + qjl_dot = self._qjl[cand] @ sq # (m,) + coef = np.float32(np.sqrt(np.pi / 2.0) / dim) + qjl_term = coef * self._residual_norms[cand] * qjl_dot + + return self._norms[cand] * (mse_term + qjl_term) + + # -- size accounting (for the benchmark) -------------------------------- + + def loaded_nbytes(self) -> int: + """Approximate in-memory size of the decoded arrays.""" + return int( + self._mse_idx.nbytes + + self._qjl.nbytes + + self._residual_norms.nbytes + + self._norms.nbytes + ) + + +def index_table_name(conn: sqlite3.Connection) -> str | None: + """Return the chunk table backing this index, or ``None`` if not indexed. + + Lets backend-agnostic callers (``ccc doctor``, status) count chunks without + hard-coding ``code_chunks_vec``. Prefers the TurboQuant table when present. + """ + for name in (TQ_TABLE, "code_chunks_vec"): + row = conn.execute( + "SELECT name FROM sqlite_master WHERE type IN ('table','view') AND name = ?", + (name,), + ).fetchone() + if row is not None: + return name + return None + + +def store_size_bytes(conn: sqlite3.Connection) -> int: + """On-disk payload size: total bytes of the packed blobs in ``code_chunks_tq``.""" + row = conn.execute( + f"SELECT COALESCE(SUM(LENGTH(idx_packed) + LENGTH(qjl_packed)), 0) FROM {TQ_TABLE}" + ).fetchone() + return int(row[0]) diff --git a/tests/test_tq_store.py b/tests/test_tq_store.py new file mode 100644 index 0000000..6965f30 --- /dev/null +++ b/tests/test_tq_store.py @@ -0,0 +1,200 @@ +"""Tests for the TurboQuant compressed store: persist, load, search, filters.""" + +from __future__ import annotations + +import math +import sqlite3 + +import numpy as np +import pytest + +from cocoindex_code.tq_store import ( + TqStore, + create_tables, + insert_rows, + quantize_row, + store_size_bytes, + write_metadata, +) +from cocoindex_code.turbo_quant import TurboQuant + +_DIM = 128 +_BITS = 4 +_SEED = 13 + + +def _unit(rng: np.random.Generator, d: int) -> np.ndarray: + v = rng.standard_normal(d).astype(np.float32) + return v / np.linalg.norm(v) + + +def _build_index(conn, embeddings, *, languages=None, file_paths=None, seed=_SEED, bits=_BITS): + """Quantize and persist a list of embeddings; return the TurboQuant used.""" + tq = TurboQuant(dim=_DIM, bits=bits, seed=seed) + create_tables(conn) + write_metadata(conn, bits=bits, dim=_DIM, seed=seed) + rows = [] + for i, emb in enumerate(embeddings): + lang = languages[i] if languages else "python" + fp = file_paths[i] if file_paths else f"src/file_{i}.py" + rows.append( + quantize_row( + tq, + chunk_id=i, + file_path=fp, + language=lang, + content=f"chunk {i}", + start_line=i, + end_line=i + 1, + embedding=emb, + ) + ) + insert_rows(conn, rows) + return tq + + +@pytest.fixture() +def conn(): + c = sqlite3.connect(":memory:") + yield c + c.close() + + +def test_persist_load_search_finds_nearest(conn) -> None: + rng = np.random.default_rng(1) + embs = [_unit(rng, _DIM) for _ in range(50)] + _build_index(conn, embs) + + store = TqStore.load(conn) + assert len(store) == 50 + + # Query with one of the indexed vectors -> it should rank top-1. + target = 17 + results = store.search(embs[target], limit=1) + assert len(results) == 1 + assert results[0].content == f"chunk {target}" + + +def test_scores_descending(conn) -> None: + rng = np.random.default_rng(2) + embs = [_unit(rng, _DIM) for _ in range(30)] + _build_index(conn, embs) + store = TqStore.load(conn) + results = store.search(embs[0], limit=10) + scores = [r.score for r in results] + assert scores == sorted(scores, reverse=True) + + +def test_language_filter(conn) -> None: + rng = np.random.default_rng(3) + embs = [_unit(rng, _DIM) for _ in range(30)] + langs = ["python" if i % 2 == 0 else "go" for i in range(30)] + _build_index(conn, embs, languages=langs) + store = TqStore.load(conn) + results = store.search(embs[0], limit=30, languages=["python"]) + assert all(r.language == "python" for r in results) + assert len(results) == 15 + + +def test_multi_language_filter(conn) -> None: + rng = np.random.default_rng(4) + embs = [_unit(rng, _DIM) for _ in range(30)] + langs = ["python", "go", "rust"] * 10 + _build_index(conn, embs, languages=langs) + store = TqStore.load(conn) + results = store.search(embs[0], limit=30, languages=["python", "go"]) + assert {r.language for r in results} <= {"python", "go"} + assert len(results) == 20 + + +def test_path_filter(conn) -> None: + rng = np.random.default_rng(5) + embs = [_unit(rng, _DIM) for _ in range(20)] + fps = [f"src/{i}.py" if i < 10 else f"tests/{i}.py" for i in range(20)] + _build_index(conn, embs, file_paths=fps) + store = TqStore.load(conn) + results = store.search(embs[0], limit=20, paths=["src/*"]) + assert all(r.file_path.startswith("src/") for r in results) + assert len(results) == 10 + + +def test_combined_language_and_path_filter(conn) -> None: + rng = np.random.default_rng(6) + embs = [_unit(rng, _DIM) for _ in range(20)] + langs = ["python" if i % 2 == 0 else "go" for i in range(20)] + fps = [f"src/{i}.py" if i < 10 else f"lib/{i}.py" for i in range(20)] + _build_index(conn, embs, languages=langs, file_paths=fps) + store = TqStore.load(conn) + results = store.search(embs[0], limit=20, languages=["python"], paths=["src/*"]) + for r in results: + assert r.language == "python" + assert r.file_path.startswith("src/") + + +def test_offset_and_limit(conn) -> None: + rng = np.random.default_rng(7) + embs = [_unit(rng, _DIM) for _ in range(40)] + _build_index(conn, embs) + store = TqStore.load(conn) + full = store.search(embs[0], limit=10, offset=0) + paged = store.search(embs[0], limit=5, offset=5) + # paged should equal items 6..10 of the full ranking. + assert [r.content for r in full[5:10]] == [r.content for r in paged] + + +def test_empty_candidate_set_returns_empty(conn) -> None: + rng = np.random.default_rng(8) + embs = [_unit(rng, _DIM) for _ in range(10)] + _build_index(conn, embs, languages=["python"] * 10) + store = TqStore.load(conn) + assert store.search(embs[0], limit=10, languages=["haskell"]) == [] + + +def test_empty_store_returns_empty(conn) -> None: + create_tables(conn) + write_metadata(conn, bits=_BITS, dim=_DIM, seed=_SEED) + store = TqStore.load(conn) + assert len(store) == 0 + assert store.search(np.ones(_DIM, dtype=np.float32), limit=5) == [] + + +def test_reload_matches_in_memory_search(conn) -> None: + """Search after reload (matrices regenerated from seed) matches first load.""" + rng = np.random.default_rng(9) + embs = [_unit(rng, _DIM) for _ in range(25)] + _build_index(conn, embs) + store1 = TqStore.load(conn) + r1 = store1.search(embs[3], limit=5) + store2 = TqStore.load(conn) # fresh load, fresh TurboQuant from seed + r2 = store2.search(embs[3], limit=5) + assert [x.content for x in r1] == [x.content for x in r2] + assert [round(x.score, 5) for x in r1] == [round(x.score, 5) for x in r2] + + +def test_store_size_reflects_bits(conn) -> None: + rng = np.random.default_rng(10) + embs = [_unit(rng, _DIM) for _ in range(100)] + _build_index(conn, embs, bits=4) + size = store_size_bytes(conn) + # idx is (bits-1)=3 bits/coord, qjl is 1 bit/coord, over dim coords, 100 rows. + expected_per_row = math.ceil(_DIM * 3 / 8) + math.ceil(_DIM * 1 / 8) + assert size == expected_per_row * 100 + + +def test_recall_at_10_reasonable(conn) -> None: + """Sanity: 4-bit prod search recovers most exact top-10 neighbors.""" + rng = np.random.default_rng(11) + embs = np.array([_unit(rng, _DIM) for _ in range(300)]) + _build_index(conn, list(embs), bits=4) + store = TqStore.load(conn) + + queries = [_unit(rng, _DIM) for _ in range(30)] + recalls = [] + for q in queries: + exact = set(np.argsort(-(embs @ q))[:10].tolist()) + got_rows = store.search(q, limit=10) + # Map results back to indices via content "chunk {i}". + got = {int(r.content.split()[1]) for r in got_rows} + recalls.append(len(exact & got) / 10.0) + mean_recall = float(np.mean(recalls)) + assert mean_recall >= 0.6, f"recall@10 too low: {mean_recall:.2f}" From 6d53a245b4cdae0aaae63d0ac7936f50deb1746b Mon Sep 17 00:00:00 2001 From: dudegladiator Date: Sun, 7 Jun 2026 14:51:35 +0530 Subject: [PATCH 3/4] feat: wire TurboQuant as a selectable ccc backend Make TurboQuant selectable at `ccc init` via `--backend turbo-quant` (`--tq-bits`), alongside the default sqlite-vec path. Index-time quantization, query-time dispatch with a daemon-lifetime store cache, backend-agnostic index status, and settings validation. sqlite-vec remains the default and its path is unchanged. --- pyproject.toml | 5 +- src/cocoindex_code/cli.py | 90 ++++++++++++++++++- src/cocoindex_code/daemon.py | 14 ++- src/cocoindex_code/indexer.py | 97 ++++++++++++++++----- src/cocoindex_code/project.py | 27 +++++- src/cocoindex_code/query.py | 60 ++++++++++++- src/cocoindex_code/schema.py | 23 +++++ src/cocoindex_code/settings.py | 44 +++++++++- src/cocoindex_code/shared.py | 4 + tests/benchmark_turbo_quant.py | 154 +++++++++++++++++++++++++++++++++ tests/test_cli_helpers.py | 47 ++++++++++ tests/test_e2e_backend.py | 134 ++++++++++++++++++++++++++++ tests/test_settings.py | 56 ++++++++++++ 13 files changed, 720 insertions(+), 35 deletions(-) create mode 100644 tests/benchmark_turbo_quant.py create mode 100644 tests/test_e2e_backend.py diff --git a/pyproject.toml b/pyproject.toml index bd702a8..11308a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,10 +107,11 @@ files = ["src"] [tool.pytest.ini_options] testpaths = ["tests"] -python_files = ["test_*.py"] +python_files = ["test_*.py", "benchmark_*.py"] python_functions = ["test_*"] -addopts = "-v --tb=short -m 'not docker_e2e'" +addopts = "-v --tb=short -m 'not docker_e2e and not benchmark'" asyncio_mode = "auto" markers = [ "docker_e2e: requires Docker; builds the image and runs containerized E2E tests. Run with: pytest -m docker_e2e", + "benchmark: TurboQuant vs sqlite-vec benchmark; prints a metrics table. Run with: pytest -m benchmark -s", ] diff --git a/src/cocoindex_code/cli.py b/src/cocoindex_code/cli.py index bd79fd4..9ab15c5 100644 --- a/src/cocoindex_code/cli.py +++ b/src/cocoindex_code/cli.py @@ -21,6 +21,9 @@ from .settings import ( DEFAULT_ST_MODEL, + DEFAULT_TQ_BITS, + SUPPORTED_TQ_BITS, + Backend, EmbeddingSettings, cocoindex_db_path, default_project_settings, @@ -34,6 +37,8 @@ save_project_settings, target_sqlite_db_path, user_settings_path, + validate_backend, + validate_tq_bits, ) app = _typer.Typer( @@ -383,6 +388,55 @@ def _resolve_embedding_choice( return EmbeddingSettings(provider=provider, model=model.strip()) +def _resolve_backend( + backend_flag: str | None, tq_bits_flag: int | None +) -> tuple[Backend, int]: + """Resolve (backend, tq_bits) from flags, an interactive prompt, or defaults. + + Explicit ``--backend`` wins. Otherwise prompt when stdin is a TTY; when not + interactive, fall back to the default backend (sqlite-vec). + """ + bits = validate_tq_bits(tq_bits_flag) if tq_bits_flag is not None else DEFAULT_TQ_BITS + + if backend_flag is not None: + return validate_backend(backend_flag), bits + + if not sys.stdin.isatty(): + return "sqlite-vec", bits + + import questionary + + backend = questionary.select( + "Vector backend", + choices=[ + questionary.Choice( + title="sqlite-vec (default, exact nearest-neighbor)", + value="sqlite-vec", + ), + questionary.Choice( + title="turbo-quant (compressed, ~4-8x smaller index)", + value="turbo-quant", + ), + ], + ).ask() + if backend is None: # cancelled + raise _typer.Exit(code=1) + + if backend == "turbo-quant" and tq_bits_flag is None: + answer = questionary.select( + "TurboQuant bit-width (higher = better recall, larger index)", + # Choice titles are strings; values are ints. The default must match a + # choice *value* (int), not its title (str). + choices=[questionary.Choice(title=str(b), value=b) for b in SUPPORTED_TQ_BITS], + default=DEFAULT_TQ_BITS, # type: ignore[arg-type] + ).ask() + if answer is None: + raise _typer.Exit(code=1) + bits = validate_tq_bits(answer) + + return validate_backend(backend), bits + + def _ok_fail_tag(ok: bool) -> str: """Return a colored `[OK]` or `[FAIL]` tag string.""" import click as _click @@ -484,9 +538,33 @@ def init( "--litellm-model", help="Use the given LiteLLM model and skip provider/model prompts.", ), + backend: str | None = _typer.Option( + None, + "--backend", + help="Vector backend: 'sqlite-vec' (default, exact) or 'turbo-quant' (compressed).", + ), + tq_bits: int | None = _typer.Option( + None, + "--tq-bits", + help=f"TurboQuant bit-width {list(SUPPORTED_TQ_BITS)} (only for --backend turbo-quant).", + ), force: bool = _typer.Option(False, "-f", "--force", help="Skip parent directory warning"), ) -> None: """Initialize a project for cocoindex-code.""" + # Validate backend flags early so bad input fails before any side effects. + if backend is not None: + try: + validate_backend(backend) + except ValueError as e: + _typer.echo(f"Error: {e}", err=True) + raise _typer.Exit(code=1) from e + if tq_bits is not None: + try: + validate_tq_bits(tq_bits) + except ValueError as e: + _typer.echo(f"Error: {e}", err=True) + raise _typer.Exit(code=1) from e + cwd = Path.cwd().resolve() settings_file = project_settings_path(cwd) @@ -520,9 +598,19 @@ def init( ) raise _typer.Exit(code=1) + # Resolve the vector backend: explicit flag wins; otherwise prompt when + # interactive; otherwise fall back to the default (sqlite-vec). + resolved_backend, resolved_bits = _resolve_backend(backend, tq_bits) + # Create project settings - save_project_settings(cwd, default_project_settings()) + project_settings = default_project_settings() + project_settings.backend = resolved_backend + project_settings.tq_bits = resolved_bits + save_project_settings(cwd, project_settings) _typer.echo(f"Created project settings: {format_path_for_display(settings_file)}") + _typer.echo(f"Vector backend: {resolved_backend}") + if resolved_backend == "turbo-quant": + _typer.echo(f"TurboQuant bit-width: {resolved_bits}") # Add to .gitignore add_to_gitignore(cwd) diff --git a/src/cocoindex_code/daemon.py b/src/cocoindex_code/daemon.py index 35b982a..f581ea2 100644 --- a/src/cocoindex_code/daemon.py +++ b/src/cocoindex_code/daemon.py @@ -437,14 +437,22 @@ async def _check_index_status(project_root_str: str) -> DoctorCheckResult: return DoctorCheckResult(name="Index Status", ok=True, details=details, errors=[]) try: + from .tq_store import index_table_name + conn = coco_sqlite.connect(str(db_path), load_vec=True) try: with conn.readonly() as db: - total_chunks = db.execute("SELECT COUNT(*) FROM code_chunks_vec").fetchone()[0] - file_rows = db.execute("SELECT DISTINCT file_path FROM code_chunks_vec").fetchall() + table = index_table_name(db) + if table is None: + details.append("Index not created yet.") + return DoctorCheckResult( + name="Index Status", ok=True, details=details, errors=[] + ) + total_chunks = db.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0] + file_rows = db.execute(f"SELECT DISTINCT file_path FROM {table}").fetchall() total_files = len(file_rows) lang_rows = db.execute( - "SELECT language, COUNT(*) FROM code_chunks_vec GROUP BY language" + f"SELECT language, COUNT(*) FROM {table} GROUP BY language" ).fetchall() languages = {row[0]: row[1] for row in lang_rows} finally: diff --git a/src/cocoindex_code/indexer.py b/src/cocoindex_code/indexer.py index e028103..e242dd3 100644 --- a/src/cocoindex_code/indexer.py +++ b/src/cocoindex_code/indexer.py @@ -4,6 +4,7 @@ from collections.abc import Iterable from pathlib import Path, PurePath +from typing import Any import cocoindex as coco from cocoindex.connectors import localfs, sqlite @@ -15,14 +16,17 @@ from pathspec import GitIgnoreSpec from .chunking import CHUNKER_REGISTRY +from .schema import TqChunkRow from .settings import load_gitignore_spec, load_project_settings from .shared import ( CODEBASE_DIR, EMBEDDER, INDEXING_EMBED_PARAMS, SQLITE_DB, + TURBO_QUANT, CodeChunk, ) +from .tq_store import TQ_TABLE, quantize_row # Chunking configuration CHUNK_SIZE = 1000 @@ -137,9 +141,14 @@ def is_file_included(self, path: PurePath) -> bool: @coco.fn(memo=True) async def process_file( file: localfs.File, - table: sqlite.TableTarget[CodeChunk], + table: sqlite.TableTarget[Any], ) -> None: - """Process a single file: chunk, embed, and store.""" + """Process a single file: chunk, embed, and store. + + The stored row type depends on the project backend: ``CodeChunk`` (raw + float32 in vec0) for sqlite-vec, or ``TqChunkRow`` (quantized) for + turbo-quant. ``table`` is the matching target built by ``indexer_main``. + """ embedder = coco.use_context(EMBEDDER) indexing_params = coco.use_context(INDEXING_EMBED_PARAMS) @@ -177,19 +186,37 @@ async def process_file( ) id_gen = IdGenerator() + backend = ps.backend + tq = coco.use_context(TURBO_QUANT) if backend == "turbo-quant" else None async def process(chunk: Chunk) -> None: - table.declare_row( - row=CodeChunk( - id=await id_gen.next_id(chunk.text), - file_path=file.file_path.path.as_posix(), - language=language, - content=chunk.text, - start_line=chunk.start.line, - end_line=chunk.end.line, - embedding=await embedder.embed(chunk.text, **indexing_params), + chunk_id = await id_gen.next_id(chunk.text) + embedding = await embedder.embed(chunk.text, **indexing_params) + if tq is not None: + table.declare_row( + row=quantize_row( + tq, + chunk_id=chunk_id, + file_path=file.file_path.path.as_posix(), + language=language, + content=chunk.text, + start_line=chunk.start.line, + end_line=chunk.end.line, + embedding=embedding, + ) + ) + else: + table.declare_row( + row=CodeChunk( + id=chunk_id, + file_path=file.file_path.path.as_posix(), + language=language, + content=chunk.text, + start_line=chunk.start.line, + end_line=chunk.end.line, + embedding=embedding, + ) ) - ) await coco.map(process, chunks) @@ -201,18 +228,40 @@ async def indexer_main() -> None: ps = load_project_settings(project_root) gitignore_spec = load_gitignore_spec(project_root) - table = await sqlite.mount_table_target( - db=SQLITE_DB, - table_name="code_chunks_vec", - table_schema=await sqlite.TableSchema.from_class( - CodeChunk, - primary_key=["id"], - ), - virtual_table_def=Vec0TableDef( - partition_key_columns=["language"], - auxiliary_columns=["file_path", "content", "start_line", "end_line"], - ), - ) + table: sqlite.TableTarget[Any] + if ps.backend == "turbo-quant": + tq = coco.use_context(TURBO_QUANT) + # Persist index metadata (bits/dim/seed) so the store can regenerate the + # rotation/QJL matrices at query time. + db = coco.use_context(SQLITE_DB) + from .tq_store import create_metadata_table, write_metadata + + # The chunk table itself is created by mount_table_target below; here we + # only own the side metadata table. + with db.transaction() as conn: + create_metadata_table(conn) + write_metadata(conn, bits=tq.bits, dim=tq.dim, seed=tq.seed) + table = await sqlite.mount_table_target( + db=SQLITE_DB, + table_name=TQ_TABLE, + table_schema=await sqlite.TableSchema.from_class( + TqChunkRow, + primary_key=["id"], + ), + ) + else: + table = await sqlite.mount_table_target( + db=SQLITE_DB, + table_name="code_chunks_vec", + table_schema=await sqlite.TableSchema.from_class( + CodeChunk, + primary_key=["id"], + ), + virtual_table_def=Vec0TableDef( + partition_key_columns=["language"], + auxiliary_columns=["file_path", "content", "start_line", "end_line"], + ), + ) base_matcher = PatternFilePathMatcher( included_patterns=ps.include_patterns, diff --git a/src/cocoindex_code/project.py b/src/cocoindex_code/project.py index f661c21..dfa2c03 100644 --- a/src/cocoindex_code/project.py +++ b/src/cocoindex_code/project.py @@ -27,6 +27,7 @@ cocoindex_db_path as _cocoindex_db_path, ) from .settings import ( + load_project_settings, resolve_db_dir, ) from .settings import ( @@ -38,6 +39,7 @@ INDEXING_EMBED_PARAMS, QUERY_EMBED_PARAMS, SQLITE_DB, + TURBO_QUANT, Embedder, ) @@ -211,16 +213,21 @@ async def search( def get_status(self) -> ProjectStatusResponse: """Get index stats by querying the SQLite database.""" + from .tq_store import index_table_name + db = self._env.get_context(SQLITE_DB) index_exists = True try: with db.readonly() as conn: - total_chunks = conn.execute("SELECT COUNT(*) FROM code_chunks_vec").fetchone()[0] + table = index_table_name(conn) + if table is None: + raise sqlite3.OperationalError("no index table") + total_chunks = conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0] total_files = conn.execute( - "SELECT COUNT(DISTINCT file_path) FROM code_chunks_vec" + f"SELECT COUNT(DISTINCT file_path) FROM {table}" ).fetchone()[0] lang_rows = conn.execute( - "SELECT language, COUNT(*) as cnt FROM code_chunks_vec" + f"SELECT language, COUNT(*) as cnt FROM {table}" " GROUP BY language ORDER BY cnt DESC" ).fetchall() except sqlite3.OperationalError: @@ -301,6 +308,20 @@ async def create( context.provide(QUERY_EMBED_PARAMS, dict(query_params)) context.provide(CHUNKER_REGISTRY, dict(chunker_registry) if chunker_registry else {}) + # TurboQuant backend: build the quantizer once (dimension probed from the + # embedder) and make it available to the indexer and query paths. The + # seed is fixed so the rotation/QJL matrices are reproducible across the + # daemon's lifetime and recorded in tq_metadata at index time. + backend = load_project_settings(project_root).backend + if backend == "turbo-quant": + from .turbo_quant import TurboQuant + + ps = load_project_settings(project_root) + probe = await embedder.embed("dimension probe", **dict(indexing_params)) + dim = len(probe) + tq = TurboQuant(dim=dim, bits=ps.tq_bits, seed=0) + context.provide(TURBO_QUANT, tq) + env = coco.Environment(settings, context_provider=context) app = coco.App( coco.AppConfig( diff --git a/src/cocoindex_code/query.py b/src/cocoindex_code/query.py index a2991ee..fdf1719 100644 --- a/src/cocoindex_code/query.py +++ b/src/cocoindex_code/query.py @@ -5,11 +5,14 @@ import heapq import sqlite3 from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any from .schema import QueryResult from .shared import EMBEDDER, QUERY_EMBED_PARAMS, SQLITE_DB +if TYPE_CHECKING: + from .tq_store import TqStore + def _l2_to_score(distance: float) -> float: """Convert L2 distance to cosine similarity (exact for unit vectors).""" @@ -82,6 +85,52 @@ def _full_scan_query( ).fetchall() +# Process-lifetime cache of loaded TurboQuant stores, keyed by the DB connection +# identity. Loading a store decodes every quantized row into NumPy arrays, which +# is the dominant query cost at scale; the daemon is long-lived and the index +# only changes on re-index, so we cache the loaded store and reuse it across +# queries. The cache is invalidated when the chunk table's row count changes +# (re-index replaces rows), which is a cheap COUNT(*) check per query. +_STORE_CACHE: dict[int, tuple[int, TqStore]] = {} + + +def _maybe_query_turbo_quant( + db: Any, + query_embedding: Any, + limit: int, + offset: int, + languages: list[str] | None, + paths: list[str] | None, +) -> list[QueryResult] | None: + """Run a TurboQuant search if the DB is a turbo-quant index, else ``None``. + + The loaded store is cached for the life of the daemon process and reused + across queries; it is reloaded only when the index's row count changes. The + store regenerates its rotation/QJL matrices from the persisted seed, so no + matrices are read from disk. + """ + from .tq_store import TQ_TABLE, TqStore, read_metadata + + cache_key = id(db) + with db.readonly() as conn: + if read_metadata(conn) is None: + return None + row_count = conn.execute(f"SELECT COUNT(*) FROM {TQ_TABLE}").fetchone()[0] + cached = _STORE_CACHE.get(cache_key) + if cached is not None and cached[0] == row_count: + store = cached[1] + else: + store = TqStore.load(conn) + _STORE_CACHE[cache_key] = (row_count, store) + return store.search( + query_embedding, + limit=limit, + offset=offset, + languages=languages, + paths=paths, + ) + + async def query_codebase( query: str, target_sqlite_db_path: Path, @@ -111,6 +160,15 @@ async def query_codebase( # Generate query embedding. query_embedding = await embedder.embed(query, **query_params) + # TurboQuant backend: search the compressed store with the unbiased + # inner-product estimator. Detected by the presence of tq_metadata, so the + # query path does not need the project settings at hand. + tq_results = _maybe_query_turbo_quant( + db, query_embedding, limit, offset, languages, paths + ) + if tq_results is not None: + return tq_results + embedding_bytes = query_embedding.astype("float32").tobytes() with db.readonly() as conn: diff --git a/src/cocoindex_code/schema.py b/src/cocoindex_code/schema.py index bfb8a74..f55fb68 100644 --- a/src/cocoindex_code/schema.py +++ b/src/cocoindex_code/schema.py @@ -17,6 +17,29 @@ class CodeChunk: embedding: Any # NDArray - type hint relaxed for compatibility +@dataclass +class TqChunkRow: + """A code chunk stored in the TurboQuant compressed backend. + + Mirrors :class:`CodeChunk` minus the raw ``embedding`` (which is replaced by + the quantized representation). ``idx_packed`` holds the bit-packed MSE-stage + codebook indices, ``qjl_packed`` the bit-packed QJL sign vector, + ``residual_norm`` the L2 norm of the unit-space residual, and ``norm`` the + original embedding's L2 norm. + """ + + id: int + file_path: str + language: str + content: str + start_line: int + end_line: int + idx_packed: bytes + qjl_packed: bytes + residual_norm: float + norm: float + + @dataclass class QueryResult: """Result from a vector similarity query.""" diff --git a/src/cocoindex_code/settings.py b/src/cocoindex_code/settings.py index 73b026b..57b717e 100644 --- a/src/cocoindex_code/settings.py +++ b/src/cocoindex_code/settings.py @@ -5,7 +5,7 @@ import os from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal, cast, get_args import yaml as _yaml @@ -122,12 +122,27 @@ class ChunkerMapping: module: str # "module.path:callable", e.g. "cocoindex_code.toml_chunker:toml_chunker" +# Vector-search backend identifiers. ``sqlite-vec`` is the default (raw float32 +# in a vec0 virtual table, C KNN). ``turbo-quant`` uses the TurboQuant compressed +# backend (see ``turbo_quant.py`` / ``tq_store.py``). +Backend = Literal["sqlite-vec", "turbo-quant"] +DEFAULT_BACKEND: Backend = "sqlite-vec" +SUPPORTED_BACKENDS: tuple[Backend, ...] = get_args(Backend) +# TurboQuant bit-widths we support end-to-end (mirrors turbo_quant.SUPPORTED_BITS). +SUPPORTED_TQ_BITS: tuple[int, ...] = (1, 2, 3, 4) +DEFAULT_TQ_BITS = 4 + + @dataclass class ProjectSettings: include_patterns: list[str] = field(default_factory=lambda: list(DEFAULT_INCLUDED_PATTERNS)) exclude_patterns: list[str] = field(default_factory=lambda: list(DEFAULT_EXCLUDED_PATTERNS)) language_overrides: list[LanguageOverride] = field(default_factory=list) chunkers: list[ChunkerMapping] = field(default_factory=list) + # Vector-search backend baked into the index at ``ccc init``. + backend: Backend = DEFAULT_BACKEND + # TurboQuant target bit-width (only meaningful when backend == "turbo-quant"). + tq_bits: int = DEFAULT_TQ_BITS # --------------------------------------------------------------------------- @@ -459,6 +474,23 @@ def _user_settings_from_dict(d: dict[str, Any]) -> UserSettings: return UserSettings(embedding=embedding, envs=envs) +def validate_backend(backend: Any) -> Backend: + """Validate a backend identifier, raising ``ValueError`` if unknown.""" + if backend not in SUPPORTED_BACKENDS: + raise ValueError( + f"unknown backend {backend!r}; expected one of {list(SUPPORTED_BACKENDS)}" + ) + # Narrowed to the Literal by the membership check above. + return cast(Backend, backend) + + +def validate_tq_bits(bits: Any) -> int: + """Validate a TurboQuant bit-width, raising ``ValueError`` if unsupported.""" + if not isinstance(bits, int) or isinstance(bits, bool) or bits not in SUPPORTED_TQ_BITS: + raise ValueError(f"tq_bits must be one of {list(SUPPORTED_TQ_BITS)}, got {bits!r}") + return bits + + def _project_settings_to_dict(settings: ProjectSettings) -> dict[str, Any]: d: dict[str, Any] = { "include_patterns": settings.include_patterns, @@ -470,6 +502,11 @@ def _project_settings_to_dict(settings: ProjectSettings) -> dict[str, Any]: ] if settings.chunkers: d["chunkers"] = [{"ext": cm.ext, "module": cm.module} for cm in settings.chunkers] + # Always persist the backend so an index records how it was built. tq_bits is + # only meaningful for turbo-quant, so omit it otherwise to keep files clean. + d["backend"] = settings.backend + if settings.backend == "turbo-quant": + d["tq_bits"] = settings.tq_bits return d @@ -478,11 +515,16 @@ def _project_settings_from_dict(d: dict[str, Any]) -> ProjectSettings: LanguageOverride(ext=lo["ext"], lang=lo["lang"]) for lo in d.get("language_overrides", []) ] chunkers = [ChunkerMapping(ext=cm["ext"], module=cm["module"]) for cm in d.get("chunkers", [])] + # Missing backend -> sqlite-vec (backward compatible with pre-backend files). + backend = validate_backend(d.get("backend", DEFAULT_BACKEND)) + tq_bits = validate_tq_bits(d.get("tq_bits", DEFAULT_TQ_BITS)) return ProjectSettings( include_patterns=d.get("include_patterns", list(DEFAULT_INCLUDED_PATTERNS)), exclude_patterns=d.get("exclude_patterns", list(DEFAULT_EXCLUDED_PATTERNS)), language_overrides=overrides, chunkers=chunkers, + backend=backend, + tq_bits=tq_bits, ) diff --git a/src/cocoindex_code/shared.py b/src/cocoindex_code/shared.py index f607755..6c45231 100644 --- a/src/cocoindex_code/shared.py +++ b/src/cocoindex_code/shared.py @@ -18,6 +18,8 @@ from cocoindex.ops.litellm import LiteLLMEmbedder from cocoindex.ops.sentence_transformers import SentenceTransformerEmbedder + from .turbo_quant import TurboQuant # noqa: F401 (used in ContextKey string annotation) + from .settings import EmbeddingSettings logger = logging.getLogger(__name__) @@ -34,6 +36,8 @@ CODEBASE_DIR = coco.ContextKey[pathlib.Path]("codebase") INDEXING_EMBED_PARAMS = coco.ContextKey[dict[str, Any]]("indexing_embed_params") QUERY_EMBED_PARAMS = coco.ContextKey[dict[str, Any]]("query_embed_params") +# TurboQuant quantizer for the active index (only set when backend=turbo-quant). +TURBO_QUANT = coco.ContextKey["TurboQuant"]("turbo_quant") def is_sentence_transformers_installed() -> bool: diff --git a/tests/benchmark_turbo_quant.py b/tests/benchmark_turbo_quant.py new file mode 100644 index 0000000..2c45c5f --- /dev/null +++ b/tests/benchmark_turbo_quant.py @@ -0,0 +1,154 @@ +"""Benchmark: TurboQuant compressed backend vs raw float32 (sqlite-vec-equivalent). + +Reports the headline trade-off numbers — index size, in-memory size, query +latency, and recall@{1,10} vs exact float32 ground truth — on real embeddings of +this repository's own source. + +Run with:: + + uv run pytest tests/benchmark_turbo_quant.py -m benchmark -s + +Excluded from the default test run (see the ``benchmark`` marker in +pyproject.toml). Carries soft assertions so genuine regressions still fail, but +the primary output is the printed table. +""" + +from __future__ import annotations + +import sqlite3 +import time +from pathlib import Path + +import numpy as np +import pytest + +from cocoindex_code.tq_store import ( + TqStore, + create_tables, + insert_rows, + quantize_row, + store_size_bytes, + write_metadata, +) +from cocoindex_code.turbo_quant import TurboQuant + +pytestmark = pytest.mark.benchmark + +_MODEL = "sentence-transformers/paraphrase-MiniLM-L3-v2" # d=384, matches conftest +_N_QUERIES = 50 +_SEED = 0 + + +def _load_corpus_texts(limit: int = 1500) -> list[str]: + """Chunk this repo's own Python source into snippets for embedding.""" + root = Path(__file__).resolve().parent.parent / "src" + texts: list[str] = [] + for path in sorted(root.rglob("*.py")): + lines = path.read_text(errors="ignore").splitlines() + for i in range(0, len(lines), 20): + block = "\n".join(lines[i : i + 20]).strip() + if len(block) > 40: + texts.append(block) + if len(texts) >= limit: + return texts + return texts + + +def _embed(texts: list[str]) -> np.ndarray: + from sentence_transformers import SentenceTransformer + + model = SentenceTransformer(_MODEL.split("/", 1)[1]) + vecs = model.encode(texts, normalize_embeddings=True, show_progress_bar=False) + return np.asarray(vecs, dtype=np.float32) + + +def _recall_at_k(exact_top: set[int], got: list[int], k: int) -> float: + return len(exact_top & set(got[:k])) / float(k) + + +def _build_tq(embs: np.ndarray, bits: int) -> tuple[sqlite3.Connection, TqStore]: + dim = embs.shape[1] + tq = TurboQuant(dim=dim, bits=bits, seed=_SEED) + conn = sqlite3.connect(":memory:") + create_tables(conn) + write_metadata(conn, bits=bits, dim=dim, seed=_SEED) + rows = [ + quantize_row( + tq, + chunk_id=i, + file_path=f"f{i}.py", + language="python", + content=f"chunk {i}", + start_line=i, + end_line=i + 1, + embedding=embs[i], + ) + for i in range(len(embs)) + ] + insert_rows(conn, rows) + return conn, TqStore.load(conn) + + +def test_benchmark_report() -> None: + texts = _load_corpus_texts() + embs = _embed(texts) + n, dim = embs.shape + rng = np.random.default_rng(123) + query_ids = rng.choice(n, size=min(_N_QUERIES, n), replace=False) + + # Exact float32 ground truth (brute force). + def exact_topk(q: np.ndarray, k: int) -> list[int]: + return np.argsort(-(embs @ q))[:k].tolist() + + raw_float32_bytes = n * dim * 4 # sqlite-vec stores raw float32 + + print(f"\n=== TurboQuant benchmark — n={n} chunks, dim={dim} ===") + print(f"{'backend':<16}{'size(MB)':>10}{'ratio':>8}{'mem(MB)':>10}" + f"{'q-lat(ms)':>11}{'recall@1':>10}{'recall@10':>11}") + + # Float32 baseline latency (numpy brute force == exact). + t0 = time.perf_counter() + for qi in query_ids: + exact_topk(embs[qi], 10) + f32_lat = (time.perf_counter() - t0) / len(query_ids) * 1000 + print(f"{'float32(exact)':<16}{raw_float32_bytes/1e6:>10.2f}{1.0:>8.1f}" + f"{raw_float32_bytes/1e6:>10.2f}{f32_lat:>11.3f}{1.0:>10.3f}{1.0:>11.3f}") + + results = {} + for bits in (2, 4): + conn, store = _build_tq(embs, bits) + disk = store_size_bytes(conn) + mem = store.loaded_nbytes() + + # Latency. + t0 = time.perf_counter() + for qi in query_ids: + store.search(embs[qi], limit=10) + lat = (time.perf_counter() - t0) / len(query_ids) * 1000 + + # Recall. + r1, r10 = [], [] + for qi in query_ids: + q = embs[qi] + exact1 = set(exact_topk(q, 1)) + exact10 = set(exact_topk(q, 10)) + got = [int(r.content.split()[1]) for r in store.search(q, limit=10)] + r1.append(_recall_at_k(exact1, got, 1)) + r10.append(_recall_at_k(exact10, got, 10)) + recall1 = float(np.mean(r1)) + recall10 = float(np.mean(r10)) + ratio = raw_float32_bytes / disk if disk else float("inf") + results[bits] = (ratio, recall1, recall10) + print(f"{'turbo-quant b' + str(bits):<16}{disk/1e6:>10.2f}{ratio:>8.1f}" + f"{mem/1e6:>10.2f}{lat:>11.3f}{recall1:>10.3f}{recall10:>11.3f}") + conn.close() + + print("\nTakeaway: TurboQuant wins on index size + memory; float32 wins on " + "query latency (numpy/C exact scan). Recall stays high at 4-bit.\n") + + # Soft regression gates. + ratio4, _, recall10_4 = results[4] + assert ratio4 >= 6.0, f"4-bit compression ratio {ratio4:.1f}x < 6x" + assert recall10_4 >= 0.80, f"4-bit recall@10 {recall10_4:.2f} < 0.80" + _, _, recall10_2 = results[2] + assert recall10_2 >= 0.55, f"2-bit recall@10 {recall10_2:.2f} < 0.55" diff --git a/tests/test_cli_helpers.py b/tests/test_cli_helpers.py index ec9876a..ffdd2dd 100644 --- a/tests/test_cli_helpers.py +++ b/tests/test_cli_helpers.py @@ -275,3 +275,50 @@ def test_init_writes_comment_template_for_unknown_model( loaded = load_user_settings() assert loaded.embedding.indexing_params is None assert loaded.embedding.query_params is None + + +# --------------------------------------------------------------------------- +# Backend resolution (_resolve_backend) — U6 +# --------------------------------------------------------------------------- + + +def test_resolve_backend_flag_wins(monkeypatch: pytest.MonkeyPatch) -> None: + # Even on a TTY, an explicit flag skips the prompt. + monkeypatch.setattr("sys.stdin.isatty", lambda: True) + backend, bits = cli._resolve_backend("turbo-quant", 2) + assert backend == "turbo-quant" + assert bits == 2 + + +def test_resolve_backend_non_tty_defaults_sqlite_vec(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("sys.stdin.isatty", lambda: False) + backend, bits = cli._resolve_backend(None, None) + assert backend == "sqlite-vec" + assert bits == 4 # DEFAULT_TQ_BITS + + +def test_resolve_backend_flag_turbo_default_bits(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("sys.stdin.isatty", lambda: False) + backend, bits = cli._resolve_backend("turbo-quant", None) + assert backend == "turbo-quant" + assert bits == 4 + + +def test_init_rejects_invalid_backend() -> None: + from typer.testing import CliRunner + + from cocoindex_code.cli import app + + result = CliRunner().invoke(app, ["init", "--backend", "faiss"]) + assert result.exit_code == 1 + assert "unknown backend" in result.output + + +def test_init_rejects_invalid_tq_bits() -> None: + from typer.testing import CliRunner + + from cocoindex_code.cli import app + + result = CliRunner().invoke(app, ["init", "--backend", "turbo-quant", "--tq-bits", "9"]) + assert result.exit_code == 1 + assert "tq_bits" in result.output diff --git a/tests/test_e2e_backend.py b/tests/test_e2e_backend.py new file mode 100644 index 0000000..4ab904e --- /dev/null +++ b/tests/test_e2e_backend.py @@ -0,0 +1,134 @@ +"""End-to-end tests across both vector backends (sqlite-vec and turbo-quant). + +Exercises the full CLI -> daemon -> index -> search loop for each backend, plus +status/doctor parity. Mirrors the fixture and driving style of test_e2e.py. +""" + +from __future__ import annotations + +import os +import tempfile +from collections.abc import Iterator +from pathlib import Path + +import pytest +from conftest import make_test_user_settings +from typer.testing import CliRunner + +from cocoindex_code.cli import app +from cocoindex_code.client import stop_daemon +from cocoindex_code.settings import save_user_settings + +runner = CliRunner() + +SAMPLE_MAIN_PY = '''\ +"""Main application entry point.""" + +def calculate_fibonacci(n: int) -> int: + """Calculate the nth Fibonacci number recursively.""" + if n <= 1: + return n + return calculate_fibonacci(n - 1) + calculate_fibonacci(n - 2) +''' + +SAMPLE_DB_PY = '''\ +"""Database connection utilities.""" + +class DatabaseConnection: + """Manages database connections.""" + + def connect(self) -> None: + """Establish connection to the database.""" + self._connected = True +''' + + +@pytest.fixture() +def e2e_project() -> Iterator[Path]: + base_dir = Path(tempfile.mkdtemp(prefix="ccc_e2e_backend_")) + project_dir = base_dir / "project" + project_dir.mkdir() + (project_dir / "main.py").write_text(SAMPLE_MAIN_PY) + lib_dir = project_dir / "lib" + lib_dir.mkdir() + (lib_dir / "database.py").write_text(SAMPLE_DB_PY) + (project_dir / ".git").mkdir() + + old_env = os.environ.get("COCOINDEX_CODE_DIR") + os.environ["COCOINDEX_CODE_DIR"] = str(base_dir) + old_cwd = os.getcwd() + os.chdir(project_dir) + save_user_settings(make_test_user_settings()) + + try: + yield project_dir + finally: + os.chdir(project_dir) + runner.invoke(app, ["reset", "--all", "-f"]) + stop_daemon() + os.chdir(old_cwd) + if old_env is None: + os.environ.pop("COCOINDEX_CODE_DIR", None) + else: + os.environ["COCOINDEX_CODE_DIR"] = old_env + + +@pytest.mark.parametrize("backend", ["sqlite-vec", "turbo-quant"]) +def test_init_index_search_per_backend(e2e_project: Path, backend: str) -> None: + # Init with explicit backend (non-interactive). + result = runner.invoke(app, ["init", "--backend", backend], catch_exceptions=False) + assert result.exit_code == 0, result.output + assert backend in result.output + + settings_text = (e2e_project / ".cocoindex_code" / "settings.yml").read_text() + assert f"backend: {backend}" in settings_text + + # Index. + result = runner.invoke(app, ["index"], catch_exceptions=False) + assert result.exit_code == 0, result.output + assert "Chunks:" in result.output + + # Status reports chunks for both backends (doctor parity). + result = runner.invoke(app, ["status"], catch_exceptions=False) + assert result.exit_code == 0, result.output + assert "Chunks:" in result.output + + # Search finds the fibonacci chunk in main.py. + result = runner.invoke(app, ["search", "fibonacci", "calculation"], catch_exceptions=False) + assert result.exit_code == 0, result.output + assert "main.py" in result.output + + +@pytest.mark.parametrize("backend", ["sqlite-vec", "turbo-quant"]) +def test_search_path_filter_per_backend(e2e_project: Path, backend: str) -> None: + runner.invoke(app, ["init", "--backend", backend], catch_exceptions=False) + runner.invoke(app, ["index"], catch_exceptions=False) + result = runner.invoke( + app, ["search", "database", "connection", "--path", "lib/*"], catch_exceptions=False + ) + assert result.exit_code == 0, result.output + assert "lib/" in result.output + + +def test_reinit_switches_backend(e2e_project: Path) -> None: + """Re-init with a different backend then re-index rebuilds cleanly (R8).""" + runner.invoke(app, ["init", "--backend", "sqlite-vec"], catch_exceptions=False) + runner.invoke(app, ["index"], catch_exceptions=False) + + # Force re-init to turbo-quant. + result = runner.invoke( + app, ["init", "-f", "--backend", "turbo-quant"], catch_exceptions=False + ) + # `init` returns early ("already initialized") if settings exist; reset first. + if "already initialized" in result.output: + runner.invoke(app, ["reset", "--all", "-f"], catch_exceptions=False) + result = runner.invoke( + app, ["init", "--backend", "turbo-quant"], catch_exceptions=False + ) + assert result.exit_code == 0, result.output + + result = runner.invoke(app, ["index"], catch_exceptions=False) + assert result.exit_code == 0, result.output + result = runner.invoke(app, ["search", "fibonacci"], catch_exceptions=False) + assert result.exit_code == 0, result.output + assert "main.py" in result.output diff --git a/tests/test_settings.py b/tests/test_settings.py index 6c06af1..29dd2e5 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -99,6 +99,62 @@ def test_save_and_load_project_settings(tmp_path: Path) -> None: assert loaded.language_overrides[0].lang == "php" +def test_backend_defaults_to_sqlite_vec(tmp_path: Path) -> None: + settings = ProjectSettings() + assert settings.backend == "sqlite-vec" + assert settings.tq_bits == 4 + save_project_settings(tmp_path, settings) + loaded = load_project_settings(tmp_path) + assert loaded.backend == "sqlite-vec" + + +def test_turbo_quant_backend_round_trip(tmp_path: Path) -> None: + settings = ProjectSettings(backend="turbo-quant", tq_bits=2) + save_project_settings(tmp_path, settings) + loaded = load_project_settings(tmp_path) + assert loaded.backend == "turbo-quant" + assert loaded.tq_bits == 2 + + +def test_missing_backend_key_loads_as_sqlite_vec() -> None: + """A pre-backend settings file (no backend key) defaults safely.""" + from cocoindex_code.settings import _project_settings_from_dict + + loaded = _project_settings_from_dict({"include_patterns": ["**/*.py"]}) + assert loaded.backend == "sqlite-vec" + assert loaded.tq_bits == 4 + + +def test_invalid_backend_raises() -> None: + from cocoindex_code.settings import _project_settings_from_dict + + with pytest.raises(ValueError, match="unknown backend"): + _project_settings_from_dict({"backend": "faiss"}) + + +@pytest.mark.parametrize("bad_bits", [0, 5, -1, 8]) +def test_invalid_tq_bits_raises(bad_bits: int) -> None: + from cocoindex_code.settings import _project_settings_from_dict + + with pytest.raises(ValueError, match="tq_bits"): + _project_settings_from_dict({"backend": "turbo-quant", "tq_bits": bad_bits}) + + +def test_tq_bits_omitted_for_sqlite_vec(tmp_path: Path) -> None: + """tq_bits is not written to disk for the sqlite-vec backend.""" + import yaml + + from cocoindex_code.settings import ( + _SETTINGS_DIR_NAME, + _SETTINGS_FILE_NAME, + ) + + save_project_settings(tmp_path, ProjectSettings(backend="sqlite-vec")) + raw = yaml.safe_load((tmp_path / _SETTINGS_DIR_NAME / _SETTINGS_FILE_NAME).read_text()) + assert raw["backend"] == "sqlite-vec" + assert "tq_bits" not in raw + + @pytest.mark.usefixtures("_patch_user_dir") def test_load_user_settings_missing_file_raises() -> None: with pytest.raises(FileNotFoundError): From 3c202dbdd9d18cfe162d4e518ab81d8192c5a355 Mon Sep 17 00:00:00 2001 From: dudegladiator Date: Sun, 7 Jun 2026 15:37:10 +0530 Subject: [PATCH 4/4] chore: annotate tests, document backends, format Add type annotations to the new TurboQuant tests (required by the mypy pre-commit hook, which checks tests/ too), document the turbo-quant backend and `--backend` / `--tq-bits` flags in the README, and apply ruff-format normalizations. --- README.md | 35 +++++++++++++++++++++++++++++++- src/cocoindex_code/cli.py | 4 +--- src/cocoindex_code/query.py | 4 +--- src/cocoindex_code/settings.py | 4 +--- tests/benchmark_turbo_quant.py | 26 +++++++++++++++--------- tests/test_e2e_backend.py | 8 ++------ tests/test_tq_store.py | 37 +++++++++++++++++++++------------- tests/test_turbo_quant.py | 6 +++--- 8 files changed, 82 insertions(+), 42 deletions(-) diff --git a/README.md b/README.md index c488421..ac75f17 100644 --- a/README.md +++ b/README.md @@ -166,7 +166,7 @@ The background daemon starts automatically on first use. | Command | Description | |---------|-------------| -| `ccc init` | Initialize a project — creates settings files, adds `.cocoindex_code/` to `.gitignore` | +| `ccc init` | Initialize a project — creates settings files, adds `.cocoindex_code/` to `.gitignore`. Use `--backend turbo-quant` (with `--tq-bits`) to pick the compressed backend; see [Vector Backends](#vector-backends) | | `ccc index` | Build or update the index (auto-inits if needed). Shows streaming progress. | | `ccc search ` | Semantic search across the codebase | | `ccc status` | Show index stats (chunk count, file count, language breakdown) | @@ -189,6 +189,34 @@ ccc search --refresh database schema # update index first, then By default, `ccc search` scopes results to your current working directory (relative to the project root). Use `--path` to override. +## Vector Backends + +`ccc` supports two vector-search backends, chosen at `ccc init` and baked into the index: + +| Backend | Index size | Search | Best for | +|---------|-----------|--------|----------| +| `sqlite-vec` (default) | full `float32` | exact KNN ([sqlite-vec](https://github.com/asg017/sqlite-vec)) | most projects — fastest, exact results | +| `turbo-quant` | ~4–8× smaller | approximate, unbiased inner-product | large codebases where index size matters | + +**TurboQuant** is a data-oblivious vector quantizer ([Zandieh et al., 2025](https://arxiv.org/abs/2504.19874)): it randomly rotates each embedding, quantizes per coordinate with optimal scalar codebooks, and adds a 1-bit QJL residual for an unbiased inner-product estimate. At 4-bit it compresses the index ~8× on disk with recall@10 ≈ 0.9, with no training or calibration. + +```bash +ccc init # interactive — prompts for backend +ccc init --backend turbo-quant # 4-bit (default bit-width) +ccc init --backend turbo-quant --tq-bits 2 # 2-bit — ~16× smaller, lower recall +ccc init --backend sqlite-vec # explicit default +``` + +Switching backends requires re-initializing and re-indexing: + +```bash +ccc reset --all -f +ccc init --backend turbo-quant +ccc index +``` + +> Higher `--tq-bits` (1–4) means better recall and a larger index. `sqlite-vec` stays the default for exact, low-latency search. + ## Docker A Docker image is available for teams who want a reproducible, dependency-free @@ -438,6 +466,9 @@ OpenAI embeddings (`text-embedding-3-*`, `text-embedding-ada-002`) are intention Per-project. Controls which files to index. ```yaml +backend: sqlite-vec # or "turbo-quant" — see Vector Backends +tq_bits: 4 # TurboQuant bit-width (1–4); only used when backend is turbo-quant + include_patterns: - "**/*.py" - "**/*.js" @@ -462,6 +493,8 @@ chunkers: module: example_toml_chunker:toml_chunker ``` +> `backend` is set at `ccc init` and baked into the index — changing it requires re-indexing (see [Vector Backends](#vector-backends)). + > `.cocoindex_code/` is automatically added to `.gitignore` during init. Use `chunkers` when you want to control how a file type is split into chunks before indexing. diff --git a/src/cocoindex_code/cli.py b/src/cocoindex_code/cli.py index 9ab15c5..6bfd675 100644 --- a/src/cocoindex_code/cli.py +++ b/src/cocoindex_code/cli.py @@ -388,9 +388,7 @@ def _resolve_embedding_choice( return EmbeddingSettings(provider=provider, model=model.strip()) -def _resolve_backend( - backend_flag: str | None, tq_bits_flag: int | None -) -> tuple[Backend, int]: +def _resolve_backend(backend_flag: str | None, tq_bits_flag: int | None) -> tuple[Backend, int]: """Resolve (backend, tq_bits) from flags, an interactive prompt, or defaults. Explicit ``--backend`` wins. Otherwise prompt when stdin is a TTY; when not diff --git a/src/cocoindex_code/query.py b/src/cocoindex_code/query.py index fdf1719..e077c2a 100644 --- a/src/cocoindex_code/query.py +++ b/src/cocoindex_code/query.py @@ -163,9 +163,7 @@ async def query_codebase( # TurboQuant backend: search the compressed store with the unbiased # inner-product estimator. Detected by the presence of tq_metadata, so the # query path does not need the project settings at hand. - tq_results = _maybe_query_turbo_quant( - db, query_embedding, limit, offset, languages, paths - ) + tq_results = _maybe_query_turbo_quant(db, query_embedding, limit, offset, languages, paths) if tq_results is not None: return tq_results diff --git a/src/cocoindex_code/settings.py b/src/cocoindex_code/settings.py index 57b717e..4962cdf 100644 --- a/src/cocoindex_code/settings.py +++ b/src/cocoindex_code/settings.py @@ -477,9 +477,7 @@ def _user_settings_from_dict(d: dict[str, Any]) -> UserSettings: def validate_backend(backend: Any) -> Backend: """Validate a backend identifier, raising ``ValueError`` if unknown.""" if backend not in SUPPORTED_BACKENDS: - raise ValueError( - f"unknown backend {backend!r}; expected one of {list(SUPPORTED_BACKENDS)}" - ) + raise ValueError(f"unknown backend {backend!r}; expected one of {list(SUPPORTED_BACKENDS)}") # Narrowed to the Literal by the membership check above. return cast(Backend, backend) diff --git a/tests/benchmark_turbo_quant.py b/tests/benchmark_turbo_quant.py index 2c45c5f..e3e0e55 100644 --- a/tests/benchmark_turbo_quant.py +++ b/tests/benchmark_turbo_quant.py @@ -98,21 +98,25 @@ def test_benchmark_report() -> None: # Exact float32 ground truth (brute force). def exact_topk(q: np.ndarray, k: int) -> list[int]: - return np.argsort(-(embs @ q))[:k].tolist() + return [int(i) for i in np.argsort(-(embs @ q))[:k]] raw_float32_bytes = n * dim * 4 # sqlite-vec stores raw float32 print(f"\n=== TurboQuant benchmark — n={n} chunks, dim={dim} ===") - print(f"{'backend':<16}{'size(MB)':>10}{'ratio':>8}{'mem(MB)':>10}" - f"{'q-lat(ms)':>11}{'recall@1':>10}{'recall@10':>11}") + print( + f"{'backend':<16}{'size(MB)':>10}{'ratio':>8}{'mem(MB)':>10}" + f"{'q-lat(ms)':>11}{'recall@1':>10}{'recall@10':>11}" + ) # Float32 baseline latency (numpy brute force == exact). t0 = time.perf_counter() for qi in query_ids: exact_topk(embs[qi], 10) f32_lat = (time.perf_counter() - t0) / len(query_ids) * 1000 - print(f"{'float32(exact)':<16}{raw_float32_bytes/1e6:>10.2f}{1.0:>8.1f}" - f"{raw_float32_bytes/1e6:>10.2f}{f32_lat:>11.3f}{1.0:>10.3f}{1.0:>11.3f}") + print( + f"{'float32(exact)':<16}{raw_float32_bytes / 1e6:>10.2f}{1.0:>8.1f}" + f"{raw_float32_bytes / 1e6:>10.2f}{f32_lat:>11.3f}{1.0:>10.3f}{1.0:>11.3f}" + ) results = {} for bits in (2, 4): @@ -139,12 +143,16 @@ def exact_topk(q: np.ndarray, k: int) -> list[int]: recall10 = float(np.mean(r10)) ratio = raw_float32_bytes / disk if disk else float("inf") results[bits] = (ratio, recall1, recall10) - print(f"{'turbo-quant b' + str(bits):<16}{disk/1e6:>10.2f}{ratio:>8.1f}" - f"{mem/1e6:>10.2f}{lat:>11.3f}{recall1:>10.3f}{recall10:>11.3f}") + print( + f"{'turbo-quant b' + str(bits):<16}{disk / 1e6:>10.2f}{ratio:>8.1f}" + f"{mem / 1e6:>10.2f}{lat:>11.3f}{recall1:>10.3f}{recall10:>11.3f}" + ) conn.close() - print("\nTakeaway: TurboQuant wins on index size + memory; float32 wins on " - "query latency (numpy/C exact scan). Recall stays high at 4-bit.\n") + print( + "\nTakeaway: TurboQuant wins on index size + memory; float32 wins on " + "query latency (numpy/C exact scan). Recall stays high at 4-bit.\n" + ) # Soft regression gates. ratio4, _, recall10_4 = results[4] diff --git a/tests/test_e2e_backend.py b/tests/test_e2e_backend.py index 4ab904e..9320eba 100644 --- a/tests/test_e2e_backend.py +++ b/tests/test_e2e_backend.py @@ -116,15 +116,11 @@ def test_reinit_switches_backend(e2e_project: Path) -> None: runner.invoke(app, ["index"], catch_exceptions=False) # Force re-init to turbo-quant. - result = runner.invoke( - app, ["init", "-f", "--backend", "turbo-quant"], catch_exceptions=False - ) + result = runner.invoke(app, ["init", "-f", "--backend", "turbo-quant"], catch_exceptions=False) # `init` returns early ("already initialized") if settings exist; reset first. if "already initialized" in result.output: runner.invoke(app, ["reset", "--all", "-f"], catch_exceptions=False) - result = runner.invoke( - app, ["init", "--backend", "turbo-quant"], catch_exceptions=False - ) + result = runner.invoke(app, ["init", "--backend", "turbo-quant"], catch_exceptions=False) assert result.exit_code == 0, result.output result = runner.invoke(app, ["index"], catch_exceptions=False) diff --git a/tests/test_tq_store.py b/tests/test_tq_store.py index 6965f30..81250e8 100644 --- a/tests/test_tq_store.py +++ b/tests/test_tq_store.py @@ -4,6 +4,7 @@ import math import sqlite3 +from collections.abc import Iterator import numpy as np import pytest @@ -28,7 +29,15 @@ def _unit(rng: np.random.Generator, d: int) -> np.ndarray: return v / np.linalg.norm(v) -def _build_index(conn, embeddings, *, languages=None, file_paths=None, seed=_SEED, bits=_BITS): +def _build_index( + conn: sqlite3.Connection, + embeddings: list[np.ndarray], + *, + languages: list[str] | None = None, + file_paths: list[str] | None = None, + seed: int = _SEED, + bits: int = _BITS, +) -> TurboQuant: """Quantize and persist a list of embeddings; return the TurboQuant used.""" tq = TurboQuant(dim=_DIM, bits=bits, seed=seed) create_tables(conn) @@ -54,13 +63,13 @@ def _build_index(conn, embeddings, *, languages=None, file_paths=None, seed=_SEE @pytest.fixture() -def conn(): +def conn() -> Iterator[sqlite3.Connection]: c = sqlite3.connect(":memory:") yield c c.close() -def test_persist_load_search_finds_nearest(conn) -> None: +def test_persist_load_search_finds_nearest(conn: sqlite3.Connection) -> None: rng = np.random.default_rng(1) embs = [_unit(rng, _DIM) for _ in range(50)] _build_index(conn, embs) @@ -75,7 +84,7 @@ def test_persist_load_search_finds_nearest(conn) -> None: assert results[0].content == f"chunk {target}" -def test_scores_descending(conn) -> None: +def test_scores_descending(conn: sqlite3.Connection) -> None: rng = np.random.default_rng(2) embs = [_unit(rng, _DIM) for _ in range(30)] _build_index(conn, embs) @@ -85,7 +94,7 @@ def test_scores_descending(conn) -> None: assert scores == sorted(scores, reverse=True) -def test_language_filter(conn) -> None: +def test_language_filter(conn: sqlite3.Connection) -> None: rng = np.random.default_rng(3) embs = [_unit(rng, _DIM) for _ in range(30)] langs = ["python" if i % 2 == 0 else "go" for i in range(30)] @@ -96,7 +105,7 @@ def test_language_filter(conn) -> None: assert len(results) == 15 -def test_multi_language_filter(conn) -> None: +def test_multi_language_filter(conn: sqlite3.Connection) -> None: rng = np.random.default_rng(4) embs = [_unit(rng, _DIM) for _ in range(30)] langs = ["python", "go", "rust"] * 10 @@ -107,7 +116,7 @@ def test_multi_language_filter(conn) -> None: assert len(results) == 20 -def test_path_filter(conn) -> None: +def test_path_filter(conn: sqlite3.Connection) -> None: rng = np.random.default_rng(5) embs = [_unit(rng, _DIM) for _ in range(20)] fps = [f"src/{i}.py" if i < 10 else f"tests/{i}.py" for i in range(20)] @@ -118,7 +127,7 @@ def test_path_filter(conn) -> None: assert len(results) == 10 -def test_combined_language_and_path_filter(conn) -> None: +def test_combined_language_and_path_filter(conn: sqlite3.Connection) -> None: rng = np.random.default_rng(6) embs = [_unit(rng, _DIM) for _ in range(20)] langs = ["python" if i % 2 == 0 else "go" for i in range(20)] @@ -131,7 +140,7 @@ def test_combined_language_and_path_filter(conn) -> None: assert r.file_path.startswith("src/") -def test_offset_and_limit(conn) -> None: +def test_offset_and_limit(conn: sqlite3.Connection) -> None: rng = np.random.default_rng(7) embs = [_unit(rng, _DIM) for _ in range(40)] _build_index(conn, embs) @@ -142,7 +151,7 @@ def test_offset_and_limit(conn) -> None: assert [r.content for r in full[5:10]] == [r.content for r in paged] -def test_empty_candidate_set_returns_empty(conn) -> None: +def test_empty_candidate_set_returns_empty(conn: sqlite3.Connection) -> None: rng = np.random.default_rng(8) embs = [_unit(rng, _DIM) for _ in range(10)] _build_index(conn, embs, languages=["python"] * 10) @@ -150,7 +159,7 @@ def test_empty_candidate_set_returns_empty(conn) -> None: assert store.search(embs[0], limit=10, languages=["haskell"]) == [] -def test_empty_store_returns_empty(conn) -> None: +def test_empty_store_returns_empty(conn: sqlite3.Connection) -> None: create_tables(conn) write_metadata(conn, bits=_BITS, dim=_DIM, seed=_SEED) store = TqStore.load(conn) @@ -158,7 +167,7 @@ def test_empty_store_returns_empty(conn) -> None: assert store.search(np.ones(_DIM, dtype=np.float32), limit=5) == [] -def test_reload_matches_in_memory_search(conn) -> None: +def test_reload_matches_in_memory_search(conn: sqlite3.Connection) -> None: """Search after reload (matrices regenerated from seed) matches first load.""" rng = np.random.default_rng(9) embs = [_unit(rng, _DIM) for _ in range(25)] @@ -171,7 +180,7 @@ def test_reload_matches_in_memory_search(conn) -> None: assert [round(x.score, 5) for x in r1] == [round(x.score, 5) for x in r2] -def test_store_size_reflects_bits(conn) -> None: +def test_store_size_reflects_bits(conn: sqlite3.Connection) -> None: rng = np.random.default_rng(10) embs = [_unit(rng, _DIM) for _ in range(100)] _build_index(conn, embs, bits=4) @@ -181,7 +190,7 @@ def test_store_size_reflects_bits(conn) -> None: assert size == expected_per_row * 100 -def test_recall_at_10_reasonable(conn) -> None: +def test_recall_at_10_reasonable(conn: sqlite3.Connection) -> None: """Sanity: 4-bit prod search recovers most exact top-10 neighbors.""" rng = np.random.default_rng(11) embs = np.array([_unit(rng, _DIM) for _ in range(300)]) diff --git a/tests/test_turbo_quant.py b/tests/test_turbo_quant.py index ab5ca95..438bfa4 100644 --- a/tests/test_turbo_quant.py +++ b/tests/test_turbo_quant.py @@ -77,14 +77,14 @@ def test_prod_estimator_is_unbiased() -> None: xs = _random_unit_vectors(_N, _DIM, seed=21) ys = _random_unit_vectors(_N, _DIM, seed=22) - errors = [] + err_list: list[float] = [] for x, y in zip(xs, ys): mse_idx, qjl, rnorm, norm = tq.quantize_prod(x) est = tq.inner_product_prod(y, mse_idx, qjl, rnorm, norm) true_ip = float(y @ x) - errors.append(est - true_ip) + err_list.append(est - true_ip) - errors = np.array(errors) + errors = np.array(err_list) mean_err = float(errors.mean()) stderr = float(errors.std(ddof=1) / math.sqrt(len(errors))) # Mean signed error within ~3 standard errors of zero -> unbiased.