feat: show augmentation

This commit is contained in:
gdamms 2023-02-10 13:37:42 +01:00
parent 42ac3e0576
commit c6942d325b
2 changed files with 33 additions and 15 deletions

View file

@ -14,28 +14,23 @@ train_ds = splits["train"]
val_ds = splits["test"] val_ds = splits["test"]
test_ds = dataset["test"] test_ds = dataset["test"]
labels = train_ds.features['label'].names labels = ["Not AI", "AI"]
id2label = {k: v for k, v in enumerate(labels)} id2label = {k: v for k, v in enumerate(labels)}
label2id = {v: k for k, v in enumerate(labels)} label2id = {v: k for k, v in enumerate(labels)}
if __name__ == '__main__': if __name__ == '__main__':
import cv2 import matplotlib.pyplot as plt
import numpy as np
print(f"labels:\n {labels}") print(f"labels:\n {labels}")
print(f"label-id correspondances:\n {label2id}\n {id2label}") print(f"label-id correspondances:\n {label2id}\n {id2label}")
idx = 0 idx = 0
label = id2label[dataset['train'][idx]['label']] label = id2label[dataset['train'][idx]['label']]
image = cv2.cvtColor( plt.subplot(1, 2, 1)
np.array(dataset['train'][idx]['image']), plt.imshow(dataset['train'][idx]['image'])
cv2.COLOR_BGR2RGB) plt.title(f"Label: {label}")
cv2.namedWindow(label, cv2.WINDOW_NORMAL) plt.subplot(1, 2, 2)
cv2.imshow(label, image) plt.imshow(dataset['test'][idx]['image'])
image = cv2.cvtColor( plt.title("Test")
np.array(dataset['test'][idx]['image']), plt.show()
cv2.COLOR_BGR2RGB)
cv2.namedWindow("Test", cv2.WINDOW_NORMAL)
cv2.imshow("Test", image)
cv2.waitKey(0)

View file

@ -6,6 +6,7 @@ from torchvision.transforms import (
RandomVerticalFlip, RandomVerticalFlip,
ToTensor, ToTensor,
) )
from imgaug.augmenters import JpegCompression
# get feature extractor (to normalize images) # get feature extractor (to normalize images)
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
@ -13,9 +14,10 @@ normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# define train transform # define train transform
_train_transforms = Compose( _train_transforms = Compose(
[ [
# AugMix(), AugMix(),
RandomHorizontalFlip(), RandomHorizontalFlip(),
RandomVerticalFlip(), RandomVerticalFlip(),
# lambda img : JpegCompression(compression=(0, 30))(image=img),
ToTensor(), ToTensor(),
normalize, normalize,
] ]
@ -42,3 +44,24 @@ def val_transforms(examples):
"""Transforms for validation.""" """Transforms for validation."""
examples["pixel_values"] = [_val_transforms(image.convert("RGB")) for image in examples["image"]] examples["pixel_values"] = [_val_transforms(image.convert("RGB")) for image in examples["image"]]
return examples return examples
if __name__ == '__main__':
from dataset import train_ds
import matplotlib.pyplot as plt
import numpy as np
idx = 0
img = train_ds[idx]['image']
plt.subplot(1, 2, 1)
plt.imshow(img)
plt.title("Original")
img = _train_transforms(img.convert("RGB"))
img = np.array(img.permute(1, 2, 0))
img -= img.min()
img /= img.max()
plt.subplot(1, 2, 2)
plt.imshow(img)
plt.title("Augmented")
plt.show()