From 89f9112fca82e1cdbcdb0ef7b25092f1d36fae2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laure=CE=B7t?= Date: Thu, 2 Feb 2023 11:14:52 +0100 Subject: [PATCH] chore: move dataset to submodule --- .gitmodules | 3 ++ aiornot_datasets | 1 + src/dataset.py | 88 -------------------------------------------- src/tests/dataset.py | 10 ++++- 4 files changed, 13 insertions(+), 89 deletions(-) create mode 100644 .gitmodules create mode 160000 aiornot_datasets delete mode 100644 src/dataset.py diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..00efabb --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "aiornot_datasets"] + path = aiornot_datasets + url = https://huggingface.co/datasets/tocard-inc/aiornot diff --git a/aiornot_datasets b/aiornot_datasets new file mode 160000 index 0000000..a90618d --- /dev/null +++ b/aiornot_datasets @@ -0,0 +1 @@ +Subproject commit a90618df992a19c775b6b0fb7e0de0fd45a4d505 diff --git a/src/dataset.py b/src/dataset.py deleted file mode 100644 index a5408ec..0000000 --- a/src/dataset.py +++ /dev/null @@ -1,88 +0,0 @@ -"""Dataset class AI or NOT HuggingFace competition.""" - -import csv -import pathlib -from typing import Optional - -import datasets - -_VERSION = "1.0.0" - -_GIT_COMMIT_REVISION = "b843a82bd712648b2fe0dc043cf8a04475491d38" - -_BASE_URLS = { - "train": f"https://huggingface.co/datasets/competitions/aiornot/resolve/{_GIT_COMMIT_REVISION}/train.zip", - "test": f"https://huggingface.co/datasets/competitions/aiornot/resolve/{_GIT_COMMIT_REVISION}/test.zip", - "csv": f"https://huggingface.co/datasets/competitions/aiornot/resolve/{_GIT_COMMIT_REVISION}/train.csv", -} -_HOMEPAGE = "https://huggingface.co/spaces/competitions/aiornot" - -_DESCRIPTION = """ -The dataset consists of approximately 31000 images, some of which have been generated by ai. -Your task is to build a model that can identify ai generated images. -Please use the community tab for discussion and questions. -""" - -_NAMES = [ - "NOT", - "AI", -] - - -class aiornot(datasets.GeneratorBasedBuilder): - """Food-101 Images dataset.""" - - def _info(self): - return datasets.DatasetInfo( - description=_DESCRIPTION, - version=_VERSION, - features=datasets.Features( - { - "image": datasets.Image(), - "label": datasets.ClassLabel(names=_NAMES), - } - ), - supervised_keys=("image", "label"), - homepage=_HOMEPAGE, - task_templates=[datasets.tasks.ImageClassification(image_column="image", label_column="label")], - ) - - def _split_generators(self, dl_manager): - train_path = pathlib.Path(dl_manager.download_and_extract(_BASE_URLS["train"])) - test_path = pathlib.Path(dl_manager.download_and_extract(_BASE_URLS["test"])) - csv_path = pathlib.Path(dl_manager.download(_BASE_URLS["csv"])) - - return [ - datasets.SplitGenerator( - name=datasets.Split.TRAIN, - gen_kwargs={ - "data_dir": train_path / "train", - "csv_file": csv_path, - }, - ), - datasets.SplitGenerator( - name=datasets.Split.TEST, - gen_kwargs={ - "data_dir": test_path / "test", - }, - ), - ] - - def _generate_examples(self, data_dir: pathlib.Path, csv_file: Optional[pathlib.Path] = None): - """Generate images and labels for splits.""" - if csv_file is not None: - with open(csv_file, "r") as f: - reader = csv.reader(f) - next(reader) - for index, row in enumerate(reader): - yield index, { - "image": str(data_dir / row[0]), - "label": row[1], - } - else: - rglob = pathlib.Path(data_dir).rglob("*.jpg") - for index, filepath in enumerate(rglob): - yield index, { - "image": str(filepath), - "label": -1, - } diff --git a/src/tests/dataset.py b/src/tests/dataset.py index bd1af43..5d60e54 100644 --- a/src/tests/dataset.py +++ b/src/tests/dataset.py @@ -3,9 +3,17 @@ import matplotlib.pyplot as plt dataset = datasets.load_dataset("src/dataset.py") +labels = dataset["train"].features["label"].names +print(labels) + +id2label = {k: v for k, v in enumerate(labels)} +label2id = {v: k for k, v in enumerate(labels)} +print(label2id) +print(id2label) + idx = 0 plt.imshow(dataset["train"][idx]["image"]) -plt.title(dataset["train"].features["label"].names[dataset["train"][idx]["label"]]) +plt.title(id2label[dataset["train"][idx]["label"]]) plt.show() plt.imshow(dataset["test"][idx]["image"])