feat: train on deformation instead of absolute position
This commit is contained in:
parent
859da1a847
commit
a3e23f59c5
16
.vscode/launch.json
vendored
16
.vscode/launch.json
vendored
|
@ -12,14 +12,14 @@
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"justMyCode": true,
|
"justMyCode": true,
|
||||||
"args": [
|
"args": [
|
||||||
"--model",
|
// "--model",
|
||||||
"output/train_generation/2023-04-11-23-38-23/epoch_99.pth",
|
// "output/train_generation/2023-04-11-23-38-23/epoch_99.pth",
|
||||||
"--generate",
|
// "--generate",
|
||||||
"True",
|
// "True",
|
||||||
"--workers",
|
// "--workers",
|
||||||
"4",
|
// "4",
|
||||||
"--batch_size",
|
// "--batch_size",
|
||||||
"6"
|
// "6"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
|
@ -9,6 +9,8 @@ import torch.optim as optim
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
from torch.distributions import Normal
|
from torch.distributions import Normal
|
||||||
|
|
||||||
|
import pyvista as pv
|
||||||
|
|
||||||
# from dataset.shapenet_data_pc import ShapeNet15kPointClouds
|
# from dataset.shapenet_data_pc import ShapeNet15kPointClouds
|
||||||
from model.pvcnn_generation import PVCNN2Base
|
from model.pvcnn_generation import PVCNN2Base
|
||||||
from utils.file_utils import *
|
from utils.file_utils import *
|
||||||
|
@ -650,6 +652,8 @@ def train(gpu, opt, output_dir):
|
||||||
""" data """
|
""" data """
|
||||||
train_dataset, _ = get_dataset(opt.dataroot, opt.npoints, opt.category)
|
train_dataset, _ = get_dataset(opt.dataroot, opt.npoints, opt.category)
|
||||||
dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None)
|
dataloader, _, train_sampler, _ = get_dataloader(opt, train_dataset, None)
|
||||||
|
VTKFILE_NOMINAL = Path("~/data/stage-laurent-f/datasets/Rotor37/processed/nominal_blade_rotated.vtk")
|
||||||
|
nominal = pv.read(VTKFILE_NOMINAL)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
create networks
|
create networks
|
||||||
|
@ -708,7 +712,7 @@ def train(gpu, opt, output_dir):
|
||||||
lr_scheduler.step(epoch)
|
lr_scheduler.step(epoch)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for i, data in enumerate(dataloader):
|
||||||
# x = data["train_points"].transpose(1, 2)
|
x = data["positions"] - nominal.points
|
||||||
x = data["positions"].transpose(1, 2)
|
x = data["positions"].transpose(1, 2)
|
||||||
noises_batch = torch.randn_like(x)
|
noises_batch = torch.randn_like(x)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue