Skip to content

Commit

Permalink
feat: Add percentile function to tensorlib (#817)
Browse files Browse the repository at this point in the history
* Add percentile function to the tensor backends
* Add tests for percentile and its interpolation methods
   - JAX requires additional dtype support with the 'linear' interpolation method
     c.f. jax-ml/jax#8513
   - PyTorch has yet to implement interpolation method options
   - c.f. #1693
  • Loading branch information
matthewfeickert authored Nov 11, 2021
1 parent 902052f commit 39a9e92
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 0 deletions.
41 changes: 41 additions & 0 deletions src/pyhf/tensor/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,47 @@ def log(self, tensor_in):
def exp(self, tensor_in):
return jnp.exp(tensor_in)

def percentile(self, tensor_in, q, axis=None, interpolation="linear"):
r"""
Compute the :math:`q`-th percentile of the tensor along the specified axis.
Example:
>>> import pyhf
>>> import jax.numpy as jnp
>>> pyhf.set_backend("jax")
>>> a = pyhf.tensorlib.astensor([[10, 7, 4], [3, 2, 1]])
>>> pyhf.tensorlib.percentile(a, jnp.float64(50))
DeviceArray(3.5, dtype=float64)
>>> pyhf.tensorlib.percentile(a, 50, axis=1)
DeviceArray([7., 2.], dtype=float64)
Args:
tensor_in (`tensor`): The tensor containing the data
q (:obj:`float` or `tensor`): The :math:`q`-th percentile to compute
axis (`number` or `tensor`): The dimensions along which to compute
interpolation (:obj:`str`): The interpolation method to use when the
desired percentile lies between two data points ``i < j``:
- ``'linear'``: ``i + (j - i) * fraction``, where ``fraction`` is the
fractional part of the index surrounded by ``i`` and ``j``.
- ``'lower'``: ``i``.
- ``'higher'``: ``j``.
- ``'midpoint'``: ``(i + j) / 2``.
- ``'nearest'``: ``i`` or ``j``, whichever is nearest.
Returns:
JAX ndarray: The value of the :math:`q`-th percentile of the tensor along the specified axis.
"""
# TODO: Monitor future JAX releases for changes to percentile dtype promotion
# c.f. https://github.com/google/jax/issues/8513
return jnp.percentile(tensor_in, q, axis=axis, interpolation=interpolation)

def stack(self, sequence, axis=0):
return jnp.stack(sequence, axis=axis)

Expand Down
38 changes: 38 additions & 0 deletions src/pyhf/tensor/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,44 @@ def log(self, tensor_in):
def exp(self, tensor_in):
return np.exp(tensor_in)

def percentile(self, tensor_in, q, axis=None, interpolation="linear"):
r"""
Compute the :math:`q`-th percentile of the tensor along the specified axis.
Example:
>>> import pyhf
>>> pyhf.set_backend("numpy")
>>> a = pyhf.tensorlib.astensor([[10, 7, 4], [3, 2, 1]])
>>> pyhf.tensorlib.percentile(a, 50)
3.5
>>> pyhf.tensorlib.percentile(a, 50, axis=1)
array([7., 2.])
Args:
tensor_in (`tensor`): The tensor containing the data
q (:obj:`float` or `tensor`): The :math:`q`-th percentile to compute
axis (`number` or `tensor`): The dimensions along which to compute
interpolation (:obj:`str`): The interpolation method to use when the
desired percentile lies between two data points ``i < j``:
- ``'linear'``: ``i + (j - i) * fraction``, where ``fraction`` is the
fractional part of the index surrounded by ``i`` and ``j``.
- ``'lower'``: ``i``.
- ``'higher'``: ``j``.
- ``'midpoint'``: ``(i + j) / 2``.
- ``'nearest'``: ``i`` or ``j``, whichever is nearest.
Returns:
NumPy ndarray: The value of the :math:`q`-th percentile of the tensor along the specified axis.
"""
return np.percentile(tensor_in, q, axis=axis, interpolation=interpolation)

def stack(self, sequence, axis=0):
return np.stack(sequence, axis=axis)

Expand Down
41 changes: 41 additions & 0 deletions src/pyhf/tensor/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,47 @@ def log(self, tensor_in):
def exp(self, tensor_in):
return torch.exp(tensor_in)

def percentile(self, tensor_in, q, axis=None, interpolation="linear"):
r"""
Compute the :math:`q`-th percentile of the tensor along the specified axis.
Example:
>>> import pyhf
>>> pyhf.set_backend("pytorch")
>>> a = pyhf.tensorlib.astensor([[10, 7, 4], [3, 2, 1]])
>>> pyhf.tensorlib.percentile(a, 50)
tensor(3.5000)
>>> pyhf.tensorlib.percentile(a, 50, axis=1)
tensor([7., 2.])
Args:
tensor_in (`tensor`): The tensor containing the data
q (:obj:`float` or `tensor`): The :math:`q`-th percentile to compute
axis (`number` or `tensor`): The dimensions along which to compute
interpolation (:obj:`str`): The interpolation method to use when the
desired percentile lies between two data points ``i < j``:
- ``'linear'``: ``i + (j - i) * fraction``, where ``fraction`` is the
fractional part of the index surrounded by ``i`` and ``j``.
- ``'lower'``: Not yet implemented in PyTorch.
- ``'higher'``: Not yet implemented in PyTorch.
- ``'midpoint'``: Not yet implemented in PyTorch.
- ``'nearest'``: Not yet implemented in PyTorch.
Returns:
PyTorch tensor: The value of the :math:`q`-th percentile of the tensor along the specified axis.
"""
# Interpolation options not yet supported
# c.f. https://github.com/pytorch/pytorch/pull/49267
# c.f. https://github.com/pytorch/pytorch/pull/59397
return torch.quantile(tensor_in, q / 100, dim=axis)

def stack(self, sequence, axis=0):
return torch.stack(sequence, dim=axis)

Expand Down
42 changes: 42 additions & 0 deletions src/pyhf/tensor/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,48 @@ def log(self, tensor_in):
def exp(self, tensor_in):
return tf.exp(tensor_in)

def percentile(self, tensor_in, q, axis=None, interpolation="linear"):
r"""
Compute the :math:`q`-th percentile of the tensor along the specified axis.
Example:
>>> import pyhf
>>> pyhf.set_backend("tensorflow")
>>> a = pyhf.tensorlib.astensor([[10, 7, 4], [3, 2, 1]])
>>> t = pyhf.tensorlib.percentile(a, 50)
>>> print(t)
tf.Tensor(3.5, shape=(), dtype=float64)
>>> t = pyhf.tensorlib.percentile(a, 50, axis=1)
>>> print(t)
tf.Tensor([7. 2.], shape=(2,), dtype=float64)
Args:
tensor_in (`tensor`): The tensor containing the data
q (:obj:`float` or `tensor`): The :math:`q`-th percentile to compute
axis (`number` or `tensor`): The dimensions along which to compute
interpolation (:obj:`str`): The interpolation method to use when the
desired percentile lies between two data points ``i < j``:
- ``'linear'``: ``i + (j - i) * fraction``, where ``fraction`` is the
fractional part of the index surrounded by ``i`` and ``j``.
- ``'lower'``: ``i``.
- ``'higher'``: ``j``.
- ``'midpoint'``: ``(i + j) / 2``.
- ``'nearest'``: ``i`` or ``j``, whichever is nearest.
Returns:
TensorFlow Tensor: The value of the :math:`q`-th percentile of the tensor along the specified axis.
"""
return tfp.stats.percentile(
tensor_in, q, axis=axis, interpolation=interpolation
)

def stack(self, sequence, axis=0):
return tf.stack(sequence, axis=axis)

Expand Down
55 changes: 55 additions & 0 deletions tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,61 @@ def test_boolean_mask(backend):
)


