From 17395f29cf3a07e2e534a286f5187a82ea3b1330 Mon Sep 17 00:00:00 2001 From: Basar Demir Date: Wed, 1 Apr 2026 16:51:30 -0400 Subject: [PATCH 1/8] Add segmentation_A/B parameters to register_pair_with_mask --- src/icon_registration/itk_wrapper.py | 41 +++++++++++++--------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/src/icon_registration/itk_wrapper.py b/src/icon_registration/itk_wrapper.py index 81a0d3f..ebad840 100644 --- a/src/icon_registration/itk_wrapper.py +++ b/src/icon_registration/itk_wrapper.py @@ -9,6 +9,13 @@ from icon_registration.losses import to_floats +def _resize_itk_mask(itk_image, shape): + """Convert an itk mask/segmentation image to a resized torch tensor (nearest interpolation).""" + assert isinstance(itk_image, itk.Image) + trch = torch.Tensor(np.array(itk_image)).to(config.device)[None, None] + return F.interpolate(trch, size=shape[2:], mode="nearest") + + def finetune_execute(model, image_A, image_B, steps): state_dict = copy.deepcopy(model.state_dict()) optimizer = torch.optim.Adam(model.parameters(), lr=0.00002) @@ -24,17 +31,17 @@ def finetune_execute(model, image_A, image_B, steps): return loss -def finetune_execute_mask(model, image_A, image_B, mask_A, mask_B, steps): +def finetune_execute_mask(model, image_A, image_B, mask_A, mask_B, steps, segmentation_A=None, segmentation_B=None): state_dict = copy.deepcopy(model.state_dict()) optimizer = torch.optim.Adam(model.parameters(), lr=0.00002) for _ in range(steps): optimizer.zero_grad() - loss_tuple = model(image_A, image_B, mask_A=mask_A, mask_B=mask_B) + loss_tuple = model(image_A, image_B, mask_A=mask_A, mask_B=mask_B, segmentation_A=segmentation_A, segmentation_B=segmentation_B) print(loss_tuple) loss_tuple[0].backward() optimizer.step() with torch.no_grad(): - loss = model(image_A, image_B, mask_A=mask_A, mask_B=mask_B) + loss = model(image_A, image_B, mask_A=mask_A, mask_B=mask_B, segmentation_A=segmentation_A, segmentation_B=segmentation_B) model.load_state_dict(state_dict) return loss @@ -97,19 +104,13 @@ def register_pair( else: return itk_transforms + (to_floats(loss),) -def register_pair_with_mask(model, image_A, image_B, mask_A, mask_B, finetune_steps=None, return_artifacts=False): +def register_pair_with_mask(model, image_A, image_B, mask_A=None, mask_B=None, finetune_steps=None, return_artifacts=False, segmentation_A=None, segmentation_B=None): assert isinstance(image_A, itk.Image) assert isinstance(image_B, itk.Image) - - assert isinstance(mask_A, itk.Image) - assert isinstance(mask_B, itk.Image) A_npy = np.array(image_A) B_npy = np.array(image_B) - - A_mask_npy = np.array(mask_A) - B_mask_npy = np.array(mask_B) assert(np.max(A_npy) != np.min(A_npy)) assert(np.max(B_npy) != np.min(B_npy)) @@ -117,8 +118,6 @@ def register_pair_with_mask(model, image_A, image_B, mask_A, mask_B, finetune_st # turn images into torch Tensors: add feature and batch dimensions (each of length 1) A_trch = torch.Tensor(A_npy).to(config.device)[None, None] B_trch = torch.Tensor(B_npy).to(config.device)[None, None] - A_mask_trch = torch.Tensor(A_mask_npy).to(config.device)[None, None] - B_mask_trch = torch.Tensor(B_mask_npy).to(config.device)[None, None] shape = model.identity_map.shape @@ -131,22 +130,20 @@ def register_pair_with_mask(model, image_A, image_B, mask_A, mask_B, finetune_st B_resized = F.interpolate( B_trch, size=shape[2:], mode="trilinear", align_corners=False ) - - A_mask_resized = F.interpolate( - A_mask_trch, size=shape[2:], mode="nearest" - ) - B_mask_resized = F.interpolate( - B_mask_trch, size=shape[2:], mode="nearest" - ) - + + A_mask_resized = _resize_itk_mask(mask_A, shape) if mask_A is not None else None + B_mask_resized = _resize_itk_mask(mask_B, shape) if mask_B is not None else None + A_seg_resized = _resize_itk_mask(segmentation_A, shape) if segmentation_A is not None else None + B_seg_resized = _resize_itk_mask(segmentation_B, shape) if segmentation_B is not None else None + if finetune_steps == 0: raise Exception("To indicate no finetune_steps, pass finetune_steps=None") if finetune_steps == None: with torch.no_grad(): - loss = model(A_resized, B_resized, mask_A=A_mask_resized, mask_B=B_mask_resized) + loss = model(A_resized, B_resized, mask_A=A_mask_resized, mask_B=B_mask_resized, segmentation_A=A_seg_resized, segmentation_B=B_seg_resized) else: - loss = finetune_execute_mask(model, A_resized, B_resized, A_mask_resized, B_mask_resized, finetune_steps) + loss = finetune_execute_mask(model, A_resized, B_resized, A_mask_resized, B_mask_resized, finetune_steps, segmentation_A=A_seg_resized, segmentation_B=B_seg_resized) # phi_AB and phi_BA are [1, 3, H, W, D] pytorch tensors representing the forward and backward # maps computed by the model From a32bb5b63fe4388b709c4c05f5fb92341cf68192 Mon Sep 17 00:00:00 2001 From: Basar Demir Date: Mon, 27 Apr 2026 03:53:34 -0400 Subject: [PATCH 2/8] correct model device handling and generalize to_floats --- src/icon_registration/itk_wrapper.py | 8 +++++--- src/icon_registration/losses.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/icon_registration/itk_wrapper.py b/src/icon_registration/itk_wrapper.py index ebad840..8addb4f 100644 --- a/src/icon_registration/itk_wrapper.py +++ b/src/icon_registration/itk_wrapper.py @@ -27,7 +27,7 @@ def finetune_execute(model, image_A, image_B, steps): optimizer.step() with torch.no_grad(): loss = model(image_A, image_B) - #model.load_state_dict(state_dict) + model.load_state_dict(state_dict) return loss @@ -41,7 +41,7 @@ def finetune_execute_mask(model, image_A, image_B, mask_A, mask_B, steps, segmen loss_tuple[0].backward() optimizer.step() with torch.no_grad(): - loss = model(image_A, image_B, mask_A=mask_A, mask_B=mask_B, segmentation_A=segmentation_A, segmentation_B=segmentation_B) + loss = model(image_A, image_B, mask_A=mask_A, mask_B=mask_B, segmentation_A=segmentation_A, segmentation_B=segmentation_B) model.load_state_dict(state_dict) return loss @@ -82,7 +82,6 @@ def register_pair( if finetune_steps == None: with torch.no_grad(): loss = model(A_resized, B_resized) - print(loss) else: loss = finetune_execute(model, A_resized, B_resized, finetune_steps) @@ -109,6 +108,9 @@ def register_pair_with_mask(model, image_A, image_B, mask_A=None, mask_B=None, f assert isinstance(image_A, itk.Image) assert isinstance(image_B, itk.Image) + # send model to cpu or gpu depending on config- auto detects capability + model.to(config.device) + A_npy = np.array(image_A) B_npy = np.array(image_B) diff --git a/src/icon_registration/losses.py b/src/icon_registration/losses.py index 0f3320d..8109cb6 100644 --- a/src/icon_registration/losses.py +++ b/src/icon_registration/losses.py @@ -15,7 +15,7 @@ def to_floats(stats): if isinstance(v, torch.Tensor): v = torch.mean(v).cpu().item() out.append(v) - return ICONLoss(*out) + return type(stats)(*out) ICONLoss = namedtuple( From b0d1afe7c1acbc1a149414c0808dfcf93ecbf7d9 Mon Sep 17 00:00:00 2001 From: Basar Demir Date: Mon, 27 Apr 2026 03:56:39 -0400 Subject: [PATCH 3/8] require segmentation args in finetune_execute_mask --- src/icon_registration/itk_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/icon_registration/itk_wrapper.py b/src/icon_registration/itk_wrapper.py index 8addb4f..624b71c 100644 --- a/src/icon_registration/itk_wrapper.py +++ b/src/icon_registration/itk_wrapper.py @@ -31,7 +31,7 @@ def finetune_execute(model, image_A, image_B, steps): return loss -def finetune_execute_mask(model, image_A, image_B, mask_A, mask_B, steps, segmentation_A=None, segmentation_B=None): +def finetune_execute_mask(model, image_A, image_B, mask_A, mask_B, steps, segmentation_A, segmentation_B): state_dict = copy.deepcopy(model.state_dict()) optimizer = torch.optim.Adam(model.parameters(), lr=0.00002) for _ in range(steps): From 74607017f171c3436528f3b4ad318c6dd9854f93 Mon Sep 17 00:00:00 2001 From: Basar Demir Date: Mon, 27 Apr 2026 14:27:37 -0400 Subject: [PATCH 4/8] reorder masked finetune helper arguments --- src/icon_registration/itk_wrapper.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/icon_registration/itk_wrapper.py b/src/icon_registration/itk_wrapper.py index faa08cd..3aadfea 100644 --- a/src/icon_registration/itk_wrapper.py +++ b/src/icon_registration/itk_wrapper.py @@ -32,7 +32,17 @@ def finetune_execute(model, image_A, image_B, steps, learning_rate): return loss -def finetune_execute_mask(model, image_A, image_B, mask_A, mask_B, steps, learning_rate, segmentation_A, segmentation_B): +def finetune_execute_mask( + model, + image_A, + image_B, + mask_A, + mask_B, + segmentation_A, + segmentation_B, + steps, + learning_rate, +): state_dict = copy.deepcopy(model.state_dict()) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) for _ in range(steps): @@ -163,10 +173,10 @@ def register_pair_with_mask( B_resized, A_mask_resized, B_mask_resized, + A_seg_resized, + B_seg_resized, finetune_steps, learning_rate, - segmentation_A=A_seg_resized, - segmentation_B=B_seg_resized, ) # phi_AB and phi_BA are [1, 3, H, W, D] pytorch tensors representing the forward and backward From 35525826d24380bbd664e1a82b4ade8853d97ef0 Mon Sep 17 00:00:00 2001 From: Basar Demir Date: Wed, 29 Apr 2026 20:21:53 -0400 Subject: [PATCH 5/8] assert positive learning rate and reorder register_pair_with_mask args --- src/icon_registration/itk_wrapper.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/icon_registration/itk_wrapper.py b/src/icon_registration/itk_wrapper.py index 3aadfea..f788ac7 100644 --- a/src/icon_registration/itk_wrapper.py +++ b/src/icon_registration/itk_wrapper.py @@ -61,6 +61,8 @@ def register_pair( model, image_A, image_B, finetune_steps=None, return_artifacts=False, learning_rate=DEFAULT_FINETUNE_LEARNING_RATE ) -> "(itk.CompositeTransform, itk.CompositeTransform)": + assert learning_rate > 0 + assert isinstance(image_A, itk.Image) assert isinstance(image_B, itk.Image) @@ -120,13 +122,15 @@ def register_pair_with_mask( image_B, mask_A=None, mask_B=None, - finetune_steps=None, - return_artifacts=False, - learning_rate=DEFAULT_FINETUNE_LEARNING_RATE, segmentation_A=None, segmentation_B=None, + finetune_steps=None, + learning_rate=DEFAULT_FINETUNE_LEARNING_RATE, + return_artifacts=False, ): + assert learning_rate > 0 + assert isinstance(image_A, itk.Image) assert isinstance(image_B, itk.Image) @@ -202,6 +206,8 @@ def register_pair_with_multimodalities( model, image_A: list, image_B: list, finetune_steps=None, return_artifacts=False, learning_rate=DEFAULT_FINETUNE_LEARNING_RATE ) -> "(itk.CompositeTransform, itk.CompositeTransform)": + assert learning_rate > 0 + assert len(image_A) == len(image_B), "image_A and image_B should have the same number of modalities." # send model to cpu or gpu depending on config- auto detects capability From 8f051d109136b3b26ff1e0676ffc53eda3352f12 Mon Sep 17 00:00:00 2001 From: Basar Demir Date: Mon, 4 May 2026 05:44:30 -0400 Subject: [PATCH 6/8] 1.1.8 --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 3b38953..162b38d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = icon_registration -version = 1.1.7 +version = 1.1.8 author = Hastings Greer author_email = t@hgreer.com description = A package for image registration regularized by inverse consistency From eb752e7b59d9a8638632c7084c16e6b9361ff733 Mon Sep 17 00:00:00 2001 From: Basar Demir Date: Mon, 4 May 2026 06:22:01 -0400 Subject: [PATCH 7/8] print loss after no-IO forward --- src/icon_registration/itk_wrapper.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/icon_registration/itk_wrapper.py b/src/icon_registration/itk_wrapper.py index f788ac7..79f9358 100644 --- a/src/icon_registration/itk_wrapper.py +++ b/src/icon_registration/itk_wrapper.py @@ -95,6 +95,7 @@ def register_pair( if finetune_steps == None: with torch.no_grad(): loss = model(A_resized, B_resized) + print(loss) else: loss = finetune_execute(model, A_resized, B_resized, finetune_steps, learning_rate) @@ -170,6 +171,7 @@ def register_pair_with_mask( if finetune_steps == None: with torch.no_grad(): loss = model(A_resized, B_resized, mask_A=A_mask_resized, mask_B=B_mask_resized, segmentation_A=A_seg_resized, segmentation_B=B_seg_resized) + print(loss) else: loss = finetune_execute_mask( model, From 669817d94849f4cb34ef06755a0313eb191dd4d6 Mon Sep 17 00:00:00 2001 From: Basar Demir Date: Mon, 4 May 2026 07:10:30 -0400 Subject: [PATCH 8/8] preserve register_pair_with_mask compatibility for optional segmentations --- src/icon_registration/itk_wrapper.py | 64 ++++++++-------------------- 1 file changed, 17 insertions(+), 47 deletions(-) diff --git a/src/icon_registration/itk_wrapper.py b/src/icon_registration/itk_wrapper.py index 79f9358..5412ccf 100644 --- a/src/icon_registration/itk_wrapper.py +++ b/src/icon_registration/itk_wrapper.py @@ -17,42 +17,17 @@ def _resize_itk_mask(itk_image, shape): return F.interpolate(trch, size=shape[2:], mode="nearest") -def finetune_execute(model, image_A, image_B, steps, learning_rate): +def finetune_execute(model, image_A, image_B, steps, learning_rate, **model_kwargs): state_dict = copy.deepcopy(model.state_dict()) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) for _ in range(steps): optimizer.zero_grad() - loss_tuple = model(image_A, image_B) + loss_tuple = model(image_A, image_B, **model_kwargs) print(loss_tuple) loss_tuple[0].backward() optimizer.step() with torch.no_grad(): - loss = model(image_A, image_B) - model.load_state_dict(state_dict) - return loss - - -def finetune_execute_mask( - model, - image_A, - image_B, - mask_A, - mask_B, - segmentation_A, - segmentation_B, - steps, - learning_rate, -): - state_dict = copy.deepcopy(model.state_dict()) - optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) - for _ in range(steps): - optimizer.zero_grad() - loss_tuple = model(image_A, image_B, mask_A=mask_A, mask_B=mask_B, segmentation_A=segmentation_A, segmentation_B=segmentation_B) - print(loss_tuple) - loss_tuple[0].backward() - optimizer.step() - with torch.no_grad(): - loss = model(image_A, image_B, mask_A=mask_A, mask_B=mask_B, segmentation_A=segmentation_A, segmentation_B=segmentation_B) + loss = model(image_A, image_B, **model_kwargs) model.load_state_dict(state_dict) return loss @@ -123,11 +98,11 @@ def register_pair_with_mask( image_B, mask_A=None, mask_B=None, - segmentation_A=None, - segmentation_B=None, finetune_steps=None, - learning_rate=DEFAULT_FINETUNE_LEARNING_RATE, return_artifacts=False, + learning_rate=DEFAULT_FINETUNE_LEARNING_RATE, + segmentation_A=None, + segmentation_B=None, ): assert learning_rate > 0 @@ -160,30 +135,25 @@ def register_pair_with_mask( B_trch, size=shape[2:], mode="trilinear", align_corners=False ) - A_mask_resized = _resize_itk_mask(mask_A, shape) if mask_A is not None else None - B_mask_resized = _resize_itk_mask(mask_B, shape) if mask_B is not None else None - A_seg_resized = _resize_itk_mask(segmentation_A, shape) if segmentation_A is not None else None - B_seg_resized = _resize_itk_mask(segmentation_B, shape) if segmentation_B is not None else None + model_kwargs = {} + if mask_A is not None: + model_kwargs["mask_A"] = _resize_itk_mask(mask_A, shape) + if mask_B is not None: + model_kwargs["mask_B"] = _resize_itk_mask(mask_B, shape) + if segmentation_A is not None: + model_kwargs["segmentation_A"] = _resize_itk_mask(segmentation_A, shape) + if segmentation_B is not None: + model_kwargs["segmentation_B"] = _resize_itk_mask(segmentation_B, shape) if finetune_steps == 0: raise Exception("To indicate no finetune_steps, pass finetune_steps=None") if finetune_steps == None: with torch.no_grad(): - loss = model(A_resized, B_resized, mask_A=A_mask_resized, mask_B=B_mask_resized, segmentation_A=A_seg_resized, segmentation_B=B_seg_resized) + loss = model(A_resized, B_resized, **model_kwargs) print(loss) else: - loss = finetune_execute_mask( - model, - A_resized, - B_resized, - A_mask_resized, - B_mask_resized, - A_seg_resized, - B_seg_resized, - finetune_steps, - learning_rate, - ) + loss = finetune_execute(model, A_resized, B_resized, finetune_steps, learning_rate, **model_kwargs) # phi_AB and phi_BA are [1, 3, H, W, D] pytorch tensors representing the forward and backward # maps computed by the model