diff --git a/README.md b/README.md index 3620a45..3516e10 100644 --- a/README.md +++ b/README.md @@ -154,7 +154,7 @@ You can also download it using the helper script: bash scripts/download_data.sh ``` -The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively. For Carvana, images are RGB and masks are black and white. +The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively (note that the `imgs` and `masks` folder should not contain any sub-folder or any other files, due to the greedy data-loader). For Carvana, images are RGB and masks are black and white. You can use your own dataset as long as you make sure it is loaded properly in `utils/data_loading.py`. diff --git a/evaluate.py b/evaluate.py index 504432f..2b4ebf9 100644 --- a/evaluate.py +++ b/evaluate.py @@ -35,4 +35,8 @@ def evaluate(net, dataloader, device): net.train() + + # Fixes a potential division by zero error + if num_val_batches == 0: + return dice_score return dice_score / num_val_batches diff --git a/train.py b/train.py index dc04e96..146373a 100644 --- a/train.py +++ b/train.py @@ -111,29 +111,31 @@ def train_net(net, pbar.set_postfix(**{'loss (batch)': loss.item()}) # Evaluation round - if global_step % (n_train // (10 * batch_size)) == 0: - histograms = {} - for tag, value in net.named_parameters(): - tag = tag.replace('/', '.') - histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu()) - histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu()) + division_step = (n_train // (10 * batch_size)) + if division_step > 0: + if global_step % division_step == 0: + histograms = {} + for tag, value in net.named_parameters(): + tag = tag.replace('/', '.') + histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu()) + histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu()) - val_score = evaluate(net, val_loader, device) - scheduler.step(val_score) + val_score = evaluate(net, val_loader, device) + scheduler.step(val_score) - logging.info('Validation Dice score: {}'.format(val_score)) - experiment.log({ - 'learning rate': optimizer.param_groups[0]['lr'], - 'validation Dice': val_score, - 'images': wandb.Image(images[0].cpu()), - 'masks': { - 'true': wandb.Image(true_masks[0].float().cpu()), - 'pred': wandb.Image(torch.softmax(masks_pred, dim=1)[0].float().cpu()), - }, - 'step': global_step, - 'epoch': epoch, - **histograms - }) + logging.info('Validation Dice score: {}'.format(val_score)) + experiment.log({ + 'learning rate': optimizer.param_groups[0]['lr'], + 'validation Dice': val_score, + 'images': wandb.Image(images[0].cpu()), + 'masks': { + 'true': wandb.Image(true_masks[0].float().cpu()), + 'pred': wandb.Image(torch.softmax(masks_pred, dim=1)[0].float().cpu()), + }, + 'step': global_step, + 'epoch': epoch, + **histograms + }) if save_checkpoint: Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)