Skip to content

Commit 0fffc49

Browse files
authored
chore: remove squash_selections from sqlglot_ir (#1833)
1 parent be0a3cf commit 0fffc49

File tree

4 files changed

+14
-81
lines changed

4 files changed

+14
-81
lines changed

bigframes/core/compile/googlesql/query.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,13 @@ def _select_field(self, field) -> SelectExpression:
8383
return SelectExpression(expression=expr.ColumnExpression(name=field))
8484

8585
else:
86-
alias = field[1] if (field[0] != field[1]) else None
86+
alias = (
87+
expr.AliasExpression(field[1])
88+
if isinstance(field[1], str)
89+
else field[1]
90+
if (field[0] != field[1])
91+
else None
92+
)
8793
return SelectExpression(
8894
expression=expr.ColumnExpression(name=field[0]), alias=alias
8995
)

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,7 @@ def _compile_result_node(self, root: nodes.ResultNode) -> str:
125125
(name, scalar_compiler.compile_scalar_expression(ref))
126126
for ref, name in root.output_cols
127127
)
128-
# Skip squashing selections to ensure the right ordering and limit keys
129-
sqlglot_ir = self.compile_node(root.child).select(
130-
selected_cols, squash_selections=False
131-
)
128+
sqlglot_ir = self.compile_node(root.child).select(selected_cols)
132129

133130
if root.order_by is not None:
134131
ordering_cols = tuple(

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,6 @@ def from_union(
203203
def select(
204204
self,
205205
selected_cols: tuple[tuple[str, sge.Expression], ...],
206-
squash_selections: bool = True,
207206
) -> SQLGlotIR:
208207
selections = [
209208
sge.Alias(
@@ -213,15 +212,6 @@ def select(
213212
for id, expr in selected_cols
214213
]
215214

216-
# If squashing is enabled, we try to simplify the selections
217-
# by checking if the new selections are simply aliases of the
218-
# original columns.
219-
if squash_selections:
220-
new_selections = _squash_selections(self.expr.expressions, selections)
221-
if new_selections != []:
222-
new_expr = self.expr.select(*new_selections, append=False)
223-
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
224-
225215
new_expr = self._encapsulate_as_cte().select(*selections, append=False)
226216
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
227217

@@ -361,63 +351,3 @@ def _table(table: bigquery.TableReference) -> sge.Table:
361351
db=sg.to_identifier(table.dataset_id, quoted=True),
362352
catalog=sg.to_identifier(table.project, quoted=True),
363353
)
364-
365-
366-
def _squash_selections(
367-
old_expr: list[sge.Expression], new_expr: list[sge.Alias]
368-
) -> list[sge.Alias]:
369-
"""
370-
TODO: Reanble this function to optimize the SQL.
371-
Simplifies the select column expressions if existing (old_expr) and
372-
new (new_expr) selected columns are both simple aliases of column definitions.
373-
374-
Example:
375-
old_expr: [A AS X, B AS Y]
376-
new_expr: [X AS P, Y AS Q]
377-
Result: [A AS P, B AS Q]
378-
"""
379-
old_alias_map: typing.Dict[str, str] = {}
380-
for selected in old_expr:
381-
column_alias_pair = _get_column_alias_pair(selected)
382-
if column_alias_pair is None:
383-
return []
384-
else:
385-
old_alias_map[column_alias_pair[1]] = column_alias_pair[0]
386-
387-
new_selected_cols: typing.List[sge.Alias] = []
388-
for selected in new_expr:
389-
column_alias_pair = _get_column_alias_pair(selected)
390-
if column_alias_pair is None or column_alias_pair[0] not in old_alias_map:
391-
return []
392-
else:
393-
new_alias_expr = sge.Alias(
394-
this=sge.ColumnDef(
395-
this=sge.to_identifier(
396-
old_alias_map[column_alias_pair[0]], quoted=True
397-
)
398-
),
399-
alias=sg.to_identifier(column_alias_pair[1], quoted=True),
400-
)
401-
new_selected_cols.append(new_alias_expr)
402-
return new_selected_cols
403-
404-
405-
def _get_column_alias_pair(
406-
expr: sge.Expression,
407-
) -> typing.Optional[typing.Tuple[str, str]]:
408-
"""Checks if an expression is a simple alias of a column definition
409-
(e.g., "column_name AS alias_name").
410-
If it is, returns a tuple containing the alias name and original column name.
411-
Returns `None` otherwise.
412-
"""
413-
if not isinstance(expr, sge.Alias):
414-
return None
415-
if not isinstance(expr.this, sge.ColumnDef):
416-
return None
417-
418-
column_def_expr: sge.ColumnDef = expr.this
419-
if not isinstance(column_def_expr.this, sge.Identifier):
420-
return None
421-
422-
original_identifier: sge.Identifier = column_def_expr.this
423-
return (original_identifier.this, expr.alias)

bigframes/core/rewrite/pruning.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import dataclasses
1515
import functools
16-
from typing import AbstractSet
16+
import typing
1717

1818
from bigframes.core import identifiers, nodes
1919

@@ -143,7 +143,7 @@ def prune_selection_child(
143143

144144
def prune_node(
145145
node: nodes.BigFrameNode,
146-
ids: AbstractSet[identifiers.ColumnId],
146+
ids: typing.AbstractSet[identifiers.ColumnId],
147147
):
148148
# This clause is important, ensures idempotency, so can reach fixed point
149149
if not (set(node.ids) - ids):
@@ -157,7 +157,7 @@ def prune_node(
157157

158158
def prune_aggregate(
159159
node: nodes.AggregateNode,
160-
used_cols: AbstractSet[identifiers.ColumnId],
160+
used_cols: typing.AbstractSet[identifiers.ColumnId],
161161
) -> nodes.AggregateNode:
162162
pruned_aggs = (
163163
tuple(agg for agg in node.aggregations if agg[1] in used_cols)
@@ -169,15 +169,15 @@ def prune_aggregate(
169169
@functools.singledispatch
170170
def prune_leaf(
171171
node: nodes.BigFrameNode,
172-
used_cols: AbstractSet[identifiers.ColumnId],
172+
used_cols: typing.AbstractSet[identifiers.ColumnId],
173173
):
174174
...
175175

176176

177177
@prune_leaf.register
178178
def prune_readlocal(
179179
node: nodes.ReadLocalNode,
180-
selection: AbstractSet[identifiers.ColumnId],
180+
selection: typing.AbstractSet[identifiers.ColumnId],
181181
) -> nodes.ReadLocalNode:
182182
new_scan_list = node.scan_list.filter_cols(selection)
183183
return dataclasses.replace(
@@ -190,7 +190,7 @@ def prune_readlocal(
190190
@prune_leaf.register
191191
def prune_readtable(
192192
node: nodes.ReadTableNode,
193-
selection: AbstractSet[identifiers.ColumnId],
193+
selection: typing.AbstractSet[identifiers.ColumnId],
194194
) -> nodes.ReadTableNode:
195195
new_scan_list = node.scan_list.filter_cols(selection)
196196
return dataclasses.replace(node, scan_list=new_scan_list)

0 commit comments

Comments
 (0)