mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
Move the sigmoid activation to the model itself
Former-commit-id: e3f8ca7b1ac7c5e9694637a81be260e9b48973b9
This commit is contained in:
parent
46d1db3115
commit
81234be7da
2
eval.py
2
eval.py
|
@ -20,7 +20,7 @@ def eval_net(net, dataset, gpu=False):
|
||||||
true_mask = true_mask.cuda()
|
true_mask = true_mask.cuda()
|
||||||
|
|
||||||
mask_pred = net(img)[0]
|
mask_pred = net(img)[0]
|
||||||
mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
|
mask_pred = (mask_pred > 0.5).float()
|
||||||
|
|
||||||
tot += dice_coeff(mask_pred, true_mask).item()
|
tot += dice_coeff(mask_pred, true_mask).item()
|
||||||
return tot / i
|
return tot / i
|
||||||
|
|
|
@ -43,8 +43,8 @@ def predict_img(net,
|
||||||
output_left = net(X_left)
|
output_left = net(X_left)
|
||||||
output_right = net(X_right)
|
output_right = net(X_right)
|
||||||
|
|
||||||
left_probs = F.sigmoid(output_left).squeeze(0)
|
left_probs = output_left.squeeze(0)
|
||||||
right_probs = F.sigmoid(output_right).squeeze(0)
|
right_probs = output_right.squeeze(0)
|
||||||
|
|
||||||
tf = transforms.Compose(
|
tf = transforms.Compose(
|
||||||
[
|
[
|
||||||
|
|
4
train.py
4
train.py
|
@ -6,7 +6,6 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.backends.cudnn as cudnn
|
import torch.backends.cudnn as cudnn
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import optim
|
from torch import optim
|
||||||
|
|
||||||
from eval import eval_net
|
from eval import eval_net
|
||||||
|
@ -74,8 +73,7 @@ def train_net(net,
|
||||||
true_masks = true_masks.cuda()
|
true_masks = true_masks.cuda()
|
||||||
|
|
||||||
masks_pred = net(imgs)
|
masks_pred = net(imgs)
|
||||||
masks_probs = F.sigmoid(masks_pred)
|
masks_probs_flat = masks_pred.view(-1)
|
||||||
masks_probs_flat = masks_probs.view(-1)
|
|
||||||
|
|
||||||
true_masks_flat = true_masks.view(-1)
|
true_masks_flat = true_masks.view(-1)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
# full assembly of the sub-parts to form the complete net
|
# full assembly of the sub-parts to form the complete net
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from .unet_parts import *
|
from .unet_parts import *
|
||||||
|
|
||||||
class UNet(nn.Module):
|
class UNet(nn.Module):
|
||||||
|
@ -27,4 +29,4 @@ class UNet(nn.Module):
|
||||||
x = self.up3(x, x2)
|
x = self.up3(x, x2)
|
||||||
x = self.up4(x, x1)
|
x = self.up4(x, x1)
|
||||||
x = self.outc(x)
|
x = self.outc(x)
|
||||||
return x
|
return F.sigmoid(x)
|
||||||
|
|
Loading…
Reference in a new issue