diff --git a/src/train.py b/src/train.py index a48164e..b7475ce 100644 --- a/src/train.py +++ b/src/train.py @@ -26,8 +26,8 @@ def main(): wandb.init( project="U-Net", config=dict( - DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/val2017", - DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/", + DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/train/", + DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/val/", DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/", DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/", FEATURES=[64, 128, 256, 512], @@ -87,11 +87,10 @@ def main(): # 2. Create datasets ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train) ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid) - # ds_train_bg20k = SphereDataset(image_dir="/home/lilian/data_disk/lfainsin/BG-20k/train/", transform=tf_train) - # ds_valid_bg20k = SphereDataset(image_dir="/home/lilian/data_disk/lfainsin/BG-20k/testval/", transform=tf_valid) - # ds_train = torch.utils.data.ChainDataset([ds_train_coco, ds_train_bg20k]) - # ds_valid = torch.utils.data.ChainDataset([ds_valid_coco, ds_valid_bg20k]) # TODO: modifier la classe SphereDataset pour prendre plusieurs dossiers + # 2.5 Create subset, if uncommented + ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000))) + ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100))) # 3. Create data loaders train_loader = DataLoader( diff --git a/src/utils/dataset.py b/src/utils/dataset.py index a798731..5f47bf9 100644 --- a/src/utils/dataset.py +++ b/src/utils/dataset.py @@ -1,4 +1,4 @@ -import os +from pathlib import Path import numpy as np from PIL import Image @@ -7,16 +7,14 @@ from torch.utils.data import Dataset class SphereDataset(Dataset): def __init__(self, image_dir, transform=None): - self.image_dir = image_dir + self.images = list(Path(image_dir).glob("**/*.jpg")) self.transform = transform - self.images = os.listdir(image_dir) def __len__(self): return len(self.images) def __getitem__(self, index): - img_path = os.path.join(self.image_dir, self.images[index]) - image = np.array(Image.open(img_path).convert("RGB"), dtype=np.uint8) + image = np.array(Image.open(self.images[index]).convert("RGB"), dtype=np.uint8) mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)