refactor: notebook to scripts main

TODO: reorganize main
This commit is contained in:
gdamms 2023-02-02 22:56:20 +01:00
parent 7dfcc358e4
commit aac135a3fc
3 changed files with 170 additions and 12 deletions

View file

@ -4,15 +4,21 @@ async-timeout==4.0.2
attrs==22.2.0 attrs==22.2.0
autopep8==2.0.1 autopep8==2.0.1
certifi==2022.12.7 certifi==2022.12.7
charset-normalizer==3.0.1 charset-normalizer==2.1.1
click==8.1.3 click==8.1.3
contourpy==1.0.7
cycler==0.11.0
datasets==2.9.0 datasets==2.9.0
dill==0.3.6 dill==0.3.6
filelock==3.9.0 filelock==3.9.0
fonttools==4.38.0
frozenlist==1.3.3 frozenlist==1.3.3
fsspec==2023.1.0 fsspec==2023.1.0
huggingface-hub==0.12.0 huggingface-hub==0.12.0
idna==3.4 idna==3.4
joblib==1.2.0
kiwisolver==1.4.4
matplotlib==3.6.3
multidict==6.0.4 multidict==6.0.4
multiprocess==0.70.14 multiprocess==0.70.14
mypy-extensions==0.4.3 mypy-extensions==0.4.3
@ -21,17 +27,28 @@ opencv-python==4.7.0.68
packaging==23.0 packaging==23.0
pandas==1.5.3 pandas==1.5.3
pathspec==0.11.0 pathspec==0.11.0
Pillow==9.4.0
platformdirs==2.6.2 platformdirs==2.6.2
pyarrow==11.0.0 pyarrow==11.0.0
pycodestyle==2.10.0 pycodestyle==2.10.0
pyparsing==3.0.9
python-dateutil==2.8.2 python-dateutil==2.8.2
pytz==2022.7.1 pytz==2022.7.1
PyYAML==6.0 PyYAML==6.0
regex==2022.10.31
requests==2.28.2 requests==2.28.2
responses==0.18.0 responses==0.18.0
scikit-learn==1.2.1
scipy==1.10.0
six==1.16.0 six==1.16.0
threadpoolctl==3.1.0
tokenizers==0.13.2
tomli==2.0.1 tomli==2.0.1
torch==1.7.1+cu110
torchaudio==0.7.2
torchvision==0.8.2+cu110
tqdm==4.64.1 tqdm==4.64.1
transformers==4.26.0
typing-extensions==4.4.0 typing-extensions==4.4.0
urllib3==1.26.14 urllib3==1.26.14
xxhash==3.2.0 xxhash==3.2.0

View file

@ -11,25 +11,28 @@ train_ds = splits['train']
val_ds = splits['test'] val_ds = splits['test']
test_ds = dataset['test'] test_ds = dataset['test']
labels = train_ds.features['label'].names
id2label = {k: v for k, v in enumerate(labels)}
label2id = {v: k for k, v in enumerate(labels)}
if __name__ == '__main__': if __name__ == '__main__':
import cv2 import cv2
import numpy as np import numpy as np
labels = train_ds.features['label'].names
print(f"labels:\n {labels}") print(f"labels:\n {labels}")
id2label = {k: v for k, v in enumerate(labels)}
label2id = {v: k for k, v in enumerate(labels)}
print(f"label-id correspondances:\n {label2id}\n {id2label}") print(f"label-id correspondances:\n {label2id}\n {id2label}")
idx = 0 idx = 0
cv2.imshow( label = id2label[dataset['train'][idx]['label']]
id2label[dataset['train'][idx]['label']], image = cv2.cvtColor(
cv2.cvtColor(np.array(dataset['train'][idx] np.array(dataset['train'][idx]['image']),
['image']), cv2.COLOR_BGR2RGB), cv2.COLOR_BGR2RGB)
) cv2.namedWindow(label, cv2.WINDOW_NORMAL)
cv2.imshow("Test", cv2.cvtColor( cv2.imshow(label, image)
np.array(dataset['test'][idx]['image']), cv2.COLOR_BGR2RGB)) image = cv2.cvtColor(
np.array(dataset['test'][idx]['image']),
cv2.COLOR_BGR2RGB)
cv2.namedWindow("Test", cv2.WINDOW_NORMAL)
cv2.imshow("Test", image)
cv2.waitKey(0) cv2.waitKey(0)

View file

@ -0,0 +1,138 @@
from transformers import (
AutoModelForImageClassification,
AutoFeatureExtractor,
TrainingArguments,
Trainer
)
from torchvision.transforms import (
CenterCrop,
Compose,
Normalize,
RandomHorizontalFlip,
RandomResizedCrop,
Resize,
ToTensor,
ToPILImage
)
import torch
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_metric
from dataset import train_ds, val_ds, test_ds, labels, id2label, label2id
feature_extractor = AutoFeatureExtractor.from_pretrained(
"microsoft/swinv2-base-patch4-window16-256")
normalize = Normalize(mean=feature_extractor.image_mean,
std=feature_extractor.image_std)
_train_transforms = Compose(
[
RandomResizedCrop((256, 256)),
RandomHorizontalFlip(),
ToTensor(),
normalize,
]
)
_val_transforms = Compose(
[
Resize((256, 256)),
CenterCrop((256, 256)),
ToTensor(),
normalize,
]
)
def train_transforms(examples):
examples['pixel_values'] = [_train_transforms(
image.convert("RGB")) for image in examples['image']]
return examples
def val_transforms(examples):
examples['pixel_values'] = [_val_transforms(
image.convert("RGB")) for image in examples['image']]
return examples
# Set the transforms
train_ds.set_transform(train_transforms)
val_ds.set_transform(val_transforms)
test_ds.set_transform(val_transforms)
transform = ToPILImage()
img = train_ds[0]["pixel_values"]
img = img - min(img.flatten().numpy())
img = img / max(img.flatten().numpy())
plt.figure("Augmentation")
plt.subplot(1, 2, 1)
plt.imshow(train_ds[0]["image"])
plt.title("label: " + id2label[train_ds[0]["label"]])
plt.subplot(1, 2, 2)
plt.imshow(transform(img))
plt.title("augmented image")
plt.show()
# Prepare trainer
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"]
for example in examples])
labels = torch.tensor([example["label"] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}
model = AutoModelForImageClassification.from_pretrained(
'microsoft/swinv2-base-patch4-window16-256',
num_labels=len(labels),
id2label=id2label,
label2id=label2id,
ignore_mismatched_sizes=True,
)
metric_name = "accuracy"
args = TrainingArguments(
f"test-aiornot",
save_strategy="steps",
evaluation_strategy="steps",
learning_rate=2e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=3,
weight_decay=0.01,
load_best_model_at_end=True,
metric_for_best_model=metric_name,
eval_steps=250,
logging_dir='logs',
logging_steps=10,
remove_unused_columns=False,
)
metric = load_metric("accuracy")
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return metric.compute(predictions=predictions, references=labels)
trainer = Trainer(
model,
args,
train_dataset=train_ds,
eval_dataset=val_ds,
data_collator=collate_fn,
compute_metrics=compute_metrics,
tokenizer=feature_extractor,
)
# Start tensorboard.
# %load_ext tensorboard
# %tensorboard - -logdir logs/
trainer.train()