Skip to content

Commit 1df91da

Browse files
More careful change
1 parent c08c278 commit 1df91da

File tree

2 files changed

+35
-11
lines changed

2 files changed

+35
-11
lines changed

pandas/plotting/_matplotlib/boxplot.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -315,21 +315,20 @@ def plot_group(keys, values, ax):
315315
ax = plt.gca()
316316
data = data._get_numeric_data()
317317

318-
# if columns is None, use all numeric columns of data; if data columns
319-
# is MultiIndex, which means a Groupby has been applied before, select
320-
# data using new grouped column names; if data columns is Index, select
321-
# data simply using columns
318+
# if columns is None, use all numeric columns of data, so directly pass
319+
# if the given 'column' are subset of column index, no matter if data column
320+
# is multiindex or index, get the subset of data directly
321+
# if given 'column' is not subset, and data columns is multiindex, and then
322+
# query the columns if column contains at least one element from 'columns'
322323
if columns is None:
323-
columns = data.columns
324+
pass
325+
elif set(column).issubset(data.columns):
326+
data = data.loc[:, columns]
324327
elif isinstance(data.columns, pd.MultiIndex):
325-
326-
# reselect columns with after-groupby multi-index columns
327-
data = data.loc[:, pd.IndexSlice[:, columns]]
328-
columns = data.columns
329-
elif isinstance(data.columns, pd.Index):
328+
columns = [col for col in data.columns if set(columns).intersection(col)]
330329
data = data.loc[:, columns]
331-
columns = data.columns
332330

331+
columns = data.columns
333332
result = plot_group(columns, data.values.T, ax)
334333
ax.grid(grid)
335334

pandas/tests/plotting/test_boxplot_method.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,3 +455,28 @@ def test_groupby_boxplot_subplots_false(self, col, expected_xticklabel):
455455
# check if xticks labels are plotted correctly
456456
result_xticklabel = [x.get_text() for x in axes.get_xticklabels()]
457457
assert expected_xticklabel == result_xticklabel
458+
459+
@pytest.mark.parametrize(
460+
"col, expected_xticklabel",
461+
[
462+
([("bar", "one"), ("bar", "two")], ["(bar, one)", "(bar, two)"]),
463+
("bar", ["(bar, one)", "(bar, two)"]),
464+
(["two"], ["(bar, two)", "(baz, two)", "(foo, two)", "(qux, two)"]),
465+
],
466+
)
467+
def test_boxplot_multiindex_column(self, col, expected_xticklabel):
468+
# this is test the boxplot on multi-index column cases
469+
arrays = [
470+
["bar", "bar", "baz", "baz", "foo", "foo", "qux", "qux"],
471+
["one", "two", "one", "two", "one", "two", "one", "two"],
472+
]
473+
tuples = list(zip(*arrays))
474+
index = MultiIndex.from_tuples(tuples, names=["first", "second"])
475+
df = DataFrame(np.random.randn(3, 8), index=["A", "B", "C"], columns=index)
476+
477+
# check if df.boxplot works
478+
axes = _check_plot_works(df.boxplot, column=col, return_type="axes")
479+
480+
# check if xticks labels are plotted correctly
481+
result_xticklabel = [x.get_text() for x in axes.get_xticklabels()]
482+
assert expected_xticklabel == result_xticklabel

0 commit comments

Comments
 (0)