feat: modify train_generation script to use rotor37 dataset

This commit is contained in:
Laurent FAINSIN 2023-04-11 16:01:07 +02:00
parent df48f8272a
commit 0ef9148666

View file

@ -1,5 +1,6 @@
import argparse
import datasets
import numpy as np
import torch.distributed as dist
import torch.multiprocessing as mp
@ -8,7 +9,7 @@ import torch.optim as optim
import torch.utils.data
from torch.distributions import Normal
from dataset.shapenet_data_pc import ShapeNet15kPointClouds
# from dataset.shapenet_data_pc import ShapeNet15kPointClouds
from model.pvcnn_generation import PVCNN2Base
from utils.file_utils import *
from utils.visualize import *
@ -549,30 +550,35 @@ def get_betas(schedule_type, b_start, b_end, time_num):
def get_dataset(dataroot, npoints, category):
tr_dataset = ShapeNet15kPointClouds(
root_dir=dataroot,
categories=[category],
split="train",
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.0,
normalize_per_shape=False,
normalize_std_per_axis=False,
random_subsample=True,
)
te_dataset = ShapeNet15kPointClouds(
root_dir=dataroot,
categories=[category],
split="val",
tr_sample_size=npoints,
te_sample_size=npoints,
scale=1.0,
normalize_per_shape=False,
normalize_std_per_axis=False,
all_points_mean=tr_dataset.all_points_mean,
all_points_std=tr_dataset.all_points_std,
)
return tr_dataset, te_dataset
# tr_dataset = ShapeNet15kPointClouds(
# root_dir=dataroot,
# categories=[category],
# split="train",
# tr_sample_size=npoints,
# te_sample_size=npoints,
# scale=1.0,
# normalize_per_shape=False,
# normalize_std_per_axis=False,
# random_subsample=True,
# )
# te_dataset = ShapeNet15kPointClouds(
# root_dir=dataroot,
# categories=[category],
# split="val",
# tr_sample_size=npoints,
# te_sample_size=npoints,
# scale=1.0,
# normalize_per_shape=False,
# normalize_std_per_axis=False,
# all_points_mean=tr_dataset.all_points_mean,
# all_points_std=tr_dataset.all_points_std,
# )
train_ds = datasets.load_dataset("dataset/rotor37_data.py", split="train")
train_ds = train_ds.with_format("torch")
test_ds = datasets.load_dataset("dataset/rotor37_data.py", split="test")
test_ds = test_ds.with_format("torch")
return train_ds, test_ds
def get_dataloader(opt, train_dataset, test_dataset=None):
@ -614,7 +620,7 @@ def get_dataloader(opt, train_dataset, test_dataset=None):
return train_dataloader, test_dataloader, train_sampler, test_sampler
def train(gpu, opt, output_dir, noises_init):
def train(gpu, opt, output_dir):
set_seed(opt)
logger = setup_logging(output_dir)
if opt.distribution_type == "multi":
@ -702,8 +708,9 @@ def train(gpu, opt, output_dir, noises_init):
lr_scheduler.step(epoch)
for i, data in enumerate(dataloader):
x = data["train_points"].transpose(1, 2)
noises_batch = noises_init[data["idx"]].transpose(1, 2)
# x = data["train_points"].transpose(1, 2)
x = data["positions"].transpose(1, 2)
noises_batch = torch.randn_like(x)
"""
train diffusion
@ -830,10 +837,6 @@ def main():
output_dir = get_output_dir(dir_id, exp_id)
copy_source(__file__, output_dir)
""" workaround """
train_dataset, _ = get_dataset(opt.dataroot, opt.npoints, opt.category)
noises_init = torch.randn(len(train_dataset), opt.npoints, opt.nc)
if opt.dist_url == "env://" and opt.world_size == -1:
opt.world_size = int(os.environ["WORLD_SIZE"])
@ -842,7 +845,7 @@ def main():
opt.world_size = opt.ngpus_per_node * opt.world_size
mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init))
else:
train(opt.gpu, opt, output_dir, noises_init)
train(opt.gpu, opt, output_dir)
def parse_args():