-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add smart information retrieval system for TfidfModel
. Fix #1785
#1791
Changes from 3 commits
5e1830b
6cef4b1
e8a3f16
648bf21
a6f1afb
d091138
951c549
40c0558
b35344c
0917e75
bef79cc
d3d431c
0e6f21e
7ee7560
f2251a4
b2def84
5b2d37a
ac4b154
0bacc08
51e0eb9
3039732
99e6a6f
7d63d9c
e5140f8
4afbadd
d2fe235
52ee3c4
48e84f7
6d2f47b
607ba61
d0878a4
b544c9c
c4e3656
98ffde5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,24 +11,27 @@ | |
from gensim import interfaces, matutils, utils | ||
from six import iteritems | ||
|
||
import numpy as np | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def df2idf(docfreq, totaldocs, log_base=2.0, add=0.0): | ||
""" | ||
Compute default inverse-document-frequency for a term with document frequency `doc_freq`:: | ||
def resolve_weights(smartirs): | ||
if not isinstance(smartirs, str) or len(smartirs) != 3: | ||
raise ValueError('Expected a string of length 3 except got ' + smartirs) | ||
|
||
idf = add + log(totaldocs / doc_freq) | ||
""" | ||
return add + math.log(1.0 * totaldocs / docfreq, log_base) | ||
w_tf, w_df, w_n = smartirs | ||
|
||
if w_tf not in 'nlabL': | ||
raise ValueError('Expected term frequency weight to be one of nlabL, except got ' + w_tf) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. better use |
||
|
||
if w_df not in 'ntp': | ||
raise ValueError('Expected inverse document frequency weight to be one of ntp, except got ' + w_df) | ||
|
||
def precompute_idfs(wglobal, dfs, total_docs): | ||
"""Precompute the inverse document frequency mapping for all terms.""" | ||
# not strictly necessary and could be computed on the fly in TfidfModel__getitem__. | ||
# this method is here just to speed things up a little. | ||
return {termid: wglobal(df, total_docs) for termid, df in iteritems(dfs)} | ||
if w_n not in 'ncb': | ||
raise ValueError('Expected normalization weight to be one of ncb, except got ' + w_n) | ||
|
||
return w_tf, w_df, w_n | ||
|
||
|
||
class TfidfModel(interfaces.TransformationABC): | ||
|
@@ -49,8 +52,8 @@ class TfidfModel(interfaces.TransformationABC): | |
Model persistency is achieved via its load/save methods. | ||
""" | ||
|
||
def __init__(self, corpus=None, id2word=None, dictionary=None, | ||
wlocal=utils.identity, wglobal=df2idf, normalize=True): | ||
def __init__(self, corpus=None, id2word=None, dictionary=None, smartirs="ntc", | ||
wlocal=None, wglobal=None, normalize=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better to support backward compatibility, why you change default values? Good solution - by default support default behavior and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
""" | ||
Compute tf-idf by multiplying a local component (term frequency) with a | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you convert all docstrings in this file to numpy-style, according to my previous comment #1780 (comment) |
||
global component (inverse document frequency), and normalizing | ||
|
@@ -78,10 +81,41 @@ def __init__(self, corpus=None, id2word=None, dictionary=None, | |
and it will be used to directly construct the inverse document frequency | ||
mapping (then `corpus`, if specified, is ignored). | ||
""" | ||
self.normalize = normalize | ||
self.id2word = id2word | ||
self.wlocal, self.wglobal = wlocal, wglobal | ||
self.wlocal, self.wglobal, self.normalize = wlocal, wglobal, normalize | ||
self.num_docs, self.num_nnz, self.idfs = None, None, None | ||
n_tf, n_df, n_n = smartirs | ||
self.smartirs = smartirs | ||
|
||
if self.wlocal is None: | ||
if n_tf == "n": | ||
self.wlocal = lambda tf, mean=None, _max=None: tf | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. better to use simple define (instead of lambda) for avoiding pickle problems (here and everywhere) |
||
elif n_tf == "l": | ||
self.wlocal = lambda tf, mean=None, _max=None: 1 + math.log(tf) | ||
elif n_tf == "a": | ||
self.wlocal = lambda tf, mean=None, _max=None: 0.5 + (0.5 * tf / _max) | ||
elif n_tf == "b": | ||
self.wlocal = lambda tf, mean=None, _max=None: 1 if tf > 0 else 0 | ||
elif n_tf == "L": | ||
self.wlocal = lambda tf, mean=None, _max=None: (1 + math.log(tf)) / (1 + math.log(mean)) | ||
|
||
if self.wglobal is None: | ||
if n_df == "n": | ||
self.wglobal = utils.identity | ||
elif n_df == "t": | ||
self.wglobal = lambda docfreq, totaldocs: math.log(1.0 * totaldocs / docfreq, 10) | ||
elif n_tf == "p": | ||
self.wglobal = lambda docfreq, totaldocs: math.log((float(totaldocs) - docfreq) / docfreq) | ||
|
||
if self.normalize is None or isinstance(self.normalize, bool): | ||
if n_n == "n" or self.normalize is False: | ||
self.normalize = lambda x: x | ||
elif n_n == "c" or self.normalize is True: | ||
self.normalize = matutils.unitvec | ||
# TODO write byte-size normalisation | ||
# elif n_n == "b": | ||
# self.normalize = matutils.unitvec | ||
|
||
if dictionary is not None: | ||
# user supplied a Dictionary object, which already contains all the | ||
# statistics we need to construct the IDF mapping. we can skip the | ||
|
@@ -92,7 +126,7 @@ def __init__(self, corpus=None, id2word=None, dictionary=None, | |
) | ||
self.num_docs, self.num_nnz = dictionary.num_docs, dictionary.num_nnz | ||
self.dfs = dictionary.dfs.copy() | ||
self.idfs = precompute_idfs(self.wglobal, self.dfs, self.num_docs) | ||
|
||
if id2word is None: | ||
self.id2word = dictionary | ||
elif corpus is not None: | ||
|
@@ -113,6 +147,7 @@ def initialize(self, corpus): | |
logger.info("collecting document frequencies") | ||
dfs = {} | ||
numnnz, docno = 0, -1 | ||
|
||
for docno, bow in enumerate(corpus): | ||
if docno % 10000 == 0: | ||
logger.info("PROGRESS: processing document #%i", docno) | ||
|
@@ -127,11 +162,6 @@ def initialize(self, corpus): | |
|
||
# and finally compute the idf weights | ||
n_features = max(dfs) if dfs else 0 | ||
logger.info( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why you remove this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This showed the progress of the |
||
"calculating IDF weights for %i documents and %i features (%i matrix non-zeros)", | ||
self.num_docs, n_features, self.num_nnz | ||
) | ||
self.idfs = precompute_idfs(self.wglobal, self.dfs, self.num_docs) | ||
|
||
def __getitem__(self, bow, eps=1e-12): | ||
""" | ||
|
@@ -144,17 +174,16 @@ def __getitem__(self, bow, eps=1e-12): | |
|
||
# unknown (new) terms will be given zero weight (NOT infinity/huge weight, | ||
# as strict application of the IDF formula would dictate) | ||
|
||
vector = [ | ||
(termid, self.wlocal(tf) * self.idfs.get(termid)) | ||
for termid, tf in bow if self.idfs.get(termid, 0.0) != 0.0 | ||
(termid, self.wlocal(tf, mean=np.mean(np.array(bow), axis=1), _max=np.max(bow, axis=1)) * self.wglobal(self.dfs[termid], self.num_docs)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks wasteful (creating arrays, only to throw them away). What are the performance implications of these changes? Do you have a benchmark before/after? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed the approach. |
||
for termid, tf in bow if self.wglobal(self.dfs[termid], self.num_docs) != 0.0 | ||
] | ||
|
||
# and finally, normalize the vector either to unit length, or use a | ||
# user-defined normalization function | ||
if self.normalize is True: | ||
vector = matutils.unitvec(vector) | ||
elif self.normalize: | ||
vector = self.normalize(vector) | ||
|
||
vector = self.normalize(vector) | ||
|
||
# make sure there are no explicit zeroes in the vector (must be sparse) | ||
vector = [(termid, weight) for termid, weight in vector if abs(weight) > eps] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -498,7 +498,6 @@ def testPersistence(self): | |
original_matrix = self.model.transform(original_bow) | ||
passed = numpy.allclose(loaded_matrix, original_matrix, atol=1e-1) | ||
self.assertTrue(passed) | ||
|
||
def testModelNotFitted(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to add more tests (for new functionality) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have that in my checklist but before that I need to pass the already present tests. |
||
lsi_wrapper = LsiTransformer(id2word=dictionary, num_topics=2) | ||
texts_new = ['graph', 'eulerian'] | ||
|
@@ -973,13 +972,13 @@ def testTransform(self): | |
|
||
def testSetGetParams(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't forget to add more tests (also, check situations, when you pass |
||
# updating only one param | ||
self.model.set_params(normalize=False) | ||
self.model.set_params(smartirs='nnn') | ||
model_params = self.model.get_params() | ||
self.assertEqual(model_params["normalize"], False) | ||
self.assertEqual(model_params["smartirs"], 'nnn') | ||
|
||
# verify that the attributes values are also changed for `gensim_model` after fitting | ||
self.model.fit(self.corpus) | ||
self.assertEqual(getattr(self.model.gensim_model, 'normalize'), False) | ||
self.assertEqual(getattr(self.model.gensim_model, 'smartirs'), 'nnn') | ||
|
||
def testPipeline(self): | ||
with open(datapath('mini_newsgroup'), 'rb') as f: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstrings needed too (for all stuff here)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that
Checks for validity of smartirs parameter.
is enough. Do you have anything else in mind as well?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@markroxor need to add "Parameters" (type, description), "Raises" (type, reason), "Returns" (type, description)