feat: got positive loss values !

Former-commit-id: 84e2a715b843ecee2e12e4878fcee4a52bb0a4cb [formerly 1a5fc82bc099885853b7b4deff81b779dafd0168]
Former-commit-id: c82cd66d6c432555a126e506631dfa2fd756437e
This commit is contained in:
Laurent Fainsin 2022-06-30 16:47:28 +02:00
parent 3e335fbcb5
commit d9f2dc2bfb
5 changed files with 39 additions and 32 deletions

1
.gitignore vendored
View file

@ -5,6 +5,7 @@ __pycache__/
wandb/
images/
checkpoints/
*.pth
*.onnx

View file

@ -38,19 +38,6 @@ def get_args():
return parser.parse_args()
def predict_img(net, img, device):
img = img.unsqueeze(0)
img = img.to(device=device, dtype=torch.float32)
net.eval()
with torch.inference_mode():
output = net(img)
# preds = torch.sigmoid(output)[0]
# full_mask = output.squeeze(0).cpu()
return np.asarray(output.squeeze().cpu())
if __name__ == "__main__":
args = get_args()
@ -81,8 +68,17 @@ if __name__ == "__main__":
img = aug["image"]
logging.info(f"Predicting image {args.input}")
mask = predict_img(net=net, img=img, device=device)
img = img.unsqueeze(0).to(device=device, dtype=torch.float32)
net.eval()
with torch.inference_mode():
mask = net(img)
mask = torch.sigmoid(mask)[0]
mask = mask.cpu()
mask = mask.squeeze()
mask = mask > 0.5
mask = np.asarray(mask)
logging.info(f"Saving prediction to {args.output}")
mask = Image.fromarray(mask, "L")
mask = Image.fromarray(mask)
mask.save(args.output)

View file

@ -26,7 +26,7 @@ def main():
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/",
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
FEATURES=[64, 128, 256, 512],
FEATURES=[16, 32, 64, 128],
N_CHANNELS=3,
N_CLASSES=1,
AMP=True,
@ -35,8 +35,8 @@ def main():
DEVICE="cuda",
WORKERS=8,
EPOCHS=5,
BATCH_SIZE=16,
LEARNING_RATE=1e-5,
BATCH_SIZE=64,
LEARNING_RATE=1e-4,
IMG_SIZE=512,
SPHERES=5,
),
@ -50,7 +50,8 @@ def main():
# 0. Create network
net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES)
wandb.config.parameters = sum(p.numel() for p in net.parameters() if p.requires_grad)
wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad)
wandb.watch(net, log_freq=100)
# transfer network to device
net.to(device=device)
@ -80,6 +81,11 @@ def main():
# 2. Create datasets
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid)
# ds_train_bg20k = SphereDataset(image_dir="/home/lilian/data_disk/lfainsin/BG-20k/train/", transform=tf_train)
# ds_valid_bg20k = SphereDataset(image_dir="/home/lilian/data_disk/lfainsin/BG-20k/testval/", transform=tf_valid)
# ds_train = torch.utils.data.ChainDataset([ds_train_coco, ds_train_bg20k])
# ds_valid = torch.utils.data.ChainDataset([ds_valid_coco, ds_valid_bg20k]) # TODO: modifier la classe SphereDataset pour prendre plusieurs dossiers
# 3. Create data loaders
train_loader = DataLoader(
@ -99,24 +105,24 @@ def main():
)
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp
optimizer = torch.optim.RMSprop(net.parameters(), lr=wandb.config.LEARNING_RATE, weight_decay=1e-8, momentum=0.9)
optimizer = torch.optim.Adam(net.parameters(), lr=wandb.config.LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
criterion = torch.nn.BCEWithLogitsLoss()
# save model.pth
wandb.watch(net, log_freq=100)
torch.save(net.state_dict(), "checkpoints/model-0.pth")
artifact = wandb.Artifact("pth", type="model")
artifact.add_file("model.pth")
artifact.add_file("checkpoints/model-0.pth")
wandb.run.log_artifact(artifact)
# save model.onxx
dummy_input = torch.randn(
1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True
).to(device)
torch.onnx.export(net, dummy_input, "model.onnx")
torch.onnx.export(net, dummy_input, "checkpoints/model-0.onnx")
artifact = wandb.Artifact("onnx", type="model")
artifact.add_file("model.onnx")
artifact.add_file("checkpoints/model-0.onnx")
wandb.run.log_artifact(artifact)
# print the config
@ -145,7 +151,7 @@ def main():
# forward
with torch.cuda.amp.autocast(enabled=wandb.config.AMP):
pred_masks = net(images)
train_loss = criterion(pred_masks, true_masks)
train_loss = criterion(true_masks, pred_masks)
# backward
optimizer.zero_grad(set_to_none=True)
@ -167,7 +173,7 @@ def main():
# Evaluation round
val_score = evaluate(net, val_loader, device)
scheduler.step(val_score)
# scheduler.step(val_score)
# log validation metrics
wandb.log(
@ -177,18 +183,19 @@ def main():
)
# save weights when epoch end
torch.save(net.state_dict(), "model.pth")
torch.save(net.state_dict(), f"checkpoints/model-{epoch}.pth")
artifact = wandb.Artifact("pth", type="model")
artifact.add_file("model.pth")
artifact.add_file(f"checkpoints/model-{epoch}.pth")
wandb.run.log_artifact(artifact)
# export model to onnx format
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
torch.onnx.export(net, dummy_input, "model.onnx")
torch.onnx.export(net, dummy_input, f"checkpoints/model-{epoch}.onnx")
artifact = wandb.Artifact("onnx", type="model")
artifact.add_file("model.onnx")
artifact.add_file(f"checkpoints/model-{epoch}.onnx")
wandb.run.log_artifact(artifact)
# stop wandb
wandb.run.finish()
except KeyboardInterrupt:

View file

@ -70,7 +70,10 @@ class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1),
nn.Sigmoid(),
)
def forward(self, x):
return self.conv(x)