-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix return dtype for matutils.unitvec
according to input dtype. Fix #1722
#1992
Changes from 5 commits
82b8d17
834b042
cf463b2
5656835
406ed66
e71afcd
fe36408
50d011c
f1a40ac
ead451f
ae03291
cab90a8
141833d
218fe42
5fd1004
80628c0
768226b
438f763
f73076a
cd50529
11b0dde
2e86529
30d6284
d9cfb0d
2e8eca1
6beec75
365a722
9327f68
a42cf38
1078bdd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -667,45 +667,51 @@ def ret_log_normalize_vec(vec, axis=1): | |
|
||
def unitvec(vec, norm='l2'): | ||
"""Scale a vector to unit length. | ||
|
||
Parameters | ||
---------- | ||
vec : {numpy.ndarray, scipy.sparse, list of (int, float)} | ||
Input vector in any format | ||
norm : {'l1', 'l2'}, optional | ||
Normalization that will be used. | ||
|
||
Returns | ||
------- | ||
{numpy.ndarray, scipy.sparse, list of (int, float)} | ||
Normalized vector in same format as `vec`. | ||
|
||
Notes | ||
----- | ||
Zero-vector will be unchanged. | ||
|
||
""" | ||
if norm not in ('l1', 'l2'): | ||
raise ValueError("'%s' is not a supported norm. Currently supported norms are 'l1' and 'l2'." % norm) | ||
|
||
if scipy.sparse.issparse(vec): | ||
vec = vec.tocsr() | ||
if norm == 'l1': | ||
veclen = np.sum(np.abs(vec.data)) | ||
if norm == 'l2': | ||
veclen = np.sqrt(np.sum(vec.data ** 2)) | ||
if veclen > 0.0: | ||
return vec / veclen | ||
if np.issubdtype(vec.dtype, np.int) == True: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to |
||
vec = vec.astype(np.float) | ||
return vec / veclen | ||
else: | ||
vec /= veclen | ||
return vec.astype(vec.dtype) | ||
else: | ||
return vec | ||
|
||
if isinstance(vec, np.ndarray): | ||
vec = np.asarray(vec, dtype=float) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that this line is needed especially for |
||
vec = np.asarray(vec, dtype=vec.dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe this have no sense (because later you'll cast it again) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree - this pretty much seems like a no-op, effectively. Does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah yes, I forgot about this - I will remove it on the next commit. |
||
if norm == 'l1': | ||
veclen = np.sum(np.abs(vec)) | ||
if norm == 'l2': | ||
veclen = blas_nrm2(vec) | ||
if veclen > 0.0: | ||
return blas_scal(1.0 / veclen, vec) | ||
if np.issubdtype(vec.dtype, np.int) == True: | ||
vec = vec.astype(np.float) | ||
return blas_scal(1.0 / veclen, vec).astype(vec.dtype) | ||
else: | ||
return blas_scal(1.0 / veclen, vec).astype(vec.dtype) | ||
else: | ||
return vec | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import numpy as np | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://github.com/RaRe-Technologies/gensim/blob/develop/gensim/test/test_matutils.py is more suitable place for this test There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Created pull request there |
||
from scipy import sparse | ||
import unittest | ||
import matutils | ||
|
||
class UnitvecTestCase(unittest.TestCase): | ||
|
||
def manual_unitvec(self, vec): | ||
self.vec = vec | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please use 4 spaces for indentation |
||
if sparse.issparse(self.vec): | ||
vec_sum_of_squares = self.vec.multiply(self.vec) | ||
unit = 1. / np.sqrt(vec_sum_of_squares.sum()) | ||
return self.vec.multiply(unit) | ||
elif not sparse.issparse(self.vec): | ||
sum_vec_squared = np.sum(self.vec ** 2) | ||
self.vec /= np.sqrt(sum_vec_squared) | ||
return self.vec | ||
|
||
def test_unitvec(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's about different vectors (sparse) + different types (more floats + int too)? |
||
input_vector = np.random.uniform(size=(5,)).astype(np.float32) | ||
unit_vector = matutils.unitvec(input_vector) | ||
self.assertEqual(input_vector.dtype, unit_vector.dtype) | ||
self.assertTrue(np.allclose(unit_vector, self.manual_unitvec(input_vector))) | ||
|
||
if __name__ == '__main__': | ||
|
||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please return empty lines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed