75 lines
2.4 KiB
Python
75 lines
2.4 KiB
Python
from pathlib import Path
|
|
|
|
import datasets
|
|
import h5py
|
|
import numpy as np
|
|
|
|
DATASET_DIR = Path("/gpfs_new/cold-data/InputData/public_datasets/rotor37/rotor37_1200/")
|
|
H5FILE_TRAIN = DATASET_DIR / "h5" / "blade_meshes_train.h5"
|
|
H5FILE_TEST = DATASET_DIR / "h5" / "blade_meshes_test.h5"
|
|
N_POINTS = 29773
|
|
|
|
_VERSION = "1.0.0"
|
|
|
|
_DESCRIPTION = """
|
|
This dataset is a collection of 1200 pointclouds, each representing a blade of a wind turbine.
|
|
The dataset is split into 2 subsets: train and test, with 1000 and 200 clouds respectively.
|
|
Each pointcloud has 29773 points, each point has 3D coordinates, 3D normals and physical properties.
|
|
"""
|
|
|
|
MEAN = np.array([0.01994637, 0.2205227, -0.00095343])
|
|
STD = np.array([0.01270086, 0.0280048, 0.01615675])
|
|
|
|
|
|
class Rotor37(datasets.GeneratorBasedBuilder):
|
|
"""Rotor37 dataset."""
|
|
|
|
def _info(self):
|
|
return datasets.DatasetInfo(
|
|
version=_VERSION,
|
|
description=_DESCRIPTION,
|
|
features=datasets.Features(
|
|
{
|
|
"positions": datasets.Array2D(shape=(N_POINTS, 3), dtype="float32"),
|
|
"normals": datasets.Array2D(shape=(N_POINTS, 3), dtype="float32"),
|
|
"features": datasets.Array2D(shape=(N_POINTS, 4), dtype="float32"),
|
|
}
|
|
),
|
|
)
|
|
|
|
def _split_generators(self, dl_manager):
|
|
return [
|
|
datasets.SplitGenerator(
|
|
name=datasets.Split.TEST, # type: ignore
|
|
gen_kwargs={
|
|
"h5file": H5FILE_TEST,
|
|
},
|
|
),
|
|
datasets.SplitGenerator(
|
|
name=datasets.Split.TRAIN, # type: ignore
|
|
gen_kwargs={
|
|
"h5file": H5FILE_TRAIN,
|
|
},
|
|
),
|
|
]
|
|
|
|
def _generate_examples(self, h5file: Path):
|
|
with h5py.File(h5file, "r") as f:
|
|
# normalize positions
|
|
positions = np.asarray(f["points"])
|
|
positions = (positions - MEAN) / STD
|
|
|
|
# zip attributes
|
|
attributes = zip(
|
|
positions,
|
|
f["normals"], # type: ignore
|
|
f["output_fields"], # type: ignore
|
|
)
|
|
|
|
for index, (positions, normals, fields) in enumerate(attributes):
|
|
yield index, {
|
|
"positions": positions,
|
|
"normals": normals,
|
|
"features": fields,
|
|
}
|