feat: argparse

This commit is contained in:
Laurent Fainsin 2023-02-10 15:23:17 +01:00
parent 89f37e6bbf
commit a88a55b8e8
7 changed files with 268 additions and 256 deletions

View file

@ -1,6 +1,6 @@
{ {
// "python.defaultInterpreterPath": ".venv/bin/python", // "python.defaultInterpreterPath": ".venv/bin/python",
"python.analysis.typeCheckingMode": "basic", "python.analysis.typeCheckingMode": "off",
"python.formatting.provider": "black", "python.formatting.provider": "black",
"editor.formatOnSave": true, "editor.formatOnSave": true,
"python.linting.enabled": true, "python.linting.enabled": true,
@ -24,4 +24,4 @@
"**/__pycache__": true, "**/__pycache__": true,
"**/.mypy_cache": true, "**/.mypy_cache": true,
}, },
} }

View file

@ -1,24 +1,21 @@
import datasets import datasets
# set seed
RANDOM_SEED = 1010101
# load dataset # load dataset
dataset = datasets.load_dataset("competitions/aiornot") dataset = datasets.load_dataset("competitions/aiornot")
# split up training into training + validation # split up training into training + validation
splits = dataset["train"].train_test_split(test_size=0.1, seed=RANDOM_SEED) splits = dataset["train"].train_test_split(test_size=0.1)
# define train, validation and test datasets # define train, validation and test datasets
train_ds = splits["train"] train_ds = splits["train"]
val_ds = splits["test"] val_ds = splits["test"]
test_ds = dataset["test"] test_ds = dataset["test"]
labels = ["Not AI", "AI"] labels = ["NOT_AI", "AI"]
id2label = {k: v for k, v in enumerate(labels)} id2label = {k: v for k, v in enumerate(labels)}
label2id = {v: k for k, v in enumerate(labels)} label2id = {v: k for k, v in enumerate(labels)}
if __name__ == '__main__': if __name__ == "__main__":
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -26,11 +23,14 @@ if __name__ == '__main__':
print(f"label-id correspondances:\n {label2id}\n {id2label}") print(f"label-id correspondances:\n {label2id}\n {id2label}")
idx = 0 idx = 0
label = id2label[dataset['train'][idx]['label']] label = id2label[dataset["train"][idx]["label"]]
plt.subplot(1, 2, 1) plt.subplot(1, 2, 1)
plt.imshow(dataset['train'][idx]['image']) plt.imshow(dataset["train"][idx]["image"])
plt.title(f"Label: {label}") plt.title(f"Label: {label}")
plt.subplot(1, 2, 2) plt.subplot(1, 2, 2)
plt.imshow(dataset['test'][idx]['image']) plt.imshow(dataset["test"][idx]["image"])
plt.title("Test") plt.title("Test")
plt.show() plt.show()

View file

@ -1,83 +0,0 @@
import torch
import pandas as pd
from transformers import (
AutoModelForImageClassification,
AutoFeatureExtractor,
)
from torchvision.transforms import (
CenterCrop,
Compose,
Resize,
ToTensor,
Normalize,
)
from rich.progress import track
from datetime import datetime
from dataset import test_ds
feature_extractor = AutoFeatureExtractor.from_pretrained(
'test-aiornot/checkpoint-500')
normalize = Normalize(mean=feature_extractor.image_mean,
std=feature_extractor.image_std)
_val_transforms = Compose(
[
Resize((256, 256)),
CenterCrop((256, 256)),
ToTensor(),
normalize,
]
)
def val_transforms(examples):
examples['pixel_values'] = [_val_transforms(
image.convert("RGB")) for image in examples['image']]
return examples
test_ds.set_transform(val_transforms)
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example["label"] for example in examples])
image_paths = [example["image_path"] for example in examples]
return {
"pixel_values": pixel_values,
"labels": labels,
"image_path": image_paths,
}
model = AutoModelForImageClassification.from_pretrained(
'test-aiornot/checkpoint-1000',
)
test_loader = torch.utils.data.DataLoader(
test_ds, batch_size=128, collate_fn=collate_fn, pin_memory=True, num_workers=4
)
device = "cuda" if torch.cuda.is_available() else "cpu"
_ = model.to(device)
file_paths = []
pred_ids = []
for batch in track(test_loader):
image_paths = batch["image_path"]
image_paths = [x.split("/")[-1] for x in image_paths]
file_paths.extend(image_paths)
images = batch["pixel_values"].to(device)
inputs = {"pixel_values": images}
with torch.no_grad():
logits = model(**inputs).logits
predictions = logits.argmax(-1).cpu().numpy().tolist()
pred_ids.extend(predictions)
submission_df = pd.DataFrame({"id": file_paths, "label": pred_ids})
submission_df.head()
TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
submission_df.to_csv(f"submissions/{TIMESTAMP}.csv", index=False)

View file

@ -1,160 +1,40 @@
import sys
from typing import List # TODO: update to python 3.11
import pytorch_lightning as pl import pytorch_lightning as pl
import timm
import torch import torch
import torchmetrics
from pytorch_lightning.callbacks import ( from pytorch_lightning.callbacks import (
ModelCheckpoint, ModelCheckpoint,
RichModelSummary, RichModelSummary,
RichProgressBar, RichProgressBar,
) )
from torch.utils.data import DataLoader from rich.progress import track
from transformers import get_cosine_with_hard_restarts_schedule_with_warmup
from dataset import train_ds, val_ds, test_ds from dataset import test_ds
from transform import train_transforms, val_transforms from model import AIorNOT
from parse import parse_args
# Set the transforms def main(argv: List[str]) -> None:
train_ds.set_transform(train_transforms) """Main entrypoint for training and inference."""
val_ds.set_transform(val_transforms) # parse args
test_ds.set_transform(val_transforms) args = parse_args(argv)
def collate_fn(examples):
"""Collate function for training and validation."""
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example["label"] for example in examples])
return pixel_values, labels
def collate_fn_test(examples):
"""Collate function for testing."""
pixel_values = torch.stack([example["pixel_values"] for example in examples])
return pixel_values
BATCH_SIZE_TEST = 250
BATCH_SIZE_TRAIN = 35
BATCH_SIZE_EVAL = 250
PREFETCH_FACTOR = 3
NUM_WORKERS = 8
class AIorNOT(pl.LightningModule):
"""AIorNOT model."""
def __init__(self, model_name, lr, weight_decay=1e-4, warmup_steps=0):
"""Initialize model."""
super().__init__()
self.save_hyperparameters()
self.net = timm.create_model(model_name, pretrained=True, num_classes=1)
self.criterion = torch.nn.BCEWithLogitsLoss()
self.val_accuracy = torchmetrics.Accuracy("binary")
# TODO: add train_accuracy
def forward(self, pixel_values):
"""Forward pass."""
outputs = self.net(pixel_values)
return outputs
def common_step(self, batch, batch_idx):
"""Common step for training and validation."""
pixel_values, labels = batch
labels = labels.float()
logits = self(pixel_values).squeeze(1).float()
loss = self.criterion(logits, labels)
return loss, logits.sigmoid(), labels
def training_step(self, batch, batch_idx):
"""Training step."""
loss, _, _ = self.common_step(batch, batch_idx)
self.log("train_loss", loss)
self.log("lr", self.optimizers().param_groups[0]["lr"])
return loss
def validation_step(self, batch, batch_idx):
"""Validation step."""
loss, preds, targets = self.common_step(batch, batch_idx)
self.log("val_loss", loss, on_epoch=True)
return preds, targets
def validation_epoch_end(self, outputs):
"""Validation epoch end."""
preds = torch.cat([x[0] for x in outputs])
targets = torch.cat([x[1] for x in outputs])
acc = self.val_accuracy(preds, targets)
self.log("val_acc", acc, prog_bar=True)
def configure_optimizers(self):
"""Configure optimizers."""
optimizer = torch.optim.Adam(self.parameters(), self.hparams.lr, weight_decay=self.hparams.weight_decay)
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
optimizer,
num_warmup_steps=self.hparams.warmup_steps,
num_training_steps=self.trainer.estimated_stepping_batches,
)
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
def train_dataloader(self):
"""Train dataloader."""
return DataLoader(
train_ds,
shuffle=True,
pin_memory=True,
collate_fn=collate_fn,
persistent_workers=True,
num_workers=NUM_WORKERS,
batch_size=BATCH_SIZE_TRAIN,
prefetch_factor=PREFETCH_FACTOR,
)
def val_dataloader(self):
"""Validation dataloader."""
return DataLoader(
val_ds,
pin_memory=True,
collate_fn=collate_fn,
persistent_workers=True,
num_workers=NUM_WORKERS,
batch_size=BATCH_SIZE_EVAL,
prefetch_factor=PREFETCH_FACTOR,
)
def test_dataloader(self):
"""Test dataloader."""
return DataLoader(
test_ds,
pin_memory=True,
collate_fn=collate_fn_test,
persistent_workers=True,
num_workers=NUM_WORKERS,
batch_size=BATCH_SIZE_TEST,
prefetch_factor=PREFETCH_FACTOR,
)
if __name__ == "__main__":
# get model
model_name = "timm/convnextv2_base.fcmae_ft_in22k_in1k_384"
# model_name = "timm/convnextv2_large.fcmae_ft_in22k_in1k_384" # TODO
model = AIorNOT(model_name, lr=5e-5, warmup_steps=500)
# stfu warnings
torch.set_float32_matmul_precision("medium") torch.set_float32_matmul_precision("medium")
# # load checkpoint # set seed
# model = AIorNOT.load_from_checkpoint( pl.seed_everything(args.seed, workers=True)
# "/home/laurent/AIorNot/lightning_logs/version_73/checkpoints/epoch=2-step=1624.ckpt"
# ) if args.load_ckpt:
# get checkpointed model
model = AIorNOT.load_from_checkpoint(
args.load_ckpt,
args=args,
)
else:
# get model
model = AIorNOT(args)
# # compile model # # compile model
# model.net = torch.compile(model.net) # model.net = torch.compile(model.net)
@ -164,7 +44,7 @@ if __name__ == "__main__":
accelerator="gpu", accelerator="gpu",
devices="auto", devices="auto",
strategy="dp", strategy="dp",
max_epochs=5, max_epochs=args.epochs,
precision=16, precision=16,
log_every_n_steps=25, log_every_n_steps=25,
val_check_interval=100, val_check_interval=100,
@ -180,14 +60,19 @@ if __name__ == "__main__":
# train model # train model
trainer.fit(model) trainer.fit(model)
# make predictions on test set if not args.skip_csv:
test_results = trainer.predict(model, dataloaders=model.test_dataloader()) # make predictions on test set
test_results = trainer.predict(model, dataloaders=model.test_dataloader())
# save predictions to csv # save predictions to csv
i = 0 with open(f"submissions/results_{trainer.logger.version}.csv", "w") as f:
with open(f"results_{trainer.logger.version}.csv", "w") as f: i = 0
f.write("id,label\n") f.write("id,label\n")
for test_result in test_results: for test_result in track(test_results):
for logit in test_result.float().sigmoid(): for logit in test_result.float().sigmoid():
f.write(f"{test_ds[i]['id']},{float(logit)}\n") f.write(f"{test_ds[i]['id']},{float(logit)}\n")
i += 1 i += 1
if __name__ == "__main__":
main(sys.argv[1:])

125
src/model.py Normal file
View file

@ -0,0 +1,125 @@
import pytorch_lightning as pl
import timm
import torch
import torchmetrics
from torch.utils.data import DataLoader
from transformers import get_cosine_with_hard_restarts_schedule_with_warmup
from dataset import test_ds, train_ds, val_ds
from transform import train_transforms, val_transforms
def collate_fn(examples):
"""Collate function for training and validation."""
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example["label"] for example in examples])
return pixel_values, labels
def collate_fn_test(examples):
"""Collate function for testing."""
pixel_values = torch.stack([example["pixel_values"] for example in examples])
return pixel_values
class AIorNOT(pl.LightningModule):
"""AIorNOT model."""
def __init__(self, args):
"""Initialize model."""
super().__init__()
self.args = args
self.save_hyperparameters()
self.net = timm.create_model(args.model_name, pretrained=True, num_classes=1)
self.criterion = torch.nn.BCEWithLogitsLoss()
self.val_accuracy = torchmetrics.Accuracy("binary")
# TODO: add train_accuracy
def forward(self, pixel_values):
"""Forward pass."""
outputs = self.net(pixel_values)
return outputs
def common_step(self, batch, batch_idx):
"""Common step for training and validation."""
pixel_values, labels = batch
labels = labels.float()
logits = self(pixel_values).squeeze(1).float()
loss = self.criterion(logits, labels)
return loss, logits.sigmoid(), labels
def training_step(self, batch, batch_idx):
"""Training step."""
loss, _, _ = self.common_step(batch, batch_idx)
self.log("train_loss", loss)
self.log("lr", self.optimizers().param_groups[0]["lr"])
return loss
def validation_step(self, batch, batch_idx):
"""Validation step."""
loss, preds, targets = self.common_step(batch, batch_idx)
self.log("val_loss", loss, on_epoch=True)
return preds, targets
def validation_epoch_end(self, outputs):
"""Validation epoch end."""
preds = torch.cat([x[0] for x in outputs])
targets = torch.cat([x[1] for x in outputs])
acc = self.val_accuracy(preds, targets)
self.log("val_acc", acc, prog_bar=True)
def configure_optimizers(self):
"""Configure optimizers."""
optimizer = torch.optim.Adam(self.parameters(), self.args.lr, weight_decay=self.args.weight_decay)
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
optimizer,
num_warmup_steps=self.args.warmup_steps,
num_training_steps=self.trainer.estimated_stepping_batches,
)
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
def train_dataloader(self):
"""Train dataloader."""
return DataLoader(
train_ds.with_transform(train_transforms),
shuffle=True,
pin_memory=True,
collate_fn=collate_fn,
persistent_workers=True,
num_workers=self.args.num_workers,
batch_size=self.args.batch_size,
prefetch_factor=self.args.prefetch_factor,
)
def val_dataloader(self):
"""Validation dataloader."""
return DataLoader(
val_ds.with_transform(val_transforms),
pin_memory=True,
collate_fn=collate_fn,
persistent_workers=True,
num_workers=self.args.num_workers,
batch_size=self.args.batch_size_val,
prefetch_factor=self.args.prefetch_factor,
)
def test_dataloader(self):
"""Test dataloader."""
return DataLoader(
test_ds.with_transform(val_transforms),
pin_memory=True,
collate_fn=collate_fn_test,
persistent_workers=True,
num_workers=self.args.num_workers,
batch_size=self.args.batch_size_val,
prefetch_factor=self.args.prefetch_factor,
)

