Skip to content
This repository has been archived by the owner on Sep 9, 2024. It is now read-only.

Commit

Permalink
Add inference capabilities.
Browse files Browse the repository at this point in the history
  • Loading branch information
nshaud committed May 18, 2018
1 parent e812c44 commit a54ca13
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
# Visualization
import seaborn as sns

import os
from utils import metrics, convert_to_color_, convert_from_color_,\
display_dataset, display_predictions, explore_spectrums, plot_spectrums,\
sample_gt, build_dataset, show_results, compute_imf_weights
Expand Down Expand Up @@ -314,3 +315,29 @@ def convert_from_color(x):
if N_RUNS > 1:
show_results(results, label_values=LABEL_VALUES,
display=viz, agregated=True)

if INFERENCE is not None:
img = open_file(INFERENCE)[:,:,:-2]
# Normalization
img = np.asarray(img, dtype='float32')
img = (img - np.min(img)) / (np.max(img) - np.min(img))
if MODEL in ['SVM', 'SVM_grid', 'SGD']:
from sklearn.externals import joblib
model = joblib.load(CHECKPOINT)
X = scaler.transform(img.reshape(-1, N_BANDS))
prediction = model.predict(X)
prediction = prediction.reshape(img.shape[:2])
else:
model = get_model(MODEL, **kwargs)[0]
model.load_state_dict(torch.load(CHECKPOINT))
probabilities = test(model, img, hyperparams)
prediction = np.argmax(probabilities, axis=-1)

basename = os.path.basename(INFERENCE)
basename = str(model.__class__.__name__) + basename
dirname = os.path.dirname(INFERENCE)
filename = dirname + '/' + basename + '.tif'
io.imsave(filename, prediction)
basename = 'color_' + basename
filename = dirname + '/' + basename + '.tif'
io.imsave(filename, convert_to_color(prediction))

0 comments on commit a54ca13

Please sign in to comment.