diff --git a/src/train.py b/src/train.py index b7475ce..b0e7c84 100644 --- a/src/train.py +++ b/src/train.py @@ -17,8 +17,7 @@ class_labels = { 1: "sphere", } - -def main(): +if __name__ == "__main__": # setup logging logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") @@ -41,6 +40,8 @@ def main(): EPOCHS=5, BATCH_SIZE=16, LEARNING_RATE=1e-4, + WEIGHT_DECAY=1e-8, + MOMENTUM=0.9, IMG_SIZE=512, SPHERES=5, ), @@ -88,7 +89,7 @@ def main(): ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train) ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid) - # 2.5 Create subset, if uncommented + # 2.5. Create subset, if uncommented ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000))) ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100))) @@ -110,7 +111,12 @@ def main(): ) # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for amp - optimizer = torch.optim.RMSprop(net.parameters(), lr=wandb.config.LEARNING_RATE, weight_decay=1e-8, momentum=0.9) + optimizer = torch.optim.RMSprop( + net.parameters(), + lr=wandb.config.LEARNING_RATE, + weight_decay=wandb.config.WEIGHT_DECAY, + momentum=wandb.config.MOMENTUM, + ) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.AMP) criterion = torch.nn.BCEWithLogitsLoss() @@ -137,6 +143,14 @@ def main(): logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}") try: + # wandb init log + # wandb.log( + # { + # "train/learning_rate": scheduler.get_lr(), + # }, + # commit=False, + # ) + for epoch in range(1, wandb.config.EPOCHS + 1): with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar: @@ -245,6 +259,7 @@ def main(): wandb.log( { "predictions": table, + "train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"], "val/accuracy": accuracy, "val/bce": val_loss, "val/dice": dice, @@ -276,7 +291,3 @@ def main(): except KeyboardInterrupt: torch.save(net.state_dict(), "INTERRUPTED.pth") raise - - -if __name__ == "__main__": - main() # TODO: fix toutes les metrics, loss, accuracy, dice...