feat: normalize rotor37 dataset
This commit is contained in:
parent
0ef9148666
commit
ff454a8048
2
.vscode/settings.json
vendored
2
.vscode/settings.json
vendored
|
@ -1,5 +1,5 @@
|
|||
{
|
||||
"python.defaultInterpreterPath": "/local_scratch/lfainsin/.conda/envs/pyg/bin/python", // required for python ide tools
|
||||
// "python.defaultInterpreterPath": "/local_scratch/lfainsin/.conda/envs/pyg/bin/python", // required for python ide tools
|
||||
"python.terminal.activateEnvironment": false, // or else terminal gets bugged
|
||||
"python.analysis.typeCheckingMode": "basic", // get ready to be annoyed
|
||||
"python.formatting.provider": "black", // opinionated, fuck off
|
||||
|
|
|
@ -2,6 +2,7 @@ from pathlib import Path
|
|||
|
||||
import datasets
|
||||
import h5py
|
||||
import numpy as np
|
||||
|
||||
DATASET_DIR = Path("/gpfs_new/cold-data/InputData/public_datasets/rotor37/rotor37_1200/")
|
||||
H5FILE_TRAIN = DATASET_DIR / "h5" / "blade_meshes_train.h5"
|
||||
|
@ -51,8 +52,17 @@ class Rotor37(datasets.GeneratorBasedBuilder):
|
|||
|
||||
def _generate_examples(self, h5file: Path):
|
||||
with h5py.File(h5file, "r") as f:
|
||||
# compute mean and std of positions
|
||||
positions = np.asarray(f["points"])
|
||||
positions_mean = positions.mean(axis=(0, 1))
|
||||
positions_std = positions.std(axis=(0, 1))
|
||||
|
||||
# normalize positions
|
||||
positions = (positions - positions_mean) / positions_std
|
||||
|
||||
# zip attributes
|
||||
attributes = zip(
|
||||
f["points"], # type: ignore
|
||||
positions,
|
||||
f["normals"], # type: ignore
|
||||
f["output_fields"], # type: ignore
|
||||
)
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import datasets
|
||||
|
||||
train_ds = datasets.load_dataset("dataset/rotor37_data.py", split="train")
|
||||
train_ds = train_ds.with_format("torch")
|
||||
print(train_ds)
|
||||
|
||||
test_ds = datasets.load_dataset("dataset/rotor37_data.py", split="test")
|
||||
test_ds = test_ds.with_format("torch")
|
||||
print(test_ds)
|
||||
|
||||
train_ds = datasets.load_dataset("dataset/rotor37_data.py", split="train")
|
||||
train_ds = train_ds.with_format("torch")
|
||||
print(train_ds)
|
||||
|
||||
print("yay")
|
||||
|
|
|
@ -19,6 +19,7 @@ dependencies:
|
|||
- scipy
|
||||
- scikit-learn
|
||||
- pyvista
|
||||
- h5py
|
||||
- datasets
|
||||
#---# toolings
|
||||
- ruff
|
||||
|
|
|
@ -900,7 +900,7 @@ def parse_args():
|
|||
parser.add_argument("--gpu", default=None, type=int, help="GPU id to use. None means using all available GPUs.")
|
||||
|
||||
"""eval"""
|
||||
parser.add_argument("--saveIter", default=100, help="unit: epoch")
|
||||
parser.add_argument("--saveIter", default=100, type=int, help="unit: epoch")
|
||||
parser.add_argument("--diagIter", default=50, help="unit: epoch")
|
||||
parser.add_argument("--vizIter", default=50, help="unit: epoch")
|
||||
parser.add_argument("--print_freq", default=50, help="unit: iter")
|
||||
|
|
Loading…
Reference in a new issue