diff --git a/load.py b/load.py new file mode 100644 index 0000000..a981f03 --- /dev/null +++ b/load.py @@ -0,0 +1,47 @@ +import os +import random +import numpy as np +from PIL import Image +from functools import partial +from utils import resize_and_crop, get_square + + +def get_ids(dir): + """Returns a list of the ids in the directory""" + return (f[:-4] for f in os.listdir(dir)) + +def split_ids(ids, n=2): + """Split each id in n, creating n tuples (id, k) for each id""" + return ((id, i) for i in range(n) for id in ids) + +def shuffle_ids(ids): + """Returns a shuffle list od the ids""" + lst = list(ids) + random.shuffle(lst) + return lst + +def to_cropped_imgs(ids, dir, suffix): + """From a list of tuples, returns the correct cropped img (left or right)""" + for id, pos in ids: + im = resize_and_crop(Image.open(dir + id + suffix)) + yield get_square(im, pos) + + + +def get_imgs_and_masks(): + """From the list of ids, return the couples (img, mask)""" + dir_img = 'data/train/' + dir_mask = 'data/train_masks/' + + ids = get_ids(dir_img) + ids = split_ids(ids) + ids = shuffle_ids(ids) + + imgs = to_cropped_imgs(ids, dir_img, '.jpg') + + # need to transform from HWC to CHW + imgs_switched = map(partial(np.transpose, axes=[2, 0, 1]), imgs) + + masks = to_cropped_imgs(ids, dir_mask, '_mask.gif') + + return zip(imgs_switched, masks) diff --git a/utils.py b/utils.py index 2664254..9732526 100644 --- a/utils.py +++ b/utils.py @@ -1,16 +1,18 @@ import PIL +import numpy as np -def split_into_squares(img): - """Extract a left and a right square from ndarray""" +def get_square(img, pos): + """Extract a left or a right square from PILimg""" """shape : (H, W, C))""" + img = np.array(img) + h = img.shape[0] w = img.shape[1] - - left = img[:, :h] - right = img[:, -h:] - - return left, right + if pos == 0: + return img[:, :h] + else: + return img[:, -h:] def resize_and_crop(pilimg, scale=0.5, final_height=640): w = pilimg.size[0]