mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
f5
Former-commit-id: 4207d1193b6eadcb491a72f51f2c512150f774c3 [formerly cc23543f39cfd6e0894d12d18cd33a18d2b4a20f] Former-commit-id: 23b4c151217a9db6b9e21105b1e0954ec2f78ce1
This commit is contained in:
parent
8c9ed80c6a
commit
24df16a612
|
@ -34,33 +34,28 @@ def get_args():
|
||||||
metavar="OUTPUT",
|
metavar="OUTPUT",
|
||||||
help="Filenames of output images",
|
help="Filenames of output images",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--threshold",
|
|
||||||
"-t",
|
|
||||||
type=float,
|
|
||||||
default=0.5,
|
|
||||||
help="Minimum probability value to consider a mask pixel white",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def predict_img(net, img, device, threshold):
|
def predict_img(net, img, device):
|
||||||
img = img.unsqueeze(0)
|
img = img.unsqueeze(0)
|
||||||
img = img.to(device=device, dtype=torch.float32)
|
img = img.to(device=device, dtype=torch.float32)
|
||||||
|
|
||||||
net.eval()
|
net.eval()
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
output = net(img)
|
output = net(img)
|
||||||
preds = torch.sigmoid(output)[0]
|
# preds = torch.sigmoid(output)[0]
|
||||||
full_mask = preds.cpu().squeeze()
|
# full_mask = output.squeeze(0).cpu()
|
||||||
|
|
||||||
return np.asarray(full_mask > threshold)
|
return np.asarray(output.squeeze().cpu())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = get_args()
|
args = get_args()
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||||
|
|
||||||
net = UNet(n_channels=3, n_classes=1)
|
net = UNet(n_channels=3, n_classes=1)
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
@ -86,8 +81,8 @@ if __name__ == "__main__":
|
||||||
img = aug["image"]
|
img = aug["image"]
|
||||||
|
|
||||||
logging.info(f"Predicting image {args.input}")
|
logging.info(f"Predicting image {args.input}")
|
||||||
mask = predict_img(net=net, img=img, threshold=args.threshold, device=device)
|
mask = predict_img(net=net, img=img, device=device)
|
||||||
|
|
||||||
logging.info(f"Saving prediction to {args.output}")
|
logging.info(f"Saving prediction to {args.output}")
|
||||||
mask = Image.fromarray(mask)
|
mask = Image.fromarray(mask, "L")
|
||||||
mask.write(args.output)
|
mask.save(args.output)
|
||||||
|
|
108
src/train.py
108
src/train.py
|
@ -18,40 +18,39 @@ def main():
|
||||||
# setup logging
|
# setup logging
|
||||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||||
|
|
||||||
# enable cuda, if possible
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
# setup wandb
|
# setup wandb
|
||||||
wandb.init(
|
wandb.init(
|
||||||
project="U-Net",
|
project="U-Net",
|
||||||
config=dict(
|
config=dict(
|
||||||
n_channels=3,
|
|
||||||
n_classes=1,
|
|
||||||
epochs=5,
|
|
||||||
batch_size=70,
|
|
||||||
learning_rate=1e-5,
|
|
||||||
amp=True,
|
|
||||||
num_workers=8,
|
|
||||||
pin_memory=True,
|
|
||||||
features=[16, 32, 64, 128],
|
|
||||||
benchmark=False,
|
|
||||||
device=device.type,
|
|
||||||
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/val2017",
|
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/val2017",
|
||||||
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/",
|
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/",
|
||||||
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
|
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
|
||||||
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
||||||
|
FEATURES=[64, 128, 256, 512],
|
||||||
|
N_CHANNELS=3,
|
||||||
|
N_CLASSES=1,
|
||||||
|
AMP=True,
|
||||||
|
PIN_MEMORY=True,
|
||||||
|
BENCHMARK=False,
|
||||||
|
DEVICE="cuda",
|
||||||
|
WORKERS=8,
|
||||||
|
EPOCHS=5,
|
||||||
|
BATCH_SIZE=16,
|
||||||
|
LEARNING_RATE=1e-5,
|
||||||
|
IMG_SIZE=512,
|
||||||
|
SPHERES=5,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# create device
|
||||||
|
device = torch.device(wandb.config.device)
|
||||||
|
|
||||||
# enable cudnn benchmarking
|
# enable cudnn benchmarking
|
||||||
torch.backends.cudnn.benchmark = wandb.config.benchmark
|
torch.backends.cudnn.benchmark = wandb.config.BENCHMARK
|
||||||
|
|
||||||
# 0. Create network
|
# 0. Create network
|
||||||
net = UNet(n_channels=3, n_classes=wandb.config.n_classes, features=wandb.config.features)
|
net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES)
|
||||||
wandb.config.params = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
wandb.config.parameters = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
||||||
|
|
||||||
# save initial model.pth
|
|
||||||
torch.save(net.state_dict(), "model.pth")
|
|
||||||
|
|
||||||
# transfer network to device
|
# transfer network to device
|
||||||
net.to(device=device)
|
net.to(device=device)
|
||||||
|
@ -59,10 +58,10 @@ def main():
|
||||||
# 1. Create transforms
|
# 1. Create transforms
|
||||||
tf_train = A.Compose(
|
tf_train = A.Compose(
|
||||||
[
|
[
|
||||||
A.Resize(512, 512),
|
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
|
||||||
A.Flip(),
|
A.Flip(),
|
||||||
A.ColorJitter(),
|
A.ColorJitter(),
|
||||||
RandomPaste(5, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
|
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
|
||||||
A.GaussianBlur(),
|
A.GaussianBlur(),
|
||||||
A.ISONoise(),
|
A.ISONoise(),
|
||||||
A.ToFloat(max_value=255),
|
A.ToFloat(max_value=255),
|
||||||
|
@ -71,8 +70,8 @@ def main():
|
||||||
)
|
)
|
||||||
tf_valid = A.Compose(
|
tf_valid = A.Compose(
|
||||||
[
|
[
|
||||||
A.Resize(512, 512),
|
A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
|
||||||
RandomPaste(5, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
|
RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
|
||||||
A.ToFloat(max_value=255),
|
A.ToFloat(max_value=255),
|
||||||
ToTensorV2(),
|
ToTensorV2(),
|
||||||
],
|
],
|
||||||
|
@ -83,16 +82,26 @@ def main():
|
||||||
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid)
|
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid)
|
||||||
|
|
||||||
# 3. Create data loaders
|
# 3. Create data loaders
|
||||||
loader_args = dict(
|
train_loader = DataLoader(
|
||||||
batch_size=wandb.config.batch_size, num_workers=wandb.config.num_workers, pin_memory=wandb.config.pin_memory
|
ds_train,
|
||||||
|
shuffle=True,
|
||||||
|
batch_size=wandb.config.BATCH_SIZE,
|
||||||
|
num_workers=wandb.config.WORKERS,
|
||||||
|
pin_memory=wandb.config.PIN_MEMORY,
|
||||||
|
)
|
||||||
|
val_loader = DataLoader(
|
||||||
|
ds_valid,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=True,
|
||||||
|
batch_size=wandb.config.BATCH_SIZE,
|
||||||
|
num_workers=wandb.config.WORKERS,
|
||||||
|
pin_memory=wandb.config.PIN_MEMORY,
|
||||||
)
|
)
|
||||||
train_loader = DataLoader(ds_train, shuffle=True, **loader_args)
|
|
||||||
val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args)
|
|
||||||
|
|
||||||
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
|
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp
|
||||||
optimizer = torch.optim.RMSprop(net.parameters(), lr=wandb.config.learning_rate, weight_decay=1e-8, momentum=0.9)
|
optimizer = torch.optim.RMSprop(net.parameters(), lr=wandb.config.LEARNING_RATE, weight_decay=1e-8, momentum=0.9)
|
||||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
|
||||||
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.amp)
|
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
|
||||||
criterion = torch.nn.BCEWithLogitsLoss()
|
criterion = torch.nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
# save model.pth
|
# save model.pth
|
||||||
|
@ -100,27 +109,31 @@ def main():
|
||||||
artifact = wandb.Artifact("pth", type="model")
|
artifact = wandb.Artifact("pth", type="model")
|
||||||
artifact.add_file("model.pth")
|
artifact.add_file("model.pth")
|
||||||
wandb.run.log_artifact(artifact)
|
wandb.run.log_artifact(artifact)
|
||||||
logging.info("model.pth saved")
|
|
||||||
|
|
||||||
# save model.onxx
|
# save model.onxx
|
||||||
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
|
dummy_input = torch.randn(
|
||||||
|
1, wandb.config.n_channels, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True
|
||||||
|
).to(device)
|
||||||
torch.onnx.export(net, dummy_input, "model.onnx")
|
torch.onnx.export(net, dummy_input, "model.onnx")
|
||||||
artifact = wandb.Artifact("onnx", type="model")
|
artifact = wandb.Artifact("onnx", type="model")
|
||||||
artifact.add_file("model.onnx")
|
artifact.add_file("model.onnx")
|
||||||
wandb.run.log_artifact(artifact)
|
wandb.run.log_artifact(artifact)
|
||||||
logging.info("model.onnx saved")
|
|
||||||
|
|
||||||
# print the config
|
# print the config
|
||||||
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
|
logging.info(
|
||||||
|
f"""wandb config:
|
||||||
|
{yaml.dump(wandb.config.as_dict())}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for epoch in range(1, wandb.config.epochs + 1):
|
for epoch in range(1, wandb.config.EPOCHS + 1):
|
||||||
with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.epochs}", unit="img") as pbar:
|
with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar:
|
||||||
|
|
||||||
# Training round
|
# Training round
|
||||||
for step, (images, true_masks) in enumerate(train_loader):
|
for step, (images, true_masks) in enumerate(train_loader):
|
||||||
assert images.shape[1] == net.n_channels, (
|
assert images.shape[1] == net.N_CHANNELS, (
|
||||||
f"Network has been defined with {net.n_channels} input channels, "
|
f"Network has been defined with {net.N_CHANNELS} input channels, "
|
||||||
f"but loaded images have {images.shape[1]} channels. Please check that "
|
f"but loaded images have {images.shape[1]} channels. Please check that "
|
||||||
"the images are loaded correctly."
|
"the images are loaded correctly."
|
||||||
)
|
)
|
||||||
|
@ -130,7 +143,7 @@ def main():
|
||||||
true_masks = true_masks.unsqueeze(1).to(device=device)
|
true_masks = true_masks.unsqueeze(1).to(device=device)
|
||||||
|
|
||||||
# forward
|
# forward
|
||||||
with torch.cuda.amp.autocast(enabled=wandb.config.amp):
|
with torch.cuda.amp.autocast(enabled=wandb.config.AMP):
|
||||||
pred_masks = net(images)
|
pred_masks = net(images)
|
||||||
train_loss = criterion(pred_masks, true_masks)
|
train_loss = criterion(pred_masks, true_masks)
|
||||||
|
|
||||||
|
@ -163,36 +176,25 @@ def main():
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info(
|
|
||||||
f"""Validation ended:
|
|
||||||
Train Loss: {train_loss}
|
|
||||||
Valid Score: {val_score}
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# save weights when epoch end
|
# save weights when epoch end
|
||||||
torch.save(net.state_dict(), "model.pth")
|
torch.save(net.state_dict(), "model.pth")
|
||||||
artifact = wandb.Artifact("pth", type="model")
|
artifact = wandb.Artifact("pth", type="model")
|
||||||
artifact.add_file("model.pth")
|
artifact.add_file("model.pth")
|
||||||
wandb.run.log_artifact(artifact)
|
wandb.run.log_artifact(artifact)
|
||||||
logging.info("model.pth saved")
|
|
||||||
|
|
||||||
# export model to onnx format
|
# export model to onnx format
|
||||||
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
|
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
|
||||||
torch.onnx.export(net, dummy_input, "model.onnx")
|
torch.onnx.export(net, dummy_input, "model.onnx")
|
||||||
artifact = wandb.Artifact("pnnx", type="model")
|
artifact = wandb.Artifact("onnx", type="model")
|
||||||
artifact.add_file("model.onnx")
|
artifact.add_file("model.onnx")
|
||||||
wandb.run.log_artifact(artifact)
|
wandb.run.log_artifact(artifact)
|
||||||
logging.info("model.onnx saved")
|
|
||||||
|
|
||||||
wandb.run.finish()
|
wandb.run.finish()
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
||||||
logging.info("Saved interrupt")
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main() # TODO: fix toutes les metrics, loss, accuracy, dice...
|
||||||
# TODO: fix toutes les metrics, loss, accuracy, dice...
|
|
||||||
|
|
Loading…
Reference in a new issue