You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello,
I have doubts about cosine_angular_loss. After permute operation, the channel dim is in dim 3, but in normalize operation, you set dim=1. So is there a mistake? Can you please explain it ?
Hi your concern looks right to me, but the models are quite good. So I think the released code might just be the problem (and the evaluation code we borrowed right from the OASIS paper). @Ainaz99 what do you think?
Hello,
I have doubts about cosine_angular_loss. After permute operation, the channel dim is in dim 3, but in normalize operation, you set dim=1. So is there a mistake? Can you please explain it ?
def masked_cosine_angular_loss(preds, target, mask_valid):
preds = (2 * preds - 1).clamp(-1, 1)
target = (2 * target - 1).clamp(-1, 1)
mask_valid = mask_valid[:,0,:,:].bool().squeeze(1)
preds = preds.permute(0,2,3,1)[mask_valid, :]
target = target.permute(0,2,3,1)[mask_valid, :]
preds_norm = torch.nn.functional.normalize(preds, p=2, dim=1)
target_norm = torch.nn.functional.normalize(target, p=2, dim=1)
loss = torch.mean(-torch.sum(preds_norm * target_norm, dim = 1))
return loss
The text was updated successfully, but these errors were encountered: