diff --git a/bigframes/core/compile/sqlglot/expressions/binary_compiler.py b/bigframes/core/compile/sqlglot/expressions/binary_compiler.py index a6eb7182e9..c46019d909 100644 --- a/bigframes/core/compile/sqlglot/expressions/binary_compiler.py +++ b/bigframes/core/compile/sqlglot/expressions/binary_compiler.py @@ -14,6 +14,7 @@ from __future__ import annotations +import bigframes_vendored.constants as constants import sqlglot.expressions as sge from bigframes import dtypes @@ -35,8 +36,83 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: # String addition return sge.Concat(expressions=[left.expr, right.expr]) - # Numerical addition - return sge.Add(this=left.expr, expression=right.expr) + if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype): + left_expr = left.expr + if left.dtype == dtypes.BOOL_DTYPE: + left_expr = sge.Cast(this=left_expr, to="INT64") + right_expr = right.expr + if right.dtype == dtypes.BOOL_DTYPE: + right_expr = sge.Cast(this=right_expr, to="INT64") + return sge.Add(this=left_expr, expression=right_expr) + + if ( + dtypes.is_time_or_date_like(left.dtype) + and right.dtype == dtypes.TIMEDELTA_DTYPE + ): + left_expr = left.expr + if left.dtype == dtypes.DATE_DTYPE: + left_expr = sge.Cast(this=left_expr, to="DATETIME") + return sge.TimestampAdd( + this=left_expr, expression=right.expr, unit=sge.Var(this="MICROSECOND") + ) + if ( + dtypes.is_time_or_date_like(right.dtype) + and left.dtype == dtypes.TIMEDELTA_DTYPE + ): + right_expr = right.expr + if right.dtype == dtypes.DATE_DTYPE: + right_expr = sge.Cast(this=right_expr, to="DATETIME") + return sge.TimestampAdd( + this=right_expr, expression=left.expr, unit=sge.Var(this="MICROSECOND") + ) + if left.dtype == dtypes.TIMEDELTA_DTYPE and right.dtype == dtypes.TIMEDELTA_DTYPE: + return sge.Add(this=left.expr, expression=right.expr) + + raise TypeError( + f"Cannot add type {left.dtype} and {right.dtype}. {constants.FEEDBACK_LINK}" + ) + + +@BINARY_OP_REGISTRATION.register(ops.sub_op) +def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: + if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype): + left_expr = left.expr + if left.dtype == dtypes.BOOL_DTYPE: + left_expr = sge.Cast(this=left_expr, to="INT64") + right_expr = right.expr + if right.dtype == dtypes.BOOL_DTYPE: + right_expr = sge.Cast(this=right_expr, to="INT64") + return sge.Sub(this=left_expr, expression=right_expr) + + if ( + dtypes.is_time_or_date_like(left.dtype) + and right.dtype == dtypes.TIMEDELTA_DTYPE + ): + left_expr = left.expr + if left.dtype == dtypes.DATE_DTYPE: + left_expr = sge.Cast(this=left_expr, to="DATETIME") + return sge.TimestampSub( + this=left_expr, expression=right.expr, unit=sge.Var(this="MICROSECOND") + ) + if dtypes.is_time_or_date_like(left.dtype) and dtypes.is_time_or_date_like( + right.dtype + ): + left_expr = left.expr + if left.dtype == dtypes.DATE_DTYPE: + left_expr = sge.Cast(this=left_expr, to="DATETIME") + right_expr = right.expr + if right.dtype == dtypes.DATE_DTYPE: + right_expr = sge.Cast(this=right_expr, to="DATETIME") + return sge.TimestampDiff( + this=left_expr, expression=right_expr, unit=sge.Var(this="MICROSECOND") + ) + + if left.dtype == dtypes.TIMEDELTA_DTYPE and right.dtype == dtypes.TIMEDELTA_DTYPE: + return sge.Sub(this=left.expr, expression=right.expr) + + raise TypeError( + f"Cannot subtract type {left.dtype} and {right.dtype}. {constants.FEEDBACK_LINK}" + ) @BINARY_OP_REGISTRATION.register(ops.ge_op) diff --git a/bigframes/dtypes.py b/bigframes/dtypes.py index a58619dc21..ef1b9e7871 100644 --- a/bigframes/dtypes.py +++ b/bigframes/dtypes.py @@ -289,6 +289,10 @@ def is_time_like(type_: ExpressionType) -> bool: return type_ in (DATETIME_DTYPE, TIMESTAMP_DTYPE, TIME_DTYPE) +def is_time_or_date_like(type_: ExpressionType) -> bool: + return type_ in (DATE_DTYPE, DATETIME_DTYPE, TIME_DTYPE, TIMESTAMP_DTYPE) + + def is_geo_like(type_: ExpressionType) -> bool: return type_ in (GEO_DTYPE,) diff --git a/tests/system/small/engines/test_numeric_ops.py b/tests/system/small/engines/test_numeric_ops.py index b53da977f5..7e5b85857b 100644 --- a/tests/system/small/engines/test_numeric_ops.py +++ b/tests/system/small/engines/test_numeric_ops.py @@ -53,7 +53,7 @@ def apply_op_pairwise( return new_arr -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_project_add( scalars_array_value: array_value.ArrayValue, engine, @@ -62,7 +62,7 @@ def test_engines_project_add( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_project_sub( scalars_array_value: array_value.ArrayValue, engine, diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_add_numeric/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_add_numeric/out.sql index e8dc2edb80..44335805e4 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_add_numeric/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_add_numeric/out.sql @@ -1,13 +1,54 @@ WITH `bfcte_0` AS ( SELECT - `int64_col` AS `bfcol_0` + `bool_col` AS `bfcol_0`, + `int64_col` AS `bfcol_1`, + `rowindex` AS `bfcol_2` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` ), `bfcte_1` AS ( SELECT *, - `bfcol_0` + `bfcol_0` AS `bfcol_1` + `bfcol_2` AS `bfcol_6`, + `bfcol_1` AS `bfcol_7`, + `bfcol_0` AS `bfcol_8`, + `bfcol_1` + `bfcol_1` AS `bfcol_9` FROM `bfcte_0` +), `bfcte_2` AS ( + SELECT + *, + `bfcol_6` AS `bfcol_14`, + `bfcol_7` AS `bfcol_15`, + `bfcol_8` AS `bfcol_16`, + `bfcol_9` AS `bfcol_17`, + `bfcol_7` + 1 AS `bfcol_18` + FROM `bfcte_1` +), `bfcte_3` AS ( + SELECT + *, + `bfcol_14` AS `bfcol_24`, + `bfcol_15` AS `bfcol_25`, + `bfcol_16` AS `bfcol_26`, + `bfcol_17` AS `bfcol_27`, + `bfcol_18` AS `bfcol_28`, + `bfcol_15` + CAST(`bfcol_16` AS INT64) AS `bfcol_29` + FROM `bfcte_2` +), `bfcte_4` AS ( + SELECT + *, + `bfcol_24` AS `bfcol_36`, + `bfcol_25` AS `bfcol_37`, + `bfcol_26` AS `bfcol_38`, + `bfcol_27` AS `bfcol_39`, + `bfcol_28` AS `bfcol_40`, + `bfcol_29` AS `bfcol_41`, + CAST(`bfcol_26` AS INT64) + `bfcol_25` AS `bfcol_42` + FROM `bfcte_3` ) SELECT - `bfcol_1` AS `int64_col` -FROM `bfcte_1` \ No newline at end of file + `bfcol_36` AS `rowindex`, + `bfcol_37` AS `int64_col`, + `bfcol_38` AS `bool_col`, + `bfcol_39` AS `int_add_int`, + `bfcol_40` AS `int_add_1`, + `bfcol_41` AS `int_add_bool`, + `bfcol_42` AS `bool_add_int` +FROM `bfcte_4` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_add_numeric_w_scalar/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_add_numeric_w_scalar/out.sql deleted file mode 100644 index 7c4cc2c770..0000000000 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_add_numeric_w_scalar/out.sql +++ /dev/null @@ -1,13 +0,0 @@ -WITH `bfcte_0` AS ( - SELECT - `int64_col` AS `bfcol_0` - FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` -), `bfcte_1` AS ( - SELECT - *, - `bfcol_0` + 1 AS `bfcol_1` - FROM `bfcte_0` -) -SELECT - `bfcol_1` AS `int64_col` -FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_add_timedelta/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_add_timedelta/out.sql new file mode 100644 index 0000000000..a47531999b --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_add_timedelta/out.sql @@ -0,0 +1,60 @@ +WITH `bfcte_0` AS ( + SELECT + `date_col` AS `bfcol_0`, + `rowindex` AS `bfcol_1`, + `timestamp_col` AS `bfcol_2` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + `bfcol_1` AS `bfcol_6`, + `bfcol_2` AS `bfcol_7`, + `bfcol_0` AS `bfcol_8`, + TIMESTAMP_ADD(CAST(`bfcol_0` AS DATETIME), INTERVAL 86400000000 MICROSECOND) AS `bfcol_9` + FROM `bfcte_0` +), `bfcte_2` AS ( + SELECT + *, + `bfcol_6` AS `bfcol_14`, + `bfcol_7` AS `bfcol_15`, + `bfcol_8` AS `bfcol_16`, + `bfcol_9` AS `bfcol_17`, + TIMESTAMP_ADD(`bfcol_7`, INTERVAL 86400000000 MICROSECOND) AS `bfcol_18` + FROM `bfcte_1` +), `bfcte_3` AS ( + SELECT + *, + `bfcol_14` AS `bfcol_24`, + `bfcol_15` AS `bfcol_25`, + `bfcol_16` AS `bfcol_26`, + `bfcol_17` AS `bfcol_27`, + `bfcol_18` AS `bfcol_28`, + TIMESTAMP_ADD(CAST(`bfcol_16` AS DATETIME), INTERVAL 86400000000 MICROSECOND) AS `bfcol_29` + FROM `bfcte_2` +), `bfcte_4` AS ( + SELECT + *, + `bfcol_24` AS `bfcol_36`, + `bfcol_25` AS `bfcol_37`, + `bfcol_26` AS `bfcol_38`, + `bfcol_27` AS `bfcol_39`, + `bfcol_28` AS `bfcol_40`, + `bfcol_29` AS `bfcol_41`, + TIMESTAMP_ADD(`bfcol_25`, INTERVAL 86400000000 MICROSECOND) AS `bfcol_42` + FROM `bfcte_3` +), `bfcte_5` AS ( + SELECT + *, + 172800000000 AS `bfcol_50` + FROM `bfcte_4` +) +SELECT + `bfcol_36` AS `rowindex`, + `bfcol_37` AS `timestamp_col`, + `bfcol_38` AS `date_col`, + `bfcol_39` AS `date_add_timedelta`, + `bfcol_40` AS `timestamp_add_timedelta`, + `bfcol_41` AS `timedelta_add_date`, + `bfcol_42` AS `timedelta_add_timestamp`, + `bfcol_50` AS `timedelta_add_timedelta` +FROM `bfcte_5` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_sub_numeric/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_sub_numeric/out.sql new file mode 100644 index 0000000000..a43fa2df67 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_sub_numeric/out.sql @@ -0,0 +1,54 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `int64_col` AS `bfcol_1`, + `rowindex` AS `bfcol_2` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + `bfcol_2` AS `bfcol_6`, + `bfcol_1` AS `bfcol_7`, + `bfcol_0` AS `bfcol_8`, + `bfcol_1` - `bfcol_1` AS `bfcol_9` + FROM `bfcte_0` +), `bfcte_2` AS ( + SELECT + *, + `bfcol_6` AS `bfcol_14`, + `bfcol_7` AS `bfcol_15`, + `bfcol_8` AS `bfcol_16`, + `bfcol_9` AS `bfcol_17`, + `bfcol_7` - 1 AS `bfcol_18` + FROM `bfcte_1` +), `bfcte_3` AS ( + SELECT + *, + `bfcol_14` AS `bfcol_24`, + `bfcol_15` AS `bfcol_25`, + `bfcol_16` AS `bfcol_26`, + `bfcol_17` AS `bfcol_27`, + `bfcol_18` AS `bfcol_28`, + `bfcol_15` - CAST(`bfcol_16` AS INT64) AS `bfcol_29` + FROM `bfcte_2` +), `bfcte_4` AS ( + SELECT + *, + `bfcol_24` AS `bfcol_36`, + `bfcol_25` AS `bfcol_37`, + `bfcol_26` AS `bfcol_38`, + `bfcol_27` AS `bfcol_39`, + `bfcol_28` AS `bfcol_40`, + `bfcol_29` AS `bfcol_41`, + CAST(`bfcol_26` AS INT64) - `bfcol_25` AS `bfcol_42` + FROM `bfcte_3` +) +SELECT + `bfcol_36` AS `rowindex`, + `bfcol_37` AS `int64_col`, + `bfcol_38` AS `bool_col`, + `bfcol_39` AS `int_add_int`, + `bfcol_40` AS `int_add_1`, + `bfcol_41` AS `int_add_bool`, + `bfcol_42` AS `bool_add_int` +FROM `bfcte_4` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_sub_timedelta/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_sub_timedelta/out.sql new file mode 100644 index 0000000000..41e45d3333 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_sub_timedelta/out.sql @@ -0,0 +1,60 @@ +WITH `bfcte_0` AS ( + SELECT + `date_col` AS `bfcol_0`, + `rowindex` AS `bfcol_1`, + `timestamp_col` AS `bfcol_2` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + `bfcol_1` AS `bfcol_6`, + `bfcol_2` AS `bfcol_7`, + `bfcol_0` AS `bfcol_8`, + TIMESTAMP_SUB(CAST(`bfcol_0` AS DATETIME), INTERVAL 86400000000 MICROSECOND) AS `bfcol_9` + FROM `bfcte_0` +), `bfcte_2` AS ( + SELECT + *, + `bfcol_6` AS `bfcol_14`, + `bfcol_7` AS `bfcol_15`, + `bfcol_8` AS `bfcol_16`, + `bfcol_9` AS `bfcol_17`, + TIMESTAMP_SUB(`bfcol_7`, INTERVAL 86400000000 MICROSECOND) AS `bfcol_18` + FROM `bfcte_1` +), `bfcte_3` AS ( + SELECT + *, + `bfcol_14` AS `bfcol_24`, + `bfcol_15` AS `bfcol_25`, + `bfcol_16` AS `bfcol_26`, + `bfcol_17` AS `bfcol_27`, + `bfcol_18` AS `bfcol_28`, + TIMESTAMP_DIFF(CAST(`bfcol_16` AS DATETIME), CAST(`bfcol_16` AS DATETIME), MICROSECOND) AS `bfcol_29` + FROM `bfcte_2` +), `bfcte_4` AS ( + SELECT + *, + `bfcol_24` AS `bfcol_36`, + `bfcol_25` AS `bfcol_37`, + `bfcol_26` AS `bfcol_38`, + `bfcol_27` AS `bfcol_39`, + `bfcol_28` AS `bfcol_40`, + `bfcol_29` AS `bfcol_41`, + TIMESTAMP_DIFF(`bfcol_25`, `bfcol_25`, MICROSECOND) AS `bfcol_42` + FROM `bfcte_3` +), `bfcte_5` AS ( + SELECT + *, + 0 AS `bfcol_50` + FROM `bfcte_4` +) +SELECT + `bfcol_36` AS `rowindex`, + `bfcol_37` AS `timestamp_col`, + `bfcol_38` AS `date_col`, + `bfcol_39` AS `date_sub_timedelta`, + `bfcol_40` AS `timestamp_sub_timedelta`, + `bfcol_41` AS `timestamp_sub_date`, + `bfcol_42` AS `date_sub_timestamp`, + `bfcol_50` AS `timedelta_sub_timedelta` +FROM `bfcte_5` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py b/tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py index a78a41fdbf..05d9c26945 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py @@ -14,6 +14,7 @@ import typing +import pandas as pd import pytest from bigframes import operations as ops @@ -42,17 +43,15 @@ def _apply_binary_op( def test_add_numeric(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["int64_col"]] - sql = _apply_binary_op(bf_df, ops.add_op, "int64_col", "int64_col") - - snapshot.assert_match(sql, "out.sql") + bf_df = scalar_types_df[["int64_col", "bool_col"]] + bf_df["int_add_int"] = bf_df["int64_col"] + bf_df["int64_col"] + bf_df["int_add_1"] = bf_df["int64_col"] + 1 -def test_add_numeric_w_scalar(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["int64_col"]] - sql = _apply_binary_op(bf_df, ops.add_op, "int64_col", ex.const(1)) + bf_df["int_add_bool"] = bf_df["int64_col"] + bf_df["bool_col"] + bf_df["bool_add_int"] = bf_df["bool_col"] + bf_df["int64_col"] - snapshot.assert_match(sql, "out.sql") + snapshot.assert_match(bf_df.sql, "out.sql") def test_add_string(scalar_types_df: bpd.DataFrame, snapshot): @@ -62,6 +61,27 @@ def test_add_string(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_add_timedelta(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["timestamp_col", "date_col"]] + timedelta = pd.Timedelta(1, unit="d") + + bf_df["date_add_timedelta"] = bf_df["date_col"] + timedelta + bf_df["timestamp_add_timedelta"] = bf_df["timestamp_col"] + timedelta + bf_df["timedelta_add_date"] = timedelta + bf_df["date_col"] + bf_df["timedelta_add_timestamp"] = timedelta + bf_df["timestamp_col"] + bf_df["timedelta_add_timedelta"] = timedelta + timedelta + + snapshot.assert_match(bf_df.sql, "out.sql") + + +def test_add_unsupported_raises(scalar_types_df: bpd.DataFrame): + with pytest.raises(TypeError): + _apply_binary_op(scalar_types_df, ops.add_op, "timestamp_col", "date_col") + + with pytest.raises(TypeError): + _apply_binary_op(scalar_types_df, ops.add_op, "int64_col", "string_col") + + def test_json_set(json_types_df: bpd.DataFrame, snapshot): bf_df = json_types_df[["json_col"]] sql = _apply_binary_op( @@ -69,3 +89,36 @@ def test_json_set(json_types_df: bpd.DataFrame, snapshot): ) snapshot.assert_match(sql, "out.sql") + + +def test_sub_numeric(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "bool_col"]] + + bf_df["int_add_int"] = bf_df["int64_col"] - bf_df["int64_col"] + bf_df["int_add_1"] = bf_df["int64_col"] - 1 + + bf_df["int_add_bool"] = bf_df["int64_col"] - bf_df["bool_col"] + bf_df["bool_add_int"] = bf_df["bool_col"] - bf_df["int64_col"] + + snapshot.assert_match(bf_df.sql, "out.sql") + + +def test_sub_timedelta(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["timestamp_col", "date_col"]] + timedelta = pd.Timedelta(1, unit="d") + + bf_df["date_sub_timedelta"] = bf_df["date_col"] - timedelta + bf_df["timestamp_sub_timedelta"] = bf_df["timestamp_col"] - timedelta + bf_df["timestamp_sub_date"] = bf_df["date_col"] - bf_df["date_col"] + bf_df["date_sub_timestamp"] = bf_df["timestamp_col"] - bf_df["timestamp_col"] + bf_df["timedelta_sub_timedelta"] = timedelta - timedelta + + snapshot.assert_match(bf_df.sql, "out.sql") + + +def test_sub_unsupported_raises(scalar_types_df: bpd.DataFrame): + with pytest.raises(TypeError): + _apply_binary_op(scalar_types_df, ops.sub_op, "string_col", "string_col") + + with pytest.raises(TypeError): + _apply_binary_op(scalar_types_df, ops.sub_op, "int64_col", "string_col")