feat: log learning_rate

Former-commit-id: aaf6be4efe43d65e70650ee8c07b81b584a8d70e [formerly c4289255d70c75c72b684886824832ab61df533b]
Former-commit-id: a163c42fa2ca66e32c093424ed8ffdc3b82b5ea5
This commit is contained in:
Laurent Fainsin 2022-07-01 14:32:30 +02:00
parent d839aec1af
commit cf8f52735a

View file

@ -17,8 +17,7 @@ class_labels = {
1: "sphere",
}
def main():
if __name__ == "__main__":
# setup logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
@ -41,6 +40,8 @@ def main():
EPOCHS=5,
BATCH_SIZE=16,
LEARNING_RATE=1e-4,
WEIGHT_DECAY=1e-8,
MOMENTUM=0.9,
IMG_SIZE=512,
SPHERES=5,
),
@ -88,7 +89,7 @@ def main():
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid)
# 2.5 Create subset, if uncommented
# 2.5. Create subset, if uncommented
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000)))
ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100)))
@ -110,7 +111,12 @@ def main():
)
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp
optimizer = torch.optim.RMSprop(net.parameters(), lr=wandb.config.LEARNING_RATE, weight_decay=1e-8, momentum=0.9)
optimizer = torch.optim.RMSprop(
net.parameters(),
lr=wandb.config.LEARNING_RATE,
weight_decay=wandb.config.WEIGHT_DECAY,
momentum=wandb.config.MOMENTUM,
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
criterion = torch.nn.BCEWithLogitsLoss()
@ -137,6 +143,14 @@ def main():
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
try:
# wandb init log
# wandb.log(
# {
# "train/learning_rate": scheduler.get_lr(),
# },
# commit=False,
# )
for epoch in range(1, wandb.config.EPOCHS + 1):
with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar:
@ -245,6 +259,7 @@ def main():
wandb.log(
{
"predictions": table,
"train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
"val/accuracy": accuracy,
"val/bce": val_loss,
"val/dice": dice,
@ -276,7 +291,3 @@ def main():
except KeyboardInterrupt:
torch.save(net.state_dict(), "INTERRUPTED.pth")
raise
if __name__ == "__main__":
main() # TODO: fix toutes les metrics, loss, accuracy, dice...