feat: cleanup dataset loaders a bit

This commit is contained in:
Laurent Fainsin 2023-04-01 18:31:24 +02:00
parent 8691735779
commit 2cc47bbb9e
4 changed files with 20 additions and 68 deletions

View file

@ -1,5 +1,3 @@
"""Dataset class AI or NOT HuggingFace competition."""
import json import json
import pathlib import pathlib
@ -8,8 +6,8 @@ import datasets
import numpy as np import numpy as np
prefix = "/data/local-files/?d=spheres/" prefix = "/data/local-files/?d=spheres/"
dataset_path = pathlib.Path("./dataset3/spheres/") dataset_path = pathlib.Path("./dataset_antoine_laurent/")
annotation_path = pathlib.Path("./annotations2.json") annotation_path = dataset_path / "annotations.json"
_VERSION = "1.0.0" _VERSION = "1.0.0"
@ -20,20 +18,13 @@ _HOMEPAGE = ""
_LICENSE = "" _LICENSE = ""
_NAMES = [ _NAMES = [
# "White",
# "Black",
# "Grey",
# "Red",
# "Chrome",
"Matte", "Matte",
"Shiny", "Shiny",
"Chrome", "Chrome",
] ]
class spheres(datasets.GeneratorBasedBuilder): class SphereAntoineLaurent(datasets.GeneratorBasedBuilder):
"""spheres image dataset."""
def _info(self): def _info(self):
return datasets.DatasetInfo( return datasets.DatasetInfo(
description=_DESCRIPTION, description=_DESCRIPTION,
@ -83,10 +74,6 @@ class spheres(datasets.GeneratorBasedBuilder):
image_name = image_name[len(prefix) :] image_name = image_name[len(prefix) :]
image_name = pathlib.Path(image_name) image_name = pathlib.Path(image_name)
# skip shitty images
# if "Soulages" in str(image_name):
# continue
# check image_name exists # check image_name exists
assert (dataset_path / image_name).is_file() assert (dataset_path / image_name).is_file()
@ -202,7 +189,7 @@ if __name__ == "__main__":
# load dataset # load dataset
dataset = datasets.load_dataset("src/spheres.py", split="train") dataset = datasets.load_dataset("src/spheres.py", split="train")
print("a") print("dataset loaded")
labels = dataset.features["objects"][0]["category_id"].names labels = dataset.features["objects"][0]["category_id"].names
id2label = {k: v for k, v in enumerate(labels)} id2label = {k: v for k, v in enumerate(labels)}
@ -214,16 +201,12 @@ if __name__ == "__main__":
print() print()
idx = 0 idx = 0
while True: while True:
image = dataset[idx]["image"] image = dataset[idx]["image"]
if "DSC_4234" in image.filename: if "DSC_4234" in image.filename:
break break
idx += 1 idx += 1
if idx > 10000:
break
print(f"image path: {image.filename}") print(f"image path: {image.filename}")
print(f"data: {dataset[idx]}") print(f"data: {dataset[idx]}")
@ -239,4 +222,4 @@ if __name__ == "__main__":
draw.text(bbox[:2], text=id2label[obj["category_id"]], fill="black") draw.text(bbox[:2], text=id2label[obj["category_id"]], fill="black")
# save image # save image
image.save("example.jpg") image.save("example_antoine_laurent.jpg")

View file

@ -1,12 +1,9 @@
"""Dataset class AI or NOT HuggingFace competition.""" import json
import pathlib import pathlib
import json
import datasets import datasets
dataset_path_train = pathlib.Path("/home/laurent/proj-long/dataset_illumination/") dataset_path_train = pathlib.Path("./dataset_illumination/")
dataset_path_test = pathlib.Path("/home/laurent/proj-long/dataset_illumination_test/")
_VERSION = "1.0.0" _VERSION = "1.0.0"
@ -23,9 +20,7 @@ _NAMES = [
] ]
class spheresSynth(datasets.GeneratorBasedBuilder): class SphereIllumination(datasets.GeneratorBasedBuilder):
"""spheres image dataset."""
def _info(self): def _info(self):
return datasets.DatasetInfo( return datasets.DatasetInfo(
description=_DESCRIPTION, description=_DESCRIPTION,
@ -60,12 +55,6 @@ class spheresSynth(datasets.GeneratorBasedBuilder):
"dataset_path": dataset_path_train, "dataset_path": dataset_path_train,
}, },
), ),
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={
"dataset_path": dataset_path_test,
},
),
] ]
def _generate_examples(self, dataset_path: pathlib.Path): def _generate_examples(self, dataset_path: pathlib.Path):
@ -172,4 +161,4 @@ if __name__ == "__main__":
draw.text(bbox[:2], text=id2label[obj["category_id"]], fill="black") draw.text(bbox[:2], text=id2label[obj["category_id"]], fill="black")
# save image # save image
image.save(f"example_{idx}.jpg") image.save(f"example_illumination_{idx}.jpg")

