379 lines
12 KiB
Python
379 lines
12 KiB
Python
from pprint import pprint
|
|
from sklearn.svm import LinearSVC
|
|
from math import log, pi
|
|
import os
|
|
import torch
|
|
import torch.distributed as dist
|
|
import random
|
|
import numpy as np
|
|
import matplotlib
|
|
matplotlib.use('Agg')
|
|
import matplotlib.pyplot as plt
|
|
from mpl_toolkits.mplot3d import Axes3D
|
|
|
|
|
|
class AverageValueMeter(object):
|
|
"""Computes and stores the average and current value"""
|
|
|
|
def __init__(self):
|
|
self.val = 0
|
|
self.avg = 0
|
|
self.sum = 0
|
|
self.count = 0.0
|
|
|
|
def reset(self):
|
|
self.val = 0
|
|
self.avg = 0
|
|
self.sum = 0
|
|
self.count = 0.0
|
|
|
|
def update(self, val, n=1):
|
|
self.val = val
|
|
self.sum += val * n
|
|
self.count += n
|
|
self.avg = self.sum / self.count
|
|
|
|
|
|
def gaussian_log_likelihood(x, mean, logvar, clip=True):
|
|
if clip:
|
|
logvar = torch.clamp(logvar, min=-4, max=3)
|
|
a = log(2 * pi)
|
|
b = logvar
|
|
c = (x - mean) ** 2 / torch.exp(logvar)
|
|
return -0.5 * torch.sum(a + b + c)
|
|
|
|
|
|
def bernoulli_log_likelihood(x, p, clip=True, eps=1e-6):
|
|
if clip:
|
|
p = torch.clamp(p, min=eps, max=1 - eps)
|
|
return torch.sum((x * torch.log(p)) + ((1 - x) * torch.log(1 - p)))
|
|
|
|
|
|
def kl_diagnormal_stdnormal(mean, logvar):
|
|
a = mean ** 2
|
|
b = torch.exp(logvar)
|
|
c = -1
|
|
d = -logvar
|
|
return 0.5 * torch.sum(a + b + c + d)
|
|
|
|
|
|
def kl_diagnormal_diagnormal(q_mean, q_logvar, p_mean, p_logvar):
|
|
# Ensure correct shapes since no numpy broadcasting yet
|
|
p_mean = p_mean.expand_as(q_mean)
|
|
p_logvar = p_logvar.expand_as(q_logvar)
|
|
|
|
a = p_logvar
|
|
b = - 1
|
|
c = - q_logvar
|
|
d = ((q_mean - p_mean) ** 2 + torch.exp(q_logvar)) / torch.exp(p_logvar)
|
|
return 0.5 * torch.sum(a + b + c + d)
|
|
|
|
|
|
# Taken from https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
|
|
def truncated_normal(tensor, mean=0, std=1, trunc_std=2):
|
|
size = tensor.shape
|
|
tmp = tensor.new_empty(size + (4,)).normal_()
|
|
valid = (tmp < trunc_std) & (tmp > -trunc_std)
|
|
ind = valid.max(-1, keepdim=True)[1]
|
|
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
|
|
tensor.data.mul_(std).add_(mean)
|
|
return tensor
|
|
|
|
|
|
def reduce_tensor(tensor, world_size=None):
|
|
rt = tensor.clone()
|
|
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
|
if world_size is None:
|
|
world_size = dist.get_world_size()
|
|
|
|
rt /= world_size
|
|
return rt
|
|
|
|
|
|
def standard_normal_logprob(z):
|
|
dim = z.size(-1)
|
|
log_z = -0.5 * dim * log(2 * pi)
|
|
return log_z - z.pow(2) / 2
|
|
|
|
|
|
def set_random_seed(seed):
|
|
"""set random seed"""
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
# Visualization
|
|
def visualize_point_clouds(pts, gtr, idx, pert_order=[0, 1, 2]):
|
|
pts = pts.cpu().detach().numpy()[:, pert_order]
|
|
gtr = gtr.cpu().detach().numpy()[:, pert_order]
|
|
|
|
fig = plt.figure(figsize=(6, 3))
|
|
ax1 = fig.add_subplot(121, projection='3d')
|
|
ax1.set_title("Sample:%s" % idx)
|
|
ax1.scatter(pts[:, 0], pts[:, 1], pts[:, 2], s=5)
|
|
|
|
ax2 = fig.add_subplot(122, projection='3d')
|
|
ax2.set_title("Ground Truth:%s" % idx)
|
|
ax2.scatter(gtr[:, 0], gtr[:, 1], gtr[:, 2], s=5)
|
|
|
|
fig.canvas.draw()
|
|
|
|
# grab the pixel buffer and dump it into a numpy array
|
|
res = np.array(fig.canvas.renderer._renderer)
|
|
res = np.transpose(res, (2, 0, 1))
|
|
|
|
plt.close()
|
|
return res
|
|
|
|
|
|
# Augmentation
|
|
def apply_random_rotation(pc, rot_axis=1):
|
|
B = pc.shape[0]
|
|
|
|
theta = np.random.rand(B) * 2 * np.pi
|
|
zeros = np.zeros(B)
|
|
ones = np.ones(B)
|
|
cos = np.cos(theta)
|
|
sin = np.sin(theta)
|
|
|
|
if rot_axis == 0:
|
|
rot = np.stack([
|
|
cos, -sin, zeros,
|
|
sin, cos, zeros,
|
|
zeros, zeros, ones
|
|
]).T.reshape(B, 3, 3)
|
|
elif rot_axis == 1:
|
|
rot = np.stack([
|
|
cos, zeros, -sin,
|
|
zeros, ones, zeros,
|
|
sin, zeros, cos
|
|
]).T.reshape(B, 3, 3)
|
|
elif rot_axis == 2:
|
|
rot = np.stack([
|
|
ones, zeros, zeros,
|
|
zeros, cos, -sin,
|
|
zeros, sin, cos
|
|
]).T.reshape(B, 3, 3)
|
|
else:
|
|
raise Exception("Invalid rotation axis")
|
|
rot = torch.from_numpy(rot).to(pc)
|
|
|
|
# (B, N, 3) mul (B, 3, 3) -> (B, N, 3)
|
|
pc_rotated = torch.bmm(pc, rot)
|
|
return pc_rotated, rot, theta
|
|
|
|
|
|
def validate_classification(loaders, model, args):
|
|
train_loader, test_loader = loaders
|
|
|
|
def _make_iter_(loader):
|
|
iterator = iter(loader)
|
|
return iterator
|
|
|
|
tr_latent = []
|
|
tr_label = []
|
|
for data in _make_iter_(train_loader):
|
|
tr_pc = data['train_points']
|
|
tr_pc = tr_pc.cuda() if args.gpu is None else tr_pc.cuda(args.gpu)
|
|
latent = model.encode(tr_pc)
|
|
label = data['cate_idx']
|
|
tr_latent.append(latent.cpu().detach().numpy())
|
|
tr_label.append(label.cpu().detach().numpy())
|
|
tr_label = np.concatenate(tr_label)
|
|
tr_latent = np.concatenate(tr_latent)
|
|
|
|
te_latent = []
|
|
te_label = []
|
|
for data in _make_iter_(test_loader):
|
|
tr_pc = data['train_points']
|
|
tr_pc = tr_pc.cuda() if args.gpu is None else tr_pc.cuda(args.gpu)
|
|
latent = model.encode(tr_pc)
|
|
label = data['cate_idx']
|
|
te_latent.append(latent.cpu().detach().numpy())
|
|
te_label.append(label.cpu().detach().numpy())
|
|
te_label = np.concatenate(te_label)
|
|
te_latent = np.concatenate(te_latent)
|
|
|
|
clf = LinearSVC(random_state=0)
|
|
clf.fit(tr_latent, tr_label)
|
|
test_pred = clf.predict(te_latent)
|
|
test_gt = te_label.flatten()
|
|
acc = np.mean((test_pred == test_gt).astype(float)) * 100.
|
|
res = {'acc': acc}
|
|
print("Acc:%s" % acc)
|
|
return res
|
|
|
|
|
|
def validate_conditioned(loader, model, args, max_samples=None, save_dir=None):
|
|
from metrics.evaluation_metrics import EMD_CD
|
|
all_idx = []
|
|
all_sample = []
|
|
all_ref = []
|
|
ttl_samples = 0
|
|
iterator = iter(loader)
|
|
|
|
for data in iterator:
|
|
# idx_b, tr_pc, te_pc = data[:3]
|
|
idx_b, tr_pc, te_pc = data['idx'], data['train_points'], data['test_points']
|
|
tr_pc = tr_pc.cuda() if args.gpu is None else tr_pc.cuda(args.gpu)
|
|
te_pc = te_pc.cuda() if args.gpu is None else te_pc.cuda(args.gpu)
|
|
|
|
if tr_pc.size(1) > te_pc.size(1):
|
|
tr_pc = tr_pc[:, :te_pc.size(1), :]
|
|
out_pc = model.reconstruct(tr_pc, num_points=te_pc.size(1))
|
|
|
|
# denormalize
|
|
m, s = data['mean'].float(), data['std'].float()
|
|
m = m.cuda() if args.gpu is None else m.cuda(args.gpu)
|
|
s = s.cuda() if args.gpu is None else s.cuda(args.gpu)
|
|
out_pc = out_pc * s + m
|
|
te_pc = te_pc * s + m
|
|
|
|
all_sample.append(out_pc)
|
|
all_ref.append(te_pc)
|
|
all_idx.append(idx_b)
|
|
|
|
ttl_samples += int(te_pc.size(0))
|
|
if max_samples is not None and ttl_samples >= max_samples:
|
|
break
|
|
|
|
# Compute MMD and CD
|
|
sample_pcs = torch.cat(all_sample, dim=0)
|
|
ref_pcs = torch.cat(all_ref, dim=0)
|
|
print("[rank %s] Recon Sample size:%s Ref size: %s" % (args.rank, sample_pcs.size(), ref_pcs.size()))
|
|
|
|
if save_dir is not None and args.save_val_results:
|
|
smp_pcs_save_name = os.path.join(save_dir, "smp_recon_pcls_gpu%s.npy" % args.gpu)
|
|
ref_pcs_save_name = os.path.join(save_dir, "ref_recon_pcls_gpu%s.npy" % args.gpu)
|
|
np.save(smp_pcs_save_name, sample_pcs.cpu().detach().numpy())
|
|
np.save(ref_pcs_save_name, ref_pcs.cpu().detach().numpy())
|
|
print("Saving file:%s %s" % (smp_pcs_save_name, ref_pcs_save_name))
|
|
|
|
res = EMD_CD(sample_pcs, ref_pcs, args.batch_size, accelerated_cd=True)
|
|
mmd_cd = res['MMD-CD'] if 'MMD-CD' in res else None
|
|
mmd_emd = res['MMD-EMD'] if 'MMD-EMD' in res else None
|
|
|
|
print("MMD-CD :%s" % mmd_cd)
|
|
print("MMD-EMD :%s" % mmd_emd)
|
|
|
|
return res
|
|
|
|
|
|
def validate_sample(loader, model, args, max_samples=None, save_dir=None):
|
|
from metrics.evaluation_metrics import compute_all_metrics, jsd_between_point_cloud_sets as JSD
|
|
all_sample = []
|
|
all_ref = []
|
|
ttl_samples = 0
|
|
|
|
iterator = iter(loader)
|
|
|
|
for data in iterator:
|
|
idx_b, te_pc = data['idx'], data['test_points']
|
|
te_pc = te_pc.cuda() if args.gpu is None else te_pc.cuda(args.gpu)
|
|
_, out_pc = model.sample(te_pc.size(0), te_pc.size(1), gpu=args.gpu)
|
|
|
|
# denormalize
|
|
m, s = data['mean'].float(), data['std'].float()
|
|
m = m.cuda() if args.gpu is None else m.cuda(args.gpu)
|
|
s = s.cuda() if args.gpu is None else s.cuda(args.gpu)
|
|
out_pc = out_pc * s + m
|
|
te_pc = te_pc * s + m
|
|
|
|
all_sample.append(out_pc)
|
|
all_ref.append(te_pc)
|
|
|
|
ttl_samples += int(te_pc.size(0))
|
|
if max_samples is not None and ttl_samples >= max_samples:
|
|
break
|
|
|
|
sample_pcs = torch.cat(all_sample, dim=0)
|
|
ref_pcs = torch.cat(all_ref, dim=0)
|
|
print("[rank %s] Generation Sample size:%s Ref size: %s"
|
|
% (args.rank, sample_pcs.size(), ref_pcs.size()))
|
|
|
|
if save_dir is not None and args.save_val_results:
|
|
smp_pcs_save_name = os.path.join(save_dir, "smp_syn_pcls_gpu%s.npy" % args.gpu)
|
|
ref_pcs_save_name = os.path.join(save_dir, "ref_syn_pcls_gpu%s.npy" % args.gpu)
|
|
np.save(smp_pcs_save_name, sample_pcs.cpu().detach().numpy())
|
|
np.save(ref_pcs_save_name, ref_pcs.cpu().detach().numpy())
|
|
print("Saving file:%s %s" % (smp_pcs_save_name, ref_pcs_save_name))
|
|
|
|
res = compute_all_metrics(sample_pcs, ref_pcs, args.batch_size, accelerated_cd=True)
|
|
pprint(res)
|
|
|
|
sample_pcs = sample_pcs.cpu().detach().numpy()
|
|
ref_pcs = ref_pcs.cpu().detach().numpy()
|
|
jsd = JSD(sample_pcs, ref_pcs)
|
|
jsd = torch.tensor(jsd).cuda() if args.gpu is None else torch.tensor(jsd).cuda(args.gpu)
|
|
res.update({"JSD": jsd})
|
|
print("JSD :%s" % jsd)
|
|
return res
|
|
|
|
|
|
def save(model, optimizer, epoch, path):
|
|
d = {
|
|
'epoch': epoch,
|
|
'model': model.state_dict(),
|
|
'optimizer': optimizer.state_dict()
|
|
}
|
|
torch.save(d, path)
|
|
|
|
|
|
def resume(path, model, optimizer=None, strict=True):
|
|
ckpt = torch.load(path)
|
|
model.load_state_dict(ckpt['model'], strict=strict)
|
|
start_epoch = ckpt['epoch']
|
|
if optimizer is not None:
|
|
optimizer.load_state_dict(ckpt['optimizer'])
|
|
return model, optimizer, start_epoch
|
|
|
|
|
|
def validate(test_loader, model, epoch, writer, save_dir, args, clf_loaders=None):
|
|
model.eval()
|
|
|
|
# Make epoch wise save directory
|
|
if writer is not None and args.save_val_results:
|
|
save_dir = os.path.join(save_dir, 'epoch-%d' % epoch)
|
|
if not os.path.isdir(save_dir):
|
|
os.makedirs(save_dir)
|
|
else:
|
|
save_dir = None
|
|
|
|
# classification
|
|
if args.eval_classification and clf_loaders is not None:
|
|
for clf_expr, loaders in clf_loaders.items():
|
|
with torch.no_grad():
|
|
clf_val_res = validate_classification(loaders, model, args)
|
|
|
|
for k, v in clf_val_res.items():
|
|
if writer is not None and v is not None:
|
|
writer.add_scalar('val_%s/%s' % (clf_expr, k), v, epoch)
|
|
|
|
# samples
|
|
if args.use_latent_flow:
|
|
with torch.no_grad():
|
|
val_sample_res = validate_sample(
|
|
test_loader, model, args, max_samples=args.max_validate_shapes,
|
|
save_dir=save_dir)
|
|
|
|
for k, v in val_sample_res.items():
|
|
if not isinstance(v, float):
|
|
v = v.cpu().detach().item()
|
|
if writer is not None and v is not None:
|
|
writer.add_scalar('val_sample/%s' % k, v, epoch)
|
|
|
|
# reconstructions
|
|
with torch.no_grad():
|
|
val_res = validate_conditioned(
|
|
test_loader, model, args, max_samples=args.max_validate_shapes,
|
|
save_dir=save_dir)
|
|
for k, v in val_res.items():
|
|
if not isinstance(v, float):
|
|
v = v.cpu().detach().item()
|
|
if writer is not None and v is not None:
|
|
writer.add_scalar('val_conditioned/%s' % k, v, epoch)
|
|
|