{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import sqlite3\n", "import PIL.Image\n", "import glob\n", "import os\n", "\n", "import tensorflow as tf\n", "import tensorflow_addons as tfa\n", "from tensorflow.keras.models import Sequential\n", "from tensorflow.keras.layers import InputLayer, Dense, Flatten, Conv2D, MaxPooling2D\n", "from tensorflow.keras import optimizers\n", "\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "IMAGE_SIZE = (400, 150, 3)\n", "RESIZED_SIZE = (100, 50, 3)\n", "RESIZED_SIZE_PIL = (RESIZED_SIZE[1], RESIZED_SIZE[0], RESIZED_SIZE[2])\n", "DATASET_PATH = \"./data/\"\n", "DATASET_PATH = os.path.abspath(DATASET_PATH)\n", "\n", "print(DATASET_PATH)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "AUTOTUNE = tf.data.experimental.AUTOTUNE\n", "BATCH_SIZE = 32\n", "SHUFFLE_SIZE = 64\n", "LIMIT = 1000\n", "\n", "\n", "def customGenerator():\n", " data = (\n", " sqlite3.connect(f\"{DATASET_PATH}/index.db\")\n", " .execute(f\"SELECT uuid, model from data LIMIT {LIMIT}\")\n", " .fetchall()\n", " )\n", "\n", " for uuid, model in data:\n", " img = tf.io.read_file(f\"{DATASET_PATH}/{uuid}.jpg\")\n", " img = tf.image.decode_jpeg(img, channels=IMAGE_SIZE[2])\n", " img = tf.image.convert_image_dtype(img, tf.float32)\n", " img = tf.image.resize(img, RESIZED_SIZE[:-1])\n", "\n", " label = tf.convert_to_tensor(model, dtype=tf.uint8)\n", "\n", " yield img, label\n", "\n", "\n", "def cutout(image, label):\n", " img = tfa.image.random_cutout(image, (6, 6), constant_values=1)\n", " return (img, label)\n", "\n", "\n", "def rotate(image, label):\n", " img = tfa.image.rotate(image, tf.constant(np.pi))\n", " return (img, label)\n", "\n", "\n", "def set_shapes(image, label):\n", " image.set_shape(RESIZED_SIZE)\n", " label.set_shape([])\n", " return image, label\n", "\n", "\n", "dataset = tf.data.Dataset.from_generator(generator=customGenerator, output_types=(tf.float32, tf.uint8))\n", "\n", "(dataset_length,) = sqlite3.connect(f\"{DATASET_PATH}/index.db\").execute(\"SELECT count(uuid) from data\").fetchone()\n", "dataset_length = min(dataset_length, LIMIT)\n", "\n", "print(f\"dataset_length = {dataset_length}\")\n", "print(f\"batch size = {BATCH_SIZE}\")\n", "print(f\"number of batchs = {dataset_length // BATCH_SIZE}\")\n", "\n", "print()\n", "\n", "train_size = int(0.8 * dataset_length / BATCH_SIZE)\n", "print(f\"train_size = {train_size}\")\n", "print(f\"validation_size = {dataset_length - train_size}\")\n", "\n", "dataset = (\n", " dataset.shuffle(SHUFFLE_SIZE)\n", " .map(set_shapes)\n", " .batch(BATCH_SIZE)\n", " # .map(cutout)\n", " .prefetch(AUTOTUNE)\n", ")\n", "\n", "dataset_train = dataset.take(train_size)\n", "dataset_validate = dataset.skip(train_size)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = Sequential(\n", " [\n", " InputLayer(input_shape=RESIZED_SIZE),\n", " Conv2D(32, 3, activation=\"relu\"),\n", " MaxPooling2D(pool_size=(2, 2)),\n", " Conv2D(64, 3, activation=\"relu\"),\n", " MaxPooling2D(pool_size=(2, 2)),\n", " Conv2D(92, 3, activation=\"relu\"),\n", " MaxPooling2D(pool_size=(2, 2)),\n", " Flatten(),\n", " Dense(250, activation=\"relu\"),\n", " Dense(4, activation=\"softmax\"),\n", " ]\n", ")\n", "\n", "model.summary()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "adam = optimizers.Adam(learning_rate=7e-6)\n", "model.compile(optimizer=adam, loss=\"sparse_categorical_crossentropy\", metrics=[\"accuracy\"])\n", "history = model.fit(dataset_train, validation_data=dataset_validate, epochs=5, batch_size=BATCH_SIZE)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_training_analysis():\n", " acc = history.history[\"accuracy\"]\n", " val_acc = history.history[\"val_accuracy\"]\n", " loss = history.history[\"loss\"]\n", " val_loss = history.history[\"val_loss\"]\n", "\n", " epochs = range(len(loss))\n", "\n", " plt.plot(epochs, acc, \"b\", linestyle=\"--\", label=\"Training acc\")\n", " plt.plot(epochs, val_acc, \"g\", label=\"Validation acc\")\n", " plt.title(\"Training and validation accuracy\")\n", " plt.legend()\n", "\n", " plt.figure()\n", "\n", " plt.plot(epochs, loss, \"b\", linestyle=\"--\", label=\"Training loss\")\n", " plt.plot(epochs, val_loss, \"g\", label=\"Validation loss\")\n", " plt.title(\"Training and validation loss\")\n", " plt.legend()\n", "\n", " plt.show()\n", "\n", "\n", "plot_training_analysis()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Save the weights\n", "# model.save('models/rot_25e')" ] } ], "metadata": { "interpreter": { "hash": "e55666fbbf217aa3df372b978577f47b6009e2f78e2ec76a584f49cd54a1e62c" }, "kernelspec": { "display_name": ".env", "language": "python", "name": ".env" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }