feat: added some training metrics

Former-commit-id: 05f9c9c44914cc2ef3443791e71d2f9ca893751b [formerly 0cec2187f92c52328f07cca65ebaeb7ab3641ac3]
Former-commit-id: 4749dbea8fd69798e08e34a91156d3111c7504a8
This commit is contained in:
Laurent Fainsin 2022-06-30 21:26:12 +02:00
parent d9f2dc2bfb
commit e20a989c41
2 changed files with 20 additions and 8 deletions

View file

@ -26,7 +26,7 @@ def main():
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/",
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
FEATURES=[16, 32, 64, 128],
FEATURES=[64, 128, 256, 512],
N_CHANNELS=3,
N_CLASSES=1,
AMP=True,
@ -35,7 +35,7 @@ def main():
DEVICE="cuda",
WORKERS=8,
EPOCHS=5,
BATCH_SIZE=64,
BATCH_SIZE=16,
LEARNING_RATE=1e-4,
IMG_SIZE=512,
SPHERES=5,
@ -105,11 +105,15 @@ def main():
)
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp
optimizer = torch.optim.Adam(net.parameters(), lr=wandb.config.LEARNING_RATE)
optimizer = torch.optim.RMSprop(net.parameters(), lr=wandb.config.LEARNING_RATE, weight_decay=1e-8, momentum=0.9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
criterion = torch.nn.BCEWithLogitsLoss()
# accuracy stuff
mse = torch.nn.MSELoss()
mae = torch.nn.L1Loss()
# save model.pth
torch.save(net.state_dict(), "checkpoints/model-0.pth")
artifact = wandb.Artifact("pth", type="model")
@ -151,7 +155,7 @@ def main():
# forward
with torch.cuda.amp.autocast(enabled=wandb.config.AMP):
pred_masks = net(images)
train_loss = criterion(true_masks, pred_masks)
train_loss = criterion(pred_masks, true_masks)
# backward
optimizer.zero_grad(set_to_none=True)
@ -159,21 +163,29 @@ def main():
grad_scaler.step(optimizer)
grad_scaler.update()
# compute metrics
accuracy = (true_masks == pred_masks).float().mean()
mse = torch.nn.functional.mse_loss(pred_masks, true_masks)
mae = torch.nn.functional.l1_loss(pred_masks, true_masks)
# update tqdm progress bar
pbar.update(images.shape[0])
pbar.set_postfix(**{"loss": train_loss.item()})
# log training metrics
# log metrics
wandb.log(
{
"train/epoch": epoch - 1 + step / len(train_loader),
"train/train_loss": train_loss,
"train/accuracy": accuracy,
"train/loss": train_loss,
"train/mse": mse,
"train/mae": mae,
}
)
# Evaluation round
val_score = evaluate(net, val_loader, device)
# scheduler.step(val_score)
scheduler.step(val_score)
# log validation metrics
wandb.log(

View file

@ -72,7 +72,7 @@ class OutConv(nn.Module):
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1),
nn.Sigmoid(),
# nn.Sigmoid(),
)
def forward(self, x):