-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Added sklearn wrapper for LDASeq model #1405
Merged
menshikh-iv
merged 21 commits into
piskvorky:develop
from
chinmayapancholi13:ldaseq_sklearn_wrapper
Jun 20, 2017
Merged
Changes from 13 commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
73cd770
added new file for LDASeq model's sklearn wrapper
chinmayapancholi13 4744c7b
PEP8 changes
chinmayapancholi13 d79f125
added 'transform' and 'partial_fit' methods
chinmayapancholi13 07efa33
added unit_tests for ldaseq model
chinmayapancholi13 d73838e
PEP8 changes
chinmayapancholi13 6e57c5f
PEP8 changes
chinmayapancholi13 c969c8b
refactored code acc. to composite design pattern
chinmayapancholi13 8b0cced
refactored wrapper and tests
chinmayapancholi13 ea9922e
removed 'self.corpus' attribute
chinmayapancholi13 8f88a10
updated 'self.__model' to 'self.gensim_model'
chinmayapancholi13 4f33248
updated 'fit' and 'transform' functions
chinmayapancholi13 8aa6898
updated 'testTransform' test
chinmayapancholi13 77a8672
updated 'testTransform' test
chinmayapancholi13 ad895a2
added 'NotFittedError' in 'transform' function
chinmayapancholi13 6f9929a
added 'testPersistence' and 'testModelNotFitted' tests
chinmayapancholi13 05b63e3
added description for 'docs' in docstring of 'transform'
chinmayapancholi13 3452e80
added 'testPipeline' test
chinmayapancholi13 492fbc6
PEP8 change
chinmayapancholi13 dec60e1
replaced 'text_lda' variable with 'text_ldaseq'
chinmayapancholi13 fd5fc90
updated 'testPersistence' test
chinmayapancholi13 e041431
set fixed seed in 'testPipeline' test
chinmayapancholi13 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
95 changes: 95 additions & 0 deletions
95
gensim/sklearn_integration/sklearn_wrapper_gensim_ldaseqmodel.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (C) 2011 Radim Rehurek <[email protected]> | ||
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html | ||
|
||
""" | ||
Scikit learn interface for gensim for easy use of gensim with scikit-learn | ||
Follows scikit-learn API conventions | ||
""" | ||
|
||
import numpy as np | ||
from sklearn.base import TransformerMixin, BaseEstimator | ||
|
||
from gensim import models | ||
from gensim.sklearn_integration import base_sklearn_wrapper | ||
|
||
|
||
class SklLdaSeqModel(base_sklearn_wrapper.BaseSklearnWrapper, TransformerMixin, BaseEstimator): | ||
""" | ||
Base LdaSeq module | ||
""" | ||
|
||
def __init__(self, time_slice=None, id2word=None, alphas=0.01, num_topics=10, | ||
initialize='gensim', sstats=None, lda_model=None, obs_variance=0.5, chain_variance=0.005, passes=10, | ||
random_state=None, lda_inference_max_iter=25, em_min_iter=6, em_max_iter=20, chunksize=100): | ||
""" | ||
Sklearn wrapper for LdaSeq model. Class derived from gensim.models.LdaSeqModel | ||
""" | ||
self.gensim_model = None | ||
self.time_slice = time_slice | ||
self.id2word = id2word | ||
self.alphas = alphas | ||
self.num_topics = num_topics | ||
self.initialize = initialize | ||
self.sstats = sstats | ||
self.lda_model = lda_model | ||
self.obs_variance = obs_variance | ||
self.chain_variance = chain_variance | ||
self.passes = passes | ||
self.random_state = random_state | ||
self.lda_inference_max_iter = lda_inference_max_iter | ||
self.em_min_iter = em_min_iter | ||
self.em_max_iter = em_max_iter | ||
self.chunksize = chunksize | ||
|
||
def get_params(self, deep=True): | ||
""" | ||
Returns all parameters as dictionary. | ||
""" | ||
return {"time_slice": self.time_slice, "id2word": self.id2word, | ||
"alphas": self.alphas, "num_topics": self.num_topics, "initialize": self.initialize, | ||
"sstats": self.sstats, "lda_model": self.lda_model, "obs_variance": self.obs_variance, | ||
"chain_variance": self.chain_variance, "passes": self.passes, "random_state": self.random_state, | ||
"lda_inference_max_iter": self.lda_inference_max_iter, "em_min_iter": self.em_min_iter, | ||
"em_max_iter": self.em_max_iter, "chunksize": self.chunksize} | ||
|
||
def set_params(self, **parameters): | ||
""" | ||
Set all parameters. | ||
""" | ||
super(SklLdaSeqModel, self).set_params(**parameters) | ||
|
||
def fit(self, X, y=None): | ||
""" | ||
Fit the model according to the given training data. | ||
Calls gensim.models.LdaSeqModel | ||
""" | ||
self.gensim_model = models.LdaSeqModel(corpus=X, time_slice=self.time_slice, id2word=self.id2word, | ||
alphas=self.alphas, num_topics=self.num_topics, initialize=self.initialize, sstats=self.sstats, | ||
lda_model=self.lda_model, obs_variance=self.obs_variance, chain_variance=self.chain_variance, | ||
passes=self.passes, random_state=self.random_state, lda_inference_max_iter=self.lda_inference_max_iter, | ||
em_min_iter=self.em_min_iter, em_max_iter=self.em_max_iter, chunksize=self.chunksize) | ||
return self | ||
|
||
def transform(self, docs): | ||
""" | ||
Return the topic proportions for the documents passed. | ||
""" | ||
# The input as array of array | ||
check = lambda x: [x] if isinstance(x[0], tuple) else x | ||
docs = check(docs) | ||
X = [[] for _ in range(0, len(docs))] | ||
|
||
for k, v in enumerate(docs): | ||
transformed_author = self.gensim_model[v] | ||
# Everything should be equal in length | ||
if len(transformed_author) != self.num_topics: | ||
transformed_author.extend([1e-12] * (self.num_topics - len(transformed_author))) | ||
X[k] = transformed_author | ||
|
||
return np.reshape(np.array(X), (len(docs), self.num_topics)) | ||
|
||
def partial_fit(self, X): | ||
raise NotImplementedError("'partial_fit' has not been implemented for the LDA Seq model") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Chek case, when you create instance and call transform immediately (without fit), you need to raise exception like sklearn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, please add an example of
docs
param in docstring.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@menshikh-iv For checking if the model has been fitted, would it be a good idea to check if
self.gensim_model
isNone
or not? This approach would clearly give an error whenfit
hasn't been called before callingtransform
but this also allows the user to set the value ofself.gensim_model
throughset_params
function (or even aswrapper.gensim_model=...
) and then calltransform
function, which makes sense for us to allow.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I completely forgot about
set_param
, so, I think if you disablegensim_model
in set_param, you can checkmodel is None
(it does not cover all cases, but covers the most obvious)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you elaborate the meaning of "disabling"
gensim_model
param from the functionset_params
?Actually,
gensim_model
is a public attribute of the model so it can be set likeldaseq_wrapper.gensim_model = some_model
, which is almost the same as usingset_params
function to set this value. So, checking whetherself.gensim_model
isNone
should be enough, right?This would be like :
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, as a temporary option.