fix: compute mean and std using entire dataset (not inside splits)
This commit is contained in:
parent
3d0b8f8620
commit
ce88fc9f88
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -12,6 +12,8 @@ eval_data
|
|||
ShapeNetCore.v2.PC15k*
|
||||
checkpoints
|
||||
|
||||
*.txt
|
||||
|
||||
# https://github.com/github/gitignore/blob/main/Python.gitignore
|
||||
# Basic .gitignore for a python repo.
|
||||
|
||||
|
|
10
.vscode/launch.json
vendored
10
.vscode/launch.json
vendored
|
@ -12,8 +12,14 @@
|
|||
"console": "integratedTerminal",
|
||||
"justMyCode": true,
|
||||
"args": [
|
||||
"--category",
|
||||
"car",
|
||||
"--model",
|
||||
"output/train_generation/2023-04-11-23-38-23/epoch_99.pth",
|
||||
"--generate",
|
||||
"True",
|
||||
"--workers",
|
||||
"4",
|
||||
"--batch_size",
|
||||
"6"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
|
@ -17,6 +17,9 @@ The dataset is split into 2 subsets: train and test, with 1000 and 200 clouds re
|
|||
Each pointcloud has 29773 points, each point has 3D coordinates, 3D normals and physical properties.
|
||||
"""
|
||||
|
||||
MEAN = np.array([0.01994637, 0.2205227, -0.00095343])
|
||||
STD = np.array([0.01270086, 0.0280048, 0.01615675])
|
||||
|
||||
|
||||
class Rotor37(datasets.GeneratorBasedBuilder):
|
||||
"""Rotor37 dataset."""
|
||||
|
@ -52,13 +55,9 @@ class Rotor37(datasets.GeneratorBasedBuilder):
|
|||
|
||||
def _generate_examples(self, h5file: Path):
|
||||
with h5py.File(h5file, "r") as f:
|
||||
# compute mean and std of positions
|
||||
positions = np.asarray(f["points"])
|
||||
positions_mean = positions.mean(axis=(0, 1))
|
||||
positions_std = positions.std(axis=(0, 1))
|
||||
|
||||
# normalize positions
|
||||
positions = (positions - positions_mean) / positions_std
|
||||
positions = np.asarray(f["points"])
|
||||
positions = (positions - MEAN) / STD
|
||||
|
||||
# zip attributes
|
||||
attributes = zip(
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import datasets
|
||||
import numpy as np
|
||||
|
||||
test_ds = datasets.load_dataset("dataset/rotor37_data.py", split="test")
|
||||
test_ds = test_ds.with_format("torch")
|
||||
|
@ -8,4 +9,15 @@ train_ds = datasets.load_dataset("dataset/rotor37_data.py", split="train")
|
|||
train_ds = train_ds.with_format("torch")
|
||||
print(train_ds)
|
||||
|
||||
print("yay")
|
||||
# save pointcloud to txt for paraview viz
|
||||
for idx, blade in enumerate(test_ds):
|
||||
pc = blade["positions"]
|
||||
|
||||
# unnormalize
|
||||
pc = pc * blade["std"] + blade["mean"]
|
||||
|
||||
print(f"Saving point cloud {idx}...")
|
||||
np.savetxt(f"pc_{idx}.txt", pc)
|
||||
|
||||
if idx >= 10:
|
||||
break
|
||||
|
|
|
@ -8,6 +8,8 @@ import torch.utils.data
|
|||
from torch.distributions import Normal
|
||||
from tqdm import tqdm
|
||||
|
||||
from dataset.rotor37_data import MEAN, STD
|
||||
|
||||
# from dataset.shapenet_data_pc import ShapeNet15kPointClouds
|
||||
from metrics.evaluation_metrics import compute_all_metrics
|
||||
from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
|
||||
|
@ -468,10 +470,11 @@ def evaluate_gen(opt, ref_pcs, logger):
|
|||
)
|
||||
ref = []
|
||||
for data in tqdm(test_dataloader, total=len(test_dataloader), desc="Generating Samples"):
|
||||
x = data["test_points"]
|
||||
m, s = data["mean"].float(), data["std"].float()
|
||||
x = data["positions"]
|
||||
# m, s = data["mean"].float(), data["std"].float()
|
||||
|
||||
ref.append(x * s + m)
|
||||
# ref.append(x * s + m)
|
||||
ref.append(x)
|
||||
|
||||
ref_pcs = torch.cat(ref, dim=0).contiguous()
|
||||
|
||||
|
@ -517,38 +520,18 @@ def generate(model, opt):
|
|||
samples.append(gen)
|
||||
ref.append(x)
|
||||
|
||||
# visualize_pointcloud_batch(
|
||||
# os.path.join(str(Path(opt.eval_path).parent), 'x.png'),
|
||||
# gen[:64],
|
||||
# None,
|
||||
# None,
|
||||
# None
|
||||
# )
|
||||
# visualize_voxels(
|
||||
# os.path.join(str(Path(opt.eval_path).parent), 'x.png'),
|
||||
# gen[:64],
|
||||
# 1,
|
||||
# 0.5,
|
||||
# )
|
||||
# save pointcloud to txt for paraview viz
|
||||
for idx, blade in enumerate(gen):
|
||||
pc = blade
|
||||
|
||||
# visualize using matplotlib
|
||||
import matplotlib
|
||||
import matplotlib.cm as cm
|
||||
import matplotlib.pyplot as plt
|
||||
# unnormalize
|
||||
pc = pc * STD + MEAN
|
||||
|
||||
matplotlib.use("TkAgg")
|
||||
for idx, pc in enumerate(gen[:64]):
|
||||
print(f"Visualizing point cloud {idx}...")
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111, projection="3d")
|
||||
ax.scatter(pc[:, 0], pc[:, 1], pc[:, 2], c=pc[:, 2], cmap=cm.jet)
|
||||
ax.set_aspect("equal")
|
||||
ax.axis("off")
|
||||
# ax.set_xlim(-1, 1)
|
||||
# ax.set_ylim(-1, 1)
|
||||
# ax.set_zlim(-1, 1)
|
||||
plt.show()
|
||||
plt.close()
|
||||
print(f"Saving point cloud {idx}...")
|
||||
np.savetxt(f"gen_{i}_{idx}.txt", pc)
|
||||
|
||||
if idx >= 10:
|
||||
break
|
||||
|
||||
samples = torch.cat(samples, dim=0)
|
||||
ref = torch.cat(ref, dim=0)
|
||||
|
|
Loading…
Reference in a new issue