diff --git a/src/train.py b/src/train.py index d8aff3e..b65cd45 100644 --- a/src/train.py +++ b/src/train.py @@ -27,6 +27,7 @@ if __name__ == "__main__": config=dict( DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/train/", DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/val/", + DIR_TEST_IMG="/home/lilian/data_disk/lfainsin/test/", DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/", DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/", FEATURES=[64, 128, 256, 512], @@ -88,10 +89,12 @@ if __name__ == "__main__": # 2. Create datasets ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train) ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid) + ds_test = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG) # 2.5. Create subset, if uncommented - ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000))) - ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100))) + ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000))) + ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 1000))) + ds_test = torch.utils.data.Subset(ds_test, list(range(0, len(ds_test), len(ds_test) // 100))) # 3. Create data loaders train_loader = DataLoader( @@ -109,6 +112,14 @@ if __name__ == "__main__": num_workers=wandb.config.WORKERS, pin_memory=wandb.config.PIN_MEMORY, ) + test_loader = DataLoader( + ds_test, + shuffle=False, + drop_last=False, + batch_size=1, + num_workers=wandb.config.WORKERS, + pin_memory=wandb.config.PIN_MEMORY, + ) # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp optimizer = torch.optim.RMSprop( @@ -121,12 +132,6 @@ if __name__ == "__main__": grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP) criterion = torch.nn.BCEWithLogitsLoss() - # save model.pth - torch.save(net.state_dict(), "checkpoints/model-0.pth") - artifact = wandb.Artifact("pth", type="model") - artifact.add_file("checkpoints/model-0.pth") - wandb.run.log_artifact(artifact) - # save model.onxx dummy_input = torch.randn( 1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True @@ -143,14 +148,6 @@ if __name__ == "__main__": logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}") try: - # wandb init log - # wandb.log( - # { - # "train/learning_rate": scheduler.get_lr(), - # }, - # commit=False, - # ) - for epoch in range(1, wandb.config.EPOCHS + 1): with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar: @@ -205,7 +202,7 @@ if __name__ == "__main__": val_loss = 0 dice = 0 mae = 0 - with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar2: + with tqdm(val_loader, total=len(ds_valid), desc="val.", unit="img", leave=False) as pbar2: for images, masks_true in val_loader: # transfer images to device @@ -217,10 +214,10 @@ if __name__ == "__main__": masks_pred = net(images) # compute metrics - val_loss += criterion(pred_masks, true_masks) - mae += torch.nn.functional.l1_loss(pred_masks_bin, true_masks) + val_loss += criterion(pred_masks, masks_true) masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() - accuracy += (true_masks == pred_masks_bin).float().mean() + mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true) + accuracy += (masks_true == masks_pred_bin).float().mean() dice += dice_coeff(masks_pred_bin, masks_true) # update progress bar @@ -235,10 +232,10 @@ if __name__ == "__main__": table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"]) for i, (img, mask, pred, pred_bin) in enumerate( zip( - images.to("cpu"), - masks_true.to("cpu"), - masks_pred.to("cpu"), - masks_pred_bin.to("cpu").squeeze().int().numpy(), + images.cpu(), + masks_true.cpu(), + masks_pred.cpu(), + masks_pred_bin.cpu().squeeze(1).int().numpy(), ) ): table.add_data( @@ -259,7 +256,7 @@ if __name__ == "__main__": # log validation metrics wandb.log( { - "predictions": table, + "val/predictions": table, "train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"], "val/accuracy": accuracy, "val/bce": val_loss, @@ -280,9 +277,80 @@ if __name__ == "__main__": artifact.add_file(f"checkpoints/model-{epoch}-{step}.onnx") wandb.run.log_artifact(artifact) + # testing round + net.eval() + accuracy = 0 + val_loss = 0 + dice = 0 + mae = 0 + with tqdm(test_loader, total=len(ds_test), desc="test", unit="img", leave=False) as pbar3: + for images, masks_true in test_loader: + + # transfer images to device + images = images.to(device=device) + masks_true = masks_true.unsqueeze(1).to(device=device) + + # forward + with torch.inference_mode(): + masks_pred = net(images) + + # compute metrics + val_loss += criterion(masks_pred, masks_true) + masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float() + mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true) + accuracy += (masks_true == masks_pred_bin).float().mean() + dice += dice_coeff(masks_pred_bin, masks_true) + + # update progress bar + pbar3.update(images.shape[0]) + + accuracy /= len(test_loader) + val_loss /= len(test_loader) + dice /= len(test_loader) + mae /= len(test_loader) + + # save the last validation batch to table + table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"]) + for i, (img, mask, pred, pred_bin) in enumerate( + zip( + images.cpu(), + masks_true.cpu(), + masks_pred.cpu(), + masks_pred_bin.cpu().squeeze(1).int().numpy(), + ) + ): + table.add_data( + i, + wandb.Image(img), + wandb.Image(mask), + wandb.Image( + pred, + masks={ + "predictions": { + "mask_data": pred_bin, + "class_labels": class_labels, + }, + }, + ), + ) + + # log validation metrics + wandb.log( + { + "test/predictions": table, + "test/accuracy": accuracy, + "test/bce": val_loss, + "test/dice": dice, + "test/mae": mae, + }, + commit=False, + ) + # stop wandb wandb.run.finish() except KeyboardInterrupt: torch.save(net.state_dict(), "INTERRUPTED.pth") raise + +# sapin de noel diff --git a/src/utils/dataset.py b/src/utils/dataset.py index 5f47bf9..bf821a4 100644 --- a/src/utils/dataset.py +++ b/src/utils/dataset.py @@ -1,6 +1,8 @@ from pathlib import Path +import albumentations as A import numpy as np +from albumentations.pytorch import ToTensorV2 from PIL import Image from torch.utils.data import Dataset @@ -16,12 +18,25 @@ class SphereDataset(Dataset): def __getitem__(self, index): image = np.array(Image.open(self.images[index]).convert("RGB"), dtype=np.uint8) - mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) - if self.transform is not None: + mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) augmentations = self.transform(image=image, mask=mask) image = augmentations["image"] mask = augmentations["mask"] + else: + mask_path = self.images[index].parent.joinpath("MASK.PNG") + mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8) / 255 + + preprocess = A.Compose( + [ + A.SmallestMaxSize(1024), + A.ToFloat(max_value=255), + ToTensorV2(), + ], + ) + augmentations = preprocess(image=image, mask=mask) + image = augmentations["image"] + mask = augmentations["mask"] # make sure image and mask are floats image = image.float()