@pytest.mark.skip_jax
def test_percentile(backend):
tb = pyhf.tensorlib
a = tb.astensor([[10, 7, 4], [3, 2, 1]])
assert tb.tolist(tb.percentile(a, 0)) == 1

assert tb.tolist(tb.percentile(a, 50)) == 3.5
assert tb.tolist(tb.percentile(a, 100)) == 10
assert tb.tolist(tb.percentile(a, 50, axis=1)) == [7.0, 2.0]


# FIXME: PyTorch doesn't yet support interpolation schemes other than "linear"
# c.f. https://github.com/pytorch/pytorch/pull/59397
# c.f. https://github.com/scikit-hep/pyhf/issues/1693
@pytest.mark.skip_pytorch
@pytest.mark.skip_pytorch64
@pytest.mark.skip_jax
def test_percentile_interpolation(backend):
tb = pyhf.tensorlib
a = tb.astensor([[10, 7, 4], [3, 2, 1]])

assert tb.tolist(tb.percentile(a, 50, interpolation="linear")) == 3.5
assert tb.tolist(tb.percentile(a, 50, interpolation="nearest")) == 3.0
assert tb.tolist(tb.percentile(a, 50, interpolation="lower")) == 3.0
assert tb.tolist(tb.percentile(a, 50, interpolation="midpoint")) == 3.5
assert tb.tolist(tb.percentile(a, 50, interpolation="higher")) == 4.0


@pytest.mark.only_jax
def test_percentile_jax(backend):
tb = pyhf.tensorlib
a = tb.astensor([[10, 7, 4], [3, 2, 1]])
assert tb.tolist(tb.percentile(a, 0)) == 1

# TODO: Monitor future JAX releases for changes to percentile dtype promotion
# c.f. https://github.com/scikit-hep/pyhf/issues/1693
assert tb.tolist(tb.percentile(a, np.float64(50))) == 3.5
assert tb.tolist(tb.percentile(a, np.float64(100))) == 10
assert tb.tolist(tb.percentile(a, 50, axis=1)) == [7.0, 2.0]


@pytest.mark.only_jax
def test_percentile_interpolation_jax(backend):
tb = pyhf.tensorlib
a = tb.astensor([[10, 7, 4], [3, 2, 1]])

# TODO: Monitor future JAX releases for changes to percentile dtype promotion
# c.f. https://github.com/scikit-hep/pyhf/issues/1693
assert tb.tolist(tb.percentile(a, np.float64(50), interpolation="linear")) == 3.5
assert tb.tolist(tb.percentile(a, 50, interpolation="nearest")) == 3.0
assert tb.tolist(tb.percentile(a, 50, interpolation="lower")) == 3.0
assert tb.tolist(tb.percentile(a, 50, interpolation="midpoint")) == 3.5
assert tb.tolist(tb.percentile(a, 50, interpolation="higher")) == 4.0


def test_tensor_tile(backend):
a = [[1], [2], [3]]
tb = pyhf.tensorlib
Expand Down

0 comments on commit 39a9e92

Please sign in to comment.