refactor: notebook to scripts main
TODO: reorganize main
This commit is contained in:
parent
7dfcc358e4
commit
aac135a3fc
|
@ -4,15 +4,21 @@ async-timeout==4.0.2
|
||||||
attrs==22.2.0
|
attrs==22.2.0
|
||||||
autopep8==2.0.1
|
autopep8==2.0.1
|
||||||
certifi==2022.12.7
|
certifi==2022.12.7
|
||||||
charset-normalizer==3.0.1
|
charset-normalizer==2.1.1
|
||||||
click==8.1.3
|
click==8.1.3
|
||||||
|
contourpy==1.0.7
|
||||||
|
cycler==0.11.0
|
||||||
datasets==2.9.0
|
datasets==2.9.0
|
||||||
dill==0.3.6
|
dill==0.3.6
|
||||||
filelock==3.9.0
|
filelock==3.9.0
|
||||||
|
fonttools==4.38.0
|
||||||
frozenlist==1.3.3
|
frozenlist==1.3.3
|
||||||
fsspec==2023.1.0
|
fsspec==2023.1.0
|
||||||
huggingface-hub==0.12.0
|
huggingface-hub==0.12.0
|
||||||
idna==3.4
|
idna==3.4
|
||||||
|
joblib==1.2.0
|
||||||
|
kiwisolver==1.4.4
|
||||||
|
matplotlib==3.6.3
|
||||||
multidict==6.0.4
|
multidict==6.0.4
|
||||||
multiprocess==0.70.14
|
multiprocess==0.70.14
|
||||||
mypy-extensions==0.4.3
|
mypy-extensions==0.4.3
|
||||||
|
@ -21,17 +27,28 @@ opencv-python==4.7.0.68
|
||||||
packaging==23.0
|
packaging==23.0
|
||||||
pandas==1.5.3
|
pandas==1.5.3
|
||||||
pathspec==0.11.0
|
pathspec==0.11.0
|
||||||
|
Pillow==9.4.0
|
||||||
platformdirs==2.6.2
|
platformdirs==2.6.2
|
||||||
pyarrow==11.0.0
|
pyarrow==11.0.0
|
||||||
pycodestyle==2.10.0
|
pycodestyle==2.10.0
|
||||||
|
pyparsing==3.0.9
|
||||||
python-dateutil==2.8.2
|
python-dateutil==2.8.2
|
||||||
pytz==2022.7.1
|
pytz==2022.7.1
|
||||||
PyYAML==6.0
|
PyYAML==6.0
|
||||||
|
regex==2022.10.31
|
||||||
requests==2.28.2
|
requests==2.28.2
|
||||||
responses==0.18.0
|
responses==0.18.0
|
||||||
|
scikit-learn==1.2.1
|
||||||
|
scipy==1.10.0
|
||||||
six==1.16.0
|
six==1.16.0
|
||||||
|
threadpoolctl==3.1.0
|
||||||
|
tokenizers==0.13.2
|
||||||
tomli==2.0.1
|
tomli==2.0.1
|
||||||
|
torch==1.7.1+cu110
|
||||||
|
torchaudio==0.7.2
|
||||||
|
torchvision==0.8.2+cu110
|
||||||
tqdm==4.64.1
|
tqdm==4.64.1
|
||||||
|
transformers==4.26.0
|
||||||
typing-extensions==4.4.0
|
typing-extensions==4.4.0
|
||||||
urllib3==1.26.14
|
urllib3==1.26.14
|
||||||
xxhash==3.2.0
|
xxhash==3.2.0
|
||||||
|
|
|
@ -11,25 +11,28 @@ train_ds = splits['train']
|
||||||
val_ds = splits['test']
|
val_ds = splits['test']
|
||||||
test_ds = dataset['test']
|
test_ds = dataset['test']
|
||||||
|
|
||||||
|
labels = train_ds.features['label'].names
|
||||||
|
id2label = {k: v for k, v in enumerate(labels)}
|
||||||
|
label2id = {v: k for k, v in enumerate(labels)}
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
labels = train_ds.features['label'].names
|
|
||||||
print(f"labels:\n {labels}")
|
print(f"labels:\n {labels}")
|
||||||
|
|
||||||
id2label = {k: v for k, v in enumerate(labels)}
|
|
||||||
label2id = {v: k for k, v in enumerate(labels)}
|
|
||||||
print(f"label-id correspondances:\n {label2id}\n {id2label}")
|
print(f"label-id correspondances:\n {label2id}\n {id2label}")
|
||||||
|
|
||||||
idx = 0
|
idx = 0
|
||||||
cv2.imshow(
|
label = id2label[dataset['train'][idx]['label']]
|
||||||
id2label[dataset['train'][idx]['label']],
|
image = cv2.cvtColor(
|
||||||
cv2.cvtColor(np.array(dataset['train'][idx]
|
np.array(dataset['train'][idx]['image']),
|
||||||
['image']), cv2.COLOR_BGR2RGB),
|
cv2.COLOR_BGR2RGB)
|
||||||
)
|
cv2.namedWindow(label, cv2.WINDOW_NORMAL)
|
||||||
cv2.imshow("Test", cv2.cvtColor(
|
cv2.imshow(label, image)
|
||||||
np.array(dataset['test'][idx]['image']), cv2.COLOR_BGR2RGB))
|
image = cv2.cvtColor(
|
||||||
|
np.array(dataset['test'][idx]['image']),
|
||||||
|
cv2.COLOR_BGR2RGB)
|
||||||
|
cv2.namedWindow("Test", cv2.WINDOW_NORMAL)
|
||||||
|
cv2.imshow("Test", image)
|
||||||
cv2.waitKey(0)
|
cv2.waitKey(0)
|
||||||
|
|
138
src/main.py
138
src/main.py
|
@ -0,0 +1,138 @@
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForImageClassification,
|
||||||
|
AutoFeatureExtractor,
|
||||||
|
TrainingArguments,
|
||||||
|
Trainer
|
||||||
|
)
|
||||||
|
from torchvision.transforms import (
|
||||||
|
CenterCrop,
|
||||||
|
Compose,
|
||||||
|
Normalize,
|
||||||
|
RandomHorizontalFlip,
|
||||||
|
RandomResizedCrop,
|
||||||
|
Resize,
|
||||||
|
ToTensor,
|
||||||
|
ToPILImage
|
||||||
|
)
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from datasets import load_metric
|
||||||
|
|
||||||
|
from dataset import train_ds, val_ds, test_ds, labels, id2label, label2id
|
||||||
|
|
||||||
|
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||||
|
"microsoft/swinv2-base-patch4-window16-256")
|
||||||
|
|
||||||
|
normalize = Normalize(mean=feature_extractor.image_mean,
|
||||||
|
std=feature_extractor.image_std)
|
||||||
|
|
||||||
|
_train_transforms = Compose(
|
||||||
|
[
|
||||||
|
RandomResizedCrop((256, 256)),
|
||||||
|
RandomHorizontalFlip(),
|
||||||
|
ToTensor(),
|
||||||
|
normalize,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
_val_transforms = Compose(
|
||||||
|
[
|
||||||
|
Resize((256, 256)),
|
||||||
|
CenterCrop((256, 256)),
|
||||||
|
ToTensor(),
|
||||||
|
normalize,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def train_transforms(examples):
|
||||||
|
examples['pixel_values'] = [_train_transforms(
|
||||||
|
image.convert("RGB")) for image in examples['image']]
|
||||||
|
return examples
|
||||||
|
|
||||||
|
|
||||||
|
def val_transforms(examples):
|
||||||
|
examples['pixel_values'] = [_val_transforms(
|
||||||
|
image.convert("RGB")) for image in examples['image']]
|
||||||
|
return examples
|
||||||
|
|
||||||
|
|
||||||
|
# Set the transforms
|
||||||
|
train_ds.set_transform(train_transforms)
|
||||||
|
val_ds.set_transform(val_transforms)
|
||||||
|
test_ds.set_transform(val_transforms)
|
||||||
|
|
||||||
|
transform = ToPILImage()
|
||||||
|
|
||||||
|
img = train_ds[0]["pixel_values"]
|
||||||
|
img = img - min(img.flatten().numpy())
|
||||||
|
img = img / max(img.flatten().numpy())
|
||||||
|
|
||||||
|
plt.figure("Augmentation")
|
||||||
|
plt.subplot(1, 2, 1)
|
||||||
|
plt.imshow(train_ds[0]["image"])
|
||||||
|
plt.title("label: " + id2label[train_ds[0]["label"]])
|
||||||
|
plt.subplot(1, 2, 2)
|
||||||
|
plt.imshow(transform(img))
|
||||||
|
plt.title("augmented image")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# Prepare trainer
|
||||||
|
|
||||||
|
def collate_fn(examples):
|
||||||
|
pixel_values = torch.stack([example["pixel_values"]
|
||||||
|
for example in examples])
|
||||||
|
labels = torch.tensor([example["label"] for example in examples])
|
||||||
|
return {"pixel_values": pixel_values, "labels": labels}
|
||||||
|
|
||||||
|
|
||||||
|
model = AutoModelForImageClassification.from_pretrained(
|
||||||
|
'microsoft/swinv2-base-patch4-window16-256',
|
||||||
|
num_labels=len(labels),
|
||||||
|
id2label=id2label,
|
||||||
|
label2id=label2id,
|
||||||
|
ignore_mismatched_sizes=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
metric_name = "accuracy"
|
||||||
|
|
||||||
|
args = TrainingArguments(
|
||||||
|
f"test-aiornot",
|
||||||
|
save_strategy="steps",
|
||||||
|
evaluation_strategy="steps",
|
||||||
|
learning_rate=2e-5,
|
||||||
|
per_device_train_batch_size=8,
|
||||||
|
per_device_eval_batch_size=8,
|
||||||
|
num_train_epochs=3,
|
||||||
|
weight_decay=0.01,
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
metric_for_best_model=metric_name,
|
||||||
|
eval_steps=250,
|
||||||
|
logging_dir='logs',
|
||||||
|
logging_steps=10,
|
||||||
|
remove_unused_columns=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
metric = load_metric("accuracy")
|
||||||
|
|
||||||
|
|
||||||
|
def compute_metrics(eval_pred):
|
||||||
|
predictions, labels = eval_pred
|
||||||
|
predictions = np.argmax(predictions, axis=1)
|
||||||
|
return metric.compute(predictions=predictions, references=labels)
|
||||||
|
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model,
|
||||||
|
args,
|
||||||
|
train_dataset=train_ds,
|
||||||
|
eval_dataset=val_ds,
|
||||||
|
data_collator=collate_fn,
|
||||||
|
compute_metrics=compute_metrics,
|
||||||
|
tokenizer=feature_extractor,
|
||||||
|
)
|
||||||
|
# Start tensorboard.
|
||||||
|
# %load_ext tensorboard
|
||||||
|
# %tensorboard - -logdir logs/
|
||||||
|
trainer.train()
|
Reference in a new issue