Remove assert for multiclass
Former-commit-id: 16809ec7ba14f67bfa5e5a095d3b512ef669a964
This commit is contained in:
parent
67c4dd34a0
commit
4e1f0398a1
5
train.py
5
train.py
|
@ -70,11 +70,6 @@ def train_net(net,
|
||||||
f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
|
f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
|
||||||
'the images are loaded correctly.'
|
'the images are loaded correctly.'
|
||||||
|
|
||||||
assert true_masks.shape[1] == net.n_classes, \
|
|
||||||
f'Network has been defined with {net.n_classes} output classes, ' \
|
|
||||||
f'but loaded masks have {true_masks.shape[1]} channels. Please check that ' \
|
|
||||||
'the masks are loaded correctly.'
|
|
||||||
|
|
||||||
imgs = imgs.to(device=device, dtype=torch.float32)
|
imgs = imgs.to(device=device, dtype=torch.float32)
|
||||||
true_masks = true_masks.to(device=device, dtype=torch.float32)
|
true_masks = true_masks.to(device=device, dtype=torch.float32)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue