mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
Cleanup + now using tensorboard
Former-commit-id: 79928c84cdf990ef6fe1043a3e4f74b9cc252642
This commit is contained in:
parent
35f955cbf8
commit
9d7be6e234
8
eval.py
8
eval.py
|
@ -10,9 +10,10 @@ def eval_net(net, loader, device, n_val):
|
||||||
net.eval()
|
net.eval()
|
||||||
tot = 0
|
tot = 0
|
||||||
|
|
||||||
for i, b in tqdm(enumerate(loader), desc='Validation round', unit='img'):
|
with tqdm(total=n_val, desc='Validation round', unit='img', leave=False) as pbar:
|
||||||
imgs = b['image']
|
for batch in loader:
|
||||||
true_masks = b['mask']
|
imgs = batch['image']
|
||||||
|
true_masks = batch['mask']
|
||||||
|
|
||||||
imgs = imgs.to(device=device, dtype=torch.float32)
|
imgs = imgs.to(device=device, dtype=torch.float32)
|
||||||
true_masks = true_masks.to(device=device, dtype=torch.float32)
|
true_masks = true_masks.to(device=device, dtype=torch.float32)
|
||||||
|
@ -25,5 +26,6 @@ def eval_net(net, loader, device, n_val):
|
||||||
tot += F.cross_entropy(mask_pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item()
|
tot += F.cross_entropy(mask_pred.unsqueeze(dim=0), true_mask.unsqueeze(dim=0)).item()
|
||||||
else:
|
else:
|
||||||
tot += dice_coeff(mask_pred, true_mask.squeeze(dim=1)).item()
|
tot += dice_coeff(mask_pred, true_mask.squeeze(dim=1)).item()
|
||||||
|
pbar.update(imgs.shape[0])
|
||||||
|
|
||||||
return tot / n_val
|
return tot / n_val
|
||||||
|
|
|
@ -9,8 +9,10 @@ from torchvision import transforms
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from unet import UNet
|
from unet import UNet
|
||||||
from utils import plot_img_and_mask
|
from utils.data_vis import plot_img_and_mask
|
||||||
from utils.dataset import BasicDataset
|
from utils.dataset import BasicDataset
|
||||||
|
from utils.crf import dense_crf
|
||||||
|
|
||||||
|
|
||||||
def predict_img(net,
|
def predict_img(net,
|
||||||
full_img,
|
full_img,
|
||||||
|
|
15
submit.py
15
submit.py
|
@ -4,10 +4,23 @@ import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from predict import predict_img
|
from predict import predict_img
|
||||||
from unet import UNet
|
from unet import UNet
|
||||||
from utils import rle_encode
|
|
||||||
|
# credits to https://stackoverflow.com/users/6076729/manuel-lagunas
|
||||||
|
def rle_encode(mask_image):
|
||||||
|
pixels = mask_image.flatten()
|
||||||
|
# We avoid issues with '1' at the start or end (at the corners of
|
||||||
|
# the original image) by setting those pixels to '0' explicitly.
|
||||||
|
# We do not expect these to be non-zero for an accurate mask,
|
||||||
|
# so this should not harm the score.
|
||||||
|
pixels[0] = 0
|
||||||
|
pixels[-1] = 0
|
||||||
|
runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
|
||||||
|
runs[1::2] = runs[1::2] - runs[:-1:2]
|
||||||
|
return runs
|
||||||
|
|
||||||
|
|
||||||
def submit(net, gpu=False):
|
def submit(net, gpu=False):
|
||||||
|
|
40
train.py
40
train.py
|
@ -11,8 +11,8 @@ from tqdm import tqdm
|
||||||
|
|
||||||
from eval import eval_net
|
from eval import eval_net
|
||||||
from unet import UNet
|
from unet import UNet
|
||||||
from utils import get_ids, split_train_val, get_imgs_and_masks, batch
|
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from utils.dataset import BasicDataset
|
from utils.dataset import BasicDataset
|
||||||
from torch.utils.data import DataLoader, random_split
|
from torch.utils.data import DataLoader, random_split
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ def train_net(net,
|
||||||
epochs=5,
|
epochs=5,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
lr=0.1,
|
lr=0.1,
|
||||||
val_percent=0.15,
|
val_percent=0.1,
|
||||||
save_cp=True,
|
save_cp=True,
|
||||||
img_scale=0.5):
|
img_scale=0.5):
|
||||||
|
|
||||||
|
@ -34,8 +34,11 @@ def train_net(net,
|
||||||
n_val = int(len(dataset) * val_percent)
|
n_val = int(len(dataset) * val_percent)
|
||||||
n_train = len(dataset) - n_val
|
n_train = len(dataset) - n_val
|
||||||
train, val = random_split(dataset, [n_train, n_val])
|
train, val = random_split(dataset, [n_train, n_val])
|
||||||
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=4)
|
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
|
||||||
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=4)
|
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)
|
||||||
|
|
||||||
|
writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
|
||||||
|
global_step = 0
|
||||||
|
|
||||||
logging.info(f'''Starting training:
|
logging.info(f'''Starting training:
|
||||||
Epochs: {epochs}
|
Epochs: {epochs}
|
||||||
|
@ -48,7 +51,7 @@ def train_net(net,
|
||||||
Images scaling: {img_scale}
|
Images scaling: {img_scale}
|
||||||
''')
|
''')
|
||||||
|
|
||||||
optimizer = optim.Adam(net.parameters(), lr=lr)
|
optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8)
|
||||||
if net.n_classes > 1:
|
if net.n_classes > 1:
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
else:
|
else:
|
||||||
|
@ -78,6 +81,7 @@ def train_net(net,
|
||||||
masks_pred = net(imgs)
|
masks_pred = net(imgs)
|
||||||
loss = criterion(masks_pred, true_masks)
|
loss = criterion(masks_pred, true_masks)
|
||||||
epoch_loss += loss.item()
|
epoch_loss += loss.item()
|
||||||
|
writer.add_scalar('Loss/train', loss.item(), global_step)
|
||||||
|
|
||||||
pbar.set_postfix(**{'loss (batch)': loss.item()})
|
pbar.set_postfix(**{'loss (batch)': loss.item()})
|
||||||
|
|
||||||
|
@ -85,7 +89,22 @@ def train_net(net,
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
pbar.update(batch_size)
|
pbar.update(imgs.shape[0])
|
||||||
|
global_step += 1
|
||||||
|
if global_step % (len(dataset) // (10 * batch_size)) == 0:
|
||||||
|
val_score = eval_net(net, val_loader, device, n_val)
|
||||||
|
if net.n_classes > 1:
|
||||||
|
logging.info('Validation cross entropy: {}'.format(val_score))
|
||||||
|
writer.add_scalar('Loss/test', val_score, global_step)
|
||||||
|
|
||||||
|
else:
|
||||||
|
logging.info('Validation Dice Coeff: {}'.format(val_score))
|
||||||
|
writer.add_scalar('Dice/test', val_score, global_step)
|
||||||
|
|
||||||
|
writer.add_images('images', imgs, global_step)
|
||||||
|
if net.n_classes == 1:
|
||||||
|
writer.add_images('masks/true', true_masks, global_step)
|
||||||
|
writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
|
||||||
|
|
||||||
if save_cp:
|
if save_cp:
|
||||||
try:
|
try:
|
||||||
|
@ -97,12 +116,7 @@ def train_net(net,
|
||||||
dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
|
dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
|
||||||
logging.info(f'Checkpoint {epoch + 1} saved !')
|
logging.info(f'Checkpoint {epoch + 1} saved !')
|
||||||
|
|
||||||
val_score = eval_net(net, val_loader, device, n_val)
|
writer.close()
|
||||||
if net.n_classes > 1:
|
|
||||||
logging.info('Validation cross entropy: {}'.format(val_score))
|
|
||||||
|
|
||||||
else:
|
|
||||||
logging.info('Validation Dice Coeff: {}'.format(val_score))
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
|
@ -118,7 +132,7 @@ def get_args():
|
||||||
help='Load model from a .pth file')
|
help='Load model from a .pth file')
|
||||||
parser.add_argument('-s', '--scale', dest='scale', type=float, default=0.5,
|
parser.add_argument('-s', '--scale', dest='scale', type=float, default=0.5,
|
||||||
help='Downscaling factor of the images')
|
help='Downscaling factor of the images')
|
||||||
parser.add_argument('-v', '--validation', dest='val', type=float, default=15.0,
|
parser.add_argument('-v', '--validation', dest='val', type=float, default=10.0,
|
||||||
help='Percent of the data that is used as validation (0-100)')
|
help='Percent of the data that is used as validation (0-100)')
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
|
@ -1,4 +0,0 @@
|
||||||
from .crf import *
|
|
||||||
from .load import *
|
|
||||||
from .utils import *
|
|
||||||
from .data_vis import *
|
|
|
@ -25,6 +25,7 @@ class BasicDataset(Dataset):
|
||||||
def preprocess(self, pil_img):
|
def preprocess(self, pil_img):
|
||||||
w, h = pil_img.size
|
w, h = pil_img.size
|
||||||
newW, newH = int(self.scale * w), int(self.scale * h)
|
newW, newH = int(self.scale * w), int(self.scale * h)
|
||||||
|
assert newW > 0 and newH > 0, 'Scale is too small'
|
||||||
pil_img = pil_img.resize((newW, newH))
|
pil_img = pil_img.resize((newW, newH))
|
||||||
|
|
||||||
img_nd = np.array(pil_img)
|
img_nd = np.array(pil_img)
|
||||||
|
|
|
@ -1,14 +0,0 @@
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# credits to https://stackoverflow.com/users/6076729/manuel-lagunas
|
|
||||||
def rle_encode(mask_image):
|
|
||||||
pixels = mask_image.flatten()
|
|
||||||
# We avoid issues with '1' at the start or end (at the corners of
|
|
||||||
# the original image) by setting those pixels to '0' explicitly.
|
|
||||||
# We do not expect these to be non-zero for an accurate mask,
|
|
||||||
# so this should not harm the score.
|
|
||||||
pixels[0] = 0
|
|
||||||
pixels[-1] = 0
|
|
||||||
runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
|
|
||||||
runs[1::2] = runs[1::2] - runs[:-1:2]
|
|
||||||
return runs
|
|
Loading…
Reference in a new issue