""" Usage: python main.py --model CurveNet --exp_name=demo1 @Author: An Tao @Contact: ta19@mails.tsinghua.edu.cn @File: main_partseg.py @Time: 2019/12/31 11:17 AM Modified by @Author: Tiange Xiang @Contact: txia7609@uni.sydney.edu.au @Time: 2021/01/21 3:10 PM """ from __future__ import print_function import os import datetime import argparse import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR, MultiStepLR from data import ShapeNetPart import models as models import numpy as np from torch.utils.data import DataLoader from util import cal_loss, IOStream import sklearn.metrics as metrics seg_num = [4, 2, 2, 4, 4, 3, 3, 2, 4, 2, 6, 2, 3, 3, 3, 3] index_start = [0, 4, 6, 8, 12, 16, 19, 22, 24, 28, 30, 36, 38, 41, 44, 47] def _init_(): # fix random seed if args.seed is not None: torch.manual_seed(args.seed) np.random.seed(args.seed) torch.cuda.manual_seed_all(args.seed) torch.cuda.manual_seed(args.seed) torch.set_printoptions(10) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True os.environ['PYTHONHASHSEED'] = str(args.seed) # prepare file structures if not os.path.exists('checkpoints'): os.makedirs('checkpoints') if not os.path.exists('checkpoints/'+args.exp_name): os.makedirs('checkpoints/'+args.exp_name) if not os.path.exists('checkpoints/'+args.exp_name+'/'+'models'): os.makedirs('checkpoints/'+args.exp_name+'/'+'models') def calculate_shape_IoU(pred_np, seg_np, label, class_choice, eva=False): label = label.squeeze() shape_ious = [] category = {} for shape_idx in range(seg_np.shape[0]): if not class_choice: start_index = index_start[label[shape_idx]] num = seg_num[label[shape_idx]] parts = range(start_index, start_index + num) else: parts = range(seg_num[label[0]]) part_ious = [] for part in parts: I = np.sum(np.logical_and(pred_np[shape_idx] == part, seg_np[shape_idx] == part)) U = np.sum(np.logical_or(pred_np[shape_idx] == part, seg_np[shape_idx] == part)) if U == 0: iou = 1 # If the union of groundtruth and prediction points is empty, then count part IoU as 1 else: iou = I / float(U) part_ious.append(iou) shape_ious.append(np.mean(part_ious)) if label[shape_idx] not in category: category[label[shape_idx]] = [shape_ious[-1]] else: category[label[shape_idx]].append(shape_ious[-1]) if eva: return shape_ious, category else: return shape_ious def train(args, io): train_dataset = ShapeNetPart(partition='trainval', num_points=args.num_points, class_choice=args.class_choice) if (len(train_dataset) < 100): drop_last = False else: drop_last = True train_loader = DataLoader(train_dataset, num_workers=8, batch_size=args.batch_size, shuffle=True, drop_last=drop_last, pin_memory=True) test_loader = DataLoader(ShapeNetPart(partition='test', num_points=args.num_points, class_choice=args.class_choice), num_workers=8, batch_size=args.test_batch_size, shuffle=False, drop_last=False, pin_memory=True) device = torch.device("cuda" if args.cuda else "cpu") io.cprint("Let's use " + str(torch.cuda.device_count()) + " GPUs!") seg_num_all = train_loader.dataset.seg_num_all seg_start_index = train_loader.dataset.seg_start_index # create model model = models.__dict__[args.model]().to(device) io.cprint(str(model)) model = nn.DataParallel(model) if args.use_sgd: print("Use SGD") opt = optim.SGD(model.parameters(), lr=args.lr*100, momentum=args.momentum, weight_decay=1e-4) else: print("Use Adam") opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) if args.scheduler == 'cos': if args.use_sgd: eta_min = args.lr/5.0 else: eta_min = args.lr/100.0 scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=eta_min) elif args.scheduler == 'step': scheduler = MultiStepLR(opt, [140, 180], gamma=0.1) criterion = cal_loss best_test_iou = 0 for epoch in range(args.epochs): #################### # Train #################### train_time_cost = datetime.datetime.now() train_loss = 0.0 count = 0.0 model.train() train_true_cls = [] train_pred_cls = [] train_true_seg = [] train_pred_seg = [] train_label_seg = [] for data, label, seg in train_loader: seg = seg - seg_start_index label_one_hot = np.zeros((label.shape[0], 16)) for idx in range(label.shape[0]): label_one_hot[idx, label[idx]] = 1 label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32)) data, label_one_hot, seg = data.to(device), label_one_hot.to(device), seg.to(device) data = data.permute(0, 2, 1) batch_size = data.size()[0] opt.zero_grad() seg_pred = model(data, label_one_hot) seg_pred = seg_pred.permute(0, 2, 1).contiguous() loss = criterion(seg_pred.view(-1, seg_num_all), seg.view(-1,1).squeeze()) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1) opt.step() pred = seg_pred.max(dim=2)[1] # (batch_size, num_points) count += batch_size train_loss += loss.item() * batch_size seg_np = seg.cpu().numpy() # (batch_size, num_points) pred_np = pred.detach().cpu().numpy() # (batch_size, num_points) train_true_cls.append(seg_np.reshape(-1)) # (batch_size * num_points) train_pred_cls.append(pred_np.reshape(-1)) # (batch_size * num_points) train_true_seg.append(seg_np) train_pred_seg.append(pred_np) train_label_seg.append(label.reshape(-1)) if args.scheduler == 'cos': scheduler.step() elif args.scheduler == 'step': if opt.param_groups[0]['lr'] > 1e-5: scheduler.step() if opt.param_groups[0]['lr'] < 1e-5: for param_group in opt.param_groups: param_group['lr'] = 1e-5 train_true_cls = np.concatenate(train_true_cls) train_pred_cls = np.concatenate(train_pred_cls) train_acc = metrics.accuracy_score(train_true_cls, train_pred_cls) avg_per_class_acc = metrics.balanced_accuracy_score(train_true_cls, train_pred_cls) train_true_seg = np.concatenate(train_true_seg, axis=0) train_pred_seg = np.concatenate(train_pred_seg, axis=0) train_label_seg = np.concatenate(train_label_seg) train_ious = calculate_shape_IoU(train_pred_seg, train_true_seg, train_label_seg, args.class_choice) train_time_cost = int((datetime.datetime.now() - train_time_cost).total_seconds()) outstr = 'Train %d, loss: %.6f, train acc: %.6f, train avg acc: %.6f, train iou: %.6f' % (epoch, train_loss*1.0/count, train_acc, avg_per_class_acc, np.mean(train_ious)) io.cprint(outstr) io.cprint(f"Training time: {train_time_cost} seconds.") #################### # Test #################### test_time_cost = datetime.datetime.now() test_loss = 0.0 count = 0.0 model.eval() test_true_cls = [] test_pred_cls = [] test_true_seg = [] test_pred_seg = [] test_label_seg = [] for data, label, seg in test_loader: seg = seg - seg_start_index label_one_hot = np.zeros((label.shape[0], 16)) for idx in range(label.shape[0]): label_one_hot[idx, label[idx]] = 1 label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32)) data, label_one_hot, seg = data.to(device), label_one_hot.to(device), seg.to(device) data = data.permute(0, 2, 1) batch_size = data.size()[0] seg_pred = model(data, label_one_hot) seg_pred = seg_pred.permute(0, 2, 1).contiguous() loss = criterion(seg_pred.view(-1, seg_num_all), seg.view(-1,1).squeeze()) pred = seg_pred.max(dim=2)[1] count += batch_size test_loss += loss.item() * batch_size seg_np = seg.cpu().numpy() pred_np = pred.detach().cpu().numpy() test_true_cls.append(seg_np.reshape(-1)) test_pred_cls.append(pred_np.reshape(-1)) test_true_seg.append(seg_np) test_pred_seg.append(pred_np) test_label_seg.append(label.reshape(-1)) test_true_cls = np.concatenate(test_true_cls) test_pred_cls = np.concatenate(test_pred_cls) test_acc = metrics.accuracy_score(test_true_cls, test_pred_cls) avg_per_class_acc = metrics.balanced_accuracy_score(test_true_cls, test_pred_cls) test_true_seg = np.concatenate(test_true_seg, axis=0) test_pred_seg = np.concatenate(test_pred_seg, axis=0) test_label_seg = np.concatenate(test_label_seg) test_ious = calculate_shape_IoU(test_pred_seg, test_true_seg, test_label_seg, args.class_choice) test_time_cost = int((datetime.datetime.now() - test_time_cost).total_seconds()) outstr = 'Test %d, loss: %.6f, test acc: %.6f, test avg acc: %.6f, test iou: %.6f, best iou %.6f' % (epoch, test_loss*1.0/count, test_acc, avg_per_class_acc, np.mean(test_ious), best_test_iou) io.cprint(outstr) io.cprint(f"Testing time: {test_time_cost} seconds.") if np.mean(test_ious) >= best_test_iou: best_test_iou = np.mean(test_ious) torch.save(model.state_dict(), 'checkpoints/%s/models/model.t7' % args.exp_name) def test(args, io): test_loader = DataLoader(ShapeNetPart(partition='test', num_points=args.num_points, class_choice=args.class_choice), batch_size=args.test_batch_size, shuffle=True, drop_last=False) device = torch.device("cuda" if args.cuda else "cpu") #Try to load models seg_start_index = test_loader.dataset.seg_start_index model = models.__dict__[args.model]().to(device) model = nn.DataParallel(model) model.load_state_dict(torch.load(args.model_path)) model = model.eval() test_acc = 0.0 test_true_cls = [] test_pred_cls = [] test_true_seg = [] test_pred_seg = [] test_label_seg = [] category = {} for data, label, seg in test_loader: seg = seg - seg_start_index label_one_hot = np.zeros((label.shape[0], 16)) for idx in range(label.shape[0]): label_one_hot[idx, label[idx]] = 1 label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32)) data, label_one_hot, seg = data.to(device), label_one_hot.to(device), seg.to(device) data = data.permute(0, 2, 1) seg_pred = model(data, label_one_hot) seg_pred = seg_pred.permute(0, 2, 1).contiguous() pred = seg_pred.max(dim=2)[1] seg_np = seg.cpu().numpy() pred_np = pred.detach().cpu().numpy() test_true_cls.append(seg_np.reshape(-1)) test_pred_cls.append(pred_np.reshape(-1)) test_true_seg.append(seg_np) test_pred_seg.append(pred_np) test_label_seg.append(label.reshape(-1)) test_true_cls = np.concatenate(test_true_cls) test_pred_cls = np.concatenate(test_pred_cls) test_acc = metrics.accuracy_score(test_true_cls, test_pred_cls) avg_per_class_acc = metrics.balanced_accuracy_score(test_true_cls, test_pred_cls) test_true_seg = np.concatenate(test_true_seg, axis=0) test_pred_seg = np.concatenate(test_pred_seg, axis=0) test_label_seg = np.concatenate(test_label_seg) test_ious,category = calculate_shape_IoU(test_pred_seg, test_true_seg, test_label_seg, args.class_choice, eva=True) outstr = 'Test :: test acc: %.6f, test avg acc: %.6f, test iou: %.6f' % (test_acc, avg_per_class_acc, np.mean(test_ious)) io.cprint(outstr) results = [] for key in category.keys(): results.append((int(key), np.mean(category[key]), len(category[key]))) results.sort(key=lambda x:x[0]) for re in results: io.cprint('idx: %d mIoU: %.3f num: %d' % (re[0], re[1], re[2])) if __name__ == "__main__": # Training settings parser = argparse.ArgumentParser(description='Point Cloud Part Segmentation') parser.add_argument('--model', type=str, default='CurveNet') parser.add_argument('--exp_name', type=str, default='exp', metavar='N', help='Name of the experiment') parser.add_argument('--dataset', type=str, default='shapenetpart', metavar='N', choices=['shapenetpart']) parser.add_argument('--class_choice', type=str, default=None, metavar='N', choices=['airplane', 'bag', 'cap', 'car', 'chair', 'earphone', 'guitar', 'knife', 'lamp', 'laptop', 'motor', 'mug', 'pistol', 'rocket', 'skateboard', 'table']) parser.add_argument('--batch_size', type=int, default=32, metavar='batch_size', help='Size of batch)') parser.add_argument('--test_batch_size', type=int, default=16, metavar='batch_size', help='Size of batch)') parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of episode to train ') parser.add_argument('--seed', type=int) parser.add_argument('--use_sgd', type=bool, default=True, help='Use SGD') parser.add_argument('--lr', type=float, default=0.0005, metavar='LR', help='learning rate (default: 0.001, 0.1 if using sgd)') parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)') parser.add_argument('--scheduler', type=str, default='step', metavar='N', choices=['cos', 'step'], help='Scheduler to use, [cos, step]') parser.add_argument('--no_cuda', type=bool, default=False, help='enables CUDA training') parser.add_argument('--eval', type=bool, default=False, help='evaluate the model') parser.add_argument('--num_points', type=int, default=2048, help='num of points to use') parser.add_argument('--model_path', type=str, default='', metavar='N', help='Pretrained model path') args = parser.parse_args() time_str = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S')) if args.exp_name is None: args.exp_name = time_str args.exp_name = args.model+"_"+args.exp_name _init_() if args.eval: io = IOStream('checkpoints/' + args.exp_name + '/eval.log') else: io = IOStream('checkpoints/' + args.exp_name + '/run.log') io.cprint(str(args)) io.cprint('random seed is: ' + str(args.seed)) args.cuda = not args.no_cuda and torch.cuda.is_available() if args.cuda: io.cprint( 'Using GPU : ' + str(torch.cuda.current_device()) + ' from ' + str(torch.cuda.device_count()) + ' devices') else: io.cprint('Using CPU') if not args.eval: train(args, io) else: with torch.no_grad(): test(args, io)