feat: train on deformation instead of absolute position

This commit is contained in:
Laurent FAINSIN 2023-04-14 14:00:11 +02:00
parent 859da1a847
commit a3e23f59c5
2 changed files with 13 additions and 9 deletions

16
.vscode/launch.json vendored
View file

@ -12,14 +12,14 @@
"console": "integratedTerminal", "console": "integratedTerminal",
"justMyCode": true, "justMyCode": true,
"args": [ "args": [
"--model", // "--model",
"output/train_generation/2023-04-11-23-38-23/epoch_99.pth", // "output/train_generation/2023-04-11-23-38-23/epoch_99.pth",
"--generate", // "--generate",
"True", // "True",
"--workers", // "--workers",
"4", // "4",
"--batch_size", // "--batch_size",
"6" // "6"
] ]
} }
] ]

View file

@ -9,6 +9,8 @@ import torch.optim as optim
import torch.utils.data import torch.utils.data
from torch.distributions import Normal from torch.distributions import Normal
import pyvista as pv
# from dataset.shapenet_data_pc import ShapeNet15kPointClouds # from dataset.shapenet_data_pc import ShapeNet15kPointClouds
from model.pvcnn_generation import PVCNN2Base from model.pvcnn_generation import PVCNN2Base
from utils.file_utils import * from utils.file_utils import *
@ -650,6 +652,8 @@ def train(gpu, opt, output_dir):
""" data """ """ data """
train_dataset, _ = get_dataset(opt.dataroot, opt.npoints, opt.category) train_dataset, _ = get_dataset(opt.dataroot, opt.npoints, opt.category)
dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None) dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None)
VTKFILE_NOMINAL = Path("~/data/stage-laurent-f/datasets/Rotor37/processed/nominal_blade_rotated.vtk")
nominal = pv.read(VTKFILE_NOMINAL)
""" """
create networks create networks
@ -708,7 +712,7 @@ def train(gpu, opt, output_dir):
lr_scheduler.step(epoch) lr_scheduler.step(epoch)
for i, data in enumerate(dataloader): for i, data in enumerate(dataloader):
# x = data["train_points"].transpose(1, 2) x = data["positions"] - nominal.points
x = data["positions"].transpose(1, 2) x = data["positions"].transpose(1, 2)
noises_batch = torch.randn_like(x) noises_batch = torch.randn_like(x)