From 370d64b1067331b413d641103a52bd4c636ac702 Mon Sep 17 00:00:00 2001 From: Nico de Vos Date: Thu, 10 Mar 2022 22:33:00 -0800 Subject: [PATCH] don't keep re-using initial frequencies and memberships --- kmodes/kmodes.py | 4 ++-- kmodes/kprototypes.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/kmodes/kmodes.py b/kmodes/kmodes.py index 10941de..d2f6985 100644 --- a/kmodes/kmodes.py +++ b/kmodes/kmodes.py @@ -309,7 +309,7 @@ def _k_modes_single(X, n_clusters, n_points, n_attrs, max_iter, dissim, init, in epoch_costs = [cost] while itr < max_iter and not converged: itr += 1 - centroids, moves = _k_modes_iter( + centroids, cl_attr_freq, membship, moves = _k_modes_iter( X, centroids, cl_attr_freq, @@ -357,7 +357,7 @@ def _k_modes_iter(X, centroids, cl_attr_freq, membship, dissim, random_state): X[rindx], rindx, old_clust, from_clust, cl_attr_freq, membship, centroids ) - return centroids, moves + return centroids, cl_attr_freq, membship, moves def _move_point_cat(point, ipoint, to_clust, from_clust, cl_attr_freq, diff --git a/kmodes/kprototypes.py b/kmodes/kprototypes.py index e0528d2..1c84f88 100644 --- a/kmodes/kprototypes.py +++ b/kmodes/kprototypes.py @@ -417,10 +417,10 @@ def _k_prototypes_single(Xnum, Xcat, nnumattrs, ncatattrs, n_clusters, n_points, epoch_costs = [cost] while itr < max_iter and not converged: itr += 1 - centroids, moves = _k_prototypes_iter(Xnum, Xcat, centroids, - cl_attr_sum, cl_memb_sum, cl_attr_freq, - membship, num_dissim, cat_dissim, gamma, - random_state) + centroids, cl_attr_sum, cl_memb_sum, cl_attr_freq, membship, moves = \ + _k_prototypes_iter(Xnum, Xcat, centroids, cl_attr_sum, cl_memb_sum, + cl_attr_freq, membship, num_dissim, cat_dissim, + gamma, random_state) # All points seen in this iteration labels, ncost = labels_cost(Xnum, Xcat, centroids, @@ -487,7 +487,7 @@ def _k_prototypes_iter(Xnum, Xcat, centroids, cl_attr_sum, cl_memb_sum, cl_attr_ cl_attr_freq, membship, centroids[1] ) - return centroids, moves + return centroids, cl_attr_sum, cl_memb_sum, cl_attr_freq, membship, moves def _move_point_num(point, to_clust, from_clust, cl_attr_sum, cl_memb_sum):