From 44e0417faa3c0212118f2286f955cd972e1f2551 Mon Sep 17 00:00:00 2001 From: Nicolas Audebert Date: Tue, 25 Sep 2018 15:40:20 +0200 Subject: [PATCH] [Models] Add sklearn.NearestNeighbors model --- main.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/main.py b/main.py index 9c819e5..7e12066 100644 --- a/main.py +++ b/main.py @@ -269,6 +269,18 @@ def convert_from_color(x): save_model(clf, MODEL, DATASET) prediction = clf.predict(scaler.transform(img.reshape(-1, N_BANDS))) prediction = prediction.reshape(img.shape[:2]) + elif MODEL == 'nearest': + X_train, y_train = build_dataset(img, train_gt, + ignored_labels=IGNORED_LABELS) + X_train, y_train = sklearn.utils.shuffle(X_train, y_train) + class_weight = 'balanced' if CLASS_BALANCING else None + clf = sklearn.neighbors.KNeighborsClassifier(weights='distance') + clf = sklearn.model_selection.GridSearchCV(clf, {'n_neighbors': [1, 3, 5, 10, 20]}, verbose=5, n_jobs=4) + clf.fit(X_train, y_train) + clf.fit(X_train, y_train) + save_model(clf, MODEL, DATASET) + prediction = clf.predict(img.reshape(-1, N_BANDS)) + prediction = prediction.reshape(img.shape[:2]) else: # Neural network model, optimizer, loss, hyperparams = get_model(MODEL, **hyperparams)