Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add colormap arg to draw_label #217

Merged
merged 1 commit into from
Aug 14, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions labelme/utils/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,28 @@ def bitget(byteval, idx):
return cmap


def _validate_colormap(colormap, n_labels):
if colormap is None:
colormap = label_colormap(n_labels)
else:
assert colormap.shape == (colormap.shape[0], 3), \
'colormap must be sequence of RGB values'
assert 0 <= colormap.min() and colormap.max() <= 1, \
'colormap must ranges 0 to 1'
return colormap


# similar function as skimage.color.label2rgb
def label2rgb(lbl, img=None, n_labels=None, alpha=0.5, thresh_suppress=0):
def label2rgb(
lbl, img=None, n_labels=None, alpha=0.5, thresh_suppress=0, colormap=None,
):
if n_labels is None:
n_labels = len(np.unique(lbl))

cmap = label_colormap(n_labels)
cmap = (cmap * 255).astype(np.uint8)
colormap = _validate_colormap(colormap, n_labels)
colormap = (colormap * 255).astype(np.uint8)

lbl_viz = cmap[lbl]
lbl_viz = colormap[lbl]
lbl_viz[lbl == -1] = (0, 0, 0) # unlabeled

if img is not None:
Expand All @@ -48,8 +61,18 @@ def label2rgb(lbl, img=None, n_labels=None, alpha=0.5, thresh_suppress=0):
return lbl_viz


def draw_label(label, img=None, label_names=None, colormap=None):
def draw_label(label, img=None, label_names=None, colormap=None, **kwargs):
"""Draw pixel-wise label with colorization and label names.

label: ndarray, (H, W)
Pixel-wise labels to colorize.
img: ndarray, (H, W, 3), optional
Image on which the colorized label will be drawn.
label_names: iterable
List of label names.
"""
import matplotlib.pyplot as plt

backend_org = plt.rcParams['backend']
plt.switch_backend('agg')

Expand All @@ -62,10 +85,11 @@ def draw_label(label, img=None, label_names=None, colormap=None):
if label_names is None:
label_names = [str(l) for l in range(label.max() + 1)]

if colormap is None:
colormap = label_colormap(len(label_names))
colormap = _validate_colormap(colormap, len(label_names))

label_viz = label2rgb(label, img, n_labels=len(label_names))
label_viz = label2rgb(
label, img, n_labels=len(label_names), colormap=colormap, **kwargs
)
plt.imshow(label_viz)
plt.axis('off')

Expand Down