|
if self.dice_loss_weight > 0.0: |
|
assert mask_A is not None and mask_B is not None, "mask_A and mask_B must be provided when dice_loss_weight>0" |
|
unique_A = torch.unique(mask_A.long()) |
|
unique_B = torch.unique(mask_B.long()) |
|
common_labels = unique_A[torch.isin(unique_A, unique_B)] |
|
common_labels = common_labels[common_labels != 0] # exclude background |
|
num_classes = len(common_labels) |
I am uncertain of what form is expected for mask_a / mask_b and label_a / label_b. I assume that label_a/b would be a labelmap with multiple classes and used in computing dice scores. And that mask_a/b would be a binary image that is used to mask the image loss computation. However, as shown in the code above, the num_classes is computed from mask, and mask_a/b must be specified if dice weight > 0. Seems like label_a/b should be used if dice weight > 0 and label_a/b should be used in computing the num_classes.
uniGradICON/src/unigradicon/__init__.py
Lines 43 to 49 in a96ec53
I am uncertain of what form is expected for mask_a / mask_b and label_a / label_b. I assume that label_a/b would be a labelmap with multiple classes and used in computing dice scores. And that mask_a/b would be a binary image that is used to mask the image loss computation. However, as shown in the code above, the num_classes is computed from mask, and mask_a/b must be specified if dice weight > 0. Seems like label_a/b should be used if dice weight > 0 and label_a/b should be used in computing the num_classes.