feat: log learning_rate
Former-commit-id: aaf6be4efe43d65e70650ee8c07b81b584a8d70e [formerly c4289255d70c75c72b684886824832ab61df533b] Former-commit-id: a163c42fa2ca66e32c093424ed8ffdc3b82b5ea5
This commit is contained in:
parent
d839aec1af
commit
cf8f52735a
27
src/train.py
27
src/train.py
|
@ -17,8 +17,7 @@ class_labels = {
|
|||
1: "sphere",
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
if __name__ == "__main__":
|
||||
# setup logging
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
|
||||
|
@ -41,6 +40,8 @@ def main():
|
|||
EPOCHS=5,
|
||||
BATCH_SIZE=16,
|
||||
LEARNING_RATE=1e-4,
|
||||
WEIGHT_DECAY=1e-8,
|
||||
MOMENTUM=0.9,
|
||||
IMG_SIZE=512,
|
||||
SPHERES=5,
|
||||
),
|
||||
|
@ -88,7 +89,7 @@ def main():
|
|||
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
|
||||
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid)
|
||||
|
||||
# 2.5 Create subset, if uncommented
|
||||
# 2.5. Create subset, if uncommented
|
||||
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000)))
|
||||
ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100)))
|
||||
|
||||
|
@ -110,7 +111,12 @@ def main():
|
|||
)
|
||||
|
||||
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp
|
||||
optimizer = torch.optim.RMSprop(net.parameters(), lr=wandb.config.LEARNING_RATE, weight_decay=1e-8, momentum=0.9)
|
||||
optimizer = torch.optim.RMSprop(
|
||||
net.parameters(),
|
||||
lr=wandb.config.LEARNING_RATE,
|
||||
weight_decay=wandb.config.WEIGHT_DECAY,
|
||||
momentum=wandb.config.MOMENTUM,
|
||||
)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
|
||||
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
|
||||
criterion = torch.nn.BCEWithLogitsLoss()
|
||||
|
@ -137,6 +143,14 @@ def main():
|
|||
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
|
||||
|
||||
try:
|
||||
# wandb init log
|
||||
# wandb.log(
|
||||
# {
|
||||
# "train/learning_rate": scheduler.get_lr(),
|
||||
# },
|
||||
# commit=False,
|
||||
# )
|
||||
|
||||
for epoch in range(1, wandb.config.EPOCHS + 1):
|
||||
with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar:
|
||||
|
||||
|
@ -245,6 +259,7 @@ def main():
|
|||
wandb.log(
|
||||
{
|
||||
"predictions": table,
|
||||
"train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
|
||||
"val/accuracy": accuracy,
|
||||
"val/bce": val_loss,
|
||||
"val/dice": dice,
|
||||
|
@ -276,7 +291,3 @@ def main():
|
|||
except KeyboardInterrupt:
|
||||
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main() # TODO: fix toutes les metrics, loss, accuracy, dice...
|
||||
|
|
Loading…
Reference in a new issue