diff --git a/INTERRUPTED.pth.REMOVED.git-id b/INTERRUPTED.pth.REMOVED.git-id new file mode 100644 index 0000000..2165389 --- /dev/null +++ b/INTERRUPTED.pth.REMOVED.git-id @@ -0,0 +1 @@ +94f4597495259e6d28987c9ec3b6b2aa43df9810 \ No newline at end of file diff --git a/src/evaluate.py b/src/evaluate.py index bdaeb78..3d56d51 100644 --- a/src/evaluate.py +++ b/src/evaluate.py @@ -2,7 +2,7 @@ import torch import torch.nn.functional as F from tqdm import tqdm -from src.utils.dice import multiclass_dice_coeff, dice_coeff +from src.utils.dice import dice_coeff, multiclass_dice_coeff def evaluate(net, dataloader, device): @@ -11,32 +11,23 @@ def evaluate(net, dataloader, device): dice_score = 0 # iterate over the validation set - for batch in tqdm(dataloader, total=num_val_batches, desc="Validation round", unit="batch", leave=False): - image, mask_true = batch["image"], batch["mask"] - # move images and labels to correct device and type - image = image.to(device=device, dtype=torch.float32) - mask_true = mask_true.to(device=device, dtype=torch.long) - mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float() + with tqdm(dataloader, total=len(dataloader.dataset), desc="Validation", unit="img", leave=False) as pbar: + for images, masks_true in dataloader: + # move images and labels to correct device + images = images.to(device=device) + masks_true = masks_true.unsqueeze(1).to(device=device) - with torch.no_grad(): - # predict the mask - mask_pred = net(image) + with torch.inference_mode(): + # predict the mask + masks_pred = net(images) + masks_pred = (torch.sigmoid(masks_pred) > 0.5).float() - # convert to one-hot format - if net.n_classes == 1: - mask_pred = (F.sigmoid(mask_pred) > 0.5).float() # compute the Dice score - dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False) - else: - mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float() - # compute the Dice score, ignoring background - dice_score += multiclass_dice_coeff( - mask_pred[:, 1:, ...], mask_true[:, 1:, ...], reduce_batch_first=False - ) + dice_score += dice_coeff(masks_pred, masks_true, reduce_batch_first=False) + + pbar.update(images.shape[0]) net.train() # Fixes a potential division by zero error - if num_val_batches == 0: - return dice_score - return dice_score / num_val_batches + return dice_score / num_val_batches if num_val_batches else dice_score diff --git a/src/train.py b/src/train.py index 969a43a..768de17 100644 --- a/src/train.py +++ b/src/train.py @@ -19,159 +19,12 @@ from unet import UNet from utils.paste import RandomPaste CHECKPOINT_DIR = Path("./checkpoints/") -DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/train2017") -DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/val2017/") -# DIR_VALID_MASK = Path("/home/lilian/data_disk/lfainsin/val2017mask/") +DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/smoltrain2017") +DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smolval2017/") DIR_SPHERE_IMG = Path("/home/lilian/data_disk/lfainsin/spheres/Images/") DIR_SPHERE_MASK = Path("/home/lilian/data_disk/lfainsin/spheres/Masks/") -def train_net( - net, - device, - epochs: int = 5, - batch_size: int = 1, - learning_rate: float = 1e-5, - save_checkpoint: bool = True, - amp: bool = False, -): - # 1. Create transforms - tf_train = A.Compose( - [ - A.Flip(), - A.ColorJitter(), - RandomPaste(5, 0.2, DIR_SPHERE_IMG, DIR_SPHERE_MASK), - A.ISONoise(), - A.ToFloat(max_value=255), - A.pytorch.ToTensorV2(), - ], - ) - - tf_valid = A.Compose( - [ - RandomPaste(5, 0.2, DIR_SPHERE_IMG, DIR_SPHERE_MASK), - A.ToFloat(max_value=255), - ToTensorV2(), - ], - ) - - # 2. Create datasets - ds_train = SphereDataset(images_dir=DIR_TRAIN_IMG, transform=tf_train) - # ds_valid = SphereDataset(images_dir=DIR_VALID_IMG, masks_dir=DIR_VALID_MASK, transform=tf_valid) - ds_valid = SphereDataset(images_dir=DIR_VALID_IMG, transform=tf_valid) - - # 3. Create data loaders - loader_args = dict(batch_size=batch_size, num_workers=4, pin_memory=True) - train_loader = DataLoader(ds_train, shuffle=True, **loader_args) - val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args) - - # (Initialize logging) - experiment = wandb.init( - project="U-Net", - config=dict( - epochs=epochs, - batch_size=batch_size, - learning_rate=learning_rate, - save_checkpoint=save_checkpoint, - amp=amp, - ), - ) - - logging.info( - f"""Starting training: - Epochs: {epochs} - Batch size: {batch_size} - Learning rate: {learning_rate} - Training size: {len(ds_train)} - Validation size: {len(ds_valid)} - Checkpoints: {save_checkpoint} - Device: {device.type} - Mixed Precision: {amp} - """ - ) - - # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP - optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9) - scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) # goal: maximize Dice score - grad_scaler = torch.cuda.amp.GradScaler(enabled=amp) - criterion = nn.CrossEntropyLoss() - global_step = 0 - - # 5. Begin training - for epoch in range(1, epochs + 1): - net.train() - epoch_loss = 0 - - with tqdm(total=len(ds_train), desc=f"Epoch {epoch}/{epochs}", unit="img") as pbar: - for batch in train_loader: - images = batch["image"] - true_masks = batch["mask"] - - assert images.shape[1] == net.n_channels, ( - f"Network has been defined with {net.n_channels} input channels, " - f"but loaded images have {images.shape[1]} channels. Please check that " - "the images are loaded correctly." - ) - - images = images.to(device=device, dtype=torch.float32) - true_masks = true_masks.to(device=device, dtype=torch.long) - - with torch.cuda.amp.autocast(enabled=amp): - masks_pred = net(images) - loss = criterion(masks_pred, true_masks) + dice_loss( - F.softmax(masks_pred, dim=1).float(), - F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(), - multiclass=True, - ) - - optimizer.zero_grad(set_to_none=True) - grad_scaler.scale(loss).backward() - grad_scaler.step(optimizer) - grad_scaler.update() - - pbar.update(images.shape[0]) - global_step += 1 - epoch_loss += loss.item() - experiment.log({"train loss": loss.item(), "step": global_step, "epoch": epoch}) - pbar.set_postfix(**{"loss (batch)": loss.item()}) - - # Evaluation round - division_step = len(ds_train) // (10 * batch_size) - if division_step > 0: - if global_step % division_step == 0: - histograms = {} - for tag, value in net.named_parameters(): - tag = tag.replace("/", ".") - histograms["Weights/" + tag] = wandb.Histogram(value.data.cpu()) - histograms["Gradients/" + tag] = wandb.Histogram(value.grad.data.cpu()) - - val_score = evaluate(net, val_loader, device) - scheduler.step(val_score) - - logging.info("Validation Dice score: {}".format(val_score)) - experiment.log( - { - "learning rate": optimizer.param_groups[0]["lr"], - "validation Dice": val_score, - "images": wandb.Image(images[0].cpu()), - "masks": { - "true": wandb.Image(true_masks[0].float().cpu()), - "pred": wandb.Image( - torch.softmax(masks_pred, dim=1).argmax(dim=1)[0].float().cpu() - ), - }, - "step": global_step, - "epoch": epoch, - **histograms, - } - ) - - if save_checkpoint: - Path(CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True) - torch.save(net.state_dict(), str(CHECKPOINT_DIR / "checkpoint_epoch{}.pth".format(epoch))) - logging.info(f"Checkpoint {epoch} saved!") - - def get_args(): parser = argparse.ArgumentParser( description="Train the UNet on images and target masks", @@ -190,7 +43,7 @@ def get_args(): dest="batch_size", metavar="B", type=int, - default=32, + default=10, help="Batch size", ) parser.add_argument( @@ -226,39 +79,148 @@ def get_args(): return parser.parse_args() -if __name__ == "__main__": +def main(): + # get args from cli args = get_args() + # setup logging logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + # enable cuda, if possible device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logging.info(f"Using device {device}") + # 0. Create network net = UNet(n_channels=3, n_classes=args.classes) - logging.info( f"""Network: - \t{net.n_channels} input channels - \t{net.n_classes} output channels (classes) + input channels: {net.n_channels} + output channels: {net.n_classes} """ ) + # Load weights, if needed if args.load: net.load_state_dict(torch.load(args.load, map_location=device)) logging.info(f"Model loaded from {args.load}") + # transfer network to device net.to(device=device) - try: - train_net( - net=net, + # 1. Create transforms + tf_train = A.Compose( + [ + A.Resize(500, 500), + A.Flip(), + A.ColorJitter(), + RandomPaste(5, 0.2, DIR_SPHERE_IMG, DIR_SPHERE_MASK), + A.ISONoise(), + A.ToFloat(max_value=255), + A.pytorch.ToTensorV2(), + ], + ) + tf_valid = A.Compose( + [ + A.Resize(500, 500), + RandomPaste(5, 0.2, DIR_SPHERE_IMG, DIR_SPHERE_MASK), + A.ToFloat(max_value=255), + ToTensorV2(), + ], + ) + + # 2. Create datasets + ds_train = SphereDataset(image_dir=DIR_TRAIN_IMG, transform=tf_train) + ds_valid = SphereDataset(image_dir=DIR_VALID_IMG, transform=tf_valid) + + # 3. Create data loaders + loader_args = dict(batch_size=args.batch_size, num_workers=4, pin_memory=True) + train_loader = DataLoader(ds_train, shuffle=True, **loader_args) + val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args) + + # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP + optimizer = optim.RMSprop(net.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9) + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) # goal: maximize Dice score + grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp) + criterion = nn.BCEWithLogitsLoss() + + # connect to wandb + wandb.init( + project="U-Net", + config=dict( epochs=args.epochs, batch_size=args.batch_size, learning_rate=args.lr, - device=device, amp=args.amp, - ) + ), + ) + + logging.info( + f"""Starting training: + Epochs: {args.epochs} + Batch size: {args.batch_size} + Learning rate: {args.lr} + Training size: {len(ds_train)} + Validation size: {len(ds_valid)} + Device: {device.type} + Mixed Precision: {args.amp} + """ + ) + + try: + for epoch in range(1, args.epochs + 1): + with tqdm(total=len(ds_train), desc=f"{epoch}/{args.epochs}", unit="img") as pbar: + + # Training round + for step, (images, true_masks) in enumerate(train_loader): + assert images.shape[1] == net.n_channels, ( + f"Network has been defined with {net.n_channels} input channels, " + f"but loaded images have {images.shape[1]} channels. Please check that " + "the images are loaded correctly." + ) + + images = images.to(device=device) + true_masks = true_masks.unsqueeze(1).to(device=device) + + with torch.cuda.amp.autocast(enabled=args.amp): + masks_pred = net(images) + train_loss = criterion(masks_pred, true_masks) # TODO: rajouter le diceloss + + optimizer.zero_grad(set_to_none=True) + grad_scaler.scale(train_loss).backward() + grad_scaler.step(optimizer) + grad_scaler.update() + + pbar.update(images.shape[0]) + pbar.set_postfix(**{"loss": train_loss.item()}) + + wandb.log( # log training metrics + { + "train/epoch": epoch + step / epoch, + "train/train_loss": train_loss, + } + ) + + # Evaluation round + val_loss = evaluate(net, val_loader, device) + scheduler.step(val_loss) + wandb.log( # log validation metrics + { + "val/val_loss": val_loss, + } + ) + + print(f"Train Loss: {train_loss:.3f}, Valid Loss: {val_loss:3f}") + + # save weights when epoch end + Path(CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True) + torch.save(net.state_dict(), str(CHECKPOINT_DIR / "checkpoint_epoch{}.pth".format(epoch))) + logging.info(f"Checkpoint {epoch} saved!") + except KeyboardInterrupt: torch.save(net.state_dict(), "INTERRUPTED.pth") logging.info("Saved interrupt") raise + + +if __name__ == "__main__": + main() diff --git a/src/utils/dataset.py b/src/utils/dataset.py index 6680bd8..e4d8369 100644 --- a/src/utils/dataset.py +++ b/src/utils/dataset.py @@ -1,81 +1,29 @@ import logging -from os import listdir -from os.path import splitext -from pathlib import Path +import os -import albumentations as A import numpy as np -import torch from PIL import Image from torch.utils.data import Dataset class SphereDataset(Dataset): - def __init__(self, images_dir: str, transform: A.Compose, masks_dir: str = None): - self.images_dir = Path(images_dir) - self.masks_dir = Path(masks_dir) if masks_dir else None - - self.ids = [splitext(file)[0] for file in listdir(images_dir) if not file.startswith(".")] - - if not self.ids: - raise RuntimeError(f"No input file found in {images_dir}, make sure you put your images there") - - logging.info(f"Creating dataset with {len(self.ids)} examples") + def __init__(self, image_dir, transform=None): + self.image_dir = image_dir + self.transform = transform + self.images = os.listdir(image_dir) def __len__(self): - return len(self.ids) + return len(self.images) - @staticmethod - def preprocess(pil_img, scale, is_mask): - w, h = pil_img.size - newW, newH = int(scale * w), int(scale * h) + def __getitem__(self, index): + img_path = os.path.join(self.image_dir, self.images[index]) + image = np.array(Image.open(img_path).convert("RGB"), dtype=np.uint8) - assert newW > 0 and newH > 0, "Scale is too small, resized images would have no pixel" + mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.float32) - pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC) - img_ndarray = np.asarray(pil_img) + if self.transform is not None: + augmentations = self.transform(image=image, mask=mask) + image = augmentations["image"] + mask = augmentations["mask"] - if not is_mask: - if img_ndarray.ndim == 2: - img_ndarray = img_ndarray[np.newaxis, ...] - else: - img_ndarray = img_ndarray.transpose((2, 0, 1)) - - img_ndarray = img_ndarray / 255 - - return img_ndarray - - @staticmethod - def load(filename): - ext = splitext(filename)[1] - - if ext in [".npz", ".npy"]: - return Image.fromarray(np.load(filename)) - elif ext in [".pt", ".pth"]: - return Image.fromarray(torch.load(filename).numpy()) - else: - return Image.open(filename) - - def __getitem__(self, idx): - name = self.ids[idx] - - mask_file = list(self.masks_dir.glob(name + self.mask_suffix + ".*")) - img_file = list(self.images_dir.glob(name + ".*")) - - assert len(img_file) == 1, f"Either no image or multiple images found for the ID {name}: {img_file}" - assert len(mask_file) == 1, f"Either no mask or multiple masks found for the ID {name}: {mask_file}" - - mask = self.load(mask_file[0]) - img = self.load(img_file[0]) - - assert ( - img.size == mask.size - ), f"Image and mask {name} should be the same size, but are {img.size} and {mask.size}" - - img = self.preprocess(img, self.scale, is_mask=False) - mask = self.preprocess(mask, self.scale, is_mask=True) - - return { - "image": torch.as_tensor(img.copy()).float().contiguous(), - "mask": torch.as_tensor(mask.copy()).long().contiguous(), - } + return image, mask