From 4beb963aff870e24b28b46aa39c88e031434a9fe Mon Sep 17 00:00:00 2001 From: Laurent FAINSIN Date: Mon, 22 May 2023 09:33:10 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20print=20number=20of=20parameters=20?= =?UTF-8?q?of=20network?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test_generation.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test_generation.py b/test_generation.py index b018fb9..4f8d018 100644 --- a/test_generation.py +++ b/test_generation.py @@ -323,6 +323,9 @@ class Model(nn.Module): extra_feature_channels=0, ) + pytorch_total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + print("Total params: {}".format(pytorch_total_params)) + def prior_kl(self, x0): return self.diffusion._prior_bpd(x0) @@ -510,7 +513,7 @@ def generate(model, opt): x = data["positions"].transpose(1, 2) # m, s = data["mean"].float(), data["std"].float() - shape = torch.Size((*x.shape[:-1], 35000)) + shape = torch.Size((*x.shape[:-1], 75000)) gen = model.gen_samples(shape, "cuda", clip_denoised=False).detach().cpu() gen = gen.transpose(1, 2).contiguous()