fix: small bugs
Former-commit-id: 61b601b392e0199527e13ab22658b81a87efaa37 [formerly c64178fb281cd142cb4cbbf9b9d4d63dfc0099e5] Former-commit-id: 74f22879d2630647e9d853db8d1a21a889cab651
This commit is contained in:
parent
75a4907591
commit
ed07e130e6
|
@ -31,19 +31,19 @@ class Spheres(pl.LightningDataModule):
|
||||||
dataset,
|
dataset,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
prefetch_factor=8,
|
prefetch_factor=8,
|
||||||
batch_size=wandb.config.BATCH_SIZE,
|
batch_size=wandb.config.TRAIN_BATCH_SIZE,
|
||||||
num_workers=wandb.config.WORKERS,
|
num_workers=wandb.config.WORKERS,
|
||||||
pin_memory=wandb.config.PIN_MEMORY,
|
pin_memory=wandb.config.PIN_MEMORY,
|
||||||
)
|
)
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
dataset = LabeledDataset(image_dir=wandb.config.DIR_VALID_IMG)
|
||||||
dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
|
# dataset = Subset(dataset, list(range(0, len(dataset), len(dataset) // 100 + 1)))
|
||||||
|
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
batch_size=8,
|
batch_size=wandb.config.VAL_BATCH_SIZE,
|
||||||
prefetch_factor=8,
|
prefetch_factor=8,
|
||||||
num_workers=wandb.config.WORKERS,
|
num_workers=wandb.config.WORKERS,
|
||||||
pin_memory=wandb.config.PIN_MEMORY,
|
pin_memory=wandb.config.PIN_MEMORY,
|
||||||
|
|
|
@ -13,9 +13,10 @@ if __name__ == "__main__":
|
||||||
# setup logging
|
# setup logging
|
||||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||||
|
|
||||||
# setup wandb, config loaded from config-default.yaml
|
# setup wandb
|
||||||
logger = WandbLogger(
|
logger = WandbLogger(
|
||||||
project="U-Net",
|
project="U-Net",
|
||||||
|
config="wandb.yaml",
|
||||||
settings=wandb.Settings(
|
settings=wandb.Settings(
|
||||||
code_dir="./src/",
|
code_dir="./src/",
|
||||||
),
|
),
|
||||||
|
@ -28,8 +29,6 @@ if __name__ == "__main__":
|
||||||
model = UNetModule(
|
model = UNetModule(
|
||||||
n_channels=wandb.config.N_CHANNELS,
|
n_channels=wandb.config.N_CHANNELS,
|
||||||
n_classes=wandb.config.N_CLASSES,
|
n_classes=wandb.config.N_CLASSES,
|
||||||
batch_size=wandb.config.BATCH_SIZE,
|
|
||||||
learning_rate=wandb.config.LEARNING_RATE,
|
|
||||||
features=wandb.config.FEATURES,
|
features=wandb.config.FEATURES,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -16,14 +16,12 @@ class_labels = {
|
||||||
|
|
||||||
|
|
||||||
class UNetModule(pl.LightningModule):
|
class UNetModule(pl.LightningModule):
|
||||||
def __init__(self, n_channels, n_classes, learning_rate, batch_size, features=[64, 128, 256, 512]):
|
def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]):
|
||||||
super(UNetModule, self).__init__()
|
super(UNetModule, self).__init__()
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
self.n_channels = n_channels
|
self.n_channels = n_channels
|
||||||
self.n_classes = n_classes
|
self.n_classes = n_classes
|
||||||
self.learning_rate = learning_rate
|
|
||||||
self.batch_size = batch_size
|
|
||||||
|
|
||||||
# log hyperparameters
|
# log hyperparameters
|
||||||
self.save_hyperparameters()
|
self.save_hyperparameters()
|
||||||
|
@ -111,7 +109,7 @@ class UNetModule(pl.LightningModule):
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = torch.optim.RMSprop(
|
optimizer = torch.optim.RMSprop(
|
||||||
self.parameters(),
|
self.parameters(),
|
||||||
lr=self.learning_rate,
|
lr=wandb.config.LEARNING_RATE,
|
||||||
weight_decay=wandb.config.WEIGHT_DECAY,
|
weight_decay=wandb.config.WEIGHT_DECAY,
|
||||||
momentum=wandb.config.MOMENTUM,
|
momentum=wandb.config.MOMENTUM,
|
||||||
)
|
)
|
||||||
|
|
|
@ -6,11 +6,11 @@ DIR_SPHERE:
|
||||||
value: "/home/lilian/data_disk/lfainsin/spheres+real/"
|
value: "/home/lilian/data_disk/lfainsin/spheres+real/"
|
||||||
|
|
||||||
FEATURES:
|
FEATURES:
|
||||||
value: { 8, 16, 32, 64 }
|
value: [8, 16, 32, 64]
|
||||||
N_CHANNELS:
|
N_CHANNELS:
|
||||||
value: 3,
|
value: 3
|
||||||
N_CLASSES:
|
N_CLASSES:
|
||||||
value: 1,
|
value: 1
|
||||||
|
|
||||||
AMP:
|
AMP:
|
||||||
value: True
|
value: True
|
||||||
|
@ -30,12 +30,14 @@ SPHERES:
|
||||||
|
|
||||||
EPOCHS:
|
EPOCHS:
|
||||||
value: 10
|
value: 10
|
||||||
BATCH_SIZE:
|
TRAIN_BATCH_SIZE:
|
||||||
value: 16
|
value: 16
|
||||||
|
VAL_BATCH_SIZE:
|
||||||
|
value: 8
|
||||||
|
|
||||||
LEARNING_RATE:
|
LEARNING_RATE:
|
||||||
value: 1e-4
|
value: 1.0e-4
|
||||||
WEIGHT_DECAY:
|
WEIGHT_DECAY:
|
||||||
value: 1e-8
|
value: 1.0e-8
|
||||||
MOMENTUM:
|
MOMENTUM:
|
||||||
value: 0.9
|
value: 0.9
|
Loading…
Reference in a new issue