diff --git a/setup.cfg b/setup.cfg index 3b38953..162b38d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = icon_registration -version = 1.1.7 +version = 1.1.8 author = Hastings Greer author_email = t@hgreer.com description = A package for image registration regularized by inverse consistency diff --git a/src/icon_registration/itk_wrapper.py b/src/icon_registration/itk_wrapper.py index ab4d331..5412ccf 100644 --- a/src/icon_registration/itk_wrapper.py +++ b/src/icon_registration/itk_wrapper.py @@ -10,33 +10,24 @@ DEFAULT_FINETUNE_LEARNING_RATE = 0.00002 +def _resize_itk_mask(itk_image, shape): + """Convert an itk mask/segmentation image to a resized torch tensor (nearest interpolation).""" + assert isinstance(itk_image, itk.Image) + trch = torch.Tensor(np.array(itk_image)).to(config.device)[None, None] + return F.interpolate(trch, size=shape[2:], mode="nearest") -def finetune_execute(model, image_A, image_B, steps, learning_rate): - state_dict = copy.deepcopy(model.state_dict()) - optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) - for _ in range(steps): - optimizer.zero_grad() - loss_tuple = model(image_A, image_B) - print(loss_tuple) - loss_tuple[0].backward() - optimizer.step() - with torch.no_grad(): - loss = model(image_A, image_B) - #model.load_state_dict(state_dict) - return loss - -def finetune_execute_mask(model, image_A, image_B, mask_A, mask_B, steps, learning_rate): +def finetune_execute(model, image_A, image_B, steps, learning_rate, **model_kwargs): state_dict = copy.deepcopy(model.state_dict()) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) for _ in range(steps): optimizer.zero_grad() - loss_tuple = model(image_A, image_B, mask_A=mask_A, mask_B=mask_B) + loss_tuple = model(image_A, image_B, **model_kwargs) print(loss_tuple) loss_tuple[0].backward() optimizer.step() with torch.no_grad(): - loss = model(image_A, image_B, mask_A=mask_A, mask_B=mask_B) + loss = model(image_A, image_B, **model_kwargs) model.load_state_dict(state_dict) return loss @@ -45,6 +36,8 @@ def register_pair( model, image_A, image_B, finetune_steps=None, return_artifacts=False, learning_rate=DEFAULT_FINETUNE_LEARNING_RATE ) -> "(itk.CompositeTransform, itk.CompositeTransform)": + assert learning_rate > 0 + assert isinstance(image_A, itk.Image) assert isinstance(image_B, itk.Image) @@ -99,19 +92,29 @@ def register_pair( else: return itk_transforms + (to_floats(loss),) -def register_pair_with_mask(model, image_A, image_B, mask_A, mask_B, finetune_steps=None, return_artifacts=False, learning_rate=DEFAULT_FINETUNE_LEARNING_RATE): +def register_pair_with_mask( + model, + image_A, + image_B, + mask_A=None, + mask_B=None, + finetune_steps=None, + return_artifacts=False, + learning_rate=DEFAULT_FINETUNE_LEARNING_RATE, + segmentation_A=None, + segmentation_B=None, +): + + assert learning_rate > 0 assert isinstance(image_A, itk.Image) assert isinstance(image_B, itk.Image) - - assert isinstance(mask_A, itk.Image) - assert isinstance(mask_B, itk.Image) + + # send model to cpu or gpu depending on config- auto detects capability + model.to(config.device) A_npy = np.array(image_A) B_npy = np.array(image_B) - - A_mask_npy = np.array(mask_A) - B_mask_npy = np.array(mask_B) assert(np.max(A_npy) != np.min(A_npy)) assert(np.max(B_npy) != np.min(B_npy)) @@ -119,8 +122,6 @@ def register_pair_with_mask(model, image_A, image_B, mask_A, mask_B, finetune_st # turn images into torch Tensors: add feature and batch dimensions (each of length 1) A_trch = torch.Tensor(A_npy).to(config.device)[None, None] B_trch = torch.Tensor(B_npy).to(config.device)[None, None] - A_mask_trch = torch.Tensor(A_mask_npy).to(config.device)[None, None] - B_mask_trch = torch.Tensor(B_mask_npy).to(config.device)[None, None] shape = model.identity_map.shape @@ -133,22 +134,26 @@ def register_pair_with_mask(model, image_A, image_B, mask_A, mask_B, finetune_st B_resized = F.interpolate( B_trch, size=shape[2:], mode="trilinear", align_corners=False ) - - A_mask_resized = F.interpolate( - A_mask_trch, size=shape[2:], mode="nearest" - ) - B_mask_resized = F.interpolate( - B_mask_trch, size=shape[2:], mode="nearest" - ) - + + model_kwargs = {} + if mask_A is not None: + model_kwargs["mask_A"] = _resize_itk_mask(mask_A, shape) + if mask_B is not None: + model_kwargs["mask_B"] = _resize_itk_mask(mask_B, shape) + if segmentation_A is not None: + model_kwargs["segmentation_A"] = _resize_itk_mask(segmentation_A, shape) + if segmentation_B is not None: + model_kwargs["segmentation_B"] = _resize_itk_mask(segmentation_B, shape) + if finetune_steps == 0: raise Exception("To indicate no finetune_steps, pass finetune_steps=None") if finetune_steps == None: with torch.no_grad(): - loss = model(A_resized, B_resized, mask_A=A_mask_resized, mask_B=B_mask_resized) + loss = model(A_resized, B_resized, **model_kwargs) + print(loss) else: - loss = finetune_execute_mask(model, A_resized, B_resized, A_mask_resized, B_mask_resized, finetune_steps, learning_rate) + loss = finetune_execute(model, A_resized, B_resized, finetune_steps, learning_rate, **model_kwargs) # phi_AB and phi_BA are [1, 3, H, W, D] pytorch tensors representing the forward and backward # maps computed by the model @@ -173,6 +178,8 @@ def register_pair_with_multimodalities( model, image_A: list, image_B: list, finetune_steps=None, return_artifacts=False, learning_rate=DEFAULT_FINETUNE_LEARNING_RATE ) -> "(itk.CompositeTransform, itk.CompositeTransform)": + assert learning_rate > 0 + assert len(image_A) == len(image_B), "image_A and image_B should have the same number of modalities." # send model to cpu or gpu depending on config- auto detects capability diff --git a/src/icon_registration/losses.py b/src/icon_registration/losses.py index 0f3320d..8109cb6 100644 --- a/src/icon_registration/losses.py +++ b/src/icon_registration/losses.py @@ -15,7 +15,7 @@ def to_floats(stats): if isinstance(v, torch.Tensor): v = torch.mean(v).cpu().item() out.append(v) - return ICONLoss(*out) + return type(stats)(*out) ICONLoss = namedtuple(