mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 23:12:05 +00:00
feat: added some training metrics
Former-commit-id: 05f9c9c44914cc2ef3443791e71d2f9ca893751b [formerly 0cec2187f92c52328f07cca65ebaeb7ab3641ac3] Former-commit-id: 4749dbea8fd69798e08e34a91156d3111c7504a8
This commit is contained in:
parent
d9f2dc2bfb
commit
e20a989c41
26
src/train.py
26
src/train.py
|
@ -26,7 +26,7 @@ def main():
|
||||||
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/",
|
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/",
|
||||||
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
|
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
|
||||||
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
||||||
FEATURES=[16, 32, 64, 128],
|
FEATURES=[64, 128, 256, 512],
|
||||||
N_CHANNELS=3,
|
N_CHANNELS=3,
|
||||||
N_CLASSES=1,
|
N_CLASSES=1,
|
||||||
AMP=True,
|
AMP=True,
|
||||||
|
@ -35,7 +35,7 @@ def main():
|
||||||
DEVICE="cuda",
|
DEVICE="cuda",
|
||||||
WORKERS=8,
|
WORKERS=8,
|
||||||
EPOCHS=5,
|
EPOCHS=5,
|
||||||
BATCH_SIZE=64,
|
BATCH_SIZE=16,
|
||||||
LEARNING_RATE=1e-4,
|
LEARNING_RATE=1e-4,
|
||||||
IMG_SIZE=512,
|
IMG_SIZE=512,
|
||||||
SPHERES=5,
|
SPHERES=5,
|
||||||
|
@ -105,11 +105,15 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp
|
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp
|
||||||
optimizer = torch.optim.Adam(net.parameters(), lr=wandb.config.LEARNING_RATE)
|
optimizer = torch.optim.RMSprop(net.parameters(), lr=wandb.config.LEARNING_RATE, weight_decay=1e-8, momentum=0.9)
|
||||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
|
||||||
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
|
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
|
||||||
criterion = torch.nn.BCEWithLogitsLoss()
|
criterion = torch.nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
|
# accuracy stuff
|
||||||
|
mse = torch.nn.MSELoss()
|
||||||
|
mae = torch.nn.L1Loss()
|
||||||
|
|
||||||
# save model.pth
|
# save model.pth
|
||||||
torch.save(net.state_dict(), "checkpoints/model-0.pth")
|
torch.save(net.state_dict(), "checkpoints/model-0.pth")
|
||||||
artifact = wandb.Artifact("pth", type="model")
|
artifact = wandb.Artifact("pth", type="model")
|
||||||
|
@ -151,7 +155,7 @@ def main():
|
||||||
# forward
|
# forward
|
||||||
with torch.cuda.amp.autocast(enabled=wandb.config.AMP):
|
with torch.cuda.amp.autocast(enabled=wandb.config.AMP):
|
||||||
pred_masks = net(images)
|
pred_masks = net(images)
|
||||||
train_loss = criterion(true_masks, pred_masks)
|
train_loss = criterion(pred_masks, true_masks)
|
||||||
|
|
||||||
# backward
|
# backward
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
@ -159,21 +163,29 @@ def main():
|
||||||
grad_scaler.step(optimizer)
|
grad_scaler.step(optimizer)
|
||||||
grad_scaler.update()
|
grad_scaler.update()
|
||||||
|
|
||||||
|
# compute metrics
|
||||||
|
accuracy = (true_masks == pred_masks).float().mean()
|
||||||
|
mse = torch.nn.functional.mse_loss(pred_masks, true_masks)
|
||||||
|
mae = torch.nn.functional.l1_loss(pred_masks, true_masks)
|
||||||
|
|
||||||
# update tqdm progress bar
|
# update tqdm progress bar
|
||||||
pbar.update(images.shape[0])
|
pbar.update(images.shape[0])
|
||||||
pbar.set_postfix(**{"loss": train_loss.item()})
|
pbar.set_postfix(**{"loss": train_loss.item()})
|
||||||
|
|
||||||
# log training metrics
|
# log metrics
|
||||||
wandb.log(
|
wandb.log(
|
||||||
{
|
{
|
||||||
"train/epoch": epoch - 1 + step / len(train_loader),
|
"train/epoch": epoch - 1 + step / len(train_loader),
|
||||||
"train/train_loss": train_loss,
|
"train/accuracy": accuracy,
|
||||||
|
"train/loss": train_loss,
|
||||||
|
"train/mse": mse,
|
||||||
|
"train/mae": mae,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Evaluation round
|
# Evaluation round
|
||||||
val_score = evaluate(net, val_loader, device)
|
val_score = evaluate(net, val_loader, device)
|
||||||
# scheduler.step(val_score)
|
scheduler.step(val_score)
|
||||||
|
|
||||||
# log validation metrics
|
# log validation metrics
|
||||||
wandb.log(
|
wandb.log(
|
||||||
|
|
|
@ -72,7 +72,7 @@ class OutConv(nn.Module):
|
||||||
|
|
||||||
self.conv = nn.Sequential(
|
self.conv = nn.Sequential(
|
||||||
nn.Conv2d(in_channels, out_channels, kernel_size=1),
|
nn.Conv2d(in_channels, out_channels, kernel_size=1),
|
||||||
nn.Sigmoid(),
|
# nn.Sigmoid(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
Loading…
Reference in a new issue