feat: show augmentation
This commit is contained in:
parent
42ac3e0576
commit
c6942d325b
|
@ -14,28 +14,23 @@ train_ds = splits["train"]
|
||||||
val_ds = splits["test"]
|
val_ds = splits["test"]
|
||||||
test_ds = dataset["test"]
|
test_ds = dataset["test"]
|
||||||
|
|
||||||
labels = train_ds.features['label'].names
|
labels = ["Not AI", "AI"]
|
||||||
id2label = {k: v for k, v in enumerate(labels)}
|
id2label = {k: v for k, v in enumerate(labels)}
|
||||||
label2id = {v: k for k, v in enumerate(labels)}
|
label2id = {v: k for k, v in enumerate(labels)}
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
import cv2
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
print(f"labels:\n {labels}")
|
print(f"labels:\n {labels}")
|
||||||
print(f"label-id correspondances:\n {label2id}\n {id2label}")
|
print(f"label-id correspondances:\n {label2id}\n {id2label}")
|
||||||
|
|
||||||
idx = 0
|
idx = 0
|
||||||
label = id2label[dataset['train'][idx]['label']]
|
label = id2label[dataset['train'][idx]['label']]
|
||||||
image = cv2.cvtColor(
|
plt.subplot(1, 2, 1)
|
||||||
np.array(dataset['train'][idx]['image']),
|
plt.imshow(dataset['train'][idx]['image'])
|
||||||
cv2.COLOR_BGR2RGB)
|
plt.title(f"Label: {label}")
|
||||||
cv2.namedWindow(label, cv2.WINDOW_NORMAL)
|
plt.subplot(1, 2, 2)
|
||||||
cv2.imshow(label, image)
|
plt.imshow(dataset['test'][idx]['image'])
|
||||||
image = cv2.cvtColor(
|
plt.title("Test")
|
||||||
np.array(dataset['test'][idx]['image']),
|
plt.show()
|
||||||
cv2.COLOR_BGR2RGB)
|
|
||||||
cv2.namedWindow("Test", cv2.WINDOW_NORMAL)
|
|
||||||
cv2.imshow("Test", image)
|
|
||||||
cv2.waitKey(0)
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ from torchvision.transforms import (
|
||||||
RandomVerticalFlip,
|
RandomVerticalFlip,
|
||||||
ToTensor,
|
ToTensor,
|
||||||
)
|
)
|
||||||
|
from imgaug.augmenters import JpegCompression
|
||||||
|
|
||||||
# get feature extractor (to normalize images)
|
# get feature extractor (to normalize images)
|
||||||
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
|
@ -13,9 +14,10 @@ normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
# define train transform
|
# define train transform
|
||||||
_train_transforms = Compose(
|
_train_transforms = Compose(
|
||||||
[
|
[
|
||||||
# AugMix(),
|
AugMix(),
|
||||||
RandomHorizontalFlip(),
|
RandomHorizontalFlip(),
|
||||||
RandomVerticalFlip(),
|
RandomVerticalFlip(),
|
||||||
|
# lambda img : JpegCompression(compression=(0, 30))(image=img),
|
||||||
ToTensor(),
|
ToTensor(),
|
||||||
normalize,
|
normalize,
|
||||||
]
|
]
|
||||||
|
@ -42,3 +44,24 @@ def val_transforms(examples):
|
||||||
"""Transforms for validation."""
|
"""Transforms for validation."""
|
||||||
examples["pixel_values"] = [_val_transforms(image.convert("RGB")) for image in examples["image"]]
|
examples["pixel_values"] = [_val_transforms(image.convert("RGB")) for image in examples["image"]]
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
from dataset import train_ds
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
img = train_ds[idx]['image']
|
||||||
|
plt.subplot(1, 2, 1)
|
||||||
|
plt.imshow(img)
|
||||||
|
plt.title("Original")
|
||||||
|
img = _train_transforms(img.convert("RGB"))
|
||||||
|
img = np.array(img.permute(1, 2, 0))
|
||||||
|
img -= img.min()
|
||||||
|
img /= img.max()
|
||||||
|
plt.subplot(1, 2, 2)
|
||||||
|
plt.imshow(img)
|
||||||
|
plt.title("Augmented")
|
||||||
|
plt.show()
|
Reference in a new issue