feat: JpegAug

This commit is contained in:
gdamms 2023-02-10 13:42:50 +01:00
parent c6942d325b
commit 89f37e6bbf

View file

@ -1,3 +1,4 @@
import numpy as np
from torchvision.transforms import (
AugMix,
Compose,
@ -17,7 +18,7 @@ _train_transforms = Compose(
AugMix(),
RandomHorizontalFlip(),
RandomVerticalFlip(),
# lambda img : JpegCompression(compression=(0, 30))(image=img),
lambda img : JpegCompression(compression=(0, 100))(image=np.array(img)),
ToTensor(),
normalize,
]
@ -50,14 +51,13 @@ if __name__ == '__main__':
from dataset import train_ds
import matplotlib.pyplot as plt
import numpy as np
idx = 0
img = train_ds[idx]['image']
plt.subplot(1, 2, 1)
plt.imshow(img)
plt.title("Original")
img = _train_transforms(img.convert("RGB"))
img = _train_transforms(img)
img = np.array(img.permute(1, 2, 0))
img -= img.min()
img /= img.max()