feat: combine multiple datasets
Former-commit-id: bec67c2948227cdbfe05bb8f69e8083f7591e965 [formerly b24d77ef7c42456893a0b797e0a6c925a2402d08] Former-commit-id: 391e17bf37e44151215179cde2d157f8e5bb5182
This commit is contained in:
parent
beafa768f7
commit
d839aec1af
11
src/train.py
11
src/train.py
|
@ -26,8 +26,8 @@ def main():
|
|||
wandb.init(
|
||||
project="U-Net",
|
||||
config=dict(
|
||||
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/val2017",
|
||||
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/",
|
||||
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/train/",
|
||||
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/val/",
|
||||
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
|
||||
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
||||
FEATURES=[64, 128, 256, 512],
|
||||
|
@ -87,11 +87,10 @@ def main():
|
|||
# 2. Create datasets
|
||||
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
|
||||
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid)
|
||||
# ds_train_bg20k = SphereDataset(image_dir="/home/lilian/data_disk/lfainsin/BG-20k/train/", transform=tf_train)
|
||||
# ds_valid_bg20k = SphereDataset(image_dir="/home/lilian/data_disk/lfainsin/BG-20k/testval/", transform=tf_valid)
|
||||
|
||||
# ds_train = torch.utils.data.ChainDataset([ds_train_coco, ds_train_bg20k])
|
||||
# ds_valid = torch.utils.data.ChainDataset([ds_valid_coco, ds_valid_bg20k]) # TODO: modifier la classe SphereDataset pour prendre plusieurs dossiers
|
||||
# 2.5 Create subset, if uncommented
|
||||
ds_train = torch.utils.data.Subset(ds_train, list(range(0, len(ds_train), len(ds_train) // 5000)))
|
||||
ds_valid = torch.utils.data.Subset(ds_valid, list(range(0, len(ds_valid), len(ds_valid) // 100)))
|
||||
|
||||
# 3. Create data loaders
|
||||
train_loader = DataLoader(
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
@ -7,16 +7,14 @@ from torch.utils.data import Dataset
|
|||
|
||||
class SphereDataset(Dataset):
|
||||
def __init__(self, image_dir, transform=None):
|
||||
self.image_dir = image_dir
|
||||
self.images = list(Path(image_dir).glob("**/*.jpg"))
|
||||
self.transform = transform
|
||||
self.images = os.listdir(image_dir)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_path = os.path.join(self.image_dir, self.images[index])
|
||||
image = np.array(Image.open(img_path).convert("RGB"), dtype=np.uint8)
|
||||
image = np.array(Image.open(self.images[index]).convert("RGB"), dtype=np.uint8)
|
||||
|
||||
mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
||||
|
||||
|
|
Loading…
Reference in a new issue