diff --git a/src/train.py b/src/train.py index 471b1c3..3e559a7 100644 --- a/src/train.py +++ b/src/train.py @@ -26,7 +26,6 @@ def train_net( epochs: int = 5, batch_size: int = 1, learning_rate: float = 1e-5, - val_percent: float = 0.1, save_checkpoint: bool = True, img_scale: float = 0.5, amp: bool = False, @@ -200,25 +199,17 @@ def get_args(): default=0.5, help="Downscaling factor of the images", ) - parser.add_argument( - "--validation", - "-v", - dest="val", - type=float, - default=10.0, - help="Percent of the data that is used as validation (0-100)", - ) parser.add_argument( "--amp", action="store_true", - default=False, + default=True, help="Use mixed precision", ) parser.add_argument( "--classes", "-c", type=int, - default=2, + default=1, help="Number of classes", ) @@ -232,9 +223,6 @@ if __name__ == "__main__": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logging.info(f"Using device {device}") - # Change here to adapt to your data - # n_channels=3 for RGB images - # n_classes is the number of probabilities you want to get per pixel net = UNet(n_channels=3, n_classes=args.classes) logging.info( diff --git a/src/utils/dataset.py b/src/utils/dataset.py index 4767740..76e3a1f 100644 --- a/src/utils/dataset.py +++ b/src/utils/dataset.py @@ -9,7 +9,7 @@ from PIL import Image from torch.utils.data import Dataset -class BasicDataset(Dataset): +class SphereDataset(Dataset): def __init__(self, images_dir: str, masks_dir: str, scale: float = 1.0, mask_suffix: str = ""): self.images_dir = Path(images_dir) self.masks_dir = Path(masks_dir) @@ -29,7 +29,12 @@ class BasicDataset(Dataset): def preprocess(pil_img, scale, is_mask): w, h = pil_img.size newW, newH = int(scale * w), int(scale * h) - assert newW > 0 and newH > 0, "Scale is too small, resized images would have no pixel" + + assert ( + newW > 0 and newH > 0, + "Scale is too small, resized images would have no pixel", + ) + pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC) img_ndarray = np.asarray(pil_img) @@ -46,6 +51,7 @@ class BasicDataset(Dataset): @staticmethod def load(filename): ext = splitext(filename)[1] + if ext in [".npz", ".npy"]: return Image.fromarray(np.load(filename)) elif ext in [".pt", ".pth"]: @@ -58,14 +64,22 @@ class BasicDataset(Dataset): mask_file = list(self.masks_dir.glob(name + self.mask_suffix + ".*")) img_file = list(self.images_dir.glob(name + ".*")) - assert len(img_file) == 1, f"Either no image or multiple images found for the ID {name}: {img_file}" - assert len(mask_file) == 1, f"Either no mask or multiple masks found for the ID {name}: {mask_file}" + assert ( + len(img_file) == 1, + f"Either no image or multiple images found for the ID {name}: {img_file}", + ) + assert ( + len(mask_file) == 1, + f"Either no mask or multiple masks found for the ID {name}: {mask_file}", + ) + mask = self.load(mask_file[0]) img = self.load(img_file[0]) assert ( - img.size == mask.size - ), f"Image and mask {name} should be the same size, but are {img.size} and {mask.size}" + img.size == mask.size, + f"Image and mask {name} should be the same size, but are {img.size} and {mask.size}", + ) img = self.preprocess(img, self.scale, is_mask=False) mask = self.preprocess(mask, self.scale, is_mask=True) @@ -74,8 +88,3 @@ class BasicDataset(Dataset): "image": torch.as_tensor(img.copy()).float().contiguous(), "mask": torch.as_tensor(mask.copy()).long().contiguous(), } - - -class CarvanaDataset(BasicDataset): - def __init__(self, images_dir, masks_dir, scale=1): - super().__init__(images_dir, masks_dir, scale, mask_suffix="_mask") diff --git a/src/utils/dice_score.py b/src/utils/dice_score.py index 79dc616..a29f794 100644 --- a/src/utils/dice_score.py +++ b/src/utils/dice_score.py @@ -2,15 +2,30 @@ import torch from torch import Tensor -def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6): - # Average of Dice coefficient for all batches, or for a single mask +def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float: + """Average of Dice coefficient for all batches, or for a single mask. + + Args: + input (Tensor): _description_ + target (Tensor): _description_ + reduce_batch_first (bool, optional): _description_. Defaults to False. + epsilon (_type_, optional): _description_. Defaults to 1e-6. + + Raises: + ValueError: _description_ + + Returns: + float: _description_ + """ assert input.size() == target.size() + if input.dim() == 2 and reduce_batch_first: raise ValueError(f"Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})") if input.dim() == 2 or reduce_batch_first: inter = torch.dot(input.reshape(-1), target.reshape(-1)) sets_sum = torch.sum(input) + torch.sum(target) + if sets_sum.item() == 0: sets_sum = 2 * inter @@ -18,23 +33,48 @@ def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, else: # compute and average metric for each batch element dice = 0 + for i in range(input.shape[0]): dice += dice_coeff(input[i, ...], target[i, ...]) + return dice / input.shape[0] -def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6): - # Average of Dice coefficient for all classes +def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float: + """Average of Dice coefficient for all classes. + + Args: + input (Tensor): _description_ + target (Tensor): _description_ + reduce_batch_first (bool, optional): _description_. Defaults to False. + epsilon (_type_, optional): _description_. Defaults to 1e-6. + + Returns: + float: _description_ + """ assert input.size() == target.size() + dice = 0 + for channel in range(input.shape[1]): dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon) return dice / input.shape[1] -def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False): - # Dice loss (objective to minimize) between 0 and 1 +def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False) -> float: + """Dice loss (objective to minimize) between 0 and 1. + + Args: + input (Tensor): _description_ + target (Tensor): _description_ + multiclass (bool, optional): _description_. Defaults to False. + + Returns: + float: _description_ + """ assert input.size() == target.size() + fn = multiclass_dice_coeff if multiclass else dice_coeff + return 1 - fn(input, target, reduce_batch_first=True)