mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
LR scheduler + TensorBoard weights and grad hists
Former-commit-id: 0e7a96225ecc7761419aa67fef861ff6afae4c55
This commit is contained in:
parent
54ba0e5d54
commit
d40a71737d
14
train.py
14
train.py
|
@ -25,7 +25,7 @@ def train_net(net,
|
||||||
device,
|
device,
|
||||||
epochs=5,
|
epochs=5,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
lr=0.1,
|
lr=0.001,
|
||||||
val_percent=0.1,
|
val_percent=0.1,
|
||||||
save_cp=True,
|
save_cp=True,
|
||||||
img_scale=0.5):
|
img_scale=0.5):
|
||||||
|
@ -51,7 +51,8 @@ def train_net(net,
|
||||||
Images scaling: {img_scale}
|
Images scaling: {img_scale}
|
||||||
''')
|
''')
|
||||||
|
|
||||||
optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8)
|
optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
|
||||||
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)
|
||||||
if net.n_classes > 1:
|
if net.n_classes > 1:
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
else:
|
else:
|
||||||
|
@ -83,16 +84,23 @@ def train_net(net,
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
nn.utils.clip_grad_value_(net.parameters(), 0.1)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
pbar.update(imgs.shape[0])
|
pbar.update(imgs.shape[0])
|
||||||
global_step += 1
|
global_step += 1
|
||||||
if global_step % (len(dataset) // (10 * batch_size)) == 0:
|
if global_step % (len(dataset) // (10 * batch_size)) == 0:
|
||||||
|
for tag, value in net.named_parameters():
|
||||||
|
tag = tag.replace('.', '/')
|
||||||
|
writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)
|
||||||
|
writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)
|
||||||
val_score = eval_net(net, val_loader, device)
|
val_score = eval_net(net, val_loader, device)
|
||||||
|
scheduler.step(val_score)
|
||||||
|
writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)
|
||||||
|
|
||||||
if net.n_classes > 1:
|
if net.n_classes > 1:
|
||||||
logging.info('Validation cross entropy: {}'.format(val_score))
|
logging.info('Validation cross entropy: {}'.format(val_score))
|
||||||
writer.add_scalar('Loss/test', val_score, global_step)
|
writer.add_scalar('Loss/test', val_score, global_step)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logging.info('Validation Dice Coeff: {}'.format(val_score))
|
logging.info('Validation Dice Coeff: {}'.format(val_score))
|
||||||
writer.add_scalar('Dice/test', val_score, global_step)
|
writer.add_scalar('Dice/test', val_score, global_step)
|
||||||
|
|
Loading…
Reference in a new issue