Compare commits
No commits in common. "a6411301cd05a7989e601414c5e25548fca189ed" and "89f9112fca82e1cdbcdb0ef7b25092f1d36fae2f" have entirely different histories.
a6411301cd
...
89f9112fca
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -1,8 +1,5 @@
|
||||||
.direnv/
|
.direnv/
|
||||||
data/
|
data/
|
||||||
test-aiornot/
|
|
||||||
submissions/
|
|
||||||
lightning_logs/
|
|
||||||
|
|
||||||
# https://github.com/github/gitignore/blob/main/Python.gitignore
|
# https://github.com/github/gitignore/blob/main/Python.gitignore
|
||||||
# Basic .gitignore for a python repo.
|
# Basic .gitignore for a python repo.
|
||||||
|
|
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
[submodule "aiornot_datasets"]
|
||||||
|
path = aiornot_datasets
|
||||||
|
url = https://huggingface.co/datasets/tocard-inc/aiornot
|
2
.vscode/settings.json
vendored
2
.vscode/settings.json
vendored
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
// "python.defaultInterpreterPath": ".venv/bin/python",
|
// "python.defaultInterpreterPath": ".venv/bin/python",
|
||||||
"python.analysis.typeCheckingMode": "off",
|
"python.analysis.typeCheckingMode": "basic",
|
||||||
"python.formatting.provider": "black",
|
"python.formatting.provider": "black",
|
||||||
"editor.formatOnSave": true,
|
"editor.formatOnSave": true,
|
||||||
"python.linting.enabled": true,
|
"python.linting.enabled": true,
|
||||||
|
|
2
LICENSE
2
LICENSE
|
@ -1,6 +1,6 @@
|
||||||
MIT License
|
MIT License
|
||||||
|
|
||||||
Copyright (c) 2023 Laurent Fainsin & Damien Guillotin
|
Copyright (c) 2023 Tocard-Inc
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
|
|
@ -1,5 +1 @@
|
||||||
# AIorNot
|
# AIorNot
|
||||||
|
|
||||||
https://huggingface.co/spaces/competitions/aiornot
|
|
||||||
|
|
||||||
8/98
|
|
||||||
|
|
1
aiornot_datasets
Submodule
1
aiornot_datasets
Submodule
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit a90618df992a19c775b6b0fb7e0de0fd45a4d505
|
1347
poetry.lock
generated
1347
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -11,13 +11,10 @@ version = "0.1.0"
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
albumentations = "^1.3.0"
|
albumentations = "^1.3.0"
|
||||||
python = ">=3.8.1,<4.0"
|
python = ">=3.8.1,<4.0"
|
||||||
rich = "^13.3.1"
|
rich = "^12.6.0"
|
||||||
torch = "^1.13.1"
|
torch = "^1.13.1"
|
||||||
datasets = "^2.9.0"
|
datasets = "^2.9.0"
|
||||||
transformers = "^4.26.0"
|
transformers = "^4.26.0"
|
||||||
evaluate = "^0.4.0"
|
|
||||||
pytorch-lightning = "^1.9.0"
|
|
||||||
timm = "^0.6.12"
|
|
||||||
|
|
||||||
[tool.poetry.group.notebooks]
|
[tool.poetry.group.notebooks]
|
||||||
optional = true
|
optional = true
|
||||||
|
@ -25,8 +22,6 @@ optional = true
|
||||||
[tool.poetry.group.notebooks.dependencies]
|
[tool.poetry.group.notebooks.dependencies]
|
||||||
ipykernel = "^6.20.2"
|
ipykernel = "^6.20.2"
|
||||||
matplotlib = "^3.6.3"
|
matplotlib = "^3.6.3"
|
||||||
ipywidgets = "^8.0.4"
|
|
||||||
jupyter = "^1.0.0"
|
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
Flake8-pyproject = "^1.1.0"
|
Flake8-pyproject = "^1.1.0"
|
||||||
|
|
0
requirements.txt
Normal file
0
requirements.txt
Normal file
|
@ -1,21 +0,0 @@
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
def BinaryCrossEntropy(y_true, y_pred):
|
|
||||||
y_pred = np.clip(y_pred, 1e-7, 1 - 1e-7)
|
|
||||||
term_0 = (1-y_true) * np.log(1-y_pred)
|
|
||||||
term_1 = y_true * np.log(y_pred)
|
|
||||||
return -np.mean(term_0+term_1, axis=0)
|
|
||||||
|
|
||||||
nb_tests = 43444
|
|
||||||
|
|
||||||
acc = 0.977
|
|
||||||
|
|
||||||
labels = np.ones(nb_tests)
|
|
||||||
|
|
||||||
nb_true = int(acc * nb_tests)
|
|
||||||
predicitions = np.concatenate((np.ones(nb_true), np.zeros(nb_tests - nb_true)))
|
|
||||||
|
|
||||||
logloss = BinaryCrossEntropy(labels, predicitions)
|
|
||||||
|
|
||||||
print(f"Accuracy: {acc}")
|
|
||||||
print(f"logloss: {logloss}")
|
|
4382
src/aiornot_baseline.ipynb
Normal file
4382
src/aiornot_baseline.ipynb
Normal file
File diff suppressed because one or more lines are too long
|
@ -1,28 +0,0 @@
|
||||||
from dataset import val_ds
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from rich.progress import track
|
|
||||||
|
|
||||||
|
|
||||||
val_labels = []
|
|
||||||
val_paths = []
|
|
||||||
for val_data in track(val_ds):
|
|
||||||
val_labels.append(val_data["label"])
|
|
||||||
val_paths.append(val_data["image_path"])
|
|
||||||
|
|
||||||
with open("results.csv", 'r') as f:
|
|
||||||
lines = f.read().splitlines()
|
|
||||||
lines = [line.split(',') for line in lines[1:]]
|
|
||||||
|
|
||||||
res = {0: [], 1: []}
|
|
||||||
for img_path, val_pred in lines:
|
|
||||||
index = val_paths.index(img_path)
|
|
||||||
label = val_labels[index]
|
|
||||||
if label == 1 and float(val_pred) < 0.5:
|
|
||||||
print(img_path)
|
|
||||||
res[label].append(float(val_pred))
|
|
||||||
|
|
||||||
plt.hist(res[0], bins=30, alpha=0.5, label="0", color="red")
|
|
||||||
plt.hist(res[1], bins=30, alpha=0.5, label="1", color="blue")
|
|
||||||
plt.yscale('log')
|
|
||||||
plt.legend()
|
|
||||||
plt.show()
|
|
|
@ -1,36 +0,0 @@
|
||||||
import datasets
|
|
||||||
|
|
||||||
# load dataset
|
|
||||||
dataset = datasets.load_dataset("competitions/aiornot")
|
|
||||||
|
|
||||||
# split up training into training + validation
|
|
||||||
splits = dataset["train"].train_test_split(test_size=0.1)
|
|
||||||
|
|
||||||
# define train, validation and test datasets
|
|
||||||
train_ds = splits["train"]
|
|
||||||
val_ds = splits["test"]
|
|
||||||
test_ds = dataset["test"]
|
|
||||||
|
|
||||||
labels = ["NOT_AI", "AI"]
|
|
||||||
id2label = {k: v for k, v in enumerate(labels)}
|
|
||||||
label2id = {v: k for k, v in enumerate(labels)}
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
print(f"labels:\n {labels}")
|
|
||||||
print(f"label-id correspondances:\n {label2id}\n {id2label}")
|
|
||||||
|
|
||||||
idx = 0
|
|
||||||
label = id2label[dataset["train"][idx]["label"]]
|
|
||||||
|
|
||||||
plt.subplot(1, 2, 1)
|
|
||||||
plt.imshow(dataset["train"][idx]["image"])
|
|
||||||
plt.title(f"Label: {label}")
|
|
||||||
|
|
||||||
plt.subplot(1, 2, 2)
|
|
||||||
plt.imshow(dataset["test"][idx]["image"])
|
|
||||||
plt.title("Test")
|
|
||||||
|
|
||||||
plt.show()
|
|
80
src/main.py
80
src/main.py
|
@ -1,80 +0,0 @@
|
||||||
import sys
|
|
||||||
from typing import List # TODO: update to python 3.11
|
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
import torch
|
|
||||||
from pytorch_lightning.callbacks import (
|
|
||||||
ModelCheckpoint,
|
|
||||||
RichModelSummary,
|
|
||||||
RichProgressBar,
|
|
||||||
)
|
|
||||||
from rich.progress import track
|
|
||||||
|
|
||||||
from dataset import test_ds
|
|
||||||
from model import AIorNOT
|
|
||||||
from parse import parse_args
|
|
||||||
|
|
||||||
|
|
||||||
def main(argv: List[str]) -> None:
|
|
||||||
"""Main entrypoint for training and inference."""
|
|
||||||
# parse args
|
|
||||||
args = parse_args(argv)
|
|
||||||
print(args)
|
|
||||||
|
|
||||||
# stfu warnings
|
|
||||||
torch.set_float32_matmul_precision("medium")
|
|
||||||
|
|
||||||
# set seed
|
|
||||||
pl.seed_everything(args.seed, workers=True)
|
|
||||||
|
|
||||||
if args.load_ckpt:
|
|
||||||
# get checkpointed model
|
|
||||||
model = AIorNOT.load_from_checkpoint(
|
|
||||||
args.load_ckpt,
|
|
||||||
args=args,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# get model
|
|
||||||
model = AIorNOT(args)
|
|
||||||
|
|
||||||
# # compile model
|
|
||||||
# model.net = torch.compile(model.net)
|
|
||||||
|
|
||||||
# define trainer
|
|
||||||
trainer = pl.Trainer(
|
|
||||||
accelerator="gpu",
|
|
||||||
devices="auto",
|
|
||||||
strategy="dp",
|
|
||||||
max_epochs=args.epochs,
|
|
||||||
precision="bf16",
|
|
||||||
log_every_n_steps=25,
|
|
||||||
val_check_interval=100,
|
|
||||||
benchmark=True,
|
|
||||||
callbacks=[
|
|
||||||
ModelCheckpoint(mode="max", monitor="val_acc"),
|
|
||||||
ModelCheckpoint(save_on_train_epoch_end=True),
|
|
||||||
RichModelSummary(max_depth=2),
|
|
||||||
RichProgressBar(),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# train model
|
|
||||||
trainer.fit(model)
|
|
||||||
|
|
||||||
if not args.skip_csv:
|
|
||||||
# make predictions on test set
|
|
||||||
test_results = trainer.predict(model, dataloaders=model.test_dataloader())
|
|
||||||
|
|
||||||
# save predictions to csv
|
|
||||||
# TODO: define track upper bound
|
|
||||||
with open(f"submissions/results_{trainer.logger.version}.csv", "w") as f:
|
|
||||||
i = 0
|
|
||||||
f.write("id,label\n")
|
|
||||||
for test_result in track(test_results):
|
|
||||||
for logit in test_result.float().sigmoid():
|
|
||||||
f.write(f"{test_ds[i]['id']},{float(logit)}\n")
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main(sys.argv[1:])
|
|
125
src/model.py
125
src/model.py
|
@ -1,125 +0,0 @@
|
||||||
import pytorch_lightning as pl
|
|
||||||
import timm
|
|
||||||
import torch
|
|
||||||
import torchmetrics
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from transformers import get_cosine_with_hard_restarts_schedule_with_warmup
|
|
||||||
|
|
||||||
from dataset import test_ds, train_ds, val_ds
|
|
||||||
from transform import train_transforms, val_transforms
|
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(examples):
|
|
||||||
"""Collate function for training and validation."""
|
|
||||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
|
||||||
labels = torch.tensor([example["label"] for example in examples])
|
|
||||||
|
|
||||||
return pixel_values, labels
|
|
||||||
|
|
||||||
|
|
||||||
def collate_fn_test(examples):
|
|
||||||
"""Collate function for testing."""
|
|
||||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
|
||||||
|
|
||||||
return pixel_values
|
|
||||||
|
|
||||||
|
|
||||||
class AIorNOT(pl.LightningModule):
|
|
||||||
"""AIorNOT model."""
|
|
||||||
|
|
||||||
def __init__(self, args):
|
|
||||||
"""Initialize model."""
|
|
||||||
super().__init__()
|
|
||||||
self.args = args
|
|
||||||
self.save_hyperparameters()
|
|
||||||
|
|
||||||
self.net = timm.create_model(args.model_name, pretrained=True, num_classes=1)
|
|
||||||
|
|
||||||
self.criterion = torch.nn.BCEWithLogitsLoss()
|
|
||||||
self.val_accuracy = torchmetrics.Accuracy("binary")
|
|
||||||
# TODO: add train_accuracy
|
|
||||||
|
|
||||||
def forward(self, pixel_values):
|
|
||||||
"""Forward pass."""
|
|
||||||
outputs = self.net(pixel_values)
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def common_step(self, batch, batch_idx):
|
|
||||||
"""Common step for training and validation."""
|
|
||||||
pixel_values, labels = batch
|
|
||||||
labels = labels.float()
|
|
||||||
logits = self(pixel_values).squeeze(1).float()
|
|
||||||
loss = self.criterion(logits, labels)
|
|
||||||
|
|
||||||
return loss, logits.sigmoid(), labels
|
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
|
||||||
"""Training step."""
|
|
||||||
loss, _, _ = self.common_step(batch, batch_idx)
|
|
||||||
self.log("train_loss", loss)
|
|
||||||
self.log("lr", self.optimizers().param_groups[0]["lr"])
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
|
||||||
"""Validation step."""
|
|
||||||
loss, preds, targets = self.common_step(batch, batch_idx)
|
|
||||||
self.log("val_loss", loss, on_epoch=True)
|
|
||||||
|
|
||||||
return preds, targets
|
|
||||||
|
|
||||||
def validation_epoch_end(self, outputs):
|
|
||||||
"""Validation epoch end."""
|
|
||||||
preds = torch.cat([x[0] for x in outputs])
|
|
||||||
targets = torch.cat([x[1] for x in outputs])
|
|
||||||
acc = self.val_accuracy(preds, targets)
|
|
||||||
self.log("val_acc", acc, prog_bar=True)
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
"""Configure optimizers."""
|
|
||||||
optimizer = torch.optim.Adam(self.parameters(), self.args.lr, weight_decay=self.args.weight_decay)
|
|
||||||
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.warmup_steps,
|
|
||||||
num_training_steps=self.trainer.estimated_stepping_batches,
|
|
||||||
)
|
|
||||||
|
|
||||||
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
|
|
||||||
|
|
||||||
def train_dataloader(self):
|
|
||||||
"""Train dataloader."""
|
|
||||||
return DataLoader(
|
|
||||||
train_ds.with_transform(train_transforms),
|
|
||||||
shuffle=True,
|
|
||||||
pin_memory=True,
|
|
||||||
collate_fn=collate_fn,
|
|
||||||
persistent_workers=True,
|
|
||||||
num_workers=self.args.num_workers,
|
|
||||||
batch_size=self.args.batch_size,
|
|
||||||
prefetch_factor=self.args.prefetch_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
def val_dataloader(self):
|
|
||||||
"""Validation dataloader."""
|
|
||||||
return DataLoader(
|
|
||||||
val_ds.with_transform(val_transforms),
|
|
||||||
pin_memory=True,
|
|
||||||
collate_fn=collate_fn,
|
|
||||||
persistent_workers=True,
|
|
||||||
num_workers=self.args.num_workers,
|
|
||||||
batch_size=self.args.batch_size_val,
|
|
||||||
prefetch_factor=self.args.prefetch_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_dataloader(self):
|
|
||||||
"""Test dataloader."""
|
|
||||||
return DataLoader(
|
|
||||||
test_ds.with_transform(val_transforms),
|
|
||||||
pin_memory=True,
|
|
||||||
collate_fn=collate_fn_test,
|
|
||||||
persistent_workers=True,
|
|
||||||
num_workers=self.args.num_workers,
|
|
||||||
batch_size=self.args.batch_size_val,
|
|
||||||
prefetch_factor=self.args.prefetch_factor,
|
|
||||||
)
|
|
82
src/parse.py
82
src/parse.py
|
@ -1,82 +0,0 @@
|
||||||
import argparse
|
|
||||||
from typing import List # TODO: update to python 3.11
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args(argv: List[str]) -> argparse.Namespace:
|
|
||||||
"""Parse command line arguments."""
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Train and inference for AIorNOT challenge",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--seed",
|
|
||||||
type=int,
|
|
||||||
default=42,
|
|
||||||
help="random seed",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--model_name",
|
|
||||||
type=str,
|
|
||||||
default="timm/convnextv2_base.fcmae_ft_in22k_in1k_384",
|
|
||||||
help="model name to use from timm",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--epochs",
|
|
||||||
type=int,
|
|
||||||
default=3,
|
|
||||||
help="number of epochs to train",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--batch_size",
|
|
||||||
type=int,
|
|
||||||
default=35,
|
|
||||||
help="batch size to use for training",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--batch_size_val",
|
|
||||||
type=int,
|
|
||||||
default=250,
|
|
||||||
help="batch size to use for validation and testing",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--lr",
|
|
||||||
type=float,
|
|
||||||
default=5e-5,
|
|
||||||
help="learning rate",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--weight_decay",
|
|
||||||
type=float,
|
|
||||||
default=1e-4,
|
|
||||||
help="weight decay",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--warmup_steps",
|
|
||||||
type=int,
|
|
||||||
default=500,
|
|
||||||
help="number of warmup steps for cosine scheduler",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--skip_csv",
|
|
||||||
action="store_true",
|
|
||||||
help="skip export test inference to csv file",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--load_ckpt",
|
|
||||||
type=str,
|
|
||||||
help="checkpoint path to load from",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--prefetch_factor",
|
|
||||||
type=int,
|
|
||||||
default=3,
|
|
||||||
help="prefetch factor for dataloaders",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_workers",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="number of workers for dataloaders",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args(argv)
|
|
20
src/tests/dataset.py
Normal file
20
src/tests/dataset.py
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
import datasets
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
dataset = datasets.load_dataset("src/dataset.py")
|
||||||
|
|
||||||
|
labels = dataset["train"].features["label"].names
|
||||||
|
print(labels)
|
||||||
|
|
||||||
|
id2label = {k: v for k, v in enumerate(labels)}
|
||||||
|
label2id = {v: k for k, v in enumerate(labels)}
|
||||||
|
print(label2id)
|
||||||
|
print(id2label)
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
plt.imshow(dataset["train"][idx]["image"])
|
||||||
|
plt.title(id2label[dataset["train"][idx]["label"]])
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
plt.imshow(dataset["test"][idx]["image"])
|
||||||
|
plt.show()
|
|
@ -1,71 +0,0 @@
|
||||||
import numpy as np
|
|
||||||
from imgaug.augmenters import JpegCompression
|
|
||||||
from torchvision.transforms import (
|
|
||||||
AugMix,
|
|
||||||
Compose,
|
|
||||||
Normalize,
|
|
||||||
RandomHorizontalFlip,
|
|
||||||
RandomVerticalFlip,
|
|
||||||
ToTensor,
|
|
||||||
)
|
|
||||||
|
|
||||||
# get feature extractor (to normalize images)
|
|
||||||
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
||||||
|
|
||||||
# define train transform
|
|
||||||
_train_transforms = Compose(
|
|
||||||
[
|
|
||||||
# AugMix(),
|
|
||||||
RandomHorizontalFlip(),
|
|
||||||
RandomVerticalFlip(),
|
|
||||||
# lambda img : JpegCompression(compression=(0, 100))(image=np.array(img)),
|
|
||||||
ToTensor(),
|
|
||||||
normalize,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# define validation transform
|
|
||||||
_val_transforms = Compose(
|
|
||||||
[
|
|
||||||
ToTensor(),
|
|
||||||
normalize,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# actually define the train transform
|
|
||||||
def train_transforms(examples):
|
|
||||||
"""Transforms for training."""
|
|
||||||
examples["pixel_values"] = [_train_transforms(image.convert("RGB")) for image in examples["image"]]
|
|
||||||
return examples
|
|
||||||
|
|
||||||
|
|
||||||
# actually define the validation transform
|
|
||||||
def val_transforms(examples):
|
|
||||||
"""Transforms for validation."""
|
|
||||||
examples["pixel_values"] = [_val_transforms(image.convert("RGB")) for image in examples["image"]]
|
|
||||||
return examples
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
from dataset import train_ds
|
|
||||||
|
|
||||||
idx = 0
|
|
||||||
img = train_ds[idx]["image"]
|
|
||||||
|
|
||||||
plt.subplot(1, 2, 1)
|
|
||||||
plt.imshow(img)
|
|
||||||
plt.title("Original")
|
|
||||||
|
|
||||||
img = _train_transforms(img)
|
|
||||||
img = np.array(img.permute(1, 2, 0))
|
|
||||||
img -= img.min()
|
|
||||||
img /= img.max()
|
|
||||||
|
|
||||||
plt.subplot(1, 2, 2)
|
|
||||||
plt.imshow(img)
|
|
||||||
plt.title("Augmented")
|
|
||||||
plt.show()
|
|
Reference in a new issue