Compare commits

..

10 commits

Author SHA1 Message Date
Laureηt a6411301cd
📝 clean up a bit 2023-09-22 15:53:30 +02:00
Laurent Fainsin 265e67bec8 fix: bad add_argument 2023-02-10 16:20:15 +01:00
Laurent Fainsin a88a55b8e8 feat: argparse 2023-02-10 15:23:17 +01:00
gdamms 89f37e6bbf feat: JpegAug 2023-02-10 13:42:50 +01:00
gdamms c6942d325b feat: show augmentation 2023-02-10 13:37:42 +01:00
gdamms 42ac3e0576 feat: 0.0192 2023-02-10 12:55:33 +01:00
gdamms fb5287eaff train and inference base scripts 2023-02-04 14:55:39 +01:00
gdamms 9ebf7de84d ignore checkpoints 2023-02-02 22:58:25 +01:00
gdamms aac135a3fc refactor: notebook to scripts main
TODO: reorganize main
2023-02-02 22:56:20 +01:00
gdamms 7dfcc358e4 refactor: notebook to scripts dataloader 2023-02-02 20:06:36 +01:00
18 changed files with 1749 additions and 4475 deletions

3
.gitignore vendored
View file

@ -1,5 +1,8 @@
.direnv/
data/
test-aiornot/
submissions/
lightning_logs/
# https://github.com/github/gitignore/blob/main/Python.gitignore
# Basic .gitignore for a python repo.

3
.gitmodules vendored
View file

@ -1,3 +0,0 @@
[submodule "aiornot_datasets"]
path = aiornot_datasets
url = https://huggingface.co/datasets/tocard-inc/aiornot

View file

@ -1,6 +1,6 @@
{
// "python.defaultInterpreterPath": ".venv/bin/python",
"python.analysis.typeCheckingMode": "basic",
"python.analysis.typeCheckingMode": "off",
"python.formatting.provider": "black",
"editor.formatOnSave": true,
"python.linting.enabled": true,
@ -24,4 +24,4 @@
"**/__pycache__": true,
"**/.mypy_cache": true,
},
}
}

View file

@ -1,6 +1,6 @@
MIT License
Copyright (c) 2023 Tocard-Inc
Copyright (c) 2023 Laurent Fainsin & Damien Guillotin
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View file

@ -1 +1,5 @@
# AIorNot
# AIorNot
https://huggingface.co/spaces/competitions/aiornot
8/98

@ -1 +0,0 @@
Subproject commit a90618df992a19c775b6b0fb7e0de0fd45a4d505

1353
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -11,10 +11,13 @@ version = "0.1.0"
[tool.poetry.dependencies]
albumentations = "^1.3.0"
python = ">=3.8.1,<4.0"
rich = "^12.6.0"
rich = "^13.3.1"
torch = "^1.13.1"
datasets = "^2.9.0"
transformers = "^4.26.0"
evaluate = "^0.4.0"
pytorch-lightning = "^1.9.0"
timm = "^0.6.12"
[tool.poetry.group.notebooks]
optional = true
@ -22,6 +25,8 @@ optional = true
[tool.poetry.group.notebooks.dependencies]
ipykernel = "^6.20.2"
matplotlib = "^3.6.3"
ipywidgets = "^8.0.4"
jupyter = "^1.0.0"
[tool.poetry.group.dev.dependencies]
Flake8-pyproject = "^1.1.0"

View file

21
src/acclogloss.py Normal file
View file

@ -0,0 +1,21 @@
import numpy as np
def BinaryCrossEntropy(y_true, y_pred):
y_pred = np.clip(y_pred, 1e-7, 1 - 1e-7)
term_0 = (1-y_true) * np.log(1-y_pred)
term_1 = y_true * np.log(y_pred)
return -np.mean(term_0+term_1, axis=0)
nb_tests = 43444
acc = 0.977
labels = np.ones(nb_tests)
nb_true = int(acc * nb_tests)
predicitions = np.concatenate((np.ones(nb_true), np.zeros(nb_tests - nb_true)))
logloss = BinaryCrossEntropy(labels, predicitions)
print(f"Accuracy: {acc}")
print(f"logloss: {logloss}")

File diff suppressed because one or more lines are too long

28
src/comparaison.py Normal file
View file

@ -0,0 +1,28 @@
from dataset import val_ds
import matplotlib.pyplot as plt
from rich.progress import track
val_labels = []
val_paths = []
for val_data in track(val_ds):
val_labels.append(val_data["label"])
val_paths.append(val_data["image_path"])
with open("results.csv", 'r') as f:
lines = f.read().splitlines()
lines = [line.split(',') for line in lines[1:]]
res = {0: [], 1: []}
for img_path, val_pred in lines:
index = val_paths.index(img_path)
label = val_labels[index]
if label == 1 and float(val_pred) < 0.5:
print(img_path)
res[label].append(float(val_pred))
plt.hist(res[0], bins=30, alpha=0.5, label="0", color="red")
plt.hist(res[1], bins=30, alpha=0.5, label="1", color="blue")
plt.yscale('log')
plt.legend()
plt.show()

36
src/dataset.py Normal file
View file

@ -0,0 +1,36 @@
import datasets
# load dataset
dataset = datasets.load_dataset("competitions/aiornot")
# split up training into training + validation
splits = dataset["train"].train_test_split(test_size=0.1)
# define train, validation and test datasets
train_ds = splits["train"]
val_ds = splits["test"]
test_ds = dataset["test"]
labels = ["NOT_AI", "AI"]
id2label = {k: v for k, v in enumerate(labels)}
label2id = {v: k for k, v in enumerate(labels)}
if __name__ == "__main__":
import matplotlib.pyplot as plt
print(f"labels:\n {labels}")
print(f"label-id correspondances:\n {label2id}\n {id2label}")
idx = 0
label = id2label[dataset["train"][idx]["label"]]
plt.subplot(1, 2, 1)
plt.imshow(dataset["train"][idx]["image"])
plt.title(f"Label: {label}")
plt.subplot(1, 2, 2)
plt.imshow(dataset["test"][idx]["image"])
plt.title("Test")
plt.show()

View file

@ -0,0 +1,80 @@
import sys
from typing import List # TODO: update to python 3.11
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import (
ModelCheckpoint,
RichModelSummary,
RichProgressBar,
)
from rich.progress import track
from dataset import test_ds
from model import AIorNOT
from parse import parse_args
def main(argv: List[str]) -> None:
"""Main entrypoint for training and inference."""
# parse args
args = parse_args(argv)
print(args)
# stfu warnings
torch.set_float32_matmul_precision("medium")
# set seed
pl.seed_everything(args.seed, workers=True)
if args.load_ckpt:
# get checkpointed model
model = AIorNOT.load_from_checkpoint(
args.load_ckpt,
args=args,
)
else:
# get model
model = AIorNOT(args)
# # compile model
# model.net = torch.compile(model.net)
# define trainer
trainer = pl.Trainer(
accelerator="gpu",
devices="auto",
strategy="dp",
max_epochs=args.epochs,
precision="bf16",
log_every_n_steps=25,
val_check_interval=100,
benchmark=True,
callbacks=[
ModelCheckpoint(mode="max", monitor="val_acc"),
ModelCheckpoint(save_on_train_epoch_end=True),
RichModelSummary(max_depth=2),
RichProgressBar(),
],
)
# train model
trainer.fit(model)
if not args.skip_csv:
# make predictions on test set
test_results = trainer.predict(model, dataloaders=model.test_dataloader())
# save predictions to csv
# TODO: define track upper bound
with open(f"submissions/results_{trainer.logger.version}.csv", "w") as f:
i = 0
f.write("id,label\n")
for test_result in track(test_results):
for logit in test_result.float().sigmoid():
f.write(f"{test_ds[i]['id']},{float(logit)}\n")
i += 1
if __name__ == "__main__":
main(sys.argv[1:])

125
src/model.py Normal file
View file

@ -0,0 +1,125 @@
import pytorch_lightning as pl
import timm
import torch
import torchmetrics
from torch.utils.data import DataLoader
from transformers import get_cosine_with_hard_restarts_schedule_with_warmup
from dataset import test_ds, train_ds, val_ds
from transform import train_transforms, val_transforms
def collate_fn(examples):
"""Collate function for training and validation."""
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example["label"] for example in examples])
return pixel_values, labels
def collate_fn_test(examples):
"""Collate function for testing."""
pixel_values = torch.stack([example["pixel_values"] for example in examples])
return pixel_values
class AIorNOT(pl.LightningModule):
"""AIorNOT model."""
def __init__(self, args):
"""Initialize model."""
super().__init__()
self.args = args
self.save_hyperparameters()
self.net = timm.create_model(args.model_name, pretrained=True, num_classes=1)
self.criterion = torch.nn.BCEWithLogitsLoss()
self.val_accuracy = torchmetrics.Accuracy("binary")
# TODO: add train_accuracy
def forward(self, pixel_values):
"""Forward pass."""
outputs = self.net(pixel_values)
return outputs
def common_step(self, batch, batch_idx):
"""Common step for training and validation."""
pixel_values, labels = batch
labels = labels.float()
logits = self(pixel_values).squeeze(1).float()
loss = self.criterion(logits, labels)
return loss, logits.sigmoid(), labels
def training_step(self, batch, batch_idx):
"""Training step."""
loss, _, _ = self.common_step(batch, batch_idx)
self.log("train_loss", loss)
self.log("lr", self.optimizers().param_groups[0]["lr"])
return loss
def validation_step(self, batch, batch_idx):
"""Validation step."""
loss, preds, targets = self.common_step(batch, batch_idx)
self.log("val_loss", loss, on_epoch=True)
return preds, targets
def validation_epoch_end(self, outputs):
"""Validation epoch end."""
preds = torch.cat([x[0] for x in outputs])
targets = torch.cat([x[1] for x in outputs])
acc = self.val_accuracy(preds, targets)
self.log("val_acc", acc, prog_bar=True)
def configure_optimizers(self):
"""Configure optimizers."""
optimizer = torch.optim.Adam(self.parameters(), self.args.lr, weight_decay=self.args.weight_decay)
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
optimizer,
num_warmup_steps=self.args.warmup_steps,
num_training_steps=self.trainer.estimated_stepping_batches,
)
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
def train_dataloader(self):
"""Train dataloader."""
return DataLoader(
train_ds.with_transform(train_transforms),
shuffle=True,
pin_memory=True,
collate_fn=collate_fn,
persistent_workers=True,
num_workers=self.args.num_workers,
batch_size=self.args.batch_size,
prefetch_factor=self.args.prefetch_factor,
)
def val_dataloader(self):
"""Validation dataloader."""
return DataLoader(
val_ds.with_transform(val_transforms),
pin_memory=True,
collate_fn=collate_fn,
persistent_workers=True,
num_workers=self.args.num_workers,
batch_size=self.args.batch_size_val,
prefetch_factor=self.args.prefetch_factor,
)
def test_dataloader(self):
"""Test dataloader."""
return DataLoader(
test_ds.with_transform(val_transforms),
pin_memory=True,
collate_fn=collate_fn_test,
persistent_workers=True,
num_workers=self.args.num_workers,
batch_size=self.args.batch_size_val,
prefetch_factor=self.args.prefetch_factor,
)

82
src/parse.py Normal file
View file

@ -0,0 +1,82 @@
import argparse
from typing import List # TODO: update to python 3.11
def parse_args(argv: List[str]) -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Train and inference for AIorNOT challenge",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="random seed",
)
parser.add_argument(
"--model_name",
type=str,
default="timm/convnextv2_base.fcmae_ft_in22k_in1k_384",
help="model name to use from timm",
)
parser.add_argument(
"--epochs",
type=int,
default=3,
help="number of epochs to train",
)
parser.add_argument(
"--batch_size",
type=int,
default=35,
help="batch size to use for training",
)
parser.add_argument(
"--batch_size_val",
type=int,
default=250,
help="batch size to use for validation and testing",
)
parser.add_argument(
"--lr",
type=float,
default=5e-5,
help="learning rate",
)
parser.add_argument(
"--weight_decay",
type=float,
default=1e-4,
help="weight decay",
)
parser.add_argument(
"--warmup_steps",
type=int,
default=500,
help="number of warmup steps for cosine scheduler",
)
parser.add_argument(
"--skip_csv",
action="store_true",
help="skip export test inference to csv file",
)
parser.add_argument(
"--load_ckpt",
type=str,
help="checkpoint path to load from",
)
parser.add_argument(
"--prefetch_factor",
type=int,
default=3,
help="prefetch factor for dataloaders",
)
parser.add_argument(
"--num_workers",
type=int,
default=8,
help="number of workers for dataloaders",
)
return parser.parse_args(argv)

View file

@ -1,20 +0,0 @@
import datasets
import matplotlib.pyplot as plt
dataset = datasets.load_dataset("src/dataset.py")
labels = dataset["train"].features["label"].names
print(labels)
id2label = {k: v for k, v in enumerate(labels)}
label2id = {v: k for k, v in enumerate(labels)}
print(label2id)
print(id2label)
idx = 0
plt.imshow(dataset["train"][idx]["image"])
plt.title(id2label[dataset["train"][idx]["label"]])
plt.show()
plt.imshow(dataset["test"][idx]["image"])
plt.show()

71
src/transform.py Normal file
View file

@ -0,0 +1,71 @@
import numpy as np
from imgaug.augmenters import JpegCompression
from torchvision.transforms import (
AugMix,
Compose,
Normalize,
RandomHorizontalFlip,
RandomVerticalFlip,
ToTensor,
)
# get feature extractor (to normalize images)
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# define train transform
_train_transforms = Compose(
[
# AugMix(),
RandomHorizontalFlip(),
RandomVerticalFlip(),
# lambda img : JpegCompression(compression=(0, 100))(image=np.array(img)),
ToTensor(),
normalize,
]
)
# define validation transform
_val_transforms = Compose(
[
ToTensor(),
normalize,
]
)
# actually define the train transform
def train_transforms(examples):
"""Transforms for training."""
examples["pixel_values"] = [_train_transforms(image.convert("RGB")) for image in examples["image"]]
return examples
# actually define the validation transform
def val_transforms(examples):
"""Transforms for validation."""
examples["pixel_values"] = [_val_transforms(image.convert("RGB")) for image in examples["image"]]
return examples
if __name__ == "__main__":
import matplotlib.pyplot as plt
from dataset import train_ds
idx = 0
img = train_ds[idx]["image"]
plt.subplot(1, 2, 1)
plt.imshow(img)
plt.title("Original")
img = _train_transforms(img)
img = np.array(img.permute(1, 2, 0))
img -= img.min()
img /= img.max()
plt.subplot(1, 2, 2)
plt.imshow(img)
plt.title("Augmented")
plt.show()