LR scheduler + TensorBoard weights and grad hists

Former-commit-id: 0e7a96225ecc7761419aa67fef861ff6afae4c55
This commit is contained in:
milesial 2020-03-15 21:38:51 -07:00
parent 54ba0e5d54
commit d40a71737d

View file

@ -25,7 +25,7 @@ def train_net(net,
device, device,
epochs=5, epochs=5,
batch_size=1, batch_size=1,
lr=0.1, lr=0.001,
val_percent=0.1, val_percent=0.1,
save_cp=True, save_cp=True,
img_scale=0.5): img_scale=0.5):
@ -51,7 +51,8 @@ def train_net(net,
Images scaling: {img_scale} Images scaling: {img_scale}
''') ''')
optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8) optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)
if net.n_classes > 1: if net.n_classes > 1:
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
else: else:
@ -83,16 +84,23 @@ def train_net(net,
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
nn.utils.clip_grad_value_(net.parameters(), 0.1)
optimizer.step() optimizer.step()
pbar.update(imgs.shape[0]) pbar.update(imgs.shape[0])
global_step += 1 global_step += 1
if global_step % (len(dataset) // (10 * batch_size)) == 0: if global_step % (len(dataset) // (10 * batch_size)) == 0:
for tag, value in net.named_parameters():
tag = tag.replace('.', '/')
writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)
writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)
val_score = eval_net(net, val_loader, device) val_score = eval_net(net, val_loader, device)
scheduler.step(val_score)
writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)
if net.n_classes > 1: if net.n_classes > 1:
logging.info('Validation cross entropy: {}'.format(val_score)) logging.info('Validation cross entropy: {}'.format(val_score))
writer.add_scalar('Loss/test', val_score, global_step) writer.add_scalar('Loss/test', val_score, global_step)
else: else:
logging.info('Validation Dice Coeff: {}'.format(val_score)) logging.info('Validation Dice Coeff: {}'.format(val_score))
writer.add_scalar('Dice/test', val_score, global_step) writer.add_scalar('Dice/test', val_score, global_step)