Skip to content

Commit

Permalink
6066 pad mode (#6076)
Browse files Browse the repository at this point in the history
Fixes #6066

### Description

prefer the pytorch backend as much as possible

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli authored Feb 28, 2023
1 parent c2fc083 commit ab800d8
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
25 changes: 19 additions & 6 deletions monai/transforms/croppad/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

from __future__ import annotations

import warnings

import numpy as np
import torch
from torch.nn.functional import pad as pad_pt
Expand All @@ -29,7 +31,12 @@


def _np_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> torch.Tensor:
img_np = img.detach().cpu().numpy() if isinstance(img, torch.Tensor) else img
if isinstance(img, torch.Tensor):
if img.is_cuda:
warnings.warn(f"Padding: moving img {img.shape} from cuda to cpu for dtype={img.dtype} mode={mode}.")
img_np = img.detach().cpu().numpy()
else:
img_np = img
mode = convert_pad_mode(dst=img_np, mode=mode).value
if mode == "constant" and "value" in kwargs:
kwargs["constant_values"] = kwargs.pop("value")
Expand All @@ -40,9 +47,15 @@ def _np_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kw


def _pt_pad(img: torch.Tensor, pad_width: list[tuple[int, int]], mode: str, **kwargs) -> torch.Tensor:
mode = convert_pad_mode(dst=img, mode=mode).value
if mode == "constant" and "constant_values" in kwargs:
_kwargs = kwargs.copy()
_kwargs["value"] = _kwargs.pop("constant_values")
else:
_kwargs = kwargs
pt_pad_width = [val for sublist in pad_width[1:] for val in sublist[::-1]][::-1]
# torch.pad expects `[B, C, H, W, [D]]` shape
return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **kwargs).squeeze(0)
return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **_kwargs).squeeze(0)


def pad_nd(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, **kwargs):
Expand All @@ -68,14 +81,14 @@ def pad_nd(img: torch.Tensor, to_pad: list[tuple[int, int]], mode: str, **kwargs
mode = convert_pad_mode(dst=img, mode=mode).value
try:
_pad = (
_pt_pad
if mode in {"reflect", "replicate"} and img.dtype not in {torch.int16, torch.int64, torch.bool, torch.uint8}
else _np_pad
_np_pad
if mode in {"reflect", "replicate"} and img.dtype in {torch.int16, torch.int64, torch.bool, torch.uint8}
else _pt_pad
)
return _pad(img, pad_width=to_pad, mode=mode, **kwargs)
except (ValueError, TypeError, RuntimeError) as err:
if isinstance(err, NotImplementedError) or any(
k in str(err) for k in ("supported", "unexpected keyword", "implemented")
k in str(err) for k in ("supported", "unexpected keyword", "implemented", "value")
):
return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs)
raise ValueError(f"{img.shape} {to_pad} {mode} {kwargs} {img.dtype} {img.device}") from err
Expand Down
4 changes: 2 additions & 2 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1628,13 +1628,13 @@ def convert_pad_mode(dst: NdarrayOrTensor, mode: str | None):
if isinstance(dst, torch.Tensor):
if mode == "wrap":
mode = "circular"
if mode == "edge":
elif mode == "edge":
mode = "replicate"
return look_up_option(mode, PytorchPadMode)
if isinstance(dst, np.ndarray):
if mode == "circular":
mode = "wrap"
if mode == "replicate":
elif mode == "replicate":
mode = "edge"
return look_up_option(mode, NumpyPadMode)
raise ValueError(f"unsupported data type: {type(dst)}.")
Expand Down

0 comments on commit ab800d8

Please sign in to comment.