fix: dataset dataloader
This commit is contained in:
parent
9b0a8bc1af
commit
684f627e73
|
@ -12,10 +12,12 @@
|
|||
in {
|
||||
devShell = pkgs.mkShell {
|
||||
buildInputs = with pkgs; [
|
||||
tk
|
||||
poetry
|
||||
python3
|
||||
python310Packages.numpy
|
||||
python310Packages.datasets
|
||||
python310Packages.matplotlib
|
||||
];
|
||||
};
|
||||
});
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
"""Dataset class AI or NOT HuggingFace competition."""
|
||||
|
||||
import csv
|
||||
import pathlib
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
|
||||
_VERSION = "1.0.0"
|
||||
|
||||
|
@ -24,8 +24,8 @@ Please use the community tab for discussion and questions.
|
|||
"""
|
||||
|
||||
_NAMES = [
|
||||
"AI",
|
||||
"NOT",
|
||||
"AI",
|
||||
]
|
||||
|
||||
|
||||
|
@ -56,14 +56,14 @@ class aiornot(datasets.GeneratorBasedBuilder):
|
|||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"data_dir": train_path,
|
||||
"data_dir": train_path / "train",
|
||||
"csv_file": csv_path,
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TEST,
|
||||
gen_kwargs={
|
||||
"data_dir": test_path,
|
||||
"data_dir": test_path / "test",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
@ -71,12 +71,14 @@ class aiornot(datasets.GeneratorBasedBuilder):
|
|||
def _generate_examples(self, data_dir: pathlib.Path, csv_file: Optional[pathlib.Path] = None):
|
||||
"""Generate images and labels for splits."""
|
||||
if csv_file is not None:
|
||||
df = pd.read_csv(csv_file)
|
||||
for index, row in df.iterrows():
|
||||
yield index, {
|
||||
"image": str(data_dir / row["image"]),
|
||||
"label": row["label"],
|
||||
}
|
||||
with open(csv_file, "r") as f:
|
||||
reader = csv.reader(f)
|
||||
next(reader)
|
||||
for index, row in enumerate(reader):
|
||||
yield index, {
|
||||
"image": str(data_dir / row[0]),
|
||||
"label": row[1],
|
||||
}
|
||||
else:
|
||||
rglob = pathlib.Path(data_dir).rglob("*.jpg")
|
||||
for index, filepath in enumerate(rglob):
|
||||
|
|
|
@ -1,5 +1,12 @@
|
|||
import datasets
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
dataset = datasets.load_dataset("src/dataset.py")
|
||||
|
||||
print(dataset)
|
||||
idx = 0
|
||||
plt.imshow(dataset["train"][idx]["image"])
|
||||
plt.title(dataset["train"].features["label"].names[dataset["train"][idx]["label"]])
|
||||
plt.show()
|
||||
|
||||
plt.imshow(dataset["test"][idx]["image"])
|
||||
plt.show()
|
||||
|
|
Loading…
Reference in a new issue