From a3e23f59c5f934a7f96f29800d91fb9e79050aec Mon Sep 17 00:00:00 2001 From: Laurent FAINSIN Date: Fri, 14 Apr 2023 14:00:11 +0200 Subject: [PATCH] feat: train on deformation instead of absolute position --- .vscode/launch.json | 16 ++++++++-------- train_generation.py | 6 +++++- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index b41f803..cb98571 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -12,14 +12,14 @@ "console": "integratedTerminal", "justMyCode": true, "args": [ - "--model", - "output/train_generation/2023-04-11-23-38-23/epoch_99.pth", - "--generate", - "True", - "--workers", - "4", - "--batch_size", - "6" + // "--model", + // "output/train_generation/2023-04-11-23-38-23/epoch_99.pth", + // "--generate", + // "True", + // "--workers", + // "4", + // "--batch_size", + // "6" ] } ] diff --git a/train_generation.py b/train_generation.py index a8b53ec..523361e 100644 --- a/train_generation.py +++ b/train_generation.py @@ -9,6 +9,8 @@ import torch.optim as optim import torch.utils.data from torch.distributions import Normal +import pyvista as pv + # from dataset.shapenet_data_pc import ShapeNet15kPointClouds from model.pvcnn_generation import PVCNN2Base from utils.file_utils import * @@ -650,6 +652,8 @@ def train(gpu, opt, output_dir): """ data """ train_dataset, _ = get_dataset(opt.dataroot, opt.npoints, opt.category) dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None) + VTKFILE_NOMINAL = Path("~/data/stage-laurent-f/datasets/Rotor37/processed/nominal_blade_rotated.vtk") + nominal = pv.read(VTKFILE_NOMINAL) """ create networks @@ -708,7 +712,7 @@ def train(gpu, opt, output_dir): lr_scheduler.step(epoch) for i, data in enumerate(dataloader): - # x = data["train_points"].transpose(1, 2) + x = data["positions"] - nominal.points x = data["positions"].transpose(1, 2) noises_batch = torch.randn_like(x)