From aac135a3fcd1c664717524f6ec4c3374af233c6b Mon Sep 17 00:00:00 2001 From: gdamms Date: Thu, 2 Feb 2023 22:56:20 +0100 Subject: [PATCH] refactor: notebook to scripts main TODO: reorganize main --- requirements.txt | 19 ++++++- src/dataset.py | 25 +++++---- src/main.py | 138 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 170 insertions(+), 12 deletions(-) diff --git a/requirements.txt b/requirements.txt index db7233e..c5f5b0a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,15 +4,21 @@ async-timeout==4.0.2 attrs==22.2.0 autopep8==2.0.1 certifi==2022.12.7 -charset-normalizer==3.0.1 +charset-normalizer==2.1.1 click==8.1.3 +contourpy==1.0.7 +cycler==0.11.0 datasets==2.9.0 dill==0.3.6 filelock==3.9.0 +fonttools==4.38.0 frozenlist==1.3.3 fsspec==2023.1.0 huggingface-hub==0.12.0 idna==3.4 +joblib==1.2.0 +kiwisolver==1.4.4 +matplotlib==3.6.3 multidict==6.0.4 multiprocess==0.70.14 mypy-extensions==0.4.3 @@ -21,17 +27,28 @@ opencv-python==4.7.0.68 packaging==23.0 pandas==1.5.3 pathspec==0.11.0 +Pillow==9.4.0 platformdirs==2.6.2 pyarrow==11.0.0 pycodestyle==2.10.0 +pyparsing==3.0.9 python-dateutil==2.8.2 pytz==2022.7.1 PyYAML==6.0 +regex==2022.10.31 requests==2.28.2 responses==0.18.0 +scikit-learn==1.2.1 +scipy==1.10.0 six==1.16.0 +threadpoolctl==3.1.0 +tokenizers==0.13.2 tomli==2.0.1 +torch==1.7.1+cu110 +torchaudio==0.7.2 +torchvision==0.8.2+cu110 tqdm==4.64.1 +transformers==4.26.0 typing-extensions==4.4.0 urllib3==1.26.14 xxhash==3.2.0 diff --git a/src/dataset.py b/src/dataset.py index a037fd2..fc39740 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -11,25 +11,28 @@ train_ds = splits['train'] val_ds = splits['test'] test_ds = dataset['test'] +labels = train_ds.features['label'].names +id2label = {k: v for k, v in enumerate(labels)} +label2id = {v: k for k, v in enumerate(labels)} if __name__ == '__main__': import cv2 import numpy as np - labels = train_ds.features['label'].names print(f"labels:\n {labels}") - - id2label = {k: v for k, v in enumerate(labels)} - label2id = {v: k for k, v in enumerate(labels)} print(f"label-id correspondances:\n {label2id}\n {id2label}") idx = 0 - cv2.imshow( - id2label[dataset['train'][idx]['label']], - cv2.cvtColor(np.array(dataset['train'][idx] - ['image']), cv2.COLOR_BGR2RGB), - ) - cv2.imshow("Test", cv2.cvtColor( - np.array(dataset['test'][idx]['image']), cv2.COLOR_BGR2RGB)) + label = id2label[dataset['train'][idx]['label']] + image = cv2.cvtColor( + np.array(dataset['train'][idx]['image']), + cv2.COLOR_BGR2RGB) + cv2.namedWindow(label, cv2.WINDOW_NORMAL) + cv2.imshow(label, image) + image = cv2.cvtColor( + np.array(dataset['test'][idx]['image']), + cv2.COLOR_BGR2RGB) + cv2.namedWindow("Test", cv2.WINDOW_NORMAL) + cv2.imshow("Test", image) cv2.waitKey(0) diff --git a/src/main.py b/src/main.py index e69de29..0a4c347 100644 --- a/src/main.py +++ b/src/main.py @@ -0,0 +1,138 @@ +from transformers import ( + AutoModelForImageClassification, + AutoFeatureExtractor, + TrainingArguments, + Trainer +) +from torchvision.transforms import ( + CenterCrop, + Compose, + Normalize, + RandomHorizontalFlip, + RandomResizedCrop, + Resize, + ToTensor, + ToPILImage +) +import torch +import numpy as np +import matplotlib.pyplot as plt +from datasets import load_metric + +from dataset import train_ds, val_ds, test_ds, labels, id2label, label2id + +feature_extractor = AutoFeatureExtractor.from_pretrained( + "microsoft/swinv2-base-patch4-window16-256") + +normalize = Normalize(mean=feature_extractor.image_mean, + std=feature_extractor.image_std) + +_train_transforms = Compose( + [ + RandomResizedCrop((256, 256)), + RandomHorizontalFlip(), + ToTensor(), + normalize, + ] +) + +_val_transforms = Compose( + [ + Resize((256, 256)), + CenterCrop((256, 256)), + ToTensor(), + normalize, + ] +) + + +def train_transforms(examples): + examples['pixel_values'] = [_train_transforms( + image.convert("RGB")) for image in examples['image']] + return examples + + +def val_transforms(examples): + examples['pixel_values'] = [_val_transforms( + image.convert("RGB")) for image in examples['image']] + return examples + + +# Set the transforms +train_ds.set_transform(train_transforms) +val_ds.set_transform(val_transforms) +test_ds.set_transform(val_transforms) + +transform = ToPILImage() + +img = train_ds[0]["pixel_values"] +img = img - min(img.flatten().numpy()) +img = img / max(img.flatten().numpy()) + +plt.figure("Augmentation") +plt.subplot(1, 2, 1) +plt.imshow(train_ds[0]["image"]) +plt.title("label: " + id2label[train_ds[0]["label"]]) +plt.subplot(1, 2, 2) +plt.imshow(transform(img)) +plt.title("augmented image") +plt.show() + +# Prepare trainer + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] + for example in examples]) + labels = torch.tensor([example["label"] for example in examples]) + return {"pixel_values": pixel_values, "labels": labels} + + +model = AutoModelForImageClassification.from_pretrained( + 'microsoft/swinv2-base-patch4-window16-256', + num_labels=len(labels), + id2label=id2label, + label2id=label2id, + ignore_mismatched_sizes=True, +) + +metric_name = "accuracy" + +args = TrainingArguments( + f"test-aiornot", + save_strategy="steps", + evaluation_strategy="steps", + learning_rate=2e-5, + per_device_train_batch_size=8, + per_device_eval_batch_size=8, + num_train_epochs=3, + weight_decay=0.01, + load_best_model_at_end=True, + metric_for_best_model=metric_name, + eval_steps=250, + logging_dir='logs', + logging_steps=10, + remove_unused_columns=False, +) + +metric = load_metric("accuracy") + + +def compute_metrics(eval_pred): + predictions, labels = eval_pred + predictions = np.argmax(predictions, axis=1) + return metric.compute(predictions=predictions, references=labels) + + +trainer = Trainer( + model, + args, + train_dataset=train_ds, + eval_dataset=val_ds, + data_collator=collate_fn, + compute_metrics=compute_metrics, + tokenizer=feature_extractor, +) +# Start tensorboard. +# %load_ext tensorboard +# %tensorboard - -logdir logs/ +trainer.train()