Merge pull request #308 from Arka161/bug_fixes

Fixed bugs for no validation set, and potential module zero error

Former-commit-id: 4f283f1b7e47992a2af1bc738c54562d5cd117c8
This commit is contained in:
milesial 2021-10-25 15:05:03 +02:00 committed by GitHub
commit 229dd44fca
3 changed files with 28 additions and 22 deletions

View file

@ -154,7 +154,7 @@ You can also download it using the helper script:
bash scripts/download_data.sh bash scripts/download_data.sh
``` ```
The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively. For Carvana, images are RGB and masks are black and white. The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively (note that the `imgs` and `masks` folder should not contain any sub-folder or any other files, due to the greedy data-loader). For Carvana, images are RGB and masks are black and white.
You can use your own dataset as long as you make sure it is loaded properly in `utils/data_loading.py`. You can use your own dataset as long as you make sure it is loaded properly in `utils/data_loading.py`.

View file

@ -35,4 +35,8 @@ def evaluate(net, dataloader, device):
net.train() net.train()
# Fixes a potential division by zero error
if num_val_batches == 0:
return dice_score
return dice_score / num_val_batches return dice_score / num_val_batches

View file

@ -111,7 +111,9 @@ def train_net(net,
pbar.set_postfix(**{'loss (batch)': loss.item()}) pbar.set_postfix(**{'loss (batch)': loss.item()})
# Evaluation round # Evaluation round
if global_step % (n_train // (10 * batch_size)) == 0: division_step = (n_train // (10 * batch_size))
if division_step > 0:
if global_step % division_step == 0:
histograms = {} histograms = {}
for tag, value in net.named_parameters(): for tag, value in net.named_parameters():
tag = tag.replace('/', '.') tag = tag.replace('/', '.')