feat: JpegAug

This commit is contained in:
gdamms 2023-02-10 13:42:50 +01:00
parent c6942d325b
commit 89f37e6bbf

View file

@ -1,3 +1,4 @@
import numpy as np
from torchvision.transforms import ( from torchvision.transforms import (
AugMix, AugMix,
Compose, Compose,
@ -17,7 +18,7 @@ _train_transforms = Compose(
AugMix(), AugMix(),
RandomHorizontalFlip(), RandomHorizontalFlip(),
RandomVerticalFlip(), RandomVerticalFlip(),
# lambda img : JpegCompression(compression=(0, 30))(image=img), lambda img : JpegCompression(compression=(0, 100))(image=np.array(img)),
ToTensor(), ToTensor(),
normalize, normalize,
] ]
@ -50,14 +51,13 @@ if __name__ == '__main__':
from dataset import train_ds from dataset import train_ds
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np
idx = 0 idx = 0
img = train_ds[idx]['image'] img = train_ds[idx]['image']
plt.subplot(1, 2, 1) plt.subplot(1, 2, 1)
plt.imshow(img) plt.imshow(img)
plt.title("Original") plt.title("Original")
img = _train_transforms(img.convert("RGB")) img = _train_transforms(img)
img = np.array(img.permute(1, 2, 0)) img = np.array(img.permute(1, 2, 0))
img -= img.min() img -= img.min()
img /= img.max() img /= img.max()