✨ print number of parameters of network
This commit is contained in:
parent
c09ff5b20f
commit
4beb963aff
|
@ -323,6 +323,9 @@ class Model(nn.Module):
|
|||
extra_feature_channels=0,
|
||||
)
|
||||
|
||||
pytorch_total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||
print("Total params: {}".format(pytorch_total_params))
|
||||
|
||||
def prior_kl(self, x0):
|
||||
return self.diffusion._prior_bpd(x0)
|
||||
|
||||
|
@ -510,7 +513,7 @@ def generate(model, opt):
|
|||
x = data["positions"].transpose(1, 2)
|
||||
# m, s = data["mean"].float(), data["std"].float()
|
||||
|
||||
shape = torch.Size((*x.shape[:-1], 35000))
|
||||
shape = torch.Size((*x.shape[:-1], 75000))
|
||||
gen = model.gen_samples(shape, "cuda", clip_denoised=False).detach().cpu()
|
||||
|
||||
gen = gen.transpose(1, 2).contiguous()
|
||||
|
|
Loading…
Reference in a new issue