diff --git a/bigframes/core/compile/sqlglot/compiler.py b/bigframes/core/compile/sqlglot/compiler.py index d2b796b0aa..68b572f911 100644 --- a/bigframes/core/compile/sqlglot/compiler.py +++ b/bigframes/core/compile/sqlglot/compiler.py @@ -125,9 +125,25 @@ def _compile_result_node(self, root: nodes.ResultNode) -> str: (name, scalar_compiler.compile_scalar_expression(ref)) for ref, name in root.output_cols ) - sqlglot_ir = sqlglot_ir.select(selected_cols) + # Skip squashing selections to ensure the right ordering and limit keys + sqlglot_ir = sqlglot_ir.select(selected_cols, squash_selections=False) + + if root.order_by is not None: + ordering_cols = tuple( + sge.Ordered( + this=scalar_compiler.compile_scalar_expression( + ordering.scalar_expression + ), + desc=ordering.direction.is_ascending is False, + nulls_first=ordering.na_last is False, + ) + for ordering in root.order_by.all_ordering_columns + ) + sqlglot_ir = sqlglot_ir.order_by(ordering_cols) + + if root.limit is not None: + sqlglot_ir = sqlglot_ir.limit(root.limit) - # TODO: add order_by, limit to sqlglot_expr return sqlglot_ir.sql @functools.lru_cache(maxsize=5000) diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index 43bdc6b06b..77ee0ccb78 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -28,7 +28,7 @@ from bigframes.core import guid import bigframes.core.compile.sqlglot.sqlglot_types as sgt import bigframes.core.local_data as local_data -import bigframes.core.schema as schemata +import bigframes.core.schema as bf_schema # shapely.wkt.dumps was moved to shapely.io.to_wkt in 2.0. try: @@ -67,7 +67,7 @@ def sql(self) -> str: def from_pyarrow( cls, pa_table: pa.Table, - schema: schemata.ArraySchema, + schema: bf_schema.ArraySchema, uid_gen: guid.SequentialUIDGenerator, ) -> SQLGlotIR: """Builds SQLGlot expression from pyarrow table.""" @@ -203,6 +203,7 @@ def from_union( def select( self, selected_cols: tuple[tuple[str, sge.Expression], ...], + squash_selections: bool = True, ) -> SQLGlotIR: selections = [ sge.Alias( @@ -211,15 +212,39 @@ def select( ) for id, expr in selected_cols ] - # Attempts to simplify selected columns when the original and new column - # names are simply aliases of each other. - squashed_selections = _squash_selections(self.expr.expressions, selections) - if squashed_selections != []: - new_expr = self.expr.select(*squashed_selections, append=False) - return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + + # If squashing is enabled, we try to simplify the selections + # by checking if the new selections are simply aliases of the + # original columns. + if squash_selections: + new_selections = _squash_selections(self.expr.expressions, selections) + if new_selections != []: + new_expr = self.expr.select(*new_selections, append=False) + return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + + new_expr = self._encapsulate_as_cte().select(*selections, append=False) + return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + + def order_by( + self, + ordering: tuple[sge.Ordered, ...], + ) -> SQLGlotIR: + """Adds ORDER BY clause to the query.""" + if len(ordering) == 0: + return SQLGlotIR(expr=self.expr.copy(), uid_gen=self.uid_gen) + new_expr = self.expr.order_by(*ordering) + return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + + def limit( + self, + limit: int | None, + ) -> SQLGlotIR: + """Adds LIMIT clause to the query.""" + if limit is not None: + new_expr = self.expr.limit(limit) else: - new_expr = self._encapsulate_as_cte().select(*selections, append=False) - return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + new_expr = self.expr.copy() + return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) def project( self, @@ -342,6 +367,7 @@ def _squash_selections( old_expr: list[sge.Expression], new_expr: list[sge.Alias] ) -> list[sge.Alias]: """ + TODO: Reanble this function to optimize the SQL. Simplifies the select column expressions if existing (old_expr) and new (new_expr) selected columns are both simple aliases of column definitions. diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_concat/test_compile_concat/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_concat/test_compile_concat/out.sql index 4b6b2617ac..855e5874c2 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_concat/test_compile_concat/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_concat/test_compile_concat/out.sql @@ -104,4 +104,7 @@ SELECT `bfcol_47` AS `rowindex_1`, `bfcol_48` AS `int64_col`, `bfcol_49` AS `string_col` -FROM `bfcte_12` \ No newline at end of file +FROM `bfcte_12` +ORDER BY + `bfcol_50` ASC NULLS LAST, + `bfcol_51` ASC NULLS LAST \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_projection/test_compile_projection/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_projection/test_compile_projection/out.sql index db470e3ba3..2804925b2d 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_projection/test_compile_projection/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_projection/test_compile_projection/out.sql @@ -15,11 +15,19 @@ WITH `bfcte_0` AS ( `bfcol_4` AS `bfcol_8`, `bfcol_1` + 1 AS `bfcol_9` FROM `bfcte_0` +), `bfcte_2` AS ( + SELECT + `bfcol_5` AS `bfcol_10`, + `bfcol_9` AS `bfcol_11`, + `bfcol_6` AS `bfcol_12`, + `bfcol_7` AS `bfcol_13`, + `bfcol_8` AS `bfcol_14` + FROM `bfcte_1` ) SELECT - `bfcol_5` AS `rowindex`, - `bfcol_9` AS `int64_col`, - `bfcol_6` AS `string_col`, - `bfcol_7` AS `float64_col`, - `bfcol_8` AS `bool_col` -FROM `bfcte_1` \ No newline at end of file + `bfcol_10` AS `rowindex`, + `bfcol_11` AS `int64_col`, + `bfcol_12` AS `string_col`, + `bfcol_13` AS `float64_col`, + `bfcol_14` AS `bool_col` +FROM `bfcte_2` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal/out.sql index a34f3526d6..89c51b346d 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal/out.sql @@ -155,21 +155,42 @@ WITH `bfcte_0` AS ( CAST(NULL AS TIMESTAMP), 8 )]) +), `bfcte_1` AS ( + SELECT + `bfcol_0` AS `bfcol_16`, + `bfcol_1` AS `bfcol_17`, + `bfcol_2` AS `bfcol_18`, + `bfcol_3` AS `bfcol_19`, + `bfcol_4` AS `bfcol_20`, + `bfcol_5` AS `bfcol_21`, + `bfcol_6` AS `bfcol_22`, + `bfcol_7` AS `bfcol_23`, + `bfcol_8` AS `bfcol_24`, + `bfcol_9` AS `bfcol_25`, + `bfcol_10` AS `bfcol_26`, + `bfcol_11` AS `bfcol_27`, + `bfcol_12` AS `bfcol_28`, + `bfcol_13` AS `bfcol_29`, + `bfcol_14` AS `bfcol_30`, + `bfcol_15` AS `bfcol_31` + FROM `bfcte_0` ) SELECT - `bfcol_0` AS `rowindex`, - `bfcol_1` AS `bool_col`, - `bfcol_2` AS `bytes_col`, - `bfcol_3` AS `date_col`, - `bfcol_4` AS `datetime_col`, - `bfcol_5` AS `geography_col`, - `bfcol_6` AS `int64_col`, - `bfcol_7` AS `int64_too`, - `bfcol_8` AS `numeric_col`, - `bfcol_9` AS `float64_col`, - `bfcol_10` AS `rowindex_1`, - `bfcol_11` AS `rowindex_2`, - `bfcol_12` AS `string_col`, - `bfcol_13` AS `time_col`, - `bfcol_14` AS `timestamp_col` -FROM `bfcte_0` \ No newline at end of file + `bfcol_16` AS `rowindex`, + `bfcol_17` AS `bool_col`, + `bfcol_18` AS `bytes_col`, + `bfcol_19` AS `date_col`, + `bfcol_20` AS `datetime_col`, + `bfcol_21` AS `geography_col`, + `bfcol_22` AS `int64_col`, + `bfcol_23` AS `int64_too`, + `bfcol_24` AS `numeric_col`, + `bfcol_25` AS `float64_col`, + `bfcol_26` AS `rowindex_1`, + `bfcol_27` AS `rowindex_2`, + `bfcol_28` AS `string_col`, + `bfcol_29` AS `time_col`, + `bfcol_30` AS `timestamp_col` +FROM `bfcte_1` +ORDER BY + `bfcol_31` ASC NULLS LAST \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_json_df/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_json_df/out.sql index 31b46e6c70..76cbde7c64 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_json_df/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_json_df/out.sql @@ -2,7 +2,14 @@ WITH `bfcte_0` AS ( SELECT * FROM UNNEST(ARRAY>[STRUCT(PARSE_JSON('null'), 0), STRUCT(PARSE_JSON('true'), 1), STRUCT(PARSE_JSON('100'), 2), STRUCT(PARSE_JSON('0.98'), 3), STRUCT(PARSE_JSON('"a string"'), 4), STRUCT(PARSE_JSON('[]'), 5), STRUCT(PARSE_JSON('[1,2,3]'), 6), STRUCT(PARSE_JSON('[{"a":1},{"a":2},{"a":null},{}]'), 7), STRUCT(PARSE_JSON('"100"'), 8), STRUCT(PARSE_JSON('{"date":"2024-07-16"}'), 9), STRUCT(PARSE_JSON('{"int_value":2,"null_filed":null}'), 10), STRUCT(PARSE_JSON('{"list_data":[10,20,30]}'), 11)]) +), `bfcte_1` AS ( + SELECT + `bfcol_0` AS `bfcol_2`, + `bfcol_1` AS `bfcol_3` + FROM `bfcte_0` ) SELECT - `bfcol_0` AS `json_col` -FROM `bfcte_0` \ No newline at end of file + `bfcol_2` AS `json_col` +FROM `bfcte_1` +ORDER BY + `bfcol_3` ASC NULLS LAST \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_lists_df/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_lists_df/out.sql index 1ba602f205..6363739d9d 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_lists_df/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_lists_df/out.sql @@ -32,14 +32,28 @@ WITH `bfcte_0` AS ( ['', 'a'], 2 )]) +), `bfcte_1` AS ( + SELECT + `bfcol_0` AS `bfcol_9`, + `bfcol_1` AS `bfcol_10`, + `bfcol_2` AS `bfcol_11`, + `bfcol_3` AS `bfcol_12`, + `bfcol_4` AS `bfcol_13`, + `bfcol_5` AS `bfcol_14`, + `bfcol_6` AS `bfcol_15`, + `bfcol_7` AS `bfcol_16`, + `bfcol_8` AS `bfcol_17` + FROM `bfcte_0` ) SELECT - `bfcol_0` AS `rowindex`, - `bfcol_1` AS `int_list_col`, - `bfcol_2` AS `bool_list_col`, - `bfcol_3` AS `float_list_col`, - `bfcol_4` AS `date_list_col`, - `bfcol_5` AS `date_time_list_col`, - `bfcol_6` AS `numeric_list_col`, - `bfcol_7` AS `string_list_col` -FROM `bfcte_0` \ No newline at end of file + `bfcol_9` AS `rowindex`, + `bfcol_10` AS `int_list_col`, + `bfcol_11` AS `bool_list_col`, + `bfcol_12` AS `float_list_col`, + `bfcol_13` AS `date_list_col`, + `bfcol_14` AS `date_time_list_col`, + `bfcol_15` AS `numeric_list_col`, + `bfcol_16` AS `string_list_col` +FROM `bfcte_1` +ORDER BY + `bfcol_17` ASC NULLS LAST \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_structs_df/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_structs_df/out.sql index 54d1a1bb2b..af7206b759 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_structs_df/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal_w_structs_df/out.sql @@ -18,8 +18,16 @@ WITH `bfcte_0` AS ( ), 1 )]) +), `bfcte_1` AS ( + SELECT + `bfcol_0` AS `bfcol_3`, + `bfcol_1` AS `bfcol_4`, + `bfcol_2` AS `bfcol_5` + FROM `bfcte_0` ) SELECT - `bfcol_0` AS `id`, - `bfcol_1` AS `person` -FROM `bfcte_0` \ No newline at end of file + `bfcol_3` AS `id`, + `bfcol_4` AS `person` +FROM `bfcte_1` +ORDER BY + `bfcol_5` ASC NULLS LAST \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_readtable/test_compile_readtable_w_limit/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readtable/test_compile_readtable_w_limit/out.sql new file mode 100644 index 0000000000..837b805ca4 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readtable/test_compile_readtable_w_limit/out.sql @@ -0,0 +1,24 @@ +WITH `bfcte_0` AS ( + SELECT + `rowindex` AS `bfcol_0`, + `int64_col` AS `bfcol_1`, + `string_col` AS `bfcol_2`, + `float64_col` AS `bfcol_3`, + `bool_col` AS `bfcol_4` + FROM `test-project`.`test_dataset`.`test_table` +), `bfcte_1` AS ( + SELECT + *, + `bfcol_1` AS `bfcol_5` + FROM `bfcte_0` +) +SELECT + `bfcol_0` AS `rowindex`, + `bfcol_1` AS `int64_col`, + `bfcol_2` AS `string_col`, + `bfcol_3` AS `float64_col`, + `bfcol_4` AS `bool_col` +FROM `bfcte_1` +ORDER BY + `bfcol_5` ASC NULLS LAST +LIMIT 10 \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_readtable/test_compile_readtable_w_ordering/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readtable/test_compile_readtable_w_ordering/out.sql new file mode 100644 index 0000000000..9376691572 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readtable/test_compile_readtable_w_ordering/out.sql @@ -0,0 +1,40 @@ +WITH `bfcte_0` AS ( + SELECT + `rowindex` AS `bfcol_0`, + `int64_col` AS `bfcol_1`, + `string_col` AS `bfcol_2`, + `float64_col` AS `bfcol_3`, + `bool_col` AS `bfcol_4` + FROM `test-project`.`test_dataset`.`test_table` +), `bfcte_1` AS ( + SELECT + `bfcol_0` AS `bfcol_5`, + `bfcol_1` AS `bfcol_6`, + `bfcol_2` AS `bfcol_7`, + `bfcol_3` AS `bfcol_8`, + `bfcol_4` AS `bfcol_9` + FROM `bfcte_0` +), `bfcte_2` AS ( + SELECT + *, + `bfcol_5` AS `bfcol_10` + FROM `bfcte_1` +), `bfcte_3` AS ( + SELECT + `bfcol_5` AS `bfcol_11`, + `bfcol_6` AS `bfcol_12`, + `bfcol_7` AS `bfcol_13`, + `bfcol_8` AS `bfcol_14`, + `bfcol_9` AS `bfcol_15`, + `bfcol_10` AS `bfcol_16` + FROM `bfcte_2` +) +SELECT + `bfcol_11` AS `rowindex`, + `bfcol_12` AS `int64_col`, + `bfcol_13` AS `string_col`, + `bfcol_14` AS `float64_col`, + `bfcol_15` AS `bool_col` +FROM `bfcte_3` +ORDER BY + `bfcol_16` ASC NULLS LAST \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/test_compile_readtable.py b/tests/unit/core/compile/sqlglot/test_compile_readtable.py index 848ace58f3..41e01e9b25 100644 --- a/tests/unit/core/compile/sqlglot/test_compile_readtable.py +++ b/tests/unit/core/compile/sqlglot/test_compile_readtable.py @@ -22,3 +22,15 @@ def test_compile_readtable(compiler_session: bigframes.Session, snapshot): bf_df = compiler_session.read_gbq_table("test-project.test_dataset.test_table") snapshot.assert_match(bf_df.sql, "out.sql") + + +def test_compile_readtable_w_ordering(compiler_session: bigframes.Session, snapshot): + bf_df = compiler_session.read_gbq_table("test-project.test_dataset.test_table") + bf_df = bf_df.set_index("rowindex").sort_index() + snapshot.assert_match(bf_df.sql, "out.sql") + + +def test_compile_readtable_w_limit(compiler_session: bigframes.Session, snapshot): + bf_df = compiler_session.read_gbq_table("test-project.test_dataset.test_table") + bf_df = bf_df.sort_values("int64_col").head(10) + snapshot.assert_match(bf_df.sql, "out.sql")