Former-commit-id: c3c20fee02acb1c9eec7332ec173e05e9df9fb99 [formerly a0bdb02a4ec4299e767d011e86970412be10d25c]
Former-commit-id: 495e626a762ffa4d8df3f66dd6e3d7c896950f94
This commit is contained in:
Laurent Fainsin 2022-06-29 10:20:35 +02:00
parent 826ae4ac9b
commit de7e2e895f
2 changed files with 20 additions and 15 deletions

View file

@ -5,7 +5,6 @@ from pathlib import Path
import albumentations as A import albumentations as A
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from albumentations.pytorch import ToTensorV2 from albumentations.pytorch import ToTensorV2
from torch import optim from torch import optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -112,16 +111,16 @@ def main():
A.Resize(500, 500), A.Resize(500, 500),
A.Flip(), A.Flip(),
A.ColorJitter(), A.ColorJitter(),
RandomPaste(5, 0.2, DIR_SPHERE_IMG, DIR_SPHERE_MASK), RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
A.ISONoise(), A.ISONoise(),
A.ToFloat(max_value=255), A.ToFloat(max_value=255),
A.pytorch.ToTensorV2(), ToTensorV2(),
], ],
) )
tf_valid = A.Compose( tf_valid = A.Compose(
[ [
A.Resize(500, 500), A.Resize(500, 500),
RandomPaste(5, 0.2, DIR_SPHERE_IMG, DIR_SPHERE_MASK), RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
A.ToFloat(max_value=255), A.ToFloat(max_value=255),
ToTensorV2(), ToTensorV2(),
], ],
@ -132,19 +131,18 @@ def main():
ds_valid = SphereDataset(image_dir=DIR_VALID_IMG, transform=tf_valid) ds_valid = SphereDataset(image_dir=DIR_VALID_IMG, transform=tf_valid)
# 3. Create data loaders # 3. Create data loaders
loader_args = dict(batch_size=args.batch_size, num_workers=4, pin_memory=True) loader_args = dict(batch_size=args.batch_size, num_workers=5, pin_memory=True)
train_loader = DataLoader(ds_train, shuffle=True, **loader_args) train_loader = DataLoader(ds_train, shuffle=True, **loader_args)
val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args) val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args)
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
optimizer = optim.RMSprop(net.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9) optimizer = optim.RMSprop(net.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) # goal: maximize Dice score scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp) grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
criterion = nn.BCEWithLogitsLoss() criterion = nn.BCEWithLogitsLoss()
# connect to wandb
wandb.init( wandb.init(
project="U-Net", project="U-Net-tmp",
config=dict( config=dict(
epochs=args.epochs, epochs=args.epochs,
batch_size=args.batch_size, batch_size=args.batch_size,
@ -177,24 +175,29 @@ def main():
"the images are loaded correctly." "the images are loaded correctly."
) )
# transfer images to device
images = images.to(device=device) images = images.to(device=device)
true_masks = true_masks.unsqueeze(1).to(device=device) true_masks = true_masks.unsqueeze(1).to(device=device)
# forward
with torch.cuda.amp.autocast(enabled=args.amp): with torch.cuda.amp.autocast(enabled=args.amp):
masks_pred = net(images) pred_masks = net(images)
train_loss = criterion(masks_pred, true_masks) # TODO: rajouter le diceloss train_loss = criterion(pred_masks, pred_masks)
# backward
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
grad_scaler.scale(train_loss).backward() grad_scaler.scale(train_loss).backward()
grad_scaler.step(optimizer) grad_scaler.step(optimizer)
grad_scaler.update() grad_scaler.update()
# update tqdm progress bar
pbar.update(images.shape[0]) pbar.update(images.shape[0])
pbar.set_postfix(**{"loss": train_loss.item()}) pbar.set_postfix(**{"loss": train_loss.item()})
wandb.log( # log training metrics # log training metrics
wandb.log(
{ {
"train/epoch": epoch + step / len(train_loader), "train/epoch": epoch - 1 + step / len(train_loader),
"train/train_loss": train_loss, "train/train_loss": train_loss,
} }
) )
@ -202,7 +205,9 @@ def main():
# Evaluation round # Evaluation round
val_score = evaluate(net, val_loader, device) val_score = evaluate(net, val_loader, device)
scheduler.step(val_score) scheduler.step(val_score)
wandb.log( # log validation metrics
# log validation metrics
wandb.log(
{ {
"val/val_score": val_score, "val/val_score": val_score,
} }

View file

@ -19,11 +19,11 @@ class SphereDataset(Dataset):
img_path = os.path.join(self.image_dir, self.images[index]) img_path = os.path.join(self.image_dir, self.images[index])
image = np.array(Image.open(img_path).convert("RGB"), dtype=np.uint8) image = np.array(Image.open(img_path).convert("RGB"), dtype=np.uint8)
mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.float32) mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
if self.transform is not None: if self.transform is not None:
augmentations = self.transform(image=image, mask=mask) augmentations = self.transform(image=image, mask=mask)
image = augmentations["image"] image = augmentations["image"]
mask = augmentations["mask"] mask = augmentations["mask"].float()
return image, mask return image, mask