View file

@ -1,10 +1,8 @@
"""Dataset class AI or NOT HuggingFace competition."""
import pathlib import pathlib
import datasets import datasets
dataset_path = pathlib.Path("/home/laurent/proj-long/dataset_predict/") dataset_path = pathlib.Path("./dataset_predict/")
_VERSION = "1.0.0" _VERSION = "1.0.0"
@ -21,9 +19,7 @@ _NAMES = [
] ]
class spheresSynth(datasets.GeneratorBasedBuilder): class SpherePredict(datasets.GeneratorBasedBuilder):
"""spheres image dataset."""
def _info(self): def _info(self):
return datasets.DatasetInfo( return datasets.DatasetInfo(
description=_DESCRIPTION, description=_DESCRIPTION,
@ -98,16 +94,5 @@ if __name__ == "__main__":
print(f"image path: {image.filename}") print(f"image path: {image.filename}")
print(f"data: {dataset[idx]}") print(f"data: {dataset[idx]}")
draw = ImageDraw.Draw(image)
for obj in dataset[idx]["objects"]:
bbox = (
obj["bbox"][0],
obj["bbox"][1],
obj["bbox"][0] + obj["bbox"][2],
obj["bbox"][1] + obj["bbox"][3],
)
draw.rectangle(bbox, outline="red", width=3)
draw.text(bbox[:2], text=id2label[obj["category_id"]], fill="black")
# save image # save image
image.save(f"example_{idx}.jpg") image.save(f"example_predict_{idx}.jpg")

View file

@ -1,12 +1,8 @@
"""Dataset class AI or NOT HuggingFace competition."""
import pathlib import pathlib
import cv2
import datasets import datasets
import numpy as np
dataset_path = pathlib.Path("/home/laurent/proj-long/dataset_render/") dataset_path = pathlib.Path("./dataset_render/")
_VERSION = "1.0.0" _VERSION = "1.0.0"
@ -23,8 +19,7 @@ _NAMES = [
] ]
class spheresSynth(datasets.GeneratorBasedBuilder): class SphereSynth(datasets.GeneratorBasedBuilder):
"""spheres image dataset."""
def _info(self): def _info(self):
return datasets.DatasetInfo( return datasets.DatasetInfo(
@ -156,8 +151,8 @@ if __name__ == "__main__":
for idx in range(10): for idx in range(10):
image = dataset[idx]["image"] image = dataset[idx]["image"]
# print(f"image path: {image.filename}") print(f"image path: {image.filename}")
# print(f"data: {dataset[idx]}") print(f"data: {dataset[idx]}")
draw = ImageDraw.Draw(image) draw = ImageDraw.Draw(image)
for obj in dataset[idx]["objects"]: for obj in dataset[idx]["objects"]:
@ -171,4 +166,4 @@ if __name__ == "__main__":
draw.text(bbox[:2], text=id2label[obj["category_id"]], fill="black") draw.text(bbox[:2], text=id2label[obj["category_id"]], fill="black")
# save image # save image
image.save(f"example_{idx}.jpg") image.save(f"example_synth_{idx}.jpg")