mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
style: formatting
Former-commit-id: 2ccef30ce44d33beb611b63adef635ab2c1226bb
This commit is contained in:
parent
a42190ec61
commit
21c388ca2e
16
src/train.py
16
src/train.py
|
@ -26,7 +26,6 @@ def train_net(
|
|||
epochs: int = 5,
|
||||
batch_size: int = 1,
|
||||
learning_rate: float = 1e-5,
|
||||
val_percent: float = 0.1,
|
||||
save_checkpoint: bool = True,
|
||||
img_scale: float = 0.5,
|
||||
amp: bool = False,
|
||||
|
@ -200,25 +199,17 @@ def get_args():
|
|||
default=0.5,
|
||||
help="Downscaling factor of the images",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation",
|
||||
"-v",
|
||||
dest="val",
|
||||
type=float,
|
||||
default=10.0,
|
||||
help="Percent of the data that is used as validation (0-100)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--amp",
|
||||
action="store_true",
|
||||
default=False,
|
||||
default=True,
|
||||
help="Use mixed precision",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--classes",
|
||||
"-c",
|
||||
type=int,
|
||||
default=2,
|
||||
default=1,
|
||||
help="Number of classes",
|
||||
)
|
||||
|
||||
|
@ -232,9 +223,6 @@ if __name__ == "__main__":
|
|||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logging.info(f"Using device {device}")
|
||||
|
||||
# Change here to adapt to your data
|
||||
# n_channels=3 for RGB images
|
||||
# n_classes is the number of probabilities you want to get per pixel
|
||||
net = UNet(n_channels=3, n_classes=args.classes)
|
||||
|
||||
logging.info(
|
||||
|
|
|
@ -9,7 +9,7 @@ from PIL import Image
|
|||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class BasicDataset(Dataset):
|
||||
class SphereDataset(Dataset):
|
||||
def __init__(self, images_dir: str, masks_dir: str, scale: float = 1.0, mask_suffix: str = ""):
|
||||
self.images_dir = Path(images_dir)
|
||||
self.masks_dir = Path(masks_dir)
|
||||
|
@ -29,7 +29,12 @@ class BasicDataset(Dataset):
|
|||
def preprocess(pil_img, scale, is_mask):
|
||||
w, h = pil_img.size
|
||||
newW, newH = int(scale * w), int(scale * h)
|
||||
assert newW > 0 and newH > 0, "Scale is too small, resized images would have no pixel"
|
||||
|
||||
assert (
|
||||
newW > 0 and newH > 0,
|
||||
"Scale is too small, resized images would have no pixel",
|
||||
)
|
||||
|
||||
pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC)
|
||||
img_ndarray = np.asarray(pil_img)
|
||||
|
||||
|
@ -46,6 +51,7 @@ class BasicDataset(Dataset):
|
|||
@staticmethod
|
||||
def load(filename):
|
||||
ext = splitext(filename)[1]
|
||||
|
||||
if ext in [".npz", ".npy"]:
|
||||
return Image.fromarray(np.load(filename))
|
||||
elif ext in [".pt", ".pth"]:
|
||||
|
@ -58,14 +64,22 @@ class BasicDataset(Dataset):
|
|||
mask_file = list(self.masks_dir.glob(name + self.mask_suffix + ".*"))
|
||||
img_file = list(self.images_dir.glob(name + ".*"))
|
||||
|
||||
assert len(img_file) == 1, f"Either no image or multiple images found for the ID {name}: {img_file}"
|
||||
assert len(mask_file) == 1, f"Either no mask or multiple masks found for the ID {name}: {mask_file}"
|
||||
assert (
|
||||
len(img_file) == 1,
|
||||
f"Either no image or multiple images found for the ID {name}: {img_file}",
|
||||
)
|
||||
assert (
|
||||
len(mask_file) == 1,
|
||||
f"Either no mask or multiple masks found for the ID {name}: {mask_file}",
|
||||
)
|
||||
|
||||
mask = self.load(mask_file[0])
|
||||
img = self.load(img_file[0])
|
||||
|
||||
assert (
|
||||
img.size == mask.size
|
||||
), f"Image and mask {name} should be the same size, but are {img.size} and {mask.size}"
|
||||
img.size == mask.size,
|
||||
f"Image and mask {name} should be the same size, but are {img.size} and {mask.size}",
|
||||
)
|
||||
|
||||
img = self.preprocess(img, self.scale, is_mask=False)
|
||||
mask = self.preprocess(mask, self.scale, is_mask=True)
|
||||
|
@ -74,8 +88,3 @@ class BasicDataset(Dataset):
|
|||
"image": torch.as_tensor(img.copy()).float().contiguous(),
|
||||
"mask": torch.as_tensor(mask.copy()).long().contiguous(),
|
||||
}
|
||||
|
||||
|
||||
class CarvanaDataset(BasicDataset):
|
||||
def __init__(self, images_dir, masks_dir, scale=1):
|
||||
super().__init__(images_dir, masks_dir, scale, mask_suffix="_mask")
|
||||
|
|
|
@ -2,15 +2,30 @@ import torch
|
|||
from torch import Tensor
|
||||
|
||||
|
||||
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
|
||||
# Average of Dice coefficient for all batches, or for a single mask
|
||||
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float:
|
||||
"""Average of Dice coefficient for all batches, or for a single mask.
|
||||
|
||||
Args:
|
||||
input (Tensor): _description_
|
||||
target (Tensor): _description_
|
||||
reduce_batch_first (bool, optional): _description_. Defaults to False.
|
||||
epsilon (_type_, optional): _description_. Defaults to 1e-6.
|
||||
|
||||
Raises:
|
||||
ValueError: _description_
|
||||
|
||||
Returns:
|
||||
float: _description_
|
||||
"""
|
||||
assert input.size() == target.size()
|
||||
|
||||
if input.dim() == 2 and reduce_batch_first:
|
||||
raise ValueError(f"Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})")
|
||||
|
||||
if input.dim() == 2 or reduce_batch_first:
|
||||
inter = torch.dot(input.reshape(-1), target.reshape(-1))
|
||||
sets_sum = torch.sum(input) + torch.sum(target)
|
||||
|
||||
if sets_sum.item() == 0:
|
||||
sets_sum = 2 * inter
|
||||
|
||||
|
@ -18,23 +33,48 @@ def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False,
|
|||
else:
|
||||
# compute and average metric for each batch element
|
||||
dice = 0
|
||||
|
||||
for i in range(input.shape[0]):
|
||||
dice += dice_coeff(input[i, ...], target[i, ...])
|
||||
|
||||
return dice / input.shape[0]
|
||||
|
||||
|
||||
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
|
||||
# Average of Dice coefficient for all classes
|
||||
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float:
|
||||
"""Average of Dice coefficient for all classes.
|
||||
|
||||
Args:
|
||||
input (Tensor): _description_
|
||||
target (Tensor): _description_
|
||||
reduce_batch_first (bool, optional): _description_. Defaults to False.
|
||||
epsilon (_type_, optional): _description_. Defaults to 1e-6.
|
||||
|
||||
Returns:
|
||||
float: _description_
|
||||
"""
|
||||
assert input.size() == target.size()
|
||||
|
||||
dice = 0
|
||||
|
||||
for channel in range(input.shape[1]):
|
||||
dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)
|
||||
|
||||
return dice / input.shape[1]
|
||||
|
||||
|
||||
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
|
||||
# Dice loss (objective to minimize) between 0 and 1
|
||||
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False) -> float:
|
||||
"""Dice loss (objective to minimize) between 0 and 1.
|
||||
|
||||
Args:
|
||||
input (Tensor): _description_
|
||||
target (Tensor): _description_
|
||||
multiclass (bool, optional): _description_. Defaults to False.
|
||||
|
||||
Returns:
|
||||
float: _description_
|
||||
"""
|
||||
assert input.size() == target.size()
|
||||
|
||||
fn = multiclass_dice_coeff if multiclass else dice_coeff
|
||||
|
||||
return 1 - fn(input, target, reduce_batch_first=True)
|
||||
|
|
Loading…
Reference in a new issue