fix: dataset dataloader
This commit is contained in:
parent
9b0a8bc1af
commit
684f627e73
|
@ -12,10 +12,12 @@
|
||||||
in {
|
in {
|
||||||
devShell = pkgs.mkShell {
|
devShell = pkgs.mkShell {
|
||||||
buildInputs = with pkgs; [
|
buildInputs = with pkgs; [
|
||||||
|
tk
|
||||||
poetry
|
poetry
|
||||||
python3
|
python3
|
||||||
python310Packages.numpy
|
python310Packages.numpy
|
||||||
python310Packages.datasets
|
python310Packages.datasets
|
||||||
|
python310Packages.matplotlib
|
||||||
];
|
];
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
"""Dataset class AI or NOT HuggingFace competition."""
|
"""Dataset class AI or NOT HuggingFace competition."""
|
||||||
|
|
||||||
|
import csv
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
_VERSION = "1.0.0"
|
_VERSION = "1.0.0"
|
||||||
|
|
||||||
|
@ -24,8 +24,8 @@ Please use the community tab for discussion and questions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_NAMES = [
|
_NAMES = [
|
||||||
"AI",
|
|
||||||
"NOT",
|
"NOT",
|
||||||
|
"AI",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -56,14 +56,14 @@ class aiornot(datasets.GeneratorBasedBuilder):
|
||||||
datasets.SplitGenerator(
|
datasets.SplitGenerator(
|
||||||
name=datasets.Split.TRAIN,
|
name=datasets.Split.TRAIN,
|
||||||
gen_kwargs={
|
gen_kwargs={
|
||||||
"data_dir": train_path,
|
"data_dir": train_path / "train",
|
||||||
"csv_file": csv_path,
|
"csv_file": csv_path,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
datasets.SplitGenerator(
|
datasets.SplitGenerator(
|
||||||
name=datasets.Split.TEST,
|
name=datasets.Split.TEST,
|
||||||
gen_kwargs={
|
gen_kwargs={
|
||||||
"data_dir": test_path,
|
"data_dir": test_path / "test",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
@ -71,11 +71,13 @@ class aiornot(datasets.GeneratorBasedBuilder):
|
||||||
def _generate_examples(self, data_dir: pathlib.Path, csv_file: Optional[pathlib.Path] = None):
|
def _generate_examples(self, data_dir: pathlib.Path, csv_file: Optional[pathlib.Path] = None):
|
||||||
"""Generate images and labels for splits."""
|
"""Generate images and labels for splits."""
|
||||||
if csv_file is not None:
|
if csv_file is not None:
|
||||||
df = pd.read_csv(csv_file)
|
with open(csv_file, "r") as f:
|
||||||
for index, row in df.iterrows():
|
reader = csv.reader(f)
|
||||||
|
next(reader)
|
||||||
|
for index, row in enumerate(reader):
|
||||||
yield index, {
|
yield index, {
|
||||||
"image": str(data_dir / row["image"]),
|
"image": str(data_dir / row[0]),
|
||||||
"label": row["label"],
|
"label": row[1],
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
rglob = pathlib.Path(data_dir).rglob("*.jpg")
|
rglob = pathlib.Path(data_dir).rglob("*.jpg")
|
||||||
|
|
|
@ -1,5 +1,12 @@
|
||||||
import datasets
|
import datasets
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
dataset = datasets.load_dataset("src/dataset.py")
|
dataset = datasets.load_dataset("src/dataset.py")
|
||||||
|
|
||||||
print(dataset)
|
idx = 0
|
||||||
|
plt.imshow(dataset["train"][idx]["image"])
|
||||||
|
plt.title(dataset["train"].features["label"].names[dataset["train"][idx]["label"]])
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
plt.imshow(dataset["test"][idx]["image"])
|
||||||
|
plt.show()
|
||||||
|
|
Loading…
Reference in a new issue