feat: cleanup dataset loaders a bit
This commit is contained in:
parent
8691735779
commit
2cc47bbb9e
|
@ -1,5 +1,3 @@
|
|||
"""Dataset class AI or NOT HuggingFace competition."""
|
||||
|
||||
import json
|
||||
import pathlib
|
||||
|
||||
|
@ -8,8 +6,8 @@ import datasets
|
|||
import numpy as np
|
||||
|
||||
prefix = "/data/local-files/?d=spheres/"
|
||||
dataset_path = pathlib.Path("./dataset3/spheres/")
|
||||
annotation_path = pathlib.Path("./annotations2.json")
|
||||
dataset_path = pathlib.Path("./dataset_antoine_laurent/")
|
||||
annotation_path = dataset_path / "annotations.json"
|
||||
|
||||
_VERSION = "1.0.0"
|
||||
|
||||
|
@ -20,20 +18,13 @@ _HOMEPAGE = ""
|
|||
_LICENSE = ""
|
||||
|
||||
_NAMES = [
|
||||
# "White",
|
||||
# "Black",
|
||||
# "Grey",
|
||||
# "Red",
|
||||
# "Chrome",
|
||||
"Matte",
|
||||
"Shiny",
|
||||
"Chrome",
|
||||
]
|
||||
|
||||
|
||||
class spheres(datasets.GeneratorBasedBuilder):
|
||||
"""spheres image dataset."""
|
||||
|
||||
class SphereAntoineLaurent(datasets.GeneratorBasedBuilder):
|
||||
def _info(self):
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
|
@ -83,10 +74,6 @@ class spheres(datasets.GeneratorBasedBuilder):
|
|||
image_name = image_name[len(prefix) :]
|
||||
image_name = pathlib.Path(image_name)
|
||||
|
||||
# skip shitty images
|
||||
# if "Soulages" in str(image_name):
|
||||
# continue
|
||||
|
||||
# check image_name exists
|
||||
assert (dataset_path / image_name).is_file()
|
||||
|
||||
|
@ -202,7 +189,7 @@ if __name__ == "__main__":
|
|||
|
||||
# load dataset
|
||||
dataset = datasets.load_dataset("src/spheres.py", split="train")
|
||||
print("a")
|
||||
print("dataset loaded")
|
||||
|
||||
labels = dataset.features["objects"][0]["category_id"].names
|
||||
id2label = {k: v for k, v in enumerate(labels)}
|
||||
|
@ -214,16 +201,12 @@ if __name__ == "__main__":
|
|||
print()
|
||||
|
||||
idx = 0
|
||||
|
||||
while True:
|
||||
image = dataset[idx]["image"]
|
||||
if "DSC_4234" in image.filename:
|
||||
break
|
||||
idx += 1
|
||||
|
||||
if idx > 10000:
|
||||
break
|
||||
|
||||
print(f"image path: {image.filename}")
|
||||
print(f"data: {dataset[idx]}")
|
||||
|
||||
|
@ -239,4 +222,4 @@ if __name__ == "__main__":
|
|||
draw.text(bbox[:2], text=id2label[obj["category_id"]], fill="black")
|
||||
|
||||
# save image
|
||||
image.save("example.jpg")
|
||||
image.save("example_antoine_laurent.jpg")
|
|
@ -1,12 +1,9 @@
|
|||
"""Dataset class AI or NOT HuggingFace competition."""
|
||||
|
||||
import json
|
||||
import pathlib
|
||||
|
||||
import json
|
||||
import datasets
|
||||
|
||||
dataset_path_train = pathlib.Path("/home/laurent/proj-long/dataset_illumination/")
|
||||
dataset_path_test = pathlib.Path("/home/laurent/proj-long/dataset_illumination_test/")
|
||||
dataset_path_train = pathlib.Path("./dataset_illumination/")
|
||||
|
||||
_VERSION = "1.0.0"
|
||||
|
||||
|
@ -23,9 +20,7 @@ _NAMES = [
|
|||
]
|
||||
|
||||
|
||||
class spheresSynth(datasets.GeneratorBasedBuilder):
|
||||
"""spheres image dataset."""
|
||||
|
||||
class SphereIllumination(datasets.GeneratorBasedBuilder):
|
||||
def _info(self):
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
|
@ -60,12 +55,6 @@ class spheresSynth(datasets.GeneratorBasedBuilder):
|
|||
"dataset_path": dataset_path_train,
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TEST,
|
||||
gen_kwargs={
|
||||
"dataset_path": dataset_path_test,
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
def _generate_examples(self, dataset_path: pathlib.Path):
|
||||
|
@ -172,4 +161,4 @@ if __name__ == "__main__":
|
|||
draw.text(bbox[:2], text=id2label[obj["category_id"]], fill="black")
|
||||
|
||||
# save image
|
||||
image.save(f"example_{idx}.jpg")
|
||||
image.save(f"example_illumination_{idx}.jpg")
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
"""Dataset class AI or NOT HuggingFace competition."""
|
||||
|
||||
import pathlib
|
||||
|
||||
import datasets
|
||||
|
||||
dataset_path = pathlib.Path("/home/laurent/proj-long/dataset_predict/")
|
||||
dataset_path = pathlib.Path("./dataset_predict/")
|
||||
|
||||
_VERSION = "1.0.0"
|
||||
|
||||
|
@ -21,9 +19,7 @@ _NAMES = [
|
|||
]
|
||||
|
||||
|
||||
class spheresSynth(datasets.GeneratorBasedBuilder):
|
||||
"""spheres image dataset."""
|
||||
|
||||
class SpherePredict(datasets.GeneratorBasedBuilder):
|
||||
def _info(self):
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
|
@ -98,16 +94,5 @@ if __name__ == "__main__":
|
|||
print(f"image path: {image.filename}")
|
||||
print(f"data: {dataset[idx]}")
|
||||
|
||||
draw = ImageDraw.Draw(image)
|
||||
for obj in dataset[idx]["objects"]:
|
||||
bbox = (
|
||||
obj["bbox"][0],
|
||||
obj["bbox"][1],
|
||||
obj["bbox"][0] + obj["bbox"][2],
|
||||
obj["bbox"][1] + obj["bbox"][3],
|
||||
)
|
||||
draw.rectangle(bbox, outline="red", width=3)
|
||||
draw.text(bbox[:2], text=id2label[obj["category_id"]], fill="black")
|
||||
|
||||
# save image
|
||||
image.save(f"example_{idx}.jpg")
|
||||
image.save(f"example_predict_{idx}.jpg")
|
||||
|
|
|
@ -1,12 +1,8 @@
|
|||
"""Dataset class AI or NOT HuggingFace competition."""
|
||||
|
||||
import pathlib
|
||||
|
||||
import cv2
|
||||
import datasets
|
||||
import numpy as np
|
||||
|
||||
dataset_path = pathlib.Path("/home/laurent/proj-long/dataset_render/")
|
||||
dataset_path = pathlib.Path("./dataset_render/")
|
||||
|
||||
_VERSION = "1.0.0"
|
||||
|
||||
|
@ -23,8 +19,7 @@ _NAMES = [
|
|||
]
|
||||
|
||||
|
||||
class spheresSynth(datasets.GeneratorBasedBuilder):
|
||||
"""spheres image dataset."""
|
||||
class SphereSynth(datasets.GeneratorBasedBuilder):
|
||||
|
||||
def _info(self):
|
||||
return datasets.DatasetInfo(
|
||||
|
@ -156,8 +151,8 @@ if __name__ == "__main__":
|
|||
for idx in range(10):
|
||||
image = dataset[idx]["image"]
|
||||
|
||||
# print(f"image path: {image.filename}")
|
||||
# print(f"data: {dataset[idx]}")
|
||||
print(f"image path: {image.filename}")
|
||||
print(f"data: {dataset[idx]}")
|
||||
|
||||
draw = ImageDraw.Draw(image)
|
||||
for obj in dataset[idx]["objects"]:
|
||||
|
@ -171,4 +166,4 @@ if __name__ == "__main__":
|
|||
draw.text(bbox[:2], text=id2label[obj["category_id"]], fill="black")
|
||||
|
||||
# save image
|
||||
image.save(f"example_{idx}.jpg")
|
||||
image.save(f"example_synth_{idx}.jpg")
|
||||
|
|
Loading…
Reference in a new issue