feat: combine multiple datasets

Former-commit-id: bec67c2948227cdbfe05bb8f69e8083f7591e965 [formerly b24d77ef7c42456893a0b797e0a6c925a2402d08]
Former-commit-id: 391e17bf37e44151215179cde2d157f8e5bb5182
This commit is contained in:
Laurent Fainsin 2022-07-01 12:00:25 +02:00
parent beafa768f7
commit d839aec1af
2 changed files with 8 additions and 11 deletions

View file

@ -26,8 +26,8 @@ def main():
wandb.init(
project="U-Net",
config=dict(
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/val2017",
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/",
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/train/",
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/val/",
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
FEATURES=[64, 128, 256, 512],
@ -87,11 +87,10 @@ def main():
# 2. Create datasets
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid)
# ds_train_bg20k = SphereDataset(image_dir="/home/lilian/data_disk/lfainsin/BG-20k/train/", transform=tf_train)
# ds_valid_bg20k = SphereDataset(image_dir="/home/lilian/data_disk/lfainsin/BG-20k/testval/", transform=tf_valid)
# ds_train = torch.utils.data.ChainDataset([ds_train_coco, ds_train_bg20k])
# ds_valid = torch.utils.data.ChainDataset([ds_valid_coco, ds_valid_bg20k]) # TODO: modifier la classe SphereDataset pour prendre plusieurs dossiers
# 2.5 Create subset, if uncommented
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000)))
ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100)))
# 3. Create data loaders
train_loader = DataLoader(

View file

@ -1,4 +1,4 @@
import os
from pathlib import Path
import numpy as np
from PIL import Image
@ -7,16 +7,14 @@ from torch.utils.data import Dataset
class SphereDataset(Dataset):
def __init__(self, image_dir, transform=None):
self.image_dir = image_dir
self.images = list(Path(image_dir).glob("**/*.jpg"))
self.transform = transform
self.images = os.listdir(image_dir)
def __len__(self):
return len(self.images)
def __getitem__(self, index):
img_path = os.path.join(self.image_dir, self.images[index])
image = np.array(Image.open(img_path).convert("RGB"), dtype=np.uint8)
image = np.array(Image.open(self.images[index]).convert("RGB"), dtype=np.uint8)
mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)