✨ print number of parameters of network
This commit is contained in:
parent
c09ff5b20f
commit
4beb963aff
|
@ -323,6 +323,9 @@ class Model(nn.Module):
|
||||||
extra_feature_channels=0,
|
extra_feature_channels=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
pytorch_total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||||
|
print("Total params: {}".format(pytorch_total_params))
|
||||||
|
|
||||||
def prior_kl(self, x0):
|
def prior_kl(self, x0):
|
||||||
return self.diffusion._prior_bpd(x0)
|
return self.diffusion._prior_bpd(x0)
|
||||||
|
|
||||||
|
@ -510,7 +513,7 @@ def generate(model, opt):
|
||||||
x = data["positions"].transpose(1, 2)
|
x = data["positions"].transpose(1, 2)
|
||||||
# m, s = data["mean"].float(), data["std"].float()
|
# m, s = data["mean"].float(), data["std"].float()
|
||||||
|
|
||||||
shape = torch.Size((*x.shape[:-1], 35000))
|
shape = torch.Size((*x.shape[:-1], 75000))
|
||||||
gen = model.gen_samples(shape, "cuda", clip_denoised=False).detach().cpu()
|
gen = model.gen_samples(shape, "cuda", clip_denoised=False).detach().cpu()
|
||||||
|
|
||||||
gen = gen.transpose(1, 2).contiguous()
|
gen = gen.transpose(1, 2).contiguous()
|
||||||
|
|
Loading…
Reference in a new issue