Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lda models load/save backward compatibility across Python versions #1039

Merged
merged 16 commits into from
Dec 22, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions gensim/models/ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import logging
import numpy as np # for arrays, array broadcasting etc.
import numbers
import os

from gensim import interfaces, utils, matutils
from gensim.models import basemodel
Expand Down Expand Up @@ -239,11 +240,11 @@ def __init__(self, corpus=None, num_topics=100, id2word=None,
prior directly from your data.

`eta` can be a scalar for a symmetric prior over topic/word
distributions, or a vector of shape num_words, which can be used to
impose (user defined) asymmetric priors over the word distribution.
distributions, or a vector of shape num_words, which can be used to
impose (user defined) asymmetric priors over the word distribution.
It also supports the special value 'auto', which learns an asymmetric
prior over words directly from your data. `eta` can also be a matrix
of shape num_topics x num_words, which can be used to impose
of shape num_topics x num_words, which can be used to impose
asymmetric priors over the word distribution on a per-topic basis
(can not be learned from data).

Expand Down Expand Up @@ -995,7 +996,7 @@ def __getitem__(self, bow, eps=None):
"""
return self.get_document_topics(bow, eps, self.minimum_phi_value, self.per_word_topics)

def save(self, fname, ignore=['state', 'dispatcher'], *args, **kwargs):
def save(self, fname, ignore=['state', 'dispatcher'], separately=None, *args, **kwargs):
"""
Save the model to file.

Expand Down Expand Up @@ -1024,18 +1025,39 @@ def save(self, fname, ignore=['state', 'dispatcher'], *args, **kwargs):
"""
if self.state is not None:
self.state.save(utils.smart_extension(fname, '.state'), *args, **kwargs)
# Save the dictionary separately if not in 'ignore'.
if 'id2word' not in ignore:
utils.pickle(self.id2word, utils.smart_extension(fname, '.id2word'))

# make sure 'state' and 'dispatcher' are ignored from the pickled object, even if
# make sure 'state', 'id2word' and 'dispatcher' are ignored from the pickled object, even if
# someone sets the ignore list themselves
if ignore is not None and ignore:
if isinstance(ignore, six.string_types):
ignore = [ignore]
ignore = [e for e in ignore if e] # make sure None and '' are not in the list
ignore = list(set(['state', 'dispatcher']) | set(ignore))
ignore = list(set(['state', 'dispatcher', 'id2word']) | set(ignore))
else:
ignore = ['state', 'dispatcher']
super(LdaModel, self).save(fname, *args, ignore=ignore, **kwargs)

ignore = ['state', 'dispatcher', 'id2word']

# make sure 'expElogbeta' and 'sstats' are ignored from the pickled object, even if
# someone sets the separately list themselves.
separately_explicit = ['expElogbeta', 'sstats']
# Also add 'alpha' and 'eta' to separately list if they are set 'auto' or some
# array manually.
if (isinstance(self.alpha, six.string_types) and self.alpha == 'auto') or len(self.alpha.shape) != 1:
separately_explicit.append('alpha')
if (isinstance(self.eta, six.string_types) and self.eta == 'auto') or len(self.eta.shape) != 1:
separately_explicit.append('eta')
# Merge separately_explicit with separately.
if separately:
if isinstance(separately, six.string_types):
separately = [separately]
separately = [e for e in separately if e] # make sure None and '' are not in the list
separately = list(set(separately_explicit) | set(separately))
else:
separately = separately_explicit
super(LdaModel, self).save(fname, ignore=ignore, separately = separately, *args, **kwargs)

@classmethod
def load(cls, fname, *args, **kwargs):
"""
Expand All @@ -1053,5 +1075,13 @@ def load(cls, fname, *args, **kwargs):
result.state = super(LdaModel, cls).load(state_fname, *args, **kwargs)
except Exception as e:
logging.warning("failed to load state from %s: %s", state_fname, e)
id2word_fname = utils.smart_extension(fname, '.id2word')
if (os.path.isfile(id2word_fname)):
try:
result.id2word = utils.unpickle(id2word_fname)
except Exception as e:
logging.warning("failed to load id2word dictionary from %s: %s", id2word_fname, e)
else:
result.id2word = None
return result
# endclass LdaModel
Binary file added gensim/test/test_data/ldamodel_python_2_7
Binary file not shown.
Binary file not shown.
Binary file added gensim/test/test_data/ldamodel_python_2_7.id2word
Binary file not shown.
Binary file added gensim/test/test_data/ldamodel_python_2_7.state
Binary file not shown.
Binary file added gensim/test/test_data/ldamodel_python_3_5
Binary file not shown.
Binary file not shown.
Binary file added gensim/test/test_data/ldamodel_python_3_5.id2word
Binary file not shown.
Binary file added gensim/test/test_data/ldamodel_python_3_5.state
Binary file not shown.
22 changes: 19 additions & 3 deletions gensim/test/test_ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@
corpus = [dictionary.doc2bow(text) for text in texts]


def testfile():
def testfile(test_fname=''):
# temporary data will be stored to this file
return os.path.join(tempfile.gettempdir(), 'gensim_models.tst')
fname = 'gensim_models_' + test_fname + '.tst'
return os.path.join(tempfile.gettempdir(), fname)


def testRandomState():
Expand All @@ -61,6 +62,7 @@ def setUp(self):
self.class_ = ldamodel.LdaModel
self.model = self.class_(corpus, id2word=dictionary, num_topics=2, passes=100)


def testTransform(self):
passed = False
# sometimes, LDA training gets stuck at a local minimum
Expand Down Expand Up @@ -408,8 +410,22 @@ def testPersistence(self):
tstvec = []
self.assertTrue(np.allclose(model[tstvec], model2[tstvec])) # try projecting an empty vector

def testModelCompatibilityWithPythonVersions(self):
fname_model_2_7 = datapath('ldamodel_python_2_7')
model_2_7 = self.class_.load(fname_model_2_7)
fname_model_3_5 = datapath('ldamodel_python_3_5')
model_3_5 = self.class_.load(fname_model_3_5)
self.assertEqual(model_2_7.num_topics, model_3_5.num_topics)
self.assertTrue(np.allclose(model_2_7.expElogbeta, model_3_5.expElogbeta))
tstvec = []
self.assertTrue(np.allclose(model_2_7[tstvec], model_3_5[tstvec])) # try projecting an empty vector
id2word_2_7 = dict((k,v) for k,v in model_2_7.id2word.iteritems())
id2word_3_5 = dict((k,v) for k,v in model_3_5.id2word.iteritems())
self.assertEqual(set(id2word_2_7.keys()), set(id2word_3_5.keys()))


def testPersistenceIgnore(self):
fname = testfile()
fname = testfile('testPersistenceIgnore')
model = ldamodel.LdaModel(self.corpus, num_topics=2)
model.save(fname, ignore='id2word')
model2 = ldamodel.LdaModel.load(fname)
Expand Down
8 changes: 5 additions & 3 deletions gensim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,10 +906,12 @@ def pickle(obj, fname, protocol=2):

def unpickle(fname):
"""Load pickled object from `fname`"""
with smart_open(fname) as f:
with smart_open(fname, 'rb') as f:
# Because of loading from S3 load can't be used (missing readline in smart_open)
return _pickle.loads(f.read())

if sys.version_info > (3, 0):
return _pickle.load(f, encoding='latin1')
Copy link
Owner

@piskvorky piskvorky Dec 8, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely not! What is this latin1?

The content is (and should be read as) binary.

Copy link
Contributor Author

@anmolgulati anmolgulati Dec 15, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works as a fix for when loading objects in Python 3 which were pickled in Python 2, which gives an exception.
Basically, Python 3 attempts to convert the pickled py2 object into a str object, when we need it to be bytes and gives an exception. I used the latin1 encoding for as a work around for that. (Asked on Stackoverflow)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a code comment to explain this?

Copy link
Owner

@piskvorky piskvorky Dec 27, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this hack needs to be marked and explained thoroughly in a comment.

I'm not familiar with such py2/py3 pickling work arounds, but isn't there a cleaner way to achieve the same effect? This sticks out like a sore thumb. @tmylk @anmol01gulati

Copy link
Contributor Author

@anmolgulati anmolgulati Dec 27, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piskvorky Umm, I had actually searched quite a lot, and tried various things on my system. This is the only way(a hack actually), I found, through which it works. By the way, I felt, we would not want to have this functionality in the future and could do away with the backward compatibility, if majority of the users shift to one Python 3 later (it's not the case right now though).
I'll open up a new PR to add a comment in the code though.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am coding something entirely different and this solution is the only thing that worked for loading python2 pickles in python3... The creators claim that pickle is backwards compatible but apparently only if I pass latin1... Any other way just breaks and burns.

else:
return _pickle.loads(f.read())

def revdict(d):
"""
Expand Down