205 lines
5.9 KiB
Plaintext
205 lines
5.9 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import sqlite3\n",
|
|
"import PIL.Image\n",
|
|
"import glob\n",
|
|
"import os\n",
|
|
"\n",
|
|
"import tensorflow as tf\n",
|
|
"import tensorflow_addons as tfa\n",
|
|
"from tensorflow.keras.models import Sequential\n",
|
|
"from tensorflow.keras.layers import InputLayer, Dense, Flatten, Conv2D, MaxPooling2D\n",
|
|
"from tensorflow.keras import optimizers\n",
|
|
"\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"%matplotlib inline\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"IMAGE_SIZE = (400, 150, 3)\n",
|
|
"RESIZED_SIZE = (100, 50, 3)\n",
|
|
"RESIZED_SIZE_PIL = (RESIZED_SIZE[1], RESIZED_SIZE[0], RESIZED_SIZE[2])\n",
|
|
"DATASET_PATH = \"./data/\"\n",
|
|
"DATASET_PATH = os.path.abspath(DATASET_PATH)\n",
|
|
"\n",
|
|
"print(DATASET_PATH)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"AUTOTUNE = tf.data.experimental.AUTOTUNE\n",
|
|
"BATCH_SIZE = 32\n",
|
|
"SHUFFLE_SIZE = 64\n",
|
|
"LIMIT = 1000\n",
|
|
"\n",
|
|
"\n",
|
|
"def customGenerator():\n",
|
|
" data = (\n",
|
|
" sqlite3.connect(f\"{DATASET_PATH}/index.db\")\n",
|
|
" .execute(f\"SELECT uuid, model from data LIMIT {LIMIT}\")\n",
|
|
" .fetchall()\n",
|
|
" )\n",
|
|
"\n",
|
|
" for uuid, model in data:\n",
|
|
" img = tf.io.read_file(f\"{DATASET_PATH}/{uuid}.jpg\")\n",
|
|
" img = tf.image.decode_jpeg(img, channels=IMAGE_SIZE[2])\n",
|
|
" img = tf.image.convert_image_dtype(img, tf.float32)\n",
|
|
" img = tf.image.resize(img, RESIZED_SIZE[:-1])\n",
|
|
"\n",
|
|
" label = tf.convert_to_tensor(model, dtype=tf.uint8)\n",
|
|
"\n",
|
|
" yield img, label\n",
|
|
"\n",
|
|
"\n",
|
|
"def cutout(image, label):\n",
|
|
" img = tfa.image.random_cutout(image, (6, 6), constant_values=1)\n",
|
|
" return (img, label)\n",
|
|
"\n",
|
|
"\n",
|
|
"def rotate(image, label):\n",
|
|
" img = tfa.image.rotate(image, tf.constant(np.pi))\n",
|
|
" return (img, label)\n",
|
|
"\n",
|
|
"\n",
|
|
"def set_shapes(image, label):\n",
|
|
" image.set_shape(RESIZED_SIZE)\n",
|
|
" label.set_shape([])\n",
|
|
" return image, label\n",
|
|
"\n",
|
|
"\n",
|
|
"dataset = tf.data.Dataset.from_generator(generator=customGenerator, output_types=(tf.float32, tf.uint8))\n",
|
|
"\n",
|
|
"(dataset_length,) = sqlite3.connect(f\"{DATASET_PATH}/index.db\").execute(\"SELECT count(uuid) from data\").fetchone()\n",
|
|
"dataset_length = min(dataset_length, LIMIT)\n",
|
|
"\n",
|
|
"print(f\"dataset_length = {dataset_length}\")\n",
|
|
"print(f\"batch size = {BATCH_SIZE}\")\n",
|
|
"print(f\"number of batchs = {dataset_length // BATCH_SIZE}\")\n",
|
|
"\n",
|
|
"print()\n",
|
|
"\n",
|
|
"train_size = int(0.8 * dataset_length / BATCH_SIZE)\n",
|
|
"print(f\"train_size = {train_size}\")\n",
|
|
"print(f\"validation_size = {dataset_length - train_size}\")\n",
|
|
"\n",
|
|
"dataset = (\n",
|
|
" dataset.shuffle(SHUFFLE_SIZE)\n",
|
|
" .map(set_shapes)\n",
|
|
" .batch(BATCH_SIZE)\n",
|
|
" # .map(cutout)\n",
|
|
" .prefetch(AUTOTUNE)\n",
|
|
")\n",
|
|
"\n",
|
|
"dataset_train = dataset.take(train_size)\n",
|
|
"dataset_validate = dataset.skip(train_size)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model = Sequential(\n",
|
|
" [\n",
|
|
" InputLayer(input_shape=RESIZED_SIZE),\n",
|
|
" Conv2D(32, 3, activation=\"relu\"),\n",
|
|
" MaxPooling2D(pool_size=(2, 2)),\n",
|
|
" Conv2D(64, 3, activation=\"relu\"),\n",
|
|
" MaxPooling2D(pool_size=(2, 2)),\n",
|
|
" Conv2D(92, 3, activation=\"relu\"),\n",
|
|
" MaxPooling2D(pool_size=(2, 2)),\n",
|
|
" Flatten(),\n",
|
|
" Dense(250, activation=\"relu\"),\n",
|
|
" Dense(4, activation=\"softmax\"),\n",
|
|
" ]\n",
|
|
")\n",
|
|
"\n",
|
|
"model.summary()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"adam = optimizers.Adam(learning_rate=7e-6)\n",
|
|
"model.compile(optimizer=adam, loss=\"sparse_categorical_crossentropy\", metrics=[\"accuracy\"])\n",
|
|
"history = model.fit(dataset_train, validation_data=dataset_validate, epochs=5, batch_size=BATCH_SIZE)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def plot_training_analysis():\n",
|
|
" acc = history.history[\"accuracy\"]\n",
|
|
" val_acc = history.history[\"val_accuracy\"]\n",
|
|
" loss = history.history[\"loss\"]\n",
|
|
" val_loss = history.history[\"val_loss\"]\n",
|
|
"\n",
|
|
" epochs = range(len(loss))\n",
|
|
"\n",
|
|
" plt.plot(epochs, acc, \"b\", linestyle=\"--\", label=\"Training acc\")\n",
|
|
" plt.plot(epochs, val_acc, \"g\", label=\"Validation acc\")\n",
|
|
" plt.title(\"Training and validation accuracy\")\n",
|
|
" plt.legend()\n",
|
|
"\n",
|
|
" plt.figure()\n",
|
|
"\n",
|
|
" plt.plot(epochs, loss, \"b\", linestyle=\"--\", label=\"Training loss\")\n",
|
|
" plt.plot(epochs, val_loss, \"g\", label=\"Validation loss\")\n",
|
|
" plt.title(\"Training and validation loss\")\n",
|
|
" plt.legend()\n",
|
|
"\n",
|
|
" plt.show()\n",
|
|
"\n",
|
|
"\n",
|
|
"plot_training_analysis()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Save the weights\n",
|
|
"# model.save('models/rot_25e')"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"interpreter": {
|
|
"hash": "e55666fbbf217aa3df372b978577f47b6009e2f78e2ec76a584f49cd54a1e62c"
|
|
},
|
|
"kernelspec": {
|
|
"display_name": ".env",
|
|
"language": "python",
|
|
"name": ".env"
|
|
},
|
|
"orig_nbformat": 4
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|