"""Some helper functions for PyTorch, including: - get_mean_and_std: calculate the mean and std value of dataset. - msr_init: net parameter initialization. - progress_bar: progress bar mimic xlua.progress. """ import errno import os import random import shutil import sys import time import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init __all__ = [ "get_mean_and_std", "init_params", "mkdir_p", "AverageMeter", "progress_bar", "save_model", "save_args", "set_seed", "IOStream", "cal_loss", ] def get_mean_and_std(dataset): """Compute the mean and std value of dataset.""" dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) mean = torch.zeros(3) std = torch.zeros(3) print("==> Computing mean and std..") for inputs, _targets in dataloader: for i in range(3): mean[i] += inputs[:, i, :, :].mean() std[i] += inputs[:, i, :, :].std() mean.div_(len(dataset)) std.div_(len(dataset)) return mean, std def init_params(net): """Init layer parameters.""" for m in net.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal(m.weight, mode="fan_out") if m.bias: init.constant(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): init.constant(m.weight, 1) init.constant(m.bias, 0) elif isinstance(m, nn.Linear): init.normal(m.weight, std=1e-3) if m.bias: init.constant(m.bias, 0) def mkdir_p(path): """Make dir if not exist.""" try: os.makedirs(path) except OSError as exc: # Python >2.5 if exc.errno == errno.EEXIST and os.path.isdir(path): pass else: raise class AverageMeter: """Computes and stores the average and current value Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262. """ def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count TOTAL_BAR_LENGTH = 65.0 last_time = time.time() begin_time = last_time def progress_bar(current, total, msg=None): global last_time, begin_time if current == 0: begin_time = time.time() # Reset for new bar. cur_len = int(TOTAL_BAR_LENGTH * current / total) rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 sys.stdout.write(" [") for _i in range(cur_len): sys.stdout.write("=") sys.stdout.write(">") for _i in range(rest_len): sys.stdout.write(".") sys.stdout.write("]") cur_time = time.time() step_time = cur_time - last_time last_time = cur_time tot_time = cur_time - begin_time L = [] L.append(" Step: %s" % format_time(step_time)) L.append(" | Tot: %s" % format_time(tot_time)) if msg: L.append(" | " + msg) msg = "".join(L) sys.stdout.write(msg) # for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): # sys.stdout.write(' ') # Go back to the center of the bar. # for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): # sys.stdout.write('\b') sys.stdout.write(" %d/%d " % (current + 1, total)) if current < total - 1: sys.stdout.write("\r") else: sys.stdout.write("\n") sys.stdout.flush() def format_time(seconds): days = int(seconds / 3600 / 24) seconds = seconds - days * 3600 * 24 hours = int(seconds / 3600) seconds = seconds - hours * 3600 minutes = int(seconds / 60) seconds = seconds - minutes * 60 secondsf = int(seconds) seconds = seconds - secondsf millis = int(seconds * 1000) f = "" i = 1 if days > 0: f += str(days) + "D" i += 1 if hours > 0 and i <= 2: f += str(hours) + "h" i += 1 if minutes > 0 and i <= 2: f += str(minutes) + "m" i += 1 if secondsf > 0 and i <= 2: f += str(secondsf) + "s" i += 1 if millis > 0 and i <= 2: f += str(millis) + "ms" i += 1 if f == "": f = "0ms" return f def save_model(net, epoch, path, acc, is_best, **kwargs): state = { "net": net.state_dict(), "epoch": epoch, "acc": acc, } for key, value in kwargs.items(): state[key] = value filepath = os.path.join(path, "last_checkpoint.pth") torch.save(state, filepath) if is_best: shutil.copyfile(filepath, os.path.join(path, "best_checkpoint.pth")) def save_args(args): file = open(os.path.join(args.checkpoint, "args.txt"), "w") for k, v in vars(args).items(): file.write(f"{k}:\t {v}\n") file.close() def set_seed(seed=None): if seed is None: return random.seed(seed) os.environ["PYTHONHASHSEED"] = "%s" % seed np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True # create a file and write the text into it class IOStream: def __init__(self, path): self.f = open(path, "a") def cprint(self, text): print(text) self.f.write(text + "\n") self.f.flush() def close(self): self.f.close() def cal_loss(pred, gold, smoothing=True): """Calculate cross entropy loss, apply label smoothing if needed.""" gold = gold.contiguous().view(-1) if smoothing: eps = 0.2 n_class = pred.size(1) one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) log_prb = F.log_softmax(pred, dim=1) loss = -(one_hot * log_prb).sum(dim=1).mean() else: loss = F.cross_entropy(pred, gold, reduction="mean") return loss