Skip to content

Commit

Permalink
fixes variable names
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli committed Jan 23, 2023
1 parent f4e71c0 commit 7518371
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
16 changes: 8 additions & 8 deletions monai/transforms/croppad/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
__all__ = ["pad_func", "crop_func"]


def pad_func(img_t, to_pad_, mode_, kwargs_, transform_info):
def pad_func(img_t, to_pad_, mode, kwargs, transform_info):
extra_info = {"padded": to_pad_}
img_size = img_t.peek_pending_shape() if isinstance(img_t, MetaTensor) else img_t.shape[1:]
_affine = (
Expand All @@ -55,25 +55,25 @@ def pad_func(img_t, to_pad_, mode_, kwargs_, transform_info):
extra_info=extra_info,
transform_info=transform_info,
)
if mode_ in {"linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty"}:
out = monai.transforms.Pad._np_pad(img_t, pad_width=to_pad_, mode=mode_, **kwargs_)
if mode in {"linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty"}:
out = monai.transforms.Pad._np_pad(img_t, pad_width=to_pad_, mode=mode, **kwargs)
else:
mode_ = convert_pad_mode(dst=img_t, mode=mode_).value
mode = convert_pad_mode(dst=img_t, mode=mode).value
try:
_pad = (
monai.transforms.Pad._pt_pad
if mode_ in {"reflect", "replicate"}
if mode in {"reflect", "replicate"}
and img_t.dtype not in {torch.int16, torch.int64, torch.bool, torch.uint8}
else monai.transforms.Pad._np_pad
)
out = _pad(img_t, pad_width=to_pad_, mode=mode_, **kwargs_)
out = _pad(img_t, pad_width=to_pad_, mode=mode, **kwargs)
except (ValueError, TypeError, RuntimeError) as err:
if isinstance(err, NotImplementedError) or any(
k in str(err) for k in ("supported", "unexpected keyword", "implemented")
):
out = monai.transforms.Pad._np_pad(img_t, pad_width=to_pad_, mode=mode_, **kwargs_)
out = monai.transforms.Pad._np_pad(img_t, pad_width=to_pad_, mode=mode, **kwargs)
else:
raise ValueError(f"{img_t.shape} {to_pad_} {mode_} {kwargs_} {img_t.dtype} {img_t.device}") from err
raise ValueError(f"{img_t.shape} {to_pad_} {mode} {kwargs} {img_t.dtype} {img_t.device}") from err
else:
out = img_t
if get_track_meta():
Expand Down
8 changes: 4 additions & 4 deletions monai/transforms/spatial/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def orientation(data_array, original_affine, spatial_ornt, transform_info):
if get_track_meta():
new_affine = to_affine_nd(len(spatial_shape), original_affine) @ affine_x
new_affine = to_affine_nd(original_affine, new_affine)
new_affine, *_ = convert_data_type(new_affine, torch.Tensor, dtype=torch.float32, device=data_array.device)
new_affine, *_ = convert_data_type(new_affine, torch.Tensor, dtype=torch.float64, device=data_array.device)
data_array.affine = new_affine
return TraceableTransform.track_transform(data_array, extra_info=extra_info, transform_info=transform_info)

Expand Down Expand Up @@ -418,8 +418,8 @@ def update_meta(img, spatial_size, new_spatial_size, axes, k):
return TraceableTransform.track_transform(out, extra_info=extra_info, transform_info=transform_info)


def affine_func(img, affine, grid, resampler, sp_size, _mode, _padding_mode, do_resampling, image_only, transform_info):
extra_info = {"affine": affine, "mode": _mode, "padding_mode": _padding_mode, "do_resampling": do_resampling}
def affine_func(img, affine, grid, resampler, sp_size, mode, padding_mode, do_resampling, image_only, transform_info):
extra_info = {"affine": affine, "mode": mode, "padding_mode": padding_mode, "do_resampling": do_resampling}
img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
if transform_info.get(TraceKeys.LAZY_EVALUATION):
if not get_track_meta():
Expand All @@ -436,7 +436,7 @@ def affine_func(img, affine, grid, resampler, sp_size, _mode, _padding_mode, do_
)
return img if image_only else (img, affine)
if do_resampling:
out = resampler(img=img, grid=grid, mode=_mode, padding_mode=_padding_mode)
out = resampler(img=img, grid=grid, mode=mode, padding_mode=padding_mode)
else:
out = convert_data_type(img, dtype=torch.float32, device=resampler.device)[0]

Expand Down

0 comments on commit 7518371

Please sign in to comment.