bug fixes for i/o and low val set
Former-commit-id: bdcad8300e3c930b43b976ccd2562f27c9867892
This commit is contained in:
parent
c912d9b726
commit
6438b1dcdd
|
@ -154,7 +154,7 @@ You can also download it using the helper script:
|
||||||
bash scripts/download_data.sh
|
bash scripts/download_data.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively. For Carvana, images are RGB and masks are black and white.
|
The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively (note that the `imgs` and `masks` folder should not contain any sub-folder or any other files, due to the greedy data-loader). For Carvana, images are RGB and masks are black and white.
|
||||||
|
|
||||||
You can use your own dataset as long as you make sure it is loaded properly in `utils/data_loading.py`.
|
You can use your own dataset as long as you make sure it is loaded properly in `utils/data_loading.py`.
|
||||||
|
|
||||||
|
|
|
@ -35,4 +35,8 @@ def evaluate(net, dataloader, device):
|
||||||
|
|
||||||
|
|
||||||
net.train()
|
net.train()
|
||||||
|
|
||||||
|
# Fixes a potential division by zero error
|
||||||
|
if num_val_batches == 0:
|
||||||
|
return dice_score
|
||||||
return dice_score / num_val_batches
|
return dice_score / num_val_batches
|
||||||
|
|
44
train.py
44
train.py
|
@ -111,29 +111,31 @@ def train_net(net,
|
||||||
pbar.set_postfix(**{'loss (batch)': loss.item()})
|
pbar.set_postfix(**{'loss (batch)': loss.item()})
|
||||||
|
|
||||||
# Evaluation round
|
# Evaluation round
|
||||||
if global_step % (n_train // (10 * batch_size)) == 0:
|
division_step = (n_train // (10 * batch_size))
|
||||||
histograms = {}
|
if division_step > 0:
|
||||||
for tag, value in net.named_parameters():
|
if global_step % division_step == 0:
|
||||||
tag = tag.replace('/', '.')
|
histograms = {}
|
||||||
histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
|
for tag, value in net.named_parameters():
|
||||||
histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())
|
tag = tag.replace('/', '.')
|
||||||
|
histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
|
||||||
|
histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())
|
||||||
|
|
||||||
val_score = evaluate(net, val_loader, device)
|
val_score = evaluate(net, val_loader, device)
|
||||||
scheduler.step(val_score)
|
scheduler.step(val_score)
|
||||||
|
|
||||||
logging.info('Validation Dice score: {}'.format(val_score))
|
logging.info('Validation Dice score: {}'.format(val_score))
|
||||||
experiment.log({
|
experiment.log({
|
||||||
'learning rate': optimizer.param_groups[0]['lr'],
|
'learning rate': optimizer.param_groups[0]['lr'],
|
||||||
'validation Dice': val_score,
|
'validation Dice': val_score,
|
||||||
'images': wandb.Image(images[0].cpu()),
|
'images': wandb.Image(images[0].cpu()),
|
||||||
'masks': {
|
'masks': {
|
||||||
'true': wandb.Image(true_masks[0].float().cpu()),
|
'true': wandb.Image(true_masks[0].float().cpu()),
|
||||||
'pred': wandb.Image(torch.softmax(masks_pred, dim=1)[0].float().cpu()),
|
'pred': wandb.Image(torch.softmax(masks_pred, dim=1)[0].float().cpu()),
|
||||||
},
|
},
|
||||||
'step': global_step,
|
'step': global_step,
|
||||||
'epoch': epoch,
|
'epoch': epoch,
|
||||||
**histograms
|
**histograms
|
||||||
})
|
})
|
||||||
|
|
||||||
if save_checkpoint:
|
if save_checkpoint:
|
||||||
Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
|
Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
Loading…
Reference in a new issue