-
Notifications
You must be signed in to change notification settings - Fork 312
Description
Hi, thanks for opening your excellent work with nice organized code. I found a very slight typo in sCM loss, but I think it is very crucial.
In code :
second_term=-r * (torch.cos(t) * torch.sin(t) * x_t + sigma_data * F_theta_grad)
is different from line 18 Algorithm 2 in Sana-Sprint Paper.
The correct implementation should be :
second_term=-r * torch.cos(t) * torch.sin(t) * (x_t + sigma_data * F_theta_grad)
To make it more clear:
coeff = =-r * torch.cos(t) * torch.sin(t) and another factor is (x_t + sigma_data * F_theta_grad).
The formula is correct and align with the orignal sCM paper:
I have tested this version second_term=-r * (torch.cos(t) * torch.sin(t) * x_t + sigma_data * F_theta_grad) on my own model while it also produces some reasonable results.