feat: got positive loss values !
Former-commit-id: 84e2a715b843ecee2e12e4878fcee4a52bb0a4cb [formerly 1a5fc82bc099885853b7b4deff81b779dafd0168] Former-commit-id: c82cd66d6c432555a126e506631dfa2fd756437e
This commit is contained in:
parent
3e335fbcb5
commit
d9f2dc2bfb
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -5,6 +5,7 @@ __pycache__/
|
||||||
wandb/
|
wandb/
|
||||||
images/
|
images/
|
||||||
|
|
||||||
|
checkpoints/
|
||||||
*.pth
|
*.pth
|
||||||
*.onnx
|
*.onnx
|
||||||
|
|
||||||
|
|
|
@ -38,19 +38,6 @@ def get_args():
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def predict_img(net, img, device):
|
|
||||||
img = img.unsqueeze(0)
|
|
||||||
img = img.to(device=device, dtype=torch.float32)
|
|
||||||
|
|
||||||
net.eval()
|
|
||||||
with torch.inference_mode():
|
|
||||||
output = net(img)
|
|
||||||
# preds = torch.sigmoid(output)[0]
|
|
||||||
# full_mask = output.squeeze(0).cpu()
|
|
||||||
|
|
||||||
return np.asarray(output.squeeze().cpu())
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = get_args()
|
args = get_args()
|
||||||
|
|
||||||
|
@ -81,8 +68,17 @@ if __name__ == "__main__":
|
||||||
img = aug["image"]
|
img = aug["image"]
|
||||||
|
|
||||||
logging.info(f"Predicting image {args.input}")
|
logging.info(f"Predicting image {args.input}")
|
||||||
mask = predict_img(net=net, img=img, device=device)
|
img = img.unsqueeze(0).to(device=device, dtype=torch.float32)
|
||||||
|
|
||||||
|
net.eval()
|
||||||
|
with torch.inference_mode():
|
||||||
|
mask = net(img)
|
||||||
|
mask = torch.sigmoid(mask)[0]
|
||||||
|
mask = mask.cpu()
|
||||||
|
mask = mask.squeeze()
|
||||||
|
mask = mask > 0.5
|
||||||
|
mask = np.asarray(mask)
|
||||||
|
|
||||||
logging.info(f"Saving prediction to {args.output}")
|
logging.info(f"Saving prediction to {args.output}")
|
||||||
mask = Image.fromarray(mask, "L")
|
mask = Image.fromarray(mask)
|
||||||
mask.save(args.output)
|
mask.save(args.output)
|
||||||
|
|
37
src/train.py
37
src/train.py
|
@ -26,7 +26,7 @@ def main():
|
||||||
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/",
|
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/",
|
||||||
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
|
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
|
||||||
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
||||||
FEATURES=[64, 128, 256, 512],
|
FEATURES=[16, 32, 64, 128],
|
||||||
N_CHANNELS=3,
|
N_CHANNELS=3,
|
||||||
N_CLASSES=1,
|
N_CLASSES=1,
|
||||||
AMP=True,
|
AMP=True,
|
||||||
|
@ -35,8 +35,8 @@ def main():
|
||||||
DEVICE="cuda",
|
DEVICE="cuda",
|
||||||
WORKERS=8,
|
WORKERS=8,
|
||||||
EPOCHS=5,
|
EPOCHS=5,
|
||||||
BATCH_SIZE=16,
|
BATCH_SIZE=64,
|
||||||
LEARNING_RATE=1e-5,
|
LEARNING_RATE=1e-4,
|
||||||
IMG_SIZE=512,
|
IMG_SIZE=512,
|
||||||
SPHERES=5,
|
SPHERES=5,
|
||||||
),
|
),
|
||||||
|
@ -50,7 +50,8 @@ def main():
|
||||||
|
|
||||||
# 0. Create network
|
# 0. Create network
|
||||||
net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES)
|
net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES)
|
||||||
wandb.config.parameters = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
||||||
|
wandb.watch(net, log_freq=100)
|
||||||
|
|
||||||
# transfer network to device
|
# transfer network to device
|
||||||
net.to(device=device)
|
net.to(device=device)
|
||||||
|
@ -80,6 +81,11 @@ def main():
|
||||||
# 2. Create datasets
|
# 2. Create datasets
|
||||||
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
|
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
|
||||||
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid)
|
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid)
|
||||||
|
# ds_train_bg20k = SphereDataset(image_dir="/home/lilian/data_disk/lfainsin/BG-20k/train/", transform=tf_train)
|
||||||
|
# ds_valid_bg20k = SphereDataset(image_dir="/home/lilian/data_disk/lfainsin/BG-20k/testval/", transform=tf_valid)
|
||||||
|
|
||||||
|
# ds_train = torch.utils.data.ChainDataset([ds_train_coco, ds_train_bg20k])
|
||||||
|
# ds_valid = torch.utils.data.ChainDataset([ds_valid_coco, ds_valid_bg20k]) # TODO: modifier la classe SphereDataset pour prendre plusieurs dossiers
|
||||||
|
|
||||||
# 3. Create data loaders
|
# 3. Create data loaders
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
|
@ -99,24 +105,24 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp
|
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp
|
||||||
optimizer = torch.optim.RMSprop(net.parameters(), lr=wandb.config.LEARNING_RATE, weight_decay=1e-8, momentum=0.9)
|
optimizer = torch.optim.Adam(net.parameters(), lr=wandb.config.LEARNING_RATE)
|
||||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
|
||||||
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
|
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
|
||||||
criterion = torch.nn.BCEWithLogitsLoss()
|
criterion = torch.nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
# save model.pth
|
# save model.pth
|
||||||
wandb.watch(net, log_freq=100)
|
torch.save(net.state_dict(), "checkpoints/model-0.pth")
|
||||||
artifact = wandb.Artifact("pth", type="model")
|
artifact = wandb.Artifact("pth", type="model")
|
||||||
artifact.add_file("model.pth")
|
artifact.add_file("checkpoints/model-0.pth")
|
||||||
wandb.run.log_artifact(artifact)
|
wandb.run.log_artifact(artifact)
|
||||||
|
|
||||||
# save model.onxx
|
# save model.onxx
|
||||||
dummy_input = torch.randn(
|
dummy_input = torch.randn(
|
||||||
1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True
|
1, wandb.config.N_CHANNELS, wandb.config.IMG_SIZE, wandb.config.IMG_SIZE, requires_grad=True
|
||||||
).to(device)
|
).to(device)
|
||||||
torch.onnx.export(net, dummy_input, "model.onnx")
|
torch.onnx.export(net, dummy_input, "checkpoints/model-0.onnx")
|
||||||
artifact = wandb.Artifact("onnx", type="model")
|
artifact = wandb.Artifact("onnx", type="model")
|
||||||
artifact.add_file("model.onnx")
|
artifact.add_file("checkpoints/model-0.onnx")
|
||||||
wandb.run.log_artifact(artifact)
|
wandb.run.log_artifact(artifact)
|
||||||
|
|
||||||
# print the config
|
# print the config
|
||||||
|
@ -145,7 +151,7 @@ def main():
|
||||||
# forward
|
# forward
|
||||||
with torch.cuda.amp.autocast(enabled=wandb.config.AMP):
|
with torch.cuda.amp.autocast(enabled=wandb.config.AMP):
|
||||||
pred_masks = net(images)
|
pred_masks = net(images)
|
||||||
train_loss = criterion(pred_masks, true_masks)
|
train_loss = criterion(true_masks, pred_masks)
|
||||||
|
|
||||||
# backward
|
# backward
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
@ -167,7 +173,7 @@ def main():
|
||||||
|
|
||||||
# Evaluation round
|
# Evaluation round
|
||||||
val_score = evaluate(net, val_loader, device)
|
val_score = evaluate(net, val_loader, device)
|
||||||
scheduler.step(val_score)
|
# scheduler.step(val_score)
|
||||||
|
|
||||||
# log validation metrics
|
# log validation metrics
|
||||||
wandb.log(
|
wandb.log(
|
||||||
|
@ -177,18 +183,19 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
# save weights when epoch end
|
# save weights when epoch end
|
||||||
torch.save(net.state_dict(), "model.pth")
|
torch.save(net.state_dict(), f"checkpoints/model-{epoch}.pth")
|
||||||
artifact = wandb.Artifact("pth", type="model")
|
artifact = wandb.Artifact("pth", type="model")
|
||||||
artifact.add_file("model.pth")
|
artifact.add_file(f"checkpoints/model-{epoch}.pth")
|
||||||
wandb.run.log_artifact(artifact)
|
wandb.run.log_artifact(artifact)
|
||||||
|
|
||||||
# export model to onnx format
|
# export model to onnx format
|
||||||
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
|
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
|
||||||
torch.onnx.export(net, dummy_input, "model.onnx")
|
torch.onnx.export(net, dummy_input, f"checkpoints/model-{epoch}.onnx")
|
||||||
artifact = wandb.Artifact("onnx", type="model")
|
artifact = wandb.Artifact("onnx", type="model")
|
||||||
artifact.add_file("model.onnx")
|
artifact.add_file(f"checkpoints/model-{epoch}.onnx")
|
||||||
wandb.run.log_artifact(artifact)
|
wandb.run.log_artifact(artifact)
|
||||||
|
|
||||||
|
# stop wandb
|
||||||
wandb.run.finish()
|
wandb.run.finish()
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
|
|
@ -70,7 +70,10 @@ class OutConv(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels):
|
def __init__(self, in_channels, out_channels):
|
||||||
super(OutConv, self).__init__()
|
super(OutConv, self).__init__()
|
||||||
|
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
self.conv = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, out_channels, kernel_size=1),
|
||||||
|
nn.Sigmoid(),
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.conv(x)
|
return self.conv(x)
|
||||||
|
|
Loading…
Reference in a new issue