diff --git a/pandas/plotting/_matplotlib/style.py b/pandas/plotting/_matplotlib/style.py index b919728971505..b2c7b2610845c 100644 --- a/pandas/plotting/_matplotlib/style.py +++ b/pandas/plotting/_matplotlib/style.py @@ -1,4 +1,14 @@ -# being a bit too dynamic +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterator, + List, + Optional, + Sequence, + Union, + cast, +) import warnings import matplotlib.cm as cm @@ -9,92 +19,256 @@ import pandas.core.common as com +if TYPE_CHECKING: + from matplotlib.colors import Colormap + + +Color = Union[str, Sequence[float]] + def get_standard_colors( - num_colors: int, colormap=None, color_type: str = "default", color=None + num_colors: int, + colormap: Optional["Colormap"] = None, + color_type: str = "default", + color: Optional[Union[Dict[str, Color], Color, Collection[Color]]] = None, ): - import matplotlib.pyplot as plt + """ + Get standard colors based on `colormap`, `color_type` or `color` inputs. + + Parameters + ---------- + num_colors : int + Minimum number of colors to be returned. + Ignored if `color` is a dictionary. + colormap : :py:class:`matplotlib.colors.Colormap`, optional + Matplotlib colormap. + When provided, the resulting colors will be derived from the colormap. + color_type : {"default", "random"}, optional + Type of colors to derive. Used if provided `color` and `colormap` are None. + Ignored if either `color` or `colormap` are not None. + color : dict or str or sequence, optional + Color(s) to be used for deriving sequence of colors. + Can be either be a dictionary, or a single color (single color string, + or sequence of floats representing a single color), + or a sequence of colors. + + Returns + ------- + dict or list + Standard colors. Can either be a mapping if `color` was a dictionary, + or a list of colors with a length of `num_colors` or more. + + Warns + ----- + UserWarning + If both `colormap` and `color` are provided. + Parameter `color` will override. + """ + if isinstance(color, dict): + return color + + colors = _derive_colors( + color=color, + colormap=colormap, + color_type=color_type, + num_colors=num_colors, + ) + + return _cycle_colors(colors, num_colors=num_colors) + + +def _derive_colors( + *, + color: Optional[Union[Color, Collection[Color]]], + colormap: Optional[Union[str, "Colormap"]], + color_type: str, + num_colors: int, +) -> List[Color]: + """ + Derive colors from either `colormap`, `color_type` or `color` inputs. + + Get a list of colors either from `colormap`, or from `color`, + or from `color_type` (if both `colormap` and `color` are None). + + Parameters + ---------- + color : str or sequence, optional + Color(s) to be used for deriving sequence of colors. + Can be either be a single color (single color string, or sequence of floats + representing a single color), or a sequence of colors. + colormap : :py:class:`matplotlib.colors.Colormap`, optional + Matplotlib colormap. + When provided, the resulting colors will be derived from the colormap. + color_type : {"default", "random"}, optional + Type of colors to derive. Used if provided `color` and `colormap` are None. + Ignored if either `color` or `colormap`` are not None. + num_colors : int + Number of colors to be extracted. + Returns + ------- + list + List of colors extracted. + + Warns + ----- + UserWarning + If both `colormap` and `color` are provided. + Parameter `color` will override. + """ if color is None and colormap is not None: - if isinstance(colormap, str): - cmap = colormap - colormap = cm.get_cmap(colormap) - if colormap is None: - raise ValueError(f"Colormap {cmap} is not recognized") - colors = [colormap(num) for num in np.linspace(0, 1, num=num_colors)] + return _get_colors_from_colormap(colormap, num_colors=num_colors) elif color is not None: if colormap is not None: warnings.warn( "'color' and 'colormap' cannot be used simultaneously. Using 'color'" ) - colors = ( - list(color) - if is_list_like(color) and not isinstance(color, dict) - else color - ) + return _get_colors_from_color(color) else: - if color_type == "default": - # need to call list() on the result to copy so we don't - # modify the global rcParams below - try: - colors = [c["color"] for c in list(plt.rcParams["axes.prop_cycle"])] - except KeyError: - colors = list(plt.rcParams.get("axes.color_cycle", list("bgrcmyk"))) - if isinstance(colors, str): - colors = list(colors) - - colors = colors[0:num_colors] - elif color_type == "random": - - def random_color(column): - """ Returns a random color represented as a list of length 3""" - # GH17525 use common._random_state to avoid resetting the seed - rs = com.random_state(column) - return rs.rand(3).tolist() - - colors = [random_color(num) for num in range(num_colors)] - else: - raise ValueError("color_type must be either 'default' or 'random'") + return _get_colors_from_color_type(color_type, num_colors=num_colors) - if isinstance(colors, str) and _is_single_color(colors): - # GH #36972 - colors = [colors] - # Append more colors by cycling if there is not enough color. - # Extra colors will be ignored by matplotlib if there are more colors - # than needed and nothing needs to be done here. +def _cycle_colors(colors: List[Color], num_colors: int) -> List[Color]: + """Append more colors by cycling if there is not enough color. + + Extra colors will be ignored by matplotlib if there are more colors + than needed and nothing needs to be done here. + """ if len(colors) < num_colors: - try: - multiple = num_colors // len(colors) - 1 - except ZeroDivisionError: - raise ValueError("Invalid color argument: ''") + multiple = num_colors // len(colors) - 1 mod = num_colors % len(colors) - colors += multiple * colors colors += colors[:mod] return colors -def _is_single_color(color: str) -> bool: - """Check if ``color`` is a single color. +def _get_colors_from_colormap( + colormap: Union[str, "Colormap"], + num_colors: int, +) -> List[Color]: + """Get colors from colormap.""" + colormap = _get_cmap_instance(colormap) + return [colormap(num) for num in np.linspace(0, 1, num=num_colors)] + + +def _get_cmap_instance(colormap: Union[str, "Colormap"]) -> "Colormap": + """Get instance of matplotlib colormap.""" + if isinstance(colormap, str): + cmap = colormap + colormap = cm.get_cmap(colormap) + if colormap is None: + raise ValueError(f"Colormap {cmap} is not recognized") + return colormap + + +def _get_colors_from_color( + color: Union[Color, Collection[Color]], +) -> List[Color]: + """Get colors from user input color.""" + if len(color) == 0: + raise ValueError(f"Invalid color argument: {color}") + + if _is_single_color(color): + color = cast(Color, color) + return [color] + + color = cast(Collection[Color], color) + return list(_gen_list_of_colors_from_iterable(color)) + + +def _is_single_color(color: Union[Color, Collection[Color]]) -> bool: + """Check if `color` is a single color, not a sequence of colors. + + Single color is of these kinds: + - Named color "red", "C0", "firebrick" + - Alias "g" + - Sequence of floats, such as (0.1, 0.2, 0.3) or (0.1, 0.2, 0.3, 0.4). + + See Also + -------- + _is_single_string_color + """ + if isinstance(color, str) and _is_single_string_color(color): + # GH #36972 + return True + + if _is_floats_color(color): + return True + + return False + + +def _gen_list_of_colors_from_iterable(color: Collection[Color]) -> Iterator[Color]: + """ + Yield colors from string of several letters or from collection of colors. + """ + for x in color: + if _is_single_color(x): + yield x + else: + raise ValueError(f"Invalid color {x}") + + +def _is_floats_color(color: Union[Color, Collection[Color]]) -> bool: + """Check if color comprises a sequence of floats representing color.""" + return bool( + is_list_like(color) + and (len(color) == 3 or len(color) == 4) + and all(isinstance(x, (int, float)) for x in color) + ) + + +def _get_colors_from_color_type(color_type: str, num_colors: int) -> List[Color]: + """Get colors from user input color type.""" + if color_type == "default": + return _get_default_colors(num_colors) + elif color_type == "random": + return _get_random_colors(num_colors) + else: + raise ValueError("color_type must be either 'default' or 'random'") + + +def _get_default_colors(num_colors: int) -> List[Color]: + """Get `num_colors` of default colors from matplotlib rc params.""" + import matplotlib.pyplot as plt + + colors = [c["color"] for c in plt.rcParams["axes.prop_cycle"]] + return colors[0:num_colors] + + +def _get_random_colors(num_colors: int) -> List[Color]: + """Get `num_colors` of random colors.""" + return [_random_color(num) for num in range(num_colors)] + + +def _random_color(column: int) -> List[float]: + """Get a random color represented as a list of length 3""" + # GH17525 use common._random_state to avoid resetting the seed + rs = com.random_state(column) + return rs.rand(3).tolist() + + +def _is_single_string_color(color: Color) -> bool: + """Check if `color` is a single string color. - Examples of single colors: + Examples of single string colors: - 'r' - 'g' - 'red' - 'green' - 'C3' + - 'firebrick' Parameters ---------- - color : string - Color string. + color : Color + Color string or sequence of floats. Returns ------- bool - True if ``color`` looks like a valid color. + True if `color` looks like a valid color. False otherwise. """ conv = matplotlib.colors.ColorConverter() diff --git a/pandas/tests/plotting/test_style.py b/pandas/tests/plotting/test_style.py new file mode 100644 index 0000000000000..665bda15724fd --- /dev/null +++ b/pandas/tests/plotting/test_style.py @@ -0,0 +1,157 @@ +import pytest + +from pandas import Series + +pytest.importorskip("matplotlib") +from pandas.plotting._matplotlib.style import get_standard_colors + + +class TestGetStandardColors: + @pytest.mark.parametrize( + "num_colors, expected", + [ + (3, ["red", "green", "blue"]), + (5, ["red", "green", "blue", "red", "green"]), + (7, ["red", "green", "blue", "red", "green", "blue", "red"]), + (2, ["red", "green"]), + (1, ["red"]), + ], + ) + def test_default_colors_named_from_prop_cycle(self, num_colors, expected): + import matplotlib as mpl + from matplotlib.pyplot import cycler + + mpl_params = { + "axes.prop_cycle": cycler(color=["red", "green", "blue"]), + } + with mpl.rc_context(rc=mpl_params): + result = get_standard_colors(num_colors=num_colors) + assert result == expected + + @pytest.mark.parametrize( + "num_colors, expected", + [ + (1, ["b"]), + (3, ["b", "g", "r"]), + (4, ["b", "g", "r", "y"]), + (5, ["b", "g", "r", "y", "b"]), + (7, ["b", "g", "r", "y", "b", "g", "r"]), + ], + ) + def test_default_colors_named_from_prop_cycle_string(self, num_colors, expected): + import matplotlib as mpl + from matplotlib.pyplot import cycler + + mpl_params = { + "axes.prop_cycle": cycler(color="bgry"), + } + with mpl.rc_context(rc=mpl_params): + result = get_standard_colors(num_colors=num_colors) + assert result == expected + + @pytest.mark.parametrize( + "num_colors, expected_name", + [ + (1, ["C0"]), + (3, ["C0", "C1", "C2"]), + ( + 12, + [ + "C0", + "C1", + "C2", + "C3", + "C4", + "C5", + "C6", + "C7", + "C8", + "C9", + "C0", + "C1", + ], + ), + ], + ) + def test_default_colors_named_undefined_prop_cycle(self, num_colors, expected_name): + import matplotlib as mpl + import matplotlib.colors as mcolors + + with mpl.rc_context(rc={}): + expected = [mcolors.to_hex(x) for x in expected_name] + result = get_standard_colors(num_colors=num_colors) + assert result == expected + + @pytest.mark.parametrize( + "num_colors, expected", + [ + (1, ["red", "green", (0.1, 0.2, 0.3)]), + (2, ["red", "green", (0.1, 0.2, 0.3)]), + (3, ["red", "green", (0.1, 0.2, 0.3)]), + (4, ["red", "green", (0.1, 0.2, 0.3), "red"]), + ], + ) + def test_user_input_color_sequence(self, num_colors, expected): + color = ["red", "green", (0.1, 0.2, 0.3)] + result = get_standard_colors(color=color, num_colors=num_colors) + assert result == expected + + @pytest.mark.parametrize( + "num_colors, expected", + [ + (1, ["r", "g", "b", "k"]), + (2, ["r", "g", "b", "k"]), + (3, ["r", "g", "b", "k"]), + (4, ["r", "g", "b", "k"]), + (5, ["r", "g", "b", "k", "r"]), + (6, ["r", "g", "b", "k", "r", "g"]), + ], + ) + def test_user_input_color_string(self, num_colors, expected): + color = "rgbk" + result = get_standard_colors(color=color, num_colors=num_colors) + assert result == expected + + @pytest.mark.parametrize( + "num_colors, expected", + [ + (1, [(0.1, 0.2, 0.3)]), + (2, [(0.1, 0.2, 0.3), (0.1, 0.2, 0.3)]), + (3, [(0.1, 0.2, 0.3), (0.1, 0.2, 0.3), (0.1, 0.2, 0.3)]), + ], + ) + def test_user_input_color_floats(self, num_colors, expected): + color = (0.1, 0.2, 0.3) + result = get_standard_colors(color=color, num_colors=num_colors) + assert result == expected + + @pytest.mark.parametrize( + "color, num_colors, expected", + [ + ("Crimson", 1, ["Crimson"]), + ("DodgerBlue", 2, ["DodgerBlue", "DodgerBlue"]), + ("firebrick", 3, ["firebrick", "firebrick", "firebrick"]), + ], + ) + def test_user_input_named_color_string(self, color, num_colors, expected): + result = get_standard_colors(color=color, num_colors=num_colors) + assert result == expected + + @pytest.mark.parametrize("color", ["", [], (), Series([], dtype="object")]) + def test_empty_color_raises(self, color): + with pytest.raises(ValueError, match="Invalid color argument"): + get_standard_colors(color=color, num_colors=1) + + @pytest.mark.parametrize( + "color", + [ + "bad_color", + ("red", "green", "bad_color"), + (0.1,), + (0.1, 0.2), + (0.1, 0.2, 0.3, 0.4, 0.5), # must be either 3 or 4 floats + ], + ) + def test_bad_color_raises(self, color): + with pytest.raises(ValueError, match="Invalid color"): + get_standard_colors(color=color, num_colors=5)