Merge pull request #308 from Arka161/bug_fixes

Fixed bugs for no validation set, and potential module zero error

Former-commit-id: 4f283f1b7e47992a2af1bc738c54562d5cd117c8
This commit is contained in:
milesial 2021-10-25 15:05:03 +02:00 committed by GitHub
commit 229dd44fca
3 changed files with 28 additions and 22 deletions

View file

@ -154,7 +154,7 @@ You can also download it using the helper script:
bash scripts/download_data.sh bash scripts/download_data.sh
``` ```
The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively. For Carvana, images are RGB and masks are black and white. The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively (note that the `imgs` and `masks` folder should not contain any sub-folder or any other files, due to the greedy data-loader). For Carvana, images are RGB and masks are black and white.
You can use your own dataset as long as you make sure it is loaded properly in `utils/data_loading.py`. You can use your own dataset as long as you make sure it is loaded properly in `utils/data_loading.py`.

View file

@ -35,4 +35,8 @@ def evaluate(net, dataloader, device):
net.train() net.train()
# Fixes a potential division by zero error
if num_val_batches == 0:
return dice_score
return dice_score / num_val_batches return dice_score / num_val_batches

View file

@ -111,29 +111,31 @@ def train_net(net,
pbar.set_postfix(**{'loss (batch)': loss.item()}) pbar.set_postfix(**{'loss (batch)': loss.item()})
# Evaluation round # Evaluation round
if global_step % (n_train // (10 * batch_size)) == 0: division_step = (n_train // (10 * batch_size))
histograms = {} if division_step > 0:
for tag, value in net.named_parameters(): if global_step % division_step == 0:
tag = tag.replace('/', '.') histograms = {}
histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu()) for tag, value in net.named_parameters():
histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu()) tag = tag.replace('/', '.')
histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())
val_score = evaluate(net, val_loader, device) val_score = evaluate(net, val_loader, device)
scheduler.step(val_score) scheduler.step(val_score)
logging.info('Validation Dice score: {}'.format(val_score)) logging.info('Validation Dice score: {}'.format(val_score))
experiment.log({ experiment.log({
'learning rate': optimizer.param_groups[0]['lr'], 'learning rate': optimizer.param_groups[0]['lr'],
'validation Dice': val_score, 'validation Dice': val_score,
'images': wandb.Image(images[0].cpu()), 'images': wandb.Image(images[0].cpu()),
'masks': { 'masks': {
'true': wandb.Image(true_masks[0].float().cpu()), 'true': wandb.Image(true_masks[0].float().cpu()),
'pred': wandb.Image(torch.softmax(masks_pred, dim=1)[0].float().cpu()), 'pred': wandb.Image(torch.softmax(masks_pred, dim=1)[0].float().cpu()),
}, },
'step': global_step, 'step': global_step,
'epoch': epoch, 'epoch': epoch,
**histograms **histograms
}) })
if save_checkpoint: if save_checkpoint:
Path(dir_checkpoint).mkdir(parents=True, exist_ok=True) Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)