11 lines
163 B
Python
11 lines
163 B
Python
|
import torch.nn as nn
|
||
|
|
||
|
import modules.functional as F
|
||
|
|
||
|
__all__ = ['KLLoss']
|
||
|
|
||
|
|
||
|
class KLLoss(nn.Module):
|
||
|
def forward(self, x, y):
|
||
|
return F.kl_loss(x, y)
|