Skip to content

Commit

Permalink
TYP: improve type annotations for take_cmap_colors
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros committed Jan 24, 2025
1 parent 550ed13 commit f2ebc8f
Showing 1 changed file with 44 additions and 18 deletions.
62 changes: 44 additions & 18 deletions src/cmasher/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from itertools import chain
from pathlib import Path
from textwrap import dedent
from typing import TYPE_CHECKING, NewType
from typing import TYPE_CHECKING, NewType, overload

import matplotlib as mpl
import numpy as np
Expand All @@ -33,7 +33,7 @@
import os
import sys
from collections.abc import Callable, Iterator
from typing import Literal, Protocol, TypeAlias
from typing import Literal, Protocol, TypeAlias, TypeVar

from matplotlib.artist import Artist
from numpy.typing import NDArray
Expand All @@ -43,6 +43,9 @@
else:
from typing_extensions import Self

T = TypeVar("T", int, float)
RGB: TypeAlias = tuple[T, T, T]

class SupportsDunderLT(Protocol):
def __lt__(self, other: Self, /) -> bool: ...

Expand All @@ -51,6 +54,7 @@ def __gt__(self, other: Self, /) -> bool: ...

SupportsOrdering: TypeAlias = SupportsDunderLT | SupportsDunderGT


_HAS_VISCM = find_spec("viscm") is not None

# All declaration
Expand Down Expand Up @@ -78,12 +82,6 @@ def __gt__(self, other: Self, /) -> bool: ...
Category = NewType("Category", str)
Name = NewType("Name", str)

# Type aliases
RED: TypeAlias = float
GREEN: TypeAlias = float
BLUE: TypeAlias = float
RGB: TypeAlias = list[tuple[RED, GREEN, BLUE]]


# %% HELPER FUNCTIONS
# Define function for obtaining the sorting order for lightness ranking
Expand Down Expand Up @@ -1436,13 +1434,43 @@ def set_cmap_legend_entry(artist: Artist, label: str) -> None:


# Function to take N equally spaced colors from a colormap
@overload
def take_cmap_colors(
cmap: Colormap | Name,
N: int | None,
*,
cmap_range: tuple[float, float] = (0, 1),
return_fmt: Literal["float", "norm"] = "float",
) -> RGB[float]: ...


@overload
def take_cmap_colors(
cmap: Colormap | Name,
N: int | None,
*,
cmap_range: tuple[float, float] = (0, 1),
return_fmt: str = "float",
) -> RGB:
return_fmt: Literal["int", "8bit"],
) -> RGB[int]: ...


@overload
def take_cmap_colors(
cmap: Colormap | Name,
N: int | None,
*,
cmap_range: tuple[float, float] = (0, 1),
return_fmt: Literal["str", "hex"],
) -> list[str]: ...


def take_cmap_colors(
cmap: Colormap | Name,
N: int | None,
*,
cmap_range: tuple[float, float] = (0, 1),
return_fmt: Literal["float", "norm", "int", "8bit", "str", "hex"] = "float",
) -> RGB[float] | RGB[int] | list[str]:
"""
Takes `N` equally spaced colors from the provided colormap `cmap` and
returns them.
Expand Down Expand Up @@ -1514,9 +1542,6 @@ def take_cmap_colors(
that describe the same property, but have a different initial state.
"""
# Convert provided fmt to lowercase
return_fmt = return_fmt.lower()

# Obtain the colormap
if isinstance(cmap, str):
cmap = mpl.colormaps[cmap]
Expand Down Expand Up @@ -1544,12 +1569,13 @@ def take_cmap_colors(
colors = np.apply_along_axis(to_rgb, 1, colors) # type: ignore [call-overload]
if return_fmt in ("int", "8bit"):
colors = np.array(np.rint(colors * 255), dtype=int)
colors = list(map(tuple, colors))
return [(int(c[0]), int(c[1]), int(c[2])) for c in colors] # type: ignore [misc]
else:
return [(float(c[0]), float(c[1]), float(c[2])) for c in colors] # type: ignore [misc]
elif return_fmt in ("str", "hex"):
return [to_hex(x).upper() for x in colors]
else:
colors = [to_hex(x).upper() for x in colors]

# Return colors
return colors
raise ValueError(return_fmt)


# Function to view what a colormap looks like
Expand Down

0 comments on commit f2ebc8f

Please sign in to comment.