11 lines
163 B
Python
11 lines
163 B
Python
import torch.nn as nn
|
|
|
|
import modules.functional as F
|
|
|
|
__all__ = ["KLLoss"]
|
|
|
|
|
|
class KLLoss(nn.Module):
|
|
def forward(self, x, y):
|
|
return F.kl_loss(x, y)
|