style: formatting

Former-commit-id: 2ccef30ce44d33beb611b63adef635ab2c1226bb
This commit is contained in:
Your Name 2022-06-27 16:40:04 +02:00
parent a42190ec61
commit 21c388ca2e
3 changed files with 68 additions and 31 deletions

View file

@ -26,7 +26,6 @@ def train_net(
epochs: int = 5,
batch_size: int = 1,
learning_rate: float = 1e-5,
val_percent: float = 0.1,
save_checkpoint: bool = True,
img_scale: float = 0.5,
amp: bool = False,
@ -200,25 +199,17 @@ def get_args():
default=0.5,
help="Downscaling factor of the images",
)
parser.add_argument(
"--validation",
"-v",
dest="val",
type=float,
default=10.0,
help="Percent of the data that is used as validation (0-100)",
)
parser.add_argument(
"--amp",
action="store_true",
default=False,
default=True,
help="Use mixed precision",
)
parser.add_argument(
"--classes",
"-c",
type=int,
default=2,
default=1,
help="Number of classes",
)
@ -232,9 +223,6 @@ if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device {device}")
# Change here to adapt to your data
# n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel
net = UNet(n_channels=3, n_classes=args.classes)
logging.info(

View file

@ -9,7 +9,7 @@ from PIL import Image
from torch.utils.data import Dataset
class BasicDataset(Dataset):
class SphereDataset(Dataset):
def __init__(self, images_dir: str, masks_dir: str, scale: float = 1.0, mask_suffix: str = ""):
self.images_dir = Path(images_dir)
self.masks_dir = Path(masks_dir)
@ -29,7 +29,12 @@ class BasicDataset(Dataset):
def preprocess(pil_img, scale, is_mask):
w, h = pil_img.size
newW, newH = int(scale * w), int(scale * h)
assert newW > 0 and newH > 0, "Scale is too small, resized images would have no pixel"
assert (
newW > 0 and newH > 0,
"Scale is too small, resized images would have no pixel",
)
pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC)
img_ndarray = np.asarray(pil_img)
@ -46,6 +51,7 @@ class BasicDataset(Dataset):
@staticmethod
def load(filename):
ext = splitext(filename)[1]
if ext in [".npz", ".npy"]:
return Image.fromarray(np.load(filename))
elif ext in [".pt", ".pth"]:
@ -58,14 +64,22 @@ class BasicDataset(Dataset):
mask_file = list(self.masks_dir.glob(name + self.mask_suffix + ".*"))
img_file = list(self.images_dir.glob(name + ".*"))
assert len(img_file) == 1, f"Either no image or multiple images found for the ID {name}: {img_file}"
assert len(mask_file) == 1, f"Either no mask or multiple masks found for the ID {name}: {mask_file}"
assert (
len(img_file) == 1,
f"Either no image or multiple images found for the ID {name}: {img_file}",
)
assert (
len(mask_file) == 1,
f"Either no mask or multiple masks found for the ID {name}: {mask_file}",
)
mask = self.load(mask_file[0])
img = self.load(img_file[0])
assert (
img.size == mask.size
), f"Image and mask {name} should be the same size, but are {img.size} and {mask.size}"
img.size == mask.size,
f"Image and mask {name} should be the same size, but are {img.size} and {mask.size}",
)
img = self.preprocess(img, self.scale, is_mask=False)
mask = self.preprocess(mask, self.scale, is_mask=True)
@ -74,8 +88,3 @@ class BasicDataset(Dataset):
"image": torch.as_tensor(img.copy()).float().contiguous(),
"mask": torch.as_tensor(mask.copy()).long().contiguous(),
}
class CarvanaDataset(BasicDataset):
def __init__(self, images_dir, masks_dir, scale=1):
super().__init__(images_dir, masks_dir, scale, mask_suffix="_mask")

View file

@ -2,15 +2,30 @@ import torch
from torch import Tensor
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
# Average of Dice coefficient for all batches, or for a single mask
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float:
"""Average of Dice coefficient for all batches, or for a single mask.
Args:
input (Tensor): _description_
target (Tensor): _description_
reduce_batch_first (bool, optional): _description_. Defaults to False.
epsilon (_type_, optional): _description_. Defaults to 1e-6.
Raises:
ValueError: _description_
Returns:
float: _description_
"""
assert input.size() == target.size()
if input.dim() == 2 and reduce_batch_first:
raise ValueError(f"Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})")
if input.dim() == 2 or reduce_batch_first:
inter = torch.dot(input.reshape(-1), target.reshape(-1))
sets_sum = torch.sum(input) + torch.sum(target)
if sets_sum.item() == 0:
sets_sum = 2 * inter
@ -18,23 +33,48 @@ def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False,
else:
# compute and average metric for each batch element
dice = 0
for i in range(input.shape[0]):
dice += dice_coeff(input[i, ...], target[i, ...])
return dice / input.shape[0]
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
# Average of Dice coefficient for all classes
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float:
"""Average of Dice coefficient for all classes.
Args:
input (Tensor): _description_
target (Tensor): _description_
reduce_batch_first (bool, optional): _description_. Defaults to False.
epsilon (_type_, optional): _description_. Defaults to 1e-6.
Returns:
float: _description_
"""
assert input.size() == target.size()
dice = 0
for channel in range(input.shape[1]):
dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)
return dice / input.shape[1]
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
# Dice loss (objective to minimize) between 0 and 1
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False) -> float:
"""Dice loss (objective to minimize) between 0 and 1.
Args:
input (Tensor): _description_
target (Tensor): _description_
multiclass (bool, optional): _description_. Defaults to False.
Returns:
float: _description_
"""
assert input.size() == target.size()
fn = multiclass_dice_coeff if multiclass else dice_coeff
return 1 - fn(input, target, reduce_batch_first=True)