refactor: notebook to scripts dataloader
This commit is contained in:
parent
89f9112fca
commit
7dfcc358e4
|
@ -0,0 +1,38 @@
|
||||||
|
aiohttp==3.8.3
|
||||||
|
aiosignal==1.3.1
|
||||||
|
async-timeout==4.0.2
|
||||||
|
attrs==22.2.0
|
||||||
|
autopep8==2.0.1
|
||||||
|
certifi==2022.12.7
|
||||||
|
charset-normalizer==3.0.1
|
||||||
|
click==8.1.3
|
||||||
|
datasets==2.9.0
|
||||||
|
dill==0.3.6
|
||||||
|
filelock==3.9.0
|
||||||
|
frozenlist==1.3.3
|
||||||
|
fsspec==2023.1.0
|
||||||
|
huggingface-hub==0.12.0
|
||||||
|
idna==3.4
|
||||||
|
multidict==6.0.4
|
||||||
|
multiprocess==0.70.14
|
||||||
|
mypy-extensions==0.4.3
|
||||||
|
numpy==1.24.1
|
||||||
|
opencv-python==4.7.0.68
|
||||||
|
packaging==23.0
|
||||||
|
pandas==1.5.3
|
||||||
|
pathspec==0.11.0
|
||||||
|
platformdirs==2.6.2
|
||||||
|
pyarrow==11.0.0
|
||||||
|
pycodestyle==2.10.0
|
||||||
|
python-dateutil==2.8.2
|
||||||
|
pytz==2022.7.1
|
||||||
|
PyYAML==6.0
|
||||||
|
requests==2.28.2
|
||||||
|
responses==0.18.0
|
||||||
|
six==1.16.0
|
||||||
|
tomli==2.0.1
|
||||||
|
tqdm==4.64.1
|
||||||
|
typing-extensions==4.4.0
|
||||||
|
urllib3==1.26.14
|
||||||
|
xxhash==3.2.0
|
||||||
|
yarl==1.8.2
|
35
src/dataset.py
Normal file
35
src/dataset.py
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
|
||||||
|
# load dataset
|
||||||
|
dataset = load_dataset("tocard-inc/aiornot").shuffle(seed=42)
|
||||||
|
|
||||||
|
# split up training into training + validation
|
||||||
|
splits = dataset['train'].train_test_split(test_size=0.1)
|
||||||
|
|
||||||
|
train_ds = splits['train']
|
||||||
|
val_ds = splits['test']
|
||||||
|
test_ds = dataset['test']
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
labels = train_ds.features['label'].names
|
||||||
|
print(f"labels:\n {labels}")
|
||||||
|
|
||||||
|
id2label = {k: v for k, v in enumerate(labels)}
|
||||||
|
label2id = {v: k for k, v in enumerate(labels)}
|
||||||
|
print(f"label-id correspondances:\n {label2id}\n {id2label}")
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
cv2.imshow(
|
||||||
|
id2label[dataset['train'][idx]['label']],
|
||||||
|
cv2.cvtColor(np.array(dataset['train'][idx]
|
||||||
|
['image']), cv2.COLOR_BGR2RGB),
|
||||||
|
)
|
||||||
|
cv2.imshow("Test", cv2.cvtColor(
|
||||||
|
np.array(dataset['test'][idx]['image']), cv2.COLOR_BGR2RGB))
|
||||||
|
cv2.waitKey(0)
|
|
@ -1,20 +0,0 @@
|
||||||
import datasets
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
dataset = datasets.load_dataset("src/dataset.py")
|
|
||||||
|
|
||||||
labels = dataset["train"].features["label"].names
|
|
||||||
print(labels)
|
|
||||||
|
|
||||||
id2label = {k: v for k, v in enumerate(labels)}
|
|
||||||
label2id = {v: k for k, v in enumerate(labels)}
|
|
||||||
print(label2id)
|
|
||||||
print(id2label)
|
|
||||||
|
|
||||||
idx = 0
|
|
||||||
plt.imshow(dataset["train"][idx]["image"])
|
|
||||||
plt.title(id2label[dataset["train"][idx]["label"]])
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
plt.imshow(dataset["test"][idx]["image"])
|
|
||||||
plt.show()
|
|
Reference in a new issue