PVD/dataset/test_rotor37_data.py

25 lines
603 B
Python
Raw Permalink Normal View History

2023-04-11 14:00:54 +00:00
import datasets
import numpy as np
from rotor37_data import MEAN, STD
2023-04-11 14:00:54 +00:00
test_ds = datasets.load_dataset("dataset/rotor37_data.py", split="test")
test_ds = test_ds.with_format("torch")
print(test_ds)
2023-04-11 15:32:30 +00:00
train_ds = datasets.load_dataset("dataset/rotor37_data.py", split="train")
train_ds = train_ds.with_format("torch")
print(train_ds)
# save pointcloud to txt for paraview viz
for idx, blade in enumerate(test_ds):
pc = blade["positions"]
# unnormalize
pc = pc * STD + MEAN
print(f"Saving point cloud {idx}...")
np.savetxt(f"output/pc_{idx}.txt", pc)
if idx >= 10:
break