feat: 0.0192
This commit is contained in:
parent
fb5287eaff
commit
42ac3e0576
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1,6 +1,8 @@
|
||||||
.direnv/
|
.direnv/
|
||||||
data/
|
data/
|
||||||
test-aiornot/
|
test-aiornot/
|
||||||
|
submissions/
|
||||||
|
lightning_logs/
|
||||||
|
|
||||||
# https://github.com/github/gitignore/blob/main/Python.gitignore
|
# https://github.com/github/gitignore/blob/main/Python.gitignore
|
||||||
# Basic .gitignore for a python repo.
|
# Basic .gitignore for a python repo.
|
||||||
|
|
1353
poetry.lock
generated
1353
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -11,10 +11,13 @@ version = "0.1.0"
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
albumentations = "^1.3.0"
|
albumentations = "^1.3.0"
|
||||||
python = ">=3.8.1,<4.0"
|
python = ">=3.8.1,<4.0"
|
||||||
rich = "^12.6.0"
|
rich = "^13.3.1"
|
||||||
torch = "^1.13.1"
|
torch = "^1.13.1"
|
||||||
datasets = "^2.9.0"
|
datasets = "^2.9.0"
|
||||||
transformers = "^4.26.0"
|
transformers = "^4.26.0"
|
||||||
|
evaluate = "^0.4.0"
|
||||||
|
pytorch-lightning = "^1.9.0"
|
||||||
|
timm = "^0.6.12"
|
||||||
|
|
||||||
[tool.poetry.group.notebooks]
|
[tool.poetry.group.notebooks]
|
||||||
optional = true
|
optional = true
|
||||||
|
@ -22,6 +25,8 @@ optional = true
|
||||||
[tool.poetry.group.notebooks.dependencies]
|
[tool.poetry.group.notebooks.dependencies]
|
||||||
ipykernel = "^6.20.2"
|
ipykernel = "^6.20.2"
|
||||||
matplotlib = "^3.6.3"
|
matplotlib = "^3.6.3"
|
||||||
|
ipywidgets = "^8.0.4"
|
||||||
|
jupyter = "^1.0.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
Flake8-pyproject = "^1.1.0"
|
Flake8-pyproject = "^1.1.0"
|
||||||
|
|
File diff suppressed because one or more lines are too long
28
src/comparaison.py
Normal file
28
src/comparaison.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
from dataset import val_ds
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from rich.progress import track
|
||||||
|
|
||||||
|
|
||||||
|
val_labels = []
|
||||||
|
val_paths = []
|
||||||
|
for val_data in track(val_ds):
|
||||||
|
val_labels.append(val_data["label"])
|
||||||
|
val_paths.append(val_data["image_path"])
|
||||||
|
|
||||||
|
with open("results.csv", 'r') as f:
|
||||||
|
lines = f.read().splitlines()
|
||||||
|
lines = [line.split(',') for line in lines[1:]]
|
||||||
|
|
||||||
|
res = {0: [], 1: []}
|
||||||
|
for img_path, val_pred in lines:
|
||||||
|
index = val_paths.index(img_path)
|
||||||
|
label = val_labels[index]
|
||||||
|
if label == 1 and float(val_pred) < 0.5:
|
||||||
|
print(img_path)
|
||||||
|
res[label].append(float(val_pred))
|
||||||
|
|
||||||
|
plt.hist(res[0], bins=30, alpha=0.5, label="0", color="red")
|
||||||
|
plt.hist(res[1], bins=30, alpha=0.5, label="1", color="blue")
|
||||||
|
plt.yscale('log')
|
||||||
|
plt.legend()
|
||||||
|
plt.show()
|
|
@ -1,15 +1,18 @@
|
||||||
from datasets import load_dataset
|
import datasets
|
||||||
|
|
||||||
|
# set seed
|
||||||
|
RANDOM_SEED = 1010101
|
||||||
|
|
||||||
# load dataset
|
# load dataset
|
||||||
dataset = load_dataset("tocard-inc/aiornot").shuffle(seed=42)
|
dataset = datasets.load_dataset("competitions/aiornot")
|
||||||
|
|
||||||
# split up training into training + validation
|
# split up training into training + validation
|
||||||
splits = dataset['train'].train_test_split(test_size=0.1)
|
splits = dataset["train"].train_test_split(test_size=0.1, seed=RANDOM_SEED)
|
||||||
|
|
||||||
train_ds = splits['train']
|
# define train, validation and test datasets
|
||||||
val_ds = splits['test']
|
train_ds = splits["train"]
|
||||||
test_ds = dataset['test']
|
val_ds = splits["test"]
|
||||||
|
test_ds = dataset["test"]
|
||||||
|
|
||||||
labels = train_ds.features['label'].names
|
labels = train_ds.features['label'].names
|
||||||
id2label = {k: v for k, v in enumerate(labels)}
|
id2label = {k: v for k, v in enumerate(labels)}
|
||||||
|
|
|
@ -80,4 +80,4 @@ submission_df = pd.DataFrame({"id": file_paths, "label": pred_ids})
|
||||||
submission_df.head()
|
submission_df.head()
|
||||||
|
|
||||||
TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
|
TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||||
submission_df.to_csv(f"{TIMESTAMP}.csv", index=False)
|
submission_df.to_csv(f"submissions/{TIMESTAMP}.csv", index=False)
|
291
src/main.py
291
src/main.py
|
@ -1,61 +1,17 @@
|
||||||
from transformers import (
|
import pytorch_lightning as pl
|
||||||
AutoModelForImageClassification,
|
import timm
|
||||||
AutoFeatureExtractor,
|
|
||||||
TrainingArguments,
|
|
||||||
Trainer
|
|
||||||
)
|
|
||||||
from torchvision.transforms import (
|
|
||||||
CenterCrop,
|
|
||||||
Compose,
|
|
||||||
Normalize,
|
|
||||||
RandomHorizontalFlip,
|
|
||||||
RandomResizedCrop,
|
|
||||||
Resize,
|
|
||||||
ToTensor,
|
|
||||||
ToPILImage
|
|
||||||
)
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import torchmetrics
|
||||||
import matplotlib.pyplot as plt
|
from pytorch_lightning.callbacks import (
|
||||||
from datasets import load_metric
|
ModelCheckpoint,
|
||||||
|
RichModelSummary,
|
||||||
from dataset import train_ds, val_ds, test_ds, labels, id2label, label2id
|
RichProgressBar,
|
||||||
|
|
||||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
|
||||||
"facebook/convnext-xlarge-384-22k-1k")
|
|
||||||
|
|
||||||
normalize = Normalize(mean=feature_extractor.image_mean,
|
|
||||||
std=feature_extractor.image_std)
|
|
||||||
|
|
||||||
_train_transforms = Compose(
|
|
||||||
[
|
|
||||||
RandomResizedCrop((256, 256)),
|
|
||||||
RandomHorizontalFlip(),
|
|
||||||
ToTensor(),
|
|
||||||
normalize,
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from transformers import get_cosine_with_hard_restarts_schedule_with_warmup
|
||||||
|
|
||||||
_val_transforms = Compose(
|
from dataset import train_ds, val_ds, test_ds
|
||||||
[
|
from transform import train_transforms, val_transforms
|
||||||
Resize((256, 256)),
|
|
||||||
CenterCrop((256, 256)),
|
|
||||||
ToTensor(),
|
|
||||||
normalize,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def train_transforms(examples):
|
|
||||||
examples['pixel_values'] = [_train_transforms(
|
|
||||||
image.convert("RGB")) for image in examples['image']]
|
|
||||||
return examples
|
|
||||||
|
|
||||||
|
|
||||||
def val_transforms(examples):
|
|
||||||
examples['pixel_values'] = [_val_transforms(
|
|
||||||
image.convert("RGB")) for image in examples['image']]
|
|
||||||
return examples
|
|
||||||
|
|
||||||
|
|
||||||
# Set the transforms
|
# Set the transforms
|
||||||
|
@ -63,76 +19,175 @@ train_ds.set_transform(train_transforms)
|
||||||
val_ds.set_transform(val_transforms)
|
val_ds.set_transform(val_transforms)
|
||||||
test_ds.set_transform(val_transforms)
|
test_ds.set_transform(val_transforms)
|
||||||
|
|
||||||
transform = ToPILImage()
|
|
||||||
|
|
||||||
img = train_ds[0]["pixel_values"]
|
|
||||||
img = img - min(img.flatten().numpy())
|
|
||||||
img = img / max(img.flatten().numpy())
|
|
||||||
|
|
||||||
plt.figure("Augmentation")
|
|
||||||
plt.subplot(1, 2, 1)
|
|
||||||
plt.imshow(train_ds[0]["image"])
|
|
||||||
plt.title("label: " + id2label[train_ds[0]["label"]])
|
|
||||||
plt.subplot(1, 2, 2)
|
|
||||||
plt.imshow(transform(img))
|
|
||||||
plt.title("augmented image")
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
# Prepare trainer
|
|
||||||
|
|
||||||
def collate_fn(examples):
|
def collate_fn(examples):
|
||||||
pixel_values = torch.stack([example["pixel_values"]
|
"""Collate function for training and validation."""
|
||||||
for example in examples])
|
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||||
labels = torch.tensor([example["label"] for example in examples])
|
labels = torch.tensor([example["label"] for example in examples])
|
||||||
return {"pixel_values": pixel_values, "labels": labels}
|
|
||||||
|
return pixel_values, labels
|
||||||
|
|
||||||
|
|
||||||
model = AutoModelForImageClassification.from_pretrained(
|
def collate_fn_test(examples):
|
||||||
'facebook/convnext-xlarge-384-22k-1k',
|
"""Collate function for testing."""
|
||||||
num_labels=len(labels),
|
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||||
id2label=id2label,
|
|
||||||
label2id=label2id,
|
|
||||||
ignore_mismatched_sizes=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
metric_name = "accuracy"
|
return pixel_values
|
||||||
|
|
||||||
args = TrainingArguments(
|
|
||||||
f"test-aiornot",
|
|
||||||
save_strategy="steps",
|
|
||||||
evaluation_strategy="steps",
|
|
||||||
learning_rate=2e-5,
|
|
||||||
per_device_train_batch_size=24,
|
|
||||||
per_device_eval_batch_size=24,
|
|
||||||
num_train_epochs=3,
|
|
||||||
weight_decay=0.01,
|
|
||||||
load_best_model_at_end=True,
|
|
||||||
metric_for_best_model=metric_name,
|
|
||||||
eval_steps=250,
|
|
||||||
logging_dir='logs',
|
|
||||||
logging_steps=10,
|
|
||||||
remove_unused_columns=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
metric = load_metric("accuracy")
|
|
||||||
|
|
||||||
|
|
||||||
def compute_metrics(eval_pred):
|
BATCH_SIZE_TEST = 250
|
||||||
predictions, labels = eval_pred
|
BATCH_SIZE_TRAIN = 35
|
||||||
predictions = np.argmax(predictions, axis=1)
|
BATCH_SIZE_EVAL = 250
|
||||||
return metric.compute(predictions=predictions, references=labels)
|
PREFETCH_FACTOR = 3
|
||||||
|
NUM_WORKERS = 8
|
||||||
|
|
||||||
|
|
||||||
trainer = Trainer(
|
class AIorNOT(pl.LightningModule):
|
||||||
model,
|
"""AIorNOT model."""
|
||||||
args,
|
|
||||||
train_dataset=train_ds,
|
def __init__(self, model_name, lr, weight_decay=1e-4, warmup_steps=0):
|
||||||
eval_dataset=val_ds,
|
"""Initialize model."""
|
||||||
data_collator=collate_fn,
|
super().__init__()
|
||||||
compute_metrics=compute_metrics,
|
self.save_hyperparameters()
|
||||||
tokenizer=feature_extractor,
|
|
||||||
)
|
self.net = timm.create_model(model_name, pretrained=True, num_classes=1)
|
||||||
# Start tensorboard.
|
|
||||||
# %load_ext tensorboard
|
self.criterion = torch.nn.BCEWithLogitsLoss()
|
||||||
# %tensorboard - -logdir logs/
|
self.val_accuracy = torchmetrics.Accuracy("binary")
|
||||||
trainer.train()
|
# TODO: add train_accuracy
|
||||||
|
|
||||||
|
def forward(self, pixel_values):
|
||||||
|
"""Forward pass."""
|
||||||
|
outputs = self.net(pixel_values)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def common_step(self, batch, batch_idx):
|
||||||
|
"""Common step for training and validation."""
|
||||||
|
pixel_values, labels = batch
|
||||||
|
labels = labels.float()
|
||||||
|
logits = self(pixel_values).squeeze(1).float()
|
||||||
|
loss = self.criterion(logits, labels)
|
||||||
|
|
||||||
|
return loss, logits.sigmoid(), labels
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
"""Training step."""
|
||||||
|
loss, _, _ = self.common_step(batch, batch_idx)
|
||||||
|
self.log("train_loss", loss)
|
||||||
|
self.log("lr", self.optimizers().param_groups[0]["lr"])
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
"""Validation step."""
|
||||||
|
loss, preds, targets = self.common_step(batch, batch_idx)
|
||||||
|
self.log("val_loss", loss, on_epoch=True)
|
||||||
|
|
||||||
|
return preds, targets
|
||||||
|
|
||||||
|
def validation_epoch_end(self, outputs):
|
||||||
|
"""Validation epoch end."""
|
||||||
|
preds = torch.cat([x[0] for x in outputs])
|
||||||
|
targets = torch.cat([x[1] for x in outputs])
|
||||||
|
acc = self.val_accuracy(preds, targets)
|
||||||
|
self.log("val_acc", acc, prog_bar=True)
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
"""Configure optimizers."""
|
||||||
|
optimizer = torch.optim.Adam(self.parameters(), self.hparams.lr, weight_decay=self.hparams.weight_decay)
|
||||||
|
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.hparams.warmup_steps,
|
||||||
|
num_training_steps=self.trainer.estimated_stepping_batches,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
"""Train dataloader."""
|
||||||
|
return DataLoader(
|
||||||
|
train_ds,
|
||||||
|
shuffle=True,
|
||||||
|
pin_memory=True,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
persistent_workers=True,
|
||||||
|
num_workers=NUM_WORKERS,
|
||||||
|
batch_size=BATCH_SIZE_TRAIN,
|
||||||
|
prefetch_factor=PREFETCH_FACTOR,
|
||||||
|
)
|
||||||
|
|
||||||
|
def val_dataloader(self):
|
||||||
|
"""Validation dataloader."""
|
||||||
|
return DataLoader(
|
||||||
|
val_ds,
|
||||||
|
pin_memory=True,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
persistent_workers=True,
|
||||||
|
num_workers=NUM_WORKERS,
|
||||||
|
batch_size=BATCH_SIZE_EVAL,
|
||||||
|
prefetch_factor=PREFETCH_FACTOR,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_dataloader(self):
|
||||||
|
"""Test dataloader."""
|
||||||
|
return DataLoader(
|
||||||
|
test_ds,
|
||||||
|
pin_memory=True,
|
||||||
|
collate_fn=collate_fn_test,
|
||||||
|
persistent_workers=True,
|
||||||
|
num_workers=NUM_WORKERS,
|
||||||
|
batch_size=BATCH_SIZE_TEST,
|
||||||
|
prefetch_factor=PREFETCH_FACTOR,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
# get model
|
||||||
|
model_name = "timm/convnextv2_base.fcmae_ft_in22k_in1k_384"
|
||||||
|
# model_name = "timm/convnextv2_large.fcmae_ft_in22k_in1k_384" # TODO
|
||||||
|
model = AIorNOT(model_name, lr=5e-5, warmup_steps=500)
|
||||||
|
|
||||||
|
torch.set_float32_matmul_precision("medium")
|
||||||
|
|
||||||
|
# # load checkpoint
|
||||||
|
# model = AIorNOT.load_from_checkpoint(
|
||||||
|
# "/home/laurent/AIorNot/lightning_logs/version_73/checkpoints/epoch=2-step=1624.ckpt"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# # compile model
|
||||||
|
# model.net = torch.compile(model.net)
|
||||||
|
|
||||||
|
# define trainer
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
accelerator="gpu",
|
||||||
|
devices="auto",
|
||||||
|
strategy="dp",
|
||||||
|
max_epochs=5,
|
||||||
|
precision=16,
|
||||||
|
log_every_n_steps=25,
|
||||||
|
val_check_interval=100,
|
||||||
|
benchmark=True,
|
||||||
|
callbacks=[
|
||||||
|
ModelCheckpoint(mode="max", monitor="val_acc"),
|
||||||
|
ModelCheckpoint(save_on_train_epoch_end=True),
|
||||||
|
RichModelSummary(max_depth=2),
|
||||||
|
RichProgressBar(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# train model
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
# make predictions on test set
|
||||||
|
test_results = trainer.predict(model, dataloaders=model.test_dataloader())
|
||||||
|
|
||||||
|
# save predictions to csv
|
||||||
|
i = 0
|
||||||
|
with open(f"results_{trainer.logger.version}.csv", "w") as f:
|
||||||
|
f.write("id,label\n")
|
||||||
|
for test_result in test_results:
|
||||||
|
for logit in test_result.float().sigmoid():
|
||||||
|
f.write(f"{test_ds[i]['id']},{float(logit)}\n")
|
||||||
|
i += 1
|
||||||
|
|
44
src/transform.py
Normal file
44
src/transform.py
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
from torchvision.transforms import (
|
||||||
|
AugMix,
|
||||||
|
Compose,
|
||||||
|
Normalize,
|
||||||
|
RandomHorizontalFlip,
|
||||||
|
RandomVerticalFlip,
|
||||||
|
ToTensor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# get feature extractor (to normalize images)
|
||||||
|
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||||
|
|
||||||
|
# define train transform
|
||||||
|
_train_transforms = Compose(
|
||||||
|
[
|
||||||
|
# AugMix(),
|
||||||
|
RandomHorizontalFlip(),
|
||||||
|
RandomVerticalFlip(),
|
||||||
|
ToTensor(),
|
||||||
|
normalize,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# define validation transform
|
||||||
|
_val_transforms = Compose(
|
||||||
|
[
|
||||||
|
ToTensor(),
|
||||||
|
normalize,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# actually define the train transform
|
||||||
|
def train_transforms(examples):
|
||||||
|
"""Transforms for training."""
|
||||||
|
examples["pixel_values"] = [_train_transforms(image.convert("RGB")) for image in examples["image"]]
|
||||||
|
return examples
|
||||||
|
|
||||||
|
|
||||||
|
# actually define the validation transform
|
||||||
|
def val_transforms(examples):
|
||||||
|
"""Transforms for validation."""
|
||||||
|
examples["pixel_values"] = [_val_transforms(image.convert("RGB")) for image in examples["image"]]
|
||||||
|
return examples
|
Reference in a new issue