mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
LR scheduler + TensorBoard weights and grad hists
Former-commit-id: 0e7a96225ecc7761419aa67fef861ff6afae4c55
This commit is contained in:
parent
54ba0e5d54
commit
d40a71737d
14
train.py
14
train.py
|
@ -25,7 +25,7 @@ def train_net(net,
|
|||
device,
|
||||
epochs=5,
|
||||
batch_size=1,
|
||||
lr=0.1,
|
||||
lr=0.001,
|
||||
val_percent=0.1,
|
||||
save_cp=True,
|
||||
img_scale=0.5):
|
||||
|
@ -51,7 +51,8 @@ def train_net(net,
|
|||
Images scaling: {img_scale}
|
||||
''')
|
||||
|
||||
optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8)
|
||||
optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)
|
||||
if net.n_classes > 1:
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
else:
|
||||
|
@ -83,16 +84,23 @@ def train_net(net,
|
|||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_value_(net.parameters(), 0.1)
|
||||
optimizer.step()
|
||||
|
||||
pbar.update(imgs.shape[0])
|
||||
global_step += 1
|
||||
if global_step % (len(dataset) // (10 * batch_size)) == 0:
|
||||
for tag, value in net.named_parameters():
|
||||
tag = tag.replace('.', '/')
|
||||
writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)
|
||||
writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)
|
||||
val_score = eval_net(net, val_loader, device)
|
||||
scheduler.step(val_score)
|
||||
writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)
|
||||
|
||||
if net.n_classes > 1:
|
||||
logging.info('Validation cross entropy: {}'.format(val_score))
|
||||
writer.add_scalar('Loss/test', val_score, global_step)
|
||||
|
||||
else:
|
||||
logging.info('Validation Dice Coeff: {}'.format(val_score))
|
||||
writer.add_scalar('Dice/test', val_score, global_step)
|
||||
|
|
Loading…
Reference in a new issue