-
- Notifications
You must be signed in to change notification settings - Fork 19.4k
ENH: DataFrame.plot.scatter argument c now accepts a column of strings, where rows with the same string are colored identically #59239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
b91e635 8609ea5 b4440c1 571c0c8 e9511d0 1ca57ed fb0d6e4 4bcdbfc 7972138 45886d9 1713727 62427ad 609fe40 6e86858 5223f2a d97606c 7e5a02a File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -10,6 +10,7 @@ | |
| Iterator, | ||
| Sequence, | ||
| ) | ||
| from random import shuffle | ||
| from typing import ( | ||
| TYPE_CHECKING, | ||
| Any, | ||
| | @@ -1337,6 +1338,13 @@ def _make_plot(self, fig: Figure) -> None: | |
| norm, cmap = self._get_norm_and_cmap(c_values, color_by_categorical) | ||
| cb = self._get_colorbar(c_values, c_is_column) | ||
| | ||
| # if a list of non color strings is passed in as c, generate a list | ||
| # colored by uniqueness of the strings, such same strings get same color | ||
| create_colors = not self._are_valid_colors(c_values) | ||
| if create_colors: | ||
| custom_color_mapping, c_values = self._uniquely_color_strs(c_values) | ||
| cb = False # no colorbar; opt for legend | ||
| | ||
| if self.legend: | ||
| label = self.label | ||
| else: | ||
| | @@ -1367,6 +1375,15 @@ def _make_plot(self, fig: Figure) -> None: | |
| label, # type: ignore[arg-type] | ||
| ) | ||
| | ||
| # build legend for labeling custom colors | ||
| if create_colors: | ||
| ax.legend( | ||
| handles=[ | ||
| mpl.patches.Circle((0, 0), facecolor=color, label=string) | ||
| for string, color in custom_color_mapping.items() | ||
| ] | ||
| ) | ||
| | ||
| errors_x = self._get_errorbars(label=x, index=0, yerr=False) | ||
| errors_y = self._get_errorbars(label=y, index=0, xerr=False) | ||
| if len(errors_x) > 0 or len(errors_y) > 0: | ||
| | @@ -1390,6 +1407,38 @@ def _get_c_values(self, color, color_by_categorical: bool, c_is_column: bool): | |
| c_values = c | ||
| return c_values | ||
| | ||
| def _are_valid_colors(self, c_values: np.ndarray | list): | ||
| # check if c_values contains strings and if these strings are valid mpl colors. | ||
| # no need to check numerics as these (and mpl colors) will be validated for us | ||
| # in .Axes.scatter._parse_scatter_color_args(...) | ||
| try: | ||
| if len(c_values) and all(isinstance(c, str) for c in c_values): | ||
| mpl.colors.to_rgba_array(c_values) | ||
| | ||
| return True | ||
| | ||
| except (TypeError, ValueError) as _: | ||
| return False | ||
| | ||
| def _uniquely_color_strs( | ||
| ||
| self, c_values: np.ndarray | list | ||
| ) -> tuple[dict, np.ndarray]: | ||
| # well, almost uniquely color them (up to 949) | ||
| unique = np.unique(c_values) | ||
| | ||
| # for up to 7, lets keep colors consistent | ||
| if len(unique) <= 7: | ||
| possible_colors = list(mpl.colors.BASE_COLORS.values()) # Hex | ||
| # explore better ways to handle this case | ||
| else: | ||
| possible_colors = list(mpl.colors.XKCD_COLORS.values()) # Hex | ||
| shuffle(possible_colors) | ||
| | ||
| colors = [possible_colors[i % len(possible_colors)] for i in range(len(unique))] | ||
| color_mapping = dict(zip(unique, colors)) | ||
| | ||
| return color_mapping, np.array(list(map(color_mapping.get, c_values))) | ||
| | ||
| def _get_norm_and_cmap(self, c_values, color_by_categorical: bool): | ||
| c = self.c | ||
| if self.colormap is not None: | ||
| | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -207,6 +207,21 @@ def test_scatter_with_c_column_name_with_colors(self, cmap): | |
| ax = df.plot.scatter(x=0, y=1, c="species", cmap=cmap) | ||
| assert ax.collections[0].colorbar is None | ||
| | ||
| def test_scatter_with_c_column_name_without_colors(self): | ||
| df = DataFrame( | ||
| { | ||
| "dataX": range(100), | ||
| "dataY": range(100), | ||
| "state": ["NY", "MD", "MA", "CA"] * 25, | ||
| } | ||
| ) | ||
| df.plot.scatter("dataX", "dataY", c="state") | ||
| | ||
| with tm.assert_produces_warning(None): | ||
| ax = df.plot.scatter(x=0, y=1, c="state") | ||
| ||
| | ||
| assert len(np.unique(ax.collections[0].get_facecolor())) == 4 # 4 states | ||
| | ||
| def test_scatter_colors(self): | ||
| df = DataFrame({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]}) | ||
| with pytest.raises(TypeError, match="Specify exactly one of `c` and `color`"): | ||
| | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In what instances is
c_valuesa list? Might be misreading but would be better if we only worked with a pd.Series and could call .unique on that, instead of checking every single value in a loopThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to take a
pd.Series, notnp.ndarray | list