train and inference base scripts
This commit is contained in:
parent
9ebf7de84d
commit
fb5287eaff
|
@ -18,7 +18,9 @@ huggingface-hub==0.12.0
|
|||
idna==3.4
|
||||
joblib==1.2.0
|
||||
kiwisolver==1.4.4
|
||||
markdown-it-py==2.1.0
|
||||
matplotlib==3.6.3
|
||||
mdurl==0.1.2
|
||||
multidict==6.0.4
|
||||
multiprocess==0.70.14
|
||||
mypy-extensions==0.4.3
|
||||
|
@ -31,6 +33,7 @@ Pillow==9.4.0
|
|||
platformdirs==2.6.2
|
||||
pyarrow==11.0.0
|
||||
pycodestyle==2.10.0
|
||||
Pygments==2.14.0
|
||||
pyparsing==3.0.9
|
||||
python-dateutil==2.8.2
|
||||
pytz==2022.7.1
|
||||
|
@ -38,6 +41,7 @@ PyYAML==6.0
|
|||
regex==2022.10.31
|
||||
requests==2.28.2
|
||||
responses==0.18.0
|
||||
rich==13.3.1
|
||||
scikit-learn==1.2.1
|
||||
scipy==1.10.0
|
||||
six==1.16.0
|
||||
|
|
21
src/acclogloss.py
Normal file
21
src/acclogloss.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
import numpy as np
|
||||
|
||||
def BinaryCrossEntropy(y_true, y_pred):
|
||||
y_pred = np.clip(y_pred, 1e-7, 1 - 1e-7)
|
||||
term_0 = (1-y_true) * np.log(1-y_pred)
|
||||
term_1 = y_true * np.log(y_pred)
|
||||
return -np.mean(term_0+term_1, axis=0)
|
||||
|
||||
nb_tests = 43444
|
||||
|
||||
acc = 0.977
|
||||
|
||||
labels = np.ones(nb_tests)
|
||||
|
||||
nb_true = int(acc * nb_tests)
|
||||
predicitions = np.concatenate((np.ones(nb_true), np.zeros(nb_tests - nb_true)))
|
||||
|
||||
logloss = BinaryCrossEntropy(labels, predicitions)
|
||||
|
||||
print(f"Accuracy: {acc}")
|
||||
print(f"logloss: {logloss}")
|
83
src/inference.py
Normal file
83
src/inference.py
Normal file
|
@ -0,0 +1,83 @@
|
|||
import torch
|
||||
import pandas as pd
|
||||
from transformers import (
|
||||
AutoModelForImageClassification,
|
||||
AutoFeatureExtractor,
|
||||
)
|
||||
from torchvision.transforms import (
|
||||
CenterCrop,
|
||||
Compose,
|
||||
Resize,
|
||||
ToTensor,
|
||||
Normalize,
|
||||
)
|
||||
from rich.progress import track
|
||||
from datetime import datetime
|
||||
|
||||
from dataset import test_ds
|
||||
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
'test-aiornot/checkpoint-500')
|
||||
|
||||
normalize = Normalize(mean=feature_extractor.image_mean,
|
||||
std=feature_extractor.image_std)
|
||||
|
||||
_val_transforms = Compose(
|
||||
[
|
||||
Resize((256, 256)),
|
||||
CenterCrop((256, 256)),
|
||||
ToTensor(),
|
||||
normalize,
|
||||
]
|
||||
)
|
||||
def val_transforms(examples):
|
||||
examples['pixel_values'] = [_val_transforms(
|
||||
image.convert("RGB")) for image in examples['image']]
|
||||
return examples
|
||||
|
||||
test_ds.set_transform(val_transforms)
|
||||
|
||||
def collate_fn(examples):
|
||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||
labels = torch.tensor([example["label"] for example in examples])
|
||||
image_paths = [example["image_path"] for example in examples]
|
||||
|
||||
return {
|
||||
"pixel_values": pixel_values,
|
||||
"labels": labels,
|
||||
"image_path": image_paths,
|
||||
}
|
||||
|
||||
model = AutoModelForImageClassification.from_pretrained(
|
||||
'test-aiornot/checkpoint-1000',
|
||||
)
|
||||
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
test_ds, batch_size=128, collate_fn=collate_fn, pin_memory=True, num_workers=4
|
||||
)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
_ = model.to(device)
|
||||
|
||||
file_paths = []
|
||||
pred_ids = []
|
||||
|
||||
for batch in track(test_loader):
|
||||
image_paths = batch["image_path"]
|
||||
image_paths = [x.split("/")[-1] for x in image_paths]
|
||||
file_paths.extend(image_paths)
|
||||
|
||||
images = batch["pixel_values"].to(device)
|
||||
inputs = {"pixel_values": images}
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(**inputs).logits
|
||||
|
||||
predictions = logits.argmax(-1).cpu().numpy().tolist()
|
||||
pred_ids.extend(predictions)
|
||||
|
||||
|
||||
submission_df = pd.DataFrame({"id": file_paths, "label": pred_ids})
|
||||
submission_df.head()
|
||||
|
||||
TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
submission_df.to_csv(f"{TIMESTAMP}.csv", index=False)
|
|
@ -22,7 +22,7 @@ from datasets import load_metric
|
|||
from dataset import train_ds, val_ds, test_ds, labels, id2label, label2id
|
||||
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
"microsoft/swinv2-base-patch4-window16-256")
|
||||
"facebook/convnext-xlarge-384-22k-1k")
|
||||
|
||||
normalize = Normalize(mean=feature_extractor.image_mean,
|
||||
std=feature_extractor.image_std)
|
||||
|
@ -88,7 +88,7 @@ def collate_fn(examples):
|
|||
|
||||
|
||||
model = AutoModelForImageClassification.from_pretrained(
|
||||
'microsoft/swinv2-base-patch4-window16-256',
|
||||
'facebook/convnext-xlarge-384-22k-1k',
|
||||
num_labels=len(labels),
|
||||
id2label=id2label,
|
||||
label2id=label2id,
|
||||
|
@ -102,8 +102,8 @@ args = TrainingArguments(
|
|||
save_strategy="steps",
|
||||
evaluation_strategy="steps",
|
||||
learning_rate=2e-5,
|
||||
per_device_train_batch_size=8,
|
||||
per_device_eval_batch_size=8,
|
||||
per_device_train_batch_size=24,
|
||||
per_device_eval_batch_size=24,
|
||||
num_train_epochs=3,
|
||||
weight_decay=0.01,
|
||||
load_best_model_at_end=True,
|
||||
|
|
Reference in a new issue