From d40a71737dcdcf4fd372a106e9108ab4df628b73 Mon Sep 17 00:00:00 2001 From: milesial Date: Sun, 15 Mar 2020 21:38:51 -0700 Subject: [PATCH] LR scheduler + TensorBoard weights and grad hists Former-commit-id: 0e7a96225ecc7761419aa67fef861ff6afae4c55 --- train.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 90a60ae..4aea941 100644 --- a/train.py +++ b/train.py @@ -25,7 +25,7 @@ def train_net(net, device, epochs=5, batch_size=1, - lr=0.1, + lr=0.001, val_percent=0.1, save_cp=True, img_scale=0.5): @@ -51,7 +51,8 @@ def train_net(net, Images scaling: {img_scale} ''') - optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8) + optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9) + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2) if net.n_classes > 1: criterion = nn.CrossEntropyLoss() else: @@ -83,16 +84,23 @@ def train_net(net, optimizer.zero_grad() loss.backward() + nn.utils.clip_grad_value_(net.parameters(), 0.1) optimizer.step() pbar.update(imgs.shape[0]) global_step += 1 if global_step % (len(dataset) // (10 * batch_size)) == 0: + for tag, value in net.named_parameters(): + tag = tag.replace('.', '/') + writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step) + writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step) val_score = eval_net(net, val_loader, device) + scheduler.step(val_score) + writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step) + if net.n_classes > 1: logging.info('Validation cross entropy: {}'.format(val_score)) writer.add_scalar('Loss/test', val_score, global_step) - else: logging.info('Validation Dice Coeff: {}'.format(val_score)) writer.add_scalar('Dice/test', val_score, global_step)