Skip to content

Commit 9eba5c1

Browse files
Fix!: Add model default audits in the model preserving their args (TobikoData#5106)
1 parent 4e80a93 commit 9eba5c1

File tree

7 files changed

+431
-54
lines changed

7 files changed

+431
-54
lines changed

sqlmesh/core/dialect.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,6 +1408,41 @@ def extract_func_call(
14081408
return func.lower(), kwargs
14091409

14101410

1411+
def extract_function_calls(func_calls: t.Any, allow_tuples: bool = False) -> t.Any:
1412+
"""Used for extracting function calls for signals or audits."""
1413+
1414+
if isinstance(func_calls, (exp.Tuple, exp.Array)):
1415+
return [extract_func_call(i, allow_tuples=allow_tuples) for i in func_calls.expressions]
1416+
if isinstance(func_calls, exp.Paren):
1417+
return [extract_func_call(func_calls.this, allow_tuples=allow_tuples)]
1418+
if isinstance(func_calls, exp.Expression):
1419+
return [extract_func_call(func_calls, allow_tuples=allow_tuples)]
1420+
if isinstance(func_calls, list):
1421+
function_calls = []
1422+
for entry in func_calls:
1423+
if isinstance(entry, dict):
1424+
args = entry
1425+
name = "" if allow_tuples else entry.pop("name")
1426+
elif isinstance(entry, (tuple, list)):
1427+
name, args = entry
1428+
else:
1429+
raise ConfigError(f"Audit must be a dictionary or named tuple. Got {entry}.")
1430+
1431+
function_calls.append(
1432+
(
1433+
name.lower(),
1434+
{
1435+
key: parse_one(value) if isinstance(value, str) else value
1436+
for key, value in args.items()
1437+
},
1438+
)
1439+
)
1440+
1441+
return function_calls
1442+
1443+
return func_calls or []
1444+
1445+
14111446
def is_meta_expression(v: t.Any) -> bool:
14121447
return isinstance(v, (Audit, Metric, Model))
14131448

sqlmesh/core/loader.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,6 @@ def _load_sql_models(
594594
macros=macros,
595595
jinja_macros=jinja_macros,
596596
audit_definitions=audits,
597-
default_audits=self.config.model_defaults.audits,
598597
module_path=self.config_path,
599598
dialect=self.config.model_defaults.dialect,
600599
time_column_format=self.config.time_column_format,

sqlmesh/core/model/definition.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
sorted_python_env_payloads,
3333
validate_extra_and_required_fields,
3434
)
35-
from sqlmesh.core.model.meta import ModelMeta, FunctionCall
35+
from sqlmesh.core.model.meta import ModelMeta
3636
from sqlmesh.core.model.kind import (
3737
ModelKindName,
3838
SeedKind,
@@ -2038,7 +2038,6 @@ def load_sql_based_model(
20382038
macros: t.Optional[MacroRegistry] = None,
20392039
jinja_macros: t.Optional[JinjaMacroRegistry] = None,
20402040
audits: t.Optional[t.Dict[str, ModelAudit]] = None,
2041-
default_audits: t.Optional[t.List[FunctionCall]] = None,
20422041
python_env: t.Optional[t.Dict[str, Executable]] = None,
20432042
dialect: t.Optional[str] = None,
20442043
physical_schema_mapping: t.Optional[t.Dict[re.Pattern, str]] = None,
@@ -2211,7 +2210,6 @@ def load_sql_based_model(
22112210
physical_schema_mapping=physical_schema_mapping,
22122211
default_catalog=default_catalog,
22132212
variables=variables,
2214-
default_audits=default_audits,
22152213
inline_audits=inline_audits,
22162214
blueprint_variables=blueprint_variables,
22172215
**meta_fields,
@@ -2431,7 +2429,6 @@ def _create_model(
24312429
physical_schema_mapping: t.Optional[t.Dict[re.Pattern, str]] = None,
24322430
python_env: t.Optional[t.Dict[str, Executable]] = None,
24332431
audit_definitions: t.Optional[t.Dict[str, ModelAudit]] = None,
2434-
default_audits: t.Optional[t.List[FunctionCall]] = None,
24352432
inline_audits: t.Optional[t.Dict[str, ModelAudit]] = None,
24362433
module_path: Path = Path(),
24372434
macros: t.Optional[MacroRegistry] = None,
@@ -2541,6 +2538,10 @@ def _create_model(
25412538
for jinja_macro in jinja_macros.root_macros.values():
25422539
used_variables.update(extract_macro_references_and_variables(jinja_macro.definition)[1])
25432540

2541+
# Merge model-specific audits with default audits
2542+
if default_audits := defaults.pop("audits", None):
2543+
kwargs["audits"] = default_audits + d.extract_function_calls(kwargs.pop("audits", []))
2544+
25442545
model = klass(
25452546
name=name,
25462547
**{
@@ -2558,12 +2559,7 @@ def _create_model(
25582559
**(inline_audits or {}),
25592560
}
25602561

2561-
# TODO: default_audits needs to be merged with model.audits; the former's arguments
2562-
# are silently dropped today because we add them in audit_definitions. We also need
2563-
# to check for duplicates when we implement this merging logic.
2564-
used_audits: t.Set[str] = set()
2565-
used_audits.update(audit_name for audit_name, _ in default_audits or [])
2566-
used_audits.update(audit_name for audit_name, _ in model.audits)
2562+
used_audits: t.Set[str] = {audit_name for audit_name, _ in model.audits}
25672563

25682564
audit_definitions = {
25692565
audit_name: audit_definitions[audit_name]

sqlmesh/core/model/meta.py

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from sqlmesh.core import dialect as d
1313
from sqlmesh.core.config.linter import LinterConfig
14-
from sqlmesh.core.dialect import normalize_model_name, extract_func_call
14+
from sqlmesh.core.dialect import normalize_model_name
1515
from sqlmesh.core.model.common import (
1616
bool_validator,
1717
default_catalog_validator,
@@ -94,37 +94,7 @@ class ModelMeta(_Node):
9494
def _func_call_validator(cls, v: t.Any, field: t.Any) -> t.Any:
9595
is_signal = getattr(field, "name" if hasattr(field, "name") else "field_name") == "signals"
9696

97-
if isinstance(v, (exp.Tuple, exp.Array)):
98-
return [extract_func_call(i, allow_tuples=is_signal) for i in v.expressions]
99-
if isinstance(v, exp.Paren):
100-
return [extract_func_call(v.this, allow_tuples=is_signal)]
101-
if isinstance(v, exp.Expression):
102-
return [extract_func_call(v, allow_tuples=is_signal)]
103-
if isinstance(v, list):
104-
audits = []
105-
106-
for entry in v:
107-
if isinstance(entry, dict):
108-
args = entry
109-
name = "" if is_signal else entry.pop("name")
110-
elif isinstance(entry, (tuple, list)):
111-
name, args = entry
112-
else:
113-
raise ConfigError(f"Audit must be a dictionary or named tuple. Got {entry}.")
114-
115-
audits.append(
116-
(
117-
name.lower(),
118-
{
119-
key: d.parse_one(value) if isinstance(value, str) else value
120-
for key, value in args.items()
121-
},
122-
)
123-
)
124-
125-
return audits
126-
127-
return v or []
97+
return d.extract_function_calls(v, allow_tuples=is_signal)
12898

12999
@field_validator("tags", mode="before")
130100
def _value_or_tuple_validator(cls, v: t.Any, info: ValidationInfo) -> t.Any:

tests/core/test_audit.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from sqlglot import exp, parse_one
44

55
from sqlmesh.core import constants as c
6+
from sqlmesh.core.config.model import ModelDefaultsConfig
67
from sqlmesh.core.context import Context
78
from sqlmesh.core.audit import (
89
ModelAudit,
@@ -962,6 +963,117 @@ def test_multiple_audits_with_same_name():
962963
assert model.audits[1][1] == model.audits[2][1]
963964

964965

966+
def test_default_audits_included_when_no_model_audits():
967+
expressions = parse("""
968+
MODEL (
969+
name test.basic_model
970+
);
971+
SELECT 1 as id, 'test' as name;
972+
""")
973+
974+
model_defaults = ModelDefaultsConfig(
975+
dialect="duckdb", audits=["not_null(columns := ['id'])", "unique_values(columns := ['id'])"]
976+
)
977+
model = load_sql_based_model(expressions, defaults=model_defaults.dict())
978+
979+
assert len(model.audits) == 2
980+
audit_names = [audit[0] for audit in model.audits]
981+
assert "not_null" in audit_names
982+
assert "unique_values" in audit_names
983+
984+
# Verify arguments are preserved
985+
for audit_name, audit_args in model.audits:
986+
if audit_name == "not_null":
987+
assert "columns" in audit_args
988+
assert audit_args["columns"].expressions[0].this == "id"
989+
elif audit_name == "unique_values":
990+
assert "columns" in audit_args
991+
assert audit_args["columns"].expressions[0].this == "id"
992+
993+
for audit_name, audit_args in model.audits_with_args:
994+
if audit_name == "not_null":
995+
assert "columns" in audit_args
996+
assert audit_args["columns"].expressions[0].this == "id"
997+
elif audit_name == "unique_values":
998+
assert "columns" in audit_args
999+
assert audit_args["columns"].expressions[0].this == "id"
1000+
1001+
1002+
def test_model_defaults_audits_with_same_name():
1003+
expressions = parse(
1004+
"""
1005+
MODEL (
1006+
name db.table,
1007+
dialect spark,
1008+
audits(
1009+
does_not_exceed_threshold(column := id, threshold := 1000),
1010+
does_not_exceed_threshold(column := price, threshold := 100),
1011+
unique_values(columns := ['id'])
1012+
)
1013+
);
1014+
1015+
SELECT id, price FROM tbl;
1016+
1017+
AUDIT (
1018+
name does_not_exceed_threshold,
1019+
);
1020+
SELECT * FROM @this_model
1021+
WHERE @column >= @threshold;
1022+
"""
1023+
)
1024+
1025+
model_defaults = ModelDefaultsConfig(
1026+
dialect="duckdb",
1027+
audits=[
1028+
"does_not_exceed_threshold(column := price, threshold := 33)",
1029+
"does_not_exceed_threshold(column := id, threshold := 65)",
1030+
"not_null(columns := ['id'])",
1031+
],
1032+
)
1033+
model = load_sql_based_model(expressions, defaults=model_defaults.dict())
1034+
assert len(model.audits) == 6
1035+
assert len(model.audits_with_args) == 6
1036+
assert len(model.audit_definitions) == 1
1037+
1038+
expected_audits = [
1039+
(
1040+
"does_not_exceed_threshold",
1041+
{"column": exp.column("price"), "threshold": exp.Literal.number(33)},
1042+
),
1043+
(
1044+
"does_not_exceed_threshold",
1045+
{"column": exp.column("id"), "threshold": exp.Literal.number(65)},
1046+
),
1047+
("not_null", {"columns": exp.convert(["id"])}),
1048+
(
1049+
"does_not_exceed_threshold",
1050+
{"column": exp.column("id"), "threshold": exp.Literal.number(1000)},
1051+
),
1052+
(
1053+
"does_not_exceed_threshold",
1054+
{"column": exp.column("price"), "threshold": exp.Literal.number(100)},
1055+
),
1056+
("unique_values", {"columns": exp.convert(["id"])}),
1057+
]
1058+
1059+
for (actual_name, actual_args), (expected_name, expected_args) in zip(
1060+
model.audits, expected_audits
1061+
):
1062+
# Validate the audit names are preserved
1063+
assert actual_name == expected_name
1064+
for key in expected_args:
1065+
# comparing sql representaion is easier
1066+
assert actual_args[key].sql() == expected_args[key].sql()
1067+
1068+
# Validate audits with args as well along with their arguments
1069+
for (actual_audit, actual_args), (expected_name, expected_args) in zip(
1070+
model.audits_with_args, expected_audits
1071+
):
1072+
assert actual_audit.name == expected_name
1073+
for key in expected_args:
1074+
assert actual_args[key].sql() == expected_args[key].sql()
1075+
1076+
9651077
def test_audit_formatting_flag_serde():
9661078
expressions = parse(
9671079
"""

0 commit comments

Comments
 (0)