diff --git a/src/train.py b/src/train.py index a1e54c8..84ccb87 100644 --- a/src/train.py +++ b/src/train.py @@ -10,9 +10,8 @@ from unet import UNet CONFIG = { "DIR_TRAIN_IMG": "/home/lilian/data_disk/lfainsin/train/", - "DIR_VALID_IMG": "/home/lilian/data_disk/lfainsin/val/", - "DIR_TEST_IMG": "/home/lilian/data_disk/lfainsin/test/", - "DIR_SPHERE": "/home/lilian/data_disk/lfainsin/realspheres/", + "DIR_VALID_IMG": "//home/lilian/data_disk/lfainsin/test_split/", + "DIR_SPHERE": "/home/lilian/data_disk/lfainsin/spheres+real_split/", "FEATURES": [8, 16, 32, 64], "N_CHANNELS": 3, "N_CLASSES": 1, diff --git a/src/unet/model.py b/src/unet/model.py index c244735..93872f5 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -93,7 +93,7 @@ class UNet(pl.LightningModule): ) def val_dataloader(self): - ds_valid = SphereDataset(image_dir=wandb.config.DIR_TEST_IMG) + ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG) return DataLoader( ds_valid,