feat: log learning_rate
Former-commit-id: aaf6be4efe43d65e70650ee8c07b81b584a8d70e [formerly c4289255d70c75c72b684886824832ab61df533b] Former-commit-id: a163c42fa2ca66e32c093424ed8ffdc3b82b5ea5
This commit is contained in:
parent
d839aec1af
commit
cf8f52735a
27
src/train.py
27
src/train.py
|
@ -17,8 +17,7 @@ class_labels = {
|
||||||
1: "sphere",
|
1: "sphere",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
def main():
|
|
||||||
# setup logging
|
# setup logging
|
||||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||||
|
|
||||||
|
@ -41,6 +40,8 @@ def main():
|
||||||
EPOCHS=5,
|
EPOCHS=5,
|
||||||
BATCH_SIZE=16,
|
BATCH_SIZE=16,
|
||||||
LEARNING_RATE=1e-4,
|
LEARNING_RATE=1e-4,
|
||||||
|
WEIGHT_DECAY=1e-8,
|
||||||
|
MOMENTUM=0.9,
|
||||||
IMG_SIZE=512,
|
IMG_SIZE=512,
|
||||||
SPHERES=5,
|
SPHERES=5,
|
||||||
),
|
),
|
||||||
|
@ -88,7 +89,7 @@ def main():
|
||||||
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
|
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
|
||||||
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid)
|
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid)
|
||||||
|
|
||||||
# 2.5 Create subset, if uncommented
|
# 2.5. Create subset, if uncommented
|
||||||
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000)))
|
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000)))
|
||||||
ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100)))
|
ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100)))
|
||||||
|
|
||||||
|
@ -110,7 +111,12 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp
|
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp
|
||||||
optimizer = torch.optim.RMSprop(net.parameters(), lr=wandb.config.LEARNING_RATE, weight_decay=1e-8, momentum=0.9)
|
optimizer = torch.optim.RMSprop(
|
||||||
|
net.parameters(),
|
||||||
|
lr=wandb.config.LEARNING_RATE,
|
||||||
|
weight_decay=wandb.config.WEIGHT_DECAY,
|
||||||
|
momentum=wandb.config.MOMENTUM,
|
||||||
|
)
|
||||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
|
||||||
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
|
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP)
|
||||||
criterion = torch.nn.BCEWithLogitsLoss()
|
criterion = torch.nn.BCEWithLogitsLoss()
|
||||||
|
@ -137,6 +143,14 @@ def main():
|
||||||
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
|
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# wandb init log
|
||||||
|
# wandb.log(
|
||||||
|
# {
|
||||||
|
# "train/learning_rate": scheduler.get_lr(),
|
||||||
|
# },
|
||||||
|
# commit=False,
|
||||||
|
# )
|
||||||
|
|
||||||
for epoch in range(1, wandb.config.EPOCHS + 1):
|
for epoch in range(1, wandb.config.EPOCHS + 1):
|
||||||
with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar:
|
with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar:
|
||||||
|
|
||||||
|
@ -245,6 +259,7 @@ def main():
|
||||||
wandb.log(
|
wandb.log(
|
||||||
{
|
{
|
||||||
"predictions": table,
|
"predictions": table,
|
||||||
|
"train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
|
||||||
"val/accuracy": accuracy,
|
"val/accuracy": accuracy,
|
||||||
"val/bce": val_loss,
|
"val/bce": val_loss,
|
||||||
"val/dice": dice,
|
"val/dice": dice,
|
||||||
|
@ -276,7 +291,3 @@ def main():
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main() # TODO: fix toutes les metrics, loss, accuracy, dice...
|
|
||||||
|
|
Loading…
Reference in a new issue