Skip to content

Commit c6495de

Browse files
committed
*2 multiplier to huber loss cause of 1/2 a^2 conv.
The Taylor expansion of sqrt near zero gives 1/2 a^2, which differs from a^2 of the standard MSE loss. This change scales them better against one another
1 parent a58f290 commit c6495de

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

library/train_util.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4665,7 +4665,7 @@ def conditional_loss(model_pred:torch.Tensor, target:torch.Tensor, reduction:str
46654665
if loss_type == 'l2':
46664666
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
46674667
elif loss_type == 'huber' or loss_type == 'huber_scheduled':
4668-
loss = huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
4668+
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
46694669
if reduction == "mean":
46704670
loss = torch.mean(loss)
46714671
elif reduction == "sum":

0 commit comments

Comments
 (0)