Skip to content

Commit fea0204

Browse files
fix(optimizer)!: query schema directly when type annotation fails for processing UNNEST source
1 parent f7458a4 commit fea0204

File tree

2 files changed

+121
-2
lines changed

2 files changed

+121
-2
lines changed

sqlglot/optimizer/resolver.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,22 @@ def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequenc
144144
# in bigquery, unnest structs are automatically scoped as tables, so you can
145145
# directly select a struct field in a query.
146146
# this handles the case where the unnest is statically defined.
147-
if self.dialect.UNNEST_COLUMN_ONLY:
148-
if source.expression.is_type(exp.DataType.Type.STRUCT):
147+
if self.dialect.UNNEST_COLUMN_ONLY and isinstance(source.expression, exp.Unnest):
148+
unnest_type = source.expression.type
149+
150+
# if type is not annotated yet, try to get it from the schema
151+
if not unnest_type or unnest_type.is_type(exp.DataType.Type.UNKNOWN):
152+
unnest_expr = seq_get(source.expression.expressions, 0)
153+
if isinstance(unnest_expr, exp.Column) and self.scope.parent:
154+
unnest_type = self._get_unnest_column_type(unnest_expr)
155+
156+
# check if unnesting an ARRAY of STRUCTs - extract struct field names
157+
if unnest_type and unnest_type.is_type(exp.DataType.Type.ARRAY):
158+
element_types = unnest_type.expressions
159+
if element_types and element_types[0].is_type(exp.DataType.Type.STRUCT):
160+
for field in element_types[0].expressions: # type: ignore
161+
columns.append(field.name)
162+
elif source.expression.is_type(exp.DataType.Type.STRUCT):
149163
for k in source.expression.type.expressions: # type: ignore
150164
columns.append(k.name)
151165
elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation):
@@ -299,3 +313,66 @@ def _get_unambiguous_columns(
299313
unambiguous_columns[column] = table
300314

301315
return unambiguous_columns
316+
317+
def _get_unnest_column_type(self, column: exp.Column) -> t.Optional[exp.DataType]:
318+
"""
319+
Get the type of a column being unnested, tracing through CTEs/subqueries to find the base table.
320+
321+
Args:
322+
column: The column expression being unnested.
323+
324+
Returns:
325+
The DataType of the column, or None if not found.
326+
"""
327+
# start from parent scope and trace through sources to find the actual table
328+
scope = self.scope.parent
329+
if not scope:
330+
return None
331+
332+
# try each source in the parent scope to find which one contains this column
333+
for source_name in scope.sources:
334+
source = scope.sources[source_name]
335+
col_type: t.Optional[exp.DataType]
336+
337+
if isinstance(source, exp.Table):
338+
# found a base table - get the column type from schema
339+
col_type = self.schema.get_column_type(source, column)
340+
if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN):
341+
return col_type
342+
elif isinstance(source, Scope):
343+
# CTE or subquery - recursively check its sources
344+
col_type = self._get_column_type_from_scope(source, column.name)
345+
if col_type:
346+
return col_type
347+
348+
return None
349+
350+
def _get_column_type_from_scope(self, scope: Scope, col_name: str) -> t.Optional[exp.DataType]:
351+
"""
352+
Recursively find a column's type by tracing through nested scopes to the base table.
353+
354+
Args:
355+
scope: The scope to search.
356+
col_name: The column name to find.
357+
358+
Returns:
359+
The DataType of the column, or None if not found.
360+
"""
361+
for source_name in scope.sources:
362+
source = scope.sources[source_name]
363+
col_type: t.Optional[exp.DataType]
364+
365+
if isinstance(source, exp.Table):
366+
# found a base table - try to get the column type
367+
col_type = self.schema.get_column_type(
368+
source, exp.Column(this=exp.to_identifier(col_name))
369+
)
370+
if col_type and not col_type.is_type(exp.DataType.Type.UNKNOWN):
371+
return col_type
372+
elif isinstance(source, Scope):
373+
# nested scope - recurse
374+
col_type = self._get_column_type_from_scope(source, col_name)
375+
if col_type:
376+
return col_type
377+
378+
return None

tests/test_optimizer.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,48 @@ def test_qualify_columns(self, logger):
516516
"SELECT a.b_id AS b_id FROM a AS a JOIN b AS b ON a.b_id = b.b_id JOIN c AS c ON b.b_id = c.b_id JOIN d AS d ON b.d_id = d.d_id",
517517
)
518518

519+
self.assertEqual(
520+
optimizer.qualify.qualify(
521+
parse_one(
522+
"""
523+
SELECT
524+
(SELECT SUM(c.amount)
525+
FROM UNNEST(credits) AS c
526+
WHERE type != 'promotion') as total
527+
FROM billing
528+
""",
529+
read="bigquery",
530+
),
531+
schema={"billing": {"credits": "ARRAY<STRUCT<amount FLOAT64, type STRING>>"}},
532+
dialect="bigquery",
533+
).sql(dialect="bigquery"),
534+
"SELECT (SELECT SUM(`c`.`amount`) AS `_col_0` FROM UNNEST(`billing`.`credits`) AS `c` WHERE `type` <> 'promotion') AS `total` FROM `billing` AS `billing`",
535+
)
536+
537+
self.assertEqual(
538+
optimizer.qualify.qualify(
539+
parse_one(
540+
"""
541+
WITH cte AS (SELECT * FROM base_table)
542+
SELECT
543+
(SELECT SUM(item.price)
544+
FROM UNNEST(items) AS item
545+
WHERE category = 'electronics') as electronics_total
546+
FROM cte
547+
""",
548+
read="bigquery",
549+
),
550+
schema={
551+
"base_table": {
552+
"id": "INT64",
553+
"items": "ARRAY<STRUCT<price FLOAT64, category STRING>>",
554+
}
555+
},
556+
dialect="bigquery",
557+
).sql(dialect="bigquery"),
558+
"WITH `cte` AS (SELECT `base_table`.`id` AS `id`, `base_table`.`items` AS `items` FROM `base_table` AS `base_table`) SELECT (SELECT SUM(`item`.`price`) AS `_col_0` FROM UNNEST(`cte`.`items`) AS `item` WHERE `category` = 'electronics') AS `electronics_total` FROM `cte` AS `cte`",
559+
)
560+
519561
self.check_file(
520562
"qualify_columns",
521563
qualify_columns,

0 commit comments

Comments
 (0)