diff --git a/gensim/models/wrappers/ldamallet.py b/gensim/models/wrappers/ldamallet.py index e1bf2a85e8..eee1a542a6 100644 --- a/gensim/models/wrappers/ldamallet.py +++ b/gensim/models/wrappers/ldamallet.py @@ -78,7 +78,7 @@ class LdaMallet(utils.SaveLoad, basemodel.BaseTopicModel): """ def __init__(self, mallet_path, corpus=None, num_topics=100, alpha=50, id2word=None, workers=4, prefix=None, - optimize_interval=0, iterations=1000, topic_threshold=0.0): + optimize_interval=0, iterations=1000, topic_threshold=0.0, random_seed=0): """ Parameters @@ -104,6 +104,8 @@ def __init__(self, mallet_path, corpus=None, num_topics=100, alpha=50, id2word=N Number of training iterations. topic_threshold : float, optional Threshold of the probability above which we consider a topic. + random_seed: int, optional + Random seed to ensure consistent results, if 0 - use system clock. """ self.mallet_path = mallet_path @@ -126,6 +128,7 @@ def __init__(self, mallet_path, corpus=None, num_topics=100, alpha=50, id2word=N self.workers = workers self.optimize_interval = optimize_interval self.iterations = iterations + self.random_seed = random_seed if corpus is not None: self.train(corpus) @@ -271,11 +274,12 @@ def train(self, corpus): self.convert_input(corpus, infer=False) cmd = self.mallet_path + ' train-topics --input %s --num-topics %s --alpha %s --optimize-interval %s '\ '--num-threads %s --output-state %s --output-doc-topics %s --output-topic-keys %s '\ - '--num-iterations %s --inferencer-filename %s --doc-topics-threshold %s' + '--num-iterations %s --inferencer-filename %s --doc-topics-threshold %s --random-seed %s' + cmd = cmd % ( self.fcorpusmallet(), self.num_topics, self.alpha, self.optimize_interval, self.workers, self.fstate(), self.fdoctopics(), self.ftopickeys(), self.iterations, - self.finferencer(), self.topic_threshold + self.finferencer(), self.topic_threshold, str(self.random_seed) ) # NOTE "--keep-sequence-bigrams" / "--use-ngrams true" poorer results + runs out of memory logger.info("training MALLET LDA with %s", cmd) @@ -312,10 +316,10 @@ def __getitem__(self, bow, iterations=100): self.convert_input(bow, infer=True) cmd = \ self.mallet_path + ' infer-topics --input %s --inferencer %s ' \ - '--output-doc-topics %s --num-iterations %s --doc-topics-threshold %s' + '--output-doc-topics %s --num-iterations %s --doc-topics-threshold %s --random-seed %s' cmd = cmd % ( self.fcorpusmallet() + '.infer', self.finferencer(), - self.fdoctopics() + '.infer', iterations, self.topic_threshold + self.fdoctopics() + '.infer', iterations, self.topic_threshold, str(self.random_seed) ) logger.info("inferring topics with MALLET LDA '%s'", cmd) check_output(args=cmd, shell=True) @@ -565,6 +569,17 @@ def read_doctopics(self, fname, eps=1e-6, renorm=True): doc = [(id_, float(weight) / total_weight) for id_, weight in doc] yield doc + @classmethod + def load(cls, *args, **kwargs): + """Load a previously saved LdaMallet class. Handles backwards compatibility from + older LdaMallet versions which did not use random_seed parameter. + """ + model = super(LdaMallet, cls).load(*args, **kwargs) + if not hasattr(model, 'random_seed'): + model.random_seed = 0 + + return model + def malletmodel2ldamodel(mallet_model, gamma_threshold=0.001, iterations=50): """Convert :class:`~gensim.models.wrappers.ldamallet.LdaMallet` to :class:`~gensim.models.ldamodel.LdaModel`. diff --git a/gensim/test/test_ldamallet_wrapper.py b/gensim/test/test_ldamallet_wrapper.py index f8f432bc96..2ad1ccb9c9 100644 --- a/gensim/test/test_ldamallet_wrapper.py +++ b/gensim/test/test_ldamallet_wrapper.py @@ -175,6 +175,40 @@ def testLargeMmapCompressed(self): # test loading the large model arrays with mmap self.assertRaises(IOError, ldamodel.LdaModel.load, fname, mmap='r') + def test_random_seed(self): + if not self.mallet_path: + return + + # test that 2 models created with the same random_seed are equal in their topics treatment + SEED = 10 + NUM_TOPICS = 10 + ITER = 500 + + tm1 = ldamallet.LdaMallet( + self.mallet_path, + corpus=corpus, + num_topics=NUM_TOPICS, + id2word=dictionary, + random_seed=SEED, + iterations=ITER, + ) + + tm2 = ldamallet.LdaMallet( + self.mallet_path, + corpus=corpus, + num_topics=NUM_TOPICS, + id2word=dictionary, + random_seed=SEED, + iterations=ITER, + ) + self.assertTrue(np.allclose(tm1.word_topics, tm2.word_topics)) + + for doc in corpus: + tm1_vector = matutils.sparse2full(tm1[doc], NUM_TOPICS) + tm2_vector = matutils.sparse2full(tm2[doc], NUM_TOPICS) + + self.assertTrue(np.allclose(tm1_vector, tm2_vector)) + if __name__ == '__main__': logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)