bug fixes for i/o and low val set

Former-commit-id: bdcad8300e3c930b43b976ccd2562f27c9867892
This commit is contained in:
Arka 2021-10-24 17:07:54 -04:00
parent c912d9b726
commit 6438b1dcdd
3 changed files with 28 additions and 22 deletions

View file

@ -154,7 +154,7 @@ You can also download it using the helper script:
bash scripts/download_data.sh
```
The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively. For Carvana, images are RGB and masks are black and white.
The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively (note that the `imgs` and `masks` folder should not contain any sub-folder or any other files, due to the greedy data-loader). For Carvana, images are RGB and masks are black and white.
You can use your own dataset as long as you make sure it is loaded properly in `utils/data_loading.py`.

View file

@ -35,4 +35,8 @@ def evaluate(net, dataloader, device):
net.train()
# Fixes a potential division by zero error
if num_val_batches == 0:
return dice_score
return dice_score / num_val_batches

View file

@ -111,7 +111,9 @@ def train_net(net,
pbar.set_postfix(**{'loss (batch)': loss.item()})
# Evaluation round
if global_step % (n_train // (10 * batch_size)) == 0:
division_step = (n_train // (10 * batch_size))
if division_step > 0:
if global_step % division_step == 0:
histograms = {}
for tag, value in net.named_parameters():
tag = tag.replace('/', '.')