Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2498,7 +2498,7 @@ def distance_transform_edt(
if return_indices:
dtype = torch.int32
if indices is None:
indices = torch.zeros((img.dim(),) + img.shape, dtype=dtype) # type: ignore
indices = torch.zeros((img.shape[0],) + (img.dim() - 1,) + img.shape[1:], dtype=dtype) # type: ignore
else:
if not isinstance(indices, torch.Tensor) and indices.device != img.device:
raise TypeError("indices must be a torch.Tensor on the same device as img")
Expand Down Expand Up @@ -2532,7 +2532,7 @@ def distance_transform_edt(
raise TypeError("distances must be a numpy.ndarray of dtype float64")
if return_indices:
if indices is None:
indices = np.zeros((img_.ndim,) + img_.shape, dtype=np.int32)
indices = np.zeros((img_.shape[0],) + (img_.ndim - 1,) + img_.shape[1:], dtype=np.int32)
else:
if not isinstance(indices, np.ndarray):
raise TypeError("indices must be a numpy.ndarray")
Expand Down
Loading