mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
feat: added some training metrics
Former-commit-id: 05f9c9c44914cc2ef3443791e71d2f9ca893751b [formerly 0cec2187f92c52328f07cca65ebaeb7ab3641ac3] Former-commit-id: 4749dbea8fd69798e08e34a91156d3111c7504a8
This commit is contained in:
parent
d9f2dc2bfb
commit
e20a989c41
26
src/train.py
26
src/train.py
|
@ -26,7 +26,7 @@ def main():
|
|||
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/",
|
||||
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
|
||||
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
||||
FEATURES=[16, 32, 64, 128],
|
||||
FEATURES=[64, 128, 256, 512],
|
||||
N_CHANNELS=3,
|
||||
N_CLASSES=1,
|
||||
AMP=True,
|
||||
|
@ -35,7 +35,7 @@ def main():
|
|||
DEVICE="cuda",
|
||||
WORKERS=8,
|
||||
EPOCHS=5,
|
||||
BATCH_SIZE=64,
|
||||
BATCH_SIZE=16,
|
||||
LEARNING_RATE=1e-4,
|
||||
IMG_SIZE=512,
|
||||
SPHERES=5,
|
||||
|
@ -105,11 +105,15 @@ def main():
|
|||
)
|
||||
|
||||
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp
|
||||
optimizer = torch.optim.Adam(net.parameters(), lr=wandb.config.LEARNING_RATE)
|
||||
optimizer = torch.optim.RMSprop(net.parameters(), lr=wandb.config.LEARNING_RATE, weight_decay=1e-8, momentum=0.9)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
|
||||
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
|
||||
criterion = torch.nn.BCEWithLogitsLoss()
|
||||
|
||||
# accuracy stuff
|
||||
mse = torch.nn.MSELoss()
|
||||
mae = torch.nn.L1Loss()
|
||||
|
||||
# save model.pth
|
||||
torch.save(net.state_dict(), "checkpoints/model-0.pth")
|
||||
artifact = wandb.Artifact("pth", type="model")
|
||||
|
@ -151,7 +155,7 @@ def main():
|
|||
# forward
|
||||
with torch.cuda.amp.autocast(enabled=wandb.config.AMP):
|
||||
pred_masks = net(images)
|
||||
train_loss = criterion(true_masks, pred_masks)
|
||||
train_loss = criterion(pred_masks, true_masks)
|
||||
|
||||
# backward
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
@ -159,21 +163,29 @@ def main():
|
|||
grad_scaler.step(optimizer)
|
||||
grad_scaler.update()
|
||||
|
||||
# compute metrics
|
||||
accuracy = (true_masks == pred_masks).float().mean()
|
||||
mse = torch.nn.functional.mse_loss(pred_masks, true_masks)
|
||||
mae = torch.nn.functional.l1_loss(pred_masks, true_masks)
|
||||
|
||||
# update tqdm progress bar
|
||||
pbar.update(images.shape[0])
|
||||
pbar.set_postfix(**{"loss": train_loss.item()})
|
||||
|
||||
# log training metrics
|
||||
# log metrics
|
||||
wandb.log(
|
||||
{
|
||||
"train/epoch": epoch - 1 + step / len(train_loader),
|
||||
"train/train_loss": train_loss,
|
||||
"train/accuracy": accuracy,
|
||||
"train/loss": train_loss,
|
||||
"train/mse": mse,
|
||||
"train/mae": mae,
|
||||
}
|
||||
)
|
||||
|
||||
# Evaluation round
|
||||
val_score = evaluate(net, val_loader, device)
|
||||
# scheduler.step(val_score)
|
||||
scheduler.step(val_score)
|
||||
|
||||
# log validation metrics
|
||||
wandb.log(
|
||||
|
|
|
@ -72,7 +72,7 @@ class OutConv(nn.Module):
|
|||
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=1),
|
||||
nn.Sigmoid(),
|
||||
# nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
Loading…
Reference in a new issue