
244 lines
6 KiB
Raw Permalink Normal View History

2023-08-03 14:40:14 +00:00
"""Some helper functions for PyTorch, including:
- get_mean_and_std: calculate the mean and std value of dataset.
- msr_init: net parameter initialization.
- progress_bar: progress bar mimic xlua.progress.
2021-10-04 07:22:15 +00:00
import errno
import os
2023-08-03 14:40:14 +00:00
import random
import shutil
2021-10-04 07:22:15 +00:00
import sys
import time
2023-08-03 14:40:14 +00:00
import numpy as np
import torch
2021-10-04 07:22:15 +00:00
import torch.nn as nn
2023-08-03 14:40:14 +00:00
import torch.nn.functional as F
2021-10-04 07:22:15 +00:00
import torch.nn.init as init
2023-08-03 14:40:14 +00:00
__all__ = [
2021-10-04 07:22:15 +00:00
def get_mean_and_std(dataset):
2023-08-03 14:40:14 +00:00
"""Compute the mean and std value of dataset."""
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
2021-10-04 07:22:15 +00:00
mean = torch.zeros(3)
std = torch.zeros(3)
2023-08-03 14:40:14 +00:00
print("==> Computing mean and std..")
for inputs, _targets in dataloader:
2021-10-04 07:22:15 +00:00
for i in range(3):
2023-08-03 14:40:14 +00:00
mean[i] += inputs[:, i, :, :].mean()
std[i] += inputs[:, i, :, :].std()
2021-10-04 07:22:15 +00:00
return mean, std
2023-08-03 14:40:14 +00:00
2021-10-04 07:22:15 +00:00
def init_params(net):
2023-08-03 14:40:14 +00:00
"""Init layer parameters."""
2021-10-04 07:22:15 +00:00
for m in net.modules():
if isinstance(m, nn.Conv2d):
2023-08-03 14:40:14 +00:00
init.kaiming_normal(m.weight, mode="fan_out")
2021-10-04 07:22:15 +00:00
if m.bias:
init.constant(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant(m.weight, 1)
init.constant(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal(m.weight, std=1e-3)
if m.bias:
init.constant(m.bias, 0)
2023-08-03 14:40:14 +00:00
2021-10-04 07:22:15 +00:00
def mkdir_p(path):
2023-08-03 14:40:14 +00:00
"""Make dir if not exist."""
2021-10-04 07:22:15 +00:00
except OSError as exc: # Python >2.5
if exc.errno == errno.EEXIST and os.path.isdir(path):
2023-08-03 14:40:14 +00:00
class AverageMeter:
2021-10-04 07:22:15 +00:00
"""Computes and stores the average and current value
2023-08-03 14:40:14 +00:00
Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262.
2021-10-04 07:22:15 +00:00
2023-08-03 14:40:14 +00:00
2021-10-04 07:22:15 +00:00
def __init__(self):
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
2023-08-03 14:40:14 +00:00
2021-10-04 07:22:15 +00:00
last_time = time.time()
begin_time = last_time
2023-08-03 14:40:14 +00:00
2021-10-04 07:22:15 +00:00
def progress_bar(current, total, msg=None):
global last_time, begin_time
if current == 0:
begin_time = time.time() # Reset for new bar.
2023-08-03 14:40:14 +00:00
cur_len = int(TOTAL_BAR_LENGTH * current / total)
2021-10-04 07:22:15 +00:00
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
2023-08-03 14:40:14 +00:00
sys.stdout.write(" [")
for _i in range(cur_len):
for _i in range(rest_len):
2021-10-04 07:22:15 +00:00
cur_time = time.time()
step_time = cur_time - last_time
last_time = cur_time
tot_time = cur_time - begin_time
L = []
2023-08-03 14:40:14 +00:00
L.append(" Step: %s" % format_time(step_time))
L.append(" | Tot: %s" % format_time(tot_time))
2021-10-04 07:22:15 +00:00
if msg:
2023-08-03 14:40:14 +00:00
L.append(" | " + msg)
2021-10-04 07:22:15 +00:00
2023-08-03 14:40:14 +00:00
msg = "".join(L)
2021-10-04 07:22:15 +00:00
# for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
# sys.stdout.write(' ')
# Go back to the center of the bar.
# for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
# sys.stdout.write('\b')
2023-08-03 14:40:14 +00:00
sys.stdout.write(" %d/%d " % (current + 1, total))
2021-10-04 07:22:15 +00:00
2023-08-03 14:40:14 +00:00
if current < total - 1:
2021-10-04 07:22:15 +00:00
2023-08-03 14:40:14 +00:00
2021-10-04 07:22:15 +00:00
def format_time(seconds):
2023-08-03 14:40:14 +00:00
days = int(seconds / 3600 / 24)
seconds = seconds - days * 3600 * 24
2021-10-04 07:22:15 +00:00
hours = int(seconds / 3600)
2023-08-03 14:40:14 +00:00
seconds = seconds - hours * 3600
2021-10-04 07:22:15 +00:00
minutes = int(seconds / 60)
2023-08-03 14:40:14 +00:00
seconds = seconds - minutes * 60
2021-10-04 07:22:15 +00:00
secondsf = int(seconds)
seconds = seconds - secondsf
2023-08-03 14:40:14 +00:00
millis = int(seconds * 1000)
2021-10-04 07:22:15 +00:00
2023-08-03 14:40:14 +00:00
f = ""
2021-10-04 07:22:15 +00:00
i = 1
if days > 0:
2023-08-03 14:40:14 +00:00
f += str(days) + "D"
2021-10-04 07:22:15 +00:00
i += 1
if hours > 0 and i <= 2:
2023-08-03 14:40:14 +00:00
f += str(hours) + "h"
2021-10-04 07:22:15 +00:00
i += 1
if minutes > 0 and i <= 2:
2023-08-03 14:40:14 +00:00
f += str(minutes) + "m"
2021-10-04 07:22:15 +00:00
i += 1
if secondsf > 0 and i <= 2:
2023-08-03 14:40:14 +00:00
f += str(secondsf) + "s"
2021-10-04 07:22:15 +00:00
i += 1
if millis > 0 and i <= 2:
2023-08-03 14:40:14 +00:00
f += str(millis) + "ms"
2021-10-04 07:22:15 +00:00
i += 1
2023-08-03 14:40:14 +00:00
if f == "":
f = "0ms"
2021-10-04 07:22:15 +00:00
return f
def save_model(net, epoch, path, acc, is_best, **kwargs):
state = {
2023-08-03 14:40:14 +00:00
"net": net.state_dict(),
"epoch": epoch,
"acc": acc,
2021-10-04 07:22:15 +00:00
for key, value in kwargs.items():
state[key] = value
filepath = os.path.join(path, "last_checkpoint.pth")
torch.save(state, filepath)
if is_best:
2023-08-03 14:40:14 +00:00
shutil.copyfile(filepath, os.path.join(path, "best_checkpoint.pth"))
2021-10-04 07:22:15 +00:00
def save_args(args):
2023-08-03 14:40:14 +00:00
file = open(os.path.join(args.checkpoint, "args.txt"), "w")
2021-10-04 07:22:15 +00:00
for k, v in vars(args).items():
file.write(f"{k}:\t {v}\n")
def set_seed(seed=None):
if seed is None:
2023-08-03 14:40:14 +00:00
os.environ["PYTHONHASHSEED"] = "%s" % seed
2021-10-04 07:22:15 +00:00
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# create a file and write the text into it
2023-08-03 14:40:14 +00:00
class IOStream:
2021-10-04 07:22:15 +00:00
def __init__(self, path):
2023-08-03 14:40:14 +00:00
self.f = open(path, "a")
2021-10-04 07:22:15 +00:00
def cprint(self, text):
2023-08-03 14:40:14 +00:00
self.f.write(text + "\n")
2021-10-04 07:22:15 +00:00
def close(self):
def cal_loss(pred, gold, smoothing=True):
2023-08-03 14:40:14 +00:00
"""Calculate cross entropy loss, apply label smoothing if needed."""
2021-10-04 07:22:15 +00:00
gold = gold.contiguous().view(-1)
if smoothing:
eps = 0.2
n_class = pred.size(1)
one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, dim=1)
loss = -(one_hot * log_prb).sum(dim=1).mean()
2023-08-03 14:40:14 +00:00
loss = F.cross_entropy(pred, gold, reduction="mean")
2021-10-04 07:22:15 +00:00
return loss