train and inference base scripts
This commit is contained in:
parent
9ebf7de84d
commit
fb5287eaff
|
@ -18,7 +18,9 @@ huggingface-hub==0.12.0
|
||||||
idna==3.4
|
idna==3.4
|
||||||
joblib==1.2.0
|
joblib==1.2.0
|
||||||
kiwisolver==1.4.4
|
kiwisolver==1.4.4
|
||||||
|
markdown-it-py==2.1.0
|
||||||
matplotlib==3.6.3
|
matplotlib==3.6.3
|
||||||
|
mdurl==0.1.2
|
||||||
multidict==6.0.4
|
multidict==6.0.4
|
||||||
multiprocess==0.70.14
|
multiprocess==0.70.14
|
||||||
mypy-extensions==0.4.3
|
mypy-extensions==0.4.3
|
||||||
|
@ -31,6 +33,7 @@ Pillow==9.4.0
|
||||||
platformdirs==2.6.2
|
platformdirs==2.6.2
|
||||||
pyarrow==11.0.0
|
pyarrow==11.0.0
|
||||||
pycodestyle==2.10.0
|
pycodestyle==2.10.0
|
||||||
|
Pygments==2.14.0
|
||||||
pyparsing==3.0.9
|
pyparsing==3.0.9
|
||||||
python-dateutil==2.8.2
|
python-dateutil==2.8.2
|
||||||
pytz==2022.7.1
|
pytz==2022.7.1
|
||||||
|
@ -38,6 +41,7 @@ PyYAML==6.0
|
||||||
regex==2022.10.31
|
regex==2022.10.31
|
||||||
requests==2.28.2
|
requests==2.28.2
|
||||||
responses==0.18.0
|
responses==0.18.0
|
||||||
|
rich==13.3.1
|
||||||
scikit-learn==1.2.1
|
scikit-learn==1.2.1
|
||||||
scipy==1.10.0
|
scipy==1.10.0
|
||||||
six==1.16.0
|
six==1.16.0
|
||||||
|
|
21
src/acclogloss.py
Normal file
21
src/acclogloss.py
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def BinaryCrossEntropy(y_true, y_pred):
|
||||||
|
y_pred = np.clip(y_pred, 1e-7, 1 - 1e-7)
|
||||||
|
term_0 = (1-y_true) * np.log(1-y_pred)
|
||||||
|
term_1 = y_true * np.log(y_pred)
|
||||||
|
return -np.mean(term_0+term_1, axis=0)
|
||||||
|
|
||||||
|
nb_tests = 43444
|
||||||
|
|
||||||
|
acc = 0.977
|
||||||
|
|
||||||
|
labels = np.ones(nb_tests)
|
||||||
|
|
||||||
|
nb_true = int(acc * nb_tests)
|
||||||
|
predicitions = np.concatenate((np.ones(nb_true), np.zeros(nb_tests - nb_true)))
|
||||||
|
|
||||||
|
logloss = BinaryCrossEntropy(labels, predicitions)
|
||||||
|
|
||||||
|
print(f"Accuracy: {acc}")
|
||||||
|
print(f"logloss: {logloss}")
|
83
src/inference.py
Normal file
83
src/inference.py
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
import torch
|
||||||
|
import pandas as pd
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForImageClassification,
|
||||||
|
AutoFeatureExtractor,
|
||||||
|
)
|
||||||
|
from torchvision.transforms import (
|
||||||
|
CenterCrop,
|
||||||
|
Compose,
|
||||||
|
Resize,
|
||||||
|
ToTensor,
|
||||||
|
Normalize,
|
||||||
|
)
|
||||||
|
from rich.progress import track
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from dataset import test_ds
|
||||||
|
|
||||||
|
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||||
|
'test-aiornot/checkpoint-500')
|
||||||
|
|
||||||
|
normalize = Normalize(mean=feature_extractor.image_mean,
|
||||||
|
std=feature_extractor.image_std)
|
||||||
|
|
||||||
|
_val_transforms = Compose(
|
||||||
|
[
|
||||||
|
Resize((256, 256)),
|
||||||
|
CenterCrop((256, 256)),
|
||||||
|
ToTensor(),
|
||||||
|
normalize,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def val_transforms(examples):
|
||||||
|
examples['pixel_values'] = [_val_transforms(
|
||||||
|
image.convert("RGB")) for image in examples['image']]
|
||||||
|
return examples
|
||||||
|
|
||||||
|
test_ds.set_transform(val_transforms)
|
||||||
|
|
||||||
|
def collate_fn(examples):
|
||||||
|
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||||
|
labels = torch.tensor([example["label"] for example in examples])
|
||||||
|
image_paths = [example["image_path"] for example in examples]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"pixel_values": pixel_values,
|
||||||
|
"labels": labels,
|
||||||
|
"image_path": image_paths,
|
||||||
|
}
|
||||||
|
|
||||||
|
model = AutoModelForImageClassification.from_pretrained(
|
||||||
|
'test-aiornot/checkpoint-1000',
|
||||||
|
)
|
||||||
|
|
||||||
|
test_loader = torch.utils.data.DataLoader(
|
||||||
|
test_ds, batch_size=128, collate_fn=collate_fn, pin_memory=True, num_workers=4
|
||||||
|
)
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
_ = model.to(device)
|
||||||
|
|
||||||
|
file_paths = []
|
||||||
|
pred_ids = []
|
||||||
|
|
||||||
|
for batch in track(test_loader):
|
||||||
|
image_paths = batch["image_path"]
|
||||||
|
image_paths = [x.split("/")[-1] for x in image_paths]
|
||||||
|
file_paths.extend(image_paths)
|
||||||
|
|
||||||
|
images = batch["pixel_values"].to(device)
|
||||||
|
inputs = {"pixel_values": images}
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(**inputs).logits
|
||||||
|
|
||||||
|
predictions = logits.argmax(-1).cpu().numpy().tolist()
|
||||||
|
pred_ids.extend(predictions)
|
||||||
|
|
||||||
|
|
||||||
|
submission_df = pd.DataFrame({"id": file_paths, "label": pred_ids})
|
||||||
|
submission_df.head()
|
||||||
|
|
||||||
|
TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||||
|
submission_df.to_csv(f"{TIMESTAMP}.csv", index=False)
|
|
@ -22,7 +22,7 @@ from datasets import load_metric
|
||||||
from dataset import train_ds, val_ds, test_ds, labels, id2label, label2id
|
from dataset import train_ds, val_ds, test_ds, labels, id2label, label2id
|
||||||
|
|
||||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||||
"microsoft/swinv2-base-patch4-window16-256")
|
"facebook/convnext-xlarge-384-22k-1k")
|
||||||
|
|
||||||
normalize = Normalize(mean=feature_extractor.image_mean,
|
normalize = Normalize(mean=feature_extractor.image_mean,
|
||||||
std=feature_extractor.image_std)
|
std=feature_extractor.image_std)
|
||||||
|
@ -88,7 +88,7 @@ def collate_fn(examples):
|
||||||
|
|
||||||
|
|
||||||
model = AutoModelForImageClassification.from_pretrained(
|
model = AutoModelForImageClassification.from_pretrained(
|
||||||
'microsoft/swinv2-base-patch4-window16-256',
|
'facebook/convnext-xlarge-384-22k-1k',
|
||||||
num_labels=len(labels),
|
num_labels=len(labels),
|
||||||
id2label=id2label,
|
id2label=id2label,
|
||||||
label2id=label2id,
|
label2id=label2id,
|
||||||
|
@ -102,8 +102,8 @@ args = TrainingArguments(
|
||||||
save_strategy="steps",
|
save_strategy="steps",
|
||||||
evaluation_strategy="steps",
|
evaluation_strategy="steps",
|
||||||
learning_rate=2e-5,
|
learning_rate=2e-5,
|
||||||
per_device_train_batch_size=8,
|
per_device_train_batch_size=24,
|
||||||
per_device_eval_batch_size=8,
|
per_device_eval_batch_size=24,
|
||||||
num_train_epochs=3,
|
num_train_epochs=3,
|
||||||
weight_decay=0.01,
|
weight_decay=0.01,
|
||||||
load_best_model_at_end=True,
|
load_best_model_at_end=True,
|
||||||
|
|
Reference in a new issue