diff --git a/src/dataset.py b/src/dataset.py index 2d4a041..746c0dd 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -14,28 +14,23 @@ train_ds = splits["train"] val_ds = splits["test"] test_ds = dataset["test"] -labels = train_ds.features['label'].names +labels = ["Not AI", "AI"] id2label = {k: v for k, v in enumerate(labels)} label2id = {v: k for k, v in enumerate(labels)} if __name__ == '__main__': - import cv2 - import numpy as np + import matplotlib.pyplot as plt print(f"labels:\n {labels}") print(f"label-id correspondances:\n {label2id}\n {id2label}") idx = 0 label = id2label[dataset['train'][idx]['label']] - image = cv2.cvtColor( - np.array(dataset['train'][idx]['image']), - cv2.COLOR_BGR2RGB) - cv2.namedWindow(label, cv2.WINDOW_NORMAL) - cv2.imshow(label, image) - image = cv2.cvtColor( - np.array(dataset['test'][idx]['image']), - cv2.COLOR_BGR2RGB) - cv2.namedWindow("Test", cv2.WINDOW_NORMAL) - cv2.imshow("Test", image) - cv2.waitKey(0) + plt.subplot(1, 2, 1) + plt.imshow(dataset['train'][idx]['image']) + plt.title(f"Label: {label}") + plt.subplot(1, 2, 2) + plt.imshow(dataset['test'][idx]['image']) + plt.title("Test") + plt.show() diff --git a/src/transform.py b/src/transform.py index 17e0b8b..9fd267c 100644 --- a/src/transform.py +++ b/src/transform.py @@ -6,6 +6,7 @@ from torchvision.transforms import ( RandomVerticalFlip, ToTensor, ) +from imgaug.augmenters import JpegCompression # get feature extractor (to normalize images) normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) @@ -13,9 +14,10 @@ normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # define train transform _train_transforms = Compose( [ - # AugMix(), + AugMix(), RandomHorizontalFlip(), RandomVerticalFlip(), + # lambda img : JpegCompression(compression=(0, 30))(image=img), ToTensor(), normalize, ] @@ -42,3 +44,24 @@ def val_transforms(examples): """Transforms for validation.""" examples["pixel_values"] = [_val_transforms(image.convert("RGB")) for image in examples["image"]] return examples + + +if __name__ == '__main__': + + from dataset import train_ds + import matplotlib.pyplot as plt + import numpy as np + + idx = 0 + img = train_ds[idx]['image'] + plt.subplot(1, 2, 1) + plt.imshow(img) + plt.title("Original") + img = _train_transforms(img.convert("RGB")) + img = np.array(img.permute(1, 2, 0)) + img -= img.min() + img /= img.max() + plt.subplot(1, 2, 2) + plt.imshow(img) + plt.title("Augmented") + plt.show() \ No newline at end of file