-
Notifications
You must be signed in to change notification settings - Fork 0
Inoutseg fix #25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Inoutseg fix #25
Conversation
…to change it, and also using argmax instead of sum and inverting at the correct order to avoid multiplying the effect of the transformation out of proportions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR fixes issues with the in/out segmentation region selection logic and adds the mix_prob parameter to transform classes. The fix addresses the incorrect handling of multi-channel segmentation masks and reorders operations in transform application.
Changes:
- Fixed
_apply_region_modefunction to correctly handle multi-channel segmentation masks by usingargmaxbefore applying mode inversion - Added
mix_probparameter toRandomInverseGPUandRandomHistogramEqualizationGPUclasses - Reordered operations in
RandomConvTransformGPUto applymix_probmixing before region selection - Updated configuration values for various transform parameters
- Code formatting improvements (trailing whitespace removal, function signature formatting)
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| auglab/transforms/gpu/contrast.py | Fixed segmentation mask handling logic, added mix_prob parameter support, reordered transform operations |
| auglab/transforms/gpu/transforms.py | Added mix_prob parameter to transform instantiations, formatting improvements |
| auglab/configs/transform_params_gpu.json | Updated transform configuration parameters including probabilities and in_seg/out_seg values |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 4 out of 4 changed files in this pull request and generated 16 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if not isinstance(num_transforms, int) or num_transforms < 0: | ||
| raise ValueError(f"num_transforms must be a non-negative int. Got {num_transforms!r}.") |
Copilot
AI
Jan 22, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The validation raises a ValueError for negative num_transforms but doesn't handle the case where num_transforms exceeds the length of transforms_list. While the code handles this gracefully at runtime with k = min(self.num_transforms, len(self.transforms_list)), it might be better to warn users at initialization if they provide a num_transforms that exceeds the available transforms.
|
|
||
| for j in idx.tolist(): | ||
| t = self.transforms_list[j] | ||
| if torch.rand(1, device=x.device, dtype=x.dtype) > t.p: |
Copilot
AI
Jan 22, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check torch.rand(1, device=x.device, dtype=x.dtype) > t.p uses the image dtype for the random number, which could be problematic if x.dtype is an integer type. This should use torch.float32 or similar floating-point dtype to ensure proper probability comparison.
| if torch.rand(1, device=x.device, dtype=x.dtype) > t.p: | |
| if torch.rand(1, device=x.device, dtype=torch.float32) > t.p: |
| m = torch.argmax(m, axis=1) > 0 | ||
| m = m.to(transformed.dtype) |
Copilot
AI
Jan 22, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic change from torch.sum(m, axis=1) to torch.argmax(m, axis=1) > 0 fundamentally alters the behavior. Previously, the sum would allow multiple overlapping segmentation channels to contribute. Now, argmax selects only the channel with the maximum value, and > 0 converts it to a binary mask. This change is significant and could affect how multiple overlapping segmentation regions are handled. Ensure this is the intended behavior and that it has been tested with multi-channel segmentation masks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine, correct? @NathanMolinier
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically we should not have overlapping segmentations so both should work I think
| o = torch.randint(0, 2, (seg_mask.shape[0],), device=seg_mask.device, dtype=seg_mask.dtype) | ||
| m = m * o.view(-1, 1, 1, 1) # Broadcasting o to match the dimensions of m | ||
| m = torch.sum(m, axis=0) | ||
| m = torch.argmax(m, axis=0) > 0 |
Copilot
AI
Jan 22, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same logic change from torch.sum(m, axis=0) to torch.argmax(m, axis=0) > 0 appears here for 3D images. This fundamentally changes how overlapping segmentation channels are handled, converting from a sum of all channels to selecting only the dominant channel. Ensure this behavior change is intentional and properly tested.
| m = torch.argmax(m, axis=0) > 0 | |
| m = torch.sum(m, axis=0) > 0 |
| """Randomly choose X transforms to apply from a given list of ImageOnlyTransform transforms (GPU version). | ||
| Args: | ||
| transforms_list: List of initialized ImageOnlyTransform to choose from. | ||
| num_transforms: Number of transforms to randomly select and apply. | ||
| same_on_batch: apply the same transformation across the batch. | ||
| p: probability for applying the X transforms to a batch. This param controls the augmentation | ||
| probabilities batch-wise. | ||
| keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch | ||
| form ``False``. |
Copilot
AI
Jan 22, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new class RandomChooseXTransformsGPU is missing comprehensive documentation. The docstring should explain the behavior when same_on_batch=False (applies different random selections per batch item), how the transforms are selected (without replacement using randperm), and what happens when a selected transform's probability check fails (it's skipped). Additionally, document the kwargs parameter or remove it if unused.
| """Randomly choose X transforms to apply from a given list of ImageOnlyTransform transforms (GPU version). | |
| Args: | |
| transforms_list: List of initialized ImageOnlyTransform to choose from. | |
| num_transforms: Number of transforms to randomly select and apply. | |
| same_on_batch: apply the same transformation across the batch. | |
| p: probability for applying the X transforms to a batch. This param controls the augmentation | |
| probabilities batch-wise. | |
| keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch | |
| form ``False``. | |
| """Randomly choose and apply a subset of transforms from a given list (GPU version). | |
| This transform samples up to ``num_transforms`` transforms from ``transforms_list`` **without | |
| replacement** for each application. The sampling is implemented by taking a random | |
| permutation of the available transforms via :func:`torch.randperm` and selecting the | |
| first ``num_transforms`` indices. | |
| For each selected transform ``t``, a separate probability check is performed against | |
| ``t.p``. If the random draw for ``t`` is greater than ``t.p``, that transform is | |
| **skipped** and not applied, and no additional transform is selected to replace it. | |
| This means that in practice fewer than ``num_transforms`` transforms may be applied | |
| to a given sample. | |
| When ``same_on_batch=True``, the same subset of transforms (with the same random | |
| outcomes of their internal sampling logic) is used for the entire batch: one | |
| selection is made and applied to all batch elements. | |
| When ``same_on_batch=False``, each element in the batch is processed independently: | |
| for every item, a new random subset of transforms is sampled (still without | |
| replacement), and each selected transform performs its own probability check and | |
| random parameter sampling. | |
| Args: | |
| transforms_list: List of initialized :class:`ImageOnlyTransform` instances to | |
| choose from. | |
| num_transforms: Maximum number of transforms to randomly select (without | |
| replacement) and attempt to apply to each sample. | |
| same_on_batch: If ``True``, apply the same randomly selected subset of transforms | |
| to every element in the batch. If ``False``, sample a separate subset for | |
| each batch element. | |
| p: Probability of applying this *composite* transform to a batch. This controls | |
| whether the selection-and-application process runs at all for a given call, | |
| independently from the per-transform probabilities ``t.p``. | |
| keepdim: Whether to keep the output shape the same as the input (``True``) or | |
| broadcast it to the batch form (``False``). | |
| **kwargs: Additional keyword arguments accepted for API compatibility. They are | |
| currently ignored and do not affect the behavior of this transform. |
| if seg is not None and isinstance(seg, torch.Tensor) and seg.shape[0] == batch_size: | ||
| seg_i = seg[i : i + 1] | ||
| else: | ||
| seg_i = seg |
Copilot
AI
Jan 22, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the per-batch iteration (lines 351-359), when seg is not None and its batch size matches input.shape[0], individual slices are extracted correctly. However, when seg exists but has a different batch size (line 357), the code falls back to using the entire seg for each batch item. This could lead to shape mismatches or incorrect segmentation application. Consider adding validation to ensure seg has the correct shape or raising an error if the shapes don't match.
| if mix_in_out: | ||
| for i in range(seg_mask.shape[0]): | ||
| # Create a tensor with random one and zero | ||
|
|
||
| o = torch.randint(0, 2, (seg_mask.shape[1],), device=seg_mask.device, dtype=seg_mask.dtype) | ||
| m[i] = m[i] * o.view(-1, 1, 1, 1) # Broadcasting o to match the dimensions of m |
Copilot
AI
Jan 22, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When mix_in_out=True, the code randomly zeros out some segmentation channels before applying argmax. However, if all channels get zeroed out for a sample (which has probability 2^(-num_channels)), then argmax(m, axis=1) > 0 will be False everywhere, effectively making the entire mask zero. This edge case should be handled, or the probability should be adjusted to ensure at least one channel remains active.
No description provided.