mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
bug fixes for i/o and low val set
Former-commit-id: bdcad8300e3c930b43b976ccd2562f27c9867892
This commit is contained in:
parent
c912d9b726
commit
6438b1dcdd
|
@ -154,7 +154,7 @@ You can also download it using the helper script:
|
|||
bash scripts/download_data.sh
|
||||
```
|
||||
|
||||
The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively. For Carvana, images are RGB and masks are black and white.
|
||||
The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively (note that the `imgs` and `masks` folder should not contain any sub-folder or any other files, due to the greedy data-loader). For Carvana, images are RGB and masks are black and white.
|
||||
|
||||
You can use your own dataset as long as you make sure it is loaded properly in `utils/data_loading.py`.
|
||||
|
||||
|
|
|
@ -35,4 +35,8 @@ def evaluate(net, dataloader, device):
|
|||
|
||||
|
||||
net.train()
|
||||
|
||||
# Fixes a potential division by zero error
|
||||
if num_val_batches == 0:
|
||||
return dice_score
|
||||
return dice_score / num_val_batches
|
||||
|
|
4
train.py
4
train.py
|
@ -111,7 +111,9 @@ def train_net(net,
|
|||
pbar.set_postfix(**{'loss (batch)': loss.item()})
|
||||
|
||||
# Evaluation round
|
||||
if global_step % (n_train // (10 * batch_size)) == 0:
|
||||
division_step = (n_train // (10 * batch_size))
|
||||
if division_step > 0:
|
||||
if global_step % division_step == 0:
|
||||
histograms = {}
|
||||
for tag, value in net.named_parameters():
|
||||
tag = tag.replace('/', '.')
|
||||
|
|
Loading…
Reference in a new issue