From de7e2e895f9a0bb2dff4ed1f772e65668d532acb Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Wed, 29 Jun 2022 10:20:35 +0200 Subject: [PATCH] f5 Former-commit-id: c3c20fee02acb1c9eec7332ec173e05e9df9fb99 [formerly a0bdb02a4ec4299e767d011e86970412be10d25c] Former-commit-id: 495e626a762ffa4d8df3f66dd6e3d7c896950f94 --- src/train.py | 31 ++++++++++++++++++------------- src/utils/dataset.py | 4 ++-- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/train.py b/src/train.py index fcd901e..b286cfa 100644 --- a/src/train.py +++ b/src/train.py @@ -5,7 +5,6 @@ from pathlib import Path import albumentations as A import torch import torch.nn as nn -import torch.nn.functional as F from albumentations.pytorch import ToTensorV2 from torch import optim from torch.utils.data import DataLoader @@ -112,16 +111,16 @@ def main(): A.Resize(500, 500), A.Flip(), A.ColorJitter(), - RandomPaste(5, 0.2, DIR_SPHERE_IMG, DIR_SPHERE_MASK), + RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK), A.ISONoise(), A.ToFloat(max_value=255), - A.pytorch.ToTensorV2(), + ToTensorV2(), ], ) tf_valid = A.Compose( [ A.Resize(500, 500), - RandomPaste(5, 0.2, DIR_SPHERE_IMG, DIR_SPHERE_MASK), + RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK), A.ToFloat(max_value=255), ToTensorV2(), ], @@ -132,19 +131,18 @@ def main(): ds_valid = SphereDataset(image_dir=DIR_VALID_IMG, transform=tf_valid) # 3. Create data loaders - loader_args = dict(batch_size=args.batch_size, num_workers=4, pin_memory=True) + loader_args = dict(batch_size=args.batch_size, num_workers=5, pin_memory=True) train_loader = DataLoader(ds_train, shuffle=True, **loader_args) val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args) # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP optimizer = optim.RMSprop(net.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9) - scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) # goal: maximize Dice score + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp) criterion = nn.BCEWithLogitsLoss() - # connect to wandb wandb.init( - project="U-Net", + project="U-Net-tmp", config=dict( epochs=args.epochs, batch_size=args.batch_size, @@ -177,24 +175,29 @@ def main(): "the images are loaded correctly." ) + # transfer images to device images = images.to(device=device) true_masks = true_masks.unsqueeze(1).to(device=device) + # forward with torch.cuda.amp.autocast(enabled=args.amp): - masks_pred = net(images) - train_loss = criterion(masks_pred, true_masks) # TODO: rajouter le diceloss + pred_masks = net(images) + train_loss = criterion(pred_masks, pred_masks) + # backward optimizer.zero_grad(set_to_none=True) grad_scaler.scale(train_loss).backward() grad_scaler.step(optimizer) grad_scaler.update() + # update tqdm progress bar pbar.update(images.shape[0]) pbar.set_postfix(**{"loss": train_loss.item()}) - wandb.log( # log training metrics + # log training metrics + wandb.log( { - "train/epoch": epoch + step / len(train_loader), + "train/epoch": epoch - 1 + step / len(train_loader), "train/train_loss": train_loss, } ) @@ -202,7 +205,9 @@ def main(): # Evaluation round val_score = evaluate(net, val_loader, device) scheduler.step(val_score) - wandb.log( # log validation metrics + + # log validation metrics + wandb.log( { "val/val_score": val_score, } diff --git a/src/utils/dataset.py b/src/utils/dataset.py index e4d8369..f1cb2be 100644 --- a/src/utils/dataset.py +++ b/src/utils/dataset.py @@ -19,11 +19,11 @@ class SphereDataset(Dataset): img_path = os.path.join(self.image_dir, self.images[index]) image = np.array(Image.open(img_path).convert("RGB"), dtype=np.uint8) - mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.float32) + mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) if self.transform is not None: augmentations = self.transform(image=image, mask=mask) image = augmentations["image"] - mask = augmentations["mask"] + mask = augmentations["mask"].float() return image, mask