Skip to content

Commit 5f1d670

Browse files
Genesis929tswast
andauthored
feat: Add pivot_table for DataFrame. (#473)
* feat: Add pivot_table for DataFrame. * Update logic * Update comments * Remove code unused after merge. * Code update. * Update code example. * Update for Tuple type. * Update code logic * Update format --------- Co-authored-by: Tim Sweña (Swast) <swast@google.com>
1 parent edef48f commit 5f1d670

File tree

3 files changed

+170
-0
lines changed

3 files changed

+170
-0
lines changed

bigframes/dataframe.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2132,6 +2132,66 @@ def pivot(
21322132
) -> DataFrame:
21332133
return self._pivot(columns=columns, index=index, values=values)
21342134

2135+
def pivot_table(
2136+
self,
2137+
values: typing.Optional[
2138+
typing.Union[blocks.Label, Sequence[blocks.Label]]
2139+
] = None,
2140+
index: typing.Optional[
2141+
typing.Union[blocks.Label, Sequence[blocks.Label]]
2142+
] = None,
2143+
columns: typing.Union[blocks.Label, Sequence[blocks.Label]] = None,
2144+
aggfunc: str = "mean",
2145+
) -> DataFrame:
2146+
if isinstance(index, Iterable) and not (
2147+
isinstance(index, blocks.Label) and index in self.columns
2148+
):
2149+
index = list(index)
2150+
else:
2151+
index = [index]
2152+
2153+
if isinstance(columns, Iterable) and not (
2154+
isinstance(columns, blocks.Label) and columns in self.columns
2155+
):
2156+
columns = list(columns)
2157+
else:
2158+
columns = [columns]
2159+
2160+
if isinstance(values, Iterable) and not (
2161+
isinstance(values, blocks.Label) and values in self.columns
2162+
):
2163+
values = list(values)
2164+
else:
2165+
values = [values]
2166+
2167+
# Unlike pivot, pivot_table has values always ordered.
2168+
values.sort()
2169+
2170+
keys = index + columns
2171+
agged = self.groupby(keys, dropna=True)[values].agg(aggfunc)
2172+
2173+
if isinstance(agged, bigframes.series.Series):
2174+
agged = agged.to_frame()
2175+
2176+
agged = agged.dropna(how="all")
2177+
2178+
if len(values) == 1:
2179+
agged = agged.rename(columns={agged.columns[0]: values[0]})
2180+
2181+
agged = agged.reset_index()
2182+
2183+
pivoted = agged.pivot(
2184+
columns=columns,
2185+
index=index,
2186+
values=values if len(values) > 1 else None,
2187+
).sort_index()
2188+
2189+
# TODO: Remove the reordering step once the issue is resolved.
2190+
# The pivot_table method results in multi-index columns that are always ordered.
2191+
# However, the order of the pivoted result columns is not guaranteed to be sorted.
2192+
# Sort and reorder.
2193+
return pivoted[pivoted.columns.sort_values()]
2194+
21352195
def stack(self, level: LevelsType = -1):
21362196
if not isinstance(self.columns, pandas.MultiIndex):
21372197
if level not in [0, -1, self.columns.name]:

tests/system/small/test_dataframe.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2606,6 +2606,34 @@ def test_df_pivot_hockey(hockey_df, hockey_pandas_df, values, index, columns):
26062606
pd.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
26072607

26082608

2609+
@pytest.mark.parametrize(
2610+
("values", "index", "columns", "aggfunc"),
2611+
[
2612+
(("culmen_length_mm", "body_mass_g"), "species", "sex", "std"),
2613+
(["body_mass_g", "culmen_length_mm"], ("species", "island"), "sex", "sum"),
2614+
("body_mass_g", "sex", ["island", "species"], "mean"),
2615+
("culmen_depth_mm", "island", "species", "max"),
2616+
],
2617+
)
2618+
def test_df_pivot_table(
2619+
penguins_df_default_index,
2620+
penguins_pandas_df_default_index,
2621+
values,
2622+
index,
2623+
columns,
2624+
aggfunc,
2625+
):
2626+
bf_result = penguins_df_default_index.pivot_table(
2627+
values=values, index=index, columns=columns, aggfunc=aggfunc
2628+
).to_pandas()
2629+
pd_result = penguins_pandas_df_default_index.pivot_table(
2630+
values=values, index=index, columns=columns, aggfunc=aggfunc
2631+
)
2632+
pd.testing.assert_frame_equal(
2633+
bf_result, pd_result, check_dtype=False, check_column_type=False
2634+
)
2635+
2636+
26092637
def test_ipython_key_completions_with_drop(scalars_dfs):
26102638
scalars_df, scalars_pandas_df = scalars_dfs
26112639
col_names = "string_col"

third_party/bigframes_vendored/pandas/core/frame.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4711,6 +4711,88 @@ def pivot(self, *, columns, index=None, values=None):
47114711
"""
47124712
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
47134713

4714+
def pivot_table(self, values=None, index=None, columns=None, aggfunc="mean"):
4715+
"""
4716+
Create a spreadsheet-style pivot table as a DataFrame.
4717+
4718+
The levels in the pivot table will be stored in MultiIndex objects (hierarchical indexes)
4719+
on the index and columns of the result DataFrame.
4720+
4721+
**Examples:**
4722+
4723+
>>> import bigframes.pandas as bpd
4724+
>>> bpd.options.display.progress_bar = None
4725+
4726+
>>> df = bpd.DataFrame({
4727+
... 'Product': ['Product A', 'Product B', 'Product A', 'Product B', 'Product A', 'Product B'],
4728+
... 'Region': ['East', 'West', 'East', 'West', 'West', 'East'],
4729+
... 'Sales': [100, 200, 150, 100, 200, 150],
4730+
... 'Rating': [3, 5, 4, 3, 3, 5]
4731+
... })
4732+
>>> df
4733+
Product Region Sales Rating
4734+
0 Product A East 100 3
4735+
1 Product B West 200 5
4736+
2 Product A East 150 4
4737+
3 Product B West 100 3
4738+
4 Product A West 200 3
4739+
5 Product B East 150 5
4740+
<BLANKLINE>
4741+
[6 rows x 4 columns]
4742+
4743+
Using `pivot_table` with default aggfunc "mean":
4744+
4745+
>>> pivot_table = df.pivot_table(
4746+
... values=['Sales', 'Rating'],
4747+
... index='Product',
4748+
... columns='Region'
4749+
... )
4750+
>>> pivot_table
4751+
Rating Sales
4752+
Region East West East West
4753+
Product
4754+
Product A 3.5 3.0 125.0 200.0
4755+
Product B 5.0 4.0 150.0 150.0
4756+
<BLANKLINE>
4757+
[2 rows x 4 columns]
4758+
4759+
Using `pivot_table` with specified aggfunc "max":
4760+
4761+
>>> pivot_table = df.pivot_table(
4762+
... values=['Sales', 'Rating'],
4763+
... index='Product',
4764+
... columns='Region',
4765+
... aggfunc="max"
4766+
... )
4767+
>>> pivot_table
4768+
Rating Sales
4769+
Region East West East West
4770+
Product
4771+
Product A 4 3 150 200
4772+
Product B 5 5 150 200
4773+
<BLANKLINE>
4774+
[2 rows x 4 columns]
4775+
4776+
Args:
4777+
values (str, object or a list of the previous, optional):
4778+
Column(s) to use for populating new frame's values. If not
4779+
specified, all remaining columns will be used and the result will
4780+
have hierarchically indexed columns.
4781+
4782+
index (str or object or a list of str, optional):
4783+
Column to use to make new frame's index. If not given, uses existing index.
4784+
4785+
columns (str or object or a list of str):
4786+
Column to use to make new frame's columns.
4787+
4788+
aggfunc (str, default "mean"):
4789+
Aggregation function name to compute summary statistics (e.g., 'sum', 'mean').
4790+
4791+
Returns:
4792+
DataFrame: An Excel style pivot table.
4793+
"""
4794+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
4795+
47144796
def stack(self, level=-1):
47154797
"""
47164798
Stack the prescribed level(s) from columns to index.

0 commit comments

Comments
 (0)