PointMLP/classification_ModelNet40/helper.py

22 lines
606 B
Python
Raw Normal View History

2021-10-04 07:22:15 +00:00
import torch
import torch.nn.functional as F
2023-08-03 14:40:14 +00:00
def cal_loss(pred, gold, smoothing=True):
"""Calculate cross entropy loss, apply label smoothing if needed."""
2021-10-04 07:22:15 +00:00
gold = gold.contiguous().view(-1)
if smoothing:
eps = 0.2
n_class = pred.size(1)
one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, dim=1)
loss = -(one_hot * log_prb).sum(dim=1).mean()
else:
2023-08-03 14:40:14 +00:00
loss = F.cross_entropy(pred, gold, reduction="mean")
2021-10-04 07:22:15 +00:00
return loss