diff --git a/flake.nix b/flake.nix index 8578d22..c12cdda 100644 --- a/flake.nix +++ b/flake.nix @@ -12,10 +12,12 @@ in { devShell = pkgs.mkShell { buildInputs = with pkgs; [ + tk poetry python3 python310Packages.numpy python310Packages.datasets + python310Packages.matplotlib ]; }; }); diff --git a/src/dataset.py b/src/dataset.py index ee774ca..a5408ec 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -1,10 +1,10 @@ """Dataset class AI or NOT HuggingFace competition.""" +import csv import pathlib from typing import Optional import datasets -import pandas as pd _VERSION = "1.0.0" @@ -24,8 +24,8 @@ Please use the community tab for discussion and questions. """ _NAMES = [ - "AI", "NOT", + "AI", ] @@ -56,14 +56,14 @@ class aiornot(datasets.GeneratorBasedBuilder): datasets.SplitGenerator( name=datasets.Split.TRAIN, gen_kwargs={ - "data_dir": train_path, + "data_dir": train_path / "train", "csv_file": csv_path, }, ), datasets.SplitGenerator( name=datasets.Split.TEST, gen_kwargs={ - "data_dir": test_path, + "data_dir": test_path / "test", }, ), ] @@ -71,12 +71,14 @@ class aiornot(datasets.GeneratorBasedBuilder): def _generate_examples(self, data_dir: pathlib.Path, csv_file: Optional[pathlib.Path] = None): """Generate images and labels for splits.""" if csv_file is not None: - df = pd.read_csv(csv_file) - for index, row in df.iterrows(): - yield index, { - "image": str(data_dir / row["image"]), - "label": row["label"], - } + with open(csv_file, "r") as f: + reader = csv.reader(f) + next(reader) + for index, row in enumerate(reader): + yield index, { + "image": str(data_dir / row[0]), + "label": row[1], + } else: rglob = pathlib.Path(data_dir).rglob("*.jpg") for index, filepath in enumerate(rglob): diff --git a/src/tests/dataset.py b/src/tests/dataset.py index c47355f..bd1af43 100644 --- a/src/tests/dataset.py +++ b/src/tests/dataset.py @@ -1,5 +1,12 @@ import datasets +import matplotlib.pyplot as plt dataset = datasets.load_dataset("src/dataset.py") -print(dataset) +idx = 0 +plt.imshow(dataset["train"][idx]["image"]) +plt.title(dataset["train"].features["label"].names[dataset["train"][idx]["label"]]) +plt.show() + +plt.imshow(dataset["test"][idx]["image"]) +plt.show()