Skip to content

Commit

Permalink
Add logic for median from scratch
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed Jul 9, 2024
1 parent 56c30e0 commit 1c11dbe
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
36 changes: 36 additions & 0 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@
concatenate,
constant,
expand_dims,
extract_constant,
full_like,
stack,
switch,
take_along_axis,
)
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.elemwise import (
Expand Down Expand Up @@ -1571,6 +1574,38 @@ def std(input, axis=None, ddof=0, keepdims=False, corrected=False):
return ret


def median(input, axis=None):
"""
Computes the median along the given axis(es) of a tensor `input`.
Parameters
----------
axis: None or int or (list of int) (see `Sum`)
Compute the median along this axis of the tensor.
None means all axes (like numpy).
Notes
-----
This function uses the numpy implementation of median.
"""

input = as_tensor_variable(input)
sorted_input = input.sort(axis=axis)
shape = input.shape[axis]
k = extract_constant(shape) // 2
if extract_constant(shape % 2) == 0:
indices1 = expand_dims(full_like(sorted_input.take(0, axis=axis), k - 1), axis)
indices2 = expand_dims(full_like(sorted_input.take(0, axis=axis), k), axis)
ans1 = take_along_axis(sorted_input, indices1, axis=axis)
ans2 = take_along_axis(sorted_input, indices2, axis=axis)
median_val = (ans1 + ans2) / 2.0
else:
indices = expand_dims(full_like(sorted_input.take(0, axis=axis), k), axis)
median_val = take_along_axis(sorted_input, indices, axis=axis)
median_val.name = "median"
return median_val.squeeze(axis=axis)


@scalar_elemwise(symbolname="scalar_maximum")
def maximum(x, y):
"""elemwise maximum. See max for the maximum in one tensor"""
Expand Down Expand Up @@ -3006,6 +3041,7 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
"sum",
"prod",
"mean",
"median",
"var",
"std",
"std",
Expand Down
31 changes: 31 additions & 0 deletions tests/tensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
max_and_argmax,
maximum,
mean,
median,
min,
minimum,
mod,
Expand Down Expand Up @@ -3731,3 +3732,33 @@ def test_nan_to_num(nan, posinf, neginf):
out,
np.nan_to_num(y, nan=nan, posinf=posinf, neginf=neginf),
)


@pytest.mark.parametrize(
"data, axis",
[
# 1D array
([1, 7, 3, 6, 5, 2, 4], 0),
# 2D array
([[6, 2], [4, 3], [1, 5]], 0),
([[6, 2], [4, 3], [1, 5]], 1),
# 3D array
([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], 0),
([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], 1),
([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], 2),
# 4D array
(
[
[[[3, 1], [4, 3]], [[0, 5], [6, 2]], [[7, 8], [9, 4]]],
[[[10, 11], [12, 13]], [[14, 15], [16, 17]], [[18, 19], [20, 21]]],
],
3,
),
],
)
def test_median(data, axis):
x = tensor(shape=np.array(data).shape)
f = function([x], median(x, axis=axis))
result = f(data)
expected = np.median(data, axis=axis)
assert np.allclose(result, expected)

0 comments on commit 1c11dbe

Please sign in to comment.