81
src/parse.py Normal file
View file

@ -0,0 +1,81 @@
import argparse
from typing import List # TODO: update to python 3.11
def parse_args(argv: List[str]) -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="pytorch lightning + classy vision TorchX example app",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="random seed",
)
parser.add_argument(
"--model_name",
type=str,
default="timm/convnextv2_base.fcmae_ft_in22k_in1k_384",
help="model name to use from timm",
)
parser.add_argument(
"--epochs",
type=int,
default=3,
help="number of epochs to train",
)
parser.add_argument(
"--batch_size",
type=int,
default=35,
help="batch size to use for training",
)
parser.add_argument(
"--batch_size_val",
type=int,
default=250,
help="batch size to use for validation and testing",
)
parser.add_argument(
"--lr",
type=float,
default=5e-5,
help="learning rate",
)
parser.add_argument(
"--weight_decay",
type=float,
default=1e-4,
help="weight decay",
)
parser.add_argument(
"--warmup_steps",
type=int,
default=500,
help="number of warmup steps for cosine scheduler",
)
parser.add_argument(
"--skip_csv",
desc="skip export test inference to csv file",
)
parser.add_argument(
"--load_ckpt",
type=str,
help="checkpoint path to load from",
)
parser.add_argument(
"--prefetch_factor",
type=int,
default=2,
help="prefetch factor for dataloaders",
)
parser.add_argument(
"--num_workers",
type=int,
default=4,
help="number of workers for dataloaders",
)
return parser.parse_args(argv)

View file

@ -1,4 +1,5 @@
import numpy as np import numpy as np
from imgaug.augmenters import JpegCompression
from torchvision.transforms import ( from torchvision.transforms import (
AugMix, AugMix,
Compose, Compose,
@ -7,7 +8,6 @@ from torchvision.transforms import (
RandomVerticalFlip, RandomVerticalFlip,
ToTensor, ToTensor,
) )
from imgaug.augmenters import JpegCompression
# get feature extractor (to normalize images) # get feature extractor (to normalize images)
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
@ -15,10 +15,10 @@ normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# define train transform # define train transform
_train_transforms = Compose( _train_transforms = Compose(
[ [
AugMix(), # AugMix(),
RandomHorizontalFlip(), RandomHorizontalFlip(),
RandomVerticalFlip(), RandomVerticalFlip(),
lambda img : JpegCompression(compression=(0, 100))(image=np.array(img)), # lambda img : JpegCompression(compression=(0, 100))(image=np.array(img)),
ToTensor(), ToTensor(),
normalize, normalize,
] ]
@ -47,21 +47,25 @@ def val_transforms(examples):
return examples return examples
if __name__ == '__main__': if __name__ == "__main__":
from dataset import train_ds
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from dataset import train_ds
idx = 0 idx = 0
img = train_ds[idx]['image'] img = train_ds[idx]["image"]
plt.subplot(1, 2, 1) plt.subplot(1, 2, 1)
plt.imshow(img) plt.imshow(img)
plt.title("Original") plt.title("Original")
img = _train_transforms(img) img = _train_transforms(img)
img = np.array(img.permute(1, 2, 0)) img = np.array(img.permute(1, 2, 0))
img -= img.min() img -= img.min()
img /= img.max() img /= img.max()
plt.subplot(1, 2, 2) plt.subplot(1, 2, 2)
plt.imshow(img) plt.imshow(img)
plt.title("Augmented") plt.title("Augmented")
plt.show() plt.show()