feat: modify test_generation to translate back sampled blade from norminal deformations

This commit is contained in:
Laurent FAINSIN 2023-04-17 10:35:46 +02:00
parent a3e23f59c5
commit 091d2074ca
3 changed files with 9 additions and 5 deletions

View file

@ -9,7 +9,7 @@ VTKFILE_NOMINAL = Path("~/data/stage-laurent-f/datasets/Rotor37/processed/nomina
nominal = pv.read(VTKFILE_NOMINAL) nominal = pv.read(VTKFILE_NOMINAL)
# for each generated/sampled blade # for each generated/sampled blade
gen_files = Path.cwd().glob("gen*.txt") gen_files = Path("./output").glob("gen*.txt")
for gen_file in gen_files: for gen_file in gen_files:
# load numpy txt # load numpy txt
blade = np.loadtxt(gen_file) blade = np.loadtxt(gen_file)
@ -28,10 +28,10 @@ for gen_file in gen_files:
blade -= center blade -= center
# save to txt # save to txt
np.savetxt(f"test_{gen_file.stem}.txt", blade) np.savetxt(f"output/test_{gen_file.stem}.txt", blade)
# swap nominal points to blade points # swap nominal points to blade points
nominal.points = blade nominal.points = blade
# save altered blade to vtk # save altered blade to vtk
nominal.save(f"test_{gen_file.stem}.vtk") nominal.save(f"output/test_{gen_file.stem}.vtk")

View file

@ -18,7 +18,7 @@ for idx, blade in enumerate(test_ds):
pc = pc * STD + MEAN pc = pc * STD + MEAN
print(f"Saving point cloud {idx}...") print(f"Saving point cloud {idx}...")
np.savetxt(f"pc_{idx}.txt", pc) np.savetxt(f"output/pc_{idx}.txt", pc)
if idx >= 10: if idx >= 10:
break break

View file

@ -2,6 +2,7 @@ import argparse
from pprint import pprint from pprint import pprint
import datasets import datasets
import pyvista as pv
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.data import torch.utils.data
@ -501,6 +502,8 @@ def generate(model, opt):
test_dataloader = torch.utils.data.DataLoader( test_dataloader = torch.utils.data.DataLoader(
test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), drop_last=False test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=int(opt.workers), drop_last=False
) )
VTKFILE_NOMINAL = Path("~/data/stage-laurent-f/datasets/Rotor37/processed/nominal_blade_rotated.vtk")
nominal = pv.read(VTKFILE_NOMINAL)
with torch.no_grad(): with torch.no_grad():
samples = [] samples = []
@ -514,6 +517,7 @@ def generate(model, opt):
gen = gen.transpose(1, 2).contiguous() gen = gen.transpose(1, 2).contiguous()
x = x.transpose(1, 2).contiguous() x = x.transpose(1, 2).contiguous()
x = x + nominal.points
# gen = gen * s + m # gen = gen * s + m
# x = x * s + m # x = x * s + m
@ -528,7 +532,7 @@ def generate(model, opt):
pc = pc * STD + MEAN pc = pc * STD + MEAN
print(f"Saving point cloud {idx}...") print(f"Saving point cloud {idx}...")
np.savetxt(f"gen_{i}_{idx}.txt", pc) np.savetxt(f"output/gen_{i}_{idx}.txt", pc)
if idx >= 10: if idx >= 10:
break break