bug fixes for i/o and low val set

Former-commit-id: bdcad8300e3c930b43b976ccd2562f27c9867892
This commit is contained in:
Arka 2021-10-24 17:07:54 -04:00
parent c912d9b726
commit 6438b1dcdd
3 changed files with 28 additions and 22 deletions

View file

@ -154,7 +154,7 @@ You can also download it using the helper script:
bash scripts/download_data.sh bash scripts/download_data.sh
``` ```
The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively. For Carvana, images are RGB and masks are black and white. The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively (note that the `imgs` and `masks` folder should not contain any sub-folder or any other files, due to the greedy data-loader). For Carvana, images are RGB and masks are black and white.
You can use your own dataset as long as you make sure it is loaded properly in `utils/data_loading.py`. You can use your own dataset as long as you make sure it is loaded properly in `utils/data_loading.py`.

View file

@ -35,4 +35,8 @@ def evaluate(net, dataloader, device):
net.train() net.train()
# Fixes a potential division by zero error
if num_val_batches == 0:
return dice_score
return dice_score / num_val_batches return dice_score / num_val_batches

View file

@ -111,29 +111,31 @@ def train_net(net,
pbar.set_postfix(**{'loss (batch)': loss.item()}) pbar.set_postfix(**{'loss (batch)': loss.item()})
# Evaluation round # Evaluation round
if global_step % (n_train // (10 * batch_size)) == 0: division_step = (n_train // (10 * batch_size))
histograms = {} if division_step > 0:
for tag, value in net.named_parameters(): if global_step % division_step == 0:
tag = tag.replace('/', '.') histograms = {}
histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu()) for tag, value in net.named_parameters():
histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu()) tag = tag.replace('/', '.')
histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())
val_score = evaluate(net, val_loader, device) val_score = evaluate(net, val_loader, device)
scheduler.step(val_score) scheduler.step(val_score)
logging.info('Validation Dice score: {}'.format(val_score)) logging.info('Validation Dice score: {}'.format(val_score))
experiment.log({ experiment.log({
'learning rate': optimizer.param_groups[0]['lr'], 'learning rate': optimizer.param_groups[0]['lr'],
'validation Dice': val_score, 'validation Dice': val_score,
'images': wandb.Image(images[0].cpu()), 'images': wandb.Image(images[0].cpu()),
'masks': { 'masks': {
'true': wandb.Image(true_masks[0].float().cpu()), 'true': wandb.Image(true_masks[0].float().cpu()),
'pred': wandb.Image(torch.softmax(masks_pred, dim=1)[0].float().cpu()), 'pred': wandb.Image(torch.softmax(masks_pred, dim=1)[0].float().cpu()),
}, },
'step': global_step, 'step': global_step,
'epoch': epoch, 'epoch': epoch,
**histograms **histograms
}) })
if save_checkpoint: if save_checkpoint:
Path(dir_checkpoint).mkdir(parents=True, exist_ok=True) Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)