fix: small bugs

Former-commit-id: 61b601b392e0199527e13ab22658b81a87efaa37 [formerly c64178fb281cd142cb4cbbf9b9d4d63dfc0099e5]
Former-commit-id: 74f22879d2630647e9d853db8d1a21a889cab651
This commit is contained in:
Laurent Fainsin 2022-07-11 16:13:16 +02:00
parent 75a4907591
commit ed07e130e6
4 changed files with 15 additions and 16 deletions

View file

@ -31,19 +31,19 @@ class Spheres(pl.LightningDataModule):
dataset,
shuffle=True,
prefetch_factor=8,
batch_size=wandb.config.BATCH_SIZE,
batch_size=wandb.config.TRAIN_BATCH_SIZE,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,
)
def val_dataloader(self):
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
# dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
return DataLoader(
dataset,
shuffle=False,
batch_size=8,
batch_size=wandb.config.VAL_BATCH_SIZE,
prefetch_factor=8,
num_workers=wandb.config.WORKERS,
pin_memory=wandb.config.PIN_MEMORY,

View file

@ -13,9 +13,10 @@ if __name__ == "__main__":
# setup logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
# setup wandb, config loaded from config-default.yaml
# setup wandb
logger = WandbLogger(
project="U-Net",
config="wandb.yaml",
settings=wandb.Settings(
code_dir="./src/",
),
@ -28,8 +29,6 @@ if __name__ == "__main__":
model = UNetModule(
n_channels=wandb.config.N_CHANNELS,
n_classes=wandb.config.N_CLASSES,
batch_size=wandb.config.BATCH_SIZE,
learning_rate=wandb.config.LEARNING_RATE,
features=wandb.config.FEATURES,
)

View file

@ -16,14 +16,12 @@ class_labels = {
class UNetModule(pl.LightningModule):
def __init__(self, n_channels, n_classes, learning_rate, batch_size, features=[64, 128, 256, 512]):
def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]):
super(UNetModule, self).__init__()
# Hyperparameters
self.n_channels = n_channels
self.n_classes = n_classes
self.learning_rate = learning_rate
self.batch_size = batch_size
# log hyperparameters
self.save_hyperparameters()
@ -111,7 +109,7 @@ class UNetModule(pl.LightningModule):
def configure_optimizers(self):
optimizer = torch.optim.RMSprop(
self.parameters(),
lr=self.learning_rate,
lr=wandb.config.LEARNING_RATE,
weight_decay=wandb.config.WEIGHT_DECAY,
momentum=wandb.config.MOMENTUM,
)

View file

@ -6,11 +6,11 @@ DIR_SPHERE:
value: "/home/lilian/data_disk/lfainsin/spheres+real/"
FEATURES:
value: { 8, 16, 32, 64 }
value: [8, 16, 32, 64]
N_CHANNELS:
value: 3,
value: 3
N_CLASSES:
value: 1,
value: 1
AMP:
value: True
@ -30,12 +30,14 @@ SPHERES:
EPOCHS:
value: 10
BATCH_SIZE:
TRAIN_BATCH_SIZE:
value: 16
VAL_BATCH_SIZE:
value: 8
LEARNING_RATE:
value: 1e-4
value: 1.0e-4
WEIGHT_DECAY:
value: 1e-8
value: 1.0e-8
MOMENTUM:
value: 0.9