fix: dataset dataloader

This commit is contained in:
Laureηt 2023-01-28 16:20:40 +01:00
parent 9b0a8bc1af
commit 684f627e73
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
3 changed files with 22 additions and 11 deletions

View file

@ -12,10 +12,12 @@
in {
devShell = pkgs.mkShell {
buildInputs = with pkgs; [
tk
poetry
python3
python310Packages.numpy
python310Packages.datasets
python310Packages.matplotlib
];
};
});

View file

@ -1,10 +1,10 @@
"""Dataset class AI or NOT HuggingFace competition."""
import csv
import pathlib
from typing import Optional
import datasets
import pandas as pd
_VERSION = "1.0.0"
@ -24,8 +24,8 @@ Please use the community tab for discussion and questions.
"""
_NAMES = [
"AI",
"NOT",
"AI",
]
@ -56,14 +56,14 @@ class aiornot(datasets.GeneratorBasedBuilder):
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
"data_dir": train_path,
"data_dir": train_path / "train",
"csv_file": csv_path,
},
),
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={
"data_dir": test_path,
"data_dir": test_path / "test",
},
),
]
@ -71,12 +71,14 @@ class aiornot(datasets.GeneratorBasedBuilder):
def _generate_examples(self, data_dir: pathlib.Path, csv_file: Optional[pathlib.Path] = None):
"""Generate images and labels for splits."""
if csv_file is not None:
df = pd.read_csv(csv_file)
for index, row in df.iterrows():
yield index, {
"image": str(data_dir / row["image"]),
"label": row["label"],
}
with open(csv_file, "r") as f:
reader = csv.reader(f)
next(reader)
for index, row in enumerate(reader):
yield index, {
"image": str(data_dir / row[0]),
"label": row[1],
}
else:
rglob = pathlib.Path(data_dir).rglob("*.jpg")
for index, filepath in enumerate(rglob):

View file

@ -1,5 +1,12 @@
import datasets
import matplotlib.pyplot as plt
dataset = datasets.load_dataset("src/dataset.py")
print(dataset)
idx = 0
plt.imshow(dataset["train"][idx]["image"])
plt.title(dataset["train"].features["label"].names[dataset["train"][idx]["label"]])
plt.show()
plt.imshow(dataset["test"][idx]["image"])
plt.show()