feat: normalize rotor37 dataset
This commit is contained in:
parent
0ef9148666
commit
ff454a8048
2
.vscode/settings.json
vendored
2
.vscode/settings.json
vendored
|
@ -1,5 +1,5 @@
|
||||||
{
|
{
|
||||||
"python.defaultInterpreterPath": "/local_scratch/lfainsin/.conda/envs/pyg/bin/python", // required for python ide tools
|
// "python.defaultInterpreterPath": "/local_scratch/lfainsin/.conda/envs/pyg/bin/python", // required for python ide tools
|
||||||
"python.terminal.activateEnvironment": false, // or else terminal gets bugged
|
"python.terminal.activateEnvironment": false, // or else terminal gets bugged
|
||||||
"python.analysis.typeCheckingMode": "basic", // get ready to be annoyed
|
"python.analysis.typeCheckingMode": "basic", // get ready to be annoyed
|
||||||
"python.formatting.provider": "black", // opinionated, fuck off
|
"python.formatting.provider": "black", // opinionated, fuck off
|
||||||
|
|
|
@ -2,6 +2,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import h5py
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
DATASET_DIR = Path("/gpfs_new/cold-data/InputData/public_datasets/rotor37/rotor37_1200/")
|
DATASET_DIR = Path("/gpfs_new/cold-data/InputData/public_datasets/rotor37/rotor37_1200/")
|
||||||
H5FILE_TRAIN = DATASET_DIR / "h5" / "blade_meshes_train.h5"
|
H5FILE_TRAIN = DATASET_DIR / "h5" / "blade_meshes_train.h5"
|
||||||
|
@ -51,8 +52,17 @@ class Rotor37(datasets.GeneratorBasedBuilder):
|
||||||
|
|
||||||
def _generate_examples(self, h5file: Path):
|
def _generate_examples(self, h5file: Path):
|
||||||
with h5py.File(h5file, "r") as f:
|
with h5py.File(h5file, "r") as f:
|
||||||
|
# compute mean and std of positions
|
||||||
|
positions = np.asarray(f["points"])
|
||||||
|
positions_mean = positions.mean(axis=(0, 1))
|
||||||
|
positions_std = positions.std(axis=(0, 1))
|
||||||
|
|
||||||
|
# normalize positions
|
||||||
|
positions = (positions - positions_mean) / positions_std
|
||||||
|
|
||||||
|
# zip attributes
|
||||||
attributes = zip(
|
attributes = zip(
|
||||||
f["points"], # type: ignore
|
positions,
|
||||||
f["normals"], # type: ignore
|
f["normals"], # type: ignore
|
||||||
f["output_fields"], # type: ignore
|
f["output_fields"], # type: ignore
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
import datasets
|
import datasets
|
||||||
|
|
||||||
train_ds = datasets.load_dataset("dataset/rotor37_data.py", split="train")
|
|
||||||
train_ds = train_ds.with_format("torch")
|
|
||||||
print(train_ds)
|
|
||||||
|
|
||||||
test_ds = datasets.load_dataset("dataset/rotor37_data.py", split="test")
|
test_ds = datasets.load_dataset("dataset/rotor37_data.py", split="test")
|
||||||
test_ds = test_ds.with_format("torch")
|
test_ds = test_ds.with_format("torch")
|
||||||
print(test_ds)
|
print(test_ds)
|
||||||
|
|
||||||
|
train_ds = datasets.load_dataset("dataset/rotor37_data.py", split="train")
|
||||||
|
train_ds = train_ds.with_format("torch")
|
||||||
|
print(train_ds)
|
||||||
|
|
||||||
print("yay")
|
print("yay")
|
||||||
|
|
|
@ -19,6 +19,7 @@ dependencies:
|
||||||
- scipy
|
- scipy
|
||||||
- scikit-learn
|
- scikit-learn
|
||||||
- pyvista
|
- pyvista
|
||||||
|
- h5py
|
||||||
- datasets
|
- datasets
|
||||||
#---# toolings
|
#---# toolings
|
||||||
- ruff
|
- ruff
|
||||||
|
|
|
@ -900,7 +900,7 @@ def parse_args():
|
||||||
parser.add_argument("--gpu", default=None, type=int, help="GPU id to use. None means using all available GPUs.")
|
parser.add_argument("--gpu", default=None, type=int, help="GPU id to use. None means using all available GPUs.")
|
||||||
|
|
||||||
"""eval"""
|
"""eval"""
|
||||||
parser.add_argument("--saveIter", default=100, help="unit: epoch")
|
parser.add_argument("--saveIter", default=100, type=int, help="unit: epoch")
|
||||||
parser.add_argument("--diagIter", default=50, help="unit: epoch")
|
parser.add_argument("--diagIter", default=50, help="unit: epoch")
|
||||||
parser.add_argument("--vizIter", default=50, help="unit: epoch")
|
parser.add_argument("--vizIter", default=50, help="unit: epoch")
|
||||||
parser.add_argument("--print_freq", default=50, help="unit: iter")
|
parser.add_argument("--print_freq", default=50, help="unit: iter")
|
||||||
|
|
Loading…
Reference in a new issue