Skip to content
This repository has been archived by the owner on Sep 9, 2024. It is now read-only.

Commit

Permalink
[Models] Add sklearn.NearestNeighbors model
Browse files Browse the repository at this point in the history
  • Loading branch information
nshaud committed Sep 25, 2018
1 parent 55df732 commit 44e0417
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,18 @@ def convert_from_color(x):
save_model(clf, MODEL, DATASET)
prediction = clf.predict(scaler.transform(img.reshape(-1, N_BANDS)))
prediction = prediction.reshape(img.shape[:2])
elif MODEL == 'nearest':
X_train, y_train = build_dataset(img, train_gt,
ignored_labels=IGNORED_LABELS)
X_train, y_train = sklearn.utils.shuffle(X_train, y_train)
class_weight = 'balanced' if CLASS_BALANCING else None
clf = sklearn.neighbors.KNeighborsClassifier(weights='distance')
clf = sklearn.model_selection.GridSearchCV(clf, {'n_neighbors': [1, 3, 5, 10, 20]}, verbose=5, n_jobs=4)
clf.fit(X_train, y_train)
clf.fit(X_train, y_train)
save_model(clf, MODEL, DATASET)
prediction = clf.predict(img.reshape(-1, N_BANDS))
prediction = prediction.reshape(img.shape[:2])
else:
# Neural network
model, optimizer, loss, hyperparams = get_model(MODEL, **hyperparams)
Expand Down

0 comments on commit 44e0417

Please sign in to comment.