From a88a55b8e8396dd927e865f02b09e344e8d0e300 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Fri, 10 Feb 2023 15:23:17 +0100 Subject: [PATCH] feat: argparse --- .vscode/settings.json | 4 +- src/dataset.py | 18 ++-- src/inference.py | 83 ------------------ src/main.py | 195 +++++++++--------------------------------- src/model.py | 125 +++++++++++++++++++++++++++ src/parse.py | 81 ++++++++++++++++++ src/transform.py | 18 ++-- 7 files changed, 268 insertions(+), 256 deletions(-) delete mode 100644 src/inference.py create mode 100644 src/model.py create mode 100644 src/parse.py diff --git a/.vscode/settings.json b/.vscode/settings.json index bc17442..cc8cdde 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,6 +1,6 @@ { // "python.defaultInterpreterPath": ".venv/bin/python", - "python.analysis.typeCheckingMode": "basic", + "python.analysis.typeCheckingMode": "off", "python.formatting.provider": "black", "editor.formatOnSave": true, "python.linting.enabled": true, @@ -24,4 +24,4 @@ "**/__pycache__": true, "**/.mypy_cache": true, }, -} +} \ No newline at end of file diff --git a/src/dataset.py b/src/dataset.py index 746c0dd..2d44067 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -1,24 +1,21 @@ import datasets -# set seed -RANDOM_SEED = 1010101 - # load dataset dataset = datasets.load_dataset("competitions/aiornot") # split up training into training + validation -splits = dataset["train"].train_test_split(test_size=0.1, seed=RANDOM_SEED) +splits = dataset["train"].train_test_split(test_size=0.1) # define train, validation and test datasets train_ds = splits["train"] val_ds = splits["test"] test_ds = dataset["test"] -labels = ["Not AI", "AI"] +labels = ["NOT_AI", "AI"] id2label = {k: v for k, v in enumerate(labels)} label2id = {v: k for k, v in enumerate(labels)} -if __name__ == '__main__': +if __name__ == "__main__": import matplotlib.pyplot as plt @@ -26,11 +23,14 @@ if __name__ == '__main__': print(f"label-id correspondances:\n {label2id}\n {id2label}") idx = 0 - label = id2label[dataset['train'][idx]['label']] + label = id2label[dataset["train"][idx]["label"]] + plt.subplot(1, 2, 1) - plt.imshow(dataset['train'][idx]['image']) + plt.imshow(dataset["train"][idx]["image"]) plt.title(f"Label: {label}") + plt.subplot(1, 2, 2) - plt.imshow(dataset['test'][idx]['image']) + plt.imshow(dataset["test"][idx]["image"]) plt.title("Test") + plt.show() diff --git a/src/inference.py b/src/inference.py deleted file mode 100644 index b844511..0000000 --- a/src/inference.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -import pandas as pd -from transformers import ( - AutoModelForImageClassification, - AutoFeatureExtractor, -) -from torchvision.transforms import ( - CenterCrop, - Compose, - Resize, - ToTensor, - Normalize, -) -from rich.progress import track -from datetime import datetime - -from dataset import test_ds - -feature_extractor = AutoFeatureExtractor.from_pretrained( - 'test-aiornot/checkpoint-500') - -normalize = Normalize(mean=feature_extractor.image_mean, - std=feature_extractor.image_std) - -_val_transforms = Compose( - [ - Resize((256, 256)), - CenterCrop((256, 256)), - ToTensor(), - normalize, - ] -) -def val_transforms(examples): - examples['pixel_values'] = [_val_transforms( - image.convert("RGB")) for image in examples['image']] - return examples - -test_ds.set_transform(val_transforms) - -def collate_fn(examples): - pixel_values = torch.stack([example["pixel_values"] for example in examples]) - labels = torch.tensor([example["label"] for example in examples]) - image_paths = [example["image_path"] for example in examples] - - return { - "pixel_values": pixel_values, - "labels": labels, - "image_path": image_paths, - } - -model = AutoModelForImageClassification.from_pretrained( - 'test-aiornot/checkpoint-1000', -) - -test_loader = torch.utils.data.DataLoader( - test_ds, batch_size=128, collate_fn=collate_fn, pin_memory=True, num_workers=4 -) -device = "cuda" if torch.cuda.is_available() else "cpu" -_ = model.to(device) - -file_paths = [] -pred_ids = [] - -for batch in track(test_loader): - image_paths = batch["image_path"] - image_paths = [x.split("/")[-1] for x in image_paths] - file_paths.extend(image_paths) - - images = batch["pixel_values"].to(device) - inputs = {"pixel_values": images} - - with torch.no_grad(): - logits = model(**inputs).logits - - predictions = logits.argmax(-1).cpu().numpy().tolist() - pred_ids.extend(predictions) - - -submission_df = pd.DataFrame({"id": file_paths, "label": pred_ids}) -submission_df.head() - -TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S") -submission_df.to_csv(f"submissions/{TIMESTAMP}.csv", index=False) \ No newline at end of file diff --git a/src/main.py b/src/main.py index 26f77b3..48b00e5 100644 --- a/src/main.py +++ b/src/main.py @@ -1,160 +1,40 @@ +import sys +from typing import List # TODO: update to python 3.11 + import pytorch_lightning as pl -import timm import torch -import torchmetrics from pytorch_lightning.callbacks import ( ModelCheckpoint, RichModelSummary, RichProgressBar, ) -from torch.utils.data import DataLoader -from transformers import get_cosine_with_hard_restarts_schedule_with_warmup +from rich.progress import track -from dataset import train_ds, val_ds, test_ds -from transform import train_transforms, val_transforms +from dataset import test_ds +from model import AIorNOT +from parse import parse_args -# Set the transforms -train_ds.set_transform(train_transforms) -val_ds.set_transform(val_transforms) -test_ds.set_transform(val_transforms) - - -def collate_fn(examples): - """Collate function for training and validation.""" - pixel_values = torch.stack([example["pixel_values"] for example in examples]) - labels = torch.tensor([example["label"] for example in examples]) - - return pixel_values, labels - - -def collate_fn_test(examples): - """Collate function for testing.""" - pixel_values = torch.stack([example["pixel_values"] for example in examples]) - - return pixel_values - - -BATCH_SIZE_TEST = 250 -BATCH_SIZE_TRAIN = 35 -BATCH_SIZE_EVAL = 250 -PREFETCH_FACTOR = 3 -NUM_WORKERS = 8 - - -class AIorNOT(pl.LightningModule): - """AIorNOT model.""" - - def __init__(self, model_name, lr, weight_decay=1e-4, warmup_steps=0): - """Initialize model.""" - super().__init__() - self.save_hyperparameters() - - self.net = timm.create_model(model_name, pretrained=True, num_classes=1) - - self.criterion = torch.nn.BCEWithLogitsLoss() - self.val_accuracy = torchmetrics.Accuracy("binary") - # TODO: add train_accuracy - - def forward(self, pixel_values): - """Forward pass.""" - outputs = self.net(pixel_values) - - return outputs - - def common_step(self, batch, batch_idx): - """Common step for training and validation.""" - pixel_values, labels = batch - labels = labels.float() - logits = self(pixel_values).squeeze(1).float() - loss = self.criterion(logits, labels) - - return loss, logits.sigmoid(), labels - - def training_step(self, batch, batch_idx): - """Training step.""" - loss, _, _ = self.common_step(batch, batch_idx) - self.log("train_loss", loss) - self.log("lr", self.optimizers().param_groups[0]["lr"]) - - return loss - - def validation_step(self, batch, batch_idx): - """Validation step.""" - loss, preds, targets = self.common_step(batch, batch_idx) - self.log("val_loss", loss, on_epoch=True) - - return preds, targets - - def validation_epoch_end(self, outputs): - """Validation epoch end.""" - preds = torch.cat([x[0] for x in outputs]) - targets = torch.cat([x[1] for x in outputs]) - acc = self.val_accuracy(preds, targets) - self.log("val_acc", acc, prog_bar=True) - - def configure_optimizers(self): - """Configure optimizers.""" - optimizer = torch.optim.Adam(self.parameters(), self.hparams.lr, weight_decay=self.hparams.weight_decay) - scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( - optimizer, - num_warmup_steps=self.hparams.warmup_steps, - num_training_steps=self.trainer.estimated_stepping_batches, - ) - - return [optimizer], [{"scheduler": scheduler, "interval": "step"}] - - def train_dataloader(self): - """Train dataloader.""" - return DataLoader( - train_ds, - shuffle=True, - pin_memory=True, - collate_fn=collate_fn, - persistent_workers=True, - num_workers=NUM_WORKERS, - batch_size=BATCH_SIZE_TRAIN, - prefetch_factor=PREFETCH_FACTOR, - ) - - def val_dataloader(self): - """Validation dataloader.""" - return DataLoader( - val_ds, - pin_memory=True, - collate_fn=collate_fn, - persistent_workers=True, - num_workers=NUM_WORKERS, - batch_size=BATCH_SIZE_EVAL, - prefetch_factor=PREFETCH_FACTOR, - ) - - def test_dataloader(self): - """Test dataloader.""" - return DataLoader( - test_ds, - pin_memory=True, - collate_fn=collate_fn_test, - persistent_workers=True, - num_workers=NUM_WORKERS, - batch_size=BATCH_SIZE_TEST, - prefetch_factor=PREFETCH_FACTOR, - ) - - -if __name__ == "__main__": - - # get model - model_name = "timm/convnextv2_base.fcmae_ft_in22k_in1k_384" - # model_name = "timm/convnextv2_large.fcmae_ft_in22k_in1k_384" # TODO - model = AIorNOT(model_name, lr=5e-5, warmup_steps=500) +def main(argv: List[str]) -> None: + """Main entrypoint for training and inference.""" + # parse args + args = parse_args(argv) + # stfu warnings torch.set_float32_matmul_precision("medium") - # # load checkpoint - # model = AIorNOT.load_from_checkpoint( - # "/home/laurent/AIorNot/lightning_logs/version_73/checkpoints/epoch=2-step=1624.ckpt" - # ) + # set seed + pl.seed_everything(args.seed, workers=True) + + if args.load_ckpt: + # get checkpointed model + model = AIorNOT.load_from_checkpoint( + args.load_ckpt, + args=args, + ) + else: + # get model + model = AIorNOT(args) # # compile model # model.net = torch.compile(model.net) @@ -164,7 +44,7 @@ if __name__ == "__main__": accelerator="gpu", devices="auto", strategy="dp", - max_epochs=5, + max_epochs=args.epochs, precision=16, log_every_n_steps=25, val_check_interval=100, @@ -180,14 +60,19 @@ if __name__ == "__main__": # train model trainer.fit(model) - # make predictions on test set - test_results = trainer.predict(model, dataloaders=model.test_dataloader()) + if not args.skip_csv: + # make predictions on test set + test_results = trainer.predict(model, dataloaders=model.test_dataloader()) - # save predictions to csv - i = 0 - with open(f"results_{trainer.logger.version}.csv", "w") as f: - f.write("id,label\n") - for test_result in test_results: - for logit in test_result.float().sigmoid(): - f.write(f"{test_ds[i]['id']},{float(logit)}\n") - i += 1 + # save predictions to csv + with open(f"submissions/results_{trainer.logger.version}.csv", "w") as f: + i = 0 + f.write("id,label\n") + for test_result in track(test_results): + for logit in test_result.float().sigmoid(): + f.write(f"{test_ds[i]['id']},{float(logit)}\n") + i += 1 + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000..6461d5e --- /dev/null +++ b/src/model.py @@ -0,0 +1,125 @@ +import pytorch_lightning as pl +import timm +import torch +import torchmetrics +from torch.utils.data import DataLoader +from transformers import get_cosine_with_hard_restarts_schedule_with_warmup + +from dataset import test_ds, train_ds, val_ds +from transform import train_transforms, val_transforms + + +def collate_fn(examples): + """Collate function for training and validation.""" + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + labels = torch.tensor([example["label"] for example in examples]) + + return pixel_values, labels + + +def collate_fn_test(examples): + """Collate function for testing.""" + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + + return pixel_values + + +class AIorNOT(pl.LightningModule): + """AIorNOT model.""" + + def __init__(self, args): + """Initialize model.""" + super().__init__() + self.args = args + self.save_hyperparameters() + + self.net = timm.create_model(args.model_name, pretrained=True, num_classes=1) + + self.criterion = torch.nn.BCEWithLogitsLoss() + self.val_accuracy = torchmetrics.Accuracy("binary") + # TODO: add train_accuracy + + def forward(self, pixel_values): + """Forward pass.""" + outputs = self.net(pixel_values) + + return outputs + + def common_step(self, batch, batch_idx): + """Common step for training and validation.""" + pixel_values, labels = batch + labels = labels.float() + logits = self(pixel_values).squeeze(1).float() + loss = self.criterion(logits, labels) + + return loss, logits.sigmoid(), labels + + def training_step(self, batch, batch_idx): + """Training step.""" + loss, _, _ = self.common_step(batch, batch_idx) + self.log("train_loss", loss) + self.log("lr", self.optimizers().param_groups[0]["lr"]) + + return loss + + def validation_step(self, batch, batch_idx): + """Validation step.""" + loss, preds, targets = self.common_step(batch, batch_idx) + self.log("val_loss", loss, on_epoch=True) + + return preds, targets + + def validation_epoch_end(self, outputs): + """Validation epoch end.""" + preds = torch.cat([x[0] for x in outputs]) + targets = torch.cat([x[1] for x in outputs]) + acc = self.val_accuracy(preds, targets) + self.log("val_acc", acc, prog_bar=True) + + def configure_optimizers(self): + """Configure optimizers.""" + optimizer = torch.optim.Adam(self.parameters(), self.args.lr, weight_decay=self.args.weight_decay) + scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( + optimizer, + num_warmup_steps=self.args.warmup_steps, + num_training_steps=self.trainer.estimated_stepping_batches, + ) + + return [optimizer], [{"scheduler": scheduler, "interval": "step"}] + + def train_dataloader(self): + """Train dataloader.""" + return DataLoader( + train_ds.with_transform(train_transforms), + shuffle=True, + pin_memory=True, + collate_fn=collate_fn, + persistent_workers=True, + num_workers=self.args.num_workers, + batch_size=self.args.batch_size, + prefetch_factor=self.args.prefetch_factor, + ) + + def val_dataloader(self): + """Validation dataloader.""" + return DataLoader( + val_ds.with_transform(val_transforms), + pin_memory=True, + collate_fn=collate_fn, + persistent_workers=True, + num_workers=self.args.num_workers, + batch_size=self.args.batch_size_val, + prefetch_factor=self.args.prefetch_factor, + ) + + def test_dataloader(self): + """Test dataloader.""" + return DataLoader( + test_ds.with_transform(val_transforms), + pin_memory=True, + collate_fn=collate_fn_test, + persistent_workers=True, + num_workers=self.args.num_workers, + batch_size=self.args.batch_size_val, + prefetch_factor=self.args.prefetch_factor, + ) diff --git a/src/parse.py b/src/parse.py new file mode 100644 index 0000000..5617efc --- /dev/null +++ b/src/parse.py @@ -0,0 +1,81 @@ +import argparse +from typing import List # TODO: update to python 3.11 + + +def parse_args(argv: List[str]) -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="pytorch lightning + classy vision TorchX example app", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="random seed", + ) + parser.add_argument( + "--model_name", + type=str, + default="timm/convnextv2_base.fcmae_ft_in22k_in1k_384", + help="model name to use from timm", + ) + parser.add_argument( + "--epochs", + type=int, + default=3, + help="number of epochs to train", + ) + parser.add_argument( + "--batch_size", + type=int, + default=35, + help="batch size to use for training", + ) + parser.add_argument( + "--batch_size_val", + type=int, + default=250, + help="batch size to use for validation and testing", + ) + parser.add_argument( + "--lr", + type=float, + default=5e-5, + help="learning rate", + ) + parser.add_argument( + "--weight_decay", + type=float, + default=1e-4, + help="weight decay", + ) + parser.add_argument( + "--warmup_steps", + type=int, + default=500, + help="number of warmup steps for cosine scheduler", + ) + parser.add_argument( + "--skip_csv", + desc="skip export test inference to csv file", + ) + parser.add_argument( + "--load_ckpt", + type=str, + help="checkpoint path to load from", + ) + parser.add_argument( + "--prefetch_factor", + type=int, + default=2, + help="prefetch factor for dataloaders", + ) + parser.add_argument( + "--num_workers", + type=int, + default=4, + help="number of workers for dataloaders", + ) + + return parser.parse_args(argv) diff --git a/src/transform.py b/src/transform.py index 1a6be40..f8bb9cf 100644 --- a/src/transform.py +++ b/src/transform.py @@ -1,4 +1,5 @@ import numpy as np +from imgaug.augmenters import JpegCompression from torchvision.transforms import ( AugMix, Compose, @@ -7,7 +8,6 @@ from torchvision.transforms import ( RandomVerticalFlip, ToTensor, ) -from imgaug.augmenters import JpegCompression # get feature extractor (to normalize images) normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) @@ -15,10 +15,10 @@ normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # define train transform _train_transforms = Compose( [ - AugMix(), + # AugMix(), RandomHorizontalFlip(), RandomVerticalFlip(), - lambda img : JpegCompression(compression=(0, 100))(image=np.array(img)), + # lambda img : JpegCompression(compression=(0, 100))(image=np.array(img)), ToTensor(), normalize, ] @@ -47,21 +47,25 @@ def val_transforms(examples): return examples -if __name__ == '__main__': +if __name__ == "__main__": - from dataset import train_ds import matplotlib.pyplot as plt + from dataset import train_ds + idx = 0 - img = train_ds[idx]['image'] + img = train_ds[idx]["image"] + plt.subplot(1, 2, 1) plt.imshow(img) plt.title("Original") + img = _train_transforms(img) img = np.array(img.permute(1, 2, 0)) img -= img.min() img /= img.max() + plt.subplot(1, 2, 2) plt.imshow(img) plt.title("Augmented") - plt.show() \ No newline at end of file + plt.show()