projet-long/utils.py

55 lines
1.1 KiB
Python

import PIL
import numpy as np
import random
def get_square(img, pos):
"""Extract a left or a right square from PILimg shape : (H, W, C))"""
img = np.array(img)
h = img.shape[0]
if pos == 0:
return img[:, :h]
else:
return img[:, -h:]
def resize_and_crop(pilimg, scale=0.2, final_height=None):
w = pilimg.size[0]
h = pilimg.size[1]
newW = int(w * scale)
newH = int(h * scale)
if not final_height:
diff = 0
else:
diff = newH - final_height
img = pilimg.resize((newW, newH))
img = img.crop((0, diff // 2, newW, newH - diff // 2))
return img
def batch(iterable, batch_size):
"""Yields lists by batch"""
b = []
for i, t in enumerate(iterable):
b.append(t)
if (i+1) % batch_size == 0:
yield b
b = []
if len(b) > 0:
yield b
def split_train_val(dataset, val_percent=0.05):
dataset = list(dataset)
length = len(dataset)
n = int(length * val_percent)
random.shuffle(dataset)
return {'train': dataset[:-n], 'val': dataset[-n:]}
def normalize(x):
return x / 255