fix: compute mean and std using entire dataset (not inside splits)
This commit is contained in:
parent
3d0b8f8620
commit
ce88fc9f88
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -12,6 +12,8 @@ eval_data
|
||||||
ShapeNetCore.v2.PC15k*
|
ShapeNetCore.v2.PC15k*
|
||||||
checkpoints
|
checkpoints
|
||||||
|
|
||||||
|
*.txt
|
||||||
|
|
||||||
# https://github.com/github/gitignore/blob/main/Python.gitignore
|
# https://github.com/github/gitignore/blob/main/Python.gitignore
|
||||||
# Basic .gitignore for a python repo.
|
# Basic .gitignore for a python repo.
|
||||||
|
|
||||||
|
|
10
.vscode/launch.json
vendored
10
.vscode/launch.json
vendored
|
@ -12,8 +12,14 @@
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"justMyCode": true,
|
"justMyCode": true,
|
||||||
"args": [
|
"args": [
|
||||||
"--category",
|
"--model",
|
||||||
"car",
|
"output/train_generation/2023-04-11-23-38-23/epoch_99.pth",
|
||||||
|
"--generate",
|
||||||
|
"True",
|
||||||
|
"--workers",
|
||||||
|
"4",
|
||||||
|
"--batch_size",
|
||||||
|
"6"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
|
@ -17,6 +17,9 @@ The dataset is split into 2 subsets: train and test, with 1000 and 200 clouds re
|
||||||
Each pointcloud has 29773 points, each point has 3D coordinates, 3D normals and physical properties.
|
Each pointcloud has 29773 points, each point has 3D coordinates, 3D normals and physical properties.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
MEAN = np.array([0.01994637, 0.2205227, -0.00095343])
|
||||||
|
STD = np.array([0.01270086, 0.0280048, 0.01615675])
|
||||||
|
|
||||||
|
|
||||||
class Rotor37(datasets.GeneratorBasedBuilder):
|
class Rotor37(datasets.GeneratorBasedBuilder):
|
||||||
"""Rotor37 dataset."""
|
"""Rotor37 dataset."""
|
||||||
|
@ -52,13 +55,9 @@ class Rotor37(datasets.GeneratorBasedBuilder):
|
||||||
|
|
||||||
def _generate_examples(self, h5file: Path):
|
def _generate_examples(self, h5file: Path):
|
||||||
with h5py.File(h5file, "r") as f:
|
with h5py.File(h5file, "r") as f:
|
||||||
# compute mean and std of positions
|
|
||||||
positions = np.asarray(f["points"])
|
|
||||||
positions_mean = positions.mean(axis=(0, 1))
|
|
||||||
positions_std = positions.std(axis=(0, 1))
|
|
||||||
|
|
||||||
# normalize positions
|
# normalize positions
|
||||||
positions = (positions - positions_mean) / positions_std
|
positions = np.asarray(f["points"])
|
||||||
|
positions = (positions - MEAN) / STD
|
||||||
|
|
||||||
# zip attributes
|
# zip attributes
|
||||||
attributes = zip(
|
attributes = zip(
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import datasets
|
import datasets
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
test_ds = datasets.load_dataset("dataset/rotor37_data.py", split="test")
|
test_ds = datasets.load_dataset("dataset/rotor37_data.py", split="test")
|
||||||
test_ds = test_ds.with_format("torch")
|
test_ds = test_ds.with_format("torch")
|
||||||
|
@ -8,4 +9,15 @@ train_ds = datasets.load_dataset("dataset/rotor37_data.py", split="train")
|
||||||
train_ds = train_ds.with_format("torch")
|
train_ds = train_ds.with_format("torch")
|
||||||
print(train_ds)
|
print(train_ds)
|
||||||
|
|
||||||
print("yay")
|
# save pointcloud to txt for paraview viz
|
||||||
|
for idx, blade in enumerate(test_ds):
|
||||||
|
pc = blade["positions"]
|
||||||
|
|
||||||
|
# unnormalize
|
||||||
|
pc = pc * blade["std"] + blade["mean"]
|
||||||
|
|
||||||
|
print(f"Saving point cloud {idx}...")
|
||||||
|
np.savetxt(f"pc_{idx}.txt", pc)
|
||||||
|
|
||||||
|
if idx >= 10:
|
||||||
|
break
|
||||||
|
|
|
@ -8,6 +8,8 @@ import torch.utils.data
|
||||||
from torch.distributions import Normal
|
from torch.distributions import Normal
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from dataset.rotor37_data import MEAN, STD
|
||||||
|
|
||||||
# from dataset.shapenet_data_pc import ShapeNet15kPointClouds
|
# from dataset.shapenet_data_pc import ShapeNet15kPointClouds
|
||||||
from metrics.evaluation_metrics import compute_all_metrics
|
from metrics.evaluation_metrics import compute_all_metrics
|
||||||
from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
|
from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
|
||||||
|
@ -468,10 +470,11 @@ def evaluate_gen(opt, ref_pcs, logger):
|
||||||
)
|
)
|
||||||
ref = []
|
ref = []
|
||||||
for data in tqdm(test_dataloader, total=len(test_dataloader), desc="Generating Samples"):
|
for data in tqdm(test_dataloader, total=len(test_dataloader), desc="Generating Samples"):
|
||||||
x = data["test_points"]
|
x = data["positions"]
|
||||||
m, s = data["mean"].float(), data["std"].float()
|
# m, s = data["mean"].float(), data["std"].float()
|
||||||
|
|
||||||
ref.append(x * s + m)
|
# ref.append(x * s + m)
|
||||||
|
ref.append(x)
|
||||||
|
|
||||||
ref_pcs = torch.cat(ref, dim=0).contiguous()
|
ref_pcs = torch.cat(ref, dim=0).contiguous()
|
||||||
|
|
||||||
|
@ -517,38 +520,18 @@ def generate(model, opt):
|
||||||
samples.append(gen)
|
samples.append(gen)
|
||||||
ref.append(x)
|
ref.append(x)
|
||||||
|
|
||||||
# visualize_pointcloud_batch(
|
# save pointcloud to txt for paraview viz
|
||||||
# os.path.join(str(Path(opt.eval_path).parent), 'x.png'),
|
for idx, blade in enumerate(gen):
|
||||||
# gen[:64],
|
pc = blade
|
||||||
# None,
|
|
||||||
# None,
|
|
||||||
# None
|
|
||||||
# )
|
|
||||||
# visualize_voxels(
|
|
||||||
# os.path.join(str(Path(opt.eval_path).parent), 'x.png'),
|
|
||||||
# gen[:64],
|
|
||||||
# 1,
|
|
||||||
# 0.5,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# visualize using matplotlib
|
# unnormalize
|
||||||
import matplotlib
|
pc = pc * STD + MEAN
|
||||||
import matplotlib.cm as cm
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
matplotlib.use("TkAgg")
|
print(f"Saving point cloud {idx}...")
|
||||||
for idx, pc in enumerate(gen[:64]):
|
np.savetxt(f"gen_{i}_{idx}.txt", pc)
|
||||||
print(f"Visualizing point cloud {idx}...")
|
|
||||||
fig = plt.figure()
|
if idx >= 10:
|
||||||
ax = fig.add_subplot(111, projection="3d")
|
break
|
||||||
ax.scatter(pc[:, 0], pc[:, 1], pc[:, 2], c=pc[:, 2], cmap=cm.jet)
|
|
||||||
ax.set_aspect("equal")
|
|
||||||
ax.axis("off")
|
|
||||||
# ax.set_xlim(-1, 1)
|
|
||||||
# ax.set_ylim(-1, 1)
|
|
||||||
# ax.set_zlim(-1, 1)
|
|
||||||
plt.show()
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
samples = torch.cat(samples, dim=0)
|
samples = torch.cat(samples, dim=0)
|
||||||
ref = torch.cat(ref, dim=0)
|
ref = torch.cat(ref, dim=0)
|
||||||
|
|
Loading…
Reference in a new issue