Carvana dataset loader
Former-commit-id: 4ad8323b2e54c7bbcf8968ca7fc0f2e3ddd87689
This commit is contained in:
parent
04c2cb1ed4
commit
8780e424b4
|
@ -9,10 +9,11 @@ from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
class BasicDataset(Dataset):
|
class BasicDataset(Dataset):
|
||||||
def __init__(self, imgs_dir, masks_dir, scale=1):
|
def __init__(self, imgs_dir, masks_dir, scale=1, mask_suffix=''):
|
||||||
self.imgs_dir = imgs_dir
|
self.imgs_dir = imgs_dir
|
||||||
self.masks_dir = masks_dir
|
self.masks_dir = masks_dir
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
self.mask_suffix = mask_suffix
|
||||||
assert 0 < scale <= 1, 'Scale must be between 0 and 1'
|
assert 0 < scale <= 1, 'Scale must be between 0 and 1'
|
||||||
|
|
||||||
self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
|
self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
|
||||||
|
@ -43,7 +44,7 @@ class BasicDataset(Dataset):
|
||||||
|
|
||||||
def __getitem__(self, i):
|
def __getitem__(self, i):
|
||||||
idx = self.ids[i]
|
idx = self.ids[i]
|
||||||
mask_file = glob(self.masks_dir + idx + '.*')
|
mask_file = glob(self.masks_dir + idx + self.mask_suffix + '.*')
|
||||||
img_file = glob(self.imgs_dir + idx + '.*')
|
img_file = glob(self.imgs_dir + idx + '.*')
|
||||||
|
|
||||||
assert len(mask_file) == 1, \
|
assert len(mask_file) == 1, \
|
||||||
|
@ -63,3 +64,8 @@ class BasicDataset(Dataset):
|
||||||
'image': torch.from_numpy(img).type(torch.FloatTensor),
|
'image': torch.from_numpy(img).type(torch.FloatTensor),
|
||||||
'mask': torch.from_numpy(mask).type(torch.FloatTensor)
|
'mask': torch.from_numpy(mask).type(torch.FloatTensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CarvanaDataset(BasicDataset):
|
||||||
|
def __init__(self, imgs_dir, masks_dir, scale=1):
|
||||||
|
super().__init__(imgs_dir, masks_dir, scale, mask_suffix='_mask')
|
||||||
|
|
Loading…
Reference in a new issue