Skip to content
Merged
Prev Previous commit
Next Next commit
create labels for custom colors
  • Loading branch information
Michael Vincent Mannino authored and Michael Vincent Mannino committed Jul 12, 2024
commit 571c0c8c8269b4b81ca37e23613db4b1d048487f
57 changes: 31 additions & 26 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Iterator,
Sequence,
)
from random import shuffle
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -1337,10 +1338,12 @@ def _make_plot(self, fig: Figure) -> None:
norm, cmap = self._get_norm_and_cmap(c_values, color_by_categorical)
cb = self._get_colorbar(c_values, c_is_column)

orig_invalid_colors = not self._are_valid_colors(c_values)
if orig_invalid_colors:
unique_color_labels, c_values = self._convert_str_to_colors(c_values)
cb = False
# if a list of non color strings is passed in as c, generate a list
# colored by uniqueness of the strings, such same strings get same color
create_colors = not self._are_valid_colors(c_values)
if create_colors:
color_mapping, c_values = self._uniquely_color_strs(c_values)
cb = False # no colorbar; opt for legend

if self.legend:
label = self.label
Expand Down Expand Up @@ -1372,14 +1375,14 @@ def _make_plot(self, fig: Figure) -> None:
label, # type: ignore[arg-type]
)

if orig_invalid_colors:
for s in unique_color_labels:
self._append_legend_handles_labels(
# error: Argument 2 to "_append_legend_handles_labels" of
# "MPLPlot" has incompatible type "Hashable"; expected "str"
scatter,
s, # type: ignore[arg-type]
)
# build legend for labeling custom colors
if create_colors:
ax.legend(
handles=[
mpl.patches.Circle((0, 0), facecolor=color, label=string)
for string, color in color_mapping.items()
]
)

errors_x = self._get_errorbars(label=x, index=0, yerr=False)
errors_y = self._get_errorbars(label=y, index=0, xerr=False)
Expand All @@ -1404,29 +1407,31 @@ def _get_c_values(self, color, color_by_categorical: bool, c_is_column: bool):
c_values = c
return c_values

def _are_valid_colors(self, c_values):
# check if c_values contains strings. no need to check numerics as these
# will be validated for us in .Axes.scatter._parse_scatter_color_args(...)
if not (
np.iterable(c_values) and len(c_values) > 0 and isinstance(c_values[0], str)
):
return True

def _are_valid_colors(self, c_values: np.ndarray | list):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what instances is c_values a list? Might be misreading but would be better if we only worked with a pd.Series and could call .unique on that, instead of checking every single value in a loop

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to take a pd.Series, not np.ndarray | list

# check if c_values contains strings and if these strings are valid mpl colors
# no need to check numerics as these (and mpl colors) will be validated for us
# in .Axes.scatter._parse_scatter_color_args(...)
try:
# similar to above, if this conversion is successful, remaining validation
# will be done in .Axes.scatter._parse_scatter_color_args(...)
_ = mpl.colors.to_rgba_array(c_values)
if len(c_values) and all(isinstance(c, str) for c in c_values):
mpl.colors.to_rgba_array(c_values)

return True

except (TypeError, ValueError) as _:
return False

def _convert_str_to_colors(self, c_values):
def _uniquely_color_strs(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am no matplotlib expert but I think we need to defer to that somehow to get the desired colors, instead of trying to write this out ourselves

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some research looking into how other libs which support this functionality, including seaborn, handle this workflow. Each utilize the current mpl.colormap and draw colors from a linear space across this map. This allows users to change the choice colors of the same way they would with all other graphs

Additionally, I've added an automatic legend to this type of plot since the chosen colors are not exposed to the user, similar to how a colorbar is drawn in some cases of the same function

self, c_values: np.ndarray | list
) -> tuple[dict, np.ndarray]:
# well, almost uniquely color them (up to 949)
possible_colors = list(mpl.colors.XKCD_COLORS.values()) # Hex representations
shuffle(possible_colors) # TODO: find better way of getting colors

unique = np.unique(c_values)
colors = np.linspace(0, 1, len(unique))
colors = [possible_colors[i % len(possible_colors)] for i in range(len(unique))]
color_mapping = dict(zip(unique, colors))

return unique, np.array(list(map(color_mapping.get, c_values)))
return color_mapping, np.array(list(map(color_mapping.get, c_values)))

def _get_norm_and_cmap(self, c_values, color_by_categorical: bool):
c = self.c
Expand Down