From 1c11dbe1d83cc7205365071bf2b124705808ca37 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Thu, 13 Jun 2024 11:45:11 +0530 Subject: [PATCH] Add logic for median from scratch --- pytensor/tensor/math.py | 36 ++++++++++++++++++++++++++++++++++++ tests/tensor/test_math.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index b55adb0312..45e6a24215 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -26,8 +26,11 @@ concatenate, constant, expand_dims, + extract_constant, + full_like, stack, switch, + take_along_axis, ) from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.elemwise import ( @@ -1571,6 +1574,38 @@ def std(input, axis=None, ddof=0, keepdims=False, corrected=False): return ret +def median(input, axis=None): + """ + Computes the median along the given axis(es) of a tensor `input`. + + Parameters + ---------- + axis: None or int or (list of int) (see `Sum`) + Compute the median along this axis of the tensor. + None means all axes (like numpy). + + Notes + ----- + This function uses the numpy implementation of median. + """ + + input = as_tensor_variable(input) + sorted_input = input.sort(axis=axis) + shape = input.shape[axis] + k = extract_constant(shape) // 2 + if extract_constant(shape % 2) == 0: + indices1 = expand_dims(full_like(sorted_input.take(0, axis=axis), k - 1), axis) + indices2 = expand_dims(full_like(sorted_input.take(0, axis=axis), k), axis) + ans1 = take_along_axis(sorted_input, indices1, axis=axis) + ans2 = take_along_axis(sorted_input, indices2, axis=axis) + median_val = (ans1 + ans2) / 2.0 + else: + indices = expand_dims(full_like(sorted_input.take(0, axis=axis), k), axis) + median_val = take_along_axis(sorted_input, indices, axis=axis) + median_val.name = "median" + return median_val.squeeze(axis=axis) + + @scalar_elemwise(symbolname="scalar_maximum") def maximum(x, y): """elemwise maximum. See max for the maximum in one tensor""" @@ -3006,6 +3041,7 @@ def nan_to_num(x, nan=0.0, posinf=None, neginf=None): "sum", "prod", "mean", + "median", "var", "std", "std", diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index e86bd4ec17..c5c2263ad0 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -94,6 +94,7 @@ max_and_argmax, maximum, mean, + median, min, minimum, mod, @@ -3731,3 +3732,33 @@ def test_nan_to_num(nan, posinf, neginf): out, np.nan_to_num(y, nan=nan, posinf=posinf, neginf=neginf), ) + + +@pytest.mark.parametrize( + "data, axis", + [ + # 1D array + ([1, 7, 3, 6, 5, 2, 4], 0), + # 2D array + ([[6, 2], [4, 3], [1, 5]], 0), + ([[6, 2], [4, 3], [1, 5]], 1), + # 3D array + ([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], 0), + ([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], 1), + ([[[6, 2, 3], [1, 5, 8], [4, 7, 9]], [[5, 3, 4], [8, 6, 2], [7, 1, 9]]], 2), + # 4D array + ( + [ + [[[3, 1], [4, 3]], [[0, 5], [6, 2]], [[7, 8], [9, 4]]], + [[[10, 11], [12, 13]], [[14, 15], [16, 17]], [[18, 19], [20, 21]]], + ], + 3, + ), + ], +) +def test_median(data, axis): + x = tensor(shape=np.array(data).shape) + f = function([x], median(x, axis=axis)) + result = f(data) + expected = np.median(data, axis=axis) + assert np.allclose(result, expected)