PVD/modules/loss.py
Linqi (Alex) Zhou 2f6aa752a6 PVD
2021-10-19 13:54:46 -07:00

11 lines
163 B
Python

import torch.nn as nn
import modules.functional as F
__all__ = ['KLLoss']
class KLLoss(nn.Module):
def forward(self, x, y):
return F.kl_loss(x, y)