From 81234be7dab7ddad315931d21d26d374b2d30fb9 Mon Sep 17 00:00:00 2001 From: rht Date: Sat, 10 Nov 2018 22:42:16 +0000 Subject: [PATCH] Move the sigmoid activation to the model itself Former-commit-id: e3f8ca7b1ac7c5e9694637a81be260e9b48973b9 --- eval.py | 2 +- predict.py | 4 ++-- train.py | 4 +--- unet/unet_model.py | 4 +++- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/eval.py b/eval.py index 57c9972..1aada5e 100644 --- a/eval.py +++ b/eval.py @@ -20,7 +20,7 @@ def eval_net(net, dataset, gpu=False): true_mask = true_mask.cuda() mask_pred = net(img)[0] - mask_pred = (F.sigmoid(mask_pred) > 0.5).float() + mask_pred = (mask_pred > 0.5).float() tot += dice_coeff(mask_pred, true_mask).item() return tot / i diff --git a/predict.py b/predict.py index 5347d88..8b08847 100755 --- a/predict.py +++ b/predict.py @@ -43,8 +43,8 @@ def predict_img(net, output_left = net(X_left) output_right = net(X_right) - left_probs = F.sigmoid(output_left).squeeze(0) - right_probs = F.sigmoid(output_right).squeeze(0) + left_probs = output_left.squeeze(0) + right_probs = output_right.squeeze(0) tf = transforms.Compose( [ diff --git a/train.py b/train.py index a470606..cfdeb3b 100644 --- a/train.py +++ b/train.py @@ -6,7 +6,6 @@ import numpy as np import torch import torch.backends.cudnn as cudnn import torch.nn as nn -import torch.nn.functional as F from torch import optim from eval import eval_net @@ -74,8 +73,7 @@ def train_net(net, true_masks = true_masks.cuda() masks_pred = net(imgs) - masks_probs = F.sigmoid(masks_pred) - masks_probs_flat = masks_probs.view(-1) + masks_probs_flat = masks_pred.view(-1) true_masks_flat = true_masks.view(-1) diff --git a/unet/unet_model.py b/unet/unet_model.py index a09ee5b..5990649 100644 --- a/unet/unet_model.py +++ b/unet/unet_model.py @@ -1,5 +1,7 @@ # full assembly of the sub-parts to form the complete net +import torch.nn.functional as F + from .unet_parts import * class UNet(nn.Module): @@ -27,4 +29,4 @@ class UNet(nn.Module): x = self.up3(x, x2) x = self.up4(x, x1) x = self.outc(x) - return x + return F.sigmoid(x)