92 lines
4.3 KiB
Python
92 lines
4.3 KiB
Python
|
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||
|
#
|
||
|
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||
|
# and proprietary rights in and to this software, related documentation
|
||
|
# and any modifications thereto. Any use, reproduction, disclosure or
|
||
|
# distribution of this software and related documentation without an express
|
||
|
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||
|
from models.vae_adain import Model as VAE
|
||
|
from models.latent_points_ada_localprior import PVCNN2Prior as LocalPrior
|
||
|
from utils.diffusion_pvd import DiffusionDiscretized
|
||
|
from utils.vis_helper import plot_points
|
||
|
from utils.model_helper import import_model
|
||
|
from diffusers import DDPMScheduler
|
||
|
import torch
|
||
|
from matplotlib import pyplot as plt
|
||
|
|
||
|
class LION(object):
|
||
|
def __init__(self, cfg):
|
||
|
self.vae = VAE(cfg).cuda()
|
||
|
GlobalPrior = import_model(cfg.latent_pts.style_prior)
|
||
|
global_prior = GlobalPrior(cfg.sde, cfg.latent_pts.style_dim, cfg).cuda()
|
||
|
local_prior = LocalPrior(cfg.sde, cfg.shapelatent.latent_dim, cfg).cuda()
|
||
|
self.priors = torch.nn.ModuleList([global_prior, local_prior])
|
||
|
self.scheduler = DDPMScheduler(clip_sample=False,
|
||
|
beta_start=cfg.ddpm.beta_1, beta_end=cfg.ddpm.beta_T, beta_schedule=cfg.ddpm.sched_mode,
|
||
|
num_train_timesteps=cfg.ddpm.num_steps, variance_type=cfg.ddpm.model_var_type)
|
||
|
self.diffusion = DiffusionDiscretized(None, None, cfg)
|
||
|
# self.load_model(cfg)
|
||
|
|
||
|
def load_model(self, model_path):
|
||
|
# model_path = cfg.ckpt.path
|
||
|
ckpt = torch.load(model_path)
|
||
|
self.priors.load_state_dict(ckpt['dae_state_dict'])
|
||
|
self.vae.load_state_dict(ckpt['vae_state_dict'])
|
||
|
print(f'INFO finish loading from {model_path}')
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def sample(self, num_samples=10, clip_feat=None, save_img=False):
|
||
|
self.scheduler.set_timesteps(1000, device='cuda')
|
||
|
timesteps = self.scheduler.timesteps
|
||
|
latent_shape = self.vae.latent_shape()
|
||
|
global_prior, local_prior = self.priors[0], self.priors[1]
|
||
|
assert(not local_prior.mixed_prediction and not global_prior.mixed_prediction)
|
||
|
sampled_list = []
|
||
|
output_dict = {}
|
||
|
|
||
|
# start sample global prior
|
||
|
x_T_shape = [num_samples] + latent_shape[0]
|
||
|
x_noisy = torch.randn(size=x_T_shape, device='cuda')
|
||
|
condition_input = None
|
||
|
for i, t in enumerate(timesteps):
|
||
|
t_tensor = torch.ones(num_samples, dtype=torch.int64, device='cuda') * (t+1)
|
||
|
noise_pred = global_prior(x=x_noisy, t=t_tensor.float(),
|
||
|
condition_input=condition_input, clip_feat=clip_feat)
|
||
|
x_noisy = self.scheduler.step(noise_pred, t, x_noisy).prev_sample
|
||
|
sampled_list.append(x_noisy)
|
||
|
output_dict['z_global'] = x_noisy
|
||
|
|
||
|
condition_input = x_noisy
|
||
|
condition_input = self.vae.global2style(condition_input)
|
||
|
|
||
|
# start sample local prior
|
||
|
x_T_shape = [num_samples] + latent_shape[1]
|
||
|
x_noisy = torch.randn(size=x_T_shape, device='cuda')
|
||
|
|
||
|
for i, t in enumerate(timesteps):
|
||
|
t_tensor = torch.ones(num_samples, dtype=torch.int64, device='cuda') * (t+1)
|
||
|
noise_pred = local_prior(x=x_noisy, t=t_tensor.float(),
|
||
|
condition_input=condition_input, clip_feat=clip_feat)
|
||
|
x_noisy = self.scheduler.step(noise_pred, t, x_noisy).prev_sample
|
||
|
sampled_list.append(x_noisy)
|
||
|
output_dict['z_local'] = x_noisy
|
||
|
|
||
|
# decode the latent
|
||
|
output = self.vae.sample(num_samples=num_samples, decomposed_eps=sampled_list)
|
||
|
if save_img:
|
||
|
out_name = plot_points(output, "/tmp/tmp.png")
|
||
|
print(f'INFO save plot image at {out_name}')
|
||
|
output_dict['points'] = output
|
||
|
return output_dict
|
||
|
|
||
|
def get_mixing_component(self, noise_pred, t):
|
||
|
# usage:
|
||
|
# if global_prior.mixed_prediction:
|
||
|
# mixing_component = self.get_mixing_component(noise_pred, t)
|
||
|
# coeff = torch.sigmoid(global_prior.mixing_logit)
|
||
|
# noise_pred = (1 - coeff) * mixing_component + coeff * noise_pred
|
||
|
|
||
|
alpha_bar = self.scheduler.alphas_cumprod[t]
|
||
|
one_minus_alpha_bars_sqrt = np.sqrt(1.0 - alpha_bar)
|
||
|
return noise_pred * one_minus_alpha_bars_sqrt
|