import torch.nn as nn
import modules.functional as F
__all__ = ["KLLoss"]
class KLLoss(nn.Module):
def forward(self, x, y):
return F.kl_loss(x, y)