Former-commit-id: 4207d1193b6eadcb491a72f51f2c512150f774c3 [formerly cc23543f39cfd6e0894d12d18cd33a18d2b4a20f]
Former-commit-id: 23b4c151217a9db6b9e21105b1e0954ec2f78ce1
This commit is contained in:
Laurent Fainsin 2022-06-30 14:36:48 +02:00
parent 8c9ed80c6a
commit 24df16a612
2 changed files with 64 additions and 67 deletions

View file

@ -34,33 +34,28 @@ def get_args():
metavar="OUTPUT", metavar="OUTPUT",
help="Filenames of output images", help="Filenames of output images",
) )
parser.add_argument(
"--threshold",
"-t",
type=float,
default=0.5,
help="Minimum probability value to consider a mask pixel white",
)
return parser.parse_args() return parser.parse_args()
def predict_img(net, img, device, threshold): def predict_img(net, img, device):
img = img.unsqueeze(0) img = img.unsqueeze(0)
img = img.to(device=device, dtype=torch.float32) img = img.to(device=device, dtype=torch.float32)
net.eval() net.eval()
with torch.inference_mode(): with torch.inference_mode():
output = net(img) output = net(img)
preds = torch.sigmoid(output)[0] # preds = torch.sigmoid(output)[0]
full_mask = preds.cpu().squeeze() # full_mask = output.squeeze(0).cpu()
return np.asarray(full_mask > threshold) return np.asarray(output.squeeze().cpu())
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
net = UNet(n_channels=3, n_classes=1) net = UNet(n_channels=3, n_classes=1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -86,8 +81,8 @@ if __name__ == "__main__":
img = aug["image"] img = aug["image"]
logging.info(f"Predicting image {args.input}") logging.info(f"Predicting image {args.input}")
mask = predict_img(net=net, img=img, threshold=args.threshold, device=device) mask = predict_img(net=net, img=img, device=device)
logging.info(f"Saving prediction to {args.output}") logging.info(f"Saving prediction to {args.output}")
mask = Image.fromarray(mask) mask = Image.fromarray(mask, "L")
mask.write(args.output) mask.save(args.output)

View file

@ -18,40 +18,39 @@ def main():
# setup logging # setup logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
# enable cuda, if possible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# setup wandb # setup wandb
wandb.init( wandb.init(
project="U-Net", project="U-Net",
config=dict( config=dict(
n_channels=3,
n_classes=1,
epochs=5,
batch_size=70,
learning_rate=1e-5,
amp=True,
num_workers=8,
pin_memory=True,
features=[16, 32, 64, 128],
benchmark=False,
device=device.type,
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/val2017", DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/val2017",
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/", DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/",
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/", DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/", DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
FEATURES=[64, 128, 256, 512],
N_CHANNELS=3,
N_CLASSES=1,
AMP=True,
PIN_MEMORY=True,
BENCHMARK=False,
DEVICE="cuda",
WORKERS=8,
EPOCHS=5,
BATCH_SIZE=16,
LEARNING_RATE=1e-5,
IMG_SIZE=512,
SPHERES=5,
), ),
) )
# create device
device = torch.device(wandb.config.device)
# enable cudnn benchmarking # enable cudnn benchmarking
torch.backends.cudnn.benchmark = wandb.config.benchmark torch.backends.cudnn.benchmark = wandb.config.BENCHMARK
# 0. Create network # 0. Create network
net = UNet(n_channels=3, n_classes=wandb.config.n_classes, features=wandb.config.features) net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES)
wandb.config.params = sum(p.numel() for p in net.parameters() if p.requires_grad) wandb.config.parameters = sum(p.numel() for p in net.parameters() if p.requires_grad)
# save initial model.pth
torch.save(net.state_dict(), "model.pth")
# transfer network to device # transfer network to device
net.to(device=device) net.to(device=device)
@ -59,10 +58,10 @@ def main():
# 1. Create transforms # 1. Create transforms
tf_train = A.Compose( tf_train = A.Compose(
[ [
A.Resize(512, 512), A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
A.Flip(), A.Flip(),
A.ColorJitter(), A.ColorJitter(),
RandomPaste(5, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK), RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
A.GaussianBlur(), A.GaussianBlur(),
A.ISONoise(), A.ISONoise(),
A.ToFloat(max_value=255), A.ToFloat(max_value=255),
@ -71,8 +70,8 @@ def main():
) )
tf_valid = A.Compose( tf_valid = A.Compose(
[ [
A.Resize(512, 512), A.Resize(wandb.config.IMG_SIZE, wandb.config.IMG_SIZE),
RandomPaste(5, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK), RandomPaste(wandb.config.SPHERES, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
A.ToFloat(max_value=255), A.ToFloat(max_value=255),
ToTensorV2(), ToTensorV2(),
], ],
@ -83,16 +82,26 @@ def main():
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid) ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid)
# 3. Create data loaders # 3. Create data loaders
loader_args = dict( train_loader = DataLoader(
batch_size=wandb.config.batch_size, num_workers=wandb.config.num_workers, pin_memory=wandb.config.pin_memory ds_train,
shuffle=True,
batch_size=wandb.config.BATCH_SIZE,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,
)
val_loader = DataLoader(
ds_valid,
shuffle=False,
drop_last=True,
batch_size=wandb.config.BATCH_SIZE,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,
) )
train_loader = DataLoader(ds_train, shuffle=True, **loader_args)
val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args)
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp
optimizer = torch.optim.RMSprop(net.parameters(), lr=wandb.config.learning_rate, weight_decay=1e-8, momentum=0.9) optimizer = torch.optim.RMSprop(net.parameters(), lr=wandb.config.LEARNING_RATE, weight_decay=1e-8, momentum=0.9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.amp) grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
criterion = torch.nn.BCEWithLogitsLoss() criterion = torch.nn.BCEWithLogitsLoss()
# save model.pth # save model.pth
@ -100,27 +109,31 @@ def main():
artifact = wandb.Artifact("pth", type="model") artifact = wandb.Artifact("pth", type="model")
artifact.add_file("model.pth") artifact.add_file("model.pth")
wandb.run.log_artifact(artifact) wandb.run.log_artifact(artifact)
logging.info("model.pth saved")
# save model.onxx # save model.onxx
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device) dummy_input = torch.randn(
1, wandb.config.n_channels, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True
).to(device)
torch.onnx.export(net, dummy_input, "model.onnx") torch.onnx.export(net, dummy_input, "model.onnx")
artifact = wandb.Artifact("onnx", type="model") artifact = wandb.Artifact("onnx", type="model")
artifact.add_file("model.onnx") artifact.add_file("model.onnx")
wandb.run.log_artifact(artifact) wandb.run.log_artifact(artifact)
logging.info("model.onnx saved")
# print the config # print the config
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}") logging.info(
f"""wandb config:
{yaml.dump(wandb.config.as_dict())}
"""
)
try: try:
for epoch in range(1, wandb.config.epochs + 1): for epoch in range(1, wandb.config.EPOCHS + 1):
with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.epochs}", unit="img") as pbar: with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar:
# Training round # Training round
for step, (images, true_masks) in enumerate(train_loader): for step, (images, true_masks) in enumerate(train_loader):
assert images.shape[1] == net.n_channels, ( assert images.shape[1] == net.N_CHANNELS, (
f"Network has been defined with {net.n_channels} input channels, " f"Network has been defined with {net.N_CHANNELS} input channels, "
f"but loaded images have {images.shape[1]} channels. Please check that " f"but loaded images have {images.shape[1]} channels. Please check that "
"the images are loaded correctly." "the images are loaded correctly."
) )
@ -130,7 +143,7 @@ def main():
true_masks = true_masks.unsqueeze(1).to(device=device) true_masks = true_masks.unsqueeze(1).to(device=device)
# forward # forward
with torch.cuda.amp.autocast(enabled=wandb.config.amp): with torch.cuda.amp.autocast(enabled=wandb.config.AMP):
pred_masks = net(images) pred_masks = net(images)
train_loss = criterion(pred_masks, true_masks) train_loss = criterion(pred_masks, true_masks)
@ -163,36 +176,25 @@ def main():
} }
) )
logging.info(
f"""Validation ended:
Train Loss: {train_loss}
Valid Score: {val_score}
"""
)
# save weights when epoch end # save weights when epoch end
torch.save(net.state_dict(), "model.pth") torch.save(net.state_dict(), "model.pth")
artifact = wandb.Artifact("pth", type="model") artifact = wandb.Artifact("pth", type="model")
artifact.add_file("model.pth") artifact.add_file("model.pth")
wandb.run.log_artifact(artifact) wandb.run.log_artifact(artifact)
logging.info("model.pth saved")
# export model to onnx format # export model to onnx format
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device) dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
torch.onnx.export(net, dummy_input, "model.onnx") torch.onnx.export(net, dummy_input, "model.onnx")
artifact = wandb.Artifact("pnnx", type="model") artifact = wandb.Artifact("onnx", type="model")
artifact.add_file("model.onnx") artifact.add_file("model.onnx")
wandb.run.log_artifact(artifact) wandb.run.log_artifact(artifact)
logging.info("model.onnx saved")
wandb.run.finish() wandb.run.finish()
except KeyboardInterrupt: except KeyboardInterrupt:
torch.save(net.state_dict(), "INTERRUPTED.pth") torch.save(net.state_dict(), "INTERRUPTED.pth")
logging.info("Saved interrupt")
raise raise
if __name__ == "__main__": if __name__ == "__main__":
main() main() # TODO: fix toutes les metrics, loss, accuracy, dice...
# TODO: fix toutes les metrics, loss, accuracy, dice...