127 lines
4.3 KiB
Python
127 lines
4.3 KiB
Python
|
# A simple torch style logger
|
||
|
# (C) Wei YANG 2017
|
||
|
from __future__ import absolute_import
|
||
|
import matplotlib.pyplot as plt
|
||
|
import os
|
||
|
import sys
|
||
|
import numpy as np
|
||
|
|
||
|
__all__ = ['Logger', 'LoggerMonitor', 'savefig']
|
||
|
|
||
|
def savefig(fname, dpi=None):
|
||
|
dpi = 150 if dpi == None else dpi
|
||
|
plt.savefig(fname, dpi=dpi)
|
||
|
|
||
|
def plot_overlap(logger, names=None):
|
||
|
names = logger.names if names == None else names
|
||
|
numbers = logger.numbers
|
||
|
for _, name in enumerate(names):
|
||
|
x = np.arange(len(numbers[name]))
|
||
|
plt.plot(x, np.asarray(numbers[name]))
|
||
|
return [logger.title + '(' + name + ')' for name in names]
|
||
|
|
||
|
class Logger(object):
|
||
|
'''Save training process to log file with simple plot function.'''
|
||
|
def __init__(self, fpath, title=None, resume=False):
|
||
|
self.file = None
|
||
|
self.resume = resume
|
||
|
self.title = '' if title == None else title
|
||
|
if fpath is not None:
|
||
|
if resume:
|
||
|
self.file = open(fpath, 'r')
|
||
|
name = self.file.readline()
|
||
|
self.names = name.rstrip().split('\t')
|
||
|
self.numbers = {}
|
||
|
for _, name in enumerate(self.names):
|
||
|
self.numbers[name] = []
|
||
|
|
||
|
for numbers in self.file:
|
||
|
numbers = numbers.rstrip().split('\t')
|
||
|
for i in range(0, len(numbers)):
|
||
|
self.numbers[self.names[i]].append(numbers[i])
|
||
|
self.file.close()
|
||
|
self.file = open(fpath, 'a')
|
||
|
else:
|
||
|
self.file = open(fpath, 'w')
|
||
|
|
||
|
def set_names(self, names):
|
||
|
if self.resume:
|
||
|
pass
|
||
|
# initialize numbers as empty list
|
||
|
self.numbers = {}
|
||
|
self.names = names
|
||
|
for _, name in enumerate(self.names):
|
||
|
self.file.write(name)
|
||
|
self.file.write('\t')
|
||
|
self.numbers[name] = []
|
||
|
self.file.write('\n')
|
||
|
self.file.flush()
|
||
|
|
||
|
|
||
|
def append(self, numbers):
|
||
|
assert len(self.names) == len(numbers), 'Numbers do not match names'
|
||
|
for index, num in enumerate(numbers):
|
||
|
self.file.write("{0:.6f}".format(num))
|
||
|
self.file.write('\t')
|
||
|
self.numbers[self.names[index]].append(num)
|
||
|
self.file.write('\n')
|
||
|
self.file.flush()
|
||
|
|
||
|
def plot(self, names=None):
|
||
|
names = self.names if names == None else names
|
||
|
numbers = self.numbers
|
||
|
for _, name in enumerate(names):
|
||
|
x = np.arange(len(numbers[name]))
|
||
|
plt.plot(x, np.asarray(numbers[name]))
|
||
|
plt.legend([self.title + '(' + name + ')' for name in names])
|
||
|
plt.grid(True)
|
||
|
|
||
|
def close(self):
|
||
|
if self.file is not None:
|
||
|
self.file.close()
|
||
|
|
||
|
class LoggerMonitor(object):
|
||
|
'''Load and visualize multiple logs.'''
|
||
|
def __init__ (self, paths):
|
||
|
'''paths is a distionary with {name:filepath} pair'''
|
||
|
self.loggers = []
|
||
|
for title, path in paths.items():
|
||
|
logger = Logger(path, title=title, resume=True)
|
||
|
self.loggers.append(logger)
|
||
|
|
||
|
def plot(self, names=None):
|
||
|
plt.figure()
|
||
|
plt.subplot(121)
|
||
|
legend_text = []
|
||
|
for logger in self.loggers:
|
||
|
legend_text += plot_overlap(logger, names)
|
||
|
plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
|
||
|
plt.grid(True)
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
# # Example
|
||
|
# logger = Logger('test.txt')
|
||
|
# logger.set_names(['Train loss', 'Valid loss','Test loss'])
|
||
|
|
||
|
# length = 100
|
||
|
# t = np.arange(length)
|
||
|
# train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
|
||
|
# valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
|
||
|
# test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
|
||
|
|
||
|
# for i in range(0, length):
|
||
|
# logger.append([train_loss[i], valid_loss[i], test_loss[i]])
|
||
|
# logger.plot()
|
||
|
|
||
|
# Example: logger monitor
|
||
|
paths = {
|
||
|
'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
|
||
|
'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
|
||
|
'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
|
||
|
}
|
||
|
|
||
|
field = ['Valid Acc.']
|
||
|
|
||
|
monitor = LoggerMonitor(paths)
|
||
|
monitor.plot(names=field)
|
||
|
savefig('test.eps')
|