2022-03-30 08:59:31 +00:00
{
2022-03-30 09:48:21 +00:00
"cells": [
{
"cell_type": "code",
2022-04-05 20:38:05 +00:00
"execution_count": 2,
2022-03-30 09:48:21 +00:00
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import PIL.Image\n",
"import glob\n",
"import os\n",
"\n",
"import matplotlib.pyplot as plt\n",
2022-04-11 20:52:16 +00:00
"%matplotlib inline\n"
2022-03-30 09:48:21 +00:00
]
},
{
"cell_type": "code",
2022-04-11 20:52:16 +00:00
"execution_count": 3,
2022-03-30 09:48:21 +00:00
"metadata": {},
"outputs": [
2022-03-30 08:59:31 +00:00
{
2022-04-11 20:52:16 +00:00
"name": "stdout",
"output_type": "stream",
"text": [
"/tmp/deepl/dataset_rot\n",
"['octane', 'werewolf', 'breakout', 'aftershock']\n"
2022-03-30 09:48:21 +00:00
]
}
],
"source": [
"IMAGE_SIZE = (400, 150, 3)\n",
"RESIZED_SIZE = (100, 50, 3)\n",
"RESIZED_SIZE_PIL = (RESIZED_SIZE[1], RESIZED_SIZE[0], RESIZED_SIZE[2])\n",
2022-04-05 20:38:05 +00:00
"DATASET_PATH = \"./dataset_rot/\"\n",
2022-03-30 09:48:21 +00:00
"DATASET_PATH = os.path.abspath(DATASET_PATH)\n",
2022-04-05 20:38:05 +00:00
"CLASSES = next(os.walk(DATASET_PATH))[1]\n",
2022-03-30 09:48:21 +00:00
"\n",
"print(DATASET_PATH)\n",
2022-04-11 20:52:16 +00:00
"print(CLASSES)\n"
2022-03-30 09:48:21 +00:00
]
},
{
"cell_type": "code",
2022-04-11 20:52:16 +00:00
"execution_count": 4,
2022-03-30 09:48:21 +00:00
"metadata": {},
2022-04-11 20:52:16 +00:00
"outputs": [],
2022-03-30 09:48:21 +00:00
"source": [
"def load_data():\n",
" # Récupération des fichiers\n",
2022-04-11 20:52:16 +00:00
" files = glob.glob(f\"{DATASET_PATH}/**/*.jpg\", recursive=True)\n",
2022-03-30 09:48:21 +00:00
"\n",
2022-04-11 20:52:16 +00:00
" # Initialise les structures de données\n",
2022-03-30 09:48:21 +00:00
" x = np.zeros((len(files), *RESIZED_SIZE_PIL))\n",
" y = np.zeros((len(files), 1))\n",
"\n",
" # print(f\"x.shape = {x.shape}\")\n",
"\n",
" for i, path in enumerate(files):\n",
2022-04-11 20:52:16 +00:00
" # Lecture de l'image\n",
2022-03-30 09:48:21 +00:00
" img = PIL.Image.open(path)\n",
"\n",
" # print(f\"img.size = {img.size}\")\n",
"\n",
" # Redimensionnement de l'image\n",
" img = img.resize(RESIZED_SIZE[:-1], PIL.Image.ANTIALIAS)\n",
"\n",
" # print(f\"img.size = {img.size}\")\n",
"\n",
" test = np.asarray(img)\n",
"\n",
" # print(f\"test.shape = {test.shape}\")\n",
"\n",
2022-04-11 20:52:16 +00:00
" # Remplissage de la variable x\n",
2022-03-30 09:48:21 +00:00
" x[i] = test\n",
"\n",
" # On récupère l'index dans le path\n",
" class_label = path.split(\"/\")[-2]\n",
"\n",
" # On récupère le numéro de la classe à partir du string\n",
" class_label = CLASSES.index(class_label)\n",
2022-04-11 20:52:16 +00:00
"\n",
2022-03-30 09:48:21 +00:00
" # Remplissage de la variable y\n",
" y[i] = class_label\n",
"\n",
2022-04-11 20:52:16 +00:00
" return x, y\n"
2022-03-30 09:48:21 +00:00
]
},
{
"cell_type": "code",
2022-04-11 20:52:16 +00:00
"execution_count": 5,
2022-03-30 09:48:21 +00:00
"metadata": {},
2022-04-11 20:52:16 +00:00
"outputs": [],
2022-03-30 09:48:21 +00:00
"source": [
"x, y = load_data()\n",
2022-04-11 20:52:16 +00:00
"x = x / 255\n"
2022-03-30 09:48:21 +00:00
]
},
{
"cell_type": "code",
2022-04-11 20:52:16 +00:00
"execution_count": 6,
2022-03-30 09:48:21 +00:00
"metadata": {},
"outputs": [
{
2022-04-11 20:52:16 +00:00
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxoAAAGoCAYAAADB3ZMFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9Z6xty5bfh/1GVc2wwk4n3vxyJ7KbpCgwSJQZZME0JMrQF5kOEEgDovVBhgNlSRZgW4YtmrZFGTYkgRIBGgYkk4Qhw5BpmWayQUpsNdgkLXbufu/1Szece87ZZ4cVZqiq4Q+j5lprnxv6Pfa99zXZa+Duu89ea+ZZNWqE//gPUVWOcpSjHOUoRznKUY5ylKMc5ZMU9/2+gKMc5ShHOcpRjnKUoxzlKP/gydHROMpRjnKUoxzlKEc5ylGO8onL0dE4ylGOcpSjHOUoRznKUY7yicvR0TjKUY5ylKMc5ShHOcpRjvKJy9HROMpRjnKUoxzlKEc5ylGO8onL0dE4ylGOcpSjHOUoRznKUY7yicvR0fg1JiLyvxKRZyLy3mdwrm+IyH/5Ez7mvy4i/8EnecyjHOUoRznKUY5ylKP8/SdHR+PXkIjIW8AfBX5EVV8RkT8kIv/p9/u6jnKUoxzlw0REPi8iKiLh+30tRznKUf7BlqO++ftTjo7Gry15C3iuqu9/Egc7TsajHOUoRznKUY5ylKN8v+ToaHwfRET+VRH5mojcisjPisg/UyBMfwl4TURWIvLngD8J/M7y91XZtxGRf1NEviUiT0TkT4rIrHz3e0TkOyLyrxTo1f9JRB6IyJ8XkSsRuRSRvy4ih+/9N4vI3xWRaxH5cyLSHlznPy8iXy37/cci8trBd79BRP5S+e6JiPxrH3KflYj8GRH5j0Sk/lQe5lGOcpRPTETkh0Xk/1v0xc+IyD9dPp+JyJ8QkW8WXfGfFr3z18quV0VP/U4R+ZKI/FUReV5goP+hiJwfnOMbIvIvfYze+adE5P9XruFviMiPfaYP4ShHOcpnIkd98+tDjo7G90e+BvxjwBnwvwD+A+BngP8q8I6qLlX1vw78C8CPl7/Py75/HPgB4DcDXwZeB/5nB8d+BbgHfA74IxgU6zvAQ+Ax8K8BerD9Pwv8fuALwI8BfwhARH4f8L8u378KfBP4s+W7E+AvA38BeK1cx185vMGiFP7vQA/8s6o6fM9P6ShHOcpnJiJSAf8P4C8Cj4D/HvAfisgPAv8m8FuBfwTTL/8ykIH/Utn9vOipHwcE0x2vAT8MvAn86y+d7qP0zm8B/jTw3wXuA/8e8B+LSPOJ3/BRjnKU75sc9c2vHzk6Gt8HUdX/q6q+o6pZVf8c8EvAb/uV9hMRwZyH/6GqXqrqLfDHgD94sFkG/ueq2qvqFhgxR+Fzqjqq6l9X1UNH4/9YruUSm/S/uXz+3wL+tKr+bVXtgf8Jll35PPBPAe+p6p9Q1U5Vb1X1Jw6OeYo5IV8D/rCqpu/tCR3lKEf5PsjvAJbAH1fVQVX/KvDnMV3w3wH++6r6tqomVf0bRS98QFT1q6r6l4oOegr8W8Dvfmmzj9I7fwT491T1J8p5/s9YsOJ3fNI3e5SjHOX7Kkd98+tEjo7G90FE5J87SNVdAb8RePBd7PoQmAN/62Dfv1A+n+SpqnYHf//vgK8Cf1FEvi4i/+pLxzxkt9pgEx8sOvDN6QtVXQHPsQzKm5gT8VHyO7CowR9/yak5ylGO8mtXXgO+rar54LNvYvO95ePn/E5E5LGI/FkReVtEbrCM7cv67aP0zueAPzrpt6Lj3izXdpSjHOUfHDnqm18ncnQ0PmMRkc8Bfwr4F4H7BRL101j672V52Uh/BmyB36Cq5+XnTFWXH7VPyTb8UVX9IvBPA/8jEfnHv4tLfQebhNN1L7DU4tvAt4Evfsy+fxFLZf4VEXn8XZzrKEc5yvdf3gHefKmG6y1svnfAlz5knw8LJPyx8vmPquop8N/mw/Xbh8m3gX/jQL+dq+pcVf/Md30XRznKUf5+kKO++XUiR0fjs5cFNimeAojIH8YyGh8mT4A3pkLq4vn/KeB/LyKPyv6vi8h/5aNOVgqdvlxgV9dAwuBVv5L8GeAPi8hvLnjFPwb8hKp+A0tvvioi/wOx4vQTEfnthzur6v8W+L9gzsZ3k605ylGO8v2Vn8Ciff+yGJHD7wH+ADaP/zTwb4nIayLiSxFmg+mxzN3AwwmwAq5F5HXgf/w9XMOfAv4FEfntYrIQkX+y1IUd5ShH+QdHjvrm14kcHY3PWFT1Z4E/Afw45kj8KPCffcTmfxUrEn9PRJ6Vz/4VDAr1n5c04V8GfvBjTvmVss2qnPPfVdX/z3dxnX8Z+J8C/xHwLhZd+IPlu1vgn8CUwntYjcnv/ZBj/C+xgvC/LCL3fqVzHuUoR/n+SSFs+AMYKcUz4N8F/jlV/XngXwJ+CvibwCXwvwGcqm6AfwP4zwr04HdgBBf/EBbY+H8C/7fv4Rp+EvjngX8beIHpuj/0SdzfUY5ylF87ctQ3v35EjhD6oxzlKEc5ylGOcpSjHOUon7QcMxpHOcpRjnKUoxzlKEc5ylE+cTk6Gkc5ylGOcpSjHOUoRznKUT5x+VU5GiLy+0XkF8S6R79Mm3qUoxzlKJ+YHPXNUY5ylM9CjrrmKEf55OTvuUZDRDzwi1hR8Hewop3/Ril2PspRjnKUT0yO+uYoRznKZyFHXXOUo3yyEn4V+/424Kuq+nUAEfmzwH8N+MjJOGtbXc7mXG97sirOAVqIkRX0ZerjAydIUTRP25btZNpeyj+lbCBUtcc5h3Nq32VHztB3CiLUdcA7qGtISen6SFZFc0ac4J2US1D7yaAqgMM57NpFERS1yyencsnlsnIhkRXvEOeQUKHiyK5GUJxGJEdc3CBk0Fju1W7NieC9I1SeGCNxjKjuH4sc3C+SQZScBQVi9OV67f5FZM8srftrc+5luulyN2Ln0WwXMz1eQe38CpARydSVxwdhHDM5K1mlnHt6jXpw7MPfHyIiiAhGrR1QDaABcYJz5ZULIM5+8FhiLtu9MoAkQvB2z85uJKcRcY727B7iK8bsynPM5ZlKOb0iIjjE7k8jXkcq7YjjSBz6co+K987ejw8478jZnk2KkZQzKSVUFVfGaUqJvh+IMX63HN9H2cv3pG9mbaPzds71ZosqiJuStzb2PvwFTGPg5Y/lzvZSNpA7f8suP7yfalLmnrP57FzRYzb5VA+PYtdm21PmwIF+2135wdyRD/7Ttjy4vpdu5gP3LYef6qGKQA6OevhLiwKY9ICUeRm8fT/Ng5wP9hd2z7Go/Zf0/d0z7j9VpGjZD357uO90hIMb+x5n2e7eP0w9iX7Ilh/3yeGVTf/a34Ps/lU+L+/cFZXlRHDsn+3uJHqwFhY9rkDKSgZyWUenZ/vkvSfPVPWwoetRfmX53m2bptH5rOVm05nOdzbnpzG5n4YvvcydAWTzVl+aaya63+dAF00qzZU57HbzzP4WsbmaUt7ZKHd0Q9lXkGIrOWKKpJw4DEB/9Lw71BZ3J81+/deXNj20mPbXJIfb3r3jlx/GdyEv65XpI93bjh+123cddy/Xf6iO79hX2ezGj971g+d96dXf1UN3tdt3c23T0Wz8TfbTpMSn8ba39wDkjt3Gbiy9fOz9+9m/uTgOpPThts2vxtF4HWt2Msl3gN/+8kYi8kewNu+czOf8k7/n9/FXf/YbDCnSNg5BimHqSFqZAnaTgZ9MATtIUUmjkrIjJQcSQNxuQFdVwDtHHAAcjx6f084q5suEd5C7hu0afvHnRkQCb711n5MTz5tvws1tx8/+0jP6PrLdjMzawNlZgxmviTRm4pCJQ804VLRzaFrF+YhIZEwQE2w2MI7gQkYFuo2Q1dGez/FtQ/vqa+Rmyap5DTTTDM+ou+ecXf5dJK7Q+BRVM9abyrFsA/cenvLKa/d5+vQ57733lM1qYLMZ8FLhxZNHR86CVAP4xJA8MTkuny/pB493PeKgaWtzKkTJCbqtGTJV5YohZMa6EhFJiB+Jo9D35iQ55/E+410ijfY+gtzipeP1N044O29
"text/plain": [
"<Figure size 864x432 with 9 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
2022-03-30 09:48:21 +00:00
}
],
"source": [
"# Randomisation des indices et affichage de 9 images alétoires de la base d'apprentissage\n",
"indices = np.arange(x.shape[0])\n",
"np.random.shuffle(indices)\n",
"\n",
"plt.figure(figsize=(12, 6))\n",
"\n",
2022-04-11 20:52:16 +00:00
"for i in range(0, 3 * 3):\n",
" plt.subplot(3, 3, i + 1)\n",
2022-03-30 09:48:21 +00:00
" plt.title(CLASSES[int(y[indices[i]])])\n",
" plt.imshow(x[indices[i]])\n",
"\n",
"plt.tight_layout()\n",
2022-04-11 20:52:16 +00:00
"plt.show()\n"
2022-03-30 09:48:21 +00:00
]
},
{
"cell_type": "code",
2022-04-11 20:52:16 +00:00
"execution_count": 7,
2022-03-30 09:48:21 +00:00
"metadata": {},
2022-04-11 20:52:16 +00:00
"outputs": [],
2022-03-30 09:48:21 +00:00
"source": [
"import tensorflow\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import InputLayer, Dense, Flatten, Conv2D, MaxPooling2D\n",
2022-04-11 20:52:16 +00:00
"from tensorflow.keras import optimizers\n"
2022-03-30 09:48:21 +00:00
]
},
{
"cell_type": "code",
2022-04-11 20:52:16 +00:00
"execution_count": 8,
2022-03-30 09:48:21 +00:00
"metadata": {},
"outputs": [
2022-04-05 20:38:05 +00:00
{
2022-04-11 20:52:16 +00:00
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" conv2d (Conv2D) (None, 48, 98, 32) 896 \n",
" \n",
" max_pooling2d (MaxPooling2D (None, 24, 49, 32) 0 \n",
" ) \n",
" \n",
" conv2d_1 (Conv2D) (None, 22, 47, 64) 18496 \n",
" \n",
" max_pooling2d_1 (MaxPooling (None, 11, 23, 64) 0 \n",
" 2D) \n",
" \n",
" conv2d_2 (Conv2D) (None, 9, 21, 92) 53084 \n",
" \n",
" max_pooling2d_2 (MaxPooling (None, 4, 10, 92) 0 \n",
" 2D) \n",
" \n",
" flatten (Flatten) (None, 3680) 0 \n",
" \n",
" dense (Dense) (None, 250) 920250 \n",
" \n",
" dense_1 (Dense) (None, 4) 1004 \n",
" \n",
"=================================================================\n",
"Total params: 993,730\n",
"Trainable params: 993,730\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-04-06 08:11:35.469372: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"2022-04-06 08:11:36.213690: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 1538 MB memory: -> device: 0, name: Quadro K620, pci bus id: 0000:03:00.0, compute capability: 5.0\n",
"2022-04-06 08:11:37.102887: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 204000000 exceeds 10% of free system memory.\n",
"2022-04-06 08:11:37.327551: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 204000000 exceeds 10% of free system memory.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-04-06 08:11:38.606991: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8100\n",
"2022-04-06 08:11:39.092316: W tensorflow/stream_executor/gpu/asm_compiler.cc:111] *** WARNING *** You are using ptxas 10.1.243, which is older than 11.1. ptxas before 11.1 is known to miscompile XLA code, leading to incorrect results or invalid-address errors.\n",
"\n",
"You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"136/136 [==============================] - 8s 38ms/step - loss: 1.3213 - accuracy: 0.3459 - val_loss: 2.0370 - val_accuracy: 0.0000e+00\n",
"Epoch 2/10\n",
"136/136 [==============================] - 5s 35ms/step - loss: 1.2617 - accuracy: 0.5968 - val_loss: 2.0015 - val_accuracy: 0.0000e+00\n",
"Epoch 3/10\n",
"136/136 [==============================] - 5s 35ms/step - loss: 1.1950 - accuracy: 0.7044 - val_loss: 2.0601 - val_accuracy: 0.0000e+00\n",
"Epoch 4/10\n",
"136/136 [==============================] - 5s 35ms/step - loss: 1.1062 - accuracy: 0.7488 - val_loss: 1.9844 - val_accuracy: 0.0000e+00\n",
"Epoch 5/10\n",
"136/136 [==============================] - 5s 35ms/step - loss: 0.9996 - accuracy: 0.7671 - val_loss: 1.9182 - val_accuracy: 0.0000e+00\n",
"Epoch 6/10\n",
"136/136 [==============================] - 5s 35ms/step - loss: 0.8900 - accuracy: 0.7841 - val_loss: 1.8775 - val_accuracy: 0.0000e+00\n",
"Epoch 7/10\n",
"136/136 [==============================] - 5s 36ms/step - loss: 0.7846 - accuracy: 0.7979 - val_loss: 1.6075 - val_accuracy: 0.0050\n",
"Epoch 8/10\n",
"136/136 [==============================] - 5s 35ms/step - loss: 0.6920 - accuracy: 0.8165 - val_loss: 1.3300 - val_accuracy: 0.2317\n",
"Epoch 9/10\n",
"136/136 [==============================] - 5s 35ms/step - loss: 0.6098 - accuracy: 0.8512 - val_loss: 1.3028 - val_accuracy: 0.2817\n",
"Epoch 10/10\n",
"136/136 [==============================] - 5s 35ms/step - loss: 0.5365 - accuracy: 0.8791 - val_loss: 1.2423 - val_accuracy: 0.3483\n"
2022-03-30 09:48:21 +00:00
]
2022-04-05 20:38:05 +00:00
}
],
"source": [
"model = Sequential()\n",
"\n",
"model.add(InputLayer(input_shape=RESIZED_SIZE_PIL))\n",
"\n",
"model.add(Conv2D(32, 3, activation=\"relu\"))\n",
"model.add(MaxPooling2D(pool_size=(2, 2)))\n",
"\n",
"model.add(Conv2D(64, 3, activation=\"relu\"))\n",
"model.add(MaxPooling2D(pool_size=(2, 2)))\n",
"\n",
"model.add(Conv2D(92, 3, activation=\"relu\"))\n",
"model.add(MaxPooling2D(pool_size=(2, 2)))\n",
"\n",
"model.add(Flatten())\n",
"\n",
"model.add(Dense(250, activation=\"relu\"))\n",
"\n",
"model.add(Dense(4, activation=\"softmax\"))\n",
"\n",
"model.summary()\n",
"\n",
"adam = optimizers.Adam(learning_rate=7e-6)\n",
2022-04-11 20:52:16 +00:00
"model.compile(optimizer=adam, loss=\"sparse_categorical_crossentropy\", metrics=[\"accuracy\"])\n",
"history = model.fit(x, y, validation_split=0.15, epochs=10, batch_size=25)\n"
2022-04-05 20:38:05 +00:00
]
},
{
"cell_type": "code",
2022-04-11 20:52:16 +00:00
"execution_count": 9,
2022-04-05 20:38:05 +00:00
"metadata": {},
"outputs": [
{
2022-04-11 20:52:16 +00:00
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEICAYAAABPgw/pAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAtbklEQVR4nO3dd3xUVf7/8deHFqpIESkBAQUUkQQIoCCCgCuWlVVgBcuKropYAburyGLFH6usq+hiwbqi6IpIWRQV9YuIIGABUQOGLlJDDRA4vz/OJExiyhAS7mTm/Xw85pG5Ze585iZ8OPM5955jzjlERKT0KxN0ACIiUjyU0EVEYoQSuohIjFBCFxGJEUroIiIxQgldRCRGKKHHMDObbmZXFPe+QTKzNDPrWQLHdWZ2Quj5s2Z2XyT7FuF9LjWzD4oap0hBTNehRxcz2xG2WBnYA+wPLQ9yzr1+5KOKHmaWBlztnJtZzMd1QDPnXGpx7WtmjYFfgPLOucxiCVSkAOWCDkBycs5VzXpeUPIys3JKEhIt9PcYHVRyKSXMrJuZrTazO83sV2C8mdUwsylmtsHMtoSeJ4a9ZpaZXR16PtDM/s/MRof2/cXMzinivk3M7DMz225mM83saTN7LZ+4I4nxATObHTreB2ZWO2z75Wa2wsw2mdnfCjg/Hc3sVzMrG7buQjP7NvS8g5nNMbOtZrbOzJ4yswr5HOslM3swbPn20GvWmtlVufY9z8wWmtk2M1tlZiPCNn8W+rnVzHaY2WlZ5zbs9Z3MbJ6ZpYd+dor03Bziea5pZuNDn2GLmU0K29bbzBaFPsMyM+sVWp+jvGVmI7J+z2bWOFR6+quZrQQ+Dq2fGPo9pIf+Rk4Oe30lM/tH6PeZHvobq2RmU83splyf51szuzCvzyr5U0IvXeoCNYHjgGvxv7/xoeVGwG7gqQJe3xH4EagNPAa8YGZWhH3/A3wF1AJGAJcX8J6RxHgJcCVQB6gA3AZgZi2BZ0LHrx96v0Ty4JybC+wEuuc67n9Cz/cDQ0Of5zSgB3B9AXETiqFXKJ6zgGZA7vr9TuAvwNHAecBgM/tTaNsZoZ9HO+eqOufm5Dp2TWAq8GTosz0OTDWzWrk+w+/OTR4KO8+v4kt4J4eO9UQohg7AK8Dtoc9wBpCWz3vkpStwEnB2aHk6/jzVARYA4SXC0UA7oBP+7/gO4ADwMnBZ1k5mlgQ0wJ8bORTOOT2i9IH/h9Uz9LwbsBeoWMD+ycCWsOVZ+JINwEAgNWxbZcABdQ9lX3yyyAQqh21/DXgtws+UV4z3hi1fD/wv9Hw4MCFsW5XQOeiZz7EfBF4MPa+GT7bH5bPvEODdsGUHnBB6/hLwYOj5i8CjYfs1D983j+OOAZ4IPW8c2rdc2PaBwP+Fnl8OfJXr9XOAgYWdm0M5z0A9fOKskcd+/86Kt6C/v9DyiKzfc9hna1pADEeH9qmO/w9nN5CUx34VgS34fgnwiX9sSfybivWHWuilywbnXEbWgplVNrN/h77CbsN/xT86vOyQy69ZT5xzu0JPqx7ivvWBzWHrAFblF3CEMf4a9nxXWEz1w4/tnNsJbMrvvfCt8YvMLAG4CFjgnFsRiqN5qAzxayiOh/Gt9cLkiAFYkevzdTSzT0KljnTgugiPm3XsFbnWrcC3TrPkd25yKOQ8N8T/zrbk8dKGwLII481L9rkxs7Jm9miobLONgy392qFHxbzeK/Q3/SZwmZmVAQbgv1HIIVJCL11yX5J0K9AC6OicO4qDX/HzK6MUh3VATTOrHLauYQH7H06M68KPHXrPWvnt7Jxbgk+I55Cz3AK+dLMU3wo8CrinKDHgv6GE+w8wGWjonKsOPBt23MIuIVuLL5GEawSsiSCu3Ao6z6vwv7Oj83jdKuD4fI65E//tLEvdPPYJ/4yXAL3xZanq+FZ8VgwbgYwC3utl4FJ8KWyXy1WeksgooZdu1fBfY7eG6rH3l/Qbhlq884ERZlbBzE4D/lhCMb4NnG9mp4c6MEdS+N/sf4Bb8AltYq44tgE7zOxEYHCEMbwFDDSzlqH/UHLHXw3f+s0I1aMvCdu2AV/qaJrPsacBzc3sEjMrZ2YXAy2BKRHGljuOPM+zc24dvrY9NtR5Wt7MshL+C8CVZtbDzMqYWYPQ+QFYBPQP7Z8C9I0ghj34b1GV8d+CsmI4gC9fPW5m9UOt+dNC36YIJfADwD9Q67zIlNBLtzFAJXzr50vgf0fofS/Fdyxuwtet38T/Q87LGIoYo3NuMXADPkmvw9dZVxfysjfwHXUfO+c2hq2/DZ9stwPPhWKOJIbpoc/wMZAa+hnuemCkmW3H1/zfCnvtLuAhYLb5q2tOzXXsTcD5+Nb1Jnwn4fm54o7UGAo+z5cD+/DfUn7D9yHgnPsK3+n6BJAOfMrBbw334VvUW4C/k/MbT15ewX9DWgMsCcUR7jbgO2AesBkYRc4c9ApwCr5PRopANxbJYTOzN4GlzrkS/4YgscvM/gJc65w7PehYSiu10OWQmVl7Mzs+9BW9F75uOingsKQUC5WzrgfGBR1LaaaELkVRF39J3Q78NdSDnXMLA41ISi0zOxvf37Cewss6UgCVXEREYoRa6CIiMSKwwblq167tGjduHNTbi4iUSl9//fVG59wxeW0LLKE3btyY+fPnB/X2IiKlkpnlvrs4m0ouIiIxQgldRCRGKKGLiMSIqJqxaN++faxevZqMjIzCd5ZAVKxYkcTERMqXLx90KCKSS1Ql9NWrV1OtWjUaN25M/vMuSFCcc2zatInVq1fTpEmToMMRkVyiquSSkZFBrVq1lMyjlJlRq1YtfYMSiVJRldABJfMop9+PSPSKuoQuIhJrnIMVK+C992DnzpJ7HyX0MJs2bSI5OZnk5GTq1q1LgwYNspf37t1b4Gvnz5/PzTffXOh7dOrUqdB9RKT0W7oUhgyBbt2gZk1o3Bj+9CdYWILD2EVVp2jQatWqxaJFiwAYMWIEVatW5bbbDk6ynpmZSblyeZ+ylJQUUlJSCn2PL774olhiFZHgbdgA33zjH4sW+Z8jRsBFF8HmzfDcc9C6NfTvD0lJ/pGcXHLxKKEXYuDAgVSsWJGFCxfSuXNn+vfvzy233EJGRgaVKlVi/PjxtGjRglmzZjF69GimTJnCiBEjWLlyJcuXL2flypUMGTIku/VetWpVduzYwaxZsxgxYgS1a9fm+++/p127drz22muYGdOmTWPYsGFUqVKFzp07s3z5cqZMyTkrWVpaGpdffjk7Q9/fnnrqqezW/6hRo3jttdcoU6YM55xzDo8++iipqalcd911bNiwgbJlyzJx4kSOPz6/6R1FJNyBA5Ca6hN2gwbQqZMvoYQPR9WggU/W1ar55Y4dYds2KJvflO0lIKoTerduv1/35z/D9dfDrl1w7rm/3z5woH9s3Ah9c82AOGtW0eJYvXo1X3zxBWXLlmXbtm18/vnnlCtXjpkzZ3LPPffwzjvv/O41S5cu5ZNPPmH79u20aNGCwYMH/+7a7YULF7J48WLq169P586dmT17NikpKQwaNIjPPvuMJk2aMGDAgDxjqlOnDh9++CEVK1bk559/ZsCAAcyfP5/p06fz3nvvMXfuXCpXrszmzZsBuPTSS7nrrru48MILycjI4MCBA0U7GSIxbv9+n4SdgxtvhAUL4Ntvfc4BuPJKn9AbNYLHH/ct8KQkqF0753GOZCLPEtUJPVr069ePsqHfTnp6OldccQU///wzZsa+ffvyfM15551HQkICCQkJ1KlTh/Xr15OYmJhjnw4dOmSvS05OJi0tjapVq9K0adPs67wHDBjAuHG/n8Rl37593HjjjSxatIiyZcvy008/ATBz5kyuvPJKKlf2k7XXrFmT7du3s2bNGi688ELA3xwkIrB2ra9pZ5VMFi2CE06AadPAzCfzhAS45pqD5ZKWLf1rzWDo0ACDz0NUJ/SCWtSVKxe8vXbtorfIc6tSpUr28/vuu48zzzyTd999l7S0NLrl9TUCSEhIyH5etmxZMjMzi7RPfp544gmOPfZYvvnmGw4cOKAkLVKAAwd8J+WCBbBuHdx+u19/2WX
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEICAYAAABPgw/pAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAA2+klEQVR4nO3dd3hVVdbH8e9KIQkk1FCTYACpCgQINShVujQFQaSIXcdRcUYdHAXrjMqo46voYAGRpgPIoIAgvSOhiFSlhBB6C0mAkMJ+/9iXECAkIdzkpKzP8+TJveece+7KBX7s7LPP3mKMQSmlVMHn4XQBSiml3EMDXSmlCgkNdKWUKiQ00JVSqpDQQFdKqUJCA10ppQoJDXSVIRGZJyJD3X2sk0QkSkQ65sJ5jYjc6nr8mYi8kp1jc/A+g0RkQU7rzOS8bUUkxt3nVXnPy+kClPuISEK6p8WBC0Cq6/ljxpjJ2T2XMaZrbhxb2BljHnfHeUQkFNgHeBtjUlznngxk+89QFT0a6IWIMcb/0mMRiQIeNsYsvPo4EfG6FBJKqcJDu1yKgEu/UovIiyJyBBgvImVE5EcROS4ip12Pg9O9ZqmIPOx6PExEVorIGNex+0Skaw6PrSYiy0UkXkQWisgnIjLpOnVnp8Y3RGSV63wLRCQw3f7BIrJfRE6KyMuZfD7NReSIiHim29ZHRLa4HjcTkTUiEisih0XkYxEpdp1zTRCRN9M9/6vrNYdEZPhVx3YXkU0iEiciB0RkdLrdy13fY0UkQURaXvps072+lYisF5Ezru+tsvvZZEZE6rpeHysi20SkZ7p93URku+ucB0XkL67tga4/n1gROSUiK0RE8yWP6QdedFQCygK3AI9i/+zHu55XBc4DH2fy+ubALiAQeBf4UkQkB8dOAX4BygGjgcGZvGd2arwfeBCoABQDLgVMPeBT1/mruN4vmAwYY9YBZ4H2V513iutxKvCc6+dpCXQAnsykblw1dHHVcxdQE7i6//4sMAQoDXQHnhCR3q59d7q+lzbG+Btj1lx17rLAHOAj18/2PjBHRMpd9TNc89lkUbM38AOwwPW6p4HJIlLbdciX2O67AOB2YLFr+/NADFAeqAiMBHRekTymgV50XARGGWMuGGPOG2NOGmNmGGPOGWPigbeANpm8fr8x5nNjTCrwNVAZ+w8328eKSFWgKfCqMSbJGLMSmH29N8xmjeONMb8bY84D3wFhru33Aj8aY5YbYy4Ar7g+g+uZCgwEEJEAoJtrG8aYDcaYtcaYFGNMFPCfDOrISH9XfVuNMWex/4Gl//mWGmN+M8ZcNMZscb1fds4L9j+AP4wx37jqmgrsBO5Od8z1PpvMtAD8gX+6/owWAz/i+myAZKCeiJQ0xpw2xmxMt70ycIsxJtkYs8LoRFF5TgO96DhujEm89EREiovIf1xdEnHYX/FLp+92uMqRSw+MMedcD/1v8NgqwKl02wAOXK/gbNZ4JN3jc+lqqpL+3K5APXm998K2xvuKiA/QF9hojNnvqqOWqzvhiKuOt7Gt9axcUQOw/6qfr7mILHF1KZ0BHs/meS+de/9V2/YDQemeX++zybJmY0z6//zSn/ce7H92+0VkmYi0dG1/D9gNLBCRvSLyUvZ+DOVOGuhFx9WtpeeB2kBzY0xJLv+Kf71uFHc4DJQVkeLptoVkcvzN1Hg4/bld71nuegcbY7Zjg6srV3a3gO262QnUdNUxMic1YLuN0puC/Q0lxBhTCvgs3Xmzat0ewnZFpVcVOJiNurI6b8hV/d9p5zXGrDfG9MJ2x8zCtvwxxsQbY543xlQHegIjRKTDTdaibpAGetEVgO2TjnX1x47K7Td0tXgjgdEiUszVurs7k5fcTI3TgR4i0tp1AfN1sv77PgV4Bvsfx3+vqiMOSBCROsAT2azhO2CYiNRz/Ydydf0B2N9YEkWkGfY/kkuOY7uIql/n3HOBWiJyv4h4ich9QD1s98jNWIdtzb8gIt4i0hb7ZzTN9Wc2SERKGWOSsZ/JRQAR6SEit7qulZzBXnfIrItL5QIN9KLrQ8APOAGsBX7Ko/cdhL2weBJ4E/gWO14+Ix+SwxqNMduAp7AhfRg4jb1ol5lLfdiLjTEn0m3/CzZs44HPXTVnp4Z5rp9hMbY7YvFVhzwJvC4i8cCruFq7rteew14zWOUaOdLiqnOfBHpgf4s5CbwA9Liq7htmjEnCBnhX7Oc+FhhijNnpOmQwEOXqenoc++cJ9qLvQiABWAOMNcYsuZla1I0TvW6hnCQi3wI7jTG5/huCUoWdttBVnhKRpiJSQ0Q8XMP6emH7YpVSN0nvFFV5rRIwE3uBMgZ4whizydmSlCoctMtFKaUKCe1yUUqpQsKxLpfAwEATGhrq1NsrpVSBtGHDhhPGmPIZ7XMs0ENDQ4mMjHTq7ZVSqkASkavvEE6jXS5KKVVIaKArpVQhoYGulFKFhI5DV6oISU5OJiYmhsTExKwPVo7y9fUlODgYb2/vbL9GA12pIiQmJoaAgABCQ0O5/vokymnGGE6ePElMTAzVqlXL9uu0y0WpIiQxMZFy5cppmOdzIkK5cuVu+DcpDXSlihgN84IhJ39O2uVSQKVeTGXb8W2sOWCXmuxyaxduKX31egdKqaJEA72AiE2MZW3MWtYcWMPqmNWsi1lHfFL8FcfUK1+P7jW7061mNyJCIvD2zP7FFKXywsmTJ+nQwS5kdOTIETw9PSlf3t70+Msvv1CsWLHrvjYyMpKJEyfy0UcfZfoerVq1YvXq1Tdd69KlSxkzZgw//niza4bkHQ30fOiiucjvJ39n9YHVaQG+/fh2ADzEg/oV6vNAgwdoGdySViGtSL6YzNw/5jL3j7l8uPZD3lv9HiV9StKpRie63dqNrjW7Usm/ksM/lVJQrlw5Nm/eDMDo0aPx9/fnL3/5S9r+lJQUvLwyjqXw8HDCw8OzfA93hHlBVeACfeeJnUz9bSo1ytagRpka1Chbg4olKhbofsGEpAR+OfhLWnivObCG04mnASjtW5qWwS0ZePtAWga3pFlQMwJ8Aq45R53AOoxoOYK4C3Es2rvIBvzuuUzfPh2AJpWb0K1mN7rX7E54lXA8Pa63FrRSeWvYsGH4+vqyadMmIiIiGDBgAM888wyJiYn4+fkxfvx4ateufUWLefTo0URHR7N3716io6N59tln+fOf/wyAv78/CQkJLF26lNGjRxMYGMjWrVtp0qQJkyZNQkSYO3cuI0aMoESJEkRERLB3795MW+KnTp1i+PDh7N27l+LFizNu3DgaNGjAsmXLeOaZZwDb5718+XISEhK47777iIuLIyUlhU8//ZQ77rgjTz7LAhfoW45u4c0Vb3Ix3aLkxb2LU71MdRvwrpCvUaYG1ctU55bSt1DM8/q/xuU1Ywz7Yvdd0frecnRL2s9TN7Aufev2TWt91w6sjYdk/9p1SZ+S9Knbhz51+2CM4dejvzL3j7nM+WMOb614izeWv0Fg8UC63NqF7jW706lGJ8r6lc2tH1flc23bXrutf3948kk4dw66dbt2/7Bh9uvECbj33iv3LV2aszpiYmJYvXo1np6exMXFsWLFCry8vFi4cCEjR45kxowZ17xm586dLFmyhPj4eGrXrs0TTzxxzZjtTZs2sW3bNqpUqUJERASrVq0iPDycxx57jOXLl1OtWjUGDhyYZX2jRo2iUaNGzJo1i8WLFzNkyBA2b97MmDFj+OSTT4iIiCAhIQFfX1/GjRtH586defnll0lNTeXcuXM5+1ByoMAFev/b+tO7Tm+iYqPYc2oPe07vYc+pPeyN3cvuU7tZsGcB51POpx3vIR5ULVU1Leyrl6l+Reu+pE/JXK33fPJ5NhzeYAM8Zg2rD6zm2NljAPgX86d5UHNevuNlWga3pEVwC8r4lXHbe4sIYZXCCKsUxsg7RnLy3EkW7FnAnD/mMO+PeUzaMgkP8aBlcMu0vvcGFRsU6N92VMHUr18/PD3tb41nzpxh6NCh/PHHH4gIycnJGb6me/fu+Pj44OPjQ4UKFTh69CjBwcFXHNOsWbO0bWFhYURFReHv70/16tXTxnc
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
2022-04-05 20:38:05 +00:00
}
],
"source": [
"def plot_training_analysis():\n",
2022-04-11 20:52:16 +00:00
" acc = history.history[\"accuracy\"]\n",
" val_acc = history.history[\"val_accuracy\"]\n",
" loss = history.history[\"loss\"]\n",
" val_loss = history.history[\"val_loss\"]\n",
"\n",
" epochs = range(len(acc))\n",
2022-04-05 20:38:05 +00:00
"\n",
2022-04-11 20:52:16 +00:00
" plt.plot(epochs, acc, \"b\", linestyle=\"--\", label=\"Training acc\")\n",
" plt.plot(epochs, val_acc, \"g\", label=\"Validation acc\")\n",
" plt.title(\"Training and validation accuracy\")\n",
" plt.legend()\n",
2022-04-05 20:38:05 +00:00
"\n",
2022-04-11 20:52:16 +00:00
" plt.figure()\n",
2022-04-05 20:38:05 +00:00
"\n",
2022-04-11 20:52:16 +00:00
" plt.plot(epochs, loss, \"b\", linestyle=\"--\", label=\"Training loss\")\n",
" plt.plot(epochs, val_loss, \"g\", label=\"Validation loss\")\n",
" plt.title(\"Training and validation loss\")\n",
" plt.legend()\n",
2022-04-05 20:38:05 +00:00
"\n",
2022-04-11 20:52:16 +00:00
" plt.show()\n",
2022-04-05 20:38:05 +00:00
"\n",
"\n",
2022-04-11 20:52:16 +00:00
"plot_training_analysis()\n"
2022-04-05 20:38:05 +00:00
]
},
{
"cell_type": "code",
2022-04-11 20:52:16 +00:00
"execution_count": 10,
2022-04-05 20:38:05 +00:00
"metadata": {},
"outputs": [
{
"name": "stdout",
2022-03-30 09:48:21 +00:00
"output_type": "stream",
"text": [
2022-04-05 20:38:05 +00:00
"/tmp/deepl/data\n",
"[]\n"
2022-03-30 09:48:21 +00:00
]
2022-04-05 20:38:05 +00:00
}
],
"source": [
"IMAGE_SIZE = (400, 150, 3)\n",
"RESIZED_SIZE = (100, 50, 3)\n",
"RESIZED_SIZE_PIL = (RESIZED_SIZE[1], RESIZED_SIZE[0], RESIZED_SIZE[2])\n",
"DATASET_PATH = \"./data/\"\n",
"DATASET_PATH = os.path.abspath(DATASET_PATH)\n",
"CLASSES = next(os.walk(DATASET_PATH))[1]\n",
"\n",
"print(DATASET_PATH)\n",
2022-04-11 20:52:16 +00:00
"print(CLASSES)\n"
2022-04-05 20:38:05 +00:00
]
},
{
"cell_type": "code",
2022-04-11 20:52:16 +00:00
"execution_count": 11,
2022-04-05 21:31:33 +00:00
"metadata": {},
"outputs": [],
"source": [
"import tensorflow\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import InputLayer, Dense, Flatten, Conv2D, MaxPooling2D\n",
2022-04-11 20:52:16 +00:00
"from tensorflow.keras import optimizers\n"
2022-04-05 21:31:33 +00:00
]
},
{
"cell_type": "code",
2022-04-11 20:52:16 +00:00
"execution_count": 12,
2022-04-05 20:38:05 +00:00
"metadata": {},
"outputs": [
2022-03-30 08:59:31 +00:00
{
2022-03-30 09:48:21 +00:00
"name": "stdout",
"output_type": "stream",
"text": [
2022-04-05 21:31:33 +00:00
"dataset_length = 10000\n",
"batch size = 32\n",
"number of batchs = 312\n",
2022-04-05 20:38:05 +00:00
"\n",
2022-04-05 21:31:33 +00:00
"train_size = 250\n",
"validation_size = 9750\n"
2022-03-30 09:48:21 +00:00
]
2022-03-30 08:59:31 +00:00
}
2022-03-30 09:48:21 +00:00
],
"source": [
2022-04-05 20:38:05 +00:00
"import tensorflow as tf\n",
"import tensorflow_addons as tfa\n",
"import sqlite3\n",
2022-03-30 09:48:21 +00:00
"\n",
2022-04-05 20:38:05 +00:00
"AUTOTUNE = tf.data.experimental.AUTOTUNE\n",
2022-04-05 21:31:33 +00:00
"BATCH_SIZE = 32\n",
"SHUFFLE_SIZE = 32\n",
"LIMIT = 10000\n",
2022-03-30 09:48:21 +00:00
"\n",
2022-04-11 20:52:16 +00:00
"\n",
2022-04-05 20:38:05 +00:00
"def customGenerator():\n",
2022-04-11 20:52:16 +00:00
" data = (\n",
" sqlite3.connect(f\"{DATASET_PATH}/index.db\")\n",
" .execute(f\"SELECT uuid, model from data order by random() LIMIT {LIMIT}\")\n",
" .fetchall()\n",
" )\n",
2022-03-30 09:48:21 +00:00
"\n",
2022-04-05 20:38:05 +00:00
" for uuid, model in data:\n",
" img = tf.io.read_file(f\"{DATASET_PATH}/{uuid}.jpg\")\n",
" img = tf.image.decode_jpeg(img, channels=IMAGE_SIZE[2])\n",
" img = tf.image.convert_image_dtype(img, tf.float32)\n",
" img = tf.image.resize(img, RESIZED_SIZE[:-1])\n",
2022-04-11 20:52:16 +00:00
"\n",
2022-04-05 20:38:05 +00:00
" label = tf.convert_to_tensor(model, dtype=tf.uint8)\n",
2022-04-11 20:52:16 +00:00
"\n",
2022-04-05 20:38:05 +00:00
" yield img, label\n",
2022-03-30 09:48:21 +00:00
"\n",
2022-04-11 20:52:16 +00:00
"\n",
2022-04-05 20:38:05 +00:00
"def cutout(image, label):\n",
" img = tfa.image.random_cutout(image, (6, 6), constant_values=1)\n",
" return (img, label)\n",
2022-03-30 09:48:21 +00:00
"\n",
2022-04-11 20:52:16 +00:00
"\n",
"def rotate(image, label):\n",
" img = tfa.image.rotate(image, tf.constant(np.pi))\n",
2022-04-05 20:38:05 +00:00
" return (img, label)\n",
2022-03-30 09:48:21 +00:00
"\n",
2022-04-11 20:52:16 +00:00
"\n",
2022-04-05 21:31:33 +00:00
"def set_shapes(image, label):\n",
" image.set_shape(RESIZED_SIZE)\n",
" label.set_shape([])\n",
" return image, label\n",
"\n",
2022-04-11 20:52:16 +00:00
"\n",
"dataset = tf.data.Dataset.from_generator(generator=customGenerator, output_types=(tf.float32, tf.uint8))\n",
2022-04-05 20:38:05 +00:00
"\n",
"(dataset_length,) = sqlite3.connect(f\"{DATASET_PATH}/index.db\").execute(\"SELECT count(uuid) from data\").fetchone()\n",
2022-04-05 21:31:33 +00:00
"dataset_length = min(dataset_length, LIMIT)\n",
"\n",
2022-04-05 20:38:05 +00:00
"print(f\"dataset_length = {dataset_length}\")\n",
"print(f\"batch size = {BATCH_SIZE}\")\n",
"print(f\"number of batchs = {dataset_length // BATCH_SIZE}\")\n",
"\n",
"print()\n",
"\n",
2022-04-05 21:31:33 +00:00
"train_size = int(0.8 * dataset_length / BATCH_SIZE)\n",
2022-04-05 20:38:05 +00:00
"print(f\"train_size = {train_size}\")\n",
"print(f\"validation_size = {dataset_length - train_size}\")\n",
"\n",
"dataset = (\n",
2022-04-05 21:31:33 +00:00
" dataset.shuffle(SHUFFLE_SIZE)\n",
" .map(set_shapes)\n",
" .batch(BATCH_SIZE)\n",
2022-04-05 20:38:05 +00:00
" # .map(cutout)\n",
2022-04-05 21:31:33 +00:00
" .prefetch(AUTOTUNE)\n",
2022-04-05 20:38:05 +00:00
")\n",
"\n",
"dataset_train = dataset.take(train_size)\n",
"dataset_validate = dataset.skip(train_size)\n",
"\n",
2022-04-05 21:31:33 +00:00
"# print()\n",
2022-04-05 20:38:05 +00:00
"# print(RESIZED_SIZE)\n",
"# for boop in dataset_train.take(2):\n",
2022-04-05 21:31:33 +00:00
"# print(boop)\n",
"\n",
"# for image_batch, label_batch in dataset.take(1):\n",
"# print(label_batch.shape, image_batch.shape)\n",
"# pass\n",
"# for image_batch, label_batch in dataset_train.take(1):\n",
"# print(label_batch.shape, image_batch.shape)\n",
"# pass\n",
"# for image_batch, label_batch in dataset_validate.take(1):\n",
"# print(label_batch.shape, image_batch.shape)\n",
2022-04-11 20:52:16 +00:00
"# pass\n"
2022-04-05 20:38:05 +00:00
]
},
{
"cell_type": "code",
2022-04-11 20:52:16 +00:00
"execution_count": null,
2022-04-05 20:38:05 +00:00
"metadata": {},
2022-04-11 20:52:16 +00:00
"outputs": [],
2022-04-05 20:38:05 +00:00
"source": [
2022-04-11 20:52:16 +00:00
"model = Sequential(\n",
" [\n",
" InputLayer(input_shape=RESIZED_SIZE),\n",
" Conv2D(32, 3, activation=\"relu\"),\n",
" MaxPooling2D(pool_size=(2, 2)),\n",
" Conv2D(64, 3, activation=\"relu\"),\n",
" MaxPooling2D(pool_size=(2, 2)),\n",
" Conv2D(92, 3, activation=\"relu\"),\n",
" MaxPooling2D(pool_size=(2, 2)),\n",
" Flatten(),\n",
" Dense(250, activation=\"relu\"),\n",
" Dense(4, activation=\"softmax\"),\n",
" ]\n",
")\n",
2022-03-30 09:48:21 +00:00
"\n",
"model.summary()\n",
"\n",
2022-04-05 20:38:05 +00:00
"adam = optimizers.Adam(learning_rate=7e-6)\n",
2022-04-11 20:52:16 +00:00
"model.compile(optimizer=adam, loss=\"sparse_categorical_crossentropy\", metrics=[\"accuracy\"])\n",
"history = model.fit(dataset_train, validation_data=dataset_validate, epochs=25, batch_size=BATCH_SIZE)\n"
2022-04-05 20:38:05 +00:00
]
},
{
"cell_type": "code",
2022-04-11 20:52:16 +00:00
"execution_count": null,
2022-04-05 20:38:05 +00:00
"metadata": {},
"outputs": [
{
"data": {
2022-04-05 21:31:33 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAEICAYAAACktLTqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAzGklEQVR4nO3dd3hU1db48e9KIIQSem9SBAQMISGAl6JwUSkiiCAQEAiIoNjbffV6VS6Wnyh6fXmxXHqXIkVEEEVBsdNCB0EIEjqhE0La/v2xJxBCKkxyMjPr8zzz5Mw5Z86sMydZ2bP3PnuLMQallFLexc/pAJRSSrmfJnellPJCmtyVUsoLaXJXSikvpMldKaW8kCZ3pZTyQprcvZiILBeRQe7e10kiEi0id+bBcY2I3Oxa/kREXsnJvtfxPv1F5OvrjVOpnBLt516wiMj5NE+LAZeAZNfz4caYWfkfVcEhItHAUGPMSjcf1wD1jDF73LWviNQC9gGFjTFJbglUqRwq5HQA6mrGmBKpy1klMhEppAlDFRT6+1jwaLWMhxCRdiISIyL/IyJHgCkiUkZElorIcRE55VqunuY1q0VkqGs5UkR+FJExrn33iUjn69y3toj8ICLnRGSliHwoIjMziTsnMb4uIj+5jve1iJRPs32AiOwXkVgReTmLz6eliBwREf8063qIyGbXcgsR+UVETovIYREZJyIBmRxrqoi8keb5C67XHBKRIen2vUdENorIWRE5ICIj02z+wfXztIicF5G/pX62aV7fSkTWisgZ189WOf1scvk5lxWRKa5zOCUii9Ns6y4iUa5z+FNEOrnWX1UFJiIjU6+ziNRyVU89JCJ/Ad+51s93XYczrt+RxmleX1RE3nNdzzOu37GiIvKliDyR7nw2i0iPjM5V5Ywmd89SGSgL3AQMw16/Ka7nNYGLwLgsXt8S2AWUB94BJomIXMe+s4HfgXLASGBAFu+Zkxj7AYOBikAA8DyAiDQCPnYdv6rr/aqTAWPMb8AF4O/pjjvbtZwMPOM6n78BHYARWcSNK4ZOrnjuAuoB6ev7LwADgdLAPcCjInKfa9vtrp+ljTEljDG/pDt2WeBLYKzr3N4HvhSRcunO4ZrPJgPZfc4zsNV8jV3H+o8rhhbAdOAF1zncDkRn8h4ZuQNoCHR0PV+O/ZwqAhuAtNWIY4BmQCvs7/E/gBRgGvBg6k4iEgJUw3426noZY/RRQB/YP7I7XcvtgAQgMIv9mwKn0jxfja3WAYgE9qTZVgwwQOXc7ItNHElAsTTbZwIzc3hOGcX4rzTPRwBfuZZfBeak2Vbc9Rncmcmx3wAmu5aDsIn3pkz2fRpYlOa5AW52LU8F3nAtTwbeTrNf/bT7ZnDcD4D/uJZrufYtlGZ7JPCja3kA8Hu61/8CRGb32eTmcwaqYJNomQz2+29qvFn9/rmej0y9zmnOrU4WMZR27VMK+8/nIhCSwX6BwClsOwbYfwIf5cXflC89tOTuWY4bY+JTn4hIMRH5r+tr7llsNUDptFUT6RxJXTDGxLkWS+Ry36rAyTTrAA5kFnAOYzySZjkuTUxV0x7bGHMBiM3svbCl9PtFpAhwP7DBGLPfFUd9V1XFEVccb2FL8dm5KgZgf7rzaykiq1zVIWeAR3J43NRj70+3bj+21Joqs8/mKtl8zjWw1+xUBi+tAfyZw3gzcvmzERF/EXnbVbVzlivfAMq7HoEZvZfrd3ou8KCI+AER2G8a6gZocvcs6bs2PQc0AFoaY0pypRogs6oWdzgMlBWRYmnW1chi/xuJ8XDaY7ves1xmOxtjtmOTY2eurpIBW72zE1s6LAn883piwH5zSWs2sASoYYwpBXyS5rjZdUU7hK1GSasmcDAHcaWX1ed8AHvNSmfwugNA3UyOeQH7rS1V5Qz2SXuO/YDu2KqrUtjSfWoMJ4D4LN5rGtAfW10WZ9JVYanc0+Tu2YKwX3VPu+pvX8vrN3SVhNcBI0UkQET+BtybRzF+BnQVkTauxs9RZP87Oxt4Cpvc5qeL4yxwXkRuAR7NYQzzgEgRaeT655I+/iBsqTjeVX/dL82249jqkDqZHHsZUF9E+olIIRHpAzQCluYwtvRxZPg5G2MOY+vCP3I1vBYWkdTkPwkYLCIdRMRPRKq5Ph+AKKCva/9woFcOYriE/XZVDPvtKDWGFGwV1/siUtVVyv+b61sWrmSeAryHltrdQpO7Z/sAKIotFf0KfJVP79sf2ygZi63nnov9o87IB1xnjMaYbcBj2IR9GFsvG5PNyz7FNvJ9Z4w5kWb989jEew6Y4Io5JzEsd53Dd8Ae18+0RgCjROQcto1gXprXxgFvAj+J7aVzW7pjxwJdsaXuWGwDY9d0cefUB2T9OQ8AErHfXo5h2xwwxvyObbD9D3AG+J4r3yZewZa0TwH/5upvQhmZjv3mdBDY7oojreeBLcBa4CQwmqtz0HQgGNuGo26Q3sSkbpiIzAV2GmPy/JuD8l4iMhAYZoxp43Qs3kBL7irXRKS5iNR1fY3vhK1nXexwWMqDuaq8RgDjnY7FW2hyV9ejMrab3nlsH+1HjTEbHY1IeSwR6YhtnzhK9lU/Koe0WkYppbyQltyVUsoLOTZwWPny5U2tWrWcenullPJI69evP2GMqZDdfo4l91q1arFu3Tqn3l4ppTySiKS/qzlDWi2jlFJeSJO7Ukp5IU3uSinlhQrUTEyJiYnExMQQHx+f/c7KEYGBgVSvXp3ChQs7HYpSKgsFKrnHxMQQFBRErVq1yHwOCeUUYwyxsbHExMRQu3Ztp8NRSmUh22oZEZksIsdEZGsm20VExorIHtfUWGHXG0x8fDzlypXTxF5AiQjlypXTb1ZKeYCc1LlPBTplsb0zdlqtetip3z6+kYA0sRdsen2U8gzZVssYY34QkVpZ7NIdmG7sOAa/ikhpEaniGkNaKaU8RlISxMXBpUtXP+rVg8KFYd8+2L376m1JSRARAQEB8MsvsGEDJCdfeaSkwAsvgAgsXQrx8dAru5Hx3cAdde7VuHoashjXumuSu4gMw5buqVkz/YQ2zouNjaVDhw4AHDlyBH9/fypUsDeC/f777wQEBGT62nXr1jF9+nTGjh2b5Xu0atWKn3/+2X1BK6UuS0iAI0fgxAk4eRJOn4YzZ+Dee6FiRfjxR5g40a5L+1ixAurWhbFj4bnnrj3ugQNQvTrMmAGvZTCwdffuNrkvWgTvvnvt9ueeA39/+PJLOHXKc5J7jhljxuMa0jM8PLzAjVhWrlw5oqKiABg5ciQlSpTg+eevTDaflJREoUIZf2Th4eGEh4dn+x6a2JXKOWMgJgZKloRSpWzJeeJEiI298jhxAsaNg9tvhy++yDhxfv+9Te6HD8N339ljlSoFlStDgwaQ+md9xx02ORcpAoGB9meRIlCmjN0+aBB06HBlfZEitkQfFGS3v/IKPP+8TeRpH36uCvCPPrIl+PzgjuR+kKvnmKzO9c0BWSBFRkYSGBjIxo0bad26NX379uWpp54iPj6eokWLMmXKFBo0aMDq1asZM2YMS5cuZeTIkfz111/s3buXv/76i6effponn3wSgBIlSnD+/HlWr17NyJEjKV++PFu3bqVZs2bMnDkTEWHZsmU8++yzFC9enNatW7N3716WLr165rXo6GgGDBjAhQsXABg3bhytWrUCYPTo0cycORM/Pz86d+7M22+/zZ49e3jkkUc4fvw4/v7+zJ8/n7p1M5vOUqn8ZYxNeidOwIcfwq5dsHMn/PEHXLgAU6ZAZCQcPQqjR0PZslCunH3UqWOTLEDz5jb5p24rXdr+Y6hSxW5/4AH7yEyzZvaRmZtuso/MBAVdSfQZyc8mK3ck9yXA4yIyB2gJnHFXfXu7dteu690bRoyw9WJduly7PTLSPk6cuPY/+OrV1xdHTEwMP//8M/7+/pw9e5Y1a9ZQqFAhVq5cyT//+U8WLFhwzWt27tzJqlWrOHfuHA0aNODRRx+9pm/4xo0b2bZtG1WrVqV169b89NNPhIeHM3z4cH744Qdq165NREREhjFVrFiRb775hsDAQHbv3k1ERATr1q1j+fL
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAEICAYAAABVv+9nAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAA6UklEQVR4nO3dd3gUVffA8e9JILRQpIhAQDpICS0B6UU6SBWlCNIFRV6wIK8oIK+8NkR+iMCrFGkKqIBUiXQQkCaEYpBikC5EKRECJLm/P+4GIpAQQpLJbs7nefbJ7szszJksnL25c+dcMcaglFLK83g5HYBSSqnkoQleKaU8lCZ4pZTyUJrglVLKQ2mCV0opD6UJXimlPJQmeHVXIrJCRJ5L6m2dJCKhItIwGfZrRKS46/lkEXkrIdsm4jhdRCQosXHGs996InIiqfernJfO6QBU0hGR8FgvMwPXgCjX6+eNMXMSui9jTLPk2NbTGWP6JcV+RKQw8BuQ3hgT6dr3HCDBn6FSmuA9iDHGN+a5iIQCvY0xq27fTkTSxSQNpZTn0i6aNCDmT3AReV1EzgDTReQhEVkqIudE5C/Xc79Y71knIr1dz7uLyCYRGePa9jcRaZbIbYuIyAYRuSwiq0TkUxGZHUfcCYnxPyLyo2t/QSKSO9b6riJyTETCRGRYPL+faiJyRkS8Yy1rKyLBrudVRWSLiFwQkdMiMkFEfOLY1xci8k6s16+53nNKRHretm0LEflZRC6JyHERGRlr9QbXzwsiEi4i1WN+t7HeX0NEtovIRdfPGgn93cRHRB5zvf+CiOwXkVax1jUXkQOufZ4UkVddy3O7Pp8LIvKniGwUEc0vDtMPIO14BMgJPAr0xX72012vCwFXgQnxvL8acBDIDXwATBURScS2XwLbgFzASKBrPMdMSIydgR7Aw4APEJNwygCTXPvP7zqeH3dhjPkJ+BtocNt+v3Q9jwIGu86nOvAE8EI8ceOKoakrnkZACeD2/v+/gW5ADqAF0F9E2rjW1XH9zGGM8TXGbLlt3zmBZcB417mNBZaJSK7bzuGO3809Yk4PLAGCXO97CZgjIqVcm0zFdvdlBcoBa1zLXwFOAHmAvMAbgNZBcZgm+LQjGhhhjLlmjLlqjAkzxnxrjLlijLkMjAbqxvP+Y8aYz40xUcAMIB/2P3KCtxWRQkAgMNwYc90YswlYHNcBExjjdGPMr8aYq8B8oKJr+VPAUmPMBmPMNeAt1+8gLl8BnQBEJCvQ3LUMY8xOY8xWY0ykMSYU+N9d4ribp13x7TPG/I39Qot9fuuMMXuNMdHGmGDX8RKyX7BfCIeMMbNccX0FhABPxtomrt9NfB4HfIH3XJ/RGmAprt8NcAMoIyLZjDF/GWN2xVqeD3jUGHPDGLPRaKErx2mCTzvOGWMiYl6ISGYR+Z+rC+MStksgR+xuituciXlijLnieup7n9vmB/6MtQzgeFwBJzDGM7GeX4kVU/7Y+3Yl2LC4joVtrbcTkQxAO2CXMeaYK46Sru6HM644/ottzd/LP2IAjt12ftVEZK2rC+oi0C+B+43Z97Hblh0DCsR6Hdfv5p4xG2NifxnG3m977JffMRFZLyLVXcs/BA4DQSJyVESGJuw0VHLSBJ923N6aegUoBVQzxmTjVpdAXN0uSeE0kFNEMsdaVjCe7R8kxtOx9+06Zq64NjbGHMAmsmb8s3sGbFdPCFDCFccbiYkB280U25fYv2AKGmOyA5Nj7fderd9T2K6r2AoBJxMQ1732W/C2/vOb+zXGbDfGtMZ23yzC/mWAMeayMeYVY0xRoBXwsog88YCxqAekCT7tyort077g6s8dkdwHdLWIdwAjRcTH1fp7Mp63PEiM3wAtRaSW64LoKO797/1L4F/YL5Kvb4vjEhAuIqWB/gmMYT7QXUTKuL5gbo8/K/YvmggRqYr9YolxDtulVDSOfS8HSopIZxFJJyLPAGWw3SkP4idsa3+IiKQXkXrYz2iu6zPrIiLZjTE3sL+TaAARaSkixV3XWi5ir1vE1yWmUoAm+LRrHJAJOA9sBb5PoeN2wV6oDAPeAeZhx+vfzTgSGaMxZj/wIjZpnwb+wl4EjE9MH/gaY8z5WMtfxSbfy8DnrpgTEsMK1zmswXZfrLltkxeAUSJyGRiOqzXseu8V7DWHH10jUx6/bd9hQEvsXzlhwBCg5W1x3zdjzHVsQm+G/b1PBLoZY0Jcm3QFQl1dVf2wnyfYi8irgHBgCzDRGLP2QWJRD070OohykojMA0KMMcn+F4RSaY224FWKEpFAESkmIl6uYYStsX25SqkkpneyqpT2CLAAe8HzBNDfGPOzsyEp5Zm0i0YppTyUdtEopZSHcqyLJnfu3KZw4cJOHV4ppdzSzp07zxtj8iRkW8cSfOHChdmxY4dTh1dKKbckIrffwRyne3bRiEhB1+3UB1yV5f51l23quSra7XY9ht9v0EoppZJWQlrwkcArxphdriJMO0XkB9et3bFtNMa0TPoQlVJKJcY9W/DGmNMxFeNcFf1+4Z8FjZRSSqVC99UHL3YasUrYehW3qy4ie7DFil513Sp++/v7YmuRU6jQ7XWXlFLJ7caNG5w4cYKIiIh7b6wclTFjRvz8/EifPn2i95HgBC8ivsC3wCBjzKXbVu/C1oEOF5Hm2DsTS9y+D2PMZ8BnAAEBAToAX6kUduLECbJmzUrhwoWJe74W5TRjDGFhYZw4cYIiRYokej8JGgfvmuXlW2COMWbBXYK5ZIwJdz1fDqRP6PRgSqmUExERQa5cuTS5p3IiQq5cuR74L62EjKIR7DRdvxhjxsaxzSMxU7K5yp56Ef/kCkoph2hydw9J8TklpIumJrZE6F4R2e1a9gauyQuMMZOx06P1F5FIbP3uju48XZcxhhl7ZpA3S14aF2uMt1dckxwppVTqdc8E75o3M96vEmPMBOKfsNltXI+6Tu/FvZkVPAuAAlkL0L1id3pU7EGxnMUcjk4p9xYWFsYTT9iJns6cOYO3tzd58tibMrdt24aPj0+c792xYwczZ85k/Pjx8R6jRo0abN68+YFjXbduHWPGjGHp0gedQ8U5Wk0ylgsRF2g3rx1rQ9cyqt4oyuQpw7Td03h307uM3jiaeoXr0atSL9o91o7M6TPfe4dKqX/IlSsXu3fvBmDkyJH4+vry6quv3lwfGRlJunR3T0sBAQEEBATc8xhJkdw9hRYbczl24Rg1p9Vk0++bmNV2Fm/VfYv2ZdqzrPMyjg06xjv13+H3i7/TdWFX8n+Un/5L+7Pj1A7cuCdKqVShe/fu9OvXj2rVqjFkyBC2bdtG9erVqVSpEjVq1ODgwYOAbVG3bGnvpRw5ciQ9e/akXr16FC1a9B+tel9f35vb16tXj6eeeorSpUvTpUuXm/9fly9fTunSpalSpQoDBw68ud+4/Pnnn7Rp0wZ/f38ef/xxgoODAVi/fj0VK1akYsWKVKpUicuXL3P69Gnq1KlDxYoVKVeuHBs3bkzy31lCaQse2HV6Fy2+bMHVG1dZ+exK6hep/4/1ftn8GFZnGP+u/W82HNvA1J+n8sWeL5i8czLlHy5Pr0q9eNb/WXJljnNOZ6VSpXr17lz29NPwwgtw5Qo0b37n+u7d7eP8eXjqqX+uW7cucXGcOHGCzZs34+3tzaVLl9i4cSPp0qVj1apVvPHGG3z77bd3vCckJIS1a9dy+fJlSpUqRf/+/e8YM/7zzz+zf/9+8ufPT82aNfnxxx8JCAjg+eefZ8OGDRQpUoROnTrdM74RI0ZQqVIlFi1axJo1a+jWrRu7d+9mzJgxfPrpp9SsWZPw8HAyZszIZ599RpMmTRg2bBhRUVFcuXIlcb+UJJDmW/DLDy2nzvQ6+Hj78GPPH+9I7rF5iRf1CtdjVttZnH7lNJNaTCJDugwMWjmI/GPz0/nbzly6dvstAkqpe+nQoQPe3nYww8WLF+nQoQPlypVj8ODB7N9/xz2TALRo0YIMGTKQO3duHn74Yc6ePXvHNlWrVsXPzw8vLy8qVqxIaGgoISEhFC1a9Ob48oQk+E2bNtG1a1cAGjRoQFhYGJcuXaJ
2022-04-05 20:38:05 +00:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"def plot_training_analysis():\n",
2022-04-05 21:31:33 +00:00
" acc = history.history[\"accuracy\"]\n",
" val_acc = history.history[\"val_accuracy\"]\n",
" loss = history.history[\"loss\"]\n",
" val_loss = history.history[\"val_loss\"]\n",
2022-04-05 20:38:05 +00:00
"\n",
2022-04-05 21:31:33 +00:00
" epochs = range(len(loss))\n",
2022-04-05 20:38:05 +00:00
"\n",
2022-04-05 21:31:33 +00:00
" plt.plot(epochs, acc, \"b\", linestyle=\"--\", label=\"Training acc\")\n",
" plt.plot(epochs, val_acc, \"g\", label=\"Validation acc\")\n",
" plt.title(\"Training and validation accuracy\")\n",
" plt.legend()\n",
2022-04-05 20:38:05 +00:00
"\n",
2022-04-05 21:31:33 +00:00
" plt.figure()\n",
2022-04-05 20:38:05 +00:00
"\n",
2022-04-05 21:31:33 +00:00
" plt.plot(epochs, loss, \"b\", linestyle=\"--\", label=\"Training loss\")\n",
" plt.plot(epochs, val_loss, \"g\", label=\"Validation loss\")\n",
" plt.title(\"Training and validation loss\")\n",
" plt.legend()\n",
2022-04-05 20:38:05 +00:00
"\n",
2022-04-05 21:31:33 +00:00
" plt.show()\n",
2022-04-05 20:38:05 +00:00
"\n",
2022-04-05 21:31:33 +00:00
"\n",
"plot_training_analysis()\n"
2022-03-30 09:48:21 +00:00
]
}
],
"metadata": {
2022-04-05 20:38:05 +00:00
"interpreter": {
"hash": "e55666fbbf217aa3df372b978577f47b6009e2f78e2ec76a584f49cd54a1e62c"
},
2022-03-30 09:48:21 +00:00
"kernelspec": {
"display_name": ".env",
"language": "python",
"name": ".env"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
2022-03-30 08:59:31 +00:00
},
2022-03-30 09:48:21 +00:00
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
2022-03-30 08:59:31 +00:00
}