feat: modify train_generation script to use rotor37 dataset
This commit is contained in:
parent
df48f8272a
commit
0ef9148666
|
@ -1,5 +1,6 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
@ -8,7 +9,7 @@ import torch.optim as optim
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
from torch.distributions import Normal
|
from torch.distributions import Normal
|
||||||
|
|
||||||
from dataset.shapenet_data_pc import ShapeNet15kPointClouds
|
# from dataset.shapenet_data_pc import ShapeNet15kPointClouds
|
||||||
from model.pvcnn_generation import PVCNN2Base
|
from model.pvcnn_generation import PVCNN2Base
|
||||||
from utils.file_utils import *
|
from utils.file_utils import *
|
||||||
from utils.visualize import *
|
from utils.visualize import *
|
||||||
|
@ -549,30 +550,35 @@ def get_betas(schedule_type, b_start, b_end, time_num):
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(dataroot, npoints, category):
|
def get_dataset(dataroot, npoints, category):
|
||||||
tr_dataset = ShapeNet15kPointClouds(
|
# tr_dataset = ShapeNet15kPointClouds(
|
||||||
root_dir=dataroot,
|
# root_dir=dataroot,
|
||||||
categories=[category],
|
# categories=[category],
|
||||||
split="train",
|
# split="train",
|
||||||
tr_sample_size=npoints,
|
# tr_sample_size=npoints,
|
||||||
te_sample_size=npoints,
|
# te_sample_size=npoints,
|
||||||
scale=1.0,
|
# scale=1.0,
|
||||||
normalize_per_shape=False,
|
# normalize_per_shape=False,
|
||||||
normalize_std_per_axis=False,
|
# normalize_std_per_axis=False,
|
||||||
random_subsample=True,
|
# random_subsample=True,
|
||||||
)
|
# )
|
||||||
te_dataset = ShapeNet15kPointClouds(
|
# te_dataset = ShapeNet15kPointClouds(
|
||||||
root_dir=dataroot,
|
# root_dir=dataroot,
|
||||||
categories=[category],
|
# categories=[category],
|
||||||
split="val",
|
# split="val",
|
||||||
tr_sample_size=npoints,
|
# tr_sample_size=npoints,
|
||||||
te_sample_size=npoints,
|
# te_sample_size=npoints,
|
||||||
scale=1.0,
|
# scale=1.0,
|
||||||
normalize_per_shape=False,
|
# normalize_per_shape=False,
|
||||||
normalize_std_per_axis=False,
|
# normalize_std_per_axis=False,
|
||||||
all_points_mean=tr_dataset.all_points_mean,
|
# all_points_mean=tr_dataset.all_points_mean,
|
||||||
all_points_std=tr_dataset.all_points_std,
|
# all_points_std=tr_dataset.all_points_std,
|
||||||
)
|
# )
|
||||||
return tr_dataset, te_dataset
|
train_ds = datasets.load_dataset("dataset/rotor37_data.py", split="train")
|
||||||
|
train_ds = train_ds.with_format("torch")
|
||||||
|
|
||||||
|
test_ds = datasets.load_dataset("dataset/rotor37_data.py", split="test")
|
||||||
|
test_ds = test_ds.with_format("torch")
|
||||||
|
return train_ds, test_ds
|
||||||
|
|
||||||
|
|
||||||
def get_dataloader(opt, train_dataset, test_dataset=None):
|
def get_dataloader(opt, train_dataset, test_dataset=None):
|
||||||
|
@ -614,7 +620,7 @@ def get_dataloader(opt, train_dataset, test_dataset=None):
|
||||||
return train_dataloader, test_dataloader, train_sampler, test_sampler
|
return train_dataloader, test_dataloader, train_sampler, test_sampler
|
||||||
|
|
||||||
|
|
||||||
def train(gpu, opt, output_dir, noises_init):
|
def train(gpu, opt, output_dir):
|
||||||
set_seed(opt)
|
set_seed(opt)
|
||||||
logger = setup_logging(output_dir)
|
logger = setup_logging(output_dir)
|
||||||
if opt.distribution_type == "multi":
|
if opt.distribution_type == "multi":
|
||||||
|
@ -702,8 +708,9 @@ def train(gpu, opt, output_dir, noises_init):
|
||||||
lr_scheduler.step(epoch)
|
lr_scheduler.step(epoch)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
x = data["train_points"].transpose(1, 2)
|
# x = data["train_points"].transpose(1, 2)
|
||||||
noises_batch = noises_init[data["idx"]].transpose(1, 2)
|
x = data["positions"].transpose(1, 2)
|
||||||
|
noises_batch = torch.randn_like(x)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
train diffusion
|
train diffusion
|
||||||
|
@ -830,10 +837,6 @@ def main():
|
||||||
output_dir = get_output_dir(dir_id, exp_id)
|
output_dir = get_output_dir(dir_id, exp_id)
|
||||||
copy_source(__file__, output_dir)
|
copy_source(__file__, output_dir)
|
||||||
|
|
||||||
""" workaround """
|
|
||||||
train_dataset, _ = get_dataset(opt.dataroot, opt.npoints, opt.category)
|
|
||||||
noises_init = torch.randn(len(train_dataset), opt.npoints, opt.nc)
|
|
||||||
|
|
||||||
if opt.dist_url == "env://" and opt.world_size == -1:
|
if opt.dist_url == "env://" and opt.world_size == -1:
|
||||||
opt.world_size = int(os.environ["WORLD_SIZE"])
|
opt.world_size = int(os.environ["WORLD_SIZE"])
|
||||||
|
|
||||||
|
@ -842,7 +845,7 @@ def main():
|
||||||
opt.world_size = opt.ngpus_per_node * opt.world_size
|
opt.world_size = opt.ngpus_per_node * opt.world_size
|
||||||
mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init))
|
mp.spawn(train, nprocs=opt.ngpus_per_node, args=(opt, output_dir, noises_init))
|
||||||
else:
|
else:
|
||||||
train(opt.gpu, opt, output_dir, noises_init)
|
train(opt.gpu, opt, output_dir)
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
|
Loading…
Reference in a new issue