PVD/dataset/rotor37_data.py

75 lines
2.4 KiB
Python
Raw Normal View History

from pathlib import Path
2023-04-11 14:00:54 +00:00
import datasets
import h5py
2023-04-11 15:32:30 +00:00
import numpy as np
DATASET_DIR = Path("/gpfs_new/cold-data/InputData/public_datasets/rotor37/rotor37_1200/")
H5FILE_TRAIN = DATASET_DIR / "h5" / "blade_meshes_train.h5"
H5FILE_TEST = DATASET_DIR / "h5" / "blade_meshes_test.h5"
2023-04-11 14:00:54 +00:00
N_POINTS = 29773
_VERSION = "1.0.0"
_DESCRIPTION = """
This dataset is a collection of 1200 pointclouds, each representing a blade of a wind turbine.
The dataset is split into 2 subsets: train and test, with 1000 and 200 clouds respectively.
Each pointcloud has 29773 points, each point has 3D coordinates, 3D normals and physical properties.
"""
MEAN = np.array([0.01994637, 0.2205227, -0.00095343])
STD = np.array([0.01270086, 0.0280048, 0.01615675])
2023-04-11 14:00:54 +00:00
class Rotor37(datasets.GeneratorBasedBuilder):
"""Rotor37 dataset."""
def _info(self):
return datasets.DatasetInfo(
version=_VERSION,
description=_DESCRIPTION,
features=datasets.Features(
{
"positions": datasets.Array2D(shape=(N_POINTS, 3), dtype="float32"),
"normals": datasets.Array2D(shape=(N_POINTS, 3), dtype="float32"),
"features": datasets.Array2D(shape=(N_POINTS, 4), dtype="float32"),
}
),
)
def _split_generators(self, dl_manager):
return [
datasets.SplitGenerator(
name=datasets.Split.TEST, # type: ignore
gen_kwargs={
"h5file": H5FILE_TEST,
},
),
datasets.SplitGenerator(
name=datasets.Split.TRAIN, # type: ignore
gen_kwargs={
"h5file": H5FILE_TRAIN,
},
),
]
2023-04-11 14:00:54 +00:00
def _generate_examples(self, h5file: Path):
with h5py.File(h5file, "r") as f:
2023-04-11 15:32:30 +00:00
# normalize positions
positions = np.asarray(f["points"])
positions = (positions - MEAN) / STD
2023-04-11 15:32:30 +00:00
# zip attributes
attributes = zip(
2023-04-11 15:32:30 +00:00
positions,
2023-04-11 14:00:54 +00:00
f["normals"], # type: ignore
f["output_fields"], # type: ignore
)
2023-04-11 14:00:54 +00:00
for index, (positions, normals, fields) in enumerate(attributes):
yield index, {
"positions": positions,
"normals": normals,
"features": fields,
}