mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
bug fixes for i/o and low val set
Former-commit-id: bdcad8300e3c930b43b976ccd2562f27c9867892
This commit is contained in:
parent
c912d9b726
commit
6438b1dcdd
|
@ -154,7 +154,7 @@ You can also download it using the helper script:
|
|||
bash scripts/download_data.sh
|
||||
```
|
||||
|
||||
The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively. For Carvana, images are RGB and masks are black and white.
|
||||
The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively (note that the `imgs` and `masks` folder should not contain any sub-folder or any other files, due to the greedy data-loader). For Carvana, images are RGB and masks are black and white.
|
||||
|
||||
You can use your own dataset as long as you make sure it is loaded properly in `utils/data_loading.py`.
|
||||
|
||||
|
|
|
@ -35,4 +35,8 @@ def evaluate(net, dataloader, device):
|
|||
|
||||
|
||||
net.train()
|
||||
|
||||
# Fixes a potential division by zero error
|
||||
if num_val_batches == 0:
|
||||
return dice_score
|
||||
return dice_score / num_val_batches
|
||||
|
|
44
train.py
44
train.py
|
@ -111,29 +111,31 @@ def train_net(net,
|
|||
pbar.set_postfix(**{'loss (batch)': loss.item()})
|
||||
|
||||
# Evaluation round
|
||||
if global_step % (n_train // (10 * batch_size)) == 0:
|
||||
histograms = {}
|
||||
for tag, value in net.named_parameters():
|
||||
tag = tag.replace('/', '.')
|
||||
histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
|
||||
histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())
|
||||
division_step = (n_train // (10 * batch_size))
|
||||
if division_step > 0:
|
||||
if global_step % division_step == 0:
|
||||
histograms = {}
|
||||
for tag, value in net.named_parameters():
|
||||
tag = tag.replace('/', '.')
|
||||
histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
|
||||
histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())
|
||||
|
||||
val_score = evaluate(net, val_loader, device)
|
||||
scheduler.step(val_score)
|
||||
val_score = evaluate(net, val_loader, device)
|
||||
scheduler.step(val_score)
|
||||
|
||||
logging.info('Validation Dice score: {}'.format(val_score))
|
||||
experiment.log({
|
||||
'learning rate': optimizer.param_groups[0]['lr'],
|
||||
'validation Dice': val_score,
|
||||
'images': wandb.Image(images[0].cpu()),
|
||||
'masks': {
|
||||
'true': wandb.Image(true_masks[0].float().cpu()),
|
||||
'pred': wandb.Image(torch.softmax(masks_pred, dim=1)[0].float().cpu()),
|
||||
},
|
||||
'step': global_step,
|
||||
'epoch': epoch,
|
||||
**histograms
|
||||
})
|
||||
logging.info('Validation Dice score: {}'.format(val_score))
|
||||
experiment.log({
|
||||
'learning rate': optimizer.param_groups[0]['lr'],
|
||||
'validation Dice': val_score,
|
||||
'images': wandb.Image(images[0].cpu()),
|
||||
'masks': {
|
||||
'true': wandb.Image(true_masks[0].float().cpu()),
|
||||
'pred': wandb.Image(torch.softmax(masks_pred, dim=1)[0].float().cpu()),
|
||||
},
|
||||
'step': global_step,
|
||||
'epoch': epoch,
|
||||
**histograms
|
||||
})
|
||||
|
||||
if save_checkpoint:
|
||||
Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
|
|
Loading…
Reference in a new issue