fix: small bugs
Former-commit-id: 61b601b392e0199527e13ab22658b81a87efaa37 [formerly c64178fb281cd142cb4cbbf9b9d4d63dfc0099e5] Former-commit-id: 74f22879d2630647e9d853db8d1a21a889cab651
This commit is contained in:
parent
75a4907591
commit
ed07e130e6
|
@ -31,19 +31,19 @@ class Spheres(pl.LightningDataModule):
|
|||
dataset,
|
||||
shuffle=True,
|
||||
prefetch_factor=8,
|
||||
batch_size=wandb.config.BATCH_SIZE,
|
||||
batch_size=wandb.config.TRAIN_BATCH_SIZE,
|
||||
num_workers=wandb.config.WORKERS,
|
||||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
||||
dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
|
||||
# dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
shuffle=False,
|
||||
batch_size=8,
|
||||
batch_size=wandb.config.VAL_BATCH_SIZE,
|
||||
prefetch_factor=8,
|
||||
num_workers=wandb.config.WORKERS,
|
||||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
|
|
|
@ -13,9 +13,10 @@ if __name__ == "__main__":
|
|||
# setup logging
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
|
||||
# setup wandb, config loaded from config-default.yaml
|
||||
# setup wandb
|
||||
logger = WandbLogger(
|
||||
project="U-Net",
|
||||
config="wandb.yaml",
|
||||
settings=wandb.Settings(
|
||||
code_dir="./src/",
|
||||
),
|
||||
|
@ -28,8 +29,6 @@ if __name__ == "__main__":
|
|||
model = UNetModule(
|
||||
n_channels=wandb.config.N_CHANNELS,
|
||||
n_classes=wandb.config.N_CLASSES,
|
||||
batch_size=wandb.config.BATCH_SIZE,
|
||||
learning_rate=wandb.config.LEARNING_RATE,
|
||||
features=wandb.config.FEATURES,
|
||||
)
|
||||
|
||||
|
|
|
@ -16,14 +16,12 @@ class_labels = {
|
|||
|
||||
|
||||
class UNetModule(pl.LightningModule):
|
||||
def __init__(self, n_channels, n_classes, learning_rate, batch_size, features=[64, 128, 256, 512]):
|
||||
def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]):
|
||||
super(UNetModule, self).__init__()
|
||||
|
||||
# Hyperparameters
|
||||
self.n_channels = n_channels
|
||||
self.n_classes = n_classes
|
||||
self.learning_rate = learning_rate
|
||||
self.batch_size = batch_size
|
||||
|
||||
# log hyperparameters
|
||||
self.save_hyperparameters()
|
||||
|
@ -111,7 +109,7 @@ class UNetModule(pl.LightningModule):
|
|||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.RMSprop(
|
||||
self.parameters(),
|
||||
lr=self.learning_rate,
|
||||
lr=wandb.config.LEARNING_RATE,
|
||||
weight_decay=wandb.config.WEIGHT_DECAY,
|
||||
momentum=wandb.config.MOMENTUM,
|
||||
)
|
||||
|
|
|
@ -6,11 +6,11 @@ DIR_SPHERE:
|
|||
value: "/home/lilian/data_disk/lfainsin/spheres+real/"
|
||||
|
||||
FEATURES:
|
||||
value: { 8, 16, 32, 64 }
|
||||
value: [8, 16, 32, 64]
|
||||
N_CHANNELS:
|
||||
value: 3,
|
||||
value: 3
|
||||
N_CLASSES:
|
||||
value: 1,
|
||||
value: 1
|
||||
|
||||
AMP:
|
||||
value: True
|
||||
|
@ -30,12 +30,14 @@ SPHERES:
|
|||
|
||||
EPOCHS:
|
||||
value: 10
|
||||
BATCH_SIZE:
|
||||
TRAIN_BATCH_SIZE:
|
||||
value: 16
|
||||
VAL_BATCH_SIZE:
|
||||
value: 8
|
||||
|
||||
LEARNING_RATE:
|
||||
value: 1e-4
|
||||
value: 1.0e-4
|
||||
WEIGHT_DECAY:
|
||||
value: 1e-8
|
||||
value: 1.0e-8
|
||||
MOMENTUM:
|
||||
value: 0.9
|
Loading…
Reference in a new issue