-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix return dtype for matutils.unitvec
according to input dtype. Fix #1722
#1992
Changes from 12 commits
82b8d17
834b042
cf463b2
5656835
406ed66
e71afcd
fe36408
50d011c
f1a40ac
ead451f
ae03291
cab90a8
141833d
218fe42
5fd1004
80628c0
768226b
438f763
f73076a
cd50529
11b0dde
2e86529
30d6284
d9cfb0d
2e8eca1
6beec75
365a722
9327f68
a42cf38
1078bdd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -668,7 +668,7 @@ def ret_log_normalize_vec(vec, axis=1): | |
|
||
def unitvec(vec, norm='l2', return_norm=False): | ||
"""Scale a vector to unit length. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no leading spaces (here and everywhere) |
||
Parameters | ||
---------- | ||
vec : {numpy.ndarray, scipy.sparse, list of (int, float)} | ||
|
@@ -677,49 +677,53 @@ def unitvec(vec, norm='l2', return_norm=False): | |
Normalization that will be used. | ||
return_norm : bool, optional | ||
If True - returns the length of vector `vec`. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please return empty lines There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||
Returns | ||
------- | ||
numpy.ndarray, scipy.sparse, list of (int, float)} | ||
Normalized vector in same format as `vec`. | ||
float | ||
Length of `vec` before normalization. | ||
|
||
Notes | ||
----- | ||
Zero-vector will be unchanged. | ||
|
||
""" | ||
if norm not in ('l1', 'l2'): | ||
raise ValueError("'%s' is not a supported norm. Currently supported norms are 'l1' and 'l2'." % norm) | ||
|
||
if scipy.sparse.issparse(vec): | ||
vec = vec.tocsr() | ||
if norm == 'l1': | ||
veclen = np.sum(np.abs(vec.data)) | ||
if norm == 'l2': | ||
veclen = np.sqrt(np.sum(vec.data ** 2)) | ||
if veclen > 0.0: | ||
if return_norm: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert existing code please (you shouldn't remove it) |
||
return vec / veclen, veclen | ||
else: | ||
if np.issubdtype(vec.dtype, np.int): | ||
vec = vec.astype(np.float) | ||
return vec / veclen | ||
else: | ||
vec /= veclen | ||
return vec.astype(vec.dtype) | ||
else: | ||
if return_norm: | ||
return vec, 1. | ||
else: | ||
return vec | ||
|
||
if isinstance(vec, np.ndarray): | ||
vec = np.asarray(vec, dtype=float) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that this line is needed especially for |
||
vec = np.asarray(vec, dtype=vec.dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe this have no sense (because later you'll cast it again) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree - this pretty much seems like a no-op, effectively. Does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah yes, I forgot about this - I will remove it on the next commit. |
||
if norm == 'l1': | ||
veclen = np.sum(np.abs(vec)) | ||
if norm == 'l2': | ||
veclen = blas_nrm2(vec) | ||
if veclen > 0.0: | ||
if return_norm: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This entire construct (and a similar construct above) seems to have some unnecessary redundant code. We could simplify to something like - if veclen > 0.0:
if np.issubdtype(vec.dtype, np.int):
vec = vec.astype(np.float)
if return_norm:
return blas_scal(1.0 / veclen, vec).astype(vec.dtype), veclen
else:
return blas_scal(1.0 / veclen, vec).astype(vec.dtype) |
||
return blas_scal(1.0 / veclen, vec), veclen | ||
if np.issubdtype(vec.dtype, np.int): | ||
vec = vec.astype(np.float) | ||
return blas_scal(1.0 / veclen, vec).astype(vec.dtype) | ||
else: | ||
return blas_scal(1.0 / veclen, vec) | ||
return blas_scal(1.0 / veclen, vec).astype(vec.dtype) | ||
else: | ||
if return_norm: | ||
return vec, 1 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -140,6 +140,46 @@ def testDirichletExpectation(self): | |
msg = "dirichlet_expectation_2d failed for dtype={}".format(dtype) | ||
self.assertTrue(np.allclose(known_good, test_values), msg) | ||
|
||
class UnitvecTestCase(unittest.TestCase): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for reorganizing the tests! Looks much better now IMO |
||
# test unitvec | ||
def manual_unitvec(self, vec): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Definitely should be simplified - why use vec = vec.astype(np.float)
if sparse.issparse(vec):
vec_sum_of_squares = vec.multiply(vec)
unit = 1. / np.sqrt(vec_sum_of_squares.sum())
return vec.multiply(unit)
else:
sum_vec_squared = np.sum(vec ** 2)
vec /= np.sqrt(sum_vec_squared)
return vec |
||
self.vec = vec | ||
self.vec = self.vec.astype(np.float) | ||
if sparse.issparse(self.vec): | ||
vec_sum_of_squares = self.vec.multiply(self.vec) | ||
unit = 1. / np.sqrt(vec_sum_of_squares.sum()) | ||
return self.vec.multiply(unit) | ||
elif not sparse.issparse(self.vec): | ||
sum_vec_squared = np.sum(self.vec ** 2) | ||
self.vec /= np.sqrt(sum_vec_squared) | ||
return self.vec | ||
|
||
def test_inputs(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO we should split this test into multiple tests (one per combination of |
||
input_dtypes = [np.float32, np.float64, np.int32, np.int64, float, int] | ||
input_arrtypes = ['sparse', 'dense'] | ||
for dtype_ in input_dtypes: | ||
for arrtype in input_arrtypes: | ||
if arrtype == 'dense': | ||
if dtype_ == np.float32 or dtype_ == np.float64: | ||
input_vector = np.random.uniform(size=(5,)).astype(dtype_) | ||
unit_vector = unitvec_with_bug.unitvec(input_vector) | ||
man_unit_vector = self.manual_unitvec(input_vector) | ||
self.assertEqual(input_vector.dtype, unit_vector.dtype) | ||
self.assertTrue(np.allclose(unit_vector, man_unit_vector)) | ||
else: | ||
input_vector = np.random.randint(10, size=5).astype(dtype_) | ||
unit_vector = unitvec_with_bug.unitvec(input_vector) | ||
man_unit_vector = self.manual_unitvec(input_vector) | ||
self.assertTrue(np.allclose(unit_vector, man_unit_vector)) | ||
else: | ||
input_vector = sparse.csr_matrix(np.asarray([[1, 0, 0, 0, 0, 3, 0, 0], [0, 0, 4, 3, 0, 0, 0, 0]]).astype(dtype_)) | ||
unit_vector = unitvec_with_bug.unitvec(input_vector) | ||
man_unit_vector = self.manual_unitvec(input_vector) | ||
if dtype_ == np.float32 or dtype_ == np.float64: | ||
self.assertEqual(input_vector.dtype, unit_vector.dtype) | ||
self.assertTrue(np.allclose(unit_vector.data, man_unit_vector.data, atol=1e-3)) | ||
else: | ||
self.assertTrue(np.allclose(unit_vector.data, man_unit_vector.data, atol=1e-3)) | ||
|
||
if __name__ == '__main__': | ||
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please don't make unrelated changes, this empty line(s) is correct by docstring convention.