1717from bigframes import dtypes
1818from bigframes .core import bigframe_node , expression
1919from bigframes .core .rewrite import op_lowering
20- from bigframes .operations import comparison_ops , numeric_ops
20+ from bigframes .operations import comparison_ops , datetime_ops , json_ops , numeric_ops
2121import bigframes .operations as ops
2222
2323# TODO: Would be more precise to actually have separate op set for polars ops (where they diverge from the original ops)
@@ -278,6 +278,16 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
278278 return wo_bools
279279
280280
281+ class LowerAsTypeRule (op_lowering .OpLoweringRule ):
282+ @property
283+ def op (self ) -> type [ops .ScalarOp ]:
284+ return ops .AsTypeOp
285+
286+ def lower (self , expr : expression .OpExpression ) -> expression .Expression :
287+ assert isinstance (expr .op , ops .AsTypeOp )
288+ return _lower_cast (expr .op , expr .inputs [0 ])
289+
290+
281291def _coerce_comparables (
282292 expr1 : expression .Expression ,
283293 expr2 : expression .Expression ,
@@ -299,12 +309,57 @@ def _coerce_comparables(
299309 return expr1 , expr2
300310
301311
302- # TODO: Need to handle bool->string cast to get capitalization correct
303312def _lower_cast (cast_op : ops .AsTypeOp , arg : expression .Expression ):
313+ if arg .output_type == cast_op .to_type :
314+ return arg
315+
316+ if arg .output_type == dtypes .JSON_DTYPE :
317+ return json_ops .JSONDecode (cast_op .to_type ).as_expr (arg )
318+ if (
319+ arg .output_type == dtypes .STRING_DTYPE
320+ and cast_op .to_type == dtypes .DATETIME_DTYPE
321+ ):
322+ return datetime_ops .ParseDatetimeOp ().as_expr (arg )
323+ if (
324+ arg .output_type == dtypes .STRING_DTYPE
325+ and cast_op .to_type == dtypes .TIMESTAMP_DTYPE
326+ ):
327+ return datetime_ops .ParseTimestampOp ().as_expr (arg )
328+ # date -> string casting
329+ if (
330+ arg .output_type == dtypes .DATETIME_DTYPE
331+ and cast_op .to_type == dtypes .STRING_DTYPE
332+ ):
333+ return datetime_ops .StrftimeOp ("%Y-%m-%d %H:%M:%S" ).as_expr (arg )
334+ if arg .output_type == dtypes .TIME_DTYPE and cast_op .to_type == dtypes .STRING_DTYPE :
335+ return datetime_ops .StrftimeOp ("%H:%M:%S.%6f" ).as_expr (arg )
336+ if (
337+ arg .output_type == dtypes .TIMESTAMP_DTYPE
338+ and cast_op .to_type == dtypes .STRING_DTYPE
339+ ):
340+ return datetime_ops .StrftimeOp ("%Y-%m-%d %H:%M:%S%.6f%:::z" ).as_expr (arg )
341+ if arg .output_type == dtypes .BOOL_DTYPE and cast_op .to_type == dtypes .STRING_DTYPE :
342+ # bool -> decimal needs two-step cast
343+ new_arg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (arg )
344+ is_true_cond = ops .eq_op .as_expr (arg , expression .const (True ))
345+ is_false_cond = ops .eq_op .as_expr (arg , expression .const (False ))
346+ return ops .CaseWhenOp ().as_expr (
347+ is_true_cond ,
348+ expression .const ("True" ),
349+ is_false_cond ,
350+ expression .const ("False" ),
351+ )
304352 if arg .output_type == dtypes .BOOL_DTYPE and dtypes .is_numeric (cast_op .to_type ):
305353 # bool -> decimal needs two-step cast
306354 new_arg = ops .AsTypeOp (to_type = dtypes .INT_DTYPE ).as_expr (arg )
307355 return cast_op .as_expr (new_arg )
356+ if arg .output_type == dtypes .TIME_DTYPE and dtypes .is_numeric (cast_op .to_type ):
357+ # polars cast gives nanoseconds, so convert to microseconds
358+ return numeric_ops .floordiv_op .as_expr (
359+ cast_op .as_expr (arg ), expression .const (1000 )
360+ )
361+ if dtypes .is_numeric (arg .output_type ) and cast_op .to_type == dtypes .TIME_DTYPE :
362+ return cast_op .as_expr (ops .mul_op .as_expr (expression .const (1000 ), arg ))
308363 return cast_op .as_expr (arg )
309364
310365
@@ -329,6 +384,7 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
329384 LowerDivRule (),
330385 LowerFloorDivRule (),
331386 LowerModRule (),
387+ LowerAsTypeRule (),
332388)
333389
334390
0 commit comments