Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions src/icon_registration/itk_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from icon_registration import config
from icon_registration.losses import to_floats

DEFAULT_FINETUNE_LEARNING_RATE = 0.00002

def finetune_execute(model, image_A, image_B, steps):

def finetune_execute(model, image_A, image_B, steps, learning_rate):
state_dict = copy.deepcopy(model.state_dict())
optimizer = torch.optim.Adam(model.parameters(), lr=0.00002)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for _ in range(steps):
optimizer.zero_grad()
loss_tuple = model(image_A, image_B)
Expand All @@ -24,9 +26,9 @@ def finetune_execute(model, image_A, image_B, steps):
return loss


def finetune_execute_mask(model, image_A, image_B, mask_A, mask_B, steps):
def finetune_execute_mask(model, image_A, image_B, mask_A, mask_B, steps, learning_rate):
state_dict = copy.deepcopy(model.state_dict())
optimizer = torch.optim.Adam(model.parameters(), lr=0.00002)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for _ in range(steps):
optimizer.zero_grad()
loss_tuple = model(image_A, image_B, mask_A=mask_A, mask_B=mask_B)
Expand All @@ -40,7 +42,7 @@ def finetune_execute_mask(model, image_A, image_B, mask_A, mask_B, steps):


def register_pair(
model, image_A, image_B, finetune_steps=None, return_artifacts=False
model, image_A, image_B, finetune_steps=None, return_artifacts=False, learning_rate=DEFAULT_FINETUNE_LEARNING_RATE
) -> "(itk.CompositeTransform, itk.CompositeTransform)":

assert isinstance(image_A, itk.Image)
Expand Down Expand Up @@ -76,7 +78,7 @@ def register_pair(
with torch.no_grad():
loss = model(A_resized, B_resized)
else:
loss = finetune_execute(model, A_resized, B_resized, finetune_steps)
loss = finetune_execute(model, A_resized, B_resized, finetune_steps, learning_rate)

# phi_AB and phi_BA are [1, 3, H, W, D] pytorch tensors representing the forward and backward
# maps computed by the model
Expand All @@ -96,7 +98,7 @@ def register_pair(
else:
return itk_transforms + (to_floats(loss),)

def register_pair_with_mask(model, image_A, image_B, mask_A, mask_B, finetune_steps=None, return_artifacts=False):
def register_pair_with_mask(model, image_A, image_B, mask_A, mask_B, finetune_steps=None, return_artifacts=False, learning_rate=DEFAULT_FINETUNE_LEARNING_RATE):

assert isinstance(image_A, itk.Image)
assert isinstance(image_B, itk.Image)
Expand Down Expand Up @@ -145,7 +147,7 @@ def register_pair_with_mask(model, image_A, image_B, mask_A, mask_B, finetune_st
with torch.no_grad():
loss = model(A_resized, B_resized, mask_A=A_mask_resized, mask_B=B_mask_resized)
else:
loss = finetune_execute_mask(model, A_resized, B_resized, A_mask_resized, B_mask_resized, finetune_steps)
loss = finetune_execute_mask(model, A_resized, B_resized, A_mask_resized, B_mask_resized, finetune_steps, learning_rate)

# phi_AB and phi_BA are [1, 3, H, W, D] pytorch tensors representing the forward and backward
# maps computed by the model
Expand All @@ -167,7 +169,7 @@ def register_pair_with_mask(model, image_A, image_B, mask_A, mask_B, finetune_st


def register_pair_with_multimodalities(
model, image_A: list, image_B: list, finetune_steps=None, return_artifacts=False
model, image_A: list, image_B: list, finetune_steps=None, return_artifacts=False, learning_rate=DEFAULT_FINETUNE_LEARNING_RATE
) -> "(itk.CompositeTransform, itk.CompositeTransform)":

assert len(image_A) == len(image_B), "image_A and image_B should have the same number of modalities."
Expand Down Expand Up @@ -209,7 +211,7 @@ def register_pair_with_multimodalities(
with torch.no_grad():
loss = model(A_trch, B_trch)
else:
loss = finetune_execute(model, A_trch, B_trch, finetune_steps)
loss = finetune_execute(model, A_trch, B_trch, finetune_steps, learning_rate)

# phi_AB and phi_BA are [1, 3, H, W, D] pytorch tensors representing the forward and backward
# maps computed by the model
Expand Down
Loading