feat: 0.0192

This commit is contained in:
gdamms 2023-02-10 12:55:33 +01:00
parent fb5287eaff
commit 42ac3e0576
9 changed files with 1552 additions and 4572 deletions

2
.gitignore vendored
View file

@ -1,6 +1,8 @@
.direnv/ .direnv/
data/ data/
test-aiornot/ test-aiornot/
submissions/
lightning_logs/
# https://github.com/github/gitignore/blob/main/Python.gitignore # https://github.com/github/gitignore/blob/main/Python.gitignore
# Basic .gitignore for a python repo. # Basic .gitignore for a python repo.

1353
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -11,10 +11,13 @@ version = "0.1.0"
[tool.poetry.dependencies] [tool.poetry.dependencies]
albumentations = "^1.3.0" albumentations = "^1.3.0"
python = ">=3.8.1,<4.0" python = ">=3.8.1,<4.0"
rich = "^12.6.0" rich = "^13.3.1"
torch = "^1.13.1" torch = "^1.13.1"
datasets = "^2.9.0" datasets = "^2.9.0"
transformers = "^4.26.0" transformers = "^4.26.0"
evaluate = "^0.4.0"
pytorch-lightning = "^1.9.0"
timm = "^0.6.12"
[tool.poetry.group.notebooks] [tool.poetry.group.notebooks]
optional = true optional = true
@ -22,6 +25,8 @@ optional = true
[tool.poetry.group.notebooks.dependencies] [tool.poetry.group.notebooks.dependencies]
ipykernel = "^6.20.2" ipykernel = "^6.20.2"
matplotlib = "^3.6.3" matplotlib = "^3.6.3"
ipywidgets = "^8.0.4"
jupyter = "^1.0.0"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
Flake8-pyproject = "^1.1.0" Flake8-pyproject = "^1.1.0"

File diff suppressed because one or more lines are too long

28
src/comparaison.py Normal file
View file

@ -0,0 +1,28 @@
from dataset import val_ds
import matplotlib.pyplot as plt
from rich.progress import track
val_labels = []
val_paths = []
for val_data in track(val_ds):
val_labels.append(val_data["label"])
val_paths.append(val_data["image_path"])
with open("results.csv", 'r') as f:
lines = f.read().splitlines()
lines = [line.split(',') for line in lines[1:]]
res = {0: [], 1: []}
for img_path, val_pred in lines:
index = val_paths.index(img_path)
label = val_labels[index]
if label == 1 and float(val_pred) < 0.5:
print(img_path)
res[label].append(float(val_pred))
plt.hist(res[0], bins=30, alpha=0.5, label="0", color="red")
plt.hist(res[1], bins=30, alpha=0.5, label="1", color="blue")
plt.yscale('log')
plt.legend()
plt.show()

View file

@ -1,15 +1,18 @@
from datasets import load_dataset import datasets
# set seed
RANDOM_SEED = 1010101
# load dataset # load dataset
dataset = load_dataset("tocard-inc/aiornot").shuffle(seed=42) dataset = datasets.load_dataset("competitions/aiornot")
# split up training into training + validation # split up training into training + validation
splits = dataset['train'].train_test_split(test_size=0.1) splits = dataset["train"].train_test_split(test_size=0.1, seed=RANDOM_SEED)
train_ds = splits['train'] # define train, validation and test datasets
val_ds = splits['test'] train_ds = splits["train"]
test_ds = dataset['test'] val_ds = splits["test"]
test_ds = dataset["test"]
labels = train_ds.features['label'].names labels = train_ds.features['label'].names
id2label = {k: v for k, v in enumerate(labels)} id2label = {k: v for k, v in enumerate(labels)}

View file

@ -80,4 +80,4 @@ submission_df = pd.DataFrame({"id": file_paths, "label": pred_ids})
submission_df.head() submission_df.head()
TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S") TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
submission_df.to_csv(f"{TIMESTAMP}.csv", index=False) submission_df.to_csv(f"submissions/{TIMESTAMP}.csv", index=False)

View file

@ -1,61 +1,17 @@
from transformers import ( import pytorch_lightning as pl
AutoModelForImageClassification, import timm
AutoFeatureExtractor,
TrainingArguments,
Trainer
)
from torchvision.transforms import (
CenterCrop,
Compose,
Normalize,
RandomHorizontalFlip,
RandomResizedCrop,
Resize,
ToTensor,
ToPILImage
)
import torch import torch
import numpy as np import torchmetrics
import matplotlib.pyplot as plt from pytorch_lightning.callbacks import (
from datasets import load_metric ModelCheckpoint,
RichModelSummary,
from dataset import train_ds, val_ds, test_ds, labels, id2label, label2id RichProgressBar,
feature_extractor = AutoFeatureExtractor.from_pretrained(
"facebook/convnext-xlarge-384-22k-1k")
normalize = Normalize(mean=feature_extractor.image_mean,
std=feature_extractor.image_std)
_train_transforms = Compose(
[
RandomResizedCrop((256, 256)),
RandomHorizontalFlip(),
ToTensor(),
normalize,
]
) )
from torch.utils.data import DataLoader
from transformers import get_cosine_with_hard_restarts_schedule_with_warmup
_val_transforms = Compose( from dataset import train_ds, val_ds, test_ds
[ from transform import train_transforms, val_transforms
Resize((256, 256)),
CenterCrop((256, 256)),
ToTensor(),
normalize,
]
)
def train_transforms(examples):
examples['pixel_values'] = [_train_transforms(
image.convert("RGB")) for image in examples['image']]
return examples
def val_transforms(examples):
examples['pixel_values'] = [_val_transforms(
image.convert("RGB")) for image in examples['image']]
return examples
# Set the transforms # Set the transforms
@ -63,76 +19,175 @@ train_ds.set_transform(train_transforms)
val_ds.set_transform(val_transforms) val_ds.set_transform(val_transforms)
test_ds.set_transform(val_transforms) test_ds.set_transform(val_transforms)
transform = ToPILImage()
img = train_ds[0]["pixel_values"]
img = img - min(img.flatten().numpy())
img = img / max(img.flatten().numpy())
plt.figure("Augmentation")
plt.subplot(1, 2, 1)
plt.imshow(train_ds[0]["image"])
plt.title("label: " + id2label[train_ds[0]["label"]])
plt.subplot(1, 2, 2)
plt.imshow(transform(img))
plt.title("augmented image")
plt.show()
# Prepare trainer
def collate_fn(examples): def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] """Collate function for training and validation."""
for example in examples]) pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example["label"] for example in examples]) labels = torch.tensor([example["label"] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}
return pixel_values, labels
model = AutoModelForImageClassification.from_pretrained( def collate_fn_test(examples):
'facebook/convnext-xlarge-384-22k-1k', """Collate function for testing."""
num_labels=len(labels), pixel_values = torch.stack([example["pixel_values"] for example in examples])
id2label=id2label,
label2id=label2id,
ignore_mismatched_sizes=True,
)
metric_name = "accuracy" return pixel_values
args = TrainingArguments(
f"test-aiornot",
save_strategy="steps",
evaluation_strategy="steps",
learning_rate=2e-5,
per_device_train_batch_size=24,
per_device_eval_batch_size=24,
num_train_epochs=3,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model=metric_name,
eval_steps=250,
logging_dir='logs',
logging_steps=10,
remove_unused_columns=False,
)
metric = load_metric("accuracy")
def compute_metrics(eval_pred): BATCH_SIZE_TEST = 250
predictions, labels = eval_pred BATCH_SIZE_TRAIN = 35
predictions = np.argmax(predictions, axis=1) BATCH_SIZE_EVAL = 250
return metric.compute(predictions=predictions, references=labels) PREFETCH_FACTOR = 3
NUM_WORKERS = 8
trainer = Trainer( class AIorNOT(pl.LightningModule):
model, """AIorNOT model."""
args,
train_dataset=train_ds, def __init__(self, model_name, lr, weight_decay=1e-4, warmup_steps=0):
eval_dataset=val_ds, """Initialize model."""
data_collator=collate_fn, super().__init__()
compute_metrics=compute_metrics, self.save_hyperparameters()
tokenizer=feature_extractor,
) self.net = timm.create_model(model_name, pretrained=True, num_classes=1)
# Start tensorboard.
# %load_ext tensorboard self.criterion = torch.nn.BCEWithLogitsLoss()
# %tensorboard - -logdir logs/ self.val_accuracy = torchmetrics.Accuracy("binary")
trainer.train() # TODO: add train_accuracy
def forward(self, pixel_values):
"""Forward pass."""
outputs = self.net(pixel_values)
return outputs
def common_step(self, batch, batch_idx):
"""Common step for training and validation."""
pixel_values, labels = batch
labels = labels.float()
logits = self(pixel_values).squeeze(1).float()
loss = self.criterion(logits, labels)
return loss, logits.sigmoid(), labels
def training_step(self, batch, batch_idx):
"""Training step."""
loss, _, _ = self.common_step(batch, batch_idx)
self.log("train_loss", loss)
self.log("lr", self.optimizers().param_groups[0]["lr"])
return loss
def validation_step(self, batch, batch_idx):
"""Validation step."""
loss, preds, targets = self.common_step(batch, batch_idx)
self.log("val_loss", loss, on_epoch=True)
return preds, targets
def validation_epoch_end(self, outputs):
"""Validation epoch end."""
preds = torch.cat([x[0] for x in outputs])
targets = torch.cat([x[1] for x in outputs])
acc = self.val_accuracy(preds, targets)
self.log("val_acc", acc, prog_bar=True)
def configure_optimizers(self):
"""Configure optimizers."""
optimizer = torch.optim.Adam(self.parameters(), self.hparams.lr, weight_decay=self.hparams.weight_decay)
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
optimizer,
num_warmup_steps=self.hparams.warmup_steps,
num_training_steps=self.trainer.estimated_stepping_batches,
)
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
def train_dataloader(self):
"""Train dataloader."""
return DataLoader(
train_ds,
shuffle=True,
pin_memory=True,
collate_fn=collate_fn,
persistent_workers=True,
num_workers=NUM_WORKERS,
batch_size=BATCH_SIZE_TRAIN,
prefetch_factor=PREFETCH_FACTOR,
)
def val_dataloader(self):
"""Validation dataloader."""
return DataLoader(
val_ds,
pin_memory=True,
collate_fn=collate_fn,
persistent_workers=True,
num_workers=NUM_WORKERS,
batch_size=BATCH_SIZE_EVAL,
prefetch_factor=PREFETCH_FACTOR,
)
def test_dataloader(self):
"""Test dataloader."""
return DataLoader(
test_ds,
pin_memory=True,
collate_fn=collate_fn_test,
persistent_workers=True,
num_workers=NUM_WORKERS,
batch_size=BATCH_SIZE_TEST,
prefetch_factor=PREFETCH_FACTOR,
)
if __name__ == "__main__":
# get model
model_name = "timm/convnextv2_base.fcmae_ft_in22k_in1k_384"
# model_name = "timm/convnextv2_large.fcmae_ft_in22k_in1k_384" # TODO
model = AIorNOT(model_name, lr=5e-5, warmup_steps=500)
torch.set_float32_matmul_precision("medium")
# # load checkpoint
# model = AIorNOT.load_from_checkpoint(
# "/home/laurent/AIorNot/lightning_logs/version_73/checkpoints/epoch=2-step=1624.ckpt"
# )
# # compile model
# model.net = torch.compile(model.net)
# define trainer
trainer = pl.Trainer(
accelerator="gpu",
devices="auto",
strategy="dp",
max_epochs=5,
precision=16,
log_every_n_steps=25,
val_check_interval=100,
benchmark=True,
callbacks=[
ModelCheckpoint(mode="max", monitor="val_acc"),
ModelCheckpoint(save_on_train_epoch_end=True),
RichModelSummary(max_depth=2),
RichProgressBar(),
],
)
# train model
trainer.fit(model)
# make predictions on test set
test_results = trainer.predict(model, dataloaders=model.test_dataloader())
# save predictions to csv
i = 0
with open(f"results_{trainer.logger.version}.csv", "w") as f:
f.write("id,label\n")
for test_result in test_results:
for logit in test_result.float().sigmoid():
f.write(f"{test_ds[i]['id']},{float(logit)}\n")
i += 1

44
src/transform.py Normal file
View file

@ -0,0 +1,44 @@
from torchvision.transforms import (
AugMix,
Compose,
Normalize,
RandomHorizontalFlip,
RandomVerticalFlip,
ToTensor,
)
# get feature extractor (to normalize images)
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# define train transform
_train_transforms = Compose(
[
# AugMix(),
RandomHorizontalFlip(),
RandomVerticalFlip(),
ToTensor(),
normalize,
]
)
# define validation transform
_val_transforms = Compose(
[
ToTensor(),
normalize,
]
)
# actually define the train transform
def train_transforms(examples):
"""Transforms for training."""
examples["pixel_values"] = [_train_transforms(image.convert("RGB")) for image in examples["image"]]
return examples
# actually define the validation transform
def val_transforms(examples):
"""Transforms for validation."""
examples["pixel_values"] = [_val_transforms(image.convert("RGB")) for image in examples["image"]]
return examples