diff --git a/part_segmentation/data.py b/part_segmentation/data.py deleted file mode 100644 index a8fb553..0000000 --- a/part_segmentation/data.py +++ /dev/null @@ -1,185 +0,0 @@ - - -import os -import sys -import glob -import h5py -import numpy as np -import torch -from torch.utils.data import Dataset - - -# change this to your data root -DATA_DIR = 'data/' -os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" - -def download_modelnet40(): - if not os.path.exists(DATA_DIR): - os.mkdir(DATA_DIR) - if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')): - os.mkdir(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')) - www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip' - zipfile = os.path.basename(www) - os.system('wget %s --no-check-certificate; unzip %s' % (www, zipfile)) - os.system('mv %s %s' % (zipfile[:-4], DATA_DIR)) - os.system('rm %s' % (zipfile)) - - -def download_shapenetpart(): - if not os.path.exists(DATA_DIR): - os.mkdir(DATA_DIR) - if not os.path.exists(os.path.join(DATA_DIR)): - os.mkdir(os.path.join(DATA_DIR)) - www = 'https://shapenet.cs.stanford.edu/media/shapenet_part_seg_hdf5_data.zip' - zipfile = os.path.basename(www) - os.system('wget %s --no-check-certificate; unzip %s' % (www, zipfile)) - os.system('mv %s %s' % (zipfile[:-4], os.path.join(DATA_DIR))) - os.system('rm %s' % (zipfile)) - - -def load_data_normal(partition): - f = h5py.File(os.path.join(DATA_DIR, 'modelnet40_normal', 'normal_%s.h5'%partition), 'r+') - data = f['xyz'][:].astype('float32') - label = f['normal'][:].astype('float32') - f.close() - return data, label - - -def load_data_cls(partition): - download_modelnet40() - all_data = [] - all_label = [] - for h5_name in glob.glob(os.path.join(DATA_DIR, 'modelnet40*hdf5_2048', '*%s*.h5'%partition)): - f = h5py.File(h5_name, 'r+') - data = f['data'][:].astype('float32') - label = f['label'][:].astype('int64') - f.close() - all_data.append(data) - all_label.append(label) - all_data = np.concatenate(all_data, axis=0) - all_label = np.concatenate(all_label, axis=0) - return all_data, all_label - - -def load_data_partseg(partition): - download_shapenetpart() - all_data = [] - all_label = [] - all_seg = [] - if partition == 'trainval': - file = glob.glob(os.path.join(DATA_DIR, 'part_segmentation_data', '*train*.h5')) \ - + glob.glob(os.path.join(DATA_DIR, 'part_segmentation_data', '*val*.h5')) - else: - file = glob.glob(os.path.join(DATA_DIR, 'part_segmentation_data', '*%s*.h5'%partition)) - for h5_name in file: - f = h5py.File(h5_name, 'r+') - data = f['data'][:].astype('float32') - label = f['label'][:].astype('int64') - seg = f['pid'][:].astype('int64') - f.close() - all_data.append(data) - all_label.append(label) - all_seg.append(seg) - all_data = np.concatenate(all_data, axis=0) - all_label = np.concatenate(all_label, axis=0) - all_seg = np.concatenate(all_seg, axis=0) - return all_data, all_label, all_seg - - -def translate_pointcloud(pointcloud): - xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3]) - xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) - - translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32') - return translated_pointcloud - - -def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02): - N, C = pointcloud.shape - pointcloud += np.clip(sigma * np.random.randn(N, C), -1*clip, clip) - return pointcloud - - -def rotate_pointcloud(pointcloud): - theta = np.pi*2 * np.random.uniform() - rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]]) - pointcloud[:,[0,2]] = pointcloud[:,[0,2]].dot(rotation_matrix) # random rotation (x,z) - return pointcloud - - -class ModelNet40(Dataset): - def __init__(self, num_points, partition='train'): - self.data, self.label = load_data_cls(partition) - self.num_points = num_points - self.partition = partition - - def __getitem__(self, item): - pointcloud = self.data[item][:self.num_points] - label = self.label[item] - if self.partition == 'train': - pointcloud = translate_pointcloud(pointcloud) - #pointcloud = rotate_pointcloud(pointcloud) - np.random.shuffle(pointcloud) - return pointcloud, label - - def __len__(self): - return self.data.shape[0] - -class ModelNetNormal(Dataset): - def __init__(self, num_points, partition='train'): - self.data, self.label = load_data_normal(partition) - self.num_points = num_points - self.partition = partition - - def __getitem__(self, item): - pointcloud = self.data[item][:self.num_points] - label = self.label[item][:self.num_points] - if self.partition == 'train': - #pointcloud = translate_pointcloud(pointcloud) - idx = np.arange(0, pointcloud.shape[0], dtype=np.int64) - np.random.shuffle(idx) - pointcloud = self.data[item][idx] - label = self.label[item][idx] - return pointcloud, label - - def __len__(self): - return self.data.shape[0] - -class ShapeNetPart(Dataset): - def __init__(self, num_points=2048, partition='train', class_choice=None): - self.data, self.label, self.seg = load_data_partseg(partition) - self.cat2id = {'airplane': 0, 'bag': 1, 'cap': 2, 'car': 3, 'chair': 4, - 'earphone': 5, 'guitar': 6, 'knife': 7, 'lamp': 8, 'laptop': 9, - 'motor': 10, 'mug': 11, 'pistol': 12, 'rocket': 13, 'skateboard': 14, 'table': 15} - self.seg_num = [4, 2, 2, 4, 4, 3, 3, 2, 4, 2, 6, 2, 3, 3, 3, 3] - self.index_start = [0, 4, 6, 8, 12, 16, 19, 22, 24, 28, 30, 36, 38, 41, 44, 47] - self.num_points = num_points - self.partition = partition - self.class_choice = class_choice - - if self.class_choice != None: - id_choice = self.cat2id[self.class_choice] - indices = (self.label == id_choice).squeeze() - self.data = self.data[indices] - self.label = self.label[indices] - self.seg = self.seg[indices] - self.seg_num_all = self.seg_num[id_choice] - self.seg_start_index = self.index_start[id_choice] - else: - self.seg_num_all = 50 - self.seg_start_index = 0 - - def __getitem__(self, item): - pointcloud = self.data[item][:self.num_points] - label = self.label[item] - seg = self.seg[item][:self.num_points] - if self.partition == 'trainval': - pointcloud = translate_pointcloud(pointcloud) - indices = list(range(pointcloud.shape[0])) - np.random.shuffle(indices) - pointcloud = pointcloud[indices] - seg = seg[indices] - return pointcloud, label, seg - - def __len__(self): - return self.data.shape[0] diff --git a/part_segmentation/models/__init__.py b/part_segmentation/model/__init__.py similarity index 100% rename from part_segmentation/models/__init__.py rename to part_segmentation/model/__init__.py diff --git a/part_segmentation/models/pointMLP.py b/part_segmentation/model/pointMLP.py similarity index 98% rename from part_segmentation/models/pointMLP.py rename to part_segmentation/model/pointMLP.py index d5c5256..7790387 100644 --- a/part_segmentation/models/pointMLP.py +++ b/part_segmentation/model/pointMLP.py @@ -338,7 +338,7 @@ class PointMLP(nn.Module): self.stages = len(pre_blocks) self.class_num = num_classes self.points = points - self.embedding = ConvBNReLU1D(3, embed_dim, bias=bias, activation=activation) + self.embedding = ConvBNReLU1D(6, embed_dim, bias=bias, activation=activation) assert len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion), \ "Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers." self.local_grouper_list = nn.ModuleList() @@ -401,14 +401,14 @@ class PointMLP(nn.Module): self.classifier = nn.Sequential( nn.Conv1d(gmp_dim+cls_dim+de_dims[-1], 128, 1, bias=bias), nn.BatchNorm1d(128), - self.act, nn.Dropout(), nn.Conv1d(128, num_classes, 1, bias=bias) ) self.en_dims = en_dims - def forward(self, x, cls_label): + def forward(self, x, norm_plt, cls_label): xyz = x.permute(0, 2, 1) + x = torch.cat([x,norm_plt],dim=1) x = self.embedding(x) # B,D,N xyz_list = [xyz] # [B, N, 3] @@ -440,8 +440,8 @@ class PointMLP(nn.Module): cls_token = self.cls_map(cls_label.unsqueeze(dim=-1)) # [b, cls_dim, 1] x = torch.cat([x, global_context.repeat([1, 1, x.shape[-1]]), cls_token.repeat([1, 1, x.shape[-1]])], dim=1) x = self.classifier(x) - # x = F.log_softmax(x, dim=1) - # x = x.permute(0, 2, 1) + x = F.log_softmax(x, dim=1) + x = x.permute(0, 2, 1) return x @@ -459,6 +459,6 @@ if __name__ == '__main__': norm = torch.rand(2, 3, 2048) cls_label = torch.rand([2, 16]) print("===> testing modelD ...") - model = model31G(50) + model = pointMLP(50) out = model(data, cls_label) # [2,2048,50] print(out.shape) diff --git a/part_segmentation/util.py b/part_segmentation/util.py deleted file mode 100644 index e648bb2..0000000 --- a/part_segmentation/util.py +++ /dev/null @@ -1,38 +0,0 @@ - - -import numpy as np -import torch -import torch.nn.functional as F - - -def cal_loss(pred, gold, smoothing=True): - ''' Calculate cross entropy loss, apply label smoothing if needed. ''' - - gold = gold.contiguous().view(-1) - - if smoothing: - eps = 0.2 - n_class = pred.size(1) - - one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) - one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) - log_prb = F.log_softmax(pred, dim=1) - - loss = -(one_hot * log_prb).sum(dim=1).mean() - else: - loss = F.cross_entropy(pred, gold, reduction='mean') - - return loss - - -class IOStream(): - def __init__(self, path): - self.f = open(path, 'a') - - def cprint(self, text): - print(text) - self.f.write(text+'\n') - self.f.flush() - - def close(self): - self.f.close() diff --git a/part_segmentation/util/__init__.py b/part_segmentation/util/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/part_segmentation/util/data_util.py b/part_segmentation/util/data_util.py new file mode 100755 index 0000000..9725ab0 --- /dev/null +++ b/part_segmentation/util/data_util.py @@ -0,0 +1,164 @@ +import glob +import h5py +import numpy as np +from torch.utils.data import Dataset +import os +import json +os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" + + +def load_data(partition): + all_data = [] + all_label = [] + for h5_name in glob.glob('./data/modelnet40_ply_hdf5_2048/ply_data_%s*.h5' % partition): + f = h5py.File(h5_name) + data = f['data'][:].astype('float32') + label = f['label'][:].astype('int64') + f.close() + all_data.append(data) + all_label.append(label) + all_data = np.concatenate(all_data, axis=0) + all_label = np.concatenate(all_label, axis=0) + return all_data, all_label + + +def pc_normalize(pc): + centroid = np.mean(pc, axis=0) + pc = pc - centroid + m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) + pc = pc / m + return pc + + +def translate_pointcloud(pointcloud): + xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3]) + xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) + + translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32') + return translated_pointcloud + + +def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02): + N, C = pointcloud.shape + pointcloud += np.clip(sigma * np.random.randn(N, C), -1*clip, clip) + return pointcloud + + +# =========== ModelNet40 ================= +class ModelNet40(Dataset): + def __init__(self, num_points, partition='train'): + self.data, self.label = load_data(partition) + self.num_points = num_points + self.partition = partition # Here the new given partition will cover the 'train' + + def __getitem__(self, item): # indice of the pts or label + pointcloud = self.data[item][:self.num_points] + label = self.label[item] + if self.partition == 'train': + # pointcloud = pc_normalize(pointcloud) # you can try to add it or not to train our model + pointcloud = translate_pointcloud(pointcloud) + np.random.shuffle(pointcloud) # shuffle the order of pts + return pointcloud, label + + def __len__(self): + return self.data.shape[0] + + +# =========== ShapeNet Part ================= +class PartNormalDataset(Dataset): + def __init__(self, npoints=2500, split='train', normalize=False): + self.npoints = npoints + self.root = './data/shapenetcore_partanno_segmentation_benchmark_v0_normal' + self.catfile = os.path.join(self.root, 'synsetoffset2category.txt') + self.cat = {} + self.normalize = normalize + + with open(self.catfile, 'r') as f: + for line in f: + ls = line.strip().split() + self.cat[ls[0]] = ls[1] + self.cat = {k: v for k, v in self.cat.items()} + + self.meta = {} + with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f: + train_ids = set([str(d.split('/')[2]) for d in json.load(f)]) + with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f: + val_ids = set([str(d.split('/')[2]) for d in json.load(f)]) + with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f: + test_ids = set([str(d.split('/')[2]) for d in json.load(f)]) + for item in self.cat: + self.meta[item] = [] + dir_point = os.path.join(self.root, self.cat[item]) + fns = sorted(os.listdir(dir_point)) + + if split == 'trainval': + fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))] + elif split == 'train': + fns = [fn for fn in fns if fn[0:-4] in train_ids] + elif split == 'val': + fns = [fn for fn in fns if fn[0:-4] in val_ids] + elif split == 'test': + fns = [fn for fn in fns if fn[0:-4] in test_ids] + else: + print('Unknown split: %s. Exiting..' % (split)) + exit(-1) + + for fn in fns: + token = (os.path.splitext(os.path.basename(fn))[0]) + self.meta[item].append(os.path.join(dir_point, token + '.txt')) + + self.datapath = [] + for item in self.cat: + for fn in self.meta[item]: + self.datapath.append((item, fn)) + + self.classes = dict(zip(self.cat, range(len(self.cat)))) + # Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels + self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], + 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], + 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], + 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], + 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]} + + self.cache = {} # from index to (point_set, cls, seg) tuple + self.cache_size = 20000 + + def __getitem__(self, index): + if index in self.cache: + point_set, normal, seg, cls = self.cache[index] + else: + fn = self.datapath[index] + cat = self.datapath[index][0] + cls = self.classes[cat] + cls = np.array([cls]).astype(np.int32) + data = np.loadtxt(fn[1]).astype(np.float32) + point_set = data[:, 0:3] + normal = data[:, 3:6] + seg = data[:, -1].astype(np.int32) + if len(self.cache) < self.cache_size: + self.cache[index] = (point_set, normal, seg, cls) + + if self.normalize: + point_set = pc_normalize(point_set) + + choice = np.random.choice(len(seg), self.npoints, replace=True) + + # resample + # note that the number of points in some points clouds is less than 2048, thus use random.choice + # remember to use the same seed during train and test for a getting stable result + point_set = point_set[choice, :] + seg = seg[choice] + normal = normal[choice, :] + + return point_set, cls, seg, normal + + def __len__(self): + return len(self.datapath) + + +if __name__ == '__main__': + train = PartNormalDataset(npoints=2048, split='trainval', normalize=False) + test = PartNormalDataset(npoints=2048, split='test', normalize=False) + for data, label, _, _ in train: + print(data.shape) + print(label.shape) diff --git a/part_segmentation/util/util.py b/part_segmentation/util/util.py new file mode 100755 index 0000000..00afdd8 --- /dev/null +++ b/part_segmentation/util/util.py @@ -0,0 +1,69 @@ +import numpy as np +import torch +import torch.nn.functional as F + + +def cal_loss(pred, gold, smoothing=True): + ''' Calculate cross entropy loss, apply label smoothing if needed. ''' + + gold = gold.contiguous().view(-1) # gold is the groudtruth label in the dataloader + + if smoothing: + eps = 0.2 + n_class = pred.size(1) # the number of feature_dim of the ouput, which is output channels + + one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) + one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) + log_prb = F.log_softmax(pred, dim=1) + + loss = -(one_hot * log_prb).sum(dim=1).mean() + else: + loss = F.cross_entropy(pred, gold, reduction='mean') + + return loss + + +# create a file and write the text into it: +class IOStream(): + def __init__(self, path): + self.f = open(path, 'a') + + def cprint(self, text): + print(text) + self.f.write(text+'\n') + self.f.flush() + + def close(self): + self.f.close() + + +def to_categorical(y, num_classes): + """ 1-hot encodes a tensor """ + new_y = torch.eye(num_classes)[y.cpu().data.numpy(),] + if (y.is_cuda): + return new_y.cuda(non_blocking=True) + return new_y + + +def compute_overall_iou(pred, target, num_classes): + shape_ious = [] + pred = pred.max(dim=2)[1] # (batch_size, num_points) the pred_class_idx of each point in each sample + pred_np = pred.cpu().data.numpy() + + target_np = target.cpu().data.numpy() + for shape_idx in range(pred.size(0)): # sample_idx + part_ious = [] + for part in range(num_classes): # class_idx! no matter which category, only consider all part_classes of all categories, check all 50 classes + # for target, each point has a class no matter which category owns this point! also 50 classes!!! + # only return 1 when both belongs to this class, which means correct: + I = np.sum(np.logical_and(pred_np[shape_idx] == part, target_np[shape_idx] == part)) + # always return 1 when either is belongs to this class: + U = np.sum(np.logical_or(pred_np[shape_idx] == part, target_np[shape_idx] == part)) + + F = np.sum(target_np[shape_idx] == part) + + if F != 0: + iou = I / float(U) # iou across all points for this class + part_ious.append(iou) # append the iou of this class + shape_ious.append(np.mean(part_ious)) # each time append an average iou across all classes of this sample (sample_level!) + return shape_ious # [batch_size]