mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
f5
Former-commit-id: c3c20fee02acb1c9eec7332ec173e05e9df9fb99 [formerly a0bdb02a4ec4299e767d011e86970412be10d25c] Former-commit-id: 495e626a762ffa4d8df3f66dd6e3d7c896950f94
This commit is contained in:
parent
826ae4ac9b
commit
de7e2e895f
31
src/train.py
31
src/train.py
|
@ -5,7 +5,6 @@ from pathlib import Path
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
from albumentations.pytorch import ToTensorV2
|
from albumentations.pytorch import ToTensorV2
|
||||||
from torch import optim
|
from torch import optim
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
@ -112,16 +111,16 @@ def main():
|
||||||
A.Resize(500, 500),
|
A.Resize(500, 500),
|
||||||
A.Flip(),
|
A.Flip(),
|
||||||
A.ColorJitter(),
|
A.ColorJitter(),
|
||||||
RandomPaste(5, 0.2, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
|
RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
|
||||||
A.ISONoise(),
|
A.ISONoise(),
|
||||||
A.ToFloat(max_value=255),
|
A.ToFloat(max_value=255),
|
||||||
A.pytorch.ToTensorV2(),
|
ToTensorV2(),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
tf_valid = A.Compose(
|
tf_valid = A.Compose(
|
||||||
[
|
[
|
||||||
A.Resize(500, 500),
|
A.Resize(500, 500),
|
||||||
RandomPaste(5, 0.2, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
|
RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
|
||||||
A.ToFloat(max_value=255),
|
A.ToFloat(max_value=255),
|
||||||
ToTensorV2(),
|
ToTensorV2(),
|
||||||
],
|
],
|
||||||
|
@ -132,19 +131,18 @@ def main():
|
||||||
ds_valid = SphereDataset(image_dir=DIR_VALID_IMG, transform=tf_valid)
|
ds_valid = SphereDataset(image_dir=DIR_VALID_IMG, transform=tf_valid)
|
||||||
|
|
||||||
# 3. Create data loaders
|
# 3. Create data loaders
|
||||||
loader_args = dict(batch_size=args.batch_size, num_workers=4, pin_memory=True)
|
loader_args = dict(batch_size=args.batch_size, num_workers=5, pin_memory=True)
|
||||||
train_loader = DataLoader(ds_train, shuffle=True, **loader_args)
|
train_loader = DataLoader(ds_train, shuffle=True, **loader_args)
|
||||||
val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args)
|
val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args)
|
||||||
|
|
||||||
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
|
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
|
||||||
optimizer = optim.RMSprop(net.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9)
|
optimizer = optim.RMSprop(net.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9)
|
||||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) # goal: maximize Dice score
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
|
||||||
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
||||||
criterion = nn.BCEWithLogitsLoss()
|
criterion = nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
# connect to wandb
|
|
||||||
wandb.init(
|
wandb.init(
|
||||||
project="U-Net",
|
project="U-Net-tmp",
|
||||||
config=dict(
|
config=dict(
|
||||||
epochs=args.epochs,
|
epochs=args.epochs,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
|
@ -177,24 +175,29 @@ def main():
|
||||||
"the images are loaded correctly."
|
"the images are loaded correctly."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# transfer images to device
|
||||||
images = images.to(device=device)
|
images = images.to(device=device)
|
||||||
true_masks = true_masks.unsqueeze(1).to(device=device)
|
true_masks = true_masks.unsqueeze(1).to(device=device)
|
||||||
|
|
||||||
|
# forward
|
||||||
with torch.cuda.amp.autocast(enabled=args.amp):
|
with torch.cuda.amp.autocast(enabled=args.amp):
|
||||||
masks_pred = net(images)
|
pred_masks = net(images)
|
||||||
train_loss = criterion(masks_pred, true_masks) # TODO: rajouter le diceloss
|
train_loss = criterion(pred_masks, pred_masks)
|
||||||
|
|
||||||
|
# backward
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
grad_scaler.scale(train_loss).backward()
|
grad_scaler.scale(train_loss).backward()
|
||||||
grad_scaler.step(optimizer)
|
grad_scaler.step(optimizer)
|
||||||
grad_scaler.update()
|
grad_scaler.update()
|
||||||
|
|
||||||
|
# update tqdm progress bar
|
||||||
pbar.update(images.shape[0])
|
pbar.update(images.shape[0])
|
||||||
pbar.set_postfix(**{"loss": train_loss.item()})
|
pbar.set_postfix(**{"loss": train_loss.item()})
|
||||||
|
|
||||||
wandb.log( # log training metrics
|
# log training metrics
|
||||||
|
wandb.log(
|
||||||
{
|
{
|
||||||
"train/epoch": epoch + step / len(train_loader),
|
"train/epoch": epoch - 1 + step / len(train_loader),
|
||||||
"train/train_loss": train_loss,
|
"train/train_loss": train_loss,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -202,7 +205,9 @@ def main():
|
||||||
# Evaluation round
|
# Evaluation round
|
||||||
val_score = evaluate(net, val_loader, device)
|
val_score = evaluate(net, val_loader, device)
|
||||||
scheduler.step(val_score)
|
scheduler.step(val_score)
|
||||||
wandb.log( # log validation metrics
|
|
||||||
|
# log validation metrics
|
||||||
|
wandb.log(
|
||||||
{
|
{
|
||||||
"val/val_score": val_score,
|
"val/val_score": val_score,
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,11 +19,11 @@ class SphereDataset(Dataset):
|
||||||
img_path = os.path.join(self.image_dir, self.images[index])
|
img_path = os.path.join(self.image_dir, self.images[index])
|
||||||
image = np.array(Image.open(img_path).convert("RGB"), dtype=np.uint8)
|
image = np.array(Image.open(img_path).convert("RGB"), dtype=np.uint8)
|
||||||
|
|
||||||
mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.float32)
|
mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
||||||
|
|
||||||
if self.transform is not None:
|
if self.transform is not None:
|
||||||
augmentations = self.transform(image=image, mask=mask)
|
augmentations = self.transform(image=image, mask=mask)
|
||||||
image = augmentations["image"]
|
image = augmentations["image"]
|
||||||
mask = augmentations["mask"]
|
mask = augmentations["mask"].float()
|
||||||
|
|
||||||
return image, mask
|
return image, mask
|
||||||
|
|
Loading…
Reference in a new issue