-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
43 lines (33 loc) · 1.22 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
"""Main module for GATech grading purposes"""
from typing import Tuple
import numpy as np
from image import rescale_image, scale_image
from k_means_numpy import KMeans
from k_medoids_numpy import KMedoids
from make_clusters import SAMPLE_DATA
def test_kmeans(pixels: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]:
"""
Convenience wrapper for GATech grading.
:param pixels: A numpy array of image data
:param k: Number of clusters to use
:return: A tuple of classes and centroids
"""
scaled_data = scale_image(pixels)
kmeans = KMeans(k=k)
kmeans.fit(scaled_data, verbose=0)
return kmeans.clusters, rescale_image(kmeans.centroids)
def test_kmedoids(pixels: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]:
"""
Convenience wrapper for GATech grading.
:param pixels: A numpy array of image data
:param k: Number of clusters to use
:return: A tuple of classes and centroids
"""
scaled_data = scale_image(pixels)
kmedoids = KMedoids(k=k)
kmedoids.fit(scaled_data, verbose=0)
return kmedoids.clusters, rescale_image(kmedoids.centroids)
if __name__ == '__main__':
classes, centroids = test_kmedoids(SAMPLE_DATA, 5)
print(classes)
print(centroids)