fix: compute mean and std using entire dataset (not inside splits)

This commit is contained in:
Laurent FAINSIN 2023-04-12 17:19:12 +02:00
parent 3d0b8f8620
commit ce88fc9f88
5 changed files with 44 additions and 42 deletions

2
.gitignore vendored
View file

@ -12,6 +12,8 @@ eval_data
ShapeNetCore.v2.PC15k* ShapeNetCore.v2.PC15k*
checkpoints checkpoints
*.txt
# https://github.com/github/gitignore/blob/main/Python.gitignore # https://github.com/github/gitignore/blob/main/Python.gitignore
# Basic .gitignore for a python repo. # Basic .gitignore for a python repo.

10
.vscode/launch.json vendored
View file

@ -12,8 +12,14 @@
"console": "integratedTerminal", "console": "integratedTerminal",
"justMyCode": true, "justMyCode": true,
"args": [ "args": [
"--category", "--model",
"car", "output/train_generation/2023-04-11-23-38-23/epoch_99.pth",
"--generate",
"True",
"--workers",
"4",
"--batch_size",
"6"
] ]
} }
] ]

View file

@ -17,6 +17,9 @@ The dataset is split into 2 subsets: train and test, with 1000 and 200 clouds re
Each pointcloud has 29773 points, each point has 3D coordinates, 3D normals and physical properties. Each pointcloud has 29773 points, each point has 3D coordinates, 3D normals and physical properties.
""" """
MEAN = np.array([0.01994637, 0.2205227, -0.00095343])
STD = np.array([0.01270086, 0.0280048, 0.01615675])
class Rotor37(datasets.GeneratorBasedBuilder): class Rotor37(datasets.GeneratorBasedBuilder):
"""Rotor37 dataset.""" """Rotor37 dataset."""
@ -52,13 +55,9 @@ class Rotor37(datasets.GeneratorBasedBuilder):
def _generate_examples(self, h5file: Path): def _generate_examples(self, h5file: Path):
with h5py.File(h5file, "r") as f: with h5py.File(h5file, "r") as f:
# compute mean and std of positions
positions = np.asarray(f["points"])
positions_mean = positions.mean(axis=(0, 1))
positions_std = positions.std(axis=(0, 1))
# normalize positions # normalize positions
positions = (positions - positions_mean) / positions_std positions = np.asarray(f["points"])
positions = (positions - MEAN) / STD
# zip attributes # zip attributes
attributes = zip( attributes = zip(

View file

@ -1,4 +1,5 @@
import datasets import datasets
import numpy as np
test_ds = datasets.load_dataset("dataset/rotor37_data.py", split="test") test_ds = datasets.load_dataset("dataset/rotor37_data.py", split="test")
test_ds = test_ds.with_format("torch") test_ds = test_ds.with_format("torch")
@ -8,4 +9,15 @@ train_ds = datasets.load_dataset("dataset/rotor37_data.py", split="train")
train_ds = train_ds.with_format("torch") train_ds = train_ds.with_format("torch")
print(train_ds) print(train_ds)
print("yay") # save pointcloud to txt for paraview viz
for idx, blade in enumerate(test_ds):
pc = blade["positions"]
# unnormalize
pc = pc * blade["std"] + blade["mean"]
print(f"Saving point cloud {idx}...")
np.savetxt(f"pc_{idx}.txt", pc)
if idx >= 10:
break

View file

@ -8,6 +8,8 @@ import torch.utils.data
from torch.distributions import Normal from torch.distributions import Normal
from tqdm import tqdm from tqdm import tqdm
from dataset.rotor37_data import MEAN, STD
# from dataset.shapenet_data_pc import ShapeNet15kPointClouds # from dataset.shapenet_data_pc import ShapeNet15kPointClouds
from metrics.evaluation_metrics import compute_all_metrics from metrics.evaluation_metrics import compute_all_metrics
from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD from metrics.evaluation_metrics import jsd_between_point_cloud_sets as JSD
@ -468,10 +470,11 @@ def evaluate_gen(opt, ref_pcs, logger):
) )
ref = [] ref = []
for data in tqdm(test_dataloader, total=len(test_dataloader), desc="Generating Samples"): for data in tqdm(test_dataloader, total=len(test_dataloader), desc="Generating Samples"):
x = data["test_points"] x = data["positions"]
m, s = data["mean"].float(), data["std"].float() # m, s = data["mean"].float(), data["std"].float()
ref.append(x * s + m) # ref.append(x * s + m)
ref.append(x)
ref_pcs = torch.cat(ref, dim=0).contiguous() ref_pcs = torch.cat(ref, dim=0).contiguous()
@ -517,38 +520,18 @@ def generate(model, opt):
samples.append(gen) samples.append(gen)
ref.append(x) ref.append(x)
# visualize_pointcloud_batch( # save pointcloud to txt for paraview viz
# os.path.join(str(Path(opt.eval_path).parent), 'x.png'), for idx, blade in enumerate(gen):
# gen[:64], pc = blade
# None,
# None,
# None
# )
# visualize_voxels(
# os.path.join(str(Path(opt.eval_path).parent), 'x.png'),
# gen[:64],
# 1,
# 0.5,
# )
# visualize using matplotlib # unnormalize
import matplotlib pc = pc * STD + MEAN
import matplotlib.cm as cm
import matplotlib.pyplot as plt
matplotlib.use("TkAgg") print(f"Saving point cloud {idx}...")
for idx, pc in enumerate(gen[:64]): np.savetxt(f"gen_{i}_{idx}.txt", pc)
print(f"Visualizing point cloud {idx}...")
fig = plt.figure() if idx >= 10:
ax = fig.add_subplot(111, projection="3d") break
ax.scatter(pc[:, 0], pc[:, 1], pc[:, 2], c=pc[:, 2], cmap=cm.jet)
ax.set_aspect("equal")
ax.axis("off")
# ax.set_xlim(-1, 1)
# ax.set_ylim(-1, 1)
# ax.set_zlim(-1, 1)
plt.show()
plt.close()
samples = torch.cat(samples, dim=0) samples = torch.cat(samples, dim=0)
ref = torch.cat(ref, dim=0) ref = torch.cat(ref, dim=0)