diff --git a/bigframes/core/compile/polars/compiler.py b/bigframes/core/compile/polars/compiler.py index e1531ee9e5..4206beab2f 100644 --- a/bigframes/core/compile/polars/compiler.py +++ b/bigframes/core/compile/polars/compiler.py @@ -33,7 +33,9 @@ import bigframes.operations.aggregations as agg_ops import bigframes.operations.bool_ops as bool_ops import bigframes.operations.comparison_ops as comp_ops +import bigframes.operations.datetime_ops as dt_ops import bigframes.operations.generic_ops as gen_ops +import bigframes.operations.json_ops as json_ops import bigframes.operations.numeric_ops as num_ops import bigframes.operations.string_ops as string_ops @@ -280,6 +282,30 @@ def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr: assert isinstance(op, string_ops.StrConcatOp) return pl.concat_str(l_input, r_input) + @compile_op.register(dt_ops.StrftimeOp) + def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: + assert isinstance(op, dt_ops.StrftimeOp) + return input.dt.strftime(op.date_format) + + @compile_op.register(dt_ops.ParseDatetimeOp) + def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: + assert isinstance(op, dt_ops.ParseDatetimeOp) + return input.str.to_datetime( + time_unit="us", time_zone=None, ambiguous="earliest" + ) + + @compile_op.register(dt_ops.ParseTimestampOp) + def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: + assert isinstance(op, dt_ops.ParseTimestampOp) + return input.str.to_datetime( + time_unit="us", time_zone="UTC", ambiguous="earliest" + ) + + @compile_op.register(json_ops.JSONDecode) + def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr: + assert isinstance(op, json_ops.JSONDecode) + return input.str.json_decode(_DTYPE_MAPPING[op.to_type]) + @dataclasses.dataclass(frozen=True) class PolarsAggregateCompiler: scalar_compiler = PolarsExpressionCompiler() diff --git a/bigframes/core/compile/polars/lowering.py b/bigframes/core/compile/polars/lowering.py index ee0933b450..013651ff17 100644 --- a/bigframes/core/compile/polars/lowering.py +++ b/bigframes/core/compile/polars/lowering.py @@ -17,7 +17,7 @@ from bigframes import dtypes from bigframes.core import bigframe_node, expression from bigframes.core.rewrite import op_lowering -from bigframes.operations import comparison_ops, numeric_ops +from bigframes.operations import comparison_ops, datetime_ops, json_ops, numeric_ops import bigframes.operations as ops # TODO: Would be more precise to actually have separate op set for polars ops (where they diverge from the original ops) @@ -278,6 +278,16 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression: return wo_bools +class LowerAsTypeRule(op_lowering.OpLoweringRule): + @property + def op(self) -> type[ops.ScalarOp]: + return ops.AsTypeOp + + def lower(self, expr: expression.OpExpression) -> expression.Expression: + assert isinstance(expr.op, ops.AsTypeOp) + return _lower_cast(expr.op, expr.inputs[0]) + + def _coerce_comparables( expr1: expression.Expression, expr2: expression.Expression, @@ -299,12 +309,57 @@ def _coerce_comparables( return expr1, expr2 -# TODO: Need to handle bool->string cast to get capitalization correct def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression): + if arg.output_type == cast_op.to_type: + return arg + + if arg.output_type == dtypes.JSON_DTYPE: + return json_ops.JSONDecode(cast_op.to_type).as_expr(arg) + if ( + arg.output_type == dtypes.STRING_DTYPE + and cast_op.to_type == dtypes.DATETIME_DTYPE + ): + return datetime_ops.ParseDatetimeOp().as_expr(arg) + if ( + arg.output_type == dtypes.STRING_DTYPE + and cast_op.to_type == dtypes.TIMESTAMP_DTYPE + ): + return datetime_ops.ParseTimestampOp().as_expr(arg) + # date -> string casting + if ( + arg.output_type == dtypes.DATETIME_DTYPE + and cast_op.to_type == dtypes.STRING_DTYPE + ): + return datetime_ops.StrftimeOp("%Y-%m-%d %H:%M:%S").as_expr(arg) + if arg.output_type == dtypes.TIME_DTYPE and cast_op.to_type == dtypes.STRING_DTYPE: + return datetime_ops.StrftimeOp("%H:%M:%S.%6f").as_expr(arg) + if ( + arg.output_type == dtypes.TIMESTAMP_DTYPE + and cast_op.to_type == dtypes.STRING_DTYPE + ): + return datetime_ops.StrftimeOp("%Y-%m-%d %H:%M:%S%.6f%:::z").as_expr(arg) + if arg.output_type == dtypes.BOOL_DTYPE and cast_op.to_type == dtypes.STRING_DTYPE: + # bool -> decimal needs two-step cast + new_arg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(arg) + is_true_cond = ops.eq_op.as_expr(arg, expression.const(True)) + is_false_cond = ops.eq_op.as_expr(arg, expression.const(False)) + return ops.CaseWhenOp().as_expr( + is_true_cond, + expression.const("True"), + is_false_cond, + expression.const("False"), + ) if arg.output_type == dtypes.BOOL_DTYPE and dtypes.is_numeric(cast_op.to_type): # bool -> decimal needs two-step cast new_arg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(arg) return cast_op.as_expr(new_arg) + if arg.output_type == dtypes.TIME_DTYPE and dtypes.is_numeric(cast_op.to_type): + # polars cast gives nanoseconds, so convert to microseconds + return numeric_ops.floordiv_op.as_expr( + cast_op.as_expr(arg), expression.const(1000) + ) + if dtypes.is_numeric(arg.output_type) and cast_op.to_type == dtypes.TIME_DTYPE: + return cast_op.as_expr(ops.mul_op.as_expr(expression.const(1000), arg)) return cast_op.as_expr(arg) @@ -329,6 +384,7 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression): LowerDivRule(), LowerFloorDivRule(), LowerModRule(), + LowerAsTypeRule(), ) diff --git a/bigframes/operations/datetime_ops.py b/bigframes/operations/datetime_ops.py index 6f44952488..9988e8ed7b 100644 --- a/bigframes/operations/datetime_ops.py +++ b/bigframes/operations/datetime_ops.py @@ -39,6 +39,28 @@ time_op = TimeOp() +@dataclasses.dataclass(frozen=True) +class ParseDatetimeOp(base_ops.UnaryOp): + # TODO: Support strict format + name: typing.ClassVar[str] = "parse_datetime" + + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + if input_types[0] != dtypes.STRING_DTYPE: + raise TypeError("expected string input") + return pd.ArrowDtype(pa.timestamp("us", tz=None)) + + +@dataclasses.dataclass(frozen=True) +class ParseTimestampOp(base_ops.UnaryOp): + # TODO: Support strict format + name: typing.ClassVar[str] = "parse_timestamp" + + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + if input_types[0] != dtypes.STRING_DTYPE: + raise TypeError("expected string input") + return pd.ArrowDtype(pa.timestamp("us", tz="UTC")) + + @dataclasses.dataclass(frozen=True) class ToDatetimeOp(base_ops.UnaryOp): name: typing.ClassVar[str] = "to_datetime" diff --git a/bigframes/operations/generic_ops.py b/bigframes/operations/generic_ops.py index 3c3f9653b4..152de543db 100644 --- a/bigframes/operations/generic_ops.py +++ b/bigframes/operations/generic_ops.py @@ -53,6 +53,280 @@ ) hash_op = HashOp() +# source, dest type +_VALID_CASTS = set( + ( + # INT casts + ( + dtypes.BOOL_DTYPE, + dtypes.INT_DTYPE, + ), + ( + dtypes.FLOAT_DTYPE, + dtypes.INT_DTYPE, + ), + ( + dtypes.NUMERIC_DTYPE, + dtypes.INT_DTYPE, + ), + ( + dtypes.BIGNUMERIC_DTYPE, + dtypes.INT_DTYPE, + ), + ( + dtypes.TIME_DTYPE, + dtypes.INT_DTYPE, + ), + ( + dtypes.DATETIME_DTYPE, + dtypes.INT_DTYPE, + ), + ( + dtypes.TIMESTAMP_DTYPE, + dtypes.INT_DTYPE, + ), + ( + dtypes.TIMEDELTA_DTYPE, + dtypes.INT_DTYPE, + ), + ( + dtypes.STRING_DTYPE, + dtypes.INT_DTYPE, + ), + ( + dtypes.JSON_DTYPE, + dtypes.INT_DTYPE, + ), + # Float casts + ( + dtypes.BOOL_DTYPE, + dtypes.FLOAT_DTYPE, + ), + ( + dtypes.NUMERIC_DTYPE, + dtypes.FLOAT_DTYPE, + ), + ( + dtypes.BIGNUMERIC_DTYPE, + dtypes.FLOAT_DTYPE, + ), + ( + dtypes.INT_DTYPE, + dtypes.FLOAT_DTYPE, + ), + ( + dtypes.STRING_DTYPE, + dtypes.FLOAT_DTYPE, + ), + ( + dtypes.JSON_DTYPE, + dtypes.FLOAT_DTYPE, + ), + # Bool casts + ( + dtypes.INT_DTYPE, + dtypes.BOOL_DTYPE, + ), + ( + dtypes.FLOAT_DTYPE, + dtypes.BOOL_DTYPE, + ), + ( + dtypes.JSON_DTYPE, + dtypes.BOOL_DTYPE, + ), + # String casts + ( + dtypes.BYTES_DTYPE, + dtypes.STRING_DTYPE, + ), + ( + dtypes.BOOL_DTYPE, + dtypes.STRING_DTYPE, + ), + ( + dtypes.FLOAT_DTYPE, + dtypes.STRING_DTYPE, + ), + ( + dtypes.TIME_DTYPE, + dtypes.STRING_DTYPE, + ), + ( + dtypes.INT_DTYPE, + dtypes.STRING_DTYPE, + ), + ( + dtypes.DATETIME_DTYPE, + dtypes.STRING_DTYPE, + ), + ( + dtypes.TIMESTAMP_DTYPE, + dtypes.STRING_DTYPE, + ), + ( + dtypes.DATE_DTYPE, + dtypes.STRING_DTYPE, + ), + ( + dtypes.JSON_DTYPE, + dtypes.STRING_DTYPE, + ), + # bytes casts + ( + dtypes.STRING_DTYPE, + dtypes.BYTES_DTYPE, + ), + # decimal casts + ( + dtypes.STRING_DTYPE, + dtypes.NUMERIC_DTYPE, + ), + ( + dtypes.INT_DTYPE, + dtypes.NUMERIC_DTYPE, + ), + ( + dtypes.FLOAT_DTYPE, + dtypes.NUMERIC_DTYPE, + ), + ( + dtypes.BIGNUMERIC_DTYPE, + dtypes.NUMERIC_DTYPE, + ), + # big decimal casts + ( + dtypes.STRING_DTYPE, + dtypes.BIGNUMERIC_DTYPE, + ), + ( + dtypes.INT_DTYPE, + dtypes.BIGNUMERIC_DTYPE, + ), + ( + dtypes.FLOAT_DTYPE, + dtypes.BIGNUMERIC_DTYPE, + ), + ( + dtypes.NUMERIC_DTYPE, + dtypes.BIGNUMERIC_DTYPE, + ), + # time casts + ( + dtypes.INT_DTYPE, + dtypes.TIME_DTYPE, + ), + ( + dtypes.DATETIME_DTYPE, + dtypes.TIME_DTYPE, + ), + ( + dtypes.TIMESTAMP_DTYPE, + dtypes.TIME_DTYPE, + ), + # date casts + ( + dtypes.STRING_DTYPE, + dtypes.DATE_DTYPE, + ), + ( + dtypes.DATETIME_DTYPE, + dtypes.DATE_DTYPE, + ), + ( + dtypes.TIMESTAMP_DTYPE, + dtypes.DATE_DTYPE, + ), + # datetime casts + ( + dtypes.DATE_DTYPE, + dtypes.DATETIME_DTYPE, + ), + ( + dtypes.STRING_DTYPE, + dtypes.DATETIME_DTYPE, + ), + ( + dtypes.TIMESTAMP_DTYPE, + dtypes.DATETIME_DTYPE, + ), + ( + dtypes.INT_DTYPE, + dtypes.DATETIME_DTYPE, + ), + # timestamp casts + ( + dtypes.DATE_DTYPE, + dtypes.TIMESTAMP_DTYPE, + ), + ( + dtypes.STRING_DTYPE, + dtypes.TIMESTAMP_DTYPE, + ), + ( + dtypes.DATETIME_DTYPE, + dtypes.TIMESTAMP_DTYPE, + ), + ( + dtypes.INT_DTYPE, + dtypes.TIMESTAMP_DTYPE, + ), + # timedelta casts + ( + dtypes.INT_DTYPE, + dtypes.TIMEDELTA_DTYPE, + ), + # json casts + ( + dtypes.BOOL_DTYPE, + dtypes.JSON_DTYPE, + ), + ( + dtypes.FLOAT_DTYPE, + dtypes.JSON_DTYPE, + ), + ( + dtypes.STRING_DTYPE, + dtypes.JSON_DTYPE, + ), + ( + dtypes.INT_DTYPE, + dtypes.JSON_DTYPE, + ), + ) +) + + +def _valid_scalar_cast(src: dtypes.Dtype, dst: dtypes.Dtype): + if src == dst: + return True + elif (src, dst) in _VALID_CASTS: + return True + return False + + +def _valid_cast(src: dtypes.Dtype, dst: dtypes.Dtype): + if src == dst: + return True + # TODO: Might need to be more strict within list/array context + if dtypes.is_array_like(src) and dtypes.is_array_like(dst): + src_inner = dtypes.get_array_inner_type(src) + dst_inner = dtypes.get_array_inner_type(dst) + return _valid_cast(src_inner, dst_inner) + if dtypes.is_struct_like(src) and dtypes.is_struct_like(dst): + src_fields = dtypes.get_struct_fields(src) + dst_fields = dtypes.get_struct_fields(dst) + if len(src_fields) != len(dst_fields): + return False + for (_, src_dtype), (_, dst_dtype) in zip( + src_fields.items(), dst_fields.items() + ): + if not _valid_cast(src_dtype, dst_dtype): + return False + return True + + return _valid_scalar_cast(src, dst) + @dataclasses.dataclass(frozen=True) class AsTypeOp(base_ops.UnaryOp): @@ -62,6 +336,9 @@ class AsTypeOp(base_ops.UnaryOp): safe: bool = False def output_type(self, *input_types): + if not _valid_cast(input_types[0], self.to_type): + raise TypeError(f"Cannot cast {input_types[0]} to {self.to_type}") + return self.to_type diff --git a/bigframes/operations/json_ops.py b/bigframes/operations/json_ops.py index 81f00c39ce..b1f4f2f689 100644 --- a/bigframes/operations/json_ops.py +++ b/bigframes/operations/json_ops.py @@ -183,3 +183,18 @@ def output_type(self, *input_types): + f" Received type: {input_type}" ) return input_type + + +@dataclasses.dataclass(frozen=True) +class JSONDecode(base_ops.UnaryOp): + name: typing.ClassVar[str] = "json_decode" + to_type: dtypes.Dtype + + def output_type(self, *input_types): + input_type = input_types[0] + if not dtypes.is_json_like(input_type): + raise TypeError( + "Input type must be a valid JSON object or JSON-formatted string type." + + f" Received type: {input_type}" + ) + return self.to_type diff --git a/bigframes/session/polars_executor.py b/bigframes/session/polars_executor.py index 2c04a0016b..ccc577deae 100644 --- a/bigframes/session/polars_executor.py +++ b/bigframes/session/polars_executor.py @@ -21,7 +21,7 @@ from bigframes.core import array_value, bigframe_node, expression, local_data, nodes import bigframes.operations from bigframes.operations import aggregations as agg_ops -from bigframes.operations import comparison_ops, numeric_ops +from bigframes.operations import comparison_ops, generic_ops, numeric_ops from bigframes.session import executor, semi_executor if TYPE_CHECKING: @@ -57,6 +57,7 @@ numeric_ops.DivOp, numeric_ops.FloorDivOp, numeric_ops.ModOp, + generic_ops.AsTypeOp, ) _COMPATIBLE_AGG_OPS = ( agg_ops.SizeOp, diff --git a/tests/system/small/engines/test_generic_ops.py b/tests/system/small/engines/test_generic_ops.py new file mode 100644 index 0000000000..af114991eb --- /dev/null +++ b/tests/system/small/engines/test_generic_ops.py @@ -0,0 +1,268 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +import pytest + +from bigframes.core import array_value, expression +import bigframes.dtypes +import bigframes.operations as ops +from bigframes.session import polars_executor +from bigframes.testing.engine_utils import assert_equivalence_execution + +pytest.importorskip("polars") + +# Polars used as reference as its fast and local. Generally though, prefer gbq engine where they disagree. +REFERENCE_ENGINE = polars_executor.PolarsExecutor() + + +def apply_op( + array: array_value.ArrayValue, op: ops.AsTypeOp, excluded_cols=[] +) -> array_value.ArrayValue: + exprs = [] + labels = [] + for arg in array.column_ids: + if arg in excluded_cols: + continue + try: + _ = op.output_type(array.get_column_type(arg)) + expr = op.as_expr(arg) + exprs.append(expr) + type_string = re.sub(r"[^a-zA-Z\d]", "_", str(op.to_type)) + labels.append(f"{arg}_as_{type_string}") + except TypeError: + continue + assert len(exprs) > 0 + new_arr, ids = array.compute_values(exprs) + new_arr = new_arr.rename_columns( + {new_col: label for new_col, label in zip(ids, labels)} + ) + return new_arr + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_astype_int(scalars_array_value: array_value.ArrayValue, engine): + arr = apply_op( + scalars_array_value, + ops.AsTypeOp(to_type=bigframes.dtypes.INT_DTYPE), + excluded_cols=["string_col"], + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_astype_string_int(scalars_array_value: array_value.ArrayValue, engine): + vals = ["1", "100", "-3"] + arr, _ = scalars_array_value.compute_values( + [ + ops.AsTypeOp(to_type=bigframes.dtypes.INT_DTYPE).as_expr( + expression.const(val) + ) + for val in vals + ] + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_astype_float(scalars_array_value: array_value.ArrayValue, engine): + arr = apply_op( + scalars_array_value, + ops.AsTypeOp(to_type=bigframes.dtypes.FLOAT_DTYPE), + excluded_cols=["string_col"], + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_astype_string_float( + scalars_array_value: array_value.ArrayValue, engine +): + vals = ["1", "1.1", ".1", "1e3", "1.34235e4", "3.33333e-4"] + arr, _ = scalars_array_value.compute_values( + [ + ops.AsTypeOp(to_type=bigframes.dtypes.FLOAT_DTYPE).as_expr( + expression.const(val) + ) + for val in vals + ] + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_astype_bool(scalars_array_value: array_value.ArrayValue, engine): + arr = apply_op( + scalars_array_value, ops.AsTypeOp(to_type=bigframes.dtypes.BOOL_DTYPE) + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_astype_string(scalars_array_value: array_value.ArrayValue, engine): + # floats work slightly different with trailing zeroes rn + arr = apply_op( + scalars_array_value, + ops.AsTypeOp(to_type=bigframes.dtypes.STRING_DTYPE), + excluded_cols=["float64_col"], + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_astype_numeric(scalars_array_value: array_value.ArrayValue, engine): + arr = apply_op( + scalars_array_value, + ops.AsTypeOp(to_type=bigframes.dtypes.NUMERIC_DTYPE), + excluded_cols=["string_col"], + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_astype_string_numeric( + scalars_array_value: array_value.ArrayValue, engine +): + vals = ["1", "1.1", ".1", "23428975070235903.209", "-23428975070235903.209"] + arr, _ = scalars_array_value.compute_values( + [ + ops.AsTypeOp(to_type=bigframes.dtypes.NUMERIC_DTYPE).as_expr( + expression.const(val) + ) + for val in vals + ] + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_astype_date(scalars_array_value: array_value.ArrayValue, engine): + arr = apply_op( + scalars_array_value, + ops.AsTypeOp(to_type=bigframes.dtypes.DATE_DTYPE), + excluded_cols=["string_col"], + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_astype_string_date( + scalars_array_value: array_value.ArrayValue, engine +): + vals = ["2014-08-15", "2215-08-15", "2016-02-29"] + arr, _ = scalars_array_value.compute_values( + [ + ops.AsTypeOp(to_type=bigframes.dtypes.DATE_DTYPE).as_expr( + expression.const(val) + ) + for val in vals + ] + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_astype_datetime(scalars_array_value: array_value.ArrayValue, engine): + arr = apply_op( + scalars_array_value, + ops.AsTypeOp(to_type=bigframes.dtypes.DATETIME_DTYPE), + excluded_cols=["string_col"], + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_astype_string_datetime( + scalars_array_value: array_value.ArrayValue, engine +): + vals = ["2014-08-15 08:15:12", "2015-08-15 08:15:12.654754", "2016-02-29 00:00:00"] + arr, _ = scalars_array_value.compute_values( + [ + ops.AsTypeOp(to_type=bigframes.dtypes.DATETIME_DTYPE).as_expr( + expression.const(val) + ) + for val in vals + ] + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_astype_timestamp(scalars_array_value: array_value.ArrayValue, engine): + arr = apply_op( + scalars_array_value, + ops.AsTypeOp(to_type=bigframes.dtypes.TIMESTAMP_DTYPE), + excluded_cols=["string_col"], + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_astype_string_timestamp( + scalars_array_value: array_value.ArrayValue, engine +): + vals = [ + "2014-08-15 08:15:12+00:00", + "2015-08-15 08:15:12.654754+05:00", + "2016-02-29 00:00:00+08:00", + ] + arr, _ = scalars_array_value.compute_values( + [ + ops.AsTypeOp(to_type=bigframes.dtypes.TIMESTAMP_DTYPE).as_expr( + expression.const(val) + ) + for val in vals + ] + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_astype_time(scalars_array_value: array_value.ArrayValue, engine): + arr = apply_op( + scalars_array_value, + ops.AsTypeOp(to_type=bigframes.dtypes.TIME_DTYPE), + excluded_cols=["string_col", "int64_col", "int64_too"], + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_astype_from_json(scalars_array_value: array_value.ArrayValue, engine): + exprs = [ + ops.AsTypeOp(to_type=bigframes.dtypes.INT_DTYPE).as_expr( + expression.const("5", bigframes.dtypes.JSON_DTYPE) + ), + ops.AsTypeOp(to_type=bigframes.dtypes.FLOAT_DTYPE).as_expr( + expression.const("5", bigframes.dtypes.JSON_DTYPE) + ), + ops.AsTypeOp(to_type=bigframes.dtypes.BOOL_DTYPE).as_expr( + expression.const("true", bigframes.dtypes.JSON_DTYPE) + ), + ops.AsTypeOp(to_type=bigframes.dtypes.STRING_DTYPE).as_expr( + expression.const('"hello world"', bigframes.dtypes.JSON_DTYPE) + ), + ] + arr, _ = scalars_array_value.compute_values(exprs) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_astype_timedelta(scalars_array_value: array_value.ArrayValue, engine): + arr = apply_op( + scalars_array_value, + ops.AsTypeOp(to_type=bigframes.dtypes.TIMEDELTA_DTYPE), + ) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) diff --git a/tests/system/small/test_series.py b/tests/system/small/test_series.py index 3f64234293..e94250e98f 100644 --- a/tests/system/small/test_series.py +++ b/tests/system/small/test_series.py @@ -3685,8 +3685,12 @@ def test_astype_numeric_to_int(scalars_df_index, scalars_pandas_df_index): column = "numeric_col" to_type = "Int64" bf_result = scalars_df_index[column].astype(to_type).to_pandas() - # Round to the nearest whole number to avoid TypeError - pd_result = scalars_pandas_df_index[column].round(0).astype(to_type) + # Truncate to int to avoid TypeError + pd_result = ( + scalars_pandas_df_index[column] + .apply(lambda x: None if pd.isna(x) else math.trunc(x)) + .astype(to_type) + ) pd.testing.assert_series_equal(bf_result, pd_result) diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py index acccd7ea6c..cbc51e59d6 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/base.py @@ -537,7 +537,7 @@ def if_(self, condition, true, false: sge.Expression | None = None) -> sge.If: false=None if false is None else sge.convert(false), ) - def cast(self, arg, to: dt.DataType) -> sge.Cast: + def cast(self, arg, to: dt.DataType, format=None) -> sge.Cast: return sge.Cast( this=sge.convert(arg), to=self.type_mapper.from_ibis(to), copy=False ) diff --git a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py index be8f9fc555..08bf0d7650 100644 --- a/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py @@ -544,7 +544,7 @@ def visit_Cast(self, op, *, arg, to): f"BigQuery does not allow extracting date part `{from_.unit}` from intervals" ) return self.f.extract(self.v[to.resolution.upper()], arg) - elif from_.is_floating() and to.is_integer(): + elif (from_.is_floating() or from_.is_decimal()) and to.is_integer(): return self.cast(self.f.trunc(arg), dt.int64) return super().visit_Cast(op, arg=arg, to=to)