Skip to content

Commit e6f2cc1

Browse files
committed
feat: add support for the 'right' parameter in 'pandas.cut'
1 parent c382a44 commit e6f2cc1

File tree

5 files changed

+166
-64
lines changed

5 files changed

+166
-64
lines changed

bigframes/core/compile/aggregate_compiler.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -364,8 +364,13 @@ def _(
364364

365365
if op.labels is False:
366366
for this_bin in range(op.bins - 1):
367+
if op.right:
368+
case_expr = x <= (col_min + (this_bin + 1) * bin_width)
369+
370+
else:
371+
case_expr = x < (col_min + (this_bin + 1) * bin_width)
367372
out = out.when(
368-
x <= (col_min + (this_bin + 1) * bin_width),
373+
case_expr,
369374
compile_ibis_types.literal_to_ibis_scalar(
370375
this_bin, force_dtype=pd.Int64Dtype()
371376
),
@@ -375,32 +380,49 @@ def _(
375380
interval_struct = None
376381
adj = (col_max - col_min) * 0.001
377382
for this_bin in range(op.bins):
378-
left_edge = (
379-
col_min + this_bin * bin_width - (0 if this_bin > 0 else adj)
380-
)
381-
right_edge = col_min + (this_bin + 1) * bin_width
382-
interval_struct = ibis_types.struct(
383-
{
384-
"left_exclusive": left_edge,
385-
"right_inclusive": right_edge,
386-
}
387-
)
383+
left_edge_adj = adj if this_bin == 0 and op.right else 0
384+
right_edge_adj = adj if this_bin == op.bins - 1 and not op.right else 0
385+
386+
left_edge = col_min + this_bin * bin_width - left_edge_adj
387+
right_edge = col_min + (this_bin + 1) * bin_width + right_edge_adj
388+
389+
if op.right:
390+
interval_struct = ibis_types.struct(
391+
{
392+
"left_exclusive": left_edge,
393+
"right_inclusive": right_edge,
394+
}
395+
)
396+
else:
397+
interval_struct = ibis_types.struct(
398+
{
399+
"left_inclusive": left_edge,
400+
"right_exclusive": right_edge,
401+
}
402+
)
388403

389404
if this_bin < op.bins - 1:
390-
out = out.when(
391-
x <= (col_min + (this_bin + 1) * bin_width),
392-
interval_struct,
393-
)
405+
if op.right:
406+
case_expr = x <= (col_min + (this_bin + 1) * bin_width)
407+
else:
408+
case_expr = x < (col_min + (this_bin + 1) * bin_width)
409+
out = out.when(case_expr, interval_struct)
394410
else:
395411
out = out.when(x.notnull(), interval_struct)
396412
else: # Interpret as intervals
397413
for interval in op.bins:
398414
left = compile_ibis_types.literal_to_ibis_scalar(interval[0])
399415
right = compile_ibis_types.literal_to_ibis_scalar(interval[1])
400-
condition = (x > left) & (x <= right)
401-
interval_struct = ibis_types.struct(
402-
{"left_exclusive": left, "right_inclusive": right}
403-
)
416+
if op.right:
417+
condition = (x > left) & (x <= right)
418+
interval_struct = ibis_types.struct(
419+
{"left_exclusive": left, "right_inclusive": right}
420+
)
421+
else:
422+
condition = (x >= left) & (x < right)
423+
interval_struct = ibis_types.struct(
424+
{"left_inclusive": left, "right_exclusive": right}
425+
)
404426
out = out.when(condition, interval_struct)
405427
return out.end()
406428

bigframes/core/reshape/tile.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from __future__ import annotations
1616

1717
import typing
18-
from typing import Iterable, Optional, Union
1918

2019
import bigframes_vendored.constants as constants
2120
import bigframes_vendored.pandas.core.reshape.tile as vendored_pandas_tile
@@ -33,26 +32,34 @@
3332

3433
def cut(
3534
x: bigframes.series.Series,
36-
bins: Union[
35+
bins: typing.Union[
3736
int,
3837
pd.IntervalIndex,
39-
Iterable,
38+
typing.Iterable,
4039
],
4140
*,
42-
labels: Union[Iterable[str], bool, None] = None,
41+
right: typing.Optional[bool] = True,
42+
labels: typing.Union[typing.Iterable[str], bool, None] = None,
4343
) -> bigframes.series.Series:
4444
if isinstance(bins, int) and bins <= 0:
4545
raise ValueError("`bins` should be a positive integer.")
4646

47-
if isinstance(bins, Iterable):
47+
# TODO: Check `right` does not apply for IntervalIndex.
48+
49+
if isinstance(bins, typing.Iterable):
4850
if isinstance(bins, pd.IntervalIndex):
51+
# TODO: test an empty internval index
4952
as_index: pd.IntervalIndex = bins
5053
bins = tuple((bin.left.item(), bin.right.item()) for bin in bins)
54+
# To maintain consistency with pandas' behavior
55+
right = True
5156
elif len(list(bins)) == 0:
5257
raise ValueError("`bins` iterable should have at least one item")
5358
elif isinstance(list(bins)[0], tuple):
5459
as_index = pd.IntervalIndex.from_tuples(list(bins))
5560
bins = tuple(bins)
61+
# To maintain consistency with pandas' behavior
62+
right = True
5663
elif pd.api.types.is_number(list(bins)[0]):
5764
bins_list = list(bins)
5865
if len(bins_list) < 2:
@@ -82,7 +89,8 @@ def cut(
8289
)
8390

8491
return x._apply_window_op(
85-
agg_ops.CutOp(bins, labels=labels), window_spec=window_specs.unbound()
92+
agg_ops.CutOp(bins, right=right, labels=labels),
93+
window_spec=window_specs.unbound(),
8694
)
8795

8896

@@ -93,7 +101,7 @@ def qcut(
93101
x: bigframes.series.Series,
94102
q: typing.Union[int, typing.Sequence[float]],
95103
*,
96-
labels: Optional[bool] = None,
104+
labels: typing.Optional[bool] = None,
97105
duplicates: typing.Literal["drop", "error"] = "error",
98106
) -> bigframes.series.Series:
99107
if isinstance(q, int) and q <= 0:

bigframes/operations/aggregations.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
339339
class CutOp(UnaryWindowOp):
340340
# TODO: Unintuitive, refactor into multiple ops?
341341
bins: typing.Union[int, Iterable]
342+
right: Optional[bool]
342343
labels: Optional[bool]
343344

344345
@property
@@ -355,12 +356,21 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
355356
if isinstance(self.bins, int)
356357
else dtypes.infer_literal_arrow_type(list(self.bins)[0][0])
357358
)
358-
pa_type = pa.struct(
359-
[
360-
pa.field("left_exclusive", interval_dtype, nullable=True),
361-
pa.field("right_inclusive", interval_dtype, nullable=True),
362-
]
363-
)
359+
if self.right:
360+
pa_type = pa.struct(
361+
[
362+
pa.field("left_exclusive", interval_dtype, nullable=True),
363+
pa.field("right_inclusive", interval_dtype, nullable=True),
364+
]
365+
)
366+
else:
367+
pa_type = pa.struct(
368+
[
369+
pa.field("left_inclusive", interval_dtype, nullable=True),
370+
pa.field("right_exclusive", interval_dtype, nullable=True),
371+
]
372+
)
373+
364374
return pd.ArrowDtype(pa_type)
365375

366376
@property

tests/system/small/test_pandas.py

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -386,33 +386,53 @@ def test_merge_series(scalars_dfs, merge_how):
386386

387387
assert_pandas_df_equal(bf_result, pd_result, ignore_order=True)
388388

389-
390-
def test_cut(scalars_dfs):
389+
@pytest.mark.parametrize(
390+
("right"),
391+
[
392+
pytest.param(True),
393+
pytest.param(False),
394+
],
395+
)
396+
def test_cut(scalars_dfs, right):
391397
scalars_df, scalars_pandas_df = scalars_dfs
392398

393-
pd_result = pd.cut(scalars_pandas_df["float64_col"], 5, labels=False)
394-
bf_result = bpd.cut(scalars_df["float64_col"], 5, labels=False)
399+
pd_result = pd.cut(scalars_pandas_df["float64_col"], 5, labels=False, right=right)
400+
bf_result = bpd.cut(scalars_df["float64_col"], 5, labels=False, right=right)
395401

396402
# make sure the result is a supported dtype
397403
assert bf_result.dtype == bpd.Int64Dtype()
398404
pd_result = pd_result.astype("Int64")
399405
pd.testing.assert_series_equal(bf_result.to_pandas(), pd_result)
400406

401407

402-
def test_cut_default_labels(scalars_dfs):
408+
@pytest.mark.parametrize(
409+
("right"),
410+
[
411+
pytest.param(True),
412+
pytest.param(False),
413+
],
414+
)
415+
def test_cut_default_labels(scalars_dfs, right):
403416
scalars_df, scalars_pandas_df = scalars_dfs
404417

405-
pd_result = pd.cut(scalars_pandas_df["float64_col"], 5)
406-
bf_result = bpd.cut(scalars_df["float64_col"], 5).to_pandas()
418+
pd_result = pd.cut(scalars_pandas_df["float64_col"], 5, right=right)
419+
bf_result = bpd.cut(scalars_df["float64_col"], 5, right=right).to_pandas()
407420

408421
# Convert to match data format
422+
pd_interval = pd_result.cat.categories[pd_result.cat.codes]
423+
if pd_interval.closed == "left":
424+
left_key = "left_inclusive"
425+
right_key = "right_exclusive"
426+
else:
427+
left_key = "left_exclusive"
428+
right_key = "right_inclusive"
409429
pd_result_converted = pd.Series(
410430
[
411-
{"left_exclusive": interval.left, "right_inclusive": interval.right}
431+
{left_key: interval.left, right_key: interval.right}
412432
if pd.notna(val)
413433
else pd.NA
414434
for val, interval in zip(
415-
pd_result, pd_result.cat.categories[pd_result.cat.codes]
435+
pd_result, pd_interval
416436
)
417437
],
418438
name=pd_result.name,
@@ -424,27 +444,36 @@ def test_cut_default_labels(scalars_dfs):
424444

425445

426446
@pytest.mark.parametrize(
427-
("breaks",),
447+
("breaks", "right"),
428448
[
429-
([0, 5, 10, 15, 20, 100, 1000],), # ints
430-
([0.5, 10.5, 15.5, 20.5, 100.5, 1000.5],), # floats
431-
([0, 5, 10.5, 15.5, 20, 100, 1000.5],), # mixed
449+
pytest.param([0, 5, 10, 15, 20, 100, 1000], True, id="int_right"),
450+
pytest.param([0, 5, 10, 15, 20, 100, 1000], False, id="int_left"),
451+
pytest.param([0.5, 10.5, 15.5, 20.5, 100.5, 1000.5], False, id="float_left"),
452+
pytest.param([0, 5, 10.5, 15.5, 20, 100, 1000.5], True, id="mixed_right"),
432453
],
433454
)
434-
def test_cut_numeric_breaks(scalars_dfs, breaks):
455+
def test_cut_numeric_breaks(scalars_dfs, breaks, right):
435456
scalars_df, scalars_pandas_df = scalars_dfs
436457

437-
pd_result = pd.cut(scalars_pandas_df["float64_col"], breaks)
438-
bf_result = bpd.cut(scalars_df["float64_col"], breaks).to_pandas()
458+
pd_result = pd.cut(scalars_pandas_df["float64_col"], breaks, right=right)
459+
bf_result = bpd.cut(scalars_df["float64_col"], breaks, right=right).to_pandas()
439460

440461
# Convert to match data format
462+
pd_interval = pd_result.cat.categories[pd_result.cat.codes]
463+
if pd_interval.closed == "left":
464+
left_key = "left_inclusive"
465+
right_key = "right_exclusive"
466+
else:
467+
left_key = "left_exclusive"
468+
right_key = "right_inclusive"
469+
441470
pd_result_converted = pd.Series(
442471
[
443-
{"left_exclusive": interval.left, "right_inclusive": interval.right}
472+
{left_key: interval.left, right_key: interval.right}
444473
if pd.notna(val)
445474
else pd.NA
446475
for val, interval in zip(
447-
pd_result, pd_result.cat.categories[pd_result.cat.codes]
476+
pd_result, pd_interval
448477
)
449478
],
450479
name=pd_result.name,
@@ -476,28 +505,38 @@ def test_cut_errors(scalars_dfs, bins):
476505

477506

478507
@pytest.mark.parametrize(
479-
("bins",),
508+
("bins", "right"),
480509
[
481-
([(-5, 2), (2, 3), (-3000, -10)],),
482-
(pd.IntervalIndex.from_tuples([(1, 2), (2, 3), (4, 5)]),),
510+
pytest.param([(-5, 2), (2, 3), (-3000, -10)], True, id="tuple_right"),
511+
pytest.param([(-5, 2), (2, 3), (-3000, -10)], False, id="tuple_left"),
512+
pytest.param(pd.IntervalIndex.from_tuples([(1, 2), (2, 3), (4, 5)]), True, id="interval_right"),
513+
pytest.param(pd.IntervalIndex.from_tuples([(1, 2), (2, 3), (4, 5)]), False, id="interval_left"),
483514
],
484515
)
485-
def test_cut_with_interval(scalars_dfs, bins):
516+
def test_cut_with_interval(scalars_dfs, bins, right):
486517
scalars_df, scalars_pandas_df = scalars_dfs
487-
bf_result = bpd.cut(scalars_df["int64_too"], bins, labels=False).to_pandas()
518+
bf_result = bpd.cut(scalars_df["int64_too"], bins, labels=False, right=right).to_pandas()
488519

489520
if isinstance(bins, list):
490521
bins = pd.IntervalIndex.from_tuples(bins)
491-
pd_result = pd.cut(scalars_pandas_df["int64_too"], bins, labels=False)
522+
pd_result = pd.cut(scalars_pandas_df["int64_too"], bins, labels=False, right=right)
492523

493524
# Convert to match data format
525+
pd_interval = pd_result.cat.categories[pd_result.cat.codes]
526+
if pd_interval.closed == "left":
527+
left_key = "left_inclusive"
528+
right_key = "right_exclusive"
529+
else:
530+
left_key = "left_exclusive"
531+
right_key = "right_inclusive"
532+
494533
pd_result_converted = pd.Series(
495534
[
496-
{"left_exclusive": interval.left, "right_inclusive": interval.right}
535+
{left_key: interval.left, right_key: interval.right}
497536
if pd.notna(val)
498537
else pd.NA
499538
for val, interval in zip(
500-
pd_result, pd_result.cat.categories[pd_result.cat.codes]
539+
pd_result, pd_interval
501540
)
502541
],
503542
name=pd_result.name,

0 commit comments

Comments
 (0)