feat: normalize rotor37 dataset

This commit is contained in:
Laurent FAINSIN 2023-04-11 17:32:30 +02:00
parent 0ef9148666
commit ff454a8048
5 changed files with 18 additions and 7 deletions

View file

@ -1,5 +1,5 @@
{
"python.defaultInterpreterPath": "/local_scratch/lfainsin/.conda/envs/pyg/bin/python", // required for python ide tools
// "python.defaultInterpreterPath": "/local_scratch/lfainsin/.conda/envs/pyg/bin/python", // required for python ide tools
"python.terminal.activateEnvironment": false, // or else terminal gets bugged
"python.analysis.typeCheckingMode": "basic", // get ready to be annoyed
"python.formatting.provider": "black", // opinionated, fuck off

View file

@ -2,6 +2,7 @@ from pathlib import Path
import datasets
import h5py
import numpy as np
DATASET_DIR = Path("/gpfs_new/cold-data/InputData/public_datasets/rotor37/rotor37_1200/")
H5FILE_TRAIN = DATASET_DIR / "h5" / "blade_meshes_train.h5"
@ -51,8 +52,17 @@ class Rotor37(datasets.GeneratorBasedBuilder):
def _generate_examples(self, h5file: Path):
with h5py.File(h5file, "r") as f:
# compute mean and std of positions
positions = np.asarray(f["points"])
positions_mean = positions.mean(axis=(0, 1))
positions_std = positions.std(axis=(0, 1))
# normalize positions
positions = (positions - positions_mean) / positions_std
# zip attributes
attributes = zip(
f["points"], # type: ignore
positions,
f["normals"], # type: ignore
f["output_fields"], # type: ignore
)

View file

@ -1,11 +1,11 @@
import datasets
train_ds = datasets.load_dataset("dataset/rotor37_data.py", split="train")
train_ds = train_ds.with_format("torch")
print(train_ds)
test_ds = datasets.load_dataset("dataset/rotor37_data.py", split="test")
test_ds = test_ds.with_format("torch")
print(test_ds)
train_ds = datasets.load_dataset("dataset/rotor37_data.py", split="train")
train_ds = train_ds.with_format("torch")
print(train_ds)
print("yay")

View file

@ -19,6 +19,7 @@ dependencies:
- scipy
- scikit-learn
- pyvista
- h5py
- datasets
#---# toolings
- ruff

View file

@ -900,7 +900,7 @@ def parse_args():
parser.add_argument("--gpu", default=None, type=int, help="GPU id to use. None means using all available GPUs.")
"""eval"""
parser.add_argument("--saveIter", default=100, help="unit: epoch")
parser.add_argument("--saveIter", default=100, type=int, help="unit: epoch")
parser.add_argument("--diagIter", default=50, help="unit: epoch")
parser.add_argument("--vizIter", default=50, help="unit: epoch")
parser.add_argument("--print_freq", default=50, help="unit: iter")