style: formatting
Former-commit-id: 2ccef30ce44d33beb611b63adef635ab2c1226bb
This commit is contained in:
parent
a42190ec61
commit
21c388ca2e
16
src/train.py
16
src/train.py
|
@ -26,7 +26,6 @@ def train_net(
|
||||||
epochs: int = 5,
|
epochs: int = 5,
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
learning_rate: float = 1e-5,
|
learning_rate: float = 1e-5,
|
||||||
val_percent: float = 0.1,
|
|
||||||
save_checkpoint: bool = True,
|
save_checkpoint: bool = True,
|
||||||
img_scale: float = 0.5,
|
img_scale: float = 0.5,
|
||||||
amp: bool = False,
|
amp: bool = False,
|
||||||
|
@ -200,25 +199,17 @@ def get_args():
|
||||||
default=0.5,
|
default=0.5,
|
||||||
help="Downscaling factor of the images",
|
help="Downscaling factor of the images",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--validation",
|
|
||||||
"-v",
|
|
||||||
dest="val",
|
|
||||||
type=float,
|
|
||||||
default=10.0,
|
|
||||||
help="Percent of the data that is used as validation (0-100)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--amp",
|
"--amp",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
default=False,
|
default=True,
|
||||||
help="Use mixed precision",
|
help="Use mixed precision",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--classes",
|
"--classes",
|
||||||
"-c",
|
"-c",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=1,
|
||||||
help="Number of classes",
|
help="Number of classes",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -232,9 +223,6 @@ if __name__ == "__main__":
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
logging.info(f"Using device {device}")
|
logging.info(f"Using device {device}")
|
||||||
|
|
||||||
# Change here to adapt to your data
|
|
||||||
# n_channels=3 for RGB images
|
|
||||||
# n_classes is the number of probabilities you want to get per pixel
|
|
||||||
net = UNet(n_channels=3, n_classes=args.classes)
|
net = UNet(n_channels=3, n_classes=args.classes)
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
|
|
|
@ -9,7 +9,7 @@ from PIL import Image
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
class BasicDataset(Dataset):
|
class SphereDataset(Dataset):
|
||||||
def __init__(self, images_dir: str, masks_dir: str, scale: float = 1.0, mask_suffix: str = ""):
|
def __init__(self, images_dir: str, masks_dir: str, scale: float = 1.0, mask_suffix: str = ""):
|
||||||
self.images_dir = Path(images_dir)
|
self.images_dir = Path(images_dir)
|
||||||
self.masks_dir = Path(masks_dir)
|
self.masks_dir = Path(masks_dir)
|
||||||
|
@ -29,7 +29,12 @@ class BasicDataset(Dataset):
|
||||||
def preprocess(pil_img, scale, is_mask):
|
def preprocess(pil_img, scale, is_mask):
|
||||||
w, h = pil_img.size
|
w, h = pil_img.size
|
||||||
newW, newH = int(scale * w), int(scale * h)
|
newW, newH = int(scale * w), int(scale * h)
|
||||||
assert newW > 0 and newH > 0, "Scale is too small, resized images would have no pixel"
|
|
||||||
|
assert (
|
||||||
|
newW > 0 and newH > 0,
|
||||||
|
"Scale is too small, resized images would have no pixel",
|
||||||
|
)
|
||||||
|
|
||||||
pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC)
|
pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC)
|
||||||
img_ndarray = np.asarray(pil_img)
|
img_ndarray = np.asarray(pil_img)
|
||||||
|
|
||||||
|
@ -46,6 +51,7 @@ class BasicDataset(Dataset):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load(filename):
|
def load(filename):
|
||||||
ext = splitext(filename)[1]
|
ext = splitext(filename)[1]
|
||||||
|
|
||||||
if ext in [".npz", ".npy"]:
|
if ext in [".npz", ".npy"]:
|
||||||
return Image.fromarray(np.load(filename))
|
return Image.fromarray(np.load(filename))
|
||||||
elif ext in [".pt", ".pth"]:
|
elif ext in [".pt", ".pth"]:
|
||||||
|
@ -58,14 +64,22 @@ class BasicDataset(Dataset):
|
||||||
mask_file = list(self.masks_dir.glob(name + self.mask_suffix + ".*"))
|
mask_file = list(self.masks_dir.glob(name + self.mask_suffix + ".*"))
|
||||||
img_file = list(self.images_dir.glob(name + ".*"))
|
img_file = list(self.images_dir.glob(name + ".*"))
|
||||||
|
|
||||||
assert len(img_file) == 1, f"Either no image or multiple images found for the ID {name}: {img_file}"
|
assert (
|
||||||
assert len(mask_file) == 1, f"Either no mask or multiple masks found for the ID {name}: {mask_file}"
|
len(img_file) == 1,
|
||||||
|
f"Either no image or multiple images found for the ID {name}: {img_file}",
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
len(mask_file) == 1,
|
||||||
|
f"Either no mask or multiple masks found for the ID {name}: {mask_file}",
|
||||||
|
)
|
||||||
|
|
||||||
mask = self.load(mask_file[0])
|
mask = self.load(mask_file[0])
|
||||||
img = self.load(img_file[0])
|
img = self.load(img_file[0])
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
img.size == mask.size
|
img.size == mask.size,
|
||||||
), f"Image and mask {name} should be the same size, but are {img.size} and {mask.size}"
|
f"Image and mask {name} should be the same size, but are {img.size} and {mask.size}",
|
||||||
|
)
|
||||||
|
|
||||||
img = self.preprocess(img, self.scale, is_mask=False)
|
img = self.preprocess(img, self.scale, is_mask=False)
|
||||||
mask = self.preprocess(mask, self.scale, is_mask=True)
|
mask = self.preprocess(mask, self.scale, is_mask=True)
|
||||||
|
@ -74,8 +88,3 @@ class BasicDataset(Dataset):
|
||||||
"image": torch.as_tensor(img.copy()).float().contiguous(),
|
"image": torch.as_tensor(img.copy()).float().contiguous(),
|
||||||
"mask": torch.as_tensor(mask.copy()).long().contiguous(),
|
"mask": torch.as_tensor(mask.copy()).long().contiguous(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class CarvanaDataset(BasicDataset):
|
|
||||||
def __init__(self, images_dir, masks_dir, scale=1):
|
|
||||||
super().__init__(images_dir, masks_dir, scale, mask_suffix="_mask")
|
|
||||||
|
|
|
@ -2,15 +2,30 @@ import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
|
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float:
|
||||||
# Average of Dice coefficient for all batches, or for a single mask
|
"""Average of Dice coefficient for all batches, or for a single mask.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (Tensor): _description_
|
||||||
|
target (Tensor): _description_
|
||||||
|
reduce_batch_first (bool, optional): _description_. Defaults to False.
|
||||||
|
epsilon (_type_, optional): _description_. Defaults to 1e-6.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: _description_
|
||||||
|
"""
|
||||||
assert input.size() == target.size()
|
assert input.size() == target.size()
|
||||||
|
|
||||||
if input.dim() == 2 and reduce_batch_first:
|
if input.dim() == 2 and reduce_batch_first:
|
||||||
raise ValueError(f"Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})")
|
raise ValueError(f"Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})")
|
||||||
|
|
||||||
if input.dim() == 2 or reduce_batch_first:
|
if input.dim() == 2 or reduce_batch_first:
|
||||||
inter = torch.dot(input.reshape(-1), target.reshape(-1))
|
inter = torch.dot(input.reshape(-1), target.reshape(-1))
|
||||||
sets_sum = torch.sum(input) + torch.sum(target)
|
sets_sum = torch.sum(input) + torch.sum(target)
|
||||||
|
|
||||||
if sets_sum.item() == 0:
|
if sets_sum.item() == 0:
|
||||||
sets_sum = 2 * inter
|
sets_sum = 2 * inter
|
||||||
|
|
||||||
|
@ -18,23 +33,48 @@ def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False,
|
||||||
else:
|
else:
|
||||||
# compute and average metric for each batch element
|
# compute and average metric for each batch element
|
||||||
dice = 0
|
dice = 0
|
||||||
|
|
||||||
for i in range(input.shape[0]):
|
for i in range(input.shape[0]):
|
||||||
dice += dice_coeff(input[i, ...], target[i, ...])
|
dice += dice_coeff(input[i, ...], target[i, ...])
|
||||||
|
|
||||||
return dice / input.shape[0]
|
return dice / input.shape[0]
|
||||||
|
|
||||||
|
|
||||||
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
|
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float:
|
||||||
# Average of Dice coefficient for all classes
|
"""Average of Dice coefficient for all classes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (Tensor): _description_
|
||||||
|
target (Tensor): _description_
|
||||||
|
reduce_batch_first (bool, optional): _description_. Defaults to False.
|
||||||
|
epsilon (_type_, optional): _description_. Defaults to 1e-6.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: _description_
|
||||||
|
"""
|
||||||
assert input.size() == target.size()
|
assert input.size() == target.size()
|
||||||
|
|
||||||
dice = 0
|
dice = 0
|
||||||
|
|
||||||
for channel in range(input.shape[1]):
|
for channel in range(input.shape[1]):
|
||||||
dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)
|
dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)
|
||||||
|
|
||||||
return dice / input.shape[1]
|
return dice / input.shape[1]
|
||||||
|
|
||||||
|
|
||||||
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
|
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False) -> float:
|
||||||
# Dice loss (objective to minimize) between 0 and 1
|
"""Dice loss (objective to minimize) between 0 and 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input (Tensor): _description_
|
||||||
|
target (Tensor): _description_
|
||||||
|
multiclass (bool, optional): _description_. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: _description_
|
||||||
|
"""
|
||||||
assert input.size() == target.size()
|
assert input.size() == target.size()
|
||||||
|
|
||||||
fn = multiclass_dice_coeff if multiclass else dice_coeff
|
fn = multiclass_dice_coeff if multiclass else dice_coeff
|
||||||
|
|
||||||
return 1 - fn(input, target, reduce_batch_first=True)
|
return 1 - fn(input, target, reduce_batch_first=True)
|
||||||
|
|
Loading…
Reference in a new issue