Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test error in test_soft_clipping_one_sided_high #7616

Closed
KumoLiu opened this issue Apr 10, 2024 · 7 comments · Fixed by #7624
Closed

Test error in test_soft_clipping_one_sided_high #7616

KumoLiu opened this issue Apr 10, 2024 · 7 comments · Fixed by #7624

Comments

@KumoLiu
Copy link
Contributor

KumoLiu commented Apr 10, 2024

[2024-04-10T01:27:20.324Z] ======================================================================
[2024-04-10T01:27:20.324Z] FAIL: test_soft_clipping_one_sided_high_2 (tests.test_clip_intensity_percentiles.TestClipIntensityPercentiles2D)
[2024-04-10T01:27:20.324Z] ----------------------------------------------------------------------
[2024-04-10T01:27:20.324Z] Traceback (most recent call last):
[2024-04-10T01:27:20.324Z]   File "/usr/local/lib/python3.10/dist-packages/parameterized/parameterized.py", line 620, in standalone_func
[2024-04-10T01:27:20.324Z]     return func(*(a + p.args), **p.kwargs, **kw)
[2024-04-10T01:27:20.324Z]   File "/opt/monai/tests/test_clip_intensity_percentiles.py", line 71, in test_soft_clipping_one_sided_high
[2024-04-10T01:27:20.324Z]     assert_allclose(result, p(expected), type_test="tensor", rtol=5e-5, atol=0)
[2024-04-10T01:27:20.324Z]   File "/opt/monai/tests/utils.py", line 135, in assert_allclose
[2024-04-10T01:27:20.324Z]     np.testing.assert_allclose(actual, desired, *args, **kwargs)
[2024-04-10T01:27:20.324Z]   File "/usr/local/lib/python3.10/dist-packages/numpy/testing/_private/utils.py", line 1592, in assert_allclose
[2024-04-10T01:27:20.324Z]     assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
[2024-04-10T01:27:20.324Z]   File "/usr/lib/python3.10/contextlib.py", line 79, in inner
[2024-04-10T01:27:20.324Z]     return func(*args, **kwds)
[2024-04-10T01:27:20.324Z]   File "/usr/local/lib/python3.10/dist-packages/numpy/testing/_private/utils.py", line 862, in assert_array_compare
[2024-04-10T01:27:20.324Z]     raise AssertionError(msg)
[2024-04-10T01:27:20.324Z] AssertionError: 
[2024-04-10T01:27:20.324Z] Not equal to tolerance rtol=5e-05, atol=0
[2024-04-10T01:27:20.324Z] 
[2024-04-10T01:27:20.324Z] Mismatched elements: 8192 / 8192 (100%)
[2024-04-10T01:27:20.324Z] Max absolute difference: 0.00011533
[2024-04-10T01:27:20.324Z] Max relative difference: 0.00107984
[2024-04-10T01:27:20.324Z]  x: array([[[[-0.425871, -0.425871, -0.425871, ..., -0.425871, -0.425871,
[2024-04-10T01:27:20.324Z]           -0.425871],
[2024-04-10T01:27:20.324Z]          [-0.425871, -0.425871, -0.425871, ..., -0.425871, -0.425871,...
[2024-04-10T01:27:20.324Z]  y: array([[[[-0.425803, -0.425803, -0.425803, ..., -0.425803, -0.425803,
[2024-04-10T01:27:20.324Z]           -0.425803],
[2024-04-10T01:27:20.324Z]          [-0.425803, -0.425803, -0.425803, ..., -0.425803, -0.425803,...
[2024-04-10T01:27:20.324Z] 
@KumoLiu
Copy link
Contributor Author

KumoLiu commented Apr 10, 2024

cc @Lucas-rbnt

@Lucas-rbnt
Copy link
Contributor

Lucas-rbnt commented Apr 10, 2024

Hello, I think this is because torch.logaddexp and numpy.logaddexp give different results.

One workaround would be, when the input is a numpy array, to convert it to tensor to apply torch.logaddexp and then convert it back to numpy. Not very elegant but simple and should solve the problem

A small test on this yields convincing results:

from monai.transforms.utils_pytorch_numpy_unification import softplus

x = torch.randn(128, 128)
np.array_equal(softplus(x.numpy()), softplus(x).numpy())
:> False

while the modified version works as expected:

import numpy as np
import torch

def softplus(x):
    """stable softplus through `np.logaddexp` with equivalent implementation for torch.
    Args:
        x: array/tensor.
    Returns:
        Softplus of the input.
    """
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x)
        return torch.logaddexp(torch.zeros_like(x), x).numpy()
    return torch.logaddexp(torch.zeros_like(x), x)

x = torch.randn(128, 128)
np.array_equal(softplus(x.numpy()), softplus(x).numpy())
:> True

Let me know if you think this is relevant and sorry about that!

@KumoLiu
Copy link
Contributor Author

KumoLiu commented Apr 10, 2024

Hi @Lucas-rbnt, I believe we could consider adjusting this line

expected = soft_clip(self.imt, sharpness_factor=1.0, minv=None, maxv=upper)
to:

expected = soft_clip(im, sharpness_factor=1.0, minv=None, maxv=upper, dtype=im.dtype)

I propose that we use im in place of self.imt and then set dtype. This approach would potentially circumvent the issue tied to the situation where results are consistently checked against NumPy version. I'm also of the thought that this could prove beneficial for other test cases included within the test file. Do you think this would be a viable solution?
Also, would you mind creating a PR to fix this one, thanks!

@Lucas-rbnt
Copy link
Contributor

It's definitely a way of doing things, but it won't be enough on its own because in the __call__ of the ClipIntensityPercentiles class x is converted into a tensor, as in many other transforms:

    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
        """
        Apply the transform to `img`.
        """
        img = convert_to_tensor(img, track_meta=get_track_meta())

Then the softplus is applied directly to a tensor and not a numpy array, and we get back to the differences in results between np.logaddexp and torch.logaddexp. So this should also be changed.
Both are possible, but I wonder if it wouldn't be better to make a choice in order to have consistent results between numpy and torch, so that a tensor or an array containing the same values has the same output through the softplus function.

KumoLiu added a commit to KumoLiu/MONAI that referenced this issue Apr 11, 2024
Signed-off-by: YunLiu <[email protected]>
@KumoLiu
Copy link
Contributor Author

KumoLiu commented Apr 11, 2024

Hi @Lucas-rbnt,

Absolutely, I understand your perspective. Since the introduction of MetaTensor in version 0.9.1, our transformations have consistently been converting data into the MetaTensor format.
More importantly, there isn't an overarching requirement for rendering identical outputs from numpy and PyTorch, particularly for logaddexp. PyTorch has its own internal safeguards to ensure that the results don't stray substantially from numpy outputs.
https://github.com/pytorch/pytorch/blob/793df52dc52f5f5f657744abfd7681eaba7a21f9/test/test_binary_ufuncs.py#L3461
In the grand scheme of things, it doesn't seem imperative for users to cross-verify the consistency of results between torch and numpy.

What do you think?
Thanks

@Lucas-rbnt
Copy link
Contributor

Yes, I agree, it remains coherent from my point of view!
However, I'm not sure that the PR will fix entirely the problem since there is still the MetaTensor conversion.
To ensure that the tests are successful, perhaps we should use the same tolerance to compare the tables as in the torch tests you mentioned?

self.assertEqual(ref, v.float(), atol=0.01, rtol=0.01)

(l.3461) with torch.bfloat16

and it seems to me that https://github.com/pytorch/pytorch/blob/793df52dc52f5f5f657744abfd7681eaba7a21f9/torch/testing/_comparison.py#L1183 gives the others default parameter depending on the type

 +---------------------------+------------+----------+
    | ``dtype``                 | ``rtol``   | ``atol`` |
    +===========================+============+==========+
    | :attr:`~torch.float16`    | ``1e-3``   | ``1e-5`` |
    +---------------------------+------------+----------+
    | :attr:`~torch.bfloat16`   | ``1.6e-2`` | ``1e-5`` |
    +---------------------------+------------+----------+
    | :attr:`~torch.float32`    | ``1.3e-6`` | ``1e-5`` |
    +---------------------------+------------+----------+
    | :attr:`~torch.float64`    | ``1e-7``   | ``1e-7`` |
    +---------------------------+------------+----------+
    | :attr:`~torch.complex32`  | ``1e-3``   | ``1e-5`` |
    +---------------------------+------------+----------+
    | :attr:`~torch.complex64`  | ``1.3e-6`` | ``1e-5`` |
    +---------------------------+------------+----------+
    | :attr:`~torch.complex128` | ``1e-7``   | ``1e-7`` |
    +---------------------------+------------+----------+
    | :attr:`~torch.quint8`     | ``1.3e-6`` | ``1e-5`` |
    +---------------------------+------------+----------+
    | :attr:`~torch.quint2x4`   | ``1.3e-6`` | ``1e-5`` |
    +---------------------------+------------+----------+
    | :attr:`~torch.quint4x2`   | ``1.3e-6`` | ``1e-5`` |
    +---------------------------+------------+----------+
    | :attr:`~torch.qint8`      | ``1.3e-6`` | ``1e-5`` |
    +---------------------------+------------+----------+
    | :attr:`~torch.qint32`     | ``1.3e-6`` | ``1e-5`` |
    +---------------------------+------------+----------+
    | other                     | ``0.0``    | ``0.0``  |
    +---------------------------+------------+----------+

https://github.com/pytorch/pytorch/blob/793df52dc52f5f5f657744abfd7681eaba7a21f9/torch/testing/_comparison.py#L1183

However, the values for float32 are quite low compared to the error mentioned above, so I'm not sure.

@KumoLiu
Copy link
Contributor Author

KumoLiu commented Apr 11, 2024

Taking into account that PyTorch innately performs a comprehensive array of internal consistency tests to ensure alignment with NumPy, it could be redundant for us to independently verify this consistency within our codebase.

Moreover, I would like to suggest that we prioritize merging this fix given that it has frequently led to test failures in other PRs:
https://github.com/Project-MONAI/MONAI/actions/runs/8644324879/job/23699221155?pr=7625#step:8:12528
#7604

Let's move this discussion onward to garner more opinions on this matter. cc @ericspod @atbenmurray @Nic-Ma

KumoLiu added a commit that referenced this issue Apr 11, 2024
Fixes #7616


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: YunLiu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants