From d00b6c1da5788b631e71d00e494bb3a0b1e7bf51 Mon Sep 17 00:00:00 2001 From: Oliver Eberle Date: Mon, 27 Mar 2017 23:18:15 +0200 Subject: [PATCH] Fix #1230. Fix word2vec reset_from bug in v1.0.1 and added unittest (#1234) * Update CHANGELOG.txt * Update CHANGELOG.txt * Release version typo fix * Typo in version * Upgraded to match word2vec class structure * Added unittest for reset_from * Fixed typo * Positive reset_from() unittest * Change unittest to check if attributes are shared * Formatting fixed --- gensim/models/word2vec.py | 4 ++-- gensim/test/test_word2vec.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/gensim/models/word2vec.py b/gensim/models/word2vec.py index 000eee6976..c29d61126c 100644 --- a/gensim/models/word2vec.py +++ b/gensim/models/word2vec.py @@ -738,8 +738,8 @@ def reset_from(self, other_model): Borrow shareable pre-built structures (like vocab) from the other_model. Useful if testing multiple models in parallel on the same corpus. """ - self.wv.vocab = other_model.vocab - self.wv.index2word = other_model.index2word + self.wv.vocab = other_model.wv.vocab + self.wv.index2word = other_model.wv.index2word self.cum_table = other_model.cum_table self.corpus_count = other_model.corpus_count self.reset_weights() diff --git a/gensim/test/test_word2vec.py b/gensim/test/test_word2vec.py index 8c15b9d9a5..e37968218c 100644 --- a/gensim/test/test_word2vec.py +++ b/gensim/test/test_word2vec.py @@ -678,6 +678,16 @@ def test_sentences_should_not_be_a_generator(self): def testLoadOnClassError(self): """Test if exception is raised when loading word2vec model on instance""" self.assertRaises(AttributeError, load_on_instance) + + def test_reset_from(self): + """Test if reset_from() uses pre-built structures from other model""" + model = word2vec.Word2Vec(sentences, min_count=1) + other_model = word2vec.Word2Vec(new_sentences, min_count=1) + other_vocab = other_model.wv.vocab + model.reset_from(other_model) + self.assertEqual(model.wv.vocab, other_vocab) + + #endclass TestWord2VecModel class TestWMD(unittest.TestCase):