PVD/modules/loss.py

11 lines
163 B
Python
Raw Normal View History

2021-10-19 20:54:46 +00:00
import torch.nn as nn
import modules.functional as F
2023-04-11 09:12:58 +00:00
__all__ = ["KLLoss"]
2021-10-19 20:54:46 +00:00
class KLLoss(nn.Module):
def forward(self, x, y):
return F.kl_loss(x, y)