Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update sklearn API for Gensim models #1473

Merged
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
c55fcc9
renamed sklearn wrapper classes
Jul 6, 2017
dde234f
added newline for flake8 check
Jul 7, 2017
721806b
renamed sklearn api files
Jul 10, 2017
3cdcfde
updated tests for sklearn api
Jul 10, 2017
99dba85
updated ipynb for sklearn api
Jul 10, 2017
4d1eaf4
PEP8 changes
Jul 10, 2017
155a1ec
updated docstrings for sklearn wrappers
Jul 11, 2017
ae6c0f3
added 'testPersistence' and 'testModelNotFitted' tests for author top…
Jul 11, 2017
3c78873
removed 'set_params' function from all wrappers
chinmayapancholi13 Jul 13, 2017
341ed1f
removed 'get_params' function from base class
chinmayapancholi13 Jul 14, 2017
9113f82
removed 'get_params' function from all api classes
chinmayapancholi13 Jul 14, 2017
2935680
removed 'partial_fit()' from base class
chinmayapancholi13 Jul 19, 2017
9628f99
updated error message
chinmayapancholi13 Jul 19, 2017
3849d06
updated error message for 'partial_fit' function in W2VTransformer
chinmayapancholi13 Jul 20, 2017
6097349
removed 'BaseTransformer' class
chinmayapancholi13 Jul 26, 2017
5b21875
updated error message for 'partial_fit' in 'W2VTransformer'
chinmayapancholi13 Jul 26, 2017
6bfdb4d
added checks for setting attributes after calling 'fit'
chinmayapancholi13 Jul 27, 2017
9f0be87
flake8 fix
chinmayapancholi13 Jul 27, 2017
6004eee
using 'sparse2full' in 'transform' function
chinmayapancholi13 Jul 27, 2017
3262ec2
added missing imports
chinmayapancholi13 Jul 27, 2017
d4e560e
added comment about returning dense representation in 'transform' fun…
chinmayapancholi13 Jul 27, 2017
ad3f1f7
added 'testConsistencyWithGensimModel' for ldamodel
chinmayapancholi13 Jul 27, 2017
877632e
updated ipynb
chinmayapancholi13 Jul 27, 2017
0871b50
updated 'testPartialFit' for Lda and Lsi transformers
chinmayapancholi13 Jul 28, 2017
3f363a1
added author info
chinmayapancholi13 Jul 28, 2017
c0894bc
added 'testConsistencyWithGensimModel' for w2v transformer
chinmayapancholi13 Jul 28, 2017
9b7402d
removed merge conflicts
chinmayapancholi13 Aug 4, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
562 changes: 562 additions & 0 deletions docs/notebooks/sklearn_api.ipynb

Large diffs are not rendered by default.

683 changes: 0 additions & 683 deletions docs/notebooks/sklearn_wrapper.ipynb

This file was deleted.

19 changes: 19 additions & 0 deletions gensim/sklearn_api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2011 Radim Rehurek <[email protected]>
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chinmayapancholi13 please add yourself as the author. This is your baby :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hehe. Ok! 😄

# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html
"""Scikit learn wrapper for gensim.
Contains various gensim based implementations which match with scikit-learn standards.
See [1] for complete set of conventions.
[1] http://scikit-learn.org/stable/developers/
"""


from .basemodel import BaseTransformer # noqa: F401
from .ldamodel import LdaTransformer # noqa: F401
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is all that noqa: F401 for? Is it really necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A line having noqa doesn't issue flake8 errors. Writing noqa: F401 saves us from flake8 errors saying that we have imports (e.g. LdaTransformer) which we haven't used anywhere in the file.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels wrong to litter the gensim code with such constructs -- the gensim code is correct, this is essentially working around some idiosyncrasy (bug?) of an unrelated library.

By the way, how come we don't these errors from all the other __init__ imports in gensim? Or do we? CC @menshikh-iv

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piskvorky because flake8 analyze the only diff every time, if you change same lines in any other __init__ file you will get same error.

Copy link
Owner

@piskvorky piskvorky Jul 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so the fake warnings will go away immediately? Why did we add these comments then. Let's remove it.

from .lsimodel import LsiTransformer # noqa: F401
from .rpmodel import RpTransformer # noqa: F401
from .ldaseqmodel import LdaSeqTransformer # noqa: F401
from .w2vmodel import W2VTransformer # noqa: F401
from .atmodel import AuthorTopicTransformer # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
from sklearn.exceptions import NotFittedError

from gensim import models
from gensim.sklearn_integration import BaseSklearnWrapper
from gensim.sklearn_api import BaseTransformer


class SklATModel(BaseSklearnWrapper, TransformerMixin, BaseEstimator):
class AuthorTopicTransformer(BaseTransformer, TransformerMixin, BaseEstimator):
"""
Base AuthorTopic module
"""
Expand All @@ -27,7 +27,7 @@ def __init__(self, num_topics=100, id2word=None, author2doc=None, doc2author=Non
gamma_threshold=0.001, serialized=False, serialization_path=None,
minimum_probability=0.01, random_state=None):
"""
Sklearn wrapper for AuthorTopic model. Class derived from gensim.models.AuthorTopicModel
Sklearn wrapper for AuthorTopic model. See gensim.models.AuthorTopicModel for parameter details.
"""
self.gensim_model = None
self.num_topics = num_topics
Expand All @@ -49,25 +49,6 @@ def __init__(self, num_topics=100, id2word=None, author2doc=None, doc2author=Non
self.minimum_probability = minimum_probability
self.random_state = random_state

def get_params(self, deep=True):
"""
Returns all parameters as dictionary.
"""
return {"num_topics": self.num_topics, "id2word": self.id2word,
"author2doc": self.author2doc, "doc2author": self.doc2author, "chunksize": self.chunksize,
"passes": self.passes, "iterations": self.iterations, "decay": self.decay,
"offset": self.offset, "alpha": self.alpha, "eta": self.eta, "update_every": self.update_every,
"eval_every": self.eval_every, "gamma_threshold": self.gamma_threshold,
"serialized": self.serialized, "serialization_path": self.serialization_path,
"minimum_probability": self.minimum_probability, "random_state": self.random_state}

def set_params(self, **parameters):
"""
Set all parameters.
"""
super(SklATModel, self).set_params(**parameters)
return self

def fit(self, X, y=None):
"""
Fit the model according to the given training data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,16 @@
from abc import ABCMeta, abstractmethod


class BaseSklearnWrapper(object):
class BaseTransformer(object):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this refactoring... I'm again wondering, is this class even helpful? Its set_params() does less than the implementation that would be inherited from sklearn's own BaseEstimator. The other methods just serve to enforce that any subclass override certain methods – a rigor that is possibly useful, but that sklearn does not, itself, impose within its own superclass-hierarchy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gojomo scikit-learn indeed does not impose such a hierarchy among its classes. However, as you stated above, in addition to not having to implement the set_params() every time, we know which functions are to be implemented for each model. Also, if in the future, we decide to refine/improve the set_params() function and make it more sophisticated (e.g. as in sklearn's BaseEstimator), we would only have to do it once in the base class rather than having to make the same change in the sklearn-api class of each Gensim model. Hence, having an abstract base class and deriving classes for Gensim models from it would help us in writing new sklearn-api classes in the future as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But, the existing set_params() looks worse than just doing nothing (and inheriting from BaseEstimator as all current subclasses of this class already do) – unless I'm missing something. So there's no need to plan for a future where "we decide to refine/improve" it, and do such improvements in the base class. We can make that decision now, and delete 8 lines of code from the base class, and all the subclasses automatically get the better & more standard (for sklearn) implementation of set_params().

And then, the only benefit of this shared superclass is forcing an error as a reminder of what to implement. But if sklearn doesn't do that itself, is it that important to do so? It's also over-enforcing – it is not a strict requirement that all transformers support partial_fit() - not every algorithm will necessarily support that.

Copy link
Contributor

@tmylk tmylk Jul 12, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my head I expect all sklearn models to have partial_fit, fit, transform and be used in a pipeline. We constantly get questions on the mailing list about how to update model X for almost every sklearn algo. Raising a NotImplemented with an explanation for word2vec and doc2vec use gensim api update function but it's not recommended as it hasn't been researched. And throw an explicitly that RandomProj can't do it because of algo limitations.

Also set_params method shouldn't be abstract and shouldn't be defined in all the children -it's the same code. Having common set_params makes the base useful.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, providing a place to include a error/warning is a potential reason to implement partial_fit() even where the underlying algorithm can't (or shouldn't casually) be used that way. BUT, simply not implementing a method is also a pretty good way of signalling that something isn't implemented – and indeed that seems to be what scikit-learn itself does. For that reason, it's even possible other code will use introspection about whether an sklearn pipeline object implements partial_fit() as a test of support, in which case having an implemented-method that just errors could be suboptimal.

But with set_params(), again, why have any lesser-featured implementation here, when every subclass (so far) inherits from scikit-learn's own BaseEstimator, which provides a more fully-featured implementation? Why wouldn't we want fewer limitations (and match the source library's behavior) for zero lines of code, rather than more limitations with more code?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi,

jumping in the conversation a little late here, but I do agree that the old BaseSklearnWrapper does very little since

  • the default get_params and set_params from sklearn's BaseEstimator are very reasonable and should only be overriden when necessary; if all the concrete wrapper classes inherit from BaseEstimator (as they do now) then there's no need to force them to re-implement these methods
  • @tmylk partial_fit is absolutely not a requirement for a class to be compatible with sklearn Pipeline since lots of algorithm are not online in nature. If there is a particular algorithm that most people are confused about, then the warning should be raised for that class only.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would struggle to find a gensim algorithm for which I haven't seen a "is it possible to update the model post-training?" question on the mailing list...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, it's always asked, but not always supported. The scikit-learn convention for not supporting incremental training is to simply not implement partial_fit(). Why not adopt the same convention for scikit-learn compatible classes here?

Copy link
Owner

@piskvorky piskvorky Jul 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @gojomo 's point of view here. The point of these adapters is to assume the "philosophy" of another package, whatever structure or idiosyncrasies that may have. The least amount of surprises and novelty for users of that package.

It is not the place of the adapters to introduce their own philosophy on class structures or missing methods or whatever.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed the BaseTransformer class now. Thanks for the useful feedback and suggestions. :)

"""
Base sklearn wrapper module
"""
__metaclass__ = ABCMeta

@abstractmethod
def get_params(self, deep=True):
pass

@abstractmethod
def set_params(self, **parameters):
"""
Set all parameters.
"""
for parameter, value in parameters.items():
setattr(self, parameter, value)
return self

@abstractmethod
def fit(self, X, y=None):
pass

@abstractmethod
def transform(self, docs, minimum_probability=None):
pass

@abstractmethod
def partial_fit(self, X):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

from gensim import models
from gensim import matutils
from gensim.sklearn_integration import BaseSklearnWrapper
from gensim.sklearn_api import BaseTransformer


class SklLdaModel(BaseSklearnWrapper, TransformerMixin, BaseEstimator):
class LdaTransformer(BaseTransformer, TransformerMixin, BaseEstimator):
"""
Base LDA module
"""
Expand All @@ -31,7 +31,7 @@ def __init__(
eval_every=10, iterations=50, gamma_threshold=0.001,
minimum_probability=0.01, random_state=None):
"""
Sklearn wrapper for LDA model. derived class for gensim.model.LdaModel .
Sklearn wrapper for LDA model. See gensim.model.LdaModel for parameter details.
"""
self.gensim_model = None
self.num_topics = num_topics
Expand All @@ -49,23 +49,6 @@ def __init__(
self.minimum_probability = minimum_probability
self.random_state = random_state

def get_params(self, deep=True):
"""
Returns all parameters as dictionary.
"""
return {"num_topics": self.num_topics, "id2word": self.id2word, "chunksize": self.chunksize,
"passes": self.passes, "update_every": self.update_every, "alpha": self.alpha, "eta": self.eta,
"decay": self.decay, "offset": self.offset, "eval_every": self.eval_every, "iterations": self.iterations,
"gamma_threshold": self.gamma_threshold, "minimum_probability": self.minimum_probability,
"random_state": self.random_state}

def set_params(self, **parameters):
"""
Set all parameters.
"""
super(SklLdaModel, self).set_params(**parameters)
return self

def fit(self, X, y=None):
"""
Fit the model according to the given training data.
Expand All @@ -86,8 +69,8 @@ def fit(self, X, y=None):

def transform(self, docs):
"""
Takes as an list of input a documents (documents).
Returns matrix of topic distribution for the given document bow, where a_ij
Takes a list of documents as input ('docs').
Returns a matrix of topic distribution for the given document bow, where a_ij
indicates (topic_i, topic_probability_j).
The input `docs` should be in BOW format and can be a list of documents like : [ [(4, 1), (7, 1)], [(9, 1), (13, 1)], [(2, 1), (6, 1)] ]
or a single document like : [(4, 1), (7, 1)]
Expand All @@ -105,7 +88,7 @@ def transform(self, docs):
probs_docs = list(map(lambda x: x[1], doc_topics))
# Everything should be equal in length
if len(probs_docs) != self.num_topics:
probs_docs.extend([1e-12]*(self.num_topics - len(probs_docs)))
probs_docs.extend([1e-12] * (self.num_topics - len(probs_docs)))
X[k] = probs_docs
return np.reshape(np.array(X), (len(docs), self.num_topics))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from sklearn.exceptions import NotFittedError

from gensim import models
from gensim.sklearn_integration import BaseSklearnWrapper
from gensim.sklearn_api import BaseTransformer


class SklLdaSeqModel(BaseSklearnWrapper, TransformerMixin, BaseEstimator):
class LdaSeqTransformer(BaseTransformer, TransformerMixin, BaseEstimator):
"""
Base LdaSeq module
"""
Expand All @@ -26,7 +26,7 @@ def __init__(self, time_slice=None, id2word=None, alphas=0.01, num_topics=10,
initialize='gensim', sstats=None, lda_model=None, obs_variance=0.5, chain_variance=0.005, passes=10,
random_state=None, lda_inference_max_iter=25, em_min_iter=6, em_max_iter=20, chunksize=100):
"""
Sklearn wrapper for LdaSeq model. Class derived from gensim.models.LdaSeqModel
Sklearn wrapper for LdaSeq model. See gensim.models.LdaSeqModel for parameter details.
"""
self.gensim_model = None
self.time_slice = time_slice
Expand All @@ -45,24 +45,6 @@ def __init__(self, time_slice=None, id2word=None, alphas=0.01, num_topics=10,
self.em_max_iter = em_max_iter
self.chunksize = chunksize

def get_params(self, deep=True):
"""
Returns all parameters as dictionary.
"""
return {"time_slice": self.time_slice, "id2word": self.id2word,
"alphas": self.alphas, "num_topics": self.num_topics, "initialize": self.initialize,
"sstats": self.sstats, "lda_model": self.lda_model, "obs_variance": self.obs_variance,
"chain_variance": self.chain_variance, "passes": self.passes, "random_state": self.random_state,
"lda_inference_max_iter": self.lda_inference_max_iter, "em_min_iter": self.em_min_iter,
"em_max_iter": self.em_max_iter, "chunksize": self.chunksize}

def set_params(self, **parameters):
"""
Set all parameters.
"""
super(SklLdaSeqModel, self).set_params(**parameters)
return self

def fit(self, X, y=None):
"""
Fit the model according to the given training data.
Expand Down Expand Up @@ -97,6 +79,3 @@ def transform(self, docs):
X[k] = transformed_author

return np.reshape(np.array(X), (len(docs), self.num_topics))

def partial_fit(self, X):
raise NotImplementedError("'partial_fit' has not been implemented for SklLdaSeqModel")
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@

from gensim import models
from gensim import matutils
from gensim.sklearn_integration import BaseSklearnWrapper
from gensim.sklearn_api import BaseTransformer


class SklLsiModel(BaseSklearnWrapper, TransformerMixin, BaseEstimator):
class LsiTransformer(BaseTransformer, TransformerMixin, BaseEstimator):
"""
Base LSI module
"""

def __init__(self, num_topics=200, id2word=None, chunksize=20000,
decay=1.0, onepass=True, power_iters=2, extra_samples=100):
"""
Sklearn wrapper for LSI model. Class derived from gensim.model.LsiModel.
Sklearn wrapper for LSI model. See gensim.model.LsiModel for parameter details.
"""
self.gensim_model = None
self.num_topics = num_topics
Expand All @@ -38,21 +38,6 @@ def __init__(self, num_topics=200, id2word=None, chunksize=20000,
self.extra_samples = extra_samples
self.power_iters = power_iters

def get_params(self, deep=True):
"""
Returns all parameters as dictionary.
"""
return {"num_topics": self.num_topics, "id2word": self.id2word,
"chunksize": self.chunksize, "decay": self.decay, "onepass": self.onepass,
"extra_samples": self.extra_samples, "power_iters": self.power_iters}

def set_params(self, **parameters):
"""
Set all parameters.
"""
super(SklLsiModel, self).set_params(**parameters)
return self

def fit(self, X, y=None):
"""
Fit the model according to the given training data.
Expand Down Expand Up @@ -81,13 +66,13 @@ def transform(self, docs):
# The input as array of array
check = lambda x: [x] if isinstance(x[0], tuple) else x
docs = check(docs)
X = [[] for i in range(0,len(docs))];
for k,v in enumerate(docs):
X = [[] for i in range(0, len(docs))]
for k, v in enumerate(docs):
doc_topics = self.gensim_model[v]
probs_docs = list(map(lambda x: x[1], doc_topics))
# Everything should be equal in length
if len(probs_docs) != self.num_topics:
probs_docs.extend([1e-12]*(self.num_topics - len(probs_docs)))
probs_docs.extend([1e-12] * (self.num_topics - len(probs_docs)))
X[k] = probs_docs
probs_docs = []
return np.reshape(np.array(X), (len(docs), self.num_topics))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,22 @@
from sklearn.exceptions import NotFittedError

from gensim import models
from gensim.sklearn_integration import BaseSklearnWrapper
from gensim.sklearn_api import BaseTransformer


class SklRpModel(BaseSklearnWrapper, TransformerMixin, BaseEstimator):
class RpTransformer(BaseTransformer, TransformerMixin, BaseEstimator):
"""
Base RP module
"""

def __init__(self, id2word=None, num_topics=300):
"""
Sklearn wrapper for RP model. Class derived from gensim.models.RpModel.
Sklearn wrapper for RP model. See gensim.models.RpModel for parameter details.
"""
self.gensim_model = None
self.id2word = id2word
self.num_topics = num_topics

def get_params(self, deep=True):
"""
Returns all parameters as dictionary.
"""
return {"id2word": self.id2word, "num_topics": self.num_topics}

def set_params(self, **parameters):
"""
Set all parameters.
"""
super(SklRpModel, self).set_params(**parameters)
return self

def fit(self, X, y=None):
"""
Fit the model according to the given training data.
Expand Down Expand Up @@ -75,6 +62,3 @@ def transform(self, docs):
X[k] = probs_docs

return np.reshape(np.array(X), (len(docs), self.num_topics))

def partial_fit(self, X):
raise NotImplementedError("'partial_fit' has not been implemented for SklRpModel")
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from sklearn.exceptions import NotFittedError

from gensim import models
from gensim.sklearn_integration import BaseSklearnWrapper
from gensim.sklearn_api import BaseTransformer


class SklW2VModel(BaseSklearnWrapper, TransformerMixin, BaseEstimator):
class W2VTransformer(BaseTransformer, TransformerMixin, BaseEstimator):
"""
Base Word2Vec module
"""
Expand All @@ -28,7 +28,7 @@ def __init__(self, size=100, alpha=0.025, window=5, min_count=5,
sg=0, hs=0, negative=5, cbow_mean=1, hashfxn=hash, iter=5, null_word=0,
trim_rule=None, sorted_vocab=1, batch_words=10000):
"""
Sklearn wrapper for Word2Vec model. Class derived from gensim.models.Word2Vec
Sklearn wrapper for Word2Vec model. See gensim.models.Word2Vec for parameter details.
"""
self.gensim_model = None
self.size = size
Expand All @@ -51,24 +51,6 @@ def __init__(self, size=100, alpha=0.025, window=5, min_count=5,
self.sorted_vocab = sorted_vocab
self.batch_words = batch_words

def get_params(self, deep=True):
"""
Returns all parameters as dictionary.
"""
return {"size": self.size, "alpha": self.alpha, "window": self.window, "min_count": self.min_count,
"max_vocab_size": self.max_vocab_size, "sample": self.sample, "seed": self.seed,
"workers": self.workers, "min_alpha": self.min_alpha, "sg": self.sg, "hs": self.hs,
"negative": self.negative, "cbow_mean": self.cbow_mean, "hashfxn": self.hashfxn,
"iter": self.iter, "null_word": self.null_word, "trim_rule": self.trim_rule,
"sorted_vocab": self.sorted_vocab, "batch_words": self.batch_words}

def set_params(self, **parameters):
"""
Set all parameters.
"""
super(SklW2VModel, self).set_params(**parameters)
return self

def fit(self, X, y=None):
"""
Fit the model according to the given training data.
Expand Down Expand Up @@ -101,4 +83,5 @@ def transform(self, words):
return np.reshape(np.array(X), (len(words), self.size))

def partial_fit(self, X):
raise NotImplementedError("'partial_fit' has not been implemented for SklW2VModel")
raise NotImplementedError("'partial_fit' has not been implemented for W2VTransformer since 'update()' function for Word2Vec class is experimental. "
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? Why would the update() method of word2vec be experimental?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's correct.

Copy link
Owner

@piskvorky piskvorky Jul 19, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't even see any update() method. Can you elaborate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piskvorky The original error message was not correct (since there is no update() function in Word2Vec class as you correctly pointed out) and I have updated the message now. My apologies for the confusion.
This change was made in reference to this comment by @gojomo in one of the older PRs.

Copy link
Owner

@piskvorky piskvorky Jul 20, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chinmayapancholi13 thanks for the clarification.

So the concern is about initializing the vocabulary. That's a separate concern to incremental training, which is fully supported and not experimental (in fact, it's the same code as non-incremental training).

I don't think partial_fit is a high priority, but at the same time, users ask for it all the time, especially in combination with incremental vocab updates (which are indeed experimental).

So I see two directions (both non-critical, nice-to-have) here:

  1. support incremental training with a fixed, pre-defined vocabulary
  • implies finding a natural way to initialize vocabulary before hand
  • implies finding a natural way to control the learning rate (alpha) during the incremental calls.
  1. support fully online training, including updating the vocabulary incrementally
  • implies changing the w2v algo(s) to support this (hashing trick with fixed hash space? what would this do to HS? or do we not support HS in this online mode?)
  • still implies finding a way to control alpha

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. So the issue is actually "updating the vocab" (and not "training") being experimental. I have updated the error message in partial_fit for W2VTransformer for this accordingly.
Thanks for your explanation and suggestions about resolving this concern. I guess it would be nice to create an issue for this. :)

"See usage and documentation of Word2Vec's 'update()' function if you need to train your Word2Vec model incrementally.")
Loading