mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
Merge branch 'pl'
Former-commit-id: 2cff12df68a22afe6b9e1133702203c4c6d86d24 [formerly 4f606e6e20e260cf94c1384ceb65c29c9ac12164] Former-commit-id: ee4167081531b498b295c0402260da58cc4f7374
This commit is contained in:
commit
164417b13f
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -4,10 +4,12 @@ __pycache__/
|
|||
|
||||
wandb/
|
||||
images/
|
||||
lightning_logs/
|
||||
|
||||
checkpoints/
|
||||
*.pth
|
||||
*.onnx
|
||||
*.ckpt
|
||||
|
||||
*.png
|
||||
*.jpg
|
||||
|
|
8
.vscode/launch.json
vendored
8
.vscode/launch.json
vendored
|
@ -12,11 +12,13 @@
|
|||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"--input",
|
||||
"images/SM.png",
|
||||
"images/test.png",
|
||||
"--output",
|
||||
"output.png",
|
||||
"output_onnx.png",
|
||||
"--model",
|
||||
"good.onnx",
|
||||
],
|
||||
"justMyCode": true
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
9cbd3cff7e664a80a5a1fa1404898b7bba3cae0d
|
3454
poetry.lock
generated
3454
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -8,8 +8,11 @@ version = "0.1.0"
|
|||
albumentations = "^1.2.0"
|
||||
matplotlib = "^3.5.2"
|
||||
numpy = "^1.23.0"
|
||||
onnx = "^1.12.0"
|
||||
onnxruntime = "^1.11.1"
|
||||
python = ">=3.8,<3.11"
|
||||
pytorch-lightning = "^1.6.4"
|
||||
rich = "^12.4.4"
|
||||
scipy = "^1.8.1"
|
||||
torch = "^1.11.0"
|
||||
torchvision = "^0.12.0"
|
||||
|
|
1
src/comp.ipynb.REMOVED.git-id
Normal file
1
src/comp.ipynb.REMOVED.git-id
Normal file
|
@ -0,0 +1 @@
|
|||
0f3136c724eea42fdf1ee15e721ef33604e9a46d
|
105
src/dynamic.ipynb
Normal file
105
src/dynamic.ipynb
Normal file
|
@ -0,0 +1,105 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from unet import UNet\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"net = UNet(\n",
|
||||
" n_channels=3,\n",
|
||||
" n_classes=1,\n",
|
||||
" batch_size=1,\n",
|
||||
" learning_rate=1e-4,\n",
|
||||
" features=[8, 16, 32, 64],\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<All keys matched successfully>"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"net.load_state_dict(\n",
|
||||
" torch.load(\"../best.pth\")\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dummy_input = torch.randn(1, 3, 1024, 1024, requires_grad=True)\n",
|
||||
"torch.onnx.export(\n",
|
||||
" net,\n",
|
||||
" dummy_input,\n",
|
||||
" \"model-test.onnx\",\n",
|
||||
" opset_version=14,\n",
|
||||
" input_names=[\"input\"],\n",
|
||||
" output_names=[\"output\"],\n",
|
||||
" dynamic_axes={\n",
|
||||
" \"input\": {\n",
|
||||
" 2: \"height\",\n",
|
||||
" 3: \"width\",\n",
|
||||
" },\n",
|
||||
" \"output\": {\n",
|
||||
" 2: \"height\",\n",
|
||||
" 3: \"width\",\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
")\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.8.0 ('.venv': poetry)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.0"
|
||||
},
|
||||
"orig_nbformat": 4,
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
"hash": "dc80d2c03865715c8671359a6bf138f6c8ae4e26ae025f2543e0980b8db0ed7e"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
177
src/extract.ipynb
Normal file
177
src/extract.ipynb
Normal file
File diff suppressed because one or more lines are too long
|
@ -3,12 +3,12 @@ import logging
|
|||
|
||||
import albumentations as A
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnxruntime
|
||||
import torch
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from PIL import Image
|
||||
|
||||
from unet import UNet
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -38,23 +38,23 @@ def get_args():
|
|||
return parser.parse_args()
|
||||
|
||||
|
||||
def sigmoid(x):
|
||||
return 1 / (1 + np.exp(-x))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
|
||||
net = UNet(n_channels=3, n_classes=1)
|
||||
onnx_model = onnx.load(args.model)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logging.info(f"Using device {device}")
|
||||
ort_session = onnxruntime.InferenceSession(args.model)
|
||||
|
||||
logging.info("Transfering model to device")
|
||||
net.to(device=device)
|
||||
def to_numpy(tensor):
|
||||
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
|
||||
|
||||
logging.info(f"Loading model {args.model}")
|
||||
net.load_state_dict(torch.load(args.model, map_location=device))
|
||||
|
||||
logging.info(f"Loading image {args.input}")
|
||||
img = Image.open(args.input).convert("RGB")
|
||||
|
||||
logging.info(f"Preprocessing image {args.input}")
|
||||
|
@ -68,17 +68,14 @@ if __name__ == "__main__":
|
|||
img = aug["image"]
|
||||
|
||||
logging.info(f"Predicting image {args.input}")
|
||||
img = img.unsqueeze(0).to(device=device, dtype=torch.float32)
|
||||
img = img.unsqueeze(0)
|
||||
|
||||
net.eval()
|
||||
with torch.inference_mode():
|
||||
mask = net(img)
|
||||
mask = torch.sigmoid(mask)[0]
|
||||
mask = mask.cpu()
|
||||
mask = mask.squeeze()
|
||||
mask = mask > 0.5
|
||||
mask = np.asarray(mask)
|
||||
# compute ONNX Runtime output prediction
|
||||
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
|
||||
ort_outs = ort_session.run(None, ort_inputs)
|
||||
|
||||
logging.info(f"Saving prediction to {args.output}")
|
||||
mask = Image.fromarray(mask)
|
||||
mask.save(args.output)
|
||||
img_out_y = ort_outs[0]
|
||||
|
||||
img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode="L")
|
||||
|
||||
img_out_y.save(args.output)
|
||||
|
|
308
src/train.py
308
src/train.py
|
@ -1,20 +1,33 @@
|
|||
import logging
|
||||
|
||||
import albumentations as A
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import yaml
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from pytorch_lightning.callbacks import RichProgressBar
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
import wandb
|
||||
from src.utils.dataset import SphereDataset
|
||||
from unet import UNet
|
||||
from utils.dice import dice_coeff
|
||||
from utils.paste import RandomPaste
|
||||
|
||||
class_labels = {
|
||||
1: "sphere",
|
||||
CONFIG = {
|
||||
"DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/",
|
||||
"DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/val/",
|
||||
"DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/",
|
||||
"DIR_SPHERE": "/home/lilian/data_disk/lfainsin/realspheres/",
|
||||
"FEATURES": [8, 16, 32, 64],
|
||||
"N_CHANNELS": 3,
|
||||
"N_CLASSES": 1,
|
||||
"AMP": True,
|
||||
"PIN_MEMORY": True,
|
||||
"BENCHMARK": True,
|
||||
"DEVICE": "gpu",
|
||||
"WORKERS": 10,
|
||||
"EPOCHS": 1,
|
||||
"BATCH_SIZE": 32,
|
||||
"LEARNING_RATE": 1e-4,
|
||||
"WEIGHT_DECAY": 1e-8,
|
||||
"MOMENTUM": 0.9,
|
||||
"IMG_SIZE": 512,
|
||||
"SPHERES": 5,
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -22,266 +35,49 @@ if __name__ == "__main__":
|
|||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
|
||||
# setup wandb
|
||||
wandb.init(
|
||||
logger = WandbLogger(
|
||||
project="U-Net",
|
||||
config=dict(
|
||||
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/train/",
|
||||
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/val/",
|
||||
DIR_TEST_IMG="/home/lilian/data_disk/lfainsin/test/",
|
||||
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
|
||||
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
||||
FEATURES=[16, 32, 64, 128],
|
||||
N_CHANNELS=3,
|
||||
N_CLASSES=1,
|
||||
AMP=True,
|
||||
PIN_MEMORY=True,
|
||||
BENCHMARK=True,
|
||||
DEVICE="cuda",
|
||||
WORKERS=7,
|
||||
EPOCHS=5,
|
||||
BATCH_SIZE=16,
|
||||
LEARNING_RATE=1e-4,
|
||||
WEIGHT_DECAY=1e-8,
|
||||
MOMENTUM=0.9,
|
||||
IMG_SIZE=512,
|
||||
SPHERES=5,
|
||||
),
|
||||
config=CONFIG,
|
||||
settings=wandb.Settings(
|
||||
code_dir="./src/",
|
||||
),
|
||||
)
|
||||
|
||||
# create device
|
||||
device = torch.device(wandb.config.DEVICE)
|
||||
# seed random generators
|
||||
pl.seed_everything(69420, workers=True)
|
||||
|
||||
# enable cudnn benchmarking
|
||||
torch.backends.cudnn.benchmark = wandb.config.BENCHMARK
|
||||
|
||||
# 0. Create network
|
||||
net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES)
|
||||
wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
||||
|
||||
# transfer network to device
|
||||
net.to(device=device)
|
||||
|
||||
# 1. Create transforms
|
||||
tf_train = A.Compose(
|
||||
[
|
||||
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
|
||||
A.Flip(),
|
||||
A.ColorJitter(),
|
||||
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
|
||||
A.GaussianBlur(),
|
||||
A.ISONoise(),
|
||||
A.ToFloat(max_value=255),
|
||||
ToTensorV2(),
|
||||
],
|
||||
)
|
||||
tf_valid = A.Compose(
|
||||
[
|
||||
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
|
||||
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
|
||||
A.ToFloat(max_value=255),
|
||||
ToTensorV2(),
|
||||
],
|
||||
# Create network
|
||||
net = UNet(
|
||||
n_channels=CONFIG["N_CHANNELS"],
|
||||
n_classes=CONFIG["N_CLASSES"],
|
||||
batch_size=CONFIG["BATCH_SIZE"],
|
||||
learning_rate=CONFIG["LEARNING_RATE"],
|
||||
features=CONFIG["FEATURES"],
|
||||
)
|
||||
|
||||
# 2. Create datasets
|
||||
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
|
||||
ds_valid = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG)
|
||||
# log gradients and weights regularly
|
||||
logger.watch(net, log="all")
|
||||
|
||||
# 2.5. Create subset, if uncommented
|
||||
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
|
||||
|
||||
# 3. Create data loaders
|
||||
train_loader = DataLoader(
|
||||
ds_train,
|
||||
shuffle=True,
|
||||
batch_size=wandb.config.BATCH_SIZE,
|
||||
num_workers=wandb.config.WORKERS,
|
||||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
ds_valid,
|
||||
shuffle=False,
|
||||
drop_last=True,
|
||||
batch_size=1,
|
||||
num_workers=wandb.config.WORKERS,
|
||||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
)
|
||||
|
||||
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp
|
||||
optimizer = torch.optim.RMSprop(
|
||||
net.parameters(),
|
||||
lr=wandb.config.LEARNING_RATE,
|
||||
weight_decay=wandb.config.WEIGHT_DECAY,
|
||||
momentum=wandb.config.MOMENTUM,
|
||||
)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
|
||||
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
|
||||
criterion = torch.nn.BCEWithLogitsLoss()
|
||||
|
||||
# save model.onxx
|
||||
dummy_input = torch.randn(
|
||||
1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True
|
||||
).to(device)
|
||||
torch.onnx.export(net, dummy_input, "checkpoints/model.onnx")
|
||||
artifact = wandb.Artifact("onnx", type="model")
|
||||
artifact.add_file("checkpoints/model.onnx")
|
||||
wandb.run.log_artifact(artifact)
|
||||
|
||||
# log gradients and weights four time per epoch
|
||||
wandb.watch(net, log_freq=100)
|
||||
|
||||
# print the config
|
||||
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
|
||||
|
||||
# wandb init log
|
||||
wandb.log(
|
||||
{
|
||||
"train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
|
||||
},
|
||||
commit=False,
|
||||
# Create the trainer
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=CONFIG["EPOCHS"],
|
||||
accelerator=CONFIG["DEVICE"],
|
||||
# precision=16,
|
||||
# auto_scale_batch_size="binsearch",
|
||||
# auto_lr_find=True,
|
||||
benchmark=CONFIG["BENCHMARK"],
|
||||
val_check_interval=100,
|
||||
callbacks=RichProgressBar(),
|
||||
logger=logger,
|
||||
log_every_n_steps=1,
|
||||
)
|
||||
|
||||
try:
|
||||
global_step = 0
|
||||
for epoch in range(1, wandb.config.EPOCHS + 1):
|
||||
with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar:
|
||||
|
||||
# Training round
|
||||
for step, (images, true_masks) in enumerate(train_loader):
|
||||
assert images.shape[1] == net.n_channels, (
|
||||
f"Network has been defined with {net.n_channels} input channels, "
|
||||
f"but loaded images have {images.shape[1]} channels. Please check that "
|
||||
"the images are loaded correctly."
|
||||
)
|
||||
|
||||
# transfer images to device
|
||||
images = images.to(device=device)
|
||||
true_masks = true_masks.unsqueeze(1).to(device=device)
|
||||
|
||||
# forward
|
||||
with torch.cuda.amp.autocast(enabled=wandb.config.AMP):
|
||||
pred_masks = net(images)
|
||||
train_loss = criterion(pred_masks, true_masks)
|
||||
|
||||
# backward
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
grad_scaler.scale(train_loss).backward()
|
||||
grad_scaler.step(optimizer)
|
||||
grad_scaler.update()
|
||||
|
||||
# compute metrics
|
||||
pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float()
|
||||
accuracy = (true_masks == pred_masks_bin).float().mean()
|
||||
dice = dice_coeff(pred_masks_bin, true_masks)
|
||||
mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks)
|
||||
|
||||
# update tqdm progress bar
|
||||
pbar.update(images.shape[0])
|
||||
pbar.set_postfix(**{"loss": train_loss.item()})
|
||||
|
||||
# log metrics
|
||||
wandb.log(
|
||||
{
|
||||
"train/accuracy": accuracy,
|
||||
"train/bce": train_loss,
|
||||
"train/dice": dice,
|
||||
"train/mae": mae,
|
||||
}
|
||||
)
|
||||
|
||||
global_step += 1
|
||||
|
||||
if global_step % 100 == 0:
|
||||
# Evaluation round
|
||||
net.eval()
|
||||
accuracy = 0
|
||||
val_loss = 0
|
||||
dice = 0
|
||||
mae = 0
|
||||
with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar2:
|
||||
for images, masks_true in val_loader:
|
||||
|
||||
# transfer images to device
|
||||
images = images.to(device=device)
|
||||
masks_true = masks_true.unsqueeze(1).to(device=device)
|
||||
|
||||
# forward
|
||||
with torch.inference_mode():
|
||||
masks_pred = net(images)
|
||||
|
||||
# compute metrics
|
||||
val_loss += criterion(masks_pred, masks_true)
|
||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||
mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||
accuracy += (masks_true == masks_pred_bin).float().mean()
|
||||
dice += dice_coeff(masks_pred_bin, masks_true)
|
||||
|
||||
# update progress bars
|
||||
pbar2.update(images.shape[0])
|
||||
pbar.refresh()
|
||||
|
||||
accuracy /= len(val_loader)
|
||||
val_loss /= len(val_loader)
|
||||
dice /= len(val_loader)
|
||||
mae /= len(val_loader)
|
||||
|
||||
# save the last validation batch to table
|
||||
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
|
||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||
zip(
|
||||
images.cpu(),
|
||||
masks_true.cpu(),
|
||||
masks_pred.cpu(),
|
||||
masks_pred_bin.cpu().squeeze(1).int().numpy(),
|
||||
)
|
||||
):
|
||||
table.add_data(
|
||||
i,
|
||||
wandb.Image(img),
|
||||
wandb.Image(mask),
|
||||
wandb.Image(
|
||||
pred,
|
||||
masks={
|
||||
"predictions": {
|
||||
"mask_data": pred_bin,
|
||||
"class_labels": class_labels,
|
||||
},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# log validation metrics
|
||||
wandb.log(
|
||||
{
|
||||
"val/predictions": table,
|
||||
"train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
|
||||
"val/accuracy": accuracy,
|
||||
"val/bce": val_loss,
|
||||
"val/dice": dice,
|
||||
"val/mae": mae,
|
||||
},
|
||||
commit=False,
|
||||
)
|
||||
|
||||
# update hyperparameters
|
||||
net.train()
|
||||
scheduler.step(dice)
|
||||
|
||||
# export model to onnx format when validation ends
|
||||
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
|
||||
torch.onnx.export(net, dummy_input, f"checkpoints/model-{epoch}-{step}.onnx")
|
||||
artifact = wandb.Artifact("onnx", type="model")
|
||||
artifact.add_file(f"checkpoints/model-{epoch}-{step}.onnx")
|
||||
wandb.run.log_artifact(artifact)
|
||||
|
||||
# stop wandb
|
||||
wandb.run.finish()
|
||||
|
||||
trainer.tune(net)
|
||||
trainer.fit(model=net)
|
||||
except KeyboardInterrupt:
|
||||
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
||||
raise
|
||||
|
||||
# sapin de noel
|
||||
# stop wandb
|
||||
wandb.run.finish()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
""" Parts of the U-Net model """
|
||||
"""Parts of the U-Net model."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -59,8 +59,14 @@ class Up(nn.Module):
|
|||
# input is CHW
|
||||
diffY = x2.size()[2] - x1.size()[2]
|
||||
diffX = x2.size()[3] - x1.size()[3]
|
||||
diffY2 = torch.div(diffY, 2, rounding_mode="trunc")
|
||||
diffX2 = torch.div(diffX, 2, rounding_mode="trunc")
|
||||
|
||||
x1 = F.pad(
|
||||
input=x1,
|
||||
pad=[diffX2, diffX - diffX2, diffY2, diffY - diffY2],
|
||||
)
|
||||
|
||||
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
|
||||
x = torch.cat([x2, x1], dim=1)
|
||||
|
||||
return self.conv(x)
|
||||
|
|
|
@ -1,14 +1,38 @@
|
|||
""" Full assembly of the parts to form the complete network """
|
||||
|
||||
import itertools
|
||||
|
||||
import albumentations as A
|
||||
import pytorch_lightning as pl
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import wandb
|
||||
from src.utils.dataset import SphereDataset
|
||||
from utils.dice import dice_loss
|
||||
from utils.paste import RandomPaste
|
||||
|
||||
from .blocks import *
|
||||
|
||||
class_labels = {
|
||||
1: "sphere",
|
||||
}
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]):
|
||||
|
||||
class UNet(pl.LightningModule):
|
||||
def __init__(self, n_channels, n_classes, learning_rate, batch_size, features=[64, 128, 256, 512]):
|
||||
super(UNet, self).__init__()
|
||||
|
||||
# Hyperparameters
|
||||
self.n_channels = n_channels
|
||||
self.n_classes = n_classes
|
||||
self.learning_rate = learning_rate
|
||||
self.batch_size = batch_size
|
||||
|
||||
# log hyperparameters
|
||||
self.save_hyperparameters()
|
||||
|
||||
# Network
|
||||
self.inc = DoubleConv(n_channels, features[0])
|
||||
|
||||
self.downs = nn.ModuleList()
|
||||
|
@ -28,6 +52,8 @@ class UNet(nn.Module):
|
|||
def forward(self, x):
|
||||
skips = []
|
||||
|
||||
x = x.to(self.device)
|
||||
|
||||
x = self.inc(x)
|
||||
|
||||
for down in self.downs:
|
||||
|
@ -40,3 +66,224 @@ class UNet(nn.Module):
|
|||
x = self.outc(x)
|
||||
|
||||
return x
|
||||
|
||||
def train_dataloader(self):
|
||||
tf_train = A.Compose(
|
||||
[
|
||||
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
|
||||
A.Flip(),
|
||||
A.ColorJitter(),
|
||||
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE),
|
||||
A.GaussianBlur(),
|
||||
A.ISONoise(),
|
||||
A.ToFloat(max_value=255),
|
||||
ToTensorV2(),
|
||||
],
|
||||
)
|
||||
|
||||
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
|
||||
# ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 10000)))
|
||||
|
||||
return DataLoader(
|
||||
ds_train,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=wandb.config.WORKERS,
|
||||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
ds_valid = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG)
|
||||
|
||||
return DataLoader(
|
||||
ds_valid,
|
||||
shuffle=False,
|
||||
batch_size=1,
|
||||
num_workers=wandb.config.WORKERS,
|
||||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# unpacking
|
||||
images, masks_true = batch
|
||||
masks_true = masks_true.unsqueeze(1)
|
||||
|
||||
# forward pass
|
||||
masks_pred = self(images)
|
||||
|
||||
# compute metrics
|
||||
bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true)
|
||||
dice = dice_loss(masks_pred, masks_true)
|
||||
|
||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||
dice_bin = dice_loss(masks_pred_bin, masks_true, logits=False)
|
||||
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||
accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||
|
||||
self.log_dict(
|
||||
{
|
||||
"train/accuracy": accuracy,
|
||||
"train/dice": dice,
|
||||
"train/dice_bin": dice_bin,
|
||||
"train/bce": bce,
|
||||
"train/mae": mae,
|
||||
},
|
||||
)
|
||||
|
||||
if batch_idx == 22000:
|
||||
rows = []
|
||||
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
|
||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||
zip(
|
||||
images.cpu(),
|
||||
masks_true.cpu(),
|
||||
masks_pred.cpu(),
|
||||
masks_pred_bin.cpu().squeeze(1).int().numpy(),
|
||||
)
|
||||
):
|
||||
rows.append(
|
||||
[
|
||||
i,
|
||||
wandb.Image(img),
|
||||
wandb.Image(mask),
|
||||
wandb.Image(
|
||||
pred,
|
||||
masks={
|
||||
"predictions": {
|
||||
"mask_data": pred_bin,
|
||||
"class_labels": class_labels,
|
||||
},
|
||||
},
|
||||
),
|
||||
dice,
|
||||
dice_bin,
|
||||
]
|
||||
)
|
||||
|
||||
# logging
|
||||
try: # required by autofinding, logger replaced by dummy
|
||||
self.logger.log_table(
|
||||
key="train/predictions",
|
||||
columns=columns,
|
||||
data=rows,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
return dict(
|
||||
accuracy=accuracy,
|
||||
loss=dice,
|
||||
bce=bce,
|
||||
mae=mae,
|
||||
)
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
# unpacking
|
||||
images, masks_true = batch
|
||||
masks_true = masks_true.unsqueeze(1)
|
||||
|
||||
# forward pass
|
||||
masks_pred = self(images)
|
||||
|
||||
# compute metrics
|
||||
bce = F.binary_cross_entropy_with_logits(masks_pred, masks_true)
|
||||
dice = dice_loss(masks_pred, masks_true)
|
||||
|
||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||
dice_bin = dice_loss(masks_pred_bin, masks_true, logits=False)
|
||||
mae = torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
|
||||
accuracy = (masks_true == masks_pred_bin).float().mean()
|
||||
|
||||
rows = []
|
||||
if batch_idx % 50 == 0 or dice < 0.1:
|
||||
for i, (img, mask, pred, pred_bin) in enumerate(
|
||||
zip(
|
||||
images.cpu(),
|
||||
masks_true.cpu(),
|
||||
masks_pred.cpu(),
|
||||
masks_pred_bin.cpu().squeeze(1).int().numpy(),
|
||||
)
|
||||
):
|
||||
rows.append(
|
||||
[
|
||||
i,
|
||||
wandb.Image(img),
|
||||
wandb.Image(mask),
|
||||
wandb.Image(
|
||||
pred,
|
||||
masks={
|
||||
"predictions": {
|
||||
"mask_data": pred_bin,
|
||||
"class_labels": class_labels,
|
||||
},
|
||||
},
|
||||
),
|
||||
dice,
|
||||
dice_bin,
|
||||
]
|
||||
)
|
||||
|
||||
return dict(
|
||||
accuracy=accuracy,
|
||||
loss=dice,
|
||||
dice_bin=dice_bin,
|
||||
bce=bce,
|
||||
mae=mae,
|
||||
table_rows=rows,
|
||||
)
|
||||
|
||||
def validation_epoch_end(self, validation_outputs):
|
||||
# matrics unpacking
|
||||
accuracy = torch.stack([d["accuracy"] for d in validation_outputs]).mean()
|
||||
dice_bin = torch.stack([d["dice_bin"] for d in validation_outputs]).mean()
|
||||
loss = torch.stack([d["loss"] for d in validation_outputs]).mean()
|
||||
bce = torch.stack([d["bce"] for d in validation_outputs]).mean()
|
||||
mae = torch.stack([d["mae"] for d in validation_outputs]).mean()
|
||||
|
||||
# table unpacking
|
||||
columns = ["ID", "image", "ground truth", "prediction", "dice", "dice_bin"]
|
||||
rowss = [d["table_rows"] for d in validation_outputs]
|
||||
rows = list(itertools.chain.from_iterable(rowss))
|
||||
|
||||
# logging
|
||||
try: # required by autofinding, logger replaced by dummy
|
||||
self.logger.log_table(
|
||||
key="val/predictions",
|
||||
columns=columns,
|
||||
data=rows,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
self.log_dict(
|
||||
{
|
||||
"val/accuracy": accuracy,
|
||||
"val/dice": loss,
|
||||
"val/dice_bin": dice_bin,
|
||||
"val/bce": bce,
|
||||
"val/mae": mae,
|
||||
}
|
||||
)
|
||||
|
||||
# export model to pth
|
||||
torch.save(self.state_dict(), f"checkpoints/model.pth")
|
||||
artifact = wandb.Artifact("pth", type="model")
|
||||
artifact.add_file("checkpoints/model.pth")
|
||||
wandb.run.log_artifact(artifact)
|
||||
|
||||
# export model to onnx
|
||||
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True)
|
||||
torch.onnx.export(self, dummy_input, f"checkpoints/model.onnx")
|
||||
artifact = wandb.Artifact("onnx", type="model")
|
||||
artifact.add_file("checkpoints/model.onnx")
|
||||
wandb.run.log_artifact(artifact)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.RMSprop(
|
||||
self.parameters(),
|
||||
lr=self.learning_rate,
|
||||
weight_decay=wandb.config.WEIGHT_DECAY,
|
||||
momentum=wandb.config.MOMENTUM,
|
||||
)
|
||||
|
||||
return optimizer
|
||||
|
|
|
@ -1,80 +1,18 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float:
|
||||
"""Average of Dice coefficient for all batches, or for a single mask.
|
||||
def dice_score(inputs, targets, smooth=1, logits=True):
|
||||
# comment out if your model contains a sigmoid or equivalent activation layer
|
||||
if logits:
|
||||
inputs = torch.sigmoid(inputs)
|
||||
|
||||
Args:
|
||||
input (Tensor): _description_
|
||||
target (Tensor): _description_
|
||||
reduce_batch_first (bool, optional): _description_. Defaults to False.
|
||||
epsilon (_type_, optional): _description_. Defaults to 1e-6.
|
||||
# flatten label and prediction tensors
|
||||
inputs = inputs.view(-1)
|
||||
targets = targets.view(-1)
|
||||
|
||||
Raises:
|
||||
ValueError: _description_
|
||||
|
||||
Returns:
|
||||
float: _description_
|
||||
"""
|
||||
assert input.size() == target.size()
|
||||
|
||||
if input.dim() == 2 and reduce_batch_first:
|
||||
raise ValueError(f"Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})")
|
||||
|
||||
if input.dim() == 2 or reduce_batch_first:
|
||||
inter = torch.dot(input.reshape(-1), target.reshape(-1))
|
||||
sets_sum = torch.sum(input) + torch.sum(target)
|
||||
|
||||
if sets_sum.item() == 0:
|
||||
sets_sum = 2 * inter
|
||||
|
||||
return (2 * inter + epsilon) / (sets_sum + epsilon)
|
||||
else:
|
||||
# compute and average metric for each batch element
|
||||
dice = 0
|
||||
|
||||
for i in range(input.shape[0]):
|
||||
dice += dice_coeff(input[i, ...], target[i, ...])
|
||||
|
||||
return dice / input.shape[0]
|
||||
intersection = (inputs * targets).sum()
|
||||
return (2.0 * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
|
||||
|
||||
|
||||
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6) -> float:
|
||||
"""Average of Dice coefficient for all classes.
|
||||
|
||||
Args:
|
||||
input (Tensor): _description_
|
||||
target (Tensor): _description_
|
||||
reduce_batch_first (bool, optional): _description_. Defaults to False.
|
||||
epsilon (_type_, optional): _description_. Defaults to 1e-6.
|
||||
|
||||
Returns:
|
||||
float: _description_
|
||||
"""
|
||||
assert input.size() == target.size()
|
||||
|
||||
dice = 0
|
||||
|
||||
for channel in range(input.shape[1]):
|
||||
dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)
|
||||
|
||||
return dice / input.shape[1]
|
||||
|
||||
|
||||
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False) -> float:
|
||||
"""Dice loss (objective to minimize) between 0 and 1.
|
||||
|
||||
Args:
|
||||
input (Tensor): _description_
|
||||
target (Tensor): _description_
|
||||
multiclass (bool, optional): _description_. Defaults to False.
|
||||
|
||||
Returns:
|
||||
float: _description_
|
||||
"""
|
||||
assert input.size() == target.size()
|
||||
|
||||
fn = multiclass_dice_coeff if multiclass else dice_coeff
|
||||
|
||||
return 1 - fn(input, target, reduce_batch_first=True)
|
||||
def dice_loss(inputs, targets, smooth=1, logits=True):
|
||||
return 1 - dice_score(inputs, targets, smooth, logits)
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import os
|
||||
import random as rd
|
||||
from pathlib import Path
|
||||
|
||||
import albumentations as A
|
||||
import numpy as np
|
||||
from PIL import Image, ImageEnhance
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class RandomPaste(A.DualTransform):
|
||||
|
@ -22,15 +23,15 @@ class RandomPaste(A.DualTransform):
|
|||
def __init__(
|
||||
self,
|
||||
nb,
|
||||
path_paste_img_dir,
|
||||
path_paste_mask_dir,
|
||||
scale_range=(0.1, 0.2),
|
||||
image_dir,
|
||||
scale_range=(0.05, 0.25),
|
||||
always_apply=True,
|
||||
p=1.0,
|
||||
):
|
||||
super().__init__(always_apply, p)
|
||||
self.path_paste_img_dir = path_paste_img_dir
|
||||
self.path_paste_mask_dir = path_paste_mask_dir
|
||||
self.images = []
|
||||
self.images.extend(list(Path(image_dir).glob("**/*.jpg")))
|
||||
self.images.extend(list(Path(image_dir).glob("**/*.png")))
|
||||
self.scale_range = scale_range
|
||||
self.nb = nb
|
||||
|
||||
|
@ -38,105 +39,129 @@ class RandomPaste(A.DualTransform):
|
|||
def targets_as_params(self):
|
||||
return ["image"]
|
||||
|
||||
def apply(self, img, positions, paste_img, paste_mask, **params):
|
||||
def apply(self, img, augmentations, paste_img, paste_mask, **params):
|
||||
# convert img to Image, needed for `paste` function
|
||||
img = Image.fromarray(img)
|
||||
|
||||
# copy paste_img and paste_mask
|
||||
paste_mask = paste_mask.copy()
|
||||
paste_img = paste_img.copy()
|
||||
|
||||
# paste spheres
|
||||
for pos in positions:
|
||||
img.paste(paste_img, pos, paste_mask)
|
||||
for (x, y, shearx, sheary, shape, angle, brightness, contrast) in augmentations:
|
||||
paste_img = T.functional.adjust_contrast(
|
||||
paste_img,
|
||||
contrast_factor=contrast,
|
||||
)
|
||||
paste_img = T.functional.adjust_brightness(
|
||||
paste_img,
|
||||
brightness_factor=brightness,
|
||||
)
|
||||
paste_img = T.functional.affine(
|
||||
paste_img,
|
||||
scale=0.95,
|
||||
angle=angle,
|
||||
translate=(0, 0),
|
||||
shear=(shearx, sheary),
|
||||
interpolation=T.InterpolationMode.BICUBIC,
|
||||
)
|
||||
paste_img = T.functional.resize(
|
||||
paste_img,
|
||||
size=shape,
|
||||
interpolation=T.InterpolationMode.BICUBIC,
|
||||
)
|
||||
|
||||
paste_mask = T.functional.affine(
|
||||
paste_mask,
|
||||
scale=0.95,
|
||||
angle=angle,
|
||||
translate=(0, 0),
|
||||
shear=(shearx, sheary),
|
||||
interpolation=T.InterpolationMode.BICUBIC,
|
||||
)
|
||||
paste_mask = T.functional.resize(
|
||||
paste_mask,
|
||||
size=shape,
|
||||
interpolation=T.InterpolationMode.BICUBIC,
|
||||
)
|
||||
|
||||
img.paste(paste_img, (x, y), paste_mask)
|
||||
|
||||
return np.asarray(img.convert("RGB"))
|
||||
|
||||
def apply_to_mask(self, mask, positions, paste_mask, **params):
|
||||
def apply_to_mask(self, mask, augmentations, paste_mask, **params):
|
||||
# convert mask to Image, needed for `paste` function
|
||||
mask = Image.fromarray(mask)
|
||||
|
||||
# binarize the mask -> {0, 1}
|
||||
paste_mask_bin = paste_mask.point(lambda p: 1 if p > 10 else 0)
|
||||
# copy paste_img and paste_mask
|
||||
paste_mask = paste_mask.copy()
|
||||
|
||||
# paste spheres
|
||||
for pos in positions:
|
||||
mask.paste(paste_mask, pos, paste_mask_bin)
|
||||
for (x, y, shearx, sheary, shape, angle, _, _) in augmentations:
|
||||
paste_mask = T.functional.affine(
|
||||
paste_mask,
|
||||
scale=0.95,
|
||||
angle=angle,
|
||||
translate=(0, 0),
|
||||
shear=(shearx, sheary),
|
||||
interpolation=T.InterpolationMode.BICUBIC,
|
||||
)
|
||||
paste_mask = T.functional.resize(
|
||||
paste_mask,
|
||||
size=shape,
|
||||
interpolation=T.InterpolationMode.BICUBIC,
|
||||
)
|
||||
|
||||
# binarize the mask -> {0, 1}
|
||||
paste_mask_bin = paste_mask.point(lambda p: 1 if p > 10 else 0)
|
||||
|
||||
mask.paste(paste_mask, (x, y), paste_mask_bin)
|
||||
|
||||
return np.asarray(mask.convert("L"))
|
||||
|
||||
@staticmethod
|
||||
def overlap(positions, x1, y1, w, h):
|
||||
for x2, y2 in positions:
|
||||
if x1 + w >= x2 and x1 <= x2 + w and y1 + h >= y2 and y1 <= y2 + h:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_params_dependent_on_targets(self, params):
|
||||
# choose a random image inside the image folder
|
||||
filename = rd.choice(os.listdir(self.path_paste_img_dir))
|
||||
# choose a random image and its corresponding mask
|
||||
img_path = rd.choice(self.images)
|
||||
mask_path = img_path.parent.joinpath("MASK.PNG")
|
||||
|
||||
# load the "paste" image
|
||||
paste_img = Image.open(
|
||||
os.path.join(
|
||||
self.path_paste_img_dir,
|
||||
filename,
|
||||
)
|
||||
).convert("RGBA")
|
||||
|
||||
# load its respective mask
|
||||
paste_mask = Image.open(
|
||||
os.path.join(
|
||||
self.path_paste_mask_dir,
|
||||
filename,
|
||||
)
|
||||
).convert("LA")
|
||||
|
||||
# load the target image
|
||||
# load images (w/ transparency)
|
||||
paste_img = Image.open(img_path).convert("RGBA")
|
||||
paste_mask = Image.open(mask_path).convert("LA")
|
||||
target_img = params["image"]
|
||||
|
||||
# compute shapes
|
||||
target_shape = np.array(target_img.shape[:2], dtype=np.uint)
|
||||
paste_shape = np.array(paste_img.size, dtype=np.uint)
|
||||
|
||||
# compute the minimum scaling to fit inside target image
|
||||
# compute minimum scaling to fit inside target
|
||||
min_scale = np.min(target_shape / paste_shape)
|
||||
|
||||
# randomize the relative scaling
|
||||
scale = rd.uniform(*self.scale_range)
|
||||
|
||||
# rotate the image and its mask
|
||||
angle = rd.uniform(0, 360)
|
||||
paste_img = paste_img.rotate(angle, expand=True)
|
||||
paste_mask = paste_mask.rotate(angle, expand=True)
|
||||
|
||||
# scale the "paste" image and its mask
|
||||
paste_img = paste_img.resize(
|
||||
tuple((paste_shape * min_scale * scale).astype(np.uint)),
|
||||
resample=Image.Resampling.LANCZOS,
|
||||
)
|
||||
paste_mask = paste_mask.resize(
|
||||
tuple((paste_shape * min_scale * scale).astype(np.uint)),
|
||||
resample=Image.Resampling.LANCZOS,
|
||||
)
|
||||
|
||||
# update paste_shape after scaling
|
||||
paste_shape = np.array(paste_img.size, dtype=np.uint)
|
||||
|
||||
# change brightness randomly
|
||||
filter = ImageEnhance.Brightness(paste_img)
|
||||
paste_img = filter.enhance(rd.uniform(0.5, 1.5))
|
||||
|
||||
# generate some positions
|
||||
positions = []
|
||||
# generate augmentations
|
||||
augmentations = []
|
||||
NB = rd.randint(1, self.nb)
|
||||
while len(positions) < NB:
|
||||
x = rd.randint(0, target_shape[0] - paste_shape[0])
|
||||
y = rd.randint(0, target_shape[1] - paste_shape[1])
|
||||
while len(augmentations) < NB: # TODO: mettre une condition d'arret ite max
|
||||
scale = rd.uniform(*self.scale_range) * min_scale
|
||||
shape = np.array(paste_shape * scale, dtype=np.uint)
|
||||
|
||||
x = rd.randint(0, target_shape[0] - shape[0])
|
||||
y = rd.randint(0, target_shape[1] - shape[1])
|
||||
|
||||
# check for overlapping
|
||||
if RandomPaste.overlap(positions, x, y, paste_shape[0], paste_shape[1]):
|
||||
if RandomPaste.overlap(augmentations, x, y, shape[0], shape[1]):
|
||||
continue
|
||||
|
||||
positions.append((x, y))
|
||||
shearx = rd.uniform(-2, 2)
|
||||
sheary = rd.uniform(-2, 2)
|
||||
|
||||
angle = rd.uniform(0, 360)
|
||||
|
||||
brightness = rd.uniform(0.8, 1.2)
|
||||
contrast = rd.uniform(0.8, 1.2)
|
||||
|
||||
augmentations.append((x, y, shearx, sheary, tuple(shape), angle, brightness, contrast))
|
||||
|
||||
params.update(
|
||||
{
|
||||
"positions": positions,
|
||||
"augmentations": augmentations,
|
||||
"paste_img": paste_img,
|
||||
"paste_mask": paste_mask,
|
||||
}
|
||||
|
@ -144,5 +169,9 @@ class RandomPaste(A.DualTransform):
|
|||
|
||||
return params
|
||||
|
||||
def get_transform_init_args_names(self):
|
||||
return "scale_range", "path_paste_img_dir", "path_paste_mask_dir"
|
||||
@staticmethod
|
||||
def overlap(positions, x1, y1, w, h):
|
||||
for x2, y2, _, _, _, _, _, _ in positions:
|
||||
if x1 + w >= x2 and x1 <= x2 + w and y1 + h >= y2 and y1 <= y2 + h:
|
||||
return True
|
||||
return False
|
||||
|
|
Loading…
Reference in a new issue