From 0ef9148666803dbb9edbcdfaf60033a94059572f Mon Sep 17 00:00:00 2001 From: Laurent FAINSIN Date: Tue, 11 Apr 2023 16:01:07 +0200 Subject: [PATCH] feat: modify train_generation script to use rotor37 dataset --- train_generation.py | 69 +++++++++++++++++++++++---------------------- 1 file changed, 36 insertions(+), 33 deletions(-) diff --git a/train_generation.py b/train_generation.py index d87c78a..e208afd 100644 --- a/train_generation.py +++ b/train_generation.py @@ -1,5 +1,6 @@ import argparse +import datasets import numpy as np import torch.distributed as dist import torch.multiprocessing as mp @@ -8,7 +9,7 @@ import torch.optim as optim import torch.utils.data from torch.distributions import Normal -from dataset.shapenet_data_pc import ShapeNet15kPointClouds +# from dataset.shapenet_data_pc import ShapeNet15kPointClouds from model.pvcnn_generation import PVCNN2Base from utils.file_utils import * from utils.visualize import * @@ -549,30 +550,35 @@ def get_betas(schedule_type, b_start, b_end, time_num): def get_dataset(dataroot, npoints, category): - tr_dataset = ShapeNet15kPointClouds( - root_dir=dataroot, - categories=[category], - split="train", - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1.0, - normalize_per_shape=False, - normalize_std_per_axis=False, - random_subsample=True, - ) - te_dataset = ShapeNet15kPointClouds( - root_dir=dataroot, - categories=[category], - split="val", - tr_sample_size=npoints, - te_sample_size=npoints, - scale=1.0, - normalize_per_shape=False, - normalize_std_per_axis=False, - all_points_mean=tr_dataset.all_points_mean, - all_points_std=tr_dataset.all_points_std, - ) - return tr_dataset, te_dataset + # tr_dataset = ShapeNet15kPointClouds( + # root_dir=dataroot, + # categories=[category], + # split="train", + # tr_sample_size=npoints, + # te_sample_size=npoints, + # scale=1.0, + # normalize_per_shape=False, + # normalize_std_per_axis=False, + # random_subsample=True, + # ) + # te_dataset = ShapeNet15kPointClouds( + # root_dir=dataroot, + # categories=[category], + # split="val", + # tr_sample_size=npoints, + # te_sample_size=npoints, + # scale=1.0, + # normalize_per_shape=False, + # normalize_std_per_axis=False, + # all_points_mean=tr_dataset.all_points_mean, + # all_points_std=tr_dataset.all_points_std, + # ) + train_ds = datasets.load_dataset("dataset/rotor37_data.py", split="train") + train_ds = train_ds.with_format("torch") + + test_ds = datasets.load_dataset("dataset/rotor37_data.py", split="test") + test_ds = test_ds.with_format("torch") + return train_ds, test_ds def get_dataloader(opt, train_dataset, test_dataset=None): @@ -614,7 +620,7 @@ def get_dataloader(opt, train_dataset, test_dataset=None): return train_dataloader, test_dataloader, train_sampler, test_sampler -def train(gpu, opt, output_dir, noises_init): +def train(gpu, opt, output_dir): set_seed(opt) logger = setup_logging(output_dir) if opt.distribution_type == "multi": @@ -702,8 +708,9 @@ def train(gpu, opt, output_dir, noises_init): lr_scheduler.step(epoch) for i, data in enumerate(dataloader): - x = data["train_points"].transpose(1, 2) - noises_batch = noises_init[data["idx"]].transpose(1, 2) + # x = data["train_points"].transpose(1, 2) + x = data["positions"].transpose(1, 2) + noises_batch = torch.randn_like(x) """ train diffusion @@ -830,10 +837,6 @@ def main(): output_dir = get_output_dir(dir_id, exp_id) copy_source(__file__, output_dir) - """ workaround """ - train_dataset, _ = get_dataset(opt.dataroot, opt.npoints, opt.category) - noises_init = torch.randn(len(train_dataset), opt.npoints, opt.nc) - if opt.dist_url == "env://" and opt.world_size == -1: opt.world_size = int(os.environ["WORLD_SIZE"]) @@ -842,7 +845,7 @@ def main(): opt.world_size = opt.ngpus_per_node * opt.world_size mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init)) else: - train(opt.gpu, opt, output_dir, noises_init) + train(opt.gpu, opt, output_dir) def parse_args():