Skip to content

Conversation

@Hendrik-code
Copy link
Collaborator

No description provided.

…to change it, and also using argmax instead of sum and inverting at the correct order to avoid multiplying the effect of the transformation out of proportions
@Hendrik-code Hendrik-code self-assigned this Jan 14, 2026
@Hendrik-code Hendrik-code added bug Something isn't working enhancement New feature or request labels Jan 14, 2026
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes issues with the in/out segmentation region selection logic and adds the mix_prob parameter to transform classes. The fix addresses the incorrect handling of multi-channel segmentation masks and reorders operations in transform application.

Changes:

  • Fixed _apply_region_mode function to correctly handle multi-channel segmentation masks by using argmax before applying mode inversion
  • Added mix_prob parameter to RandomInverseGPU and RandomHistogramEqualizationGPU classes
  • Reordered operations in RandomConvTransformGPU to apply mix_prob mixing before region selection
  • Updated configuration values for various transform parameters
  • Code formatting improvements (trailing whitespace removal, function signature formatting)

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

File Description
auglab/transforms/gpu/contrast.py Fixed segmentation mask handling logic, added mix_prob parameter support, reordered transform operations
auglab/transforms/gpu/transforms.py Added mix_prob parameter to transform instantiations, formatting improvements
auglab/configs/transform_params_gpu.json Updated transform configuration parameters including probabilities and in_seg/out_seg values

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 16 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +310 to +311
if not isinstance(num_transforms, int) or num_transforms < 0:
raise ValueError(f"num_transforms must be a non-negative int. Got {num_transforms!r}.")
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validation raises a ValueError for negative num_transforms but doesn't handle the case where num_transforms exceeds the length of transforms_list. While the code handles this gracefully at runtime with k = min(self.num_transforms, len(self.transforms_list)), it might be better to warn users at initialization if they provide a num_transforms that exceeds the available transforms.

Copilot uses AI. Check for mistakes.

for j in idx.tolist():
t = self.transforms_list[j]
if torch.rand(1, device=x.device, dtype=x.dtype) > t.p:
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check torch.rand(1, device=x.device, dtype=x.dtype) > t.p uses the image dtype for the random number, which could be problematic if x.dtype is an integer type. This should use torch.float32 or similar floating-point dtype to ensure proper probability comparison.

Suggested change
if torch.rand(1, device=x.device, dtype=x.dtype) > t.p:
if torch.rand(1, device=x.device, dtype=torch.float32) > t.p:

Copilot uses AI. Check for mistakes.
Comment on lines +70 to +71
m = torch.argmax(m, axis=1) > 0
m = m.to(transformed.dtype)
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic change from torch.sum(m, axis=1) to torch.argmax(m, axis=1) > 0 fundamentally alters the behavior. Previously, the sum would allow multiple overlapping segmentation channels to contribute. Now, argmax selects only the channel with the maximum value, and > 0 converts it to a binary mask. This change is significant and could affect how multiple overlapping segmentation regions are handled. Ensure this is the intended behavior and that it has been tested with multi-channel segmentation masks.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine, correct? @NathanMolinier

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically we should not have overlapping segmentations so both should work I think

o = torch.randint(0, 2, (seg_mask.shape[0],), device=seg_mask.device, dtype=seg_mask.dtype)
m = m * o.view(-1, 1, 1, 1) # Broadcasting o to match the dimensions of m
m = torch.sum(m, axis=0)
m = torch.argmax(m, axis=0) > 0
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same logic change from torch.sum(m, axis=0) to torch.argmax(m, axis=0) > 0 appears here for 3D images. This fundamentally changes how overlapping segmentation channels are handled, converting from a sum of all channels to selecting only the dominant channel. Ensure this behavior change is intentional and properly tested.

Suggested change
m = torch.argmax(m, axis=0) > 0
m = torch.sum(m, axis=0) > 0

Copilot uses AI. Check for mistakes.
Comment on lines +287 to +296
"""Randomly choose X transforms to apply from a given list of ImageOnlyTransform transforms (GPU version).
Args:
transforms_list: List of initialized ImageOnlyTransform to choose from.
num_transforms: Number of transforms to randomly select and apply.
same_on_batch: apply the same transformation across the batch.
p: probability for applying the X transforms to a batch. This param controls the augmentation
probabilities batch-wise.
keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch
form ``False``.
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new class RandomChooseXTransformsGPU is missing comprehensive documentation. The docstring should explain the behavior when same_on_batch=False (applies different random selections per batch item), how the transforms are selected (without replacement using randperm), and what happens when a selected transform's probability check fails (it's skipped). Additionally, document the kwargs parameter or remove it if unused.

Suggested change
"""Randomly choose X transforms to apply from a given list of ImageOnlyTransform transforms (GPU version).
Args:
transforms_list: List of initialized ImageOnlyTransform to choose from.
num_transforms: Number of transforms to randomly select and apply.
same_on_batch: apply the same transformation across the batch.
p: probability for applying the X transforms to a batch. This param controls the augmentation
probabilities batch-wise.
keepdim: whether to keep the output shape the same as input ``True`` or broadcast it to the batch
form ``False``.
"""Randomly choose and apply a subset of transforms from a given list (GPU version).
This transform samples up to ``num_transforms`` transforms from ``transforms_list`` **without
replacement** for each application. The sampling is implemented by taking a random
permutation of the available transforms via :func:`torch.randperm` and selecting the
first ``num_transforms`` indices.
For each selected transform ``t``, a separate probability check is performed against
``t.p``. If the random draw for ``t`` is greater than ``t.p``, that transform is
**skipped** and not applied, and no additional transform is selected to replace it.
This means that in practice fewer than ``num_transforms`` transforms may be applied
to a given sample.
When ``same_on_batch=True``, the same subset of transforms (with the same random
outcomes of their internal sampling logic) is used for the entire batch: one
selection is made and applied to all batch elements.
When ``same_on_batch=False``, each element in the batch is processed independently:
for every item, a new random subset of transforms is sampled (still without
replacement), and each selected transform performs its own probability check and
random parameter sampling.
Args:
transforms_list: List of initialized :class:`ImageOnlyTransform` instances to
choose from.
num_transforms: Maximum number of transforms to randomly select (without
replacement) and attempt to apply to each sample.
same_on_batch: If ``True``, apply the same randomly selected subset of transforms
to every element in the batch. If ``False``, sample a separate subset for
each batch element.
p: Probability of applying this *composite* transform to a batch. This controls
whether the selection-and-application process runs at all for a given call,
independently from the per-transform probabilities ``t.p``.
keepdim: Whether to keep the output shape the same as the input (``True``) or
broadcast it to the batch form (``False``).
**kwargs: Additional keyword arguments accepted for API compatibility. They are
currently ignored and do not affect the behavior of this transform.

Copilot uses AI. Check for mistakes.
Comment on lines +354 to +357
if seg is not None and isinstance(seg, torch.Tensor) and seg.shape[0] == batch_size:
seg_i = seg[i : i + 1]
else:
seg_i = seg
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the per-batch iteration (lines 351-359), when seg is not None and its batch size matches input.shape[0], individual slices are extracted correctly. However, when seg exists but has a different batch size (line 357), the code falls back to using the entire seg for each batch item. This could lead to shape mismatches or incorrect segmentation application. Consider adding validation to ensure seg has the correct shape or raising an error if the shapes don't match.

Copilot uses AI. Check for mistakes.
Comment on lines 63 to 68
if mix_in_out:
for i in range(seg_mask.shape[0]):
# Create a tensor with random one and zero

o = torch.randint(0, 2, (seg_mask.shape[1],), device=seg_mask.device, dtype=seg_mask.dtype)
m[i] = m[i] * o.view(-1, 1, 1, 1) # Broadcasting o to match the dimensions of m
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When mix_in_out=True, the code randomly zeros out some segmentation channels before applying argmax. However, if all channels get zeroed out for a sample (which has probability 2^(-num_channels)), then argmax(m, axis=1) > 0 will be False everywhere, effectively making the entire mask zero. This edge case should be handled, or the probability should be adjusted to ensure at least one channel remains active.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants