PointMLP/classification_ScanObjectNN/utils/misc.py
2023-08-03 16:40:14 +02:00

244 lines
6 KiB
Python

"""Some helper functions for PyTorch, including:
- get_mean_and_std: calculate the mean and std value of dataset.
- msr_init: net parameter initialization.
- progress_bar: progress bar mimic xlua.progress.
"""
import errno
import os
import random
import shutil
import sys
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
__all__ = [
"get_mean_and_std",
"init_params",
"mkdir_p",
"AverageMeter",
"progress_bar",
"save_model",
"save_args",
"set_seed",
"IOStream",
"cal_loss",
]
def get_mean_and_std(dataset):
"""Compute the mean and std value of dataset."""
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
mean = torch.zeros(3)
std = torch.zeros(3)
print("==> Computing mean and std..")
for inputs, _targets in dataloader:
for i in range(3):
mean[i] += inputs[:, i, :, :].mean()
std[i] += inputs[:, i, :, :].std()
mean.div_(len(dataset))
std.div_(len(dataset))
return mean, std
def init_params(net):
"""Init layer parameters."""
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal(m.weight, mode="fan_out")
if m.bias:
init.constant(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant(m.weight, 1)
init.constant(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal(m.weight, std=1e-3)
if m.bias:
init.constant(m.bias, 0)
def mkdir_p(path):
"""Make dir if not exist."""
try:
os.makedirs(path)
except OSError as exc: # Python >2.5
if exc.errno == errno.EEXIST and os.path.isdir(path):
pass
else:
raise
class AverageMeter:
"""Computes and stores the average and current value
Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262.
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
TOTAL_BAR_LENGTH = 65.0
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
global last_time, begin_time
if current == 0:
begin_time = time.time() # Reset for new bar.
cur_len = int(TOTAL_BAR_LENGTH * current / total)
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
sys.stdout.write(" [")
for _i in range(cur_len):
sys.stdout.write("=")
sys.stdout.write(">")
for _i in range(rest_len):
sys.stdout.write(".")
sys.stdout.write("]")
cur_time = time.time()
step_time = cur_time - last_time
last_time = cur_time
tot_time = cur_time - begin_time
L = []
L.append(" Step: %s" % format_time(step_time))
L.append(" | Tot: %s" % format_time(tot_time))
if msg:
L.append(" | " + msg)
msg = "".join(L)
sys.stdout.write(msg)
# for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
# sys.stdout.write(' ')
# Go back to the center of the bar.
# for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
# sys.stdout.write('\b')
sys.stdout.write(" %d/%d " % (current + 1, total))
if current < total - 1:
sys.stdout.write("\r")
else:
sys.stdout.write("\n")
sys.stdout.flush()
def format_time(seconds):
days = int(seconds / 3600 / 24)
seconds = seconds - days * 3600 * 24
hours = int(seconds / 3600)
seconds = seconds - hours * 3600
minutes = int(seconds / 60)
seconds = seconds - minutes * 60
secondsf = int(seconds)
seconds = seconds - secondsf
millis = int(seconds * 1000)
f = ""
i = 1
if days > 0:
f += str(days) + "D"
i += 1
if hours > 0 and i <= 2:
f += str(hours) + "h"
i += 1
if minutes > 0 and i <= 2:
f += str(minutes) + "m"
i += 1
if secondsf > 0 and i <= 2:
f += str(secondsf) + "s"
i += 1
if millis > 0 and i <= 2:
f += str(millis) + "ms"
i += 1
if f == "":
f = "0ms"
return f
def save_model(net, epoch, path, acc, is_best, **kwargs):
state = {
"net": net.state_dict(),
"epoch": epoch,
"acc": acc,
}
for key, value in kwargs.items():
state[key] = value
filepath = os.path.join(path, "last_checkpoint.pth")
torch.save(state, filepath)
if is_best:
shutil.copyfile(filepath, os.path.join(path, "best_checkpoint.pth"))
def save_args(args):
file = open(os.path.join(args.checkpoint, "args.txt"), "w")
for k, v in vars(args).items():
file.write(f"{k}:\t {v}\n")
file.close()
def set_seed(seed=None):
if seed is None:
return
random.seed(seed)
os.environ["PYTHONHASHSEED"] = "%s" % seed
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# create a file and write the text into it
class IOStream:
def __init__(self, path):
self.f = open(path, "a")
def cprint(self, text):
print(text)
self.f.write(text + "\n")
self.f.flush()
def close(self):
self.f.close()
def cal_loss(pred, gold, smoothing=True):
"""Calculate cross entropy loss, apply label smoothing if needed."""
gold = gold.contiguous().view(-1)
if smoothing:
eps = 0.2
n_class = pred.size(1)
one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, dim=1)
loss = -(one_hot * log_prb).sum(dim=1).mean()
else:
loss = F.cross_entropy(pred, gold, reduction="mean")
return loss