From 706c70619f0c6679e4e5da1bfc040f1bcc465e16 Mon Sep 17 00:00:00 2001 From: Chinmaya Pancholi Date: Thu, 8 Jun 2017 19:19:37 -0700 Subject: [PATCH 01/15] added sklearn wrapper for author topic model --- .../sklearn_wrapper_gensim_atmodel.py | 84 +++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py diff --git a/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py b/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py new file mode 100644 index 0000000000..67cfaa92e6 --- /dev/null +++ b/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (C) 2011 Radim Rehurek +# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html +# +""" +Scikit learn interface for gensim for easy use of gensim with scikit-learn +Follows scikit-learn API conventions +""" +import numpy as np + +from gensim import models +from gensim.sklearn_integration import base_sklearn_wrapper +from sklearn.base import TransformerMixin, BaseEstimator + + +class SklearnWrapperATModel(models.AuthorTopicModel, base_sklearn_wrapper.BaseSklearnWrapper, TransformerMixin, BaseEstimator): + """ + Base AuthorTopic module + """ + + def __init__(self, num_topics=100, id2word=None, author2doc=None, doc2author=None, + chunksize=2000, passes=1, iterations=50, decay=0.5, offset=1.0, + alpha='symmetric', eta='symmetric', update_every=1, eval_every=10, + gamma_threshold=0.001, serialized=False, serialization_path=None, + minimum_probability=0.01, random_state=None): + """ + Sklearn wrapper for AuthorTopic model. Class derived from gensim.models.AuthorTopicModel + """ + self.corpus = None + self.num_topics = num_topics + self.id2word = id2word + self.author2doc = author2doc + self.doc2author = doc2author + self.chunksize = chunksize + self.passes = passes + self.iterations = iterations + self.decay = decay + self.offset = offset + self.alpha = alpha + self.eta = eta + self.update_every = update_every + self.eval_every = eval_every + self.gamma_threshold = gamma_threshold + self.serialized = serialized + self.serialization_path = serialization_path + self.minimum_probability = minimum_probability + self.random_state = random_state + + def get_params(self, deep=True): + """ + Returns all parameters as dictionary. + """ + return {"corpus": self.corpus, "num_topics": self.num_topics, "id2word": self.id2word, + "author2doc": self.author2doc, "doc2author": self.doc2author, "chunksize": self.chunksize, + "passes": self.passes, "iterations": self.iterations, "decay": self.decay, + "offset": self.offset, "alpha": self.alpha, "eta": self.eta, "update_every": self.update_every, + "eval_every": self.eval_every, "gamma_threshold": self.gamma_threshold, + "serialized": self.serialized, "serialization_path": self.serialization_path, + "minimum_probability": self.minimum_probability, "random_state": self.random_state} + + def set_params(self, **parameters): + """ + Set all parameters. + """ + super(SklearnWrapperATModel, self).set_params(**parameters) + + def fit(self, X, y=None): + """ + Fit the model according to the given training data. + """ + pass + + def transform(self, docs): + """ + """ + pass + + def partial_fit(self, X): + """ + Train model over X. + """ + pass From 441e1c68988464e7c2cead7184036c9e4073557f Mon Sep 17 00:00:00 2001 From: Chinmaya Pancholi Date: Thu, 8 Jun 2017 20:09:22 -0700 Subject: [PATCH 02/15] added 'fit', 'partial_fit', 'transform' functions --- .../sklearn_wrapper_gensim_atmodel.py | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py b/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py index 67cfaa92e6..4bf6f7fed2 100644 --- a/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py +++ b/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py @@ -8,8 +8,6 @@ Scikit learn interface for gensim for easy use of gensim with scikit-learn Follows scikit-learn API conventions """ -import numpy as np - from gensim import models from gensim.sklearn_integration import base_sklearn_wrapper from sklearn.base import TransformerMixin, BaseEstimator @@ -69,16 +67,30 @@ def set_params(self, **parameters): def fit(self, X, y=None): """ Fit the model according to the given training data. + Calls gensim.models.AuthorTopicModel: + >>> gensim.models.AuthorTopicModel(corpus=self.corpus, num_topics=self.num_topics, id2word=self.id2word, author2doc=self.author2doc, doc2author=self.doc2author, + chunksize=self.chunksize, passes=self.passes, iterations=self.iterations, decay=self.decay, offset=self.offset, alpha=self.alpha, eta=self.eta update_every=self.update_every, + eval_every=self.eval_every, gamma_threshold=self.gamma_threshold, serialized=self.serialized, serialization_path=self.serialization_path, minimum_probability=self.minimum_probability, random_state=self.random_state) """ - pass + self.corpus = X + + super(SklearnWrapperATModel, self).__init__( + corpus=self.corpus, num_topics=self.num_topics, id2word=self.id2word, author2doc=self.author2doc, + doc2author=self.doc2author, chunksize=self.chunksize, passes=self.passes, iterations=self.iterations, + decay=self.decay, offset=self.offset, alpha=self.alpha, eta=self.eta, update_every=self.update_every, + eval_every=self.eval_every, gamma_threshold=self.gamma_threshold, serialized=self.serialized, + serialization_path=self.serialization_path, minimum_probability=self.minimum_probability, random_state=self.random_state + ) - def transform(self, docs): + def transform(self, author_names): """ + Return topic distribution for input author as a list of + (topic_id, topic_probabiity) 2-tuples. """ - pass + return self[author_names] - def partial_fit(self, X): + def partial_fit(self, X, author2doc=None, doc2author=None): """ Train model over X. """ - pass + self.update(corpus=X, author2doc=author2doc, doc2author=doc2author) From 921cd8600d122151395e88225311bd2de4cccd4a Mon Sep 17 00:00:00 2001 From: Chinmaya Pancholi Date: Thu, 8 Jun 2017 20:41:02 -0700 Subject: [PATCH 03/15] added unit-tests for ATModel --- gensim/test/test_sklearn_integration.py | 39 +++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/gensim/test/test_sklearn_integration.py b/gensim/test/test_sklearn_integration.py index 08e2ae9fe7..5ee6435482 100644 --- a/gensim/test/test_sklearn_integration.py +++ b/gensim/test/test_sklearn_integration.py @@ -16,6 +16,7 @@ from gensim.sklearn_integration.sklearn_wrapper_gensim_ldamodel import SklearnWrapperLdaModel from gensim.sklearn_integration.sklearn_wrapper_gensim_lsimodel import SklearnWrapperLsiModel +from gensim.sklearn_integration.sklearn_wrapper_gensim_atmodel import SklearnWrapperATModel from gensim.corpora import Dictionary from gensim import matutils @@ -35,6 +36,12 @@ ] dictionary = Dictionary(texts) corpus = [dictionary.doc2bow(text) for text in texts] +author2doc = {'john': [0, 1, 2, 3, 4, 5, 6], 'jane': [2, 3, 4, 5, 6, 7, 8], 'jack': [0, 2, 4, 6, 8], 'jill': [1, 3, 5, 7]} + +texts_new = texts[0:3] +author2doc_new = {'jill': [0], 'bob': [0, 1], 'sally': [1, 2]} +dictionary_new = Dictionary(texts_new) +corpus_new = [dictionary_new.doc2bow(text) for text in texts_new] class TestSklearnLDAWrapper(unittest.TestCase): @@ -192,5 +199,37 @@ def testSetGetParams(self): self.assertEqual(model_params[key], param_dict[key]) +class TestSklearnATModelWrapper(unittest.TestCase): + def setUp(self): + self.model = SklearnWrapperATModel(id2word=dictionary, author2doc=author2doc, num_topics=2, passes=100) + self.model.fit(corpus) + + def testTransform(self): + jill_topics = self.model['jill'] + jill_topics = matutils.sparse2full(jill_topics, self.model.num_topics) + self.assertTrue(all(jill_topics > 0)) + + def testPartialFit(self): + self.model.partial_fit(corpus_new, author2doc=author2doc_new) + + # Did we learn something about Sally? + sally_topics = self.model.get_author_topics('sally') + sally_topics = matutils.sparse2full(sally_topics, self.model.num_topics) + self.assertTrue(all(sally_topics > 0)) + + def testSetGetParams(self): + # updating only one param + self.model.set_params(num_topics=3) + model_params = self.model.get_params() + self.assertEqual(model_params["num_topics"], 3) + + # updating multiple params + param_dict = {"passes": 5, "iterations": 10} + self.model.set_params(**param_dict) + model_params = self.model.get_params() + for key in param_dict.keys(): + self.assertEqual(model_params[key], param_dict[key]) + + if __name__ == '__main__': unittest.main() From ed6f5a357098cf3da90e5d4b2b2d8e14c7ffc9c6 Mon Sep 17 00:00:00 2001 From: Chinmaya Pancholi Date: Tue, 13 Jun 2017 02:38:12 -0700 Subject: [PATCH 04/15] refactored code acc. to composite design pattern --- .../sklearn_wrapper_gensim_atmodel.py | 26 +++++++++---------- gensim/test/test_sklearn_integration.py | 4 +-- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py b/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py index 4bf6f7fed2..c213c486b9 100644 --- a/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py +++ b/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py @@ -3,17 +3,19 @@ # # Copyright (C) 2011 Radim Rehurek # Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html -# + """ Scikit learn interface for gensim for easy use of gensim with scikit-learn Follows scikit-learn API conventions """ + +from sklearn.base import TransformerMixin, BaseEstimator + from gensim import models from gensim.sklearn_integration import base_sklearn_wrapper -from sklearn.base import TransformerMixin, BaseEstimator -class SklearnWrapperATModel(models.AuthorTopicModel, base_sklearn_wrapper.BaseSklearnWrapper, TransformerMixin, BaseEstimator): +class SklearnWrapperATModel(base_sklearn_wrapper.BaseSklearnWrapper, TransformerMixin, BaseEstimator): """ Base AuthorTopic module """ @@ -27,6 +29,7 @@ def __init__(self, num_topics=100, id2word=None, author2doc=None, doc2author=Non Sklearn wrapper for AuthorTopic model. Class derived from gensim.models.AuthorTopicModel """ self.corpus = None + self.model = None self.num_topics = num_topics self.id2word = id2word self.author2doc = author2doc @@ -73,24 +76,21 @@ def fit(self, X, y=None): eval_every=self.eval_every, gamma_threshold=self.gamma_threshold, serialized=self.serialized, serialization_path=self.serialization_path, minimum_probability=self.minimum_probability, random_state=self.random_state) """ self.corpus = X - - super(SklearnWrapperATModel, self).__init__( - corpus=self.corpus, num_topics=self.num_topics, id2word=self.id2word, author2doc=self.author2doc, - doc2author=self.doc2author, chunksize=self.chunksize, passes=self.passes, iterations=self.iterations, - decay=self.decay, offset=self.offset, alpha=self.alpha, eta=self.eta, update_every=self.update_every, - eval_every=self.eval_every, gamma_threshold=self.gamma_threshold, serialized=self.serialized, - serialization_path=self.serialization_path, minimum_probability=self.minimum_probability, random_state=self.random_state - ) + self.model = models.AuthorTopicModel(corpus=self.corpus, num_topics=self.num_topics, id2word=self.id2word, + author2doc=self.author2doc, doc2author=self.doc2author, chunksize=self.chunksize, passes=self.passes, + iterations=self.iterations, decay=self.decay, offset=self.offset, alpha=self.alpha, eta=self.eta, + update_every=self.update_every, eval_every=self.eval_every, gamma_threshold=self.gamma_threshold, serialized=self.serialized, + serialization_path=self.serialization_path, minimum_probability=self.minimum_probability, random_state=self.random_state) def transform(self, author_names): """ Return topic distribution for input author as a list of (topic_id, topic_probabiity) 2-tuples. """ - return self[author_names] + return self.model[author_names] def partial_fit(self, X, author2doc=None, doc2author=None): """ Train model over X. """ - self.update(corpus=X, author2doc=author2doc, doc2author=doc2author) + self.model.update(corpus=X, author2doc=author2doc, doc2author=doc2author) diff --git a/gensim/test/test_sklearn_integration.py b/gensim/test/test_sklearn_integration.py index 5ee6435482..4942c35699 100644 --- a/gensim/test/test_sklearn_integration.py +++ b/gensim/test/test_sklearn_integration.py @@ -205,7 +205,7 @@ def setUp(self): self.model.fit(corpus) def testTransform(self): - jill_topics = self.model['jill'] + jill_topics = self.model.transform('jill') jill_topics = matutils.sparse2full(jill_topics, self.model.num_topics) self.assertTrue(all(jill_topics > 0)) @@ -213,7 +213,7 @@ def testPartialFit(self): self.model.partial_fit(corpus_new, author2doc=author2doc_new) # Did we learn something about Sally? - sally_topics = self.model.get_author_topics('sally') + sally_topics = self.model.model.get_author_topics('sally') sally_topics = matutils.sparse2full(sally_topics, self.model.num_topics) self.assertTrue(all(sally_topics > 0)) From 451536b1ba5c3260b108cb5ba1fa6ee842fc8a28 Mon Sep 17 00:00:00 2001 From: Chinmaya Pancholi Date: Wed, 14 Jun 2017 00:57:04 -0700 Subject: [PATCH 05/15] refactored wrapper and tests --- .../sklearn_wrapper_gensim_atmodel.py | 12 ++++++------ gensim/test/test_sklearn_integration.py | 8 ++++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py b/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py index c213c486b9..97ee3c2db7 100644 --- a/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py +++ b/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py @@ -15,7 +15,7 @@ from gensim.sklearn_integration import base_sklearn_wrapper -class SklearnWrapperATModel(base_sklearn_wrapper.BaseSklearnWrapper, TransformerMixin, BaseEstimator): +class SklATModel(base_sklearn_wrapper.BaseSklearnWrapper, TransformerMixin, BaseEstimator): """ Base AuthorTopic module """ @@ -28,8 +28,8 @@ def __init__(self, num_topics=100, id2word=None, author2doc=None, doc2author=Non """ Sklearn wrapper for AuthorTopic model. Class derived from gensim.models.AuthorTopicModel """ + self.__model = None self.corpus = None - self.model = None self.num_topics = num_topics self.id2word = id2word self.author2doc = author2doc @@ -65,7 +65,7 @@ def set_params(self, **parameters): """ Set all parameters. """ - super(SklearnWrapperATModel, self).set_params(**parameters) + super(SklATModel, self).set_params(**parameters) def fit(self, X, y=None): """ @@ -76,7 +76,7 @@ def fit(self, X, y=None): eval_every=self.eval_every, gamma_threshold=self.gamma_threshold, serialized=self.serialized, serialization_path=self.serialization_path, minimum_probability=self.minimum_probability, random_state=self.random_state) """ self.corpus = X - self.model = models.AuthorTopicModel(corpus=self.corpus, num_topics=self.num_topics, id2word=self.id2word, + self.__model = models.AuthorTopicModel(corpus=self.corpus, num_topics=self.num_topics, id2word=self.id2word, author2doc=self.author2doc, doc2author=self.doc2author, chunksize=self.chunksize, passes=self.passes, iterations=self.iterations, decay=self.decay, offset=self.offset, alpha=self.alpha, eta=self.eta, update_every=self.update_every, eval_every=self.eval_every, gamma_threshold=self.gamma_threshold, serialized=self.serialized, @@ -87,10 +87,10 @@ def transform(self, author_names): Return topic distribution for input author as a list of (topic_id, topic_probabiity) 2-tuples. """ - return self.model[author_names] + return self.__model[author_names] def partial_fit(self, X, author2doc=None, doc2author=None): """ Train model over X. """ - self.model.update(corpus=X, author2doc=author2doc, doc2author=doc2author) + self.__model.update(corpus=X, author2doc=author2doc, doc2author=doc2author) diff --git a/gensim/test/test_sklearn_integration.py b/gensim/test/test_sklearn_integration.py index 4942c35699..3173d546f8 100644 --- a/gensim/test/test_sklearn_integration.py +++ b/gensim/test/test_sklearn_integration.py @@ -16,7 +16,7 @@ from gensim.sklearn_integration.sklearn_wrapper_gensim_ldamodel import SklearnWrapperLdaModel from gensim.sklearn_integration.sklearn_wrapper_gensim_lsimodel import SklearnWrapperLsiModel -from gensim.sklearn_integration.sklearn_wrapper_gensim_atmodel import SklearnWrapperATModel +from gensim.sklearn_integration.sklearn_wrapper_gensim_atmodel import SklATModel from gensim.corpora import Dictionary from gensim import matutils @@ -199,9 +199,9 @@ def testSetGetParams(self): self.assertEqual(model_params[key], param_dict[key]) -class TestSklearnATModelWrapper(unittest.TestCase): +class TestSklATModelWrapper(unittest.TestCase): def setUp(self): - self.model = SklearnWrapperATModel(id2word=dictionary, author2doc=author2doc, num_topics=2, passes=100) + self.model = SklATModel(id2word=dictionary, author2doc=author2doc, num_topics=2, passes=100) self.model.fit(corpus) def testTransform(self): @@ -213,7 +213,7 @@ def testPartialFit(self): self.model.partial_fit(corpus_new, author2doc=author2doc_new) # Did we learn something about Sally? - sally_topics = self.model.model.get_author_topics('sally') + sally_topics = self.model.transform('sally') sally_topics = matutils.sparse2full(sally_topics, self.model.num_topics) self.assertTrue(all(sally_topics > 0)) From 497de9b2aa92b926409fc930930d90a26557e3ba Mon Sep 17 00:00:00 2001 From: Chinmaya Pancholi Date: Wed, 14 Jun 2017 15:55:24 +0530 Subject: [PATCH 06/15] removed 'self.corpus' attribute --- .../sklearn_wrapper_gensim_atmodel.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py b/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py index 97ee3c2db7..0206613022 100644 --- a/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py +++ b/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py @@ -29,7 +29,6 @@ def __init__(self, num_topics=100, id2word=None, author2doc=None, doc2author=Non Sklearn wrapper for AuthorTopic model. Class derived from gensim.models.AuthorTopicModel """ self.__model = None - self.corpus = None self.num_topics = num_topics self.id2word = id2word self.author2doc = author2doc @@ -53,7 +52,7 @@ def get_params(self, deep=True): """ Returns all parameters as dictionary. """ - return {"corpus": self.corpus, "num_topics": self.num_topics, "id2word": self.id2word, + return {"num_topics": self.num_topics, "id2word": self.id2word, "author2doc": self.author2doc, "doc2author": self.doc2author, "chunksize": self.chunksize, "passes": self.passes, "iterations": self.iterations, "decay": self.decay, "offset": self.offset, "alpha": self.alpha, "eta": self.eta, "update_every": self.update_every, @@ -70,13 +69,9 @@ def set_params(self, **parameters): def fit(self, X, y=None): """ Fit the model according to the given training data. - Calls gensim.models.AuthorTopicModel: - >>> gensim.models.AuthorTopicModel(corpus=self.corpus, num_topics=self.num_topics, id2word=self.id2word, author2doc=self.author2doc, doc2author=self.doc2author, - chunksize=self.chunksize, passes=self.passes, iterations=self.iterations, decay=self.decay, offset=self.offset, alpha=self.alpha, eta=self.eta update_every=self.update_every, - eval_every=self.eval_every, gamma_threshold=self.gamma_threshold, serialized=self.serialized, serialization_path=self.serialization_path, minimum_probability=self.minimum_probability, random_state=self.random_state) + Calls gensim.models.AuthorTopicModel """ - self.corpus = X - self.__model = models.AuthorTopicModel(corpus=self.corpus, num_topics=self.num_topics, id2word=self.id2word, + self.__model = models.AuthorTopicModel(corpus=X, num_topics=self.num_topics, id2word=self.id2word, author2doc=self.author2doc, doc2author=self.doc2author, chunksize=self.chunksize, passes=self.passes, iterations=self.iterations, decay=self.decay, offset=self.offset, alpha=self.alpha, eta=self.eta, update_every=self.update_every, eval_every=self.eval_every, gamma_threshold=self.gamma_threshold, serialized=self.serialized, From 6437ac75e0c543172ffa981bea36e3fdc24deb4e Mon Sep 17 00:00:00 2001 From: Chinmaya Pancholi Date: Wed, 14 Jun 2017 20:47:27 -0700 Subject: [PATCH 07/15] updates 'self.model' to 'self.gensim_model' --- .../sklearn_integration/sklearn_wrapper_gensim_atmodel.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py b/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py index 0206613022..984b0950a3 100644 --- a/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py +++ b/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py @@ -28,7 +28,7 @@ def __init__(self, num_topics=100, id2word=None, author2doc=None, doc2author=Non """ Sklearn wrapper for AuthorTopic model. Class derived from gensim.models.AuthorTopicModel """ - self.__model = None + self.gensim_model = None self.num_topics = num_topics self.id2word = id2word self.author2doc = author2doc @@ -71,7 +71,7 @@ def fit(self, X, y=None): Fit the model according to the given training data. Calls gensim.models.AuthorTopicModel """ - self.__model = models.AuthorTopicModel(corpus=X, num_topics=self.num_topics, id2word=self.id2word, + self.gensim_model = models.AuthorTopicModel(corpus=X, num_topics=self.num_topics, id2word=self.id2word, author2doc=self.author2doc, doc2author=self.doc2author, chunksize=self.chunksize, passes=self.passes, iterations=self.iterations, decay=self.decay, offset=self.offset, alpha=self.alpha, eta=self.eta, update_every=self.update_every, eval_every=self.eval_every, gamma_threshold=self.gamma_threshold, serialized=self.serialized, @@ -82,10 +82,10 @@ def transform(self, author_names): Return topic distribution for input author as a list of (topic_id, topic_probabiity) 2-tuples. """ - return self.__model[author_names] + return self.gensim_model[author_names] def partial_fit(self, X, author2doc=None, doc2author=None): """ Train model over X. """ - self.__model.update(corpus=X, author2doc=author2doc, doc2author=doc2author) + self.gensim_model.update(corpus=X, author2doc=author2doc, doc2author=doc2author) From 45acaf9c2bd91e72704f618662eee2f72d45a59d Mon Sep 17 00:00:00 2001 From: Chinmaya Pancholi Date: Thu, 15 Jun 2017 02:11:55 -0700 Subject: [PATCH 08/15] updated 'fit' and 'transform' functions --- .../sklearn_wrapper_gensim_atmodel.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py b/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py index 984b0950a3..8fb97d3d2e 100644 --- a/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py +++ b/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py @@ -8,7 +8,7 @@ Scikit learn interface for gensim for easy use of gensim with scikit-learn Follows scikit-learn API conventions """ - +import numpy as np from sklearn.base import TransformerMixin, BaseEstimator from gensim import models @@ -76,13 +76,27 @@ def fit(self, X, y=None): iterations=self.iterations, decay=self.decay, offset=self.offset, alpha=self.alpha, eta=self.eta, update_every=self.update_every, eval_every=self.eval_every, gamma_threshold=self.gamma_threshold, serialized=self.serialized, serialization_path=self.serialization_path, minimum_probability=self.minimum_probability, random_state=self.random_state) + return self def transform(self, author_names): """ - Return topic distribution for input author as a list of + Return topic distribution for input authors as a list of (topic_id, topic_probabiity) 2-tuples. """ - return self.gensim_model[author_names] + # The input as array of array + check = lambda x: [x] if isinstance(x[0], tuple) else x + author_names = check(author_names) + X = [[] for _ in range(0, len(author_names))] + + for k, v in enumerate(author_names): + transformed_author = self.gensim_model[v] + probs_author = list(map(lambda x: x[1], transformed_author)) + # Everything should be equal in length + if len(probs_author) != self.num_topics: + probs_author.extend([1e-12]*(self.num_topics - len(probs_author))) + X[k] = probs_author + + return np.reshape(np.array(X), (len(author_names), self.num_topics)) def partial_fit(self, X, author2doc=None, doc2author=None): """ From e50ffe492b1289247a9d54bf71137a90c9c7e964 Mon Sep 17 00:00:00 2001 From: Chinmaya Pancholi Date: Thu, 15 Jun 2017 02:12:56 -0700 Subject: [PATCH 09/15] updated 'testTransform' test --- gensim/test/test_sklearn_integration.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/gensim/test/test_sklearn_integration.py b/gensim/test/test_sklearn_integration.py index 3173d546f8..01df819c4e 100644 --- a/gensim/test/test_sklearn_integration.py +++ b/gensim/test/test_sklearn_integration.py @@ -205,9 +205,10 @@ def setUp(self): self.model.fit(corpus) def testTransform(self): - jill_topics = self.model.transform('jill') - jill_topics = matutils.sparse2full(jill_topics, self.model.num_topics) - self.assertTrue(all(jill_topics > 0)) + author_list = ['jill', 'jack'] + author_topics = self.model.transform(author_list) + self.assertEqual(author_topics.shape[0], 2) + self.assertEqual(author_topics.shape[1], self.model.num_topics) def testPartialFit(self): self.model.partial_fit(corpus_new, author2doc=author2doc_new) From ba1330fcf66f0c104bdff4e0226770be18f90b7a Mon Sep 17 00:00:00 2001 From: Chinmaya Pancholi Date: Thu, 15 Jun 2017 02:44:52 -0700 Subject: [PATCH 10/15] updated 'tranform' function slightly --- gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py b/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py index 8fb97d3d2e..c106ec04ee 100644 --- a/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py +++ b/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py @@ -84,7 +84,7 @@ def transform(self, author_names): (topic_id, topic_probabiity) 2-tuples. """ # The input as array of array - check = lambda x: [x] if isinstance(x[0], tuple) else x + check = lambda x: [x] if not isinstance(x, list) else x author_names = check(author_names) X = [[] for _ in range(0, len(author_names))] From c7c697b6216b077b45af09acfb629d4624ce5f49 Mon Sep 17 00:00:00 2001 From: Chinmaya Pancholi Date: Thu, 15 Jun 2017 02:45:30 -0700 Subject: [PATCH 11/15] updated 'testTransform' test --- gensim/test/test_sklearn_integration.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/gensim/test/test_sklearn_integration.py b/gensim/test/test_sklearn_integration.py index 01df819c4e..032d6a33f9 100644 --- a/gensim/test/test_sklearn_integration.py +++ b/gensim/test/test_sklearn_integration.py @@ -205,6 +205,7 @@ def setUp(self): self.model.fit(corpus) def testTransform(self): + # transforming multiple authors author_list = ['jill', 'jack'] author_topics = self.model.transform(author_list) self.assertEqual(author_topics.shape[0], 2) @@ -214,8 +215,8 @@ def testPartialFit(self): self.model.partial_fit(corpus_new, author2doc=author2doc_new) # Did we learn something about Sally? - sally_topics = self.model.transform('sally') - sally_topics = matutils.sparse2full(sally_topics, self.model.num_topics) + output_topics = self.model.transform('sally') + sally_topics = output_topics[0] # getting the topics corresponding to 'sally' (from the list of lists) self.assertTrue(all(sally_topics > 0)) def testSetGetParams(self): From 4cba42031b67401d75ce9f1617be448f6b3d8fb7 Mon Sep 17 00:00:00 2001 From: Chinmaya Pancholi Date: Thu, 15 Jun 2017 02:46:53 -0700 Subject: [PATCH 12/15] PEP8 change --- gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py b/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py index c106ec04ee..5e14dc7a40 100644 --- a/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py +++ b/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py @@ -93,7 +93,7 @@ def transform(self, author_names): probs_author = list(map(lambda x: x[1], transformed_author)) # Everything should be equal in length if len(probs_author) != self.num_topics: - probs_author.extend([1e-12]*(self.num_topics - len(probs_author))) + probs_author.extend([1e-12] * (self.num_topics - len(probs_author))) X[k] = probs_author return np.reshape(np.array(X), (len(author_names), self.num_topics)) From 2ff7e85de6f9cc9d605daaf9f72d35051964b463 Mon Sep 17 00:00:00 2001 From: Chinmaya Pancholi Date: Thu, 15 Jun 2017 03:17:20 -0700 Subject: [PATCH 13/15] updated 'testTransform' test --- gensim/test/test_sklearn_integration.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gensim/test/test_sklearn_integration.py b/gensim/test/test_sklearn_integration.py index 032d6a33f9..ade83a76b5 100644 --- a/gensim/test/test_sklearn_integration.py +++ b/gensim/test/test_sklearn_integration.py @@ -211,6 +211,11 @@ def testTransform(self): self.assertEqual(author_topics.shape[0], 2) self.assertEqual(author_topics.shape[1], self.model.num_topics) + # transforming one author + jill_topics = self.model.transform('jill') + self.assertEqual(jill_topics.shape[0], 1) + self.assertEqual(jill_topics.shape[1], self.model.num_topics) + def testPartialFit(self): self.model.partial_fit(corpus_new, author2doc=author2doc_new) From ac5fa8198312e5d1834ef418e5d8cbedf6c1dd6b Mon Sep 17 00:00:00 2001 From: Chinmaya Pancholi Date: Thu, 22 Jun 2017 03:43:57 -0700 Subject: [PATCH 14/15] included 'NotFittedError' error --- gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py b/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py index 63d18a9208..2ef6db3c7a 100644 --- a/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py +++ b/gensim/sklearn_integration/sklearn_wrapper_gensim_atmodel.py @@ -10,6 +10,7 @@ """ import numpy as np from sklearn.base import TransformerMixin, BaseEstimator +from sklearn.exceptions import NotFittedError from gensim import models from gensim.sklearn_integration import BaseSklearnWrapper From 1ec40f90f8a48cd90d866b10216ddb188340a1b8 Mon Sep 17 00:00:00 2001 From: Chinmaya Pancholi Date: Fri, 23 Jun 2017 03:27:18 -0700 Subject: [PATCH 15/15] added pipeline unittest for ATModel --- gensim/test/test_sklearn_integration.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/gensim/test/test_sklearn_integration.py b/gensim/test/test_sklearn_integration.py index 0d003c3f6a..09e65dcc8b 100644 --- a/gensim/test/test_sklearn_integration.py +++ b/gensim/test/test_sklearn_integration.py @@ -10,7 +10,7 @@ from sklearn.pipeline import Pipeline from sklearn.feature_extraction.text import CountVectorizer from sklearn.datasets import load_files - from sklearn import linear_model + from sklearn import linear_model, cluster from sklearn.exceptions import NotFittedError except ImportError: raise unittest.SkipTest("Test requires scikit-learn to be installed, which is not available") @@ -441,6 +441,22 @@ def testSetGetParams(self): for key in param_dict.keys(): self.assertEqual(model_params[key], param_dict[key]) + def testPipeline(self): + # train the AuthorTopic model first + model = SklATModel(id2word=dictionary, author2doc=author2doc, num_topics=10, passes=100) + model.fit(corpus) + + # create and train clustering model + clstr = cluster.MiniBatchKMeans(n_clusters=2) + authors_full = ['john', 'jane', 'jack', 'jill'] + clstr.fit(model.transform(authors_full)) + + # stack together the two models in a pipeline + text_atm = Pipeline((('features', model,), ('cluster', clstr))) + author_list = ['jane', 'jack', 'jill'] + ret_val = text_atm.predict(author_list) + self.assertEqual(len(ret_val), len(author_list)) + if __name__ == '__main__': unittest.main()