mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 22:42:02 +00:00
56 lines
1.1 KiB
Python
56 lines
1.1 KiB
Python
import PIL
|
|
import numpy as np
|
|
import random
|
|
|
|
|
|
def get_square(img, pos):
|
|
"""Extract a left or a right square from PILimg shape : (H, W, C))"""
|
|
img = np.array(img)
|
|
h = img.shape[0]
|
|
if pos == 0:
|
|
return img[:, :h]
|
|
else:
|
|
return img[:, -h:]
|
|
|
|
|
|
def resize_and_crop(pilimg, scale=0.5, final_height=None):
|
|
w = pilimg.size[0]
|
|
h = pilimg.size[1]
|
|
newW = int(w * scale)
|
|
newH = int(h * scale)
|
|
|
|
if not final_height:
|
|
diff = 0
|
|
else:
|
|
diff = newH - final_height
|
|
|
|
img = pilimg.resize((newW, newH))
|
|
img = img.crop((0, diff // 2, newW, newH - diff // 2))
|
|
return img
|
|
|
|
|
|
def batch(iterable, batch_size):
|
|
"""Yields lists by batch"""
|
|
b = []
|
|
for i, t in enumerate(iterable):
|
|
b.append(t)
|
|
if (i+1) % batch_size == 0:
|
|
yield b
|
|
b = []
|
|
|
|
if len(b) > 0:
|
|
yield b
|
|
|
|
|
|
def split_train_val(dataset, val_percent=0.05):
|
|
dataset = list(dataset)
|
|
length = len(dataset)
|
|
n = int(length * val_percent)
|
|
random.seed(42)
|
|
random.shuffle(dataset)
|
|
return {'train': dataset[:-n], 'val': dataset[-n:]}
|
|
|
|
|
|
def normalize(x):
|
|
return x / 255
|