fix: dataset dataloader

This commit is contained in:
Laureηt 2023-01-28 16:20:40 +01:00
parent 9b0a8bc1af
commit 684f627e73
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
3 changed files with 22 additions and 11 deletions

View file

@ -12,10 +12,12 @@
in { in {
devShell = pkgs.mkShell { devShell = pkgs.mkShell {
buildInputs = with pkgs; [ buildInputs = with pkgs; [
tk
poetry poetry
python3 python3
python310Packages.numpy python310Packages.numpy
python310Packages.datasets python310Packages.datasets
python310Packages.matplotlib
]; ];
}; };
}); });

View file

@ -1,10 +1,10 @@
"""Dataset class AI or NOT HuggingFace competition.""" """Dataset class AI or NOT HuggingFace competition."""
import csv
import pathlib import pathlib
from typing import Optional from typing import Optional
import datasets import datasets
import pandas as pd
_VERSION = "1.0.0" _VERSION = "1.0.0"
@ -24,8 +24,8 @@ Please use the community tab for discussion and questions.
""" """
_NAMES = [ _NAMES = [
"AI",
"NOT", "NOT",
"AI",
] ]
@ -56,14 +56,14 @@ class aiornot(datasets.GeneratorBasedBuilder):
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TRAIN, name=datasets.Split.TRAIN,
gen_kwargs={ gen_kwargs={
"data_dir": train_path, "data_dir": train_path / "train",
"csv_file": csv_path, "csv_file": csv_path,
}, },
), ),
datasets.SplitGenerator( datasets.SplitGenerator(
name=datasets.Split.TEST, name=datasets.Split.TEST,
gen_kwargs={ gen_kwargs={
"data_dir": test_path, "data_dir": test_path / "test",
}, },
), ),
] ]
@ -71,12 +71,14 @@ class aiornot(datasets.GeneratorBasedBuilder):
def _generate_examples(self, data_dir: pathlib.Path, csv_file: Optional[pathlib.Path] = None): def _generate_examples(self, data_dir: pathlib.Path, csv_file: Optional[pathlib.Path] = None):
"""Generate images and labels for splits.""" """Generate images and labels for splits."""
if csv_file is not None: if csv_file is not None:
df = pd.read_csv(csv_file) with open(csv_file, "r") as f:
for index, row in df.iterrows(): reader = csv.reader(f)
yield index, { next(reader)
"image": str(data_dir / row["image"]), for index, row in enumerate(reader):
"label": row["label"], yield index, {
} "image": str(data_dir / row[0]),
"label": row[1],
}
else: else:
rglob = pathlib.Path(data_dir).rglob("*.jpg") rglob = pathlib.Path(data_dir).rglob("*.jpg")
for index, filepath in enumerate(rglob): for index, filepath in enumerate(rglob):

View file

@ -1,5 +1,12 @@
import datasets import datasets
import matplotlib.pyplot as plt
dataset = datasets.load_dataset("src/dataset.py") dataset = datasets.load_dataset("src/dataset.py")
print(dataset) idx = 0
plt.imshow(dataset["train"][idx]["image"])
plt.title(dataset["train"].features["label"].names[dataset["train"][idx]["label"]])
plt.show()
plt.imshow(dataset["test"][idx]["image"])
plt.show()