diff --git a/models/networks.py b/models/networks.py index 88a4935..c514966 100644 --- a/models/networks.py +++ b/models/networks.py @@ -52,10 +52,10 @@ class Encoder(nn.Module): x = torch.max(x, 2, keepdim=True)[0] x = x.view(-1, 512) - ms = F.relu(self.fc_bn1(self.fc1(x))) - ms = F.relu(self.fc_bn2(self.fc2(ms))) - ms = self.fc3(ms) if self.use_deterministic_encoder: + ms = F.relu(self.fc_bn1(self.fc1(x))) + ms = F.relu(self.fc_bn2(self.fc2(ms))) + ms = self.fc3(ms) m, v = ms, 0 else: m = F.relu(self.fc_bn1_m(self.fc1_m(x)))