Skip to content
This repository has been archived by the owner on Sep 9, 2024. It is now read-only.

Commit

Permalink
Simplify visdom management by removing overkill code
Browse files Browse the repository at this point in the history
  • Loading branch information
nshaud committed May 25, 2018
1 parent 1dc2c76 commit de19b6c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 63 deletions.
7 changes: 3 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,15 +297,14 @@ def convert_from_color(x):
prediction[mask] = 0

color_prediction = convert_to_color(prediction)
display_predictions(color_prediction, color_gt, display=viz)
display_predictions(color_prediction, color_gt, viz)

run_results = metrics(prediction, test_gt, ignored_labels=IGNORED_LABELS, n_classes=N_CLASSES)
results.append(run_results)
show_results(run_results, label_values=LABEL_VALUES, display=viz)
show_results(run_results, viz, label_values=LABEL_VALUES)

if N_RUNS > 1:
show_results(results, label_values=LABEL_VALUES,
display=viz, agregated=True)
show_results(results, viz, label_values=LABEL_VALUES, agregated=True)

if INFERENCE is not None:
img = open_file(INFERENCE)[:,:,:-2]
Expand Down
83 changes: 24 additions & 59 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
import seaborn as sns
import itertools
import spectral
try:
import visdom
except:
pass
import visdom

# Torch
import torch
import torch.nn as nn
Expand Down Expand Up @@ -62,15 +60,13 @@ def convert_from_color_(arr_3d, palette=None):
return arr_2d


def display_predictions(pred, gt, display=None):
d_type = get_display_type(display)
if d_type == 'visdom':
display.images([np.transpose(pred, (2, 0, 1)),
np.transpose(gt, (2, 0, 1))],
nrow=2,
opts={'caption': "Prediction vs. ground truth"})
def display_predictions(pred, gt, vis):
vis.images([np.transpose(pred, (2, 0, 1)),
np.transpose(gt, (2, 0, 1))],
nrow=2,
opts={'caption': "Prediction vs. ground truth"})

def display_dataset(img, gt, bands, labels, palette, display=None):
def display_dataset(img, gt, bands, labels, palette, vis):
"""Display the specified dataset.
Args:
Expand All @@ -82,22 +78,19 @@ def display_dataset(img, gt, bands, labels, palette, display=None):
display (optional): type of display, if any
"""
d_type = get_display_type(display)
print("Image has dimensions {}x{} and {} channels".format(*img.shape))
rgb = spectral.get_rgb(img, bands)
rgb /= np.max(rgb)
rgb = np.asarray(255 * rgb, dtype='uint8')

# Display the RGB composite image
if d_type == 'visdom':
caption = "RGB (bands {}, {}, {}) and ground truth".format(*bands)
# send to visdom server
display.images([np.transpose(rgb, (2, 0, 1)),
np.transpose(convert_to_color_(gt, palette=palette),
(2, 0, 1))
],
nrow=2,
opts={'caption': caption})
caption = "RGB (bands {}, {}, {}) and ground truth".format(*bands)
# send to visdom server
vis.images([np.transpose(rgb, (2, 0, 1)),
np.transpose(convert_to_color_(gt, palette=palette), (2, 0, 1))
],
nrow=2,
opts={'caption': caption})

def explore_spectrums(img, complete_gt, class_names,
ignored_labels=None, display=None):
Expand All @@ -114,8 +107,7 @@ def explore_spectrums(img, complete_gt, class_names,
"""
mean_spectrums = {}
d_type = get_display_type(display)

d_type = 'visdom'
for c in np.unique(complete_gt):
if c in ignored_labels:
continue
Expand Down Expand Up @@ -155,7 +147,7 @@ def plot_spectrums(spectrums, display=None):
palette = sns.color_palette("hls", len(spectrums.keys()))
sns.set_palette(palette)

d_type = get_display_type(display)
d_type = 'visdom'
if d_type == 'visdom':
pass
elif d_type == 'plt':
Expand Down Expand Up @@ -339,22 +331,7 @@ def metrics(prediction, target, ignored_labels=[], n_classes=None):
return results


def get_display_type(display):
if display:
display_type = 'plt'
try:
if isinstance(display, visdom.Visdom):
display_type = 'visdom'
except NameError:
pass
else:
display_type = 'print'
return display_type


def show_results(results, label_values=None,
display=None, agregated=False):
d_type = get_display_type(display)
def show_results(results, vis, label_values=None, agregated=False):
text = ""

if agregated:
Expand All @@ -372,19 +349,10 @@ def show_results(results, label_values=None,
F1scores = results["F1 scores"]
kappa = results["Kappa"]

if d_type == 'visdom':
display.heatmap(cm, opts={'rownames': label_values,
'columnnames': label_values})
elif d_type == 'plt':
plt.rcParams.update({'font.size': 10})
sns.heatmap(cm, annot=True, square=True)
plt.title("Confusion matrix")
plt.show()
plt.rcParams.update({'font.size': 22})
elif d_type == 'print':
text += "Confusion matrix :\n"
text += str(cm)
text += "---\n"
vis.heatmap(cm, opts={'rownames': label_values, 'columnnames': label_values})
text += "Confusion matrix :\n"
text += str(cm)
text += "---\n"

if agregated:
text += ("Accuracy: {:.03f} +- {:.03f}\n".format(np.mean(accuracies),
Expand All @@ -409,11 +377,8 @@ def show_results(results, label_values=None,
else:
text += "Kappa: {:.03f}\n".format(kappa)

if d_type == 'visdom':
text = text.replace('\n', '<br/>')
display.text(text)
else:
print(text)
vis.text(text.replace('\n', '<br/>'))
print(text)


def sample_gt(gt, percentage, mode='random'):
Expand Down

0 comments on commit de19b6c

Please sign in to comment.