Debug: update encoder
This commit is contained in:
parent
b439ea71e7
commit
4dda09e2d0
|
@ -52,10 +52,10 @@ class Encoder(nn.Module):
|
||||||
x = torch.max(x, 2, keepdim=True)[0]
|
x = torch.max(x, 2, keepdim=True)[0]
|
||||||
x = x.view(-1, 512)
|
x = x.view(-1, 512)
|
||||||
|
|
||||||
|
if self.use_deterministic_encoder:
|
||||||
ms = F.relu(self.fc_bn1(self.fc1(x)))
|
ms = F.relu(self.fc_bn1(self.fc1(x)))
|
||||||
ms = F.relu(self.fc_bn2(self.fc2(ms)))
|
ms = F.relu(self.fc_bn2(self.fc2(ms)))
|
||||||
ms = self.fc3(ms)
|
ms = self.fc3(ms)
|
||||||
if self.use_deterministic_encoder:
|
|
||||||
m, v = ms, 0
|
m, v = ms, 0
|
||||||
else:
|
else:
|
||||||
m = F.relu(self.fc_bn1_m(self.fc1_m(x)))
|
m = F.relu(self.fc_bn1_m(self.fc1_m(x)))
|
||||||
|
|
Loading…
Reference in a new issue