-
Notifications
You must be signed in to change notification settings - Fork 26
feat: Support cross-database evaluation with SQLite ground truth #465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,7 +21,7 @@ | |
| from scorers import setmatcher | ||
| import logging | ||
|
|
||
| from scorers import comparator | ||
| from scorers import comparator, sqlite_bridge | ||
| from .util import make_hashable, with_cache_execute | ||
| from databases.util import get_cache_client | ||
|
|
||
|
|
@@ -252,7 +252,15 @@ def compare( | |
| return 100, "Skipped. Exact Match was found." | ||
|
|
||
| if golden_error: | ||
| return 0, "Golden query failed to execute." | ||
| # If using hybrid judge, fetch ground truth from SQLite when BQ | ||
| # fails on golden query syntax (e.g. SQLite functions in reference | ||
| # queries). | ||
| if sqlite_bridge.is_hybrid_cross_db_enabled(): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Silent fallback — no log line when it fires. When hybrid mode kicks in and resolves a golden answer from SQLite, nothing in the logs records that this happened. The eval row just has a normal-looking score. Debugging "why does my BQ run report fewer golden errors than expected" becomes a source-dive. Add one log line: if sqlite_bridge.is_hybrid_cross_db_enabled():
logging.info(
"Hybrid ground truth: BQ golden query failed, resolving from "
"SQLite reference. query=%s", golden_query
)
golden_execution_result = ... |
||
| golden_execution_result = ( | ||
| sqlite_bridge.get_sqlite_ground_truth(golden_query) | ||
| ) | ||
| else: | ||
| return 0, "Golden query failed to execute." | ||
| if generated_error: | ||
| return 0, "Generated query failed to execute." | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| """SQLite Ground Truth Resolution Adapter for EvalBench.""" | ||
|
|
||
| import os | ||
| import sqlite3 | ||
| import sys | ||
|
|
||
| import pandas as pd | ||
| import yaml | ||
|
|
||
|
|
||
| def get_sqlite_ground_truth(query: str) -> list: | ||
| """Resolves candidate SQLite database files and executes query.""" | ||
| parent_dir = os.path.dirname(__file__) | ||
| root_dir = os.path.abspath(os.path.join(parent_dir, "..", "..")) | ||
| db_dir = os.path.join(root_dir, "db_connections", "bird") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hardcoded path. |
||
| if not os.path.exists(db_dir): | ||
| return [] | ||
|
|
||
| candidates = [ | ||
| f[:-7] for f in os.listdir(db_dir) if f.endswith(".sqlite") | ||
| ] | ||
| for cand in candidates: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Picks an arbitrary SQLite file by trial-and-error. Iterating every
The scenario record already carries the database name (the def get_sqlite_ground_truth(query: str, database: str) -> list:
sqlite_path = os.path.join(db_dir, f"{database}.sqlite")
if not os.path.exists(sqlite_path):
return []
conn = sqlite3.connect(sqlite_path)
try:
return pd.read_sql_query(query, conn).to_dict(orient="records")
finally:
conn.close()Bonus: dramatically faster — one connect instead of N. |
||
| sqlite_path = os.path.join(db_dir, f"{cand}.sqlite") | ||
| try: | ||
| conn = sqlite3.connect(sqlite_path) | ||
| df_cand = pd.read_sql_query(query, conn) | ||
| conn.close() | ||
| return df_cand.to_dict(orient="records") | ||
| except Exception: | ||
| continue | ||
|
|
||
| return [] | ||
|
|
||
|
|
||
| def is_hybrid_cross_db_enabled() -> bool: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sniffing
The signal that hybrid mode is active is conceptually a scorer config, not a global mode. Lift it to def __init__(self, config, global_models):
...
self.hybrid_ground_truth = config.get("hybrid_ground_truth", False)And in YAML: scorers:
llm_rater:
hybrid_ground_truth: trueThen the |
||
| """Checks if hybrid_cross_db is supplied in experiment config.""" | ||
| for arg in sys.argv: | ||
| if arg.startswith("--experiment_config="): | ||
| config_path = arg.split("=", 1)[1] | ||
| try: | ||
| with open(config_path, "r") as f: | ||
| cfg = yaml.safe_load(f) | ||
| py_scorer = cfg.get("scorers", {}).get("python_scorer", {}) | ||
| script = str(py_scorer.get("script_path", "")) | ||
| name = str(py_scorer.get("scorer_name", "")) | ||
| is_judge = "hybrid_xa_judge.py" in script | ||
| is_name = "hybrid_cross_db" in name | ||
| if is_judge or is_name: | ||
| return True | ||
| except Exception: | ||
| pass | ||
| return False | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| """Hybrid Execution Accuracy (XA) Cross-Database Evaluator for EvalBench.""" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Script lives at the repo root. All other Python under this repo lives under |
||
|
|
||
| from decimal import Decimal | ||
| import json | ||
| import sqlite3 | ||
| import sys | ||
|
|
||
| import pandas as pd | ||
|
|
||
| from evalbench.scorers.sqlite_bridge import get_sqlite_ground_truth | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This import will fail at runtime. command = ["uv", "run", "--isolated", self.script_path]
The PR description mentions functional end-to-end testing, so either I'm misreading the invocation OR the testing didn't actually exercise this code path. Worth confirming with a fresh-venv repro. Fix: inline the ~10 lines of |
||
|
|
||
|
|
||
| def compare_result_sets( | ||
| df_bq: pd.DataFrame, df_sqlite: pd.DataFrame | ||
| ) -> bool: | ||
| """Compares two DataFrames ignoring column names and row order.""" | ||
| if df_bq is None or df_sqlite is None: | ||
| return False | ||
|
|
||
| if df_bq.empty and df_sqlite.empty: | ||
| return True | ||
|
|
||
| if df_bq.empty != df_sqlite.empty: | ||
| return False | ||
|
|
||
| def normalize_df(df: pd.DataFrame) -> list[tuple]: | ||
| rows = [] | ||
| for _, r in df.iterrows(): | ||
| normalized_row = [] | ||
| for val in r: | ||
| if pd.isna(val): | ||
| normalized_row.append(None) | ||
| elif isinstance(val, (int, float, Decimal)): | ||
| try: | ||
| normalized_row.append(round(float(val), 4)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
A short docstring listing these explicitly makes the contract clear and lets future maintainers reason about edge cases without re-deriving them. |
||
| except (ValueError, TypeError): | ||
| normalized_row.append(str(val)) | ||
| else: | ||
| s = str(val).strip().lower() | ||
| if s.endswith(".0"): | ||
| s = s[:-2] | ||
| normalized_row.append(s) | ||
| rows.append(tuple(normalized_row)) | ||
| rows.sort(key=lambda x: str(x)) | ||
| return rows | ||
|
|
||
| try: | ||
| bq_rows = normalize_df(df_bq) | ||
| sqlite_rows = normalize_df(df_sqlite) | ||
| except Exception: | ||
| return False | ||
|
|
||
| if len(bq_rows) != len(sqlite_rows): | ||
| return False | ||
|
|
||
| for r_bq, r_sqlite in zip(bq_rows, sqlite_rows): | ||
| if len(r_bq) != len(r_sqlite): | ||
| return False | ||
| for val_bq, val_sqlite in zip(r_bq, r_sqlite): | ||
| if val_bq != val_sqlite: | ||
| return False | ||
|
|
||
| return True | ||
|
|
||
|
|
||
| def main(): | ||
| try: | ||
| input_data = json.load(sys.stdin) | ||
| pred_rows = input_data.get("generated_execution_result") | ||
| ref_sql = input_data.get("golden_query", "") | ||
|
|
||
| sqlite_records = get_sqlite_ground_truth(ref_sql) | ||
| df_sqlite = pd.DataFrame(sqlite_records) | ||
| sqlite_res_str = json.dumps(sqlite_records) | ||
|
|
||
| gen_err = input_data.get("generated_error") | ||
| if pred_rows is None or isinstance(pred_rows, str) or gen_err: | ||
| err_msg = gen_err or "Invalid prediction object" | ||
| reason = ( | ||
| f"FAIL | BigQuery Error: {err_msg} | " | ||
| f"SQLite Ground Truth Result: {sqlite_res_str}" | ||
| ) | ||
| print(json.dumps({"score": 0.0, "reason": reason})) | ||
| return | ||
|
|
||
| if isinstance(pred_rows, list): | ||
| df_bq = pd.DataFrame(pred_rows) | ||
| else: | ||
| df_bq = pd.DataFrame() | ||
|
|
||
| match = compare_result_sets(df_bq, df_sqlite) | ||
| score = 100.0 if match else 0.0 | ||
| if match: | ||
| reason = f"PASS | SQLite Ground Truth Result: {sqlite_res_str}" | ||
| else: | ||
| bq_res_str = json.dumps(df_bq.to_dict(orient="records")) | ||
| reason = ( | ||
| f"FAIL | BQ Prediction: {bq_res_str} vs " | ||
| f"SQLite Ground Truth: {sqlite_res_str}" | ||
| ) | ||
| print(json.dumps({"score": score, "reason": reason})) | ||
|
|
||
| except Exception as e: | ||
| err_reason = f"FAIL: Exception in hybrid judge: {e}" | ||
| print(json.dumps({"score": 0.0, "reason": err_reason})) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Signature violates the abstract base class.
databases/db.py:182declares:All 8 other implementations (mysql, postgres, sqlite, spanner, bigtable, mongodb, sqlserver) match this signature. The new BQ override drops the parameter entirely. Any generic caller that does
db.ensure_database_exists("foo")— which is the documented contract — will hitTypeError: ensure_database_exists() takes 1 positional argument but 2 were given.