'''Some helper functions for PyTorch, including: - get_mean_and_std: calculate the mean and std value of dataset. - msr_init: net parameter initialization. - progress_bar: progress bar mimic xlua.progress. ''' import errno import os import sys import time import math import torch import shutil import numpy as np import random import torch.nn.functional as F import torch.nn as nn import torch.nn.init as init from torch.autograd import Variable __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter', 'progress_bar','save_model',"save_args","set_seed", "IOStream", "cal_loss"] def get_mean_and_std(dataset): '''Compute the mean and std value of dataset.''' dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) mean = torch.zeros(3) std = torch.zeros(3) print('==> Computing mean and std..') for inputs, targets in dataloader: for i in range(3): mean[i] += inputs[:,i,:,:].mean() std[i] += inputs[:,i,:,:].std() mean.div_(len(dataset)) std.div_(len(dataset)) return mean, std def init_params(net): '''Init layer parameters.''' for m in net.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal(m.weight, mode='fan_out') if m.bias: init.constant(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): init.constant(m.weight, 1) init.constant(m.bias, 0) elif isinstance(m, nn.Linear): init.normal(m.weight, std=1e-3) if m.bias: init.constant(m.bias, 0) def mkdir_p(path): '''make dir if not exist''' try: os.makedirs(path) except OSError as exc: # Python >2.5 if exc.errno == errno.EEXIST and os.path.isdir(path): pass else: raise class AverageMeter(object): """Computes and stores the average and current value Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 """ def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count TOTAL_BAR_LENGTH = 65. last_time = time.time() begin_time = last_time def progress_bar(current, total, msg=None): global last_time, begin_time if current == 0: begin_time = time.time() # Reset for new bar. cur_len = int(TOTAL_BAR_LENGTH*current/total) rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 sys.stdout.write(' [') for i in range(cur_len): sys.stdout.write('=') sys.stdout.write('>') for i in range(rest_len): sys.stdout.write('.') sys.stdout.write(']') cur_time = time.time() step_time = cur_time - last_time last_time = cur_time tot_time = cur_time - begin_time L = [] L.append(' Step: %s' % format_time(step_time)) L.append(' | Tot: %s' % format_time(tot_time)) if msg: L.append(' | ' + msg) msg = ''.join(L) sys.stdout.write(msg) # for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): # sys.stdout.write(' ') # Go back to the center of the bar. # for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): # sys.stdout.write('\b') sys.stdout.write(' %d/%d ' % (current+1, total)) if current < total-1: sys.stdout.write('\r') else: sys.stdout.write('\n') sys.stdout.flush() def format_time(seconds): days = int(seconds / 3600/24) seconds = seconds - days*3600*24 hours = int(seconds / 3600) seconds = seconds - hours*3600 minutes = int(seconds / 60) seconds = seconds - minutes*60 secondsf = int(seconds) seconds = seconds - secondsf millis = int(seconds*1000) f = '' i = 1 if days > 0: f += str(days) + 'D' i += 1 if hours > 0 and i <= 2: f += str(hours) + 'h' i += 1 if minutes > 0 and i <= 2: f += str(minutes) + 'm' i += 1 if secondsf > 0 and i <= 2: f += str(secondsf) + 's' i += 1 if millis > 0 and i <= 2: f += str(millis) + 'ms' i += 1 if f == '': f = '0ms' return f def save_model(net, epoch, path, acc, is_best, **kwargs): state = { 'net': net.state_dict(), 'epoch': epoch, 'acc': acc } for key, value in kwargs.items(): state[key] = value filepath = os.path.join(path, "last_checkpoint.pth") torch.save(state, filepath) if is_best: shutil.copyfile(filepath, os.path.join(path, 'best_checkpoint.pth')) def save_args(args): file = open(os.path.join(args.checkpoint, 'args.txt'), "w") for k, v in vars(args).items(): file.write(f"{k}:\t {v}\n") file.close() def set_seed(seed=None): if seed is None: return random.seed(seed) os.environ['PYTHONHASHSEED'] = ("%s" % seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True # create a file and write the text into it class IOStream(): def __init__(self, path): self.f = open(path, 'a') def cprint(self, text): print(text) self.f.write(text+'\n') self.f.flush() def close(self): self.f.close() def cal_loss(pred, gold, smoothing=True): ''' Calculate cross entropy loss, apply label smoothing if needed. ''' gold = gold.contiguous().view(-1) if smoothing: eps = 0.2 n_class = pred.size(1) one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) log_prb = F.log_softmax(pred, dim=1) loss = -(one_hot * log_prb).sum(dim=1).mean() else: loss = F.cross_entropy(pred, gold, reduction='mean') return loss