Debug: update encoder

This commit is contained in:
Guandao Yang 2019-07-15 20:58:09 -07:00
parent b439ea71e7
commit 4dda09e2d0

View file

@ -52,10 +52,10 @@ class Encoder(nn.Module):
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 512)
ms = F.relu(self.fc_bn1(self.fc1(x)))
ms = F.relu(self.fc_bn2(self.fc2(ms)))
ms = self.fc3(ms)
if self.use_deterministic_encoder:
ms = F.relu(self.fc_bn1(self.fc1(x)))
ms = F.relu(self.fc_bn2(self.fc2(ms)))
ms = self.fc3(ms)
m, v = ms, 0
else:
m = F.relu(self.fc_bn1_m(self.fc1_m(x)))