From a54ca13f73554aeb6b4db5e0eeeaab28b2070cca Mon Sep 17 00:00:00 2001 From: Nicolas Audebert Date: Fri, 18 May 2018 16:41:30 +0200 Subject: [PATCH] Add inference capabilities. --- main.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/main.py b/main.py index 74c2947..6da9099 100644 --- a/main.py +++ b/main.py @@ -27,6 +27,7 @@ # Visualization import seaborn as sns +import os from utils import metrics, convert_to_color_, convert_from_color_,\ display_dataset, display_predictions, explore_spectrums, plot_spectrums,\ sample_gt, build_dataset, show_results, compute_imf_weights @@ -314,3 +315,29 @@ def convert_from_color(x): if N_RUNS > 1: show_results(results, label_values=LABEL_VALUES, display=viz, agregated=True) + +if INFERENCE is not None: + img = open_file(INFERENCE)[:,:,:-2] + # Normalization + img = np.asarray(img, dtype='float32') + img = (img - np.min(img)) / (np.max(img) - np.min(img)) + if MODEL in ['SVM', 'SVM_grid', 'SGD']: + from sklearn.externals import joblib + model = joblib.load(CHECKPOINT) + X = scaler.transform(img.reshape(-1, N_BANDS)) + prediction = model.predict(X) + prediction = prediction.reshape(img.shape[:2]) + else: + model = get_model(MODEL, **kwargs)[0] + model.load_state_dict(torch.load(CHECKPOINT)) + probabilities = test(model, img, hyperparams) + prediction = np.argmax(probabilities, axis=-1) + + basename = os.path.basename(INFERENCE) + basename = str(model.__class__.__name__) + basename + dirname = os.path.dirname(INFERENCE) + filename = dirname + '/' + basename + '.tif' + io.imsave(filename, prediction) + basename = 'color_' + basename + filename = dirname + '/' + basename + '.tif' + io.imsave(filename, convert_to_color(prediction))