mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-09-19 11:35:28 +00:00
feat: ajout des données de test
Former-commit-id: c582ae71d296afe90d25127f541c696052172a2a [formerly b5fe53254d424e3d6ea74573378a716ccd429d84] Former-commit-id: f0ae70c1025d70af43b0f172e5abaeeba819999f
This commit is contained in:
parent
2571e5c6d3
commit
0d6f85518e
118
src/train.py
118
src/train.py
|
@ -27,6 +27,7 @@ if __name__ == "__main__":
|
||||||
config=dict(
|
config=dict(
|
||||||
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/train/",
|
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/train/",
|
||||||
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/val/",
|
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/val/",
|
||||||
|
DIR_TEST_IMG="/home/lilian/data_disk/lfainsin/test/",
|
||||||
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
|
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
|
||||||
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
||||||
FEATURES=[64, 128, 256, 512],
|
FEATURES=[64, 128, 256, 512],
|
||||||
|
@ -88,10 +89,12 @@ if __name__ == "__main__":
|
||||||
# 2. Create datasets
|
# 2. Create datasets
|
||||||
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
|
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
|
||||||
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid)
|
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid)
|
||||||
|
ds_test = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG)
|
||||||
|
|
||||||
# 2.5. Create subset, if uncommented
|
# 2.5. Create subset, if uncommented
|
||||||
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000)))
|
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
|
||||||
ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100)))
|
ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 1000)))
|
||||||
|
ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100)))
|
||||||
|
|
||||||
# 3. Create data loaders
|
# 3. Create data loaders
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
|
@ -109,6 +112,14 @@ if __name__ == "__main__":
|
||||||
num_workers=wandb.config.WORKERS,
|
num_workers=wandb.config.WORKERS,
|
||||||
pin_memory=wandb.config.PIN_MEMORY,
|
pin_memory=wandb.config.PIN_MEMORY,
|
||||||
)
|
)
|
||||||
|
test_loader = DataLoader(
|
||||||
|
ds_test,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=False,
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=wandb.config.WORKERS,
|
||||||
|
pin_memory=wandb.config.PIN_MEMORY,
|
||||||
|
)
|
||||||
|
|
||||||
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp
|
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp
|
||||||
optimizer = torch.optim.RMSprop(
|
optimizer = torch.optim.RMSprop(
|
||||||
|
@ -121,12 +132,6 @@ if __name__ == "__main__":
|
||||||
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
|
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
|
||||||
criterion = torch.nn.BCEWithLogitsLoss()
|
criterion = torch.nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
# save model.pth
|
|
||||||
torch.save(net.state_dict(), "checkpoints/model-0.pth")
|
|
||||||
artifact = wandb.Artifact("pth", type="model")
|
|
||||||
artifact.add_file("checkpoints/model-0.pth")
|
|
||||||
wandb.run.log_artifact(artifact)
|
|
||||||
|
|
||||||
# save model.onxx
|
# save model.onxx
|
||||||
dummy_input = torch.randn(
|
dummy_input = torch.randn(
|
||||||
1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True
|
1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True
|
||||||
|
@ -143,14 +148,6 @@ if __name__ == "__main__":
|
||||||
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
|
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# wandb init log
|
|
||||||
# wandb.log(
|
|
||||||
# {
|
|
||||||
# "train/learning_rate": scheduler.get_lr(),
|
|
||||||
# },
|
|
||||||
# commit=False,
|
|
||||||
# )
|
|
||||||
|
|
||||||
for epoch in range(1, wandb.config.EPOCHS + 1):
|
for epoch in range(1, wandb.config.EPOCHS + 1):
|
||||||
with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar:
|
with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar:
|
||||||
|
|
||||||
|
@ -205,7 +202,7 @@ if __name__ == "__main__":
|
||||||
val_loss = 0
|
val_loss = 0
|
||||||
dice = 0
|
dice = 0
|
||||||
mae = 0
|
mae = 0
|
||||||
with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar2:
|
with tqdm(val_loader, total=len(ds_valid), desc="val.", unit="img", leave=False) as pbar2:
|
||||||
for images, masks_true in val_loader:
|
for images, masks_true in val_loader:
|
||||||
|
|
||||||
# transfer images to device
|
# transfer images to device
|
||||||
|
@ -217,10 +214,10 @@ if __name__ == "__main__":
|
||||||
masks_pred = net(images)
|
masks_pred = net(images)
|
||||||
|
|
||||||
# compute metrics
|
# compute metrics
|
||||||
val_loss += criterion(pred_masks, true_masks)
|
val_loss += criterion(pred_masks, masks_true)
|
||||||
mae += torch.nn.functional.l1_loss(pred_masks_bin, true_masks)
|
|
||||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||||
accuracy += (true_masks == pred_masks_bin).float().mean()
|
mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||||
|
accuracy += (masks_true == masks_pred_bin).float().mean()
|
||||||
dice += dice_coeff(masks_pred_bin, masks_true)
|
dice += dice_coeff(masks_pred_bin, masks_true)
|
||||||
|
|
||||||
# update progress bar
|
# update progress bar
|
||||||
|
@ -235,10 +232,10 @@ if __name__ == "__main__":
|
||||||
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
|
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
|
||||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||||
zip(
|
zip(
|
||||||
images.to("cpu"),
|
images.cpu(),
|
||||||
masks_true.to("cpu"),
|
masks_true.cpu(),
|
||||||
masks_pred.to("cpu"),
|
masks_pred.cpu(),
|
||||||
masks_pred_bin.to("cpu").squeeze().int().numpy(),
|
masks_pred_bin.cpu().squeeze(1).int().numpy(),
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
table.add_data(
|
table.add_data(
|
||||||
|
@ -259,7 +256,7 @@ if __name__ == "__main__":
|
||||||
# log validation metrics
|
# log validation metrics
|
||||||
wandb.log(
|
wandb.log(
|
||||||
{
|
{
|
||||||
"predictions": table,
|
"val/predictions": table,
|
||||||
"train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
|
"train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
|
||||||
"val/accuracy": accuracy,
|
"val/accuracy": accuracy,
|
||||||
"val/bce": val_loss,
|
"val/bce": val_loss,
|
||||||
|
@ -280,9 +277,80 @@ if __name__ == "__main__":
|
||||||
artifact.add_file(f"checkpoints/model-{epoch}-{step}.onnx")
|
artifact.add_file(f"checkpoints/model-{epoch}-{step}.onnx")
|
||||||
wandb.run.log_artifact(artifact)
|
wandb.run.log_artifact(artifact)
|
||||||
|
|
||||||
|
# testing round
|
||||||
|
net.eval()
|
||||||
|
accuracy = 0
|
||||||
|
val_loss = 0
|
||||||
|
dice = 0
|
||||||
|
mae = 0
|
||||||
|
with tqdm(test_loader, total=len(ds_test), desc="test", unit="img", leave=False) as pbar3:
|
||||||
|
for images, masks_true in test_loader:
|
||||||
|
|
||||||
|
# transfer images to device
|
||||||
|
images = images.to(device=device)
|
||||||
|
masks_true = masks_true.unsqueeze(1).to(device=device)
|
||||||
|
|
||||||
|
# forward
|
||||||
|
with torch.inference_mode():
|
||||||
|
masks_pred = net(images)
|
||||||
|
|
||||||
|
# compute metrics
|
||||||
|
val_loss += criterion(masks_pred, masks_true)
|
||||||
|
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||||
|
mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||||
|
accuracy += (masks_true == masks_pred_bin).float().mean()
|
||||||
|
dice += dice_coeff(masks_pred_bin, masks_true)
|
||||||
|
|
||||||
|
# update progress bar
|
||||||
|
pbar3.update(images.shape[0])
|
||||||
|
|
||||||
|
accuracy /= len(test_loader)
|
||||||
|
val_loss /= len(test_loader)
|
||||||
|
dice /= len(test_loader)
|
||||||
|
mae /= len(test_loader)
|
||||||
|
|
||||||
|
# save the last validation batch to table
|
||||||
|
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
|
||||||
|
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||||
|
zip(
|
||||||
|
images.cpu(),
|
||||||
|
masks_true.cpu(),
|
||||||
|
masks_pred.cpu(),
|
||||||
|
masks_pred_bin.cpu().squeeze(1).int().numpy(),
|
||||||
|
)
|
||||||
|
):
|
||||||
|
table.add_data(
|
||||||
|
i,
|
||||||
|
wandb.Image(img),
|
||||||
|
wandb.Image(mask),
|
||||||
|
wandb.Image(
|
||||||
|
pred,
|
||||||
|
masks={
|
||||||
|
"predictions": {
|
||||||
|
"mask_data": pred_bin,
|
||||||
|
"class_labels": class_labels,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# log validation metrics
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
"test/predictions": table,
|
||||||
|
"test/accuracy": accuracy,
|
||||||
|
"test/bce": val_loss,
|
||||||
|
"test/dice": dice,
|
||||||
|
"test/mae": mae,
|
||||||
|
},
|
||||||
|
commit=False,
|
||||||
|
)
|
||||||
|
|
||||||
# stop wandb
|
# stop wandb
|
||||||
wandb.run.finish()
|
wandb.run.finish()
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
# sapin de noel
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import albumentations as A
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from albumentations.pytorch import ToTensorV2
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
@ -16,12 +18,25 @@ class SphereDataset(Dataset):
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
image = np.array(Image.open(self.images[index]).convert("RGB"), dtype=np.uint8)
|
image = np.array(Image.open(self.images[index]).convert("RGB"), dtype=np.uint8)
|
||||||
|
|
||||||
mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
|
||||||
|
|
||||||
if self.transform is not None:
|
if self.transform is not None:
|
||||||
|
mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
||||||
augmentations = self.transform(image=image, mask=mask)
|
augmentations = self.transform(image=image, mask=mask)
|
||||||
image = augmentations["image"]
|
image = augmentations["image"]
|
||||||
mask = augmentations["mask"]
|
mask = augmentations["mask"]
|
||||||
|
else:
|
||||||
|
mask_path = self.images[index].parent.joinpath("MASK.PNG")
|
||||||
|
mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8) / 255
|
||||||
|
|
||||||
|
preprocess = A.Compose(
|
||||||
|
[
|
||||||
|
A.SmallestMaxSize(1024),
|
||||||
|
A.ToFloat(max_value=255),
|
||||||
|
ToTensorV2(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
augmentations = preprocess(image=image, mask=mask)
|
||||||
|
image = augmentations["image"]
|
||||||
|
mask = augmentations["mask"]
|
||||||
|
|
||||||
# make sure image and mask are floats
|
# make sure image and mask are floats
|
||||||
image = image.float()
|
image = image.float()
|
||||||
|
|
Loading…
Reference in a new issue