diff --git a/gensim/models/ldamodel.py b/gensim/models/ldamodel.py index 048e5d4c51..c2324904e3 100755 --- a/gensim/models/ldamodel.py +++ b/gensim/models/ldamodel.py @@ -34,6 +34,7 @@ import logging import numpy as np # for arrays, array broadcasting etc. import numbers +import os from gensim import interfaces, utils, matutils from gensim.models import basemodel @@ -239,11 +240,11 @@ def __init__(self, corpus=None, num_topics=100, id2word=None, prior directly from your data. `eta` can be a scalar for a symmetric prior over topic/word - distributions, or a vector of shape num_words, which can be used to - impose (user defined) asymmetric priors over the word distribution. + distributions, or a vector of shape num_words, which can be used to + impose (user defined) asymmetric priors over the word distribution. It also supports the special value 'auto', which learns an asymmetric prior over words directly from your data. `eta` can also be a matrix - of shape num_topics x num_words, which can be used to impose + of shape num_topics x num_words, which can be used to impose asymmetric priors over the word distribution on a per-topic basis (can not be learned from data). @@ -995,7 +996,7 @@ def __getitem__(self, bow, eps=None): """ return self.get_document_topics(bow, eps, self.minimum_phi_value, self.per_word_topics) - def save(self, fname, ignore=['state', 'dispatcher'], *args, **kwargs): + def save(self, fname, ignore=['state', 'dispatcher'], separately=None, *args, **kwargs): """ Save the model to file. @@ -1024,18 +1025,39 @@ def save(self, fname, ignore=['state', 'dispatcher'], *args, **kwargs): """ if self.state is not None: self.state.save(utils.smart_extension(fname, '.state'), *args, **kwargs) + # Save the dictionary separately if not in 'ignore'. + if 'id2word' not in ignore: + utils.pickle(self.id2word, utils.smart_extension(fname, '.id2word')) - # make sure 'state' and 'dispatcher' are ignored from the pickled object, even if + # make sure 'state', 'id2word' and 'dispatcher' are ignored from the pickled object, even if # someone sets the ignore list themselves if ignore is not None and ignore: if isinstance(ignore, six.string_types): ignore = [ignore] ignore = [e for e in ignore if e] # make sure None and '' are not in the list - ignore = list(set(['state', 'dispatcher']) | set(ignore)) + ignore = list(set(['state', 'dispatcher', 'id2word']) | set(ignore)) else: - ignore = ['state', 'dispatcher'] - super(LdaModel, self).save(fname, *args, ignore=ignore, **kwargs) - + ignore = ['state', 'dispatcher', 'id2word'] + + # make sure 'expElogbeta' and 'sstats' are ignored from the pickled object, even if + # someone sets the separately list themselves. + separately_explicit = ['expElogbeta', 'sstats'] + # Also add 'alpha' and 'eta' to separately list if they are set 'auto' or some + # array manually. + if (isinstance(self.alpha, six.string_types) and self.alpha == 'auto') or len(self.alpha.shape) != 1: + separately_explicit.append('alpha') + if (isinstance(self.eta, six.string_types) and self.eta == 'auto') or len(self.eta.shape) != 1: + separately_explicit.append('eta') + # Merge separately_explicit with separately. + if separately: + if isinstance(separately, six.string_types): + separately = [separately] + separately = [e for e in separately if e] # make sure None and '' are not in the list + separately = list(set(separately_explicit) | set(separately)) + else: + separately = separately_explicit + super(LdaModel, self).save(fname, ignore=ignore, separately = separately, *args, **kwargs) + @classmethod def load(cls, fname, *args, **kwargs): """ @@ -1053,5 +1075,13 @@ def load(cls, fname, *args, **kwargs): result.state = super(LdaModel, cls).load(state_fname, *args, **kwargs) except Exception as e: logging.warning("failed to load state from %s: %s", state_fname, e) + id2word_fname = utils.smart_extension(fname, '.id2word') + if (os.path.isfile(id2word_fname)): + try: + result.id2word = utils.unpickle(id2word_fname) + except Exception as e: + logging.warning("failed to load id2word dictionary from %s: %s", id2word_fname, e) + else: + result.id2word = None return result # endclass LdaModel diff --git a/gensim/test/test_data/ldamodel_python_2_7 b/gensim/test/test_data/ldamodel_python_2_7 new file mode 100644 index 0000000000..f8ee3514e1 Binary files /dev/null and b/gensim/test/test_data/ldamodel_python_2_7 differ diff --git a/gensim/test/test_data/ldamodel_python_2_7.expElogbeta.npy b/gensim/test/test_data/ldamodel_python_2_7.expElogbeta.npy new file mode 100644 index 0000000000..1971e44b14 Binary files /dev/null and b/gensim/test/test_data/ldamodel_python_2_7.expElogbeta.npy differ diff --git a/gensim/test/test_data/ldamodel_python_2_7.id2word b/gensim/test/test_data/ldamodel_python_2_7.id2word new file mode 100644 index 0000000000..5fad7912fb Binary files /dev/null and b/gensim/test/test_data/ldamodel_python_2_7.id2word differ diff --git a/gensim/test/test_data/ldamodel_python_2_7.state b/gensim/test/test_data/ldamodel_python_2_7.state new file mode 100644 index 0000000000..424cf0096d Binary files /dev/null and b/gensim/test/test_data/ldamodel_python_2_7.state differ diff --git a/gensim/test/test_data/ldamodel_python_3_5 b/gensim/test/test_data/ldamodel_python_3_5 new file mode 100644 index 0000000000..0733b35fd5 Binary files /dev/null and b/gensim/test/test_data/ldamodel_python_3_5 differ diff --git a/gensim/test/test_data/ldamodel_python_3_5.expElogbeta.npy b/gensim/test/test_data/ldamodel_python_3_5.expElogbeta.npy new file mode 100644 index 0000000000..1971e44b14 Binary files /dev/null and b/gensim/test/test_data/ldamodel_python_3_5.expElogbeta.npy differ diff --git a/gensim/test/test_data/ldamodel_python_3_5.id2word b/gensim/test/test_data/ldamodel_python_3_5.id2word new file mode 100644 index 0000000000..8c8a2b7af0 Binary files /dev/null and b/gensim/test/test_data/ldamodel_python_3_5.id2word differ diff --git a/gensim/test/test_data/ldamodel_python_3_5.state b/gensim/test/test_data/ldamodel_python_3_5.state new file mode 100644 index 0000000000..62d7635c32 Binary files /dev/null and b/gensim/test/test_data/ldamodel_python_3_5.state differ diff --git a/gensim/test/test_ldamodel.py b/gensim/test/test_ldamodel.py index 7f1b4ba44e..041724dad2 100644 --- a/gensim/test/test_ldamodel.py +++ b/gensim/test/test_ldamodel.py @@ -44,9 +44,10 @@ corpus = [dictionary.doc2bow(text) for text in texts] -def testfile(): +def testfile(test_fname=''): # temporary data will be stored to this file - return os.path.join(tempfile.gettempdir(), 'gensim_models.tst') + fname = 'gensim_models_' + test_fname + '.tst' + return os.path.join(tempfile.gettempdir(), fname) def testRandomState(): @@ -61,6 +62,7 @@ def setUp(self): self.class_ = ldamodel.LdaModel self.model = self.class_(corpus, id2word=dictionary, num_topics=2, passes=100) + def testTransform(self): passed = False # sometimes, LDA training gets stuck at a local minimum @@ -408,8 +410,22 @@ def testPersistence(self): tstvec = [] self.assertTrue(np.allclose(model[tstvec], model2[tstvec])) # try projecting an empty vector + def testModelCompatibilityWithPythonVersions(self): + fname_model_2_7 = datapath('ldamodel_python_2_7') + model_2_7 = self.class_.load(fname_model_2_7) + fname_model_3_5 = datapath('ldamodel_python_3_5') + model_3_5 = self.class_.load(fname_model_3_5) + self.assertEqual(model_2_7.num_topics, model_3_5.num_topics) + self.assertTrue(np.allclose(model_2_7.expElogbeta, model_3_5.expElogbeta)) + tstvec = [] + self.assertTrue(np.allclose(model_2_7[tstvec], model_3_5[tstvec])) # try projecting an empty vector + id2word_2_7 = dict((k,v) for k,v in model_2_7.id2word.iteritems()) + id2word_3_5 = dict((k,v) for k,v in model_3_5.id2word.iteritems()) + self.assertEqual(set(id2word_2_7.keys()), set(id2word_3_5.keys())) + + def testPersistenceIgnore(self): - fname = testfile() + fname = testfile('testPersistenceIgnore') model = ldamodel.LdaModel(self.corpus, num_topics=2) model.save(fname, ignore='id2word') model2 = ldamodel.LdaModel.load(fname) diff --git a/gensim/utils.py b/gensim/utils.py index 5db16a92e5..907e1e657f 100644 --- a/gensim/utils.py +++ b/gensim/utils.py @@ -906,10 +906,12 @@ def pickle(obj, fname, protocol=2): def unpickle(fname): """Load pickled object from `fname`""" - with smart_open(fname) as f: + with smart_open(fname, 'rb') as f: # Because of loading from S3 load can't be used (missing readline in smart_open) - return _pickle.loads(f.read()) - + if sys.version_info > (3, 0): + return _pickle.load(f, encoding='latin1') + else: + return _pickle.loads(f.read()) def revdict(d): """