wtf am i doing
Former-commit-id: dde43cce52408ec8f67372b365796b9014ceee57
This commit is contained in:
parent
ad179241ce
commit
842aaa81c2
1
INTERRUPTED.pth.REMOVED.git-id
Normal file
1
INTERRUPTED.pth.REMOVED.git-id
Normal file
|
@ -0,0 +1 @@
|
|||
94f4597495259e6d28987c9ec3b6b2aa43df9810
|
|
@ -2,7 +2,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from tqdm import tqdm
|
||||
|
||||
from src.utils.dice import multiclass_dice_coeff, dice_coeff
|
||||
from src.utils.dice import dice_coeff, multiclass_dice_coeff
|
||||
|
||||
|
||||
def evaluate(net, dataloader, device):
|
||||
|
@ -11,32 +11,23 @@ def evaluate(net, dataloader, device):
|
|||
dice_score = 0
|
||||
|
||||
# iterate over the validation set
|
||||
for batch in tqdm(dataloader, total=num_val_batches, desc="Validation round", unit="batch", leave=False):
|
||||
image, mask_true = batch["image"], batch["mask"]
|
||||
# move images and labels to correct device and type
|
||||
image = image.to(device=device, dtype=torch.float32)
|
||||
mask_true = mask_true.to(device=device, dtype=torch.long)
|
||||
mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float()
|
||||
with tqdm(dataloader, total=len(dataloader.dataset), desc="Validation", unit="img", leave=False) as pbar:
|
||||
for images, masks_true in dataloader:
|
||||
# move images and labels to correct device
|
||||
images = images.to(device=device)
|
||||
masks_true = masks_true.unsqueeze(1).to(device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
# predict the mask
|
||||
mask_pred = net(image)
|
||||
with torch.inference_mode():
|
||||
# predict the mask
|
||||
masks_pred = net(images)
|
||||
masks_pred = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||
|
||||
# convert to one-hot format
|
||||
if net.n_classes == 1:
|
||||
mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
|
||||
# compute the Dice score
|
||||
dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False)
|
||||
else:
|
||||
mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float()
|
||||
# compute the Dice score, ignoring background
|
||||
dice_score += multiclass_dice_coeff(
|
||||
mask_pred[:, 1:, ...], mask_true[:, 1:, ...], reduce_batch_first=False
|
||||
)
|
||||
dice_score += dice_coeff(masks_pred, masks_true, reduce_batch_first=False)
|
||||
|
||||
pbar.update(images.shape[0])
|
||||
|
||||
net.train()
|
||||
|
||||
# Fixes a potential division by zero error
|
||||
if num_val_batches == 0:
|
||||
return dice_score
|
||||
return dice_score / num_val_batches
|
||||
return dice_score / num_val_batches if num_val_batches else dice_score
|
||||
|
|
280
src/train.py
280
src/train.py
|
@ -19,159 +19,12 @@ from unet import UNet
|
|||
from utils.paste import RandomPaste
|
||||
|
||||
CHECKPOINT_DIR = Path("./checkpoints/")
|
||||
DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/train2017")
|
||||
DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/val2017/")
|
||||
# DIR_VALID_MASK = Path("/home/lilian/data_disk/lfainsin/val2017mask/")
|
||||
DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/smoltrain2017")
|
||||
DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smolval2017/")
|
||||
DIR_SPHERE_IMG = Path("/home/lilian/data_disk/lfainsin/spheres/Images/")
|
||||
DIR_SPHERE_MASK = Path("/home/lilian/data_disk/lfainsin/spheres/Masks/")
|
||||
|
||||
|
||||
def train_net(
|
||||
net,
|
||||
device,
|
||||
epochs: int = 5,
|
||||
batch_size: int = 1,
|
||||
learning_rate: float = 1e-5,
|
||||
save_checkpoint: bool = True,
|
||||
amp: bool = False,
|
||||
):
|
||||
# 1. Create transforms
|
||||
tf_train = A.Compose(
|
||||
[
|
||||
A.Flip(),
|
||||
A.ColorJitter(),
|
||||
RandomPaste(5, 0.2, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
|
||||
A.ISONoise(),
|
||||
A.ToFloat(max_value=255),
|
||||
A.pytorch.ToTensorV2(),
|
||||
],
|
||||
)
|
||||
|
||||
tf_valid = A.Compose(
|
||||
[
|
||||
RandomPaste(5, 0.2, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
|
||||
A.ToFloat(max_value=255),
|
||||
ToTensorV2(),
|
||||
],
|
||||
)
|
||||
|
||||
# 2. Create datasets
|
||||
ds_train = SphereDataset(images_dir=DIR_TRAIN_IMG, transform=tf_train)
|
||||
# ds_valid = SphereDataset(images_dir=DIR_VALID_IMG, masks_dir=DIR_VALID_MASK, transform=tf_valid)
|
||||
ds_valid = SphereDataset(images_dir=DIR_VALID_IMG, transform=tf_valid)
|
||||
|
||||
# 3. Create data loaders
|
||||
loader_args = dict(batch_size=batch_size, num_workers=4, pin_memory=True)
|
||||
train_loader = DataLoader(ds_train, shuffle=True, **loader_args)
|
||||
val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args)
|
||||
|
||||
# (Initialize logging)
|
||||
experiment = wandb.init(
|
||||
project="U-Net",
|
||||
config=dict(
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
learning_rate=learning_rate,
|
||||
save_checkpoint=save_checkpoint,
|
||||
amp=amp,
|
||||
),
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"""Starting training:
|
||||
Epochs: {epochs}
|
||||
Batch size: {batch_size}
|
||||
Learning rate: {learning_rate}
|
||||
Training size: {len(ds_train)}
|
||||
Validation size: {len(ds_valid)}
|
||||
Checkpoints: {save_checkpoint}
|
||||
Device: {device.type}
|
||||
Mixed Precision: {amp}
|
||||
"""
|
||||
)
|
||||
|
||||
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
|
||||
optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) # goal: maximize Dice score
|
||||
grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
global_step = 0
|
||||
|
||||
# 5. Begin training
|
||||
for epoch in range(1, epochs + 1):
|
||||
net.train()
|
||||
epoch_loss = 0
|
||||
|
||||
with tqdm(total=len(ds_train), desc=f"Epoch {epoch}/{epochs}", unit="img") as pbar:
|
||||
for batch in train_loader:
|
||||
images = batch["image"]
|
||||
true_masks = batch["mask"]
|
||||
|
||||
assert images.shape[1] == net.n_channels, (
|
||||
f"Network has been defined with {net.n_channels} input channels, "
|
||||
f"but loaded images have {images.shape[1]} channels. Please check that "
|
||||
"the images are loaded correctly."
|
||||
)
|
||||
|
||||
images = images.to(device=device, dtype=torch.float32)
|
||||
true_masks = true_masks.to(device=device, dtype=torch.long)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=amp):
|
||||
masks_pred = net(images)
|
||||
loss = criterion(masks_pred, true_masks) + dice_loss(
|
||||
F.softmax(masks_pred, dim=1).float(),
|
||||
F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(),
|
||||
multiclass=True,
|
||||
)
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
grad_scaler.scale(loss).backward()
|
||||
grad_scaler.step(optimizer)
|
||||
grad_scaler.update()
|
||||
|
||||
pbar.update(images.shape[0])
|
||||
global_step += 1
|
||||
epoch_loss += loss.item()
|
||||
experiment.log({"train loss": loss.item(), "step": global_step, "epoch": epoch})
|
||||
pbar.set_postfix(**{"loss (batch)": loss.item()})
|
||||
|
||||
# Evaluation round
|
||||
division_step = len(ds_train) // (10 * batch_size)
|
||||
if division_step > 0:
|
||||
if global_step % division_step == 0:
|
||||
histograms = {}
|
||||
for tag, value in net.named_parameters():
|
||||
tag = tag.replace("/", ".")
|
||||
histograms["Weights/" + tag] = wandb.Histogram(value.data.cpu())
|
||||
histograms["Gradients/" + tag] = wandb.Histogram(value.grad.data.cpu())
|
||||
|
||||
val_score = evaluate(net, val_loader, device)
|
||||
scheduler.step(val_score)
|
||||
|
||||
logging.info("Validation Dice score: {}".format(val_score))
|
||||
experiment.log(
|
||||
{
|
||||
"learning rate": optimizer.param_groups[0]["lr"],
|
||||
"validation Dice": val_score,
|
||||
"images": wandb.Image(images[0].cpu()),
|
||||
"masks": {
|
||||
"true": wandb.Image(true_masks[0].float().cpu()),
|
||||
"pred": wandb.Image(
|
||||
torch.softmax(masks_pred, dim=1).argmax(dim=1)[0].float().cpu()
|
||||
),
|
||||
},
|
||||
"step": global_step,
|
||||
"epoch": epoch,
|
||||
**histograms,
|
||||
}
|
||||
)
|
||||
|
||||
if save_checkpoint:
|
||||
Path(CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True)
|
||||
torch.save(net.state_dict(), str(CHECKPOINT_DIR / "checkpoint_epoch{}.pth".format(epoch)))
|
||||
logging.info(f"Checkpoint {epoch} saved!")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train the UNet on images and target masks",
|
||||
|
@ -190,7 +43,7 @@ def get_args():
|
|||
dest="batch_size",
|
||||
metavar="B",
|
||||
type=int,
|
||||
default=32,
|
||||
default=10,
|
||||
help="Batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
@ -226,39 +79,148 @@ def get_args():
|
|||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
# get args from cli
|
||||
args = get_args()
|
||||
|
||||
# setup logging
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
|
||||
# enable cuda, if possible
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logging.info(f"Using device {device}")
|
||||
|
||||
# 0. Create network
|
||||
net = UNet(n_channels=3, n_classes=args.classes)
|
||||
|
||||
logging.info(
|
||||
f"""Network:
|
||||
\t{net.n_channels} input channels
|
||||
\t{net.n_classes} output channels (classes)
|
||||
input channels: {net.n_channels}
|
||||
output channels: {net.n_classes}
|
||||
"""
|
||||
)
|
||||
|
||||
# Load weights, if needed
|
||||
if args.load:
|
||||
net.load_state_dict(torch.load(args.load, map_location=device))
|
||||
logging.info(f"Model loaded from {args.load}")
|
||||
|
||||
# transfer network to device
|
||||
net.to(device=device)
|
||||
|
||||
try:
|
||||
train_net(
|
||||
net=net,
|
||||
# 1. Create transforms
|
||||
tf_train = A.Compose(
|
||||
[
|
||||
A.Resize(500, 500),
|
||||
A.Flip(),
|
||||
A.ColorJitter(),
|
||||
RandomPaste(5, 0.2, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
|
||||
A.ISONoise(),
|
||||
A.ToFloat(max_value=255),
|
||||
A.pytorch.ToTensorV2(),
|
||||
],
|
||||
)
|
||||
tf_valid = A.Compose(
|
||||
[
|
||||
A.Resize(500, 500),
|
||||
RandomPaste(5, 0.2, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
|
||||
A.ToFloat(max_value=255),
|
||||
ToTensorV2(),
|
||||
],
|
||||
)
|
||||
|
||||
# 2. Create datasets
|
||||
ds_train = SphereDataset(image_dir=DIR_TRAIN_IMG, transform=tf_train)
|
||||
ds_valid = SphereDataset(image_dir=DIR_VALID_IMG, transform=tf_valid)
|
||||
|
||||
# 3. Create data loaders
|
||||
loader_args = dict(batch_size=args.batch_size, num_workers=4, pin_memory=True)
|
||||
train_loader = DataLoader(ds_train, shuffle=True, **loader_args)
|
||||
val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args)
|
||||
|
||||
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
|
||||
optimizer = optim.RMSprop(net.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) # goal: maximize Dice score
|
||||
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
||||
criterion = nn.BCEWithLogitsLoss()
|
||||
|
||||
# connect to wandb
|
||||
wandb.init(
|
||||
project="U-Net",
|
||||
config=dict(
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batch_size,
|
||||
learning_rate=args.lr,
|
||||
device=device,
|
||||
amp=args.amp,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"""Starting training:
|
||||
Epochs: {args.epochs}
|
||||
Batch size: {args.batch_size}
|
||||
Learning rate: {args.lr}
|
||||
Training size: {len(ds_train)}
|
||||
Validation size: {len(ds_valid)}
|
||||
Device: {device.type}
|
||||
Mixed Precision: {args.amp}
|
||||
"""
|
||||
)
|
||||
|
||||
try:
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
with tqdm(total=len(ds_train), desc=f"{epoch}/{args.epochs}", unit="img") as pbar:
|
||||
|
||||
# Training round
|
||||
for step, (images, true_masks) in enumerate(train_loader):
|
||||
assert images.shape[1] == net.n_channels, (
|
||||
f"Network has been defined with {net.n_channels} input channels, "
|
||||
f"but loaded images have {images.shape[1]} channels. Please check that "
|
||||
"the images are loaded correctly."
|
||||
)
|
||||
|
||||
images = images.to(device=device)
|
||||
true_masks = true_masks.unsqueeze(1).to(device=device)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=args.amp):
|
||||
masks_pred = net(images)
|
||||
train_loss = criterion(masks_pred, true_masks) # TODO: rajouter le diceloss
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
grad_scaler.scale(train_loss).backward()
|
||||
grad_scaler.step(optimizer)
|
||||
grad_scaler.update()
|
||||
|
||||
pbar.update(images.shape[0])
|
||||
pbar.set_postfix(**{"loss": train_loss.item()})
|
||||
|
||||
wandb.log( # log training metrics
|
||||
{
|
||||
"train/epoch": epoch + step / epoch,
|
||||
"train/train_loss": train_loss,
|
||||
}
|
||||
)
|
||||
|
||||
# Evaluation round
|
||||
val_loss = evaluate(net, val_loader, device)
|
||||
scheduler.step(val_loss)
|
||||
wandb.log( # log validation metrics
|
||||
{
|
||||
"val/val_loss": val_loss,
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Train Loss: {train_loss:.3f}, Valid Loss: {val_loss:3f}")
|
||||
|
||||
# save weights when epoch end
|
||||
Path(CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True)
|
||||
torch.save(net.state_dict(), str(CHECKPOINT_DIR / "checkpoint_epoch{}.pth".format(epoch)))
|
||||
logging.info(f"Checkpoint {epoch} saved!")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
||||
logging.info("Saved interrupt")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -1,81 +1,29 @@
|
|||
import logging
|
||||
from os import listdir
|
||||
from os.path import splitext
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
import albumentations as A
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class SphereDataset(Dataset):
|
||||
def __init__(self, images_dir: str, transform: A.Compose, masks_dir: str = None):
|
||||
self.images_dir = Path(images_dir)
|
||||
self.masks_dir = Path(masks_dir) if masks_dir else None
|
||||
|
||||
self.ids = [splitext(file)[0] for file in listdir(images_dir) if not file.startswith(".")]
|
||||
|
||||
if not self.ids:
|
||||
raise RuntimeError(f"No input file found in {images_dir}, make sure you put your images there")
|
||||
|
||||
logging.info(f"Creating dataset with {len(self.ids)} examples")
|
||||
def __init__(self, image_dir, transform=None):
|
||||
self.image_dir = image_dir
|
||||
self.transform = transform
|
||||
self.images = os.listdir(image_dir)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.ids)
|
||||
return len(self.images)
|
||||
|
||||
@staticmethod
|
||||
def preprocess(pil_img, scale, is_mask):
|
||||
w, h = pil_img.size
|
||||
newW, newH = int(scale * w), int(scale * h)
|
||||
def __getitem__(self, index):
|
||||
img_path = os.path.join(self.image_dir, self.images[index])
|
||||
image = np.array(Image.open(img_path).convert("RGB"), dtype=np.uint8)
|
||||
|
||||
assert newW > 0 and newH > 0, "Scale is too small, resized images would have no pixel"
|
||||
mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.float32)
|
||||
|
||||
pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC)
|
||||
img_ndarray = np.asarray(pil_img)
|
||||
if self.transform is not None:
|
||||
augmentations = self.transform(image=image, mask=mask)
|
||||
image = augmentations["image"]
|
||||
mask = augmentations["mask"]
|
||||
|
||||
if not is_mask:
|
||||
if img_ndarray.ndim == 2:
|
||||
img_ndarray = img_ndarray[np.newaxis, ...]
|
||||
else:
|
||||
img_ndarray = img_ndarray.transpose((2, 0, 1))
|
||||
|
||||
img_ndarray = img_ndarray / 255
|
||||
|
||||
return img_ndarray
|
||||
|
||||
@staticmethod
|
||||
def load(filename):
|
||||
ext = splitext(filename)[1]
|
||||
|
||||
if ext in [".npz", ".npy"]:
|
||||
return Image.fromarray(np.load(filename))
|
||||
elif ext in [".pt", ".pth"]:
|
||||
return Image.fromarray(torch.load(filename).numpy())
|
||||
else:
|
||||
return Image.open(filename)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
name = self.ids[idx]
|
||||
|
||||
mask_file = list(self.masks_dir.glob(name + self.mask_suffix + ".*"))
|
||||
img_file = list(self.images_dir.glob(name + ".*"))
|
||||
|
||||
assert len(img_file) == 1, f"Either no image or multiple images found for the ID {name}: {img_file}"
|
||||
assert len(mask_file) == 1, f"Either no mask or multiple masks found for the ID {name}: {mask_file}"
|
||||
|
||||
mask = self.load(mask_file[0])
|
||||
img = self.load(img_file[0])
|
||||
|
||||
assert (
|
||||
img.size == mask.size
|
||||
), f"Image and mask {name} should be the same size, but are {img.size} and {mask.size}"
|
||||
|
||||
img = self.preprocess(img, self.scale, is_mask=False)
|
||||
mask = self.preprocess(mask, self.scale, is_mask=True)
|
||||
|
||||
return {
|
||||
"image": torch.as_tensor(img.copy()).float().contiguous(),
|
||||
"mask": torch.as_tensor(mask.copy()).long().contiguous(),
|
||||
}
|
||||
return image, mask
|
||||
|
|
Loading…
Reference in a new issue