feat: got positive loss values !

Former-commit-id: 84e2a715b843ecee2e12e4878fcee4a52bb0a4cb [formerly 1a5fc82bc099885853b7b4deff81b779dafd0168]
Former-commit-id: c82cd66d6c432555a126e506631dfa2fd756437e
This commit is contained in:
Laurent Fainsin 2022-06-30 16:47:28 +02:00
parent 3e335fbcb5
commit d9f2dc2bfb
5 changed files with 39 additions and 32 deletions

1
.gitignore vendored
View file

@ -5,6 +5,7 @@ __pycache__/
wandb/ wandb/
images/ images/
checkpoints/
*.pth *.pth
*.onnx *.onnx

2
.vscode/launch.json vendored
View file

@ -19,4 +19,4 @@
"justMyCode": true "justMyCode": true
} }
] ]
} }

View file

@ -38,19 +38,6 @@ def get_args():
return parser.parse_args() return parser.parse_args()
def predict_img(net, img, device):
img = img.unsqueeze(0)
img = img.to(device=device, dtype=torch.float32)
net.eval()
with torch.inference_mode():
output = net(img)
# preds = torch.sigmoid(output)[0]
# full_mask = output.squeeze(0).cpu()
return np.asarray(output.squeeze().cpu())
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
@ -81,8 +68,17 @@ if __name__ == "__main__":
img = aug["image"] img = aug["image"]
logging.info(f"Predicting image {args.input}") logging.info(f"Predicting image {args.input}")
mask = predict_img(net=net, img=img, device=device) img = img.unsqueeze(0).to(device=device, dtype=torch.float32)
net.eval()
with torch.inference_mode():
mask = net(img)
mask = torch.sigmoid(mask)[0]
mask = mask.cpu()
mask = mask.squeeze()
mask = mask > 0.5
mask = np.asarray(mask)
logging.info(f"Saving prediction to {args.output}") logging.info(f"Saving prediction to {args.output}")
mask = Image.fromarray(mask, "L") mask = Image.fromarray(mask)
mask.save(args.output) mask.save(args.output)

View file

@ -26,7 +26,7 @@ def main():
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/", DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/",
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/", DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/", DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
FEATURES=[64, 128, 256, 512], FEATURES=[16, 32, 64, 128],
N_CHANNELS=3, N_CHANNELS=3,
N_CLASSES=1, N_CLASSES=1,
AMP=True, AMP=True,
@ -35,8 +35,8 @@ def main():
DEVICE="cuda", DEVICE="cuda",
WORKERS=8, WORKERS=8,
EPOCHS=5, EPOCHS=5,
BATCH_SIZE=16, BATCH_SIZE=64,
LEARNING_RATE=1e-5, LEARNING_RATE=1e-4,
IMG_SIZE=512, IMG_SIZE=512,
SPHERES=5, SPHERES=5,
), ),
@ -50,7 +50,8 @@ def main():
# 0. Create network # 0. Create network
net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES) net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES)
wandb.config.parameters = sum(p.numel() for p in net.parameters() if p.requires_grad) wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad)
wandb.watch(net, log_freq=100)
# transfer network to device # transfer network to device
net.to(device=device) net.to(device=device)
@ -80,6 +81,11 @@ def main():
# 2. Create datasets # 2. Create datasets
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train) ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid) ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid)
# ds_train_bg20k = SphereDataset(image_dir="/home/lilian/data_disk/lfainsin/BG-20k/train/", transform=tf_train)
# ds_valid_bg20k = SphereDataset(image_dir="/home/lilian/data_disk/lfainsin/BG-20k/testval/", transform=tf_valid)
# ds_train = torch.utils.data.ChainDataset([ds_train_coco, ds_train_bg20k])
# ds_valid = torch.utils.data.ChainDataset([ds_valid_coco, ds_valid_bg20k]) # TODO: modifier la classe SphereDataset pour prendre plusieurs dossiers
# 3. Create data loaders # 3. Create data loaders
train_loader = DataLoader( train_loader = DataLoader(
@ -99,24 +105,24 @@ def main():
) )
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp
optimizer = torch.optim.RMSprop(net.parameters(), lr=wandb.config.LEARNING_RATE, weight_decay=1e-8, momentum=0.9) optimizer = torch.optim.Adam(net.parameters(), lr=wandb.config.LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP) grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
criterion = torch.nn.BCEWithLogitsLoss() criterion = torch.nn.BCEWithLogitsLoss()
# save model.pth # save model.pth
wandb.watch(net, log_freq=100) torch.save(net.state_dict(), "checkpoints/model-0.pth")
artifact = wandb.Artifact("pth", type="model") artifact = wandb.Artifact("pth", type="model")
artifact.add_file("model.pth") artifact.add_file("checkpoints/model-0.pth")
wandb.run.log_artifact(artifact) wandb.run.log_artifact(artifact)
# save model.onxx # save model.onxx
dummy_input = torch.randn( dummy_input = torch.randn(
1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True 1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True
).to(device) ).to(device)
torch.onnx.export(net, dummy_input, "model.onnx") torch.onnx.export(net, dummy_input, "checkpoints/model-0.onnx")
artifact = wandb.Artifact("onnx", type="model") artifact = wandb.Artifact("onnx", type="model")
artifact.add_file("model.onnx") artifact.add_file("checkpoints/model-0.onnx")
wandb.run.log_artifact(artifact) wandb.run.log_artifact(artifact)
# print the config # print the config
@ -145,7 +151,7 @@ def main():
# forward # forward
with torch.cuda.amp.autocast(enabled=wandb.config.AMP): with torch.cuda.amp.autocast(enabled=wandb.config.AMP):
pred_masks = net(images) pred_masks = net(images)
train_loss = criterion(pred_masks, true_masks) train_loss = criterion(true_masks, pred_masks)
# backward # backward
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
@ -167,7 +173,7 @@ def main():
# Evaluation round # Evaluation round
val_score = evaluate(net, val_loader, device) val_score = evaluate(net, val_loader, device)
scheduler.step(val_score) # scheduler.step(val_score)
# log validation metrics # log validation metrics
wandb.log( wandb.log(
@ -177,18 +183,19 @@ def main():
) )
# save weights when epoch end # save weights when epoch end
torch.save(net.state_dict(), "model.pth") torch.save(net.state_dict(), f"checkpoints/model-{epoch}.pth")
artifact = wandb.Artifact("pth", type="model") artifact = wandb.Artifact("pth", type="model")
artifact.add_file("model.pth") artifact.add_file(f"checkpoints/model-{epoch}.pth")
wandb.run.log_artifact(artifact) wandb.run.log_artifact(artifact)
# export model to onnx format # export model to onnx format
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device) dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
torch.onnx.export(net, dummy_input, "model.onnx") torch.onnx.export(net, dummy_input, f"checkpoints/model-{epoch}.onnx")
artifact = wandb.Artifact("onnx", type="model") artifact = wandb.Artifact("onnx", type="model")
artifact.add_file("model.onnx") artifact.add_file(f"checkpoints/model-{epoch}.onnx")
wandb.run.log_artifact(artifact) wandb.run.log_artifact(artifact)
# stop wandb
wandb.run.finish() wandb.run.finish()
except KeyboardInterrupt: except KeyboardInterrupt:

View file

@ -70,7 +70,10 @@ class OutConv(nn.Module):
def __init__(self, in_channels, out_channels): def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__() super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1),
nn.Sigmoid(),
)
def forward(self, x): def forward(self, x):
return self.conv(x) return self.conv(x)