-
Notifications
You must be signed in to change notification settings - Fork 121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Support for Median #907
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #907 +/- ##
==========================================
+ Coverage 81.89% 81.90% +0.01%
==========================================
Files 182 182
Lines 47778 47799 +21
Branches 8597 8599 +2
==========================================
+ Hits 39126 39150 +24
+ Misses 6487 6485 -2
+ Partials 2165 2164 -1
|
Hi @ricardoV94 is there something left here? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments/questions
pytensor/tensor/math.py
Outdated
elif isinstance(axis, int | np.integer): | ||
axis = [axis] | ||
elif isinstance(axis, np.ndarray) and axis.ndim == 0: | ||
axis = [int(axis)] | ||
else: | ||
axis = [int(a) for a in axis] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use normalize_axis_tuple for these 3 cases like we do elsewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Surely. It won't work in cases of axis=None
right?
Stuff like var
and mean
use this so should we open an issue to replace those?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we should use it everywhere, and no I don't think it handles axis=None, but we can double check
pytensor/tensor/math.py
Outdated
indices1 = expand_dims(full_like(sorted_input.take(0, axis=axis), k - 1), axis) | ||
indices2 = expand_dims(full_like(sorted_input.take(0, axis=axis), k), axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is happening here? Can't quite follow why we need all this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was a tricky part.
The problem started when indices
was used in take_along_axis
. We need to ensure that the shape of the indices tensor (indices1) matches the shape required for broadcasting during the selection of elements using take_along_axis (shape of sorted_inputs).
The full_like
essentially makes an array of k/k-1 (central element indices) with the shape similar to the shape of the tensor along that axis( calculated using sorted_input.take(0, axis=axis)
.
I think this will be better with an example:
sorted_input = np.array([[1, 3, 5],
[2, 4, 6],
[7, 8, 9]])
k = sorted_input.shape[1] // 2 # Middle index for axis 1
# For this example, k = 3 // 2 = 1
first_elements = sorted_input.take(0, axis=1)
# first_elements = [1, 2, 7]
k_minus_1 = k - 1 # k - 1 = 0
full_tensor = full_like(first_elements, k_minus_1)
# full_tensor = [0, 0, 0]
indices1 = expand_dims(full_tensor, axis=1)
# indices1 = [[0],
# [0],
# [0]]
ans1 = take_along_axis(sorted_input, indices1, axis=axis)
# ans1 = [[1],
# [2],
# [7]]
k = k # k = 1
full_tensor = full_like(first_elements, k)
# full_tensor = [1, 1, 1]
indices2 = expand_dims(full_tensor, axis=1)
# indices2 = [[1],
# [1],
# [1]]
ans2 = take_along_axis(sorted_input, indices2, axis=axis)
# ans2 = [[3],
# [4],
# [8]]
And hence based on ans1 and ans2, median is calculated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay so take_along_axis is not what we want, since we don't need advanced (array indexing). indices is a scalar and always the same. We can do sorted_inputs[:, :, ..., :, k)
to get the value we want. Where the : are empty slices before the axis we want to index. I thought take_along_axis
would do this, but apparently it's only for advanced (list of numbers) indexing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For what basic vs advanced indexing means check out the numpy docs: https://numpy.org/doc/stable/user/basics.indexing.html
pytensor/tensor/math.py
Outdated
|
||
indices1 = expand_dims(full_like(sorted_input.take(0, axis=axis), k - 1), axis) | ||
indices2 = expand_dims(full_like(sorted_input.take(0, axis=axis), k), axis) | ||
ans1 = take_along_axis(sorted_input, indices1, axis=axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should check if take along axis does basic indexing in this case, shouldn't require advanced indexing
We just need sorted_input[..., k]
and sorted_input[..., k+1]
right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate on basic indexing and advanced indexing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More comments
Hi @ricardoV94 , is there something left here? |
Co-authored-by: Ricardo Vieira <[email protected]>
k_values = x_sorted[..., k] | ||
km1_values = x_sorted[..., k - 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I simplified the indexing, we can use simple indexing instead of take_along_axis
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! I did not know we can use simple indexing so conveniently.
# Put axis at the end and unravel them | ||
x_raveled = x.transpose(*non_axis, *axis) | ||
if len(axis) > 1: | ||
x_raveled = x_raveled.reshape((*non_axis_shape, -1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a small optimization to avoid reshaping when not needed
We could have added |
Description
Added support and test for median without creating a separate
op
for median as discussed in #53 . Hence, our implementation of median is standalone and not dependent on numpy. Also added test for it.Related Issue
Checklist
Type of change