Merge pull request #308 from Arka161/bug_fixes
Fixed bugs for no validation set, and potential module zero error Former-commit-id: 4f283f1b7e47992a2af1bc738c54562d5cd117c8
This commit is contained in:
commit
229dd44fca
|
@ -154,7 +154,7 @@ You can also download it using the helper script:
|
||||||
bash scripts/download_data.sh
|
bash scripts/download_data.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively. For Carvana, images are RGB and masks are black and white.
|
The input images and target masks should be in the `data/imgs` and `data/masks` folders respectively (note that the `imgs` and `masks` folder should not contain any sub-folder or any other files, due to the greedy data-loader). For Carvana, images are RGB and masks are black and white.
|
||||||
|
|
||||||
You can use your own dataset as long as you make sure it is loaded properly in `utils/data_loading.py`.
|
You can use your own dataset as long as you make sure it is loaded properly in `utils/data_loading.py`.
|
||||||
|
|
||||||
|
|
|
@ -35,4 +35,8 @@ def evaluate(net, dataloader, device):
|
||||||
|
|
||||||
|
|
||||||
net.train()
|
net.train()
|
||||||
|
|
||||||
|
# Fixes a potential division by zero error
|
||||||
|
if num_val_batches == 0:
|
||||||
|
return dice_score
|
||||||
return dice_score / num_val_batches
|
return dice_score / num_val_batches
|
||||||
|
|
4
train.py
4
train.py
|
@ -111,7 +111,9 @@ def train_net(net,
|
||||||
pbar.set_postfix(**{'loss (batch)': loss.item()})
|
pbar.set_postfix(**{'loss (batch)': loss.item()})
|
||||||
|
|
||||||
# Evaluation round
|
# Evaluation round
|
||||||
if global_step % (n_train // (10 * batch_size)) == 0:
|
division_step = (n_train // (10 * batch_size))
|
||||||
|
if division_step > 0:
|
||||||
|
if global_step % division_step == 0:
|
||||||
histograms = {}
|
histograms = {}
|
||||||
for tag, value in net.named_parameters():
|
for tag, value in net.named_parameters():
|
||||||
tag = tag.replace('/', '.')
|
tag = tag.replace('/', '.')
|
||||||
|
|
Loading…
Reference in a new issue