print number of parameters of network

This commit is contained in:
Laurent FAINSIN 2023-05-22 09:33:10 +02:00
parent c09ff5b20f
commit 4beb963aff

View file

@ -323,6 +323,9 @@ class Model(nn.Module):
extra_feature_channels=0, extra_feature_channels=0,
) )
pytorch_total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
print("Total params: {}".format(pytorch_total_params))
def prior_kl(self, x0): def prior_kl(self, x0):
return self.diffusion._prior_bpd(x0) return self.diffusion._prior_bpd(x0)
@ -510,7 +513,7 @@ def generate(model, opt):
x = data["positions"].transpose(1, 2) x = data["positions"].transpose(1, 2)
# m, s = data["mean"].float(), data["std"].float() # m, s = data["mean"].float(), data["std"].float()
shape = torch.Size((*x.shape[:-1], 35000)) shape = torch.Size((*x.shape[:-1], 75000))
gen = model.gen_samples(shape, "cuda", clip_denoised=False).detach().cpu() gen = model.gen_samples(shape, "cuda", clip_denoised=False).detach().cpu()
gen = gen.transpose(1, 2).contiguous() gen = gen.transpose(1, 2).contiguous()