PointMLP/part_segmentation/main.py

368 lines
16 KiB
Python
Raw Normal View History

2021-10-04 07:22:15 +00:00
"""
Usage:
python main.py --model CurveNet --exp_name=demo1
@Author: An Tao
@Contact: ta19@mails.tsinghua.edu.cn
@File: main_partseg.py
@Time: 2019/12/31 11:17 AM
Modified by
@Author: Tiange Xiang
@Contact: txia7609@uni.sydney.edu.au
@Time: 2021/01/21 3:10 PM
"""
from __future__ import print_function
import os
import datetime
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR, MultiStepLR
from data import ShapeNetPart
import models as models
import numpy as np
from torch.utils.data import DataLoader
from util import cal_loss, IOStream
import sklearn.metrics as metrics
seg_num = [4, 2, 2, 4, 4, 3, 3, 2, 4, 2, 6, 2, 3, 3, 3, 3]
index_start = [0, 4, 6, 8, 12, 16, 19, 22, 24, 28, 30, 36, 38, 41, 44, 47]
def _init_():
# fix random seed
if args.seed is not None:
torch.manual_seed(args.seed)
np.random.seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.cuda.manual_seed(args.seed)
torch.set_printoptions(10)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
os.environ['PYTHONHASHSEED'] = str(args.seed)
# prepare file structures
if not os.path.exists('checkpoints'):
os.makedirs('checkpoints')
if not os.path.exists('checkpoints/'+args.exp_name):
os.makedirs('checkpoints/'+args.exp_name)
if not os.path.exists('checkpoints/'+args.exp_name+'/'+'models'):
os.makedirs('checkpoints/'+args.exp_name+'/'+'models')
def calculate_shape_IoU(pred_np, seg_np, label, class_choice, eva=False):
label = label.squeeze()
shape_ious = []
category = {}
for shape_idx in range(seg_np.shape[0]):
if not class_choice:
start_index = index_start[label[shape_idx]]
num = seg_num[label[shape_idx]]
parts = range(start_index, start_index + num)
else:
parts = range(seg_num[label[0]])
part_ious = []
for part in parts:
I = np.sum(np.logical_and(pred_np[shape_idx] == part, seg_np[shape_idx] == part))
U = np.sum(np.logical_or(pred_np[shape_idx] == part, seg_np[shape_idx] == part))
if U == 0:
iou = 1 # If the union of groundtruth and prediction points is empty, then count part IoU as 1
else:
iou = I / float(U)
part_ious.append(iou)
shape_ious.append(np.mean(part_ious))
if label[shape_idx] not in category:
category[label[shape_idx]] = [shape_ious[-1]]
else:
category[label[shape_idx]].append(shape_ious[-1])
if eva:
return shape_ious, category
else:
return shape_ious
def train(args, io):
train_dataset = ShapeNetPart(partition='trainval', num_points=args.num_points, class_choice=args.class_choice)
if (len(train_dataset) < 100):
drop_last = False
else:
drop_last = True
train_loader = DataLoader(train_dataset, num_workers=8, batch_size=args.batch_size, shuffle=True, drop_last=drop_last, pin_memory=True)
test_loader = DataLoader(ShapeNetPart(partition='test', num_points=args.num_points, class_choice=args.class_choice),
num_workers=8, batch_size=args.test_batch_size, shuffle=False, drop_last=False, pin_memory=True)
device = torch.device("cuda" if args.cuda else "cpu")
io.cprint("Let's use " + str(torch.cuda.device_count()) + " GPUs!")
seg_num_all = train_loader.dataset.seg_num_all
seg_start_index = train_loader.dataset.seg_start_index
# create model
model = models.__dict__[args.model]().to(device)
io.cprint(str(model))
model = nn.DataParallel(model)
if args.use_sgd:
print("Use SGD")
opt = optim.SGD(model.parameters(), lr=args.lr*100, momentum=args.momentum, weight_decay=1e-4)
else:
print("Use Adam")
opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4)
if args.scheduler == 'cos':
if args.use_sgd:
eta_min = args.lr/5.0
else:
eta_min = args.lr/100.0
scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=eta_min)
elif args.scheduler == 'step':
scheduler = MultiStepLR(opt, [140, 180], gamma=0.1)
criterion = cal_loss
best_test_iou = 0
for epoch in range(args.epochs):
####################
# Train
####################
train_time_cost = datetime.datetime.now()
train_loss = 0.0
count = 0.0
model.train()
train_true_cls = []
train_pred_cls = []
train_true_seg = []
train_pred_seg = []
train_label_seg = []
for data, label, seg in train_loader:
seg = seg - seg_start_index
label_one_hot = np.zeros((label.shape[0], 16))
for idx in range(label.shape[0]):
label_one_hot[idx, label[idx]] = 1
label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32))
data, label_one_hot, seg = data.to(device), label_one_hot.to(device), seg.to(device)
data = data.permute(0, 2, 1)
batch_size = data.size()[0]
opt.zero_grad()
seg_pred = model(data, label_one_hot)
seg_pred = seg_pred.permute(0, 2, 1).contiguous()
loss = criterion(seg_pred.view(-1, seg_num_all), seg.view(-1,1).squeeze())
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
opt.step()
pred = seg_pred.max(dim=2)[1] # (batch_size, num_points)
count += batch_size
train_loss += loss.item() * batch_size
seg_np = seg.cpu().numpy() # (batch_size, num_points)
pred_np = pred.detach().cpu().numpy() # (batch_size, num_points)
train_true_cls.append(seg_np.reshape(-1)) # (batch_size * num_points)
train_pred_cls.append(pred_np.reshape(-1)) # (batch_size * num_points)
train_true_seg.append(seg_np)
train_pred_seg.append(pred_np)
train_label_seg.append(label.reshape(-1))
if args.scheduler == 'cos':
scheduler.step()
elif args.scheduler == 'step':
if opt.param_groups[0]['lr'] > 1e-5:
scheduler.step()
if opt.param_groups[0]['lr'] < 1e-5:
for param_group in opt.param_groups:
param_group['lr'] = 1e-5
train_true_cls = np.concatenate(train_true_cls)
train_pred_cls = np.concatenate(train_pred_cls)
train_acc = metrics.accuracy_score(train_true_cls, train_pred_cls)
avg_per_class_acc = metrics.balanced_accuracy_score(train_true_cls, train_pred_cls)
train_true_seg = np.concatenate(train_true_seg, axis=0)
train_pred_seg = np.concatenate(train_pred_seg, axis=0)
train_label_seg = np.concatenate(train_label_seg)
train_ious = calculate_shape_IoU(train_pred_seg, train_true_seg, train_label_seg, args.class_choice)
train_time_cost = int((datetime.datetime.now() - train_time_cost).total_seconds())
outstr = 'Train %d, loss: %.6f, train acc: %.6f, train avg acc: %.6f, train iou: %.6f' % (epoch,
train_loss*1.0/count,
train_acc,
avg_per_class_acc,
np.mean(train_ious))
io.cprint(outstr)
io.cprint(f"Training time: {train_time_cost} seconds.")
####################
# Test
####################
test_time_cost = datetime.datetime.now()
test_loss = 0.0
count = 0.0
model.eval()
test_true_cls = []
test_pred_cls = []
test_true_seg = []
test_pred_seg = []
test_label_seg = []
for data, label, seg in test_loader:
seg = seg - seg_start_index
label_one_hot = np.zeros((label.shape[0], 16))
for idx in range(label.shape[0]):
label_one_hot[idx, label[idx]] = 1
label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32))
data, label_one_hot, seg = data.to(device), label_one_hot.to(device), seg.to(device)
data = data.permute(0, 2, 1)
batch_size = data.size()[0]
seg_pred = model(data, label_one_hot)
seg_pred = seg_pred.permute(0, 2, 1).contiguous()
loss = criterion(seg_pred.view(-1, seg_num_all), seg.view(-1,1).squeeze())
pred = seg_pred.max(dim=2)[1]
count += batch_size
test_loss += loss.item() * batch_size
seg_np = seg.cpu().numpy()
pred_np = pred.detach().cpu().numpy()
test_true_cls.append(seg_np.reshape(-1))
test_pred_cls.append(pred_np.reshape(-1))
test_true_seg.append(seg_np)
test_pred_seg.append(pred_np)
test_label_seg.append(label.reshape(-1))
test_true_cls = np.concatenate(test_true_cls)
test_pred_cls = np.concatenate(test_pred_cls)
test_acc = metrics.accuracy_score(test_true_cls, test_pred_cls)
avg_per_class_acc = metrics.balanced_accuracy_score(test_true_cls, test_pred_cls)
test_true_seg = np.concatenate(test_true_seg, axis=0)
test_pred_seg = np.concatenate(test_pred_seg, axis=0)
test_label_seg = np.concatenate(test_label_seg)
test_ious = calculate_shape_IoU(test_pred_seg, test_true_seg, test_label_seg, args.class_choice)
test_time_cost = int((datetime.datetime.now() - test_time_cost).total_seconds())
outstr = 'Test %d, loss: %.6f, test acc: %.6f, test avg acc: %.6f, test iou: %.6f, best iou %.6f' % (epoch,
test_loss*1.0/count,
test_acc,
avg_per_class_acc,
np.mean(test_ious), best_test_iou)
io.cprint(outstr)
io.cprint(f"Testing time: {test_time_cost} seconds.")
if np.mean(test_ious) >= best_test_iou:
best_test_iou = np.mean(test_ious)
torch.save(model.state_dict(), 'checkpoints/%s/models/model.t7' % args.exp_name)
def test(args, io):
test_loader = DataLoader(ShapeNetPart(partition='test', num_points=args.num_points, class_choice=args.class_choice),
batch_size=args.test_batch_size, shuffle=True, drop_last=False)
device = torch.device("cuda" if args.cuda else "cpu")
#Try to load models
seg_start_index = test_loader.dataset.seg_start_index
model = models.__dict__[args.model]().to(device)
model = nn.DataParallel(model)
model.load_state_dict(torch.load(args.model_path))
model = model.eval()
test_acc = 0.0
test_true_cls = []
test_pred_cls = []
test_true_seg = []
test_pred_seg = []
test_label_seg = []
category = {}
for data, label, seg in test_loader:
seg = seg - seg_start_index
label_one_hot = np.zeros((label.shape[0], 16))
for idx in range(label.shape[0]):
label_one_hot[idx, label[idx]] = 1
label_one_hot = torch.from_numpy(label_one_hot.astype(np.float32))
data, label_one_hot, seg = data.to(device), label_one_hot.to(device), seg.to(device)
data = data.permute(0, 2, 1)
seg_pred = model(data, label_one_hot)
seg_pred = seg_pred.permute(0, 2, 1).contiguous()
pred = seg_pred.max(dim=2)[1]
seg_np = seg.cpu().numpy()
pred_np = pred.detach().cpu().numpy()
test_true_cls.append(seg_np.reshape(-1))
test_pred_cls.append(pred_np.reshape(-1))
test_true_seg.append(seg_np)
test_pred_seg.append(pred_np)
test_label_seg.append(label.reshape(-1))
test_true_cls = np.concatenate(test_true_cls)
test_pred_cls = np.concatenate(test_pred_cls)
test_acc = metrics.accuracy_score(test_true_cls, test_pred_cls)
avg_per_class_acc = metrics.balanced_accuracy_score(test_true_cls, test_pred_cls)
test_true_seg = np.concatenate(test_true_seg, axis=0)
test_pred_seg = np.concatenate(test_pred_seg, axis=0)
test_label_seg = np.concatenate(test_label_seg)
test_ious,category = calculate_shape_IoU(test_pred_seg, test_true_seg, test_label_seg, args.class_choice, eva=True)
outstr = 'Test :: test acc: %.6f, test avg acc: %.6f, test iou: %.6f' % (test_acc,
avg_per_class_acc,
np.mean(test_ious))
io.cprint(outstr)
results = []
for key in category.keys():
results.append((int(key), np.mean(category[key]), len(category[key])))
results.sort(key=lambda x:x[0])
for re in results:
io.cprint('idx: %d mIoU: %.3f num: %d' % (re[0], re[1], re[2]))
if __name__ == "__main__":
# Training settings
parser = argparse.ArgumentParser(description='Point Cloud Part Segmentation')
parser.add_argument('--model', type=str, default='CurveNet')
parser.add_argument('--exp_name', type=str, default='exp', metavar='N',
help='Name of the experiment')
parser.add_argument('--dataset', type=str, default='shapenetpart', metavar='N',
choices=['shapenetpart'])
parser.add_argument('--class_choice', type=str, default=None, metavar='N',
choices=['airplane', 'bag', 'cap', 'car', 'chair',
'earphone', 'guitar', 'knife', 'lamp', 'laptop',
'motor', 'mug', 'pistol', 'rocket', 'skateboard', 'table'])
parser.add_argument('--batch_size', type=int, default=32, metavar='batch_size',
help='Size of batch)')
parser.add_argument('--test_batch_size', type=int, default=16, metavar='batch_size',
help='Size of batch)')
parser.add_argument('--epochs', type=int, default=200, metavar='N',
help='number of episode to train ')
parser.add_argument('--seed', type=int)
parser.add_argument('--use_sgd', type=bool, default=True,
help='Use SGD')
parser.add_argument('--lr', type=float, default=0.0005, metavar='LR',
help='learning rate (default: 0.001, 0.1 if using sgd)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--scheduler', type=str, default='step', metavar='N',
choices=['cos', 'step'],
help='Scheduler to use, [cos, step]')
parser.add_argument('--no_cuda', type=bool, default=False,
help='enables CUDA training')
parser.add_argument('--eval', type=bool, default=False,
help='evaluate the model')
parser.add_argument('--num_points', type=int, default=2048,
help='num of points to use')
parser.add_argument('--model_path', type=str, default='', metavar='N',
help='Pretrained model path')
args = parser.parse_args()
time_str = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S'))
if args.exp_name is None:
args.exp_name = time_str
args.exp_name = args.model+"_"+args.exp_name
_init_()
if args.eval:
io = IOStream('checkpoints/' + args.exp_name + '/eval.log')
else:
io = IOStream('checkpoints/' + args.exp_name + '/run.log')
io.cprint(str(args))
io.cprint('random seed is: ' + str(args.seed))
args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda:
io.cprint(
'Using GPU : ' + str(torch.cuda.current_device()) + ' from ' + str(torch.cuda.device_count()) + ' devices')
else:
io.cprint('Using CPU')
if not args.eval:
train(args, io)
else:
with torch.no_grad():
test(args, io)