The hybrid loss from eq. 5 is defined (and also accordingly implemented in the official implementation) as:
losses = vb_losses + self.hybrid_coeff * ce_losses
Here it is implemented as follows though:
|
return self.hybrid_loss_coeff * vb_loss + ce_loss, { |
The hybrid loss from eq. 5 is defined (and also accordingly implemented in the official implementation) as:
Here it is implemented as follows though:
d3pm/d3pm_runner.py
Line 294 in 3ceb637