From 1c07b41071080edf8f65c73418f95fb425ecc836 Mon Sep 17 00:00:00 2001 From: Ariel Jassan Date: Sun, 28 Jun 2026 17:17:16 +0000 Subject: [PATCH] feat(scorers): add hybrid execution accuracy judging and SQLite ground truth resolution --- evalbench/databases/bigquery.py | 3 + evalbench/scorers/llmrater.py | 12 +++- evalbench/scorers/sqlite_bridge.py | 52 ++++++++++++++ hybrid_xa_judge.py | 109 +++++++++++++++++++++++++++++ 4 files changed, 174 insertions(+), 2 deletions(-) create mode 100644 evalbench/scorers/sqlite_bridge.py create mode 100644 hybrid_xa_judge.py diff --git a/evalbench/databases/bigquery.py b/evalbench/databases/bigquery.py index b8cad5f5..838b6a8c 100644 --- a/evalbench/databases/bigquery.py +++ b/evalbench/databases/bigquery.py @@ -27,6 +27,9 @@ def __init__(self, db_config): self.client = bigquery.Client(project=self.project_id) self.tmp_users = [] + def ensure_database_exists(self): + pass + ##################################################### ##################################################### # Database Specific Execution Logic and Handling diff --git a/evalbench/scorers/llmrater.py b/evalbench/scorers/llmrater.py index 42112ca3..85e4f4b7 100644 --- a/evalbench/scorers/llmrater.py +++ b/evalbench/scorers/llmrater.py @@ -21,7 +21,7 @@ from scorers import setmatcher import logging -from scorers import comparator +from scorers import comparator, sqlite_bridge from .util import make_hashable, with_cache_execute from databases.util import get_cache_client @@ -252,7 +252,15 @@ def compare( return 100, "Skipped. Exact Match was found." if golden_error: - return 0, "Golden query failed to execute." + # If using hybrid judge, fetch ground truth from SQLite when BQ + # fails on golden query syntax (e.g. SQLite functions in reference + # queries). + if sqlite_bridge.is_hybrid_cross_db_enabled(): + golden_execution_result = ( + sqlite_bridge.get_sqlite_ground_truth(golden_query) + ) + else: + return 0, "Golden query failed to execute." if generated_error: return 0, "Generated query failed to execute." diff --git a/evalbench/scorers/sqlite_bridge.py b/evalbench/scorers/sqlite_bridge.py new file mode 100644 index 00000000..51111605 --- /dev/null +++ b/evalbench/scorers/sqlite_bridge.py @@ -0,0 +1,52 @@ +"""SQLite Ground Truth Resolution Adapter for EvalBench.""" + +import os +import sqlite3 +import sys + +import pandas as pd +import yaml + + +def get_sqlite_ground_truth(query: str) -> list: + """Resolves candidate SQLite database files and executes query.""" + parent_dir = os.path.dirname(__file__) + root_dir = os.path.abspath(os.path.join(parent_dir, "..", "..")) + db_dir = os.path.join(root_dir, "db_connections", "bird") + if not os.path.exists(db_dir): + return [] + + candidates = [ + f[:-7] for f in os.listdir(db_dir) if f.endswith(".sqlite") + ] + for cand in candidates: + sqlite_path = os.path.join(db_dir, f"{cand}.sqlite") + try: + conn = sqlite3.connect(sqlite_path) + df_cand = pd.read_sql_query(query, conn) + conn.close() + return df_cand.to_dict(orient="records") + except Exception: + continue + + return [] + + +def is_hybrid_cross_db_enabled() -> bool: + """Checks if hybrid_cross_db is supplied in experiment config.""" + for arg in sys.argv: + if arg.startswith("--experiment_config="): + config_path = arg.split("=", 1)[1] + try: + with open(config_path, "r") as f: + cfg = yaml.safe_load(f) + py_scorer = cfg.get("scorers", {}).get("python_scorer", {}) + script = str(py_scorer.get("script_path", "")) + name = str(py_scorer.get("scorer_name", "")) + is_judge = "hybrid_xa_judge.py" in script + is_name = "hybrid_cross_db" in name + if is_judge or is_name: + return True + except Exception: + pass + return False diff --git a/hybrid_xa_judge.py b/hybrid_xa_judge.py new file mode 100644 index 00000000..8641055a --- /dev/null +++ b/hybrid_xa_judge.py @@ -0,0 +1,109 @@ +"""Hybrid Execution Accuracy (XA) Cross-Database Evaluator for EvalBench.""" + +from decimal import Decimal +import json +import sqlite3 +import sys + +import pandas as pd + +from evalbench.scorers.sqlite_bridge import get_sqlite_ground_truth + + +def compare_result_sets( + df_bq: pd.DataFrame, df_sqlite: pd.DataFrame +) -> bool: + """Compares two DataFrames ignoring column names and row order.""" + if df_bq is None or df_sqlite is None: + return False + + if df_bq.empty and df_sqlite.empty: + return True + + if df_bq.empty != df_sqlite.empty: + return False + + def normalize_df(df: pd.DataFrame) -> list[tuple]: + rows = [] + for _, r in df.iterrows(): + normalized_row = [] + for val in r: + if pd.isna(val): + normalized_row.append(None) + elif isinstance(val, (int, float, Decimal)): + try: + normalized_row.append(round(float(val), 4)) + except (ValueError, TypeError): + normalized_row.append(str(val)) + else: + s = str(val).strip().lower() + if s.endswith(".0"): + s = s[:-2] + normalized_row.append(s) + rows.append(tuple(normalized_row)) + rows.sort(key=lambda x: str(x)) + return rows + + try: + bq_rows = normalize_df(df_bq) + sqlite_rows = normalize_df(df_sqlite) + except Exception: + return False + + if len(bq_rows) != len(sqlite_rows): + return False + + for r_bq, r_sqlite in zip(bq_rows, sqlite_rows): + if len(r_bq) != len(r_sqlite): + return False + for val_bq, val_sqlite in zip(r_bq, r_sqlite): + if val_bq != val_sqlite: + return False + + return True + + +def main(): + try: + input_data = json.load(sys.stdin) + pred_rows = input_data.get("generated_execution_result") + ref_sql = input_data.get("golden_query", "") + + sqlite_records = get_sqlite_ground_truth(ref_sql) + df_sqlite = pd.DataFrame(sqlite_records) + sqlite_res_str = json.dumps(sqlite_records) + + gen_err = input_data.get("generated_error") + if pred_rows is None or isinstance(pred_rows, str) or gen_err: + err_msg = gen_err or "Invalid prediction object" + reason = ( + f"FAIL | BigQuery Error: {err_msg} | " + f"SQLite Ground Truth Result: {sqlite_res_str}" + ) + print(json.dumps({"score": 0.0, "reason": reason})) + return + + if isinstance(pred_rows, list): + df_bq = pd.DataFrame(pred_rows) + else: + df_bq = pd.DataFrame() + + match = compare_result_sets(df_bq, df_sqlite) + score = 100.0 if match else 0.0 + if match: + reason = f"PASS | SQLite Ground Truth Result: {sqlite_res_str}" + else: + bq_res_str = json.dumps(df_bq.to_dict(orient="records")) + reason = ( + f"FAIL | BQ Prediction: {bq_res_str} vs " + f"SQLite Ground Truth: {sqlite_res_str}" + ) + print(json.dumps({"score": score, "reason": reason})) + + except Exception as e: + err_reason = f"FAIL: Exception in hybrid judge: {e}" + print(json.dumps({"score": 0.0, "reason": err_reason})) + + +if __name__ == "__main__": + main()