Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = icon_registration
version = 1.1.7
version = 1.1.8
author = Hastings Greer
author_email = t@hgreer.com
description = A package for image registration regularized by inverse consistency
Expand Down
79 changes: 43 additions & 36 deletions src/icon_registration/itk_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,24 @@

DEFAULT_FINETUNE_LEARNING_RATE = 0.00002

def _resize_itk_mask(itk_image, shape):
"""Convert an itk mask/segmentation image to a resized torch tensor (nearest interpolation)."""
assert isinstance(itk_image, itk.Image)
trch = torch.Tensor(np.array(itk_image)).to(config.device)[None, None]
return F.interpolate(trch, size=shape[2:], mode="nearest")

def finetune_execute(model, image_A, image_B, steps, learning_rate):
state_dict = copy.deepcopy(model.state_dict())
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for _ in range(steps):
optimizer.zero_grad()
loss_tuple = model(image_A, image_B)
print(loss_tuple)
loss_tuple[0].backward()
optimizer.step()
with torch.no_grad():
loss = model(image_A, image_B)
#model.load_state_dict(state_dict)
return loss


def finetune_execute_mask(model, image_A, image_B, mask_A, mask_B, steps, learning_rate):
def finetune_execute(model, image_A, image_B, steps, learning_rate, **model_kwargs):
state_dict = copy.deepcopy(model.state_dict())
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for _ in range(steps):
optimizer.zero_grad()
loss_tuple = model(image_A, image_B, mask_A=mask_A, mask_B=mask_B)
loss_tuple = model(image_A, image_B, **model_kwargs)
print(loss_tuple)
loss_tuple[0].backward()
optimizer.step()
with torch.no_grad():
loss = model(image_A, image_B, mask_A=mask_A, mask_B=mask_B)
loss = model(image_A, image_B, **model_kwargs)
model.load_state_dict(state_dict)
return loss

Expand All @@ -45,6 +36,8 @@ def register_pair(
model, image_A, image_B, finetune_steps=None, return_artifacts=False, learning_rate=DEFAULT_FINETUNE_LEARNING_RATE
) -> "(itk.CompositeTransform, itk.CompositeTransform)":

assert learning_rate > 0

assert isinstance(image_A, itk.Image)
assert isinstance(image_B, itk.Image)

Expand Down Expand Up @@ -99,28 +92,36 @@ def register_pair(
else:
return itk_transforms + (to_floats(loss),)

def register_pair_with_mask(model, image_A, image_B, mask_A, mask_B, finetune_steps=None, return_artifacts=False, learning_rate=DEFAULT_FINETUNE_LEARNING_RATE):
def register_pair_with_mask(
model,
image_A,
image_B,
mask_A=None,
mask_B=None,
finetune_steps=None,
return_artifacts=False,
learning_rate=DEFAULT_FINETUNE_LEARNING_RATE,
segmentation_A=None,
segmentation_B=None,
):

assert learning_rate > 0

assert isinstance(image_A, itk.Image)
assert isinstance(image_B, itk.Image)
assert isinstance(mask_A, itk.Image)
assert isinstance(mask_B, itk.Image)

# send model to cpu or gpu depending on config- auto detects capability
model.to(config.device)

A_npy = np.array(image_A)
B_npy = np.array(image_B)

A_mask_npy = np.array(mask_A)
B_mask_npy = np.array(mask_B)

assert(np.max(A_npy) != np.min(A_npy))
assert(np.max(B_npy) != np.min(B_npy))

# turn images into torch Tensors: add feature and batch dimensions (each of length 1)
A_trch = torch.Tensor(A_npy).to(config.device)[None, None]
B_trch = torch.Tensor(B_npy).to(config.device)[None, None]
A_mask_trch = torch.Tensor(A_mask_npy).to(config.device)[None, None]
B_mask_trch = torch.Tensor(B_mask_npy).to(config.device)[None, None]

shape = model.identity_map.shape

Expand All @@ -133,22 +134,26 @@ def register_pair_with_mask(model, image_A, image_B, mask_A, mask_B, finetune_st
B_resized = F.interpolate(
B_trch, size=shape[2:], mode="trilinear", align_corners=False
)

A_mask_resized = F.interpolate(
A_mask_trch, size=shape[2:], mode="nearest"
)
B_mask_resized = F.interpolate(
B_mask_trch, size=shape[2:], mode="nearest"
)


model_kwargs = {}
if mask_A is not None:
model_kwargs["mask_A"] = _resize_itk_mask(mask_A, shape)
if mask_B is not None:
model_kwargs["mask_B"] = _resize_itk_mask(mask_B, shape)
if segmentation_A is not None:
model_kwargs["segmentation_A"] = _resize_itk_mask(segmentation_A, shape)
if segmentation_B is not None:
model_kwargs["segmentation_B"] = _resize_itk_mask(segmentation_B, shape)

if finetune_steps == 0:
raise Exception("To indicate no finetune_steps, pass finetune_steps=None")

if finetune_steps == None:
with torch.no_grad():
loss = model(A_resized, B_resized, mask_A=A_mask_resized, mask_B=B_mask_resized)
loss = model(A_resized, B_resized, **model_kwargs)
print(loss)
else:
loss = finetune_execute_mask(model, A_resized, B_resized, A_mask_resized, B_mask_resized, finetune_steps, learning_rate)
loss = finetune_execute(model, A_resized, B_resized, finetune_steps, learning_rate, **model_kwargs)

# phi_AB and phi_BA are [1, 3, H, W, D] pytorch tensors representing the forward and backward
# maps computed by the model
Expand All @@ -173,6 +178,8 @@ def register_pair_with_multimodalities(
model, image_A: list, image_B: list, finetune_steps=None, return_artifacts=False, learning_rate=DEFAULT_FINETUNE_LEARNING_RATE
) -> "(itk.CompositeTransform, itk.CompositeTransform)":

assert learning_rate > 0

assert len(image_A) == len(image_B), "image_A and image_B should have the same number of modalities."

# send model to cpu or gpu depending on config- auto detects capability
Expand Down
2 changes: 1 addition & 1 deletion src/icon_registration/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def to_floats(stats):
if isinstance(v, torch.Tensor):
v = torch.mean(v).cpu().item()
out.append(v)
return ICONLoss(*out)
return type(stats)(*out)


ICONLoss = namedtuple(
Expand Down
Loading