PointMLP/classification_ScanObjectNN/main.py

299 lines
11 KiB
Python
Raw Normal View History

2023-08-03 14:40:14 +00:00
"""for training with resume functions.
2021-10-04 07:22:15 +00:00
Usage:
python main.py --model PointNet --msg demo
or
2023-08-03 14:40:14 +00:00
CUDA_VISIBLE_DEVICES=0 nohup python main.py --model PointNet --msg demo > nohup/PointNet_demo.out &.
2021-10-04 07:22:15 +00:00
"""
import argparse
import datetime
2023-08-03 14:40:14 +00:00
import logging
import os
import models as models
import numpy as np
import sklearn.metrics as metrics
2021-10-04 07:22:15 +00:00
import torch
import torch.backends.cudnn as cudnn
2023-08-03 14:40:14 +00:00
import torch.nn.parallel
2021-10-04 07:22:15 +00:00
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from ScanObjectNN import ScanObjectNN
from torch.optim.lr_scheduler import CosineAnnealingLR
2023-08-03 14:40:14 +00:00
from torch.utils.data import DataLoader
from utils import Logger, cal_loss, mkdir_p, progress_bar, save_args, save_model
2021-10-04 07:22:15 +00:00
def parse_args():
"""Parameters"""
2023-08-03 14:40:14 +00:00
parser = argparse.ArgumentParser("training")
parser.add_argument(
"-c",
"--checkpoint",
type=str,
metavar="PATH",
help="path to save checkpoint (default: checkpoint)",
)
parser.add_argument("--msg", type=str, help="message after checkpoint")
parser.add_argument("--batch_size", type=int, default=32, help="batch size in training")
parser.add_argument("--model", default="PointNet", help="model name [default: pointnet_cls]")
parser.add_argument("--num_classes", default=15, type=int, help="default value for classes of ScanObjectNN")
parser.add_argument("--epoch", default=200, type=int, help="number of epoch in training")
parser.add_argument("--num_points", type=int, default=1024, help="Point Number")
parser.add_argument("--learning_rate", default=0.01, type=float, help="learning rate in training")
parser.add_argument("--weight_decay", type=float, default=1e-4, help="decay rate")
parser.add_argument("--smoothing", action="store_true", default=False, help="loss smoothing")
parser.add_argument("--seed", type=int, help="random seed")
parser.add_argument("--workers", default=4, type=int, help="workers")
2021-10-04 07:22:15 +00:00
return parser.parse_args()
def main():
args = parse_args()
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
if args.seed is not None:
torch.manual_seed(args.seed)
if torch.cuda.is_available():
2023-08-03 14:40:14 +00:00
device = "cuda"
2021-10-04 07:22:15 +00:00
if args.seed is not None:
torch.cuda.manual_seed(args.seed)
else:
2023-08-03 14:40:14 +00:00
device = "cpu"
time_str = str(datetime.datetime.now().strftime("-%Y%m%d%H%M%S"))
2021-10-04 07:22:15 +00:00
if args.msg is None:
message = time_str
else:
message = "-" + args.msg
2023-08-03 14:40:14 +00:00
args.checkpoint = "checkpoints/" + args.model + message
2021-10-04 07:22:15 +00:00
if not os.path.isdir(args.checkpoint):
mkdir_p(args.checkpoint)
screen_logger = logging.getLogger("Model")
screen_logger.setLevel(logging.INFO)
2023-08-03 14:40:14 +00:00
formatter = logging.Formatter("%(message)s")
2021-10-04 07:22:15 +00:00
file_handler = logging.FileHandler(os.path.join(args.checkpoint, "out.txt"))
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
screen_logger.addHandler(file_handler)
def printf(str):
screen_logger.info(str)
print(str)
# Model
printf(f"args: {args}")
2023-08-03 14:40:14 +00:00
printf("==> Building model..")
2021-10-04 07:22:15 +00:00
net = models.__dict__[args.model](num_classes=args.num_classes)
criterion = cal_loss
net = net.to(device)
# criterion = criterion.to(device)
2023-08-03 14:40:14 +00:00
if device == "cuda":
2021-10-04 07:22:15 +00:00
net = torch.nn.DataParallel(net)
cudnn.benchmark = True
2023-08-03 14:40:14 +00:00
best_test_acc = 0.0 # best test accuracy
best_train_acc = 0.0
best_test_acc_avg = 0.0
best_train_acc_avg = 0.0
2021-10-04 07:22:15 +00:00
best_test_loss = float("inf")
best_train_loss = float("inf")
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
optimizer_dict = None
if not os.path.isfile(os.path.join(args.checkpoint, "last_checkpoint.pth")):
save_args(args)
2023-08-03 14:40:14 +00:00
logger = Logger(os.path.join(args.checkpoint, "log.txt"), title="ModelNet" + args.model)
logger.set_names(
[
"Epoch-Num",
"Learning-Rate",
"Train-Loss",
"Train-acc-B",
"Train-acc",
"Valid-Loss",
"Valid-acc-B",
"Valid-acc",
],
)
2021-10-04 07:22:15 +00:00
else:
printf(f"Resuming last checkpoint from {args.checkpoint}")
checkpoint_path = os.path.join(args.checkpoint, "last_checkpoint.pth")
checkpoint = torch.load(checkpoint_path)
2023-08-03 14:40:14 +00:00
net.load_state_dict(checkpoint["net"])
start_epoch = checkpoint["epoch"]
best_test_acc = checkpoint["best_test_acc"]
best_train_acc = checkpoint["best_train_acc"]
best_test_acc_avg = checkpoint["best_test_acc_avg"]
best_train_acc_avg = checkpoint["best_train_acc_avg"]
best_test_loss = checkpoint["best_test_loss"]
best_train_loss = checkpoint["best_train_loss"]
logger = Logger(os.path.join(args.checkpoint, "log.txt"), title="ModelNet" + args.model, resume=True)
optimizer_dict = checkpoint["optimizer"]
printf("==> Preparing data..")
train_loader = DataLoader(
ScanObjectNN(partition="training", num_points=args.num_points),
num_workers=args.workers,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
)
test_loader = DataLoader(
ScanObjectNN(partition="test", num_points=args.num_points),
num_workers=args.workers,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
)
2021-10-04 07:22:15 +00:00
optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay)
if optimizer_dict is not None:
optimizer.load_state_dict(optimizer_dict)
scheduler = CosineAnnealingLR(optimizer, args.epoch, eta_min=args.learning_rate / 100, last_epoch=start_epoch - 1)
for epoch in range(start_epoch, args.epoch):
2023-08-03 14:40:14 +00:00
printf("Epoch(%d/%s) Learning Rate %s:" % (epoch + 1, args.epoch, optimizer.param_groups[0]["lr"]))
2021-10-04 07:22:15 +00:00
train_out = train(net, train_loader, optimizer, criterion, device) # {"loss", "acc", "acc_avg", "time"}
test_out = validate(net, test_loader, criterion, device)
scheduler.step()
if test_out["acc"] > best_test_acc:
best_test_acc = test_out["acc"]
is_best = True
else:
is_best = False
best_test_acc = test_out["acc"] if (test_out["acc"] > best_test_acc) else best_test_acc
best_train_acc = train_out["acc"] if (train_out["acc"] > best_train_acc) else best_train_acc
best_test_acc_avg = test_out["acc_avg"] if (test_out["acc_avg"] > best_test_acc_avg) else best_test_acc_avg
best_train_acc_avg = train_out["acc_avg"] if (train_out["acc_avg"] > best_train_acc_avg) else best_train_acc_avg
best_test_loss = test_out["loss"] if (test_out["loss"] < best_test_loss) else best_test_loss
best_train_loss = train_out["loss"] if (train_out["loss"] < best_train_loss) else best_train_loss
save_model(
2023-08-03 14:40:14 +00:00
net,
epoch,
path=args.checkpoint,
acc=test_out["acc"],
is_best=is_best,
2021-10-04 07:22:15 +00:00
best_test_acc=best_test_acc, # best test accuracy
best_train_acc=best_train_acc,
best_test_acc_avg=best_test_acc_avg,
best_train_acc_avg=best_train_acc_avg,
best_test_loss=best_test_loss,
best_train_loss=best_train_loss,
2023-08-03 14:40:14 +00:00
optimizer=optimizer.state_dict(),
)
logger.append(
[
epoch,
optimizer.param_groups[0]["lr"],
train_out["loss"],
train_out["acc_avg"],
train_out["acc"],
test_out["loss"],
test_out["acc_avg"],
test_out["acc"],
],
2021-10-04 07:22:15 +00:00
)
printf(
2023-08-03 14:40:14 +00:00
f"Training loss:{train_out['loss']} acc_avg:{train_out['acc_avg']}% acc:{train_out['acc']}% time:{train_out['time']}s",
)
2021-10-04 07:22:15 +00:00
printf(
f"Testing loss:{test_out['loss']} acc_avg:{test_out['acc_avg']}% "
2023-08-03 14:40:14 +00:00
f"acc:{test_out['acc']}% time:{test_out['time']}s [best test acc: {best_test_acc}%] \n\n",
)
2021-10-04 07:22:15 +00:00
logger.close()
2023-08-03 14:40:14 +00:00
printf("++++++++" * 2 + "Final results" + "++++++++" * 2)
2021-10-04 07:22:15 +00:00
printf(f"++ Last Train time: {train_out['time']} | Last Test time: {test_out['time']} ++")
printf(f"++ Best Train loss: {best_train_loss} | Best Test loss: {best_test_loss} ++")
printf(f"++ Best Train acc_B: {best_train_acc_avg} | Best Test acc_B: {best_test_acc_avg} ++")
printf(f"++ Best Train acc: {best_train_acc} | Best Test acc: {best_test_acc} ++")
2023-08-03 14:40:14 +00:00
printf("++++++++" * 5)
2021-10-04 07:22:15 +00:00
def train(net, trainloader, optimizer, criterion, device):
net.train()
train_loss = 0
correct = 0
total = 0
train_pred = []
train_true = []
time_cost = datetime.datetime.now()
for batch_idx, (data, label) in enumerate(trainloader):
data, label = data.to(device), label.to(device).squeeze()
data = data.permute(0, 2, 1) # so, the input data shape is [batch, 3, 1024]
optimizer.zero_grad()
logits = net(data)
loss = criterion(logits, label)
loss.backward()
optimizer.step()
train_loss += loss.item()
preds = logits.max(dim=1)[1]
train_true.append(label.cpu().numpy())
train_pred.append(preds.detach().cpu().numpy())
total += label.size(0)
correct += preds.eq(label).sum().item()
2023-08-03 14:40:14 +00:00
progress_bar(
batch_idx,
len(trainloader),
"Loss: %.3f | Acc: %.3f%% (%d/%d)"
% (train_loss / (batch_idx + 1), 100.0 * correct / total, correct, total),
)
2021-10-04 07:22:15 +00:00
time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
train_true = np.concatenate(train_true)
train_pred = np.concatenate(train_pred)
return {
"loss": float("%.3f" % (train_loss / (batch_idx + 1))),
2023-08-03 14:40:14 +00:00
"acc": float("%.3f" % (100.0 * metrics.accuracy_score(train_true, train_pred))),
"acc_avg": float("%.3f" % (100.0 * metrics.balanced_accuracy_score(train_true, train_pred))),
"time": time_cost,
2021-10-04 07:22:15 +00:00
}
def validate(net, testloader, criterion, device):
net.eval()
test_loss = 0
correct = 0
total = 0
test_true = []
test_pred = []
time_cost = datetime.datetime.now()
with torch.no_grad():
for batch_idx, (data, label) in enumerate(testloader):
data, label = data.to(device), label.to(device).squeeze()
data = data.permute(0, 2, 1)
logits = net(data)
loss = criterion(logits, label)
test_loss += loss.item()
preds = logits.max(dim=1)[1]
test_true.append(label.cpu().numpy())
test_pred.append(preds.detach().cpu().numpy())
total += label.size(0)
correct += preds.eq(label).sum().item()
2023-08-03 14:40:14 +00:00
progress_bar(
batch_idx,
len(testloader),
"Loss: %.3f | Acc: %.3f%% (%d/%d)"
% (test_loss / (batch_idx + 1), 100.0 * correct / total, correct, total),
)
2021-10-04 07:22:15 +00:00
time_cost = int((datetime.datetime.now() - time_cost).total_seconds())
test_true = np.concatenate(test_true)
test_pred = np.concatenate(test_pred)
return {
"loss": float("%.3f" % (test_loss / (batch_idx + 1))),
2023-08-03 14:40:14 +00:00
"acc": float("%.3f" % (100.0 * metrics.accuracy_score(test_true, test_pred))),
"acc_avg": float("%.3f" % (100.0 * metrics.balanced_accuracy_score(test_true, test_pred))),
"time": time_cost,
2021-10-04 07:22:15 +00:00
}
2023-08-03 14:40:14 +00:00
if __name__ == "__main__":
2021-10-04 07:22:15 +00:00
main()