diff --git a/kmodes/kmodes.py b/kmodes/kmodes.py index 2bb6b7a..56c3581 100644 --- a/kmodes/kmodes.py +++ b/kmodes/kmodes.py @@ -127,7 +127,8 @@ def fit(self, X, y=None, sample_weight=None, **kwargs): X = pandas_to_numpy(X) random_state = check_random_state(self.random_state) - _validate_sample_weight(sample_weight, n_samples=X.shape[0]) + _validate_sample_weight(sample_weight, n_samples=X.shape[0], + n_clusters=self.n_clusters) self._enc_cluster_centroids, self._enc_map, self.labels_, self.cost_, \ self.n_iter_, self.epoch_costs_ = k_modes( @@ -407,7 +408,7 @@ def _move_point_cat(point, ipoint, to_clust, from_clust, cl_attr_freq, return cl_attr_freq, membship, centroids -def _validate_sample_weight(sample_weight, n_samples): +def _validate_sample_weight(sample_weight, n_samples, n_clusters): if sample_weight is not None: if len(sample_weight) != n_samples: raise ValueError("sample_weight should be of equal size as samples.") @@ -418,3 +419,6 @@ def _validate_sample_weight(sample_weight, n_samples): raise ValueError("sample_weight elements should either be int or floats.") if any(sample < 0 for sample in sample_weight): raise ValueError("sample_weight elements should be positive.") + if sum([x > 0 for x in sample_weight]) < n_clusters: + raise ValueError("Number of non-zero sample_weight elements should be " + "larger than the number of clusters.") diff --git a/kmodes/kprototypes.py b/kmodes/kprototypes.py index c159d32..3d4b57f 100644 --- a/kmodes/kprototypes.py +++ b/kmodes/kprototypes.py @@ -152,7 +152,8 @@ def fit(self, X, y=None, categorical=None, sample_weight=None): X = pandas_to_numpy(X) random_state = check_random_state(self.random_state) - kmodes._validate_sample_weight(sample_weight, n_samples=X.shape[0]) + kmodes._validate_sample_weight(sample_weight, n_samples=X.shape[0], + n_clusters=self.n_clusters) # If self.gamma is None, gamma will be automatically determined from # the data. The function below returns its value. @@ -175,7 +176,7 @@ def fit(self, X, y=None, categorical=None, sample_weight=None): return self - def predict(self, X, categorical=None): + def predict(self, X, categorical=None, **kwargs): """Predict the closest cluster each sample in X belongs to. Parameters diff --git a/kmodes/tests/test_kmodes.py b/kmodes/tests/test_kmodes.py index 7fc5966..2188bee 100644 --- a/kmodes/tests/test_kmodes.py +++ b/kmodes/tests/test_kmodes.py @@ -567,3 +567,11 @@ def test_k_modes_sample_weight_unchanged(self): tuple_pairs = zip(sorted(expected), sorted(factual)) for tuple_expected, tuple_factual in tuple_pairs: self.assertAlmostEqual(tuple_expected, tuple_factual) + + def test_kmodes_fit_predict(self): + """Test whether fit_predict interface works the same as fit and predict.""" + kmodes = KModes(n_clusters=4, init='Cao', random_state=42) + sample_weight = [0.5] * TEST_DATA.shape[0] + data1 = kmodes.fit_predict(TEST_DATA, sample_weight=sample_weight) + data2 = kmodes.fit(TEST_DATA, sample_weight=sample_weight).predict(TEST_DATA) + assert_cluster_splits_equal(data1, data2) diff --git a/kmodes/tests/test_kprototypes.py b/kmodes/tests/test_kprototypes.py index a57299b..5fde199 100644 --- a/kmodes/tests/test_kprototypes.py +++ b/kmodes/tests/test_kprototypes.py @@ -337,17 +337,26 @@ def test_kprototypes_ninit(self): def test_kprototypes_sample_weights_validation(self): kproto = kprototypes.KPrototypes(n_clusters=4, init='Cao', verbose=2) sample_weight_too_few = [1] * 11 - with self.assertRaisesRegex(ValueError, "sample_weight should be of equal size as samples."): + with self.assertRaisesRegex( + ValueError, + "sample_weight should be of equal size as samples." + ): kproto.fit_predict( STOCKS, categorical=[1, 2], sample_weight=sample_weight_too_few ) sample_weight_negative = [-1] + [1] * 11 - with self.assertRaisesRegex(ValueError, "sample_weight elements should be positive."): + with self.assertRaisesRegex( + ValueError, + "sample_weight elements should be positive." + ): kproto.fit_predict( STOCKS, categorical=[1, 2], sample_weight=sample_weight_negative ) sample_weight_non_numerical = [None] + [1] * 11 - with self.assertRaisesRegex(ValueError, "sample_weight elements should either be int or floats."): + with self.assertRaisesRegex( + ValueError, + "sample_weight elements should either be int or floats." + ): kproto.fit_predict( STOCKS, categorical=[1, 2], sample_weight=sample_weight_non_numerical ) @@ -362,7 +371,21 @@ def test_k_prototypes_sample_weight_all_but_one_zero(self): model = kproto.fit( STOCKS[:n_samples, :], categorical=[1, 2], sample_weight=sample_weight ) - self.assertTrue((model.cluster_centroids_[0, :] == STOCKS[indicator, :]).all()) + np.testing.assert_array_equal( + model.cluster_centroids_[0, :], + STOCKS[indicator, :] + ) + + def test_k_prototypes_sample_weight_not_enough_non_zero(self): + kproto = kprototypes.KPrototypes(n_clusters=2, init='Cao', random_state=42) + sample_weight = np.zeros(STOCKS.shape[0]) + sample_weight[0] = 1 + with self.assertRaisesRegex( + ValueError, + "Number of non-zero sample_weight elements should be larger " + "than the number of clusters." + ): + kproto.fit(STOCKS, categorical=[1, 2], sample_weight=sample_weight) def test_k_prototypes_sample_weight_unchanged(self): """Test whether centroid definition remains unchanged when scaling uniformly.""" @@ -390,3 +413,11 @@ def test_k_prototypes_sample_weight_unchanged(self): for index in categorical: self.assertTrue(tuple_expected[index] == tuple_factual[index]) + def test_kmodes_fit_predict_equality(self): + """Test whether fit_predict interface works the same as fit and predict.""" + kproto = kprototypes.KPrototypes(n_clusters=3, init='Cao', random_state=42) + sample_weight = [0.5] * STOCKS.shape[0] + model1 = kproto.fit(STOCKS, categorical=[1, 2], sample_weight=sample_weight) + data1 = model1.predict(STOCKS, categorical=[1, 2]) + data2 = kproto.fit_predict(STOCKS, categorical=[1, 2], sample_weight=sample_weight) + assert_cluster_splits_equal(data1, data2)