fix: small bugs

Former-commit-id: 61b601b392e0199527e13ab22658b81a87efaa37 [formerly c64178fb281cd142cb4cbbf9b9d4d63dfc0099e5]
Former-commit-id: 74f22879d2630647e9d853db8d1a21a889cab651
This commit is contained in:
Laurent Fainsin 2022-07-11 16:13:16 +02:00
parent 75a4907591
commit ed07e130e6
4 changed files with 15 additions and 16 deletions

View file

@ -31,19 +31,19 @@ class Spheres(pl.LightningDataModule):
dataset, dataset,
shuffle=True, shuffle=True,
prefetch_factor=8, prefetch_factor=8,
batch_size=wandb.config.BATCH_SIZE, batch_size=wandb.config.TRAIN_BATCH_SIZE,
num_workers=wandb.config.WORKERS, num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY, pin_memory=wandb.config.PIN_MEMORY,
) )
def val_dataloader(self): def val_dataloader(self):
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG) dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1))) # dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
return DataLoader( return DataLoader(
dataset, dataset,
shuffle=False, shuffle=False,
batch_size=8, batch_size=wandb.config.VAL_BATCH_SIZE,
prefetch_factor=8, prefetch_factor=8,
num_workers=wandb.config.WORKERS, num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY, pin_memory=wandb.config.PIN_MEMORY,

View file

@ -13,9 +13,10 @@ if __name__ == "__main__":
# setup logging # setup logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
# setup wandb, config loaded from config-default.yaml # setup wandb
logger = WandbLogger( logger = WandbLogger(
project="U-Net", project="U-Net",
config="wandb.yaml",
settings=wandb.Settings( settings=wandb.Settings(
code_dir="./src/", code_dir="./src/",
), ),
@ -28,8 +29,6 @@ if __name__ == "__main__":
model = UNetModule( model = UNetModule(
n_channels=wandb.config.N_CHANNELS, n_channels=wandb.config.N_CHANNELS,
n_classes=wandb.config.N_CLASSES, n_classes=wandb.config.N_CLASSES,
batch_size=wandb.config.BATCH_SIZE,
learning_rate=wandb.config.LEARNING_RATE,
features=wandb.config.FEATURES, features=wandb.config.FEATURES,
) )

View file

@ -16,14 +16,12 @@ class_labels = {
class UNetModule(pl.LightningModule): class UNetModule(pl.LightningModule):
def __init__(self, n_channels, n_classes, learning_rate, batch_size, features=[64, 128, 256, 512]): def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]):
super(UNetModule, self).__init__() super(UNetModule, self).__init__()
# Hyperparameters # Hyperparameters
self.n_channels = n_channels self.n_channels = n_channels
self.n_classes = n_classes self.n_classes = n_classes
self.learning_rate = learning_rate
self.batch_size = batch_size
# log hyperparameters # log hyperparameters
self.save_hyperparameters() self.save_hyperparameters()
@ -111,7 +109,7 @@ class UNetModule(pl.LightningModule):
def configure_optimizers(self): def configure_optimizers(self):
optimizer = torch.optim.RMSprop( optimizer = torch.optim.RMSprop(
self.parameters(), self.parameters(),
lr=self.learning_rate, lr=wandb.config.LEARNING_RATE,
weight_decay=wandb.config.WEIGHT_DECAY, weight_decay=wandb.config.WEIGHT_DECAY,
momentum=wandb.config.MOMENTUM, momentum=wandb.config.MOMENTUM,
) )

View file

@ -6,11 +6,11 @@ DIR_SPHERE:
value: "/home/lilian/data_disk/lfainsin/spheres+real/" value: "/home/lilian/data_disk/lfainsin/spheres+real/"
FEATURES: FEATURES:
value: { 8, 16, 32, 64 } value: [8, 16, 32, 64]
N_CHANNELS: N_CHANNELS:
value: 3, value: 3
N_CLASSES: N_CLASSES:
value: 1, value: 1
AMP: AMP:
value: True value: True
@ -30,12 +30,14 @@ SPHERES:
EPOCHS: EPOCHS:
value: 10 value: 10
BATCH_SIZE: TRAIN_BATCH_SIZE:
value: 16 value: 16
VAL_BATCH_SIZE:
value: 8
LEARNING_RATE: LEARNING_RATE:
value: 1e-4 value: 1.0e-4
WEIGHT_DECAY: WEIGHT_DECAY:
value: 1e-8 value: 1.0e-8
MOMENTUM: MOMENTUM:
value: 0.9 value: 0.9