diff --git a/src/data/dataset.py b/src/data/dataset.py index d30931f..61508f2 100644 --- a/src/data/dataset.py +++ b/src/data/dataset.py @@ -58,7 +58,7 @@ class LabeledDataset(Dataset): # open and convert mask mask_path = self.images[index].parent.joinpath("MASK.PNG") - mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8) / 255 + mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8) // 255 # convert image & mask to Tensor float in [0, 1] post_process = A.Compose( @@ -72,4 +72,8 @@ class LabeledDataset(Dataset): image = augmentations["image"] mask = augmentations["mask"] + # make sure image and mask are floats, TODO: mettre dans le post_process, ToFloat Image only + image = image.float() + mask = mask.float() + return image, mask diff --git a/src/train.py b/src/train.py index 85625e0..7167634 100644 --- a/src/train.py +++ b/src/train.py @@ -38,7 +38,7 @@ if __name__ == "__main__": # model.load_state_dict(state_dict) # log gradients and weights regularly - logger.watch(model, log="all") + logger.watch(model.model, log="all") # Create the dataloaders datamodule = Spheres() @@ -49,7 +49,7 @@ if __name__ == "__main__": accelerator=wandb.config.DEVICE, benchmark=wandb.config.BENCHMARK, # profiler="simple", - # precision=16, + precision=16, logger=logger, log_every_n_steps=1, val_check_interval=100, diff --git a/src/unet/module.py b/src/unet/module.py index c7ae528..43b09d0 100644 --- a/src/unet/module.py +++ b/src/unet/module.py @@ -38,7 +38,7 @@ class UNetModule(pl.LightningModule): # forward pass, compute masks prediction = self.model(data) - binary = (torch.sigmoid(prediction) > 0.5).float() # TODO: check if float necessary + binary = (torch.sigmoid(prediction) > 0.5).half() # compute metrics (in dictionnary) metrics = { diff --git a/src/utils/callback.py b/src/utils/callback.py index 20497f2..969b619 100644 --- a/src/utils/callback.py +++ b/src/utils/callback.py @@ -31,7 +31,7 @@ class TableLog(Callback): zip( images.cpu(), ground_truth.cpu(), - predictions["linear"].cpu(), + predictions["linear"].cpu().float(), predictions["binary"].cpu().squeeze(1).int().numpy(), ) ): diff --git a/src/utils/paste.py b/src/utils/paste.py index 02c1a1b..be6f25d 100644 --- a/src/utils/paste.py +++ b/src/utils/paste.py @@ -87,7 +87,7 @@ class RandomPaste(A.DualTransform): img.paste(paste_img, (x, y), paste_mask) - return np.asarray(img.convert("RGB")) + return np.array(img.convert("RGB")) def apply_to_mask(self, mask, augmentations, paste_mask, **params): # convert mask to Image, needed for `paste` function @@ -116,7 +116,7 @@ class RandomPaste(A.DualTransform): mask.paste(paste_mask, (x, y), paste_mask_bin) - return np.asarray(mask.convert("L")) + return np.array(mask.convert("L")) def get_params_dependent_on_targets(self, params): # choose a random image and its corresponding mask