Skip to content

Commit 4819a3d

Browse files
perf: Produce simpler sql
1 parent 37666e4 commit 4819a3d

File tree

6 files changed

+140
-3
lines changed

6 files changed

+140
-3
lines changed

bigframes/core/compile/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def compile_sql(request: configs.CompileRequest) -> configs.CompileResult:
6565
ordering: Optional[bf_ordering.RowOrdering] = result_node.order_by
6666
result_node = dataclasses.replace(result_node, order_by=None)
6767
result_node = cast(nodes.ResultNode, rewrites.column_pruning(result_node))
68+
result_node = cast(nodes.ResultNode, rewrites.defer_selection(result_node))
6869
sql = compile_result_node(result_node)
6970
# Return the ordering iff no extra columns are needed to define the row order
7071
if ordering is not None:

bigframes/core/compile/googlesql/query.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,13 @@ def _select_field(self, field) -> SelectExpression:
8383
return SelectExpression(expression=expr.ColumnExpression(name=field))
8484

8585
else:
86-
alias = field[1] if (field[0] != field[1]) else None
86+
alias = (
87+
expr.AliasExpression(field[1])
88+
if isinstance(field[1], str)
89+
else field[1]
90+
if (field[0] != field[1])
91+
else None
92+
)
8793
return SelectExpression(
8894
expression=expr.ColumnExpression(name=field[0]), alias=alias
8995
)
@@ -119,7 +125,7 @@ def sql(self) -> str:
119125
return "\n".join(text)
120126

121127

122-
@dataclasses.dataclass
128+
@dataclasses.dataclass(frozen=True)
123129
class SelectExpression(abc.SQLSyntax):
124130
"""This class represents `select_expression`."""
125131

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult
8888
)
8989
result_node = self._remap_variables(result_node)
9090
sql = self._compile_result_node(result_node)
91+
result_node = typing.cast(
92+
nodes.ResultNode, rewrite.defer_selection(result_node)
93+
)
9194
return configs.CompileResult(
9295
sql, result_node.schema.to_bigquery(), result_node.order_by
9396
)
@@ -97,6 +100,9 @@ def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult
97100
result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node))
98101

99102
result_node = self._remap_variables(result_node)
103+
result_node = typing.cast(
104+
nodes.ResultNode, rewrite.defer_selection(result_node)
105+
)
100106
sql = self._compile_result_node(result_node)
101107
# Return the ordering iff no extra columns are needed to define the row order
102108
if ordering is not None:

bigframes/core/nodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def additive_base(self) -> BigFrameNode:
7575
...
7676

7777
@abc.abstractmethod
78-
def replace_additive_base(self, BigFrameNode):
78+
def replace_additive_base(self, BigFrameNode) -> BigFrameNode:
7979
...
8080

8181

bigframes/core/rewrite/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
try_reduce_to_local_scan,
2323
try_reduce_to_table_scan,
2424
)
25+
from bigframes.core.rewrite.select_pullup import defer_selection
2526
from bigframes.core.rewrite.slices import pull_out_limit, pull_up_limits, rewrite_slice
2627
from bigframes.core.rewrite.timedeltas import rewrite_timedelta_expressions
2728
from bigframes.core.rewrite.windows import pull_out_window_order, rewrite_range_rolling
@@ -42,4 +43,5 @@
4243
"try_reduce_to_local_scan",
4344
"fold_row_counts",
4445
"pull_out_window_order",
46+
"defer_selection",
4547
]
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import dataclasses
16+
from typing import cast
17+
18+
from bigframes.core import expression, nodes
19+
20+
21+
def defer_selection(
22+
root: nodes.BigFrameNode,
23+
) -> nodes.BigFrameNode:
24+
return nodes.bottom_up(root, pull_up_select)
25+
26+
27+
def pull_up_select(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
28+
if isinstance(node, nodes.LeafNode):
29+
return node
30+
if isinstance(node, nodes.JoinNode):
31+
return pull_up_selects_under_join(node)
32+
if isinstance(node, nodes.ConcatNode):
33+
return handle_selects_under_concat(node)
34+
if isinstance(node, nodes.UnaryNode):
35+
return pull_up_select_unary(node)
36+
# shouldn't hit this, but not worth crashing over
37+
return node
38+
39+
40+
def pull_up_select_unary(node: nodes.UnaryNode) -> nodes.BigFrameNode:
41+
child = node.child
42+
if not isinstance(child, nodes.SelectionNode):
43+
return node
44+
45+
# Schema-preserving nodes
46+
if isinstance(
47+
node,
48+
(
49+
nodes.ReversedNode,
50+
nodes.OrderByNode,
51+
nodes.SliceNode,
52+
nodes.FilterNode,
53+
nodes.RandomSampleNode,
54+
),
55+
):
56+
pushed_down_node: nodes.BigFrameNode = node.remap_refs(
57+
{id: ref.id for ref, id in child.input_output_pairs}
58+
).replace_child(child.child)
59+
pulled_up_select = cast(
60+
nodes.SelectionNode, child.replace_child(pushed_down_node)
61+
)
62+
return pulled_up_select
63+
elif isinstance(node, (nodes.SelectionNode, nodes.ResultNode, nodes.AggregateNode)):
64+
return node.remap_refs(
65+
{id: ref.id for ref, id in child.input_output_pairs}
66+
).replace_child(child.child)
67+
elif isinstance(node, nodes.ExplodeNode):
68+
pushed_down_node = node.remap_refs(
69+
{id: ref.id for ref, id in child.input_output_pairs}
70+
).replace_child(child.child)
71+
pulled_up_select = cast(
72+
nodes.SelectionNode, child.replace_child(pushed_down_node)
73+
)
74+
if node.offsets_col:
75+
pulled_up_select = dataclasses.replace(
76+
pulled_up_select,
77+
input_output_pairs=(
78+
*pulled_up_select.input_output_pairs,
79+
nodes.AliasedRef(
80+
expression.DerefOp(node.offsets_col), node.offsets_col
81+
),
82+
),
83+
)
84+
return pulled_up_select
85+
elif isinstance(node, nodes.AdditiveNode):
86+
pushed_down_node = node.replace_additive_base(child.child).remap_refs(
87+
{id: ref.id for ref, id in child.input_output_pairs}
88+
)
89+
new_selection = (
90+
*child.input_output_pairs,
91+
*(
92+
nodes.AliasedRef(expression.DerefOp(col.id), col.id)
93+
for col in node.added_fields
94+
),
95+
)
96+
pulled_up_select = dataclasses.replace(
97+
child, child=pushed_down_node, input_output_pairs=new_selection
98+
)
99+
return pulled_up_select
100+
# shouldn't hit this, but not worth crashing over
101+
return node
102+
103+
104+
def pull_up_selects_under_join(node: nodes.JoinNode) -> nodes.JoinNode:
105+
# Can in theory pull up selects here, but it is a bit dangerous, in particular or self-joins, when there are more transforms to do.
106+
# TODO: Safely pull up selects above join
107+
return node
108+
109+
110+
def handle_selects_under_concat(node: nodes.ConcatNode) -> nodes.ConcatNode:
111+
new_children = []
112+
for child in node.child_nodes:
113+
# remove select if no-op
114+
if not isinstance(child, nodes.SelectionNode):
115+
new_children.append(child)
116+
else:
117+
inputs = (ref.id for ref in child.input_output_pairs)
118+
if inputs == tuple(child.child.ids):
119+
new_children.append(child.child)
120+
else:
121+
new_children.append(child)
122+
return dataclasses.replace(node, children=tuple(new_children))

0 commit comments

Comments
 (0)