Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions evalbench/databases/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def __init__(self, db_config):
self.client = bigquery.Client(project=self.project_id)
self.tmp_users = []

def ensure_database_exists(self):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Signature violates the abstract base class. databases/db.py:182 declares:

@abstractmethod
def ensure_database_exists(self, database_name: str) -> None:

All 8 other implementations (mysql, postgres, sqlite, spanner, bigtable, mongodb, sqlserver) match this signature. The new BQ override drops the parameter entirely. Any generic caller that does db.ensure_database_exists("foo") — which is the documented contract — will hit TypeError: ensure_database_exists() takes 1 positional argument but 2 were given.

def ensure_database_exists(self, database_name: str) -> None:
    # BigQuery datasets are project-scoped; no per-database creation needed.
    pass

pass

#####################################################
#####################################################
# Database Specific Execution Logic and Handling
Expand Down
12 changes: 10 additions & 2 deletions evalbench/scorers/llmrater.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from scorers import setmatcher
import logging

from scorers import comparator
from scorers import comparator, sqlite_bridge
from .util import make_hashable, with_cache_execute
from databases.util import get_cache_client

Expand Down Expand Up @@ -252,7 +252,15 @@ def compare(
return 100, "Skipped. Exact Match was found."

if golden_error:
return 0, "Golden query failed to execute."
# If using hybrid judge, fetch ground truth from SQLite when BQ
# fails on golden query syntax (e.g. SQLite functions in reference
# queries).
if sqlite_bridge.is_hybrid_cross_db_enabled():

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Silent fallback — no log line when it fires. When hybrid mode kicks in and resolves a golden answer from SQLite, nothing in the logs records that this happened. The eval row just has a normal-looking score. Debugging "why does my BQ run report fewer golden errors than expected" becomes a source-dive.

Add one log line:

if sqlite_bridge.is_hybrid_cross_db_enabled():
    logging.info(
        "Hybrid ground truth: BQ golden query failed, resolving from "
        "SQLite reference. query=%s", golden_query
    )
    golden_execution_result = ...

golden_execution_result = (
sqlite_bridge.get_sqlite_ground_truth(golden_query)
)
else:
return 0, "Golden query failed to execute."
if generated_error:
return 0, "Generated query failed to execute."

Expand Down
52 changes: 52 additions & 0 deletions evalbench/scorers/sqlite_bridge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""SQLite Ground Truth Resolution Adapter for EvalBench."""

import os
import sqlite3
import sys

import pandas as pd
import yaml


def get_sqlite_ground_truth(query: str) -> list:
"""Resolves candidate SQLite database files and executes query."""
parent_dir = os.path.dirname(__file__)
root_dir = os.path.abspath(os.path.join(parent_dir, "..", ".."))
db_dir = os.path.join(root_dir, "db_connections", "bird")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoded path. db_connections/bird is specific to one dataset and won't work for other suites. The experiment_config already carries database_configs; derive the SQLite path from there instead of hardcoding the directory. As written, this scorer is silently bird-only — and nothing in the docstring says so.

if not os.path.exists(db_dir):
return []

candidates = [
f[:-7] for f in os.listdir(db_dir) if f.endswith(".sqlite")
]
for cand in candidates:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Picks an arbitrary SQLite file by trial-and-error. Iterating every .sqlite file in the directory and using whichever one accepts the query is wrong in two ways:

  1. Non-deterministic. Returns whatever file os.listdir yielded first that didn't raise. If two BIRD databases (california_schools, card_games) happen to have tables/columns with overlapping names, the resolved "truth" depends on the OS's directory ordering. Run the same scenario on two machines, get two different scores.
  2. Silent wrong-answer risk. A query may "succeed" against the wrong DB by returning bogus rows that don't match the scenario's intent. The judge then scores against those rows and reports a misleading PASS/FAIL.

The scenario record already carries the database name (the database field in BIRD dataset entries — california_schools, etc.). Pass it through:

def get_sqlite_ground_truth(query: str, database: str) -> list:
    sqlite_path = os.path.join(db_dir, f"{database}.sqlite")
    if not os.path.exists(sqlite_path):
        return []
    conn = sqlite3.connect(sqlite_path)
    try:
        return pd.read_sql_query(query, conn).to_dict(orient="records")
    finally:
        conn.close()

Bonus: dramatically faster — one connect instead of N.

sqlite_path = os.path.join(db_dir, f"{cand}.sqlite")
try:
conn = sqlite3.connect(sqlite_path)
df_cand = pd.read_sql_query(query, conn)
conn.close()
return df_cand.to_dict(orient="records")
except Exception:
continue

return []


def is_hybrid_cross_db_enabled() -> bool:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sniffing sys.argv to detect mode is fragile and creates several real problems:

  • Only works for python evalbench.py --experiment_config=foo.yaml. Other invocation paths don't have this argv pattern: the gRPC server (eval_server.py) constructs configs programmatically, run_suite passes config paths via a different mechanism, and any programmatic / in-process caller doesn't have argv at all. Hybrid mode silently disables in these cases.
  • Re-parses the YAML on every call. compare() runs once per scenario per trial; this is O(N) file reads of the same YAML.
  • Cross-scorer coupling via string match. Checks the PythonScorer's script_path to decide what LLMRater does. If someone renames hybrid_xa_judge.py or moves it, LLMRater silently goes back to returning 0 on golden errors.

The signal that hybrid mode is active is conceptually a scorer config, not a global mode. Lift it to LLMRater.__init__:

def __init__(self, config, global_models):
    ...
    self.hybrid_ground_truth = config.get("hybrid_ground_truth", False)

And in YAML:

scorers:
  llm_rater:
    hybrid_ground_truth: true

Then the compare() branch checks self.hybrid_ground_truth. No argv sniffing, no cross-scorer coupling, no file I/O in the hot path.

"""Checks if hybrid_cross_db is supplied in experiment config."""
for arg in sys.argv:
if arg.startswith("--experiment_config="):
config_path = arg.split("=", 1)[1]
try:
with open(config_path, "r") as f:
cfg = yaml.safe_load(f)
py_scorer = cfg.get("scorers", {}).get("python_scorer", {})
script = str(py_scorer.get("script_path", ""))
name = str(py_scorer.get("scorer_name", ""))
is_judge = "hybrid_xa_judge.py" in script
is_name = "hybrid_cross_db" in name
if is_judge or is_name:
return True
except Exception:
pass
return False
109 changes: 109 additions & 0 deletions hybrid_xa_judge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Hybrid Execution Accuracy (XA) Cross-Database Evaluator for EvalBench."""

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Script lives at the repo root. All other Python under this repo lives under evalbench/ or datasets/. Convention break. Suggest moving to something like evalbench/scorers/judges/hybrid_xa_judge.py and updating the script_path in any example configs that reference it.


from decimal import Decimal
import json
import sqlite3
import sys

import pandas as pd

from evalbench.scorers.sqlite_bridge import get_sqlite_ground_truth

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This import will fail at runtime. pythonscorer.py:57 invokes the script as:

command = ["uv", "run", "--isolated", self.script_path]

--isolated strips the parent's PYTHONPATH and runs in a clean venv with no access to the evalbench source tree. from evalbench.scorers.sqlite_bridge import ... will raise ModuleNotFoundError. The script returns nonzero, the scorer reports a generic FAIL: Script failed with exit code..., and the user never sees a real result.

The PR description mentions functional end-to-end testing, so either I'm misreading the invocation OR the testing didn't actually exercise this code path. Worth confirming with a fresh-venv repro.

Fix: inline the ~10 lines of get_sqlite_ground_truth directly into hybrid_xa_judge.py so the script has no project-local imports. Together with #2's fix (pass database in), the inlined function becomes very small.



def compare_result_sets(
df_bq: pd.DataFrame, df_sqlite: pd.DataFrame
) -> bool:
"""Compares two DataFrames ignoring column names and row order."""
if df_bq is None or df_sqlite is None:
return False

if df_bq.empty and df_sqlite.empty:
return True

if df_bq.empty != df_sqlite.empty:
return False

def normalize_df(df: pd.DataFrame) -> list[tuple]:
rows = []
for _, r in df.iterrows():
normalized_row = []
for val in r:
if pd.isna(val):
normalized_row.append(None)
elif isinstance(val, (int, float, Decimal)):
try:
normalized_row.append(round(float(val), 4))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compare_result_sets has three lossy behaviors worth pinning down in the docstring so the next reader doesn't second-guess them:

  1. Floats rounded to 4 decimals (round(float(val), 4)). For some BIRD queries — AVG/SUM over money or rate fields — the reference may exceed 4 decimal places, turning a real mismatch into a false PASS. Either widen the precision or document the choice.
  2. Sort key lambda x: str(x) sorts [1, 10, 2] as ["(1,)", "(10,)", "(2,)"]. Fine for equality (both sides sort identically) but unreadable in FAIL logs. Consider key=lambda r: tuple(str(v) for v in r) for slightly less surprising debug output.
  3. .0 suffix stripped from string values. If BQ returns the literal string "3.0" (e.g., a version label) and SQLite returns "3", they're considered equal. Almost certainly intentional for numeric-as-string coercion but worth a one-line comment.

A short docstring listing these explicitly makes the contract clear and lets future maintainers reason about edge cases without re-deriving them.

except (ValueError, TypeError):
normalized_row.append(str(val))
else:
s = str(val).strip().lower()
if s.endswith(".0"):
s = s[:-2]
normalized_row.append(s)
rows.append(tuple(normalized_row))
rows.sort(key=lambda x: str(x))
return rows

try:
bq_rows = normalize_df(df_bq)
sqlite_rows = normalize_df(df_sqlite)
except Exception:
return False

if len(bq_rows) != len(sqlite_rows):
return False

for r_bq, r_sqlite in zip(bq_rows, sqlite_rows):
if len(r_bq) != len(r_sqlite):
return False
for val_bq, val_sqlite in zip(r_bq, r_sqlite):
if val_bq != val_sqlite:
return False

return True


def main():
try:
input_data = json.load(sys.stdin)
pred_rows = input_data.get("generated_execution_result")
ref_sql = input_data.get("golden_query", "")

sqlite_records = get_sqlite_ground_truth(ref_sql)
df_sqlite = pd.DataFrame(sqlite_records)
sqlite_res_str = json.dumps(sqlite_records)

gen_err = input_data.get("generated_error")
if pred_rows is None or isinstance(pred_rows, str) or gen_err:
err_msg = gen_err or "Invalid prediction object"
reason = (
f"FAIL | BigQuery Error: {err_msg} | "
f"SQLite Ground Truth Result: {sqlite_res_str}"
)
print(json.dumps({"score": 0.0, "reason": reason}))
return

if isinstance(pred_rows, list):
df_bq = pd.DataFrame(pred_rows)
else:
df_bq = pd.DataFrame()

match = compare_result_sets(df_bq, df_sqlite)
score = 100.0 if match else 0.0
if match:
reason = f"PASS | SQLite Ground Truth Result: {sqlite_res_str}"
else:
bq_res_str = json.dumps(df_bq.to_dict(orient="records"))
reason = (
f"FAIL | BQ Prediction: {bq_res_str} vs "
f"SQLite Ground Truth: {sqlite_res_str}"
)
print(json.dumps({"score": score, "reason": reason}))

except Exception as e:
err_reason = f"FAIL: Exception in hybrid judge: {e}"
print(json.dumps({"score": 0.0, "reason": err_reason}))


if __name__ == "__main__":
main()
Loading