feat: matrice de confusion

This commit is contained in:
Laureηt 2022-04-19 09:21:04 +02:00
parent bf92ebfed7
commit c4134de040
No known key found for this signature in database
GPG key ID: D88C6B294FD40994
2 changed files with 330 additions and 118 deletions

File diff suppressed because one or more lines are too long

204
src/notebook_test.ipynb Normal file
View file

@ -0,0 +1,204 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import sqlite3\n",
"import PIL.Image\n",
"import glob\n",
"import os\n",
"\n",
"import tensorflow as tf\n",
"import tensorflow_addons as tfa\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import InputLayer, Dense, Flatten, Conv2D, MaxPooling2D\n",
"from tensorflow.keras import optimizers\n",
"\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"IMAGE_SIZE = (400, 150, 3)\n",
"RESIZED_SIZE = (100, 50, 3)\n",
"RESIZED_SIZE_PIL = (RESIZED_SIZE[1], RESIZED_SIZE[0], RESIZED_SIZE[2])\n",
"DATASET_PATH = \"./data/\"\n",
"DATASET_PATH = os.path.abspath(DATASET_PATH)\n",
"\n",
"print(DATASET_PATH)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"AUTOTUNE = tf.data.experimental.AUTOTUNE\n",
"BATCH_SIZE = 32\n",
"SHUFFLE_SIZE = 64\n",
"LIMIT = 1000\n",
"\n",
"\n",
"def customGenerator():\n",
" data = (\n",
" sqlite3.connect(f\"{DATASET_PATH}/index.db\")\n",
" .execute(f\"SELECT uuid, model from data LIMIT {LIMIT}\")\n",
" .fetchall()\n",
" )\n",
"\n",
" for uuid, model in data:\n",
" img = tf.io.read_file(f\"{DATASET_PATH}/{uuid}.jpg\")\n",
" img = tf.image.decode_jpeg(img, channels=IMAGE_SIZE[2])\n",
" img = tf.image.convert_image_dtype(img, tf.float32)\n",
" img = tf.image.resize(img, RESIZED_SIZE[:-1])\n",
"\n",
" label = tf.convert_to_tensor(model, dtype=tf.uint8)\n",
"\n",
" yield img, label\n",
"\n",
"\n",
"def cutout(image, label):\n",
" img = tfa.image.random_cutout(image, (6, 6), constant_values=1)\n",
" return (img, label)\n",
"\n",
"\n",
"def rotate(image, label):\n",
" img = tfa.image.rotate(image, tf.constant(np.pi))\n",
" return (img, label)\n",
"\n",
"\n",
"def set_shapes(image, label):\n",
" image.set_shape(RESIZED_SIZE)\n",
" label.set_shape([])\n",
" return image, label\n",
"\n",
"\n",
"dataset = tf.data.Dataset.from_generator(generator=customGenerator, output_types=(tf.float32, tf.uint8))\n",
"\n",
"(dataset_length,) = sqlite3.connect(f\"{DATASET_PATH}/index.db\").execute(\"SELECT count(uuid) from data\").fetchone()\n",
"dataset_length = min(dataset_length, LIMIT)\n",
"\n",
"print(f\"dataset_length = {dataset_length}\")\n",
"print(f\"batch size = {BATCH_SIZE}\")\n",
"print(f\"number of batchs = {dataset_length // BATCH_SIZE}\")\n",
"\n",
"print()\n",
"\n",
"train_size = int(0.8 * dataset_length / BATCH_SIZE)\n",
"print(f\"train_size = {train_size}\")\n",
"print(f\"validation_size = {dataset_length - train_size}\")\n",
"\n",
"dataset = (\n",
" dataset.shuffle(SHUFFLE_SIZE)\n",
" .map(set_shapes)\n",
" .batch(BATCH_SIZE)\n",
" # .map(cutout)\n",
" .prefetch(AUTOTUNE)\n",
")\n",
"\n",
"dataset_train = dataset.take(train_size)\n",
"dataset_validate = dataset.skip(train_size)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = Sequential(\n",
" [\n",
" InputLayer(input_shape=RESIZED_SIZE),\n",
" Conv2D(32, 3, activation=\"relu\"),\n",
" MaxPooling2D(pool_size=(2, 2)),\n",
" Conv2D(64, 3, activation=\"relu\"),\n",
" MaxPooling2D(pool_size=(2, 2)),\n",
" Conv2D(92, 3, activation=\"relu\"),\n",
" MaxPooling2D(pool_size=(2, 2)),\n",
" Flatten(),\n",
" Dense(250, activation=\"relu\"),\n",
" Dense(4, activation=\"softmax\"),\n",
" ]\n",
")\n",
"\n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"adam = optimizers.Adam(learning_rate=7e-6)\n",
"model.compile(optimizer=adam, loss=\"sparse_categorical_crossentropy\", metrics=[\"accuracy\"])\n",
"history = model.fit(dataset_train, validation_data=dataset_validate, epochs=5, batch_size=BATCH_SIZE)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def plot_training_analysis():\n",
" acc = history.history[\"accuracy\"]\n",
" val_acc = history.history[\"val_accuracy\"]\n",
" loss = history.history[\"loss\"]\n",
" val_loss = history.history[\"val_loss\"]\n",
"\n",
" epochs = range(len(loss))\n",
"\n",
" plt.plot(epochs, acc, \"b\", linestyle=\"--\", label=\"Training acc\")\n",
" plt.plot(epochs, val_acc, \"g\", label=\"Validation acc\")\n",
" plt.title(\"Training and validation accuracy\")\n",
" plt.legend()\n",
"\n",
" plt.figure()\n",
"\n",
" plt.plot(epochs, loss, \"b\", linestyle=\"--\", label=\"Training loss\")\n",
" plt.plot(epochs, val_loss, \"g\", label=\"Validation loss\")\n",
" plt.title(\"Training and validation loss\")\n",
" plt.legend()\n",
"\n",
" plt.show()\n",
"\n",
"\n",
"plot_training_analysis()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Save the weights\n",
"# model.save('models/rot_25e')"
]
}
],
"metadata": {
"interpreter": {
"hash": "e55666fbbf217aa3df372b978577f47b6009e2f78e2ec76a584f49cd54a1e62c"
},
"kernelspec": {
"display_name": ".env",
"language": "python",
"name": ".env"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}