feat: grad-cam ?
This commit is contained in:
parent
37d7c5da67
commit
95f31269af
File diff suppressed because one or more lines are too long
|
@ -1,204 +0,0 @@
|
||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import numpy as np\n",
|
|
||||||
"import sqlite3\n",
|
|
||||||
"import PIL.Image\n",
|
|
||||||
"import glob\n",
|
|
||||||
"import os\n",
|
|
||||||
"\n",
|
|
||||||
"import tensorflow as tf\n",
|
|
||||||
"import tensorflow_addons as tfa\n",
|
|
||||||
"from tensorflow.keras.models import Sequential\n",
|
|
||||||
"from tensorflow.keras.layers import InputLayer, Dense, Flatten, Conv2D, MaxPooling2D\n",
|
|
||||||
"from tensorflow.keras import optimizers\n",
|
|
||||||
"\n",
|
|
||||||
"import matplotlib.pyplot as plt\n",
|
|
||||||
"%matplotlib inline\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"IMAGE_SIZE = (400, 150, 3)\n",
|
|
||||||
"RESIZED_SIZE = (100, 50, 3)\n",
|
|
||||||
"RESIZED_SIZE_PIL = (RESIZED_SIZE[1], RESIZED_SIZE[0], RESIZED_SIZE[2])\n",
|
|
||||||
"DATASET_PATH = \"./data/\"\n",
|
|
||||||
"DATASET_PATH = os.path.abspath(DATASET_PATH)\n",
|
|
||||||
"\n",
|
|
||||||
"print(DATASET_PATH)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"AUTOTUNE = tf.data.experimental.AUTOTUNE\n",
|
|
||||||
"BATCH_SIZE = 32\n",
|
|
||||||
"SHUFFLE_SIZE = 64\n",
|
|
||||||
"LIMIT = 1000\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"def customGenerator():\n",
|
|
||||||
" data = (\n",
|
|
||||||
" sqlite3.connect(f\"{DATASET_PATH}/index.db\")\n",
|
|
||||||
" .execute(f\"SELECT uuid, model from data LIMIT {LIMIT}\")\n",
|
|
||||||
" .fetchall()\n",
|
|
||||||
" )\n",
|
|
||||||
"\n",
|
|
||||||
" for uuid, model in data:\n",
|
|
||||||
" img = tf.io.read_file(f\"{DATASET_PATH}/{uuid}.jpg\")\n",
|
|
||||||
" img = tf.image.decode_jpeg(img, channels=IMAGE_SIZE[2])\n",
|
|
||||||
" img = tf.image.convert_image_dtype(img, tf.float32)\n",
|
|
||||||
" img = tf.image.resize(img, RESIZED_SIZE[:-1])\n",
|
|
||||||
"\n",
|
|
||||||
" label = tf.convert_to_tensor(model, dtype=tf.uint8)\n",
|
|
||||||
"\n",
|
|
||||||
" yield img, label\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"def cutout(image, label):\n",
|
|
||||||
" img = tfa.image.random_cutout(image, (6, 6), constant_values=1)\n",
|
|
||||||
" return (img, label)\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"def rotate(image, label):\n",
|
|
||||||
" img = tfa.image.rotate(image, tf.constant(np.pi))\n",
|
|
||||||
" return (img, label)\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"def set_shapes(image, label):\n",
|
|
||||||
" image.set_shape(RESIZED_SIZE)\n",
|
|
||||||
" label.set_shape([])\n",
|
|
||||||
" return image, label\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"dataset = tf.data.Dataset.from_generator(generator=customGenerator, output_types=(tf.float32, tf.uint8))\n",
|
|
||||||
"\n",
|
|
||||||
"(dataset_length,) = sqlite3.connect(f\"{DATASET_PATH}/index.db\").execute(\"SELECT count(uuid) from data\").fetchone()\n",
|
|
||||||
"dataset_length = min(dataset_length, LIMIT)\n",
|
|
||||||
"\n",
|
|
||||||
"print(f\"dataset_length = {dataset_length}\")\n",
|
|
||||||
"print(f\"batch size = {BATCH_SIZE}\")\n",
|
|
||||||
"print(f\"number of batchs = {dataset_length // BATCH_SIZE}\")\n",
|
|
||||||
"\n",
|
|
||||||
"print()\n",
|
|
||||||
"\n",
|
|
||||||
"train_size = int(0.8 * dataset_length / BATCH_SIZE)\n",
|
|
||||||
"print(f\"train_size = {train_size}\")\n",
|
|
||||||
"print(f\"validation_size = {dataset_length - train_size}\")\n",
|
|
||||||
"\n",
|
|
||||||
"dataset = (\n",
|
|
||||||
" dataset.shuffle(SHUFFLE_SIZE)\n",
|
|
||||||
" .map(set_shapes)\n",
|
|
||||||
" .batch(BATCH_SIZE)\n",
|
|
||||||
" # .map(cutout)\n",
|
|
||||||
" .prefetch(AUTOTUNE)\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
|
||||||
"dataset_train = dataset.take(train_size)\n",
|
|
||||||
"dataset_validate = dataset.skip(train_size)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"model = Sequential(\n",
|
|
||||||
" [\n",
|
|
||||||
" InputLayer(input_shape=RESIZED_SIZE),\n",
|
|
||||||
" Conv2D(32, 3, activation=\"relu\"),\n",
|
|
||||||
" MaxPooling2D(pool_size=(2, 2)),\n",
|
|
||||||
" Conv2D(64, 3, activation=\"relu\"),\n",
|
|
||||||
" MaxPooling2D(pool_size=(2, 2)),\n",
|
|
||||||
" Conv2D(92, 3, activation=\"relu\"),\n",
|
|
||||||
" MaxPooling2D(pool_size=(2, 2)),\n",
|
|
||||||
" Flatten(),\n",
|
|
||||||
" Dense(250, activation=\"relu\"),\n",
|
|
||||||
" Dense(4, activation=\"softmax\"),\n",
|
|
||||||
" ]\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
|
||||||
"model.summary()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"adam = optimizers.Adam(learning_rate=7e-6)\n",
|
|
||||||
"model.compile(optimizer=adam, loss=\"sparse_categorical_crossentropy\", metrics=[\"accuracy\"])\n",
|
|
||||||
"history = model.fit(dataset_train, validation_data=dataset_validate, epochs=5, batch_size=BATCH_SIZE)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def plot_training_analysis():\n",
|
|
||||||
" acc = history.history[\"accuracy\"]\n",
|
|
||||||
" val_acc = history.history[\"val_accuracy\"]\n",
|
|
||||||
" loss = history.history[\"loss\"]\n",
|
|
||||||
" val_loss = history.history[\"val_loss\"]\n",
|
|
||||||
"\n",
|
|
||||||
" epochs = range(len(loss))\n",
|
|
||||||
"\n",
|
|
||||||
" plt.plot(epochs, acc, \"b\", linestyle=\"--\", label=\"Training acc\")\n",
|
|
||||||
" plt.plot(epochs, val_acc, \"g\", label=\"Validation acc\")\n",
|
|
||||||
" plt.title(\"Training and validation accuracy\")\n",
|
|
||||||
" plt.legend()\n",
|
|
||||||
"\n",
|
|
||||||
" plt.figure()\n",
|
|
||||||
"\n",
|
|
||||||
" plt.plot(epochs, loss, \"b\", linestyle=\"--\", label=\"Training loss\")\n",
|
|
||||||
" plt.plot(epochs, val_loss, \"g\", label=\"Validation loss\")\n",
|
|
||||||
" plt.title(\"Training and validation loss\")\n",
|
|
||||||
" plt.legend()\n",
|
|
||||||
"\n",
|
|
||||||
" plt.show()\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"plot_training_analysis()\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Save the weights\n",
|
|
||||||
"# model.save('models/rot_25e')"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"interpreter": {
|
|
||||||
"hash": "e55666fbbf217aa3df372b978577f47b6009e2f78e2ec76a584f49cd54a1e62c"
|
|
||||||
},
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": ".env",
|
|
||||||
"language": "python",
|
|
||||||
"name": ".env"
|
|
||||||
},
|
|
||||||
"orig_nbformat": 4
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 2
|
|
||||||
}
|
|
294
src/notebook_train.ipynb
Normal file
294
src/notebook_train.ipynb
Normal file
File diff suppressed because one or more lines are too long
608
src/notebook_train_rot.ipynb
Normal file
608
src/notebook_train_rot.ipynb
Normal file
File diff suppressed because one or more lines are too long
Loading…
Reference in a new issue