feat: got positive loss values !
Former-commit-id: 84e2a715b843ecee2e12e4878fcee4a52bb0a4cb [formerly 1a5fc82bc099885853b7b4deff81b779dafd0168] Former-commit-id: c82cd66d6c432555a126e506631dfa2fd756437e
This commit is contained in:
parent
3e335fbcb5
commit
d9f2dc2bfb
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -5,6 +5,7 @@ __pycache__/
|
|||
wandb/
|
||||
images/
|
||||
|
||||
checkpoints/
|
||||
*.pth
|
||||
*.onnx
|
||||
|
||||
|
|
2
.vscode/launch.json
vendored
2
.vscode/launch.json
vendored
|
@ -19,4 +19,4 @@
|
|||
"justMyCode": true
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
|
@ -38,19 +38,6 @@ def get_args():
|
|||
return parser.parse_args()
|
||||
|
||||
|
||||
def predict_img(net, img, device):
|
||||
img = img.unsqueeze(0)
|
||||
img = img.to(device=device, dtype=torch.float32)
|
||||
|
||||
net.eval()
|
||||
with torch.inference_mode():
|
||||
output = net(img)
|
||||
# preds = torch.sigmoid(output)[0]
|
||||
# full_mask = output.squeeze(0).cpu()
|
||||
|
||||
return np.asarray(output.squeeze().cpu())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
|
@ -81,8 +68,17 @@ if __name__ == "__main__":
|
|||
img = aug["image"]
|
||||
|
||||
logging.info(f"Predicting image {args.input}")
|
||||
mask = predict_img(net=net, img=img, device=device)
|
||||
img = img.unsqueeze(0).to(device=device, dtype=torch.float32)
|
||||
|
||||
net.eval()
|
||||
with torch.inference_mode():
|
||||
mask = net(img)
|
||||
mask = torch.sigmoid(mask)[0]
|
||||
mask = mask.cpu()
|
||||
mask = mask.squeeze()
|
||||
mask = mask > 0.5
|
||||
mask = np.asarray(mask)
|
||||
|
||||
logging.info(f"Saving prediction to {args.output}")
|
||||
mask = Image.fromarray(mask, "L")
|
||||
mask = Image.fromarray(mask)
|
||||
mask.save(args.output)
|
||||
|
|
37
src/train.py
37
src/train.py
|
@ -26,7 +26,7 @@ def main():
|
|||
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/",
|
||||
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
|
||||
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
||||
FEATURES=[64, 128, 256, 512],
|
||||
FEATURES=[16, 32, 64, 128],
|
||||
N_CHANNELS=3,
|
||||
N_CLASSES=1,
|
||||
AMP=True,
|
||||
|
@ -35,8 +35,8 @@ def main():
|
|||
DEVICE="cuda",
|
||||
WORKERS=8,
|
||||
EPOCHS=5,
|
||||
BATCH_SIZE=16,
|
||||
LEARNING_RATE=1e-5,
|
||||
BATCH_SIZE=64,
|
||||
LEARNING_RATE=1e-4,
|
||||
IMG_SIZE=512,
|
||||
SPHERES=5,
|
||||
),
|
||||
|
@ -50,7 +50,8 @@ def main():
|
|||
|
||||
# 0. Create network
|
||||
net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES)
|
||||
wandb.config.parameters = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
||||
wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
||||
wandb.watch(net, log_freq=100)
|
||||
|
||||
# transfer network to device
|
||||
net.to(device=device)
|
||||
|
@ -80,6 +81,11 @@ def main():
|
|||
# 2. Create datasets
|
||||
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
|
||||
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid)
|
||||
# ds_train_bg20k = SphereDataset(image_dir="/home/lilian/data_disk/lfainsin/BG-20k/train/", transform=tf_train)
|
||||
# ds_valid_bg20k = SphereDataset(image_dir="/home/lilian/data_disk/lfainsin/BG-20k/testval/", transform=tf_valid)
|
||||
|
||||
# ds_train = torch.utils.data.ChainDataset([ds_train_coco, ds_train_bg20k])
|
||||
# ds_valid = torch.utils.data.ChainDataset([ds_valid_coco, ds_valid_bg20k]) # TODO: modifier la classe SphereDataset pour prendre plusieurs dossiers
|
||||
|
||||
# 3. Create data loaders
|
||||
train_loader = DataLoader(
|
||||
|
@ -99,24 +105,24 @@ def main():
|
|||
)
|
||||
|
||||
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp
|
||||
optimizer = torch.optim.RMSprop(net.parameters(), lr=wandb.config.LEARNING_RATE, weight_decay=1e-8, momentum=0.9)
|
||||
optimizer = torch.optim.Adam(net.parameters(), lr=wandb.config.LEARNING_RATE)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
|
||||
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
|
||||
criterion = torch.nn.BCEWithLogitsLoss()
|
||||
|
||||
# save model.pth
|
||||
wandb.watch(net, log_freq=100)
|
||||
torch.save(net.state_dict(), "checkpoints/model-0.pth")
|
||||
artifact = wandb.Artifact("pth", type="model")
|
||||
artifact.add_file("model.pth")
|
||||
artifact.add_file("checkpoints/model-0.pth")
|
||||
wandb.run.log_artifact(artifact)
|
||||
|
||||
# save model.onxx
|
||||
dummy_input = torch.randn(
|
||||
1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True
|
||||
).to(device)
|
||||
torch.onnx.export(net, dummy_input, "model.onnx")
|
||||
torch.onnx.export(net, dummy_input, "checkpoints/model-0.onnx")
|
||||
artifact = wandb.Artifact("onnx", type="model")
|
||||
artifact.add_file("model.onnx")
|
||||
artifact.add_file("checkpoints/model-0.onnx")
|
||||
wandb.run.log_artifact(artifact)
|
||||
|
||||
# print the config
|
||||
|
@ -145,7 +151,7 @@ def main():
|
|||
# forward
|
||||
with torch.cuda.amp.autocast(enabled=wandb.config.AMP):
|
||||
pred_masks = net(images)
|
||||
train_loss = criterion(pred_masks, true_masks)
|
||||
train_loss = criterion(true_masks, pred_masks)
|
||||
|
||||
# backward
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
@ -167,7 +173,7 @@ def main():
|
|||
|
||||
# Evaluation round
|
||||
val_score = evaluate(net, val_loader, device)
|
||||
scheduler.step(val_score)
|
||||
# scheduler.step(val_score)
|
||||
|
||||
# log validation metrics
|
||||
wandb.log(
|
||||
|
@ -177,18 +183,19 @@ def main():
|
|||
)
|
||||
|
||||
# save weights when epoch end
|
||||
torch.save(net.state_dict(), "model.pth")
|
||||
torch.save(net.state_dict(), f"checkpoints/model-{epoch}.pth")
|
||||
artifact = wandb.Artifact("pth", type="model")
|
||||
artifact.add_file("model.pth")
|
||||
artifact.add_file(f"checkpoints/model-{epoch}.pth")
|
||||
wandb.run.log_artifact(artifact)
|
||||
|
||||
# export model to onnx format
|
||||
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
|
||||
torch.onnx.export(net, dummy_input, "model.onnx")
|
||||
torch.onnx.export(net, dummy_input, f"checkpoints/model-{epoch}.onnx")
|
||||
artifact = wandb.Artifact("onnx", type="model")
|
||||
artifact.add_file("model.onnx")
|
||||
artifact.add_file(f"checkpoints/model-{epoch}.onnx")
|
||||
wandb.run.log_artifact(artifact)
|
||||
|
||||
# stop wandb
|
||||
wandb.run.finish()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
|
|
|
@ -70,7 +70,10 @@ class OutConv(nn.Module):
|
|||
def __init__(self, in_channels, out_channels):
|
||||
super(OutConv, self).__init__()
|
||||
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
|
Loading…
Reference in a new issue