Skip to content

Commit

Permalink
Merge branch 'dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
KumoLiu authored Aug 30, 2024
2 parents cb4b646 + b209347 commit edc7d8e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
8 changes: 6 additions & 2 deletions monai/networks/layers/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def forward(ctx, input, spatial_sigma=5, color_sigma=0.5, fast_approx=True):
ctx.cs = color_sigma
ctx.fa = fast_approx
output_data = _C.bilateral_filter(input, spatial_sigma, color_sigma, fast_approx)
if torch.cuda.is_available():
torch.cuda.synchronize()
return output_data

@staticmethod
Expand Down Expand Up @@ -139,7 +141,8 @@ def forward(ctx, input_img, sigma_x, sigma_y, sigma_z, color_sigma):
do_dsig_y,
do_dsig_z,
)

if torch.cuda.is_available():
torch.cuda.synchronize()
return output_tensor

@staticmethod
Expand Down Expand Up @@ -301,7 +304,8 @@ def forward(ctx, input_img, guidance_img, sigma_x, sigma_y, sigma_z, color_sigma
do_dsig_z,
guidance_img,
)

if torch.cuda.is_available():
torch.cuda.synchronize()
return output_tensor

@staticmethod
Expand Down
3 changes: 2 additions & 1 deletion monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2512,6 +2512,7 @@ def distance_transform_edt(
block_params=block_params,
float64_distances=float64_distances,
)
torch.cuda.synchronize()
else:
if not has_ndimage:
raise RuntimeError("scipy.ndimage required if cupy is not available")
Expand Down Expand Up @@ -2545,7 +2546,7 @@ def distance_transform_edt(

r_vals = []
if return_distances and distances_original is None:
r_vals.append(distances)
r_vals.append(distances_ if use_cp else distances)
if return_indices and indices_original is None:
r_vals.append(indices)
if not r_vals:
Expand Down

0 comments on commit edc7d8e

Please sign in to comment.