diff --git a/src/data/dataloader.py b/src/data/dataloader.py index fe26b0b..9e67c4c 100644 --- a/src/data/dataloader.py +++ b/src/data/dataloader.py @@ -31,19 +31,19 @@ class Spheres(pl.LightningDataModule): dataset, shuffle=True, prefetch_factor=8, - batch_size=wandb.config.BATCH_SIZE, + batch_size=wandb.config.TRAIN_BATCH_SIZE, num_workers=wandb.config.WORKERS, pin_memory=wandb.config.PIN_MEMORY, ) def val_dataloader(self): dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG) - dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1))) + # dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1))) return DataLoader( dataset, shuffle=False, - batch_size=8, + batch_size=wandb.config.VAL_BATCH_SIZE, prefetch_factor=8, num_workers=wandb.config.WORKERS, pin_memory=wandb.config.PIN_MEMORY, diff --git a/src/train.py b/src/train.py index 1a7dbe5..85625e0 100644 --- a/src/train.py +++ b/src/train.py @@ -13,9 +13,10 @@ if __name__ == "__main__": # setup logging logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") - # setup wandb, config loaded from config-default.yaml + # setup wandb logger = WandbLogger( project="U-Net", + config="wandb.yaml", settings=wandb.Settings( code_dir="./src/", ), @@ -28,8 +29,6 @@ if __name__ == "__main__": model = UNetModule( n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, - batch_size=wandb.config.BATCH_SIZE, - learning_rate=wandb.config.LEARNING_RATE, features=wandb.config.FEATURES, ) diff --git a/src/unet/module.py b/src/unet/module.py index 86bb107..c7ae528 100644 --- a/src/unet/module.py +++ b/src/unet/module.py @@ -16,14 +16,12 @@ class_labels = { class UNetModule(pl.LightningModule): - def __init__(self, n_channels, n_classes, learning_rate, batch_size, features=[64, 128, 256, 512]): + def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]): super(UNetModule, self).__init__() # Hyperparameters self.n_channels = n_channels self.n_classes = n_classes - self.learning_rate = learning_rate - self.batch_size = batch_size # log hyperparameters self.save_hyperparameters() @@ -111,7 +109,7 @@ class UNetModule(pl.LightningModule): def configure_optimizers(self): optimizer = torch.optim.RMSprop( self.parameters(), - lr=self.learning_rate, + lr=wandb.config.LEARNING_RATE, weight_decay=wandb.config.WEIGHT_DECAY, momentum=wandb.config.MOMENTUM, ) diff --git a/src/config-defaults.yaml b/wandb.yaml similarity index 79% rename from src/config-defaults.yaml rename to wandb.yaml index 4fcda67..2515950 100644 --- a/src/config-defaults.yaml +++ b/wandb.yaml @@ -6,11 +6,11 @@ DIR_SPHERE: value: "/home/lilian/data_disk/lfainsin/spheres+real/" FEATURES: - value: { 8, 16, 32, 64 } + value: [8, 16, 32, 64] N_CHANNELS: - value: 3, + value: 3 N_CLASSES: - value: 1, + value: 1 AMP: value: True @@ -30,12 +30,14 @@ SPHERES: EPOCHS: value: 10 -BATCH_SIZE: +TRAIN_BATCH_SIZE: value: 16 +VAL_BATCH_SIZE: + value: 8 LEARNING_RATE: - value: 1e-4 + value: 1.0e-4 WEIGHT_DECAY: - value: 1e-8 + value: 1.0e-8 MOMENTUM: value: 0.9