mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 22:42:02 +00:00
48 lines
1.2 KiB
Python
48 lines
1.2 KiB
Python
import os
|
|
import random
|
|
import numpy as np
|
|
from PIL import Image
|
|
from functools import partial
|
|
from utils import resize_and_crop, get_square
|
|
|
|
|
|
def get_ids(dir):
|
|
"""Returns a list of the ids in the directory"""
|
|
return (f[:-4] for f in os.listdir(dir))
|
|
|
|
def split_ids(ids, n=2):
|
|
"""Split each id in n, creating n tuples (id, k) for each id"""
|
|
return ((id, i) for i in range(n) for id in ids)
|
|
|
|
def shuffle_ids(ids):
|
|
"""Returns a shuffle list od the ids"""
|
|
lst = list(ids)
|
|
random.shuffle(lst)
|
|
return lst
|
|
|
|
def to_cropped_imgs(ids, dir, suffix):
|
|
"""From a list of tuples, returns the correct cropped img (left or right)"""
|
|
for id, pos in ids:
|
|
im = resize_and_crop(Image.open(dir + id + suffix))
|
|
yield get_square(im, pos)
|
|
|
|
|
|
|
|
def get_imgs_and_masks():
|
|
"""From the list of ids, return the couples (img, mask)"""
|
|
dir_img = 'data/train/'
|
|
dir_mask = 'data/train_masks/'
|
|
|
|
ids = get_ids(dir_img)
|
|
ids = split_ids(ids)
|
|
ids = shuffle_ids(ids)
|
|
|
|
imgs = to_cropped_imgs(ids, dir_img, '.jpg')
|
|
|
|
# need to transform from HWC to CHW
|
|
imgs_switched = map(partial(np.transpose, axes=[2, 0, 1]), imgs)
|
|
|
|
masks = to_cropped_imgs(ids, dir_mask, '_mask.gif')
|
|
|
|
return zip(imgs_switched, masks)
|