217 lines
8 KiB
Python
217 lines
8 KiB
Python
|
import torch
|
||
|
from torch.utils.data import Dataset, DataLoader
|
||
|
import numpy as np; np.set_printoptions(precision=4)
|
||
|
import shutil, argparse, time, os
|
||
|
import pandas as pd
|
||
|
from collections import defaultdict
|
||
|
from src import config
|
||
|
from src.utils import mc_from_psr, export_mesh, export_pointcloud
|
||
|
from src.dpsr import DPSR
|
||
|
from src.training import Trainer
|
||
|
from src.model import Encode2Points
|
||
|
from src.utils import load_config, load_model_manual, scale2onet, is_url, load_url
|
||
|
from tqdm import tqdm
|
||
|
from pdb import set_trace as st
|
||
|
|
||
|
|
||
|
def main():
|
||
|
parser = argparse.ArgumentParser(description='MNIST toy experiment')
|
||
|
parser.add_argument('config', type=str, help='Path to config file.')
|
||
|
parser.add_argument('--no_cuda', action='store_true', default=False,
|
||
|
help='disables CUDA training')
|
||
|
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
|
||
|
parser.add_argument('--iter', type=int, metavar='S', help='the training iteration to be evaluated.')
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
cfg = load_config(args.config, 'configs/default.yaml')
|
||
|
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
||
|
device = torch.device("cuda" if use_cuda else "cpu")
|
||
|
data_type = cfg['data']['data_type']
|
||
|
input_type = cfg['data']['input_type']
|
||
|
vis_n_outputs = cfg['generation']['vis_n_outputs']
|
||
|
if vis_n_outputs is None:
|
||
|
vis_n_outputs = -1
|
||
|
# Shorthands
|
||
|
out_dir = cfg['train']['out_dir']
|
||
|
if not out_dir:
|
||
|
os.makedirs(out_dir)
|
||
|
generation_dir = os.path.join(out_dir, cfg['generation']['generation_dir'])
|
||
|
out_time_file = os.path.join(generation_dir, 'time_generation_full.pkl')
|
||
|
out_time_file_class = os.path.join(generation_dir, 'time_generation.pkl')
|
||
|
|
||
|
# PYTORCH VERSION > 1.0.0
|
||
|
assert(float(torch.__version__.split('.')[-3]) > 0)
|
||
|
|
||
|
dataset = config.get_dataset('test', cfg, return_idx=True)
|
||
|
test_loader = torch.utils.data.DataLoader(
|
||
|
dataset, batch_size=1, num_workers=0, shuffle=False)
|
||
|
|
||
|
model = Encode2Points(cfg).to(device)
|
||
|
|
||
|
# load model
|
||
|
try:
|
||
|
if is_url(cfg['test']['model_file']):
|
||
|
state_dict = load_url(cfg['test']['model_file'])
|
||
|
elif cfg['generation'].get('iter', 0)!=0:
|
||
|
state_dict = torch.load(os.path.join(out_dir, 'model', '%04d.pt'% cfg['generation']['iter']))
|
||
|
generation_dir += '_%04d'%cfg['generation']['iter']
|
||
|
elif args.iter is not None:
|
||
|
state_dict = torch.load(os.path.join(out_dir, 'model', '%04d.pt'% args.iter))
|
||
|
else:
|
||
|
state_dict = torch.load(os.path.join(out_dir, 'model_best.pt'))
|
||
|
|
||
|
load_model_manual(state_dict['state_dict'], model)
|
||
|
|
||
|
except:
|
||
|
print('Model loading error. Exiting.')
|
||
|
exit()
|
||
|
|
||
|
|
||
|
# Generator
|
||
|
generator = config.get_generator(model, cfg, device=device)
|
||
|
|
||
|
# Determine what to generate
|
||
|
generate_mesh = cfg['generation']['generate_mesh']
|
||
|
generate_pointcloud = cfg['generation']['generate_pointcloud']
|
||
|
|
||
|
# Statistics
|
||
|
time_dicts = []
|
||
|
|
||
|
# Generate
|
||
|
model.eval()
|
||
|
dpsr = DPSR(res=(cfg['generation']['psr_resolution'],
|
||
|
cfg['generation']['psr_resolution'],
|
||
|
cfg['generation']['psr_resolution']),
|
||
|
sig= cfg['generation']['psr_sigma']).to(device)
|
||
|
|
||
|
|
||
|
|
||
|
# Count how many models already created
|
||
|
model_counter = defaultdict(int)
|
||
|
|
||
|
print('Generating...')
|
||
|
for it, data in enumerate(tqdm(test_loader)):
|
||
|
|
||
|
# Output folders
|
||
|
mesh_dir = os.path.join(generation_dir, 'meshes')
|
||
|
in_dir = os.path.join(generation_dir, 'input')
|
||
|
pointcloud_dir = os.path.join(generation_dir, 'pointcloud')
|
||
|
generation_vis_dir = os.path.join(generation_dir, 'vis', )
|
||
|
|
||
|
# Get index etc.
|
||
|
idx = data['idx'].item()
|
||
|
|
||
|
try:
|
||
|
model_dict = dataset.get_model_dict(idx)
|
||
|
except AttributeError:
|
||
|
model_dict = {'model': str(idx), 'category': 'n/a'}
|
||
|
|
||
|
modelname = model_dict['model']
|
||
|
category_id = model_dict['category']
|
||
|
|
||
|
try:
|
||
|
category_name = dataset.metadata[category_id].get('name', 'n/a')
|
||
|
except AttributeError:
|
||
|
category_name = 'n/a'
|
||
|
|
||
|
if category_id != 'n/a':
|
||
|
mesh_dir = os.path.join(mesh_dir, str(category_id))
|
||
|
pointcloud_dir = os.path.join(pointcloud_dir, str(category_id))
|
||
|
in_dir = os.path.join(in_dir, str(category_id))
|
||
|
|
||
|
folder_name = str(category_id)
|
||
|
if category_name != 'n/a':
|
||
|
folder_name = str(folder_name) + '_' + category_name.split(',')[0]
|
||
|
|
||
|
generation_vis_dir = os.path.join(generation_vis_dir, folder_name)
|
||
|
|
||
|
# Create directories if necessary
|
||
|
if vis_n_outputs >= 0 and not os.path.exists(generation_vis_dir):
|
||
|
os.makedirs(generation_vis_dir)
|
||
|
|
||
|
if generate_mesh and not os.path.exists(mesh_dir):
|
||
|
os.makedirs(mesh_dir)
|
||
|
|
||
|
if generate_pointcloud and not os.path.exists(pointcloud_dir):
|
||
|
os.makedirs(pointcloud_dir)
|
||
|
|
||
|
if not os.path.exists(in_dir):
|
||
|
os.makedirs(in_dir)
|
||
|
|
||
|
# Timing dict
|
||
|
time_dict = {
|
||
|
'idx': idx,
|
||
|
'class id': category_id,
|
||
|
'class name': category_name,
|
||
|
'modelname':modelname,
|
||
|
}
|
||
|
time_dicts.append(time_dict)
|
||
|
|
||
|
# Generate outputs
|
||
|
out_file_dict = {}
|
||
|
|
||
|
if generate_mesh:
|
||
|
#! deploy the generator to a separate class
|
||
|
out = generator.generate_mesh(data)
|
||
|
|
||
|
v, f, points, normals, stats_dict = out
|
||
|
time_dict.update(stats_dict)
|
||
|
|
||
|
# Write output
|
||
|
mesh_out_file = os.path.join(mesh_dir, '%s.off' % modelname)
|
||
|
export_mesh(mesh_out_file, scale2onet(v), f)
|
||
|
out_file_dict['mesh'] = mesh_out_file
|
||
|
|
||
|
if generate_pointcloud:
|
||
|
pointcloud_out_file = os.path.join(
|
||
|
pointcloud_dir, '%s.ply' % modelname)
|
||
|
export_pointcloud(pointcloud_out_file, scale2onet(points), normals)
|
||
|
out_file_dict['pointcloud'] = pointcloud_out_file
|
||
|
|
||
|
if cfg['generation']['copy_input']:
|
||
|
inputs_path = os.path.join(in_dir, '%s.ply' % modelname)
|
||
|
p = data.get('inputs').to(device)
|
||
|
export_pointcloud(inputs_path, scale2onet(p))
|
||
|
out_file_dict['in'] = inputs_path
|
||
|
|
||
|
# Copy to visualization directory for first vis_n_output samples
|
||
|
c_it = model_counter[category_id]
|
||
|
if c_it < vis_n_outputs:
|
||
|
# Save output files
|
||
|
img_name = '%02d.off' % c_it
|
||
|
for k, filepath in out_file_dict.items():
|
||
|
ext = os.path.splitext(filepath)[1]
|
||
|
out_file = os.path.join(generation_vis_dir, '%02d_%s%s'
|
||
|
% (c_it, k, ext))
|
||
|
shutil.copyfile(filepath, out_file)
|
||
|
|
||
|
# Also generate oracle meshes
|
||
|
if cfg['generation']['exp_oracle']:
|
||
|
points_gt = data.get('gt_points').to(device)
|
||
|
normals_gt = data.get('gt_points.normals').to(device)
|
||
|
psr_gt = dpsr(points_gt, normals_gt)
|
||
|
v, f, _ = mc_from_psr(psr_gt,
|
||
|
zero_level=cfg['data']['zero_level'])
|
||
|
out_file = os.path.join(generation_vis_dir, '%02d_%s%s'
|
||
|
% (c_it, 'mesh_oracle', '.off'))
|
||
|
export_mesh(out_file, scale2onet(v), f)
|
||
|
|
||
|
model_counter[category_id] += 1
|
||
|
|
||
|
|
||
|
# Create pandas dataframe and save
|
||
|
time_df = pd.DataFrame(time_dicts)
|
||
|
time_df.set_index(['idx'], inplace=True)
|
||
|
time_df.to_pickle(out_time_file)
|
||
|
|
||
|
# Create pickle files with main statistics
|
||
|
time_df_class = time_df.groupby(by=['class name']).mean()
|
||
|
time_df_class.loc['mean'] = time_df_class.mean()
|
||
|
time_df_class.to_pickle(out_time_file_class)
|
||
|
|
||
|
# Print results
|
||
|
print('Timings [s]:')
|
||
|
print(time_df_class)
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|