Skip to content

feat: Can cast locally in hybrid engine #1944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions bigframes/core/compile/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
import bigframes.operations.aggregations as agg_ops
import bigframes.operations.bool_ops as bool_ops
import bigframes.operations.comparison_ops as comp_ops
import bigframes.operations.datetime_ops as dt_ops
import bigframes.operations.generic_ops as gen_ops
import bigframes.operations.json_ops as json_ops
import bigframes.operations.numeric_ops as num_ops
import bigframes.operations.string_ops as string_ops

Expand Down Expand Up @@ -280,6 +282,30 @@ def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
assert isinstance(op, string_ops.StrConcatOp)
return pl.concat_str(l_input, r_input)

@compile_op.register(dt_ops.StrftimeOp)
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
assert isinstance(op, dt_ops.StrftimeOp)
return input.dt.strftime(op.date_format)

@compile_op.register(dt_ops.ParseDatetimeOp)
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
assert isinstance(op, dt_ops.ParseDatetimeOp)
return input.str.to_datetime(
time_unit="us", time_zone=None, ambiguous="earliest"
)

@compile_op.register(dt_ops.ParseTimestampOp)
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
assert isinstance(op, dt_ops.ParseTimestampOp)
return input.str.to_datetime(
time_unit="us", time_zone="UTC", ambiguous="earliest"
)

@compile_op.register(json_ops.JSONDecode)
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
assert isinstance(op, json_ops.JSONDecode)
return input.str.json_decode(_DTYPE_MAPPING[op.to_type])

@dataclasses.dataclass(frozen=True)
class PolarsAggregateCompiler:
scalar_compiler = PolarsExpressionCompiler()
Expand Down
60 changes: 58 additions & 2 deletions bigframes/core/compile/polars/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from bigframes import dtypes
from bigframes.core import bigframe_node, expression
from bigframes.core.rewrite import op_lowering
from bigframes.operations import comparison_ops, numeric_ops
from bigframes.operations import comparison_ops, datetime_ops, json_ops, numeric_ops
import bigframes.operations as ops

# TODO: Would be more precise to actually have separate op set for polars ops (where they diverge from the original ops)
Expand Down Expand Up @@ -278,6 +278,16 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
return wo_bools


class LowerAsTypeRule(op_lowering.OpLoweringRule):
@property
def op(self) -> type[ops.ScalarOp]:
return ops.AsTypeOp

def lower(self, expr: expression.OpExpression) -> expression.Expression:
assert isinstance(expr.op, ops.AsTypeOp)
return _lower_cast(expr.op, expr.inputs[0])


def _coerce_comparables(
expr1: expression.Expression,
expr2: expression.Expression,
Expand All @@ -299,12 +309,57 @@ def _coerce_comparables(
return expr1, expr2


# TODO: Need to handle bool->string cast to get capitalization correct
def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
if arg.output_type == cast_op.to_type:
return arg

if arg.output_type == dtypes.JSON_DTYPE:
return json_ops.JSONDecode(cast_op.to_type).as_expr(arg)
if (
arg.output_type == dtypes.STRING_DTYPE
and cast_op.to_type == dtypes.DATETIME_DTYPE
):
return datetime_ops.ParseDatetimeOp().as_expr(arg)
if (
arg.output_type == dtypes.STRING_DTYPE
and cast_op.to_type == dtypes.TIMESTAMP_DTYPE
):
return datetime_ops.ParseTimestampOp().as_expr(arg)
# date -> string casting
if (
arg.output_type == dtypes.DATETIME_DTYPE
and cast_op.to_type == dtypes.STRING_DTYPE
):
return datetime_ops.StrftimeOp("%Y-%m-%d %H:%M:%S").as_expr(arg)
if arg.output_type == dtypes.TIME_DTYPE and cast_op.to_type == dtypes.STRING_DTYPE:
return datetime_ops.StrftimeOp("%H:%M:%S.%6f").as_expr(arg)
if (
arg.output_type == dtypes.TIMESTAMP_DTYPE
and cast_op.to_type == dtypes.STRING_DTYPE
):
return datetime_ops.StrftimeOp("%Y-%m-%d %H:%M:%S%.6f%:::z").as_expr(arg)
if arg.output_type == dtypes.BOOL_DTYPE and cast_op.to_type == dtypes.STRING_DTYPE:
# bool -> decimal needs two-step cast
new_arg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(arg)
is_true_cond = ops.eq_op.as_expr(arg, expression.const(True))
is_false_cond = ops.eq_op.as_expr(arg, expression.const(False))
return ops.CaseWhenOp().as_expr(
is_true_cond,
expression.const("True"),
is_false_cond,
expression.const("False"),
)
if arg.output_type == dtypes.BOOL_DTYPE and dtypes.is_numeric(cast_op.to_type):
# bool -> decimal needs two-step cast
new_arg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(arg)
return cast_op.as_expr(new_arg)
if arg.output_type == dtypes.TIME_DTYPE and dtypes.is_numeric(cast_op.to_type):
# polars cast gives nanoseconds, so convert to microseconds
return numeric_ops.floordiv_op.as_expr(
cast_op.as_expr(arg), expression.const(1000)
)
if dtypes.is_numeric(arg.output_type) and cast_op.to_type == dtypes.TIME_DTYPE:
return cast_op.as_expr(ops.mul_op.as_expr(expression.const(1000), arg))
return cast_op.as_expr(arg)


Expand All @@ -329,6 +384,7 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
LowerDivRule(),
LowerFloorDivRule(),
LowerModRule(),
LowerAsTypeRule(),
)


Expand Down
22 changes: 22 additions & 0 deletions bigframes/operations/datetime_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,28 @@
time_op = TimeOp()


@dataclasses.dataclass(frozen=True)
class ParseDatetimeOp(base_ops.UnaryOp):
# TODO: Support strict format
name: typing.ClassVar[str] = "parse_datetime"

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
if input_types[0] != dtypes.STRING_DTYPE:
raise TypeError("expected string input")
return pd.ArrowDtype(pa.timestamp("us", tz=None))


@dataclasses.dataclass(frozen=True)
class ParseTimestampOp(base_ops.UnaryOp):
# TODO: Support strict format
name: typing.ClassVar[str] = "parse_timestamp"

def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
if input_types[0] != dtypes.STRING_DTYPE:
raise TypeError("expected string input")
return pd.ArrowDtype(pa.timestamp("us", tz="UTC"))


@dataclasses.dataclass(frozen=True)
class ToDatetimeOp(base_ops.UnaryOp):
name: typing.ClassVar[str] = "to_datetime"
Expand Down
Loading