TP-reseaux-profond/TP3.ipynb

1674 lines
1.3 MiB
Plaintext
Raw Normal View History

2023-06-22 18:35:38 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "1qjWcjWFVcs7"
},
"source": [
"# Introduction à la librairie Keras"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "77ojmk9zVgUt"
},
"source": [
"Dans le TP précédent, vous avez implémenté l'apprentissage et l'inférence d'un réseau de neurones. En pratique, il est plus courant de faire appel à des librairies qui masquent la complexité de ces algorithmes (notamment le calcul des gradients, réalisé par différentiation automatique). Dans la suite, nous utiliserons pour les TPs la librairie ***Keras***. Dans un premier temps, pour ce TP nous allons détailler sur un exemple simple (le même que pour le TP précédent) les différentes étapes à mettre en place pour entraîner un réseau à l'aide de cette librairie."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b2Sq7AygNuNL"
},
"source": [
"## Exemple de classification simple"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "luY3XIU7WfWQ",
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAj7ElEQVR4nO2df4wlV3Xnv+d1zzRWshKb9mSdYPdOVpBIrE3spYPSsIS2PfEYgoKIFynsjyZMcGOEd3FEhDJjTfqNJsxk8W4yq4VsehxPQico2ZWGrCPkyMaDe+1VPwg9/LANBjRZEWOLhElHCfnlHs/02T9ul191dVW9+nGr7r31vh/p6fV7Xa/qVL163zr31DnniqqCEEJIuPRcG0AIIaQeFHJCCAkcCjkhhAQOhZwQQgKHQk4IIYEz6WKjV199te7fv9/FpgkhJFjOnz//l6q6L/m+NSEXkQkA6wCeV9W35i27f/9+rK+v29o0IYSMBSLyZ2nv2wytfADAMxbXRwghpABWhFxErgXwUwB+y8b6CCGEFMeWR34KwIcAbGUtICKLIrIuIusXL160tFlCCCG1hVxE3grgO6p6Pm85VT2tqrOqOrtv365YPSGEkIrY8MjfAOCnReSbAP4AwC0i8nsW1ksIIaQAtYVcVQ+r6rWquh/AzwL4jKr++9qWEUIIKQQLgggpwGAAnDxpngnxDasFQaq6CmDV5joJcc1gANx6K3DpErB3L3DuHDA359oqQobQIydkBKurRsSvXDHPq6uuLSJkJxRyQkYwP2888YkJ8zw/79oiQnbipNcKISExN2fCKaurRsQZViG+QSEnpABzcxRw4i8MrRBCSOBQyAkhzmF6Zz0YWiGEOIXpnfWhR07ImOPaG2Z6Z33okRMyxvjgDUfpnZENTO8sD4WckDEmzRtuW8iZ3lkfCjkhY4wv3jDTO+tBISekAoOBfQ+yiXWOIiRv2MXxCQUKOSElaSKu7DJWHYI37EMs32eYtUJISZrIsmDmRj48PvlQyAkpSRNNtNiYKx8en3wYWiGkJE3ElUOKVWfRZAy7C8enSURVW9/o7Oysrq+vt75dQkgzMIbdDiJyXlVnk+/XDq2IyMtE5E9E5Msi8hUROVZ3nYSQsGAM2y02YuSbAG5R1R8FcCOA20Xkxy2sl5DO4bocvikYw3ZL7Ri5mtjM322/3LP9aD9eQ4jndDn8wBi2W6zc7BSRCQDnAbwSwMdU9XMpyywCWASAmZkZG5slJCh8KIdvkhDy0buKlfRDVb2iqjcCuBbA60Tk+pRlTqvqrKrO7tu3z8ZmCQkKhh9IU1hNP1TVvxaRxwDcDuBpm+smJHQYfigGS/HLU1vIRWQfgBe3RfwqAD8J4D/XtoyQDsLwQz5dvo/QJDZCKz8A4DEReRLA5wF8WlU/ZWG9hJAxg2mM1bCRtfIkgJss2EIIGXN8aasbGizRJ4R4A+8jVINCTgjxCt5HKA+7HxJCStHV6tSQoZCT4AlJWKrY2u83Zk5poqySo0fNcwjHfBxgaIUETUjpalVtPXbMHzHvenVqqNAjJ0ETUrpaSLZmwepUP6GQk6BxJSxVQiRlbO33ARHzAIZ/2/LMq4ajoqyS48f9Hv1UIaQQXRJOLEGCp+2S7jrhnCq2igA2f6YhhaPaIpRjkjWxBGPkJHjaTlerEyf2IbWu7Th3CL1TQo/9U8gJKUnb1YdLS3bX16b9oXi6oVeUUsgJKUnb1Ye2M1batD8UTzf0ilIKOSEV8CFEUoe27I883c1NE+ufnm5+m1UJ+Ttl1gohI2gym8FlpkTatm3bMzcHnDplMnW2toB77gkzK8R36JETkkM8xjs5Cbz73cDCgh3PzWX8OG3bQDP2bGwYEd/a8ju8EjL0yEmnqethxmO8m5vA8rK90vQyBULx/bDhNadtu6mCJRYRNQ89ctJZbHi8kQi98ILJ5Va151UWzZSI78fEhIk1X75cz2tObnt6GvjiF82oA7AruKHfSAwBG1O9XQdgBcA/A6AATqvqf6u7XkLqYiNjIhKhlRXgzBmzLlsiV1Tg4vuxtWXeq3tBiW97etrErqMLxZ132gsfxbcHDL38tsU8hFz2OtjwyC8D+KCqfkFE/gmA8yLyaVX9qoV1E1IZW7nBUTbDwoJ9MSgicPH9SHrkdS4o0X6dPGnCRltb5gIxM2Nf7Hy7H9A1Mbcx1du3AXx7+++/FZFnALwCAIWcOMX2kL5KetooT7CIyCT3A7B7QZmeHnr6W1vNpAi6zCcPJZe9DlZj5CKyH2b+zs+l/G8RwCIAzMzM2NwsIZm4zA0uItJFRSa5H1X3KbqwTE+bbJL5efPc6xkR7/XMa9u4rJwMvWqzCNaEXES+F8BZAPeo6neT/1fV0wBOA6Zplq3tEuIrRUTahsgUjf+ePg28//3GHlUj2lNTJs97airbBhvxZZc3PMfhZqsVIReRPTAi/glV/aSNdRLiC1WFrIhI1xGZwWD3Tdis+O9gANx9t4mtR0R53Rsb2TbYjC+7HB2FXLVZBBtZKwLgAQDPqOqv1TeJEH+oI2RFRbpq7P3WW4dpkUB+aGZ11Yh9nF5veIHJssF2fNlFy+Eue+IRNjzyNwD4DwCeEpEvbb93RFUfsrBuQpxSV8ia8gQjuyIRF8kPzczPm/DJ5qbJfPmFXwBe/vLRAmczvtx29sg4ZKtE2Mha+b8AxIIthHiHrzfK4nYVaR2QNTqIqkSzBL3IqKKo11v0ohit78IF4IEHstc3inHIVnkJVW398drXvlYJCYW1NdUTJ8yzT+u18fmrrlKdmDDPVdZTZh1Flo0vA9Q75jb2zzcArGuKprJEn5AE/f7OHuAPP2y/J7iNYX/dsI0Nj7XMOop49/H1Ra/r3Fw9dQo4exa4444Oe+Ng0yxCdnHsWP5rGzTVoKoMNppZlV3H3Bxw+HC2qF64sPOm7JEj1SedHgxM64Fz57rfPpceOSEO8CH2biO/2naO9gMPAO95j1nfkSP1Jp0epxi5qM3puQsyOzur6+vrrW+XkCz6/WKe99KSWdZGWlvoqXFNHwORekLexawVETmvqrO73qeQE7KTpIAkX0cCEaXyffSjwOJi+3bapoww2xDJUevo94GDB+tdLEK/WCbJEnKGVggpyerqsFvg1pYpe7/hhrCFoqwwr6wMi5GaulF68KD7G8KhwJudhCRYWsp/PT9vPPGIrS03NyttMRgY73dzs/hsRWfODEcpk5PN3CgtckM4a7Ykl3OhOiEtJ7HpB/PISegsL6tOTqr2em5zlJeW6n0+yrXu9Uzedtb+xHPWT5wY5nmLqN51V73tZ+XCj8oDz/p/F/PHI8A8ckLssbhowiltx1+TMd9jx+rluEdeb9TC9sABs768xlmnTu3MuFlYKG5vkrzQx6iMmKzQzDhlq0RQyAmpSN34a9kbcWlx7Lok0yCTIg7sFsa8bomj7LUZ485K4fQhtbNtKOSEOKCKyMUF9R//EXj96837st3pKEqNLEOa15usbE0TxiIXsaY94yyPfRz6jyehkBPigCoilxTUc+eMmKsOvfvBoL7XmwzXxEvdb7yx+ATKbXjGWReUcclWiaCQE+KAKiKX18Ew8u4nJoBDh/I7IZYlKnXf3AQeeWQ4s9CoUcQ4esauYPohIQ6YmwPe+U7g+PH0Qpi8z8V7lSwt7fbul5eNsJdJvev3TYgmCtNEf/f7O2+IAuZ5c7NYyuWo3irEDhRyQhxx5ky6yJVp0tXvm4mU4yIcL9Ipsx6TUDhch6p5Pxo99GJqsbVltkv8gEJOSEAkC12isEeUPrhnT3aBTZEimWg0EF82CpEcODC8WPR6JnuFlKOxQqW05PKyDwBnAHwHwNNFlmdBEBlXlpYiX3fn401vSn8/XvCTVugSL86ZmDDFOWkFNkWLZADVQ4fyC216PVMMtbycvY9N0tREH01jo1AJGQVBtjzy3wFwu6V1EdJZskIYq6s7319bA06cMP1GItIyXZJl7gsL6eGaMv3PX/nK9GWj7JWJCTMCyOrx3UT/9ojoxu7Ro+XvA7imyR70VoRcVR8H8Fc21kU6hO1pdTyhjT4ea
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from sklearn.model_selection import train_test_split\n",
"from sklearn import datasets\n",
"import matplotlib.pyplot as plt \n",
"\n",
"# Génération des données \n",
"x, y = datasets.make_blobs(n_samples=250, n_features=2, centers=2, center_box=(- 3, 3), random_state=1)\n",
"# Partitionnement des données en apprentissage et test\n",
"x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state=1)\n",
"\n",
"# Affichage des données d'apprentissage\n",
"plt.plot(x_train[y_train==0,0], x_train[y_train==0,1], 'b.')\n",
"plt.plot(x_train[y_train==1,0], x_train[y_train==1,1], 'r.')\n",
"\n",
"# Affichage des données de test\n",
"plt.plot(x_test[y_test==0,0], x_test[y_test==0,1], 'b+')\n",
"plt.plot(x_test[y_test==1,0], x_test[y_test==1,1], 'r+')\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "lTGP4a9WXWpU",
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" dense (Dense) (None, 1) 3 \n",
" \n",
"=================================================================\n",
"Total params: 3\n",
"Trainable params: 3\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-03-25 18:31:29.078560: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudnn.so.8'; dlerror: libcudnn.so.8: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /applications/opam-2.0.4/default/lib/stublibs/::/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/extras/CUPTI/lib64\n",
"2022-03-25 18:31:29.078622: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1850] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.\n",
"Skipping registering GPU devices...\n",
"2022-03-25 18:31:29.079207: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
]
}
],
"source": [
"import tensorflow\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import Dense\n",
"\n",
"# Définition du modèle, auquel on va ensuite ajouter les différentes couches, dans l'ordre\n",
"# NB: c'est exactement ce que nous avons implémenté avec le perceptron multicouche dans le\n",
"# TP précédent ! \n",
"model = Sequential()\n",
"model.add(Dense(1, activation='sigmoid', input_dim=2)) # input_dim indique la dimension de la couche d'entrée, ici 2\n",
"\n",
"model.summary() # affiche un résumé du modèle"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "nZEDm5I-Lu-p",
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/15\n",
"9/9 [==============================] - 0s 26ms/step - loss: 0.2186 - accuracy: 0.9444 - val_loss: 0.2752 - val_accuracy: 0.9111\n",
"Epoch 2/15\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.1951 - accuracy: 0.9500 - val_loss: 0.2616 - val_accuracy: 0.9333\n",
"Epoch 3/15\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.1809 - accuracy: 0.9500 - val_loss: 0.2516 - val_accuracy: 0.9333\n",
"Epoch 4/15\n",
"9/9 [==============================] - 0s 4ms/step - loss: 0.1698 - accuracy: 0.9500 - val_loss: 0.2434 - val_accuracy: 0.9333\n",
"Epoch 5/15\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.1611 - accuracy: 0.9500 - val_loss: 0.2363 - val_accuracy: 0.9333\n",
"Epoch 6/15\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.1543 - accuracy: 0.9556 - val_loss: 0.2300 - val_accuracy: 0.9333\n",
"Epoch 7/15\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.1483 - accuracy: 0.9556 - val_loss: 0.2242 - val_accuracy: 0.9333\n",
"Epoch 8/15\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.1428 - accuracy: 0.9611 - val_loss: 0.2190 - val_accuracy: 0.9333\n",
"Epoch 9/15\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.1382 - accuracy: 0.9667 - val_loss: 0.2145 - val_accuracy: 0.9333\n",
"Epoch 10/15\n",
"9/9 [==============================] - 0s 6ms/step - loss: 0.1341 - accuracy: 0.9667 - val_loss: 0.2104 - val_accuracy: 0.9333\n",
"Epoch 11/15\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.1303 - accuracy: 0.9722 - val_loss: 0.2066 - val_accuracy: 0.9556\n",
"Epoch 12/15\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.1270 - accuracy: 0.9722 - val_loss: 0.2030 - val_accuracy: 0.9556\n",
"Epoch 13/15\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.1239 - accuracy: 0.9722 - val_loss: 0.1997 - val_accuracy: 0.9556\n",
"Epoch 14/15\n",
"9/9 [==============================] - 0s 6ms/step - loss: 0.1210 - accuracy: 0.9722 - val_loss: 0.1967 - val_accuracy: 0.9333\n",
"Epoch 15/15\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.1183 - accuracy: 0.9722 - val_loss: 0.1940 - val_accuracy: 0.9333\n"
]
}
],
"source": [
"from tensorflow.keras import optimizers\n",
"\n",
"# Définition de l'optimiseur\n",
"sgd = optimizers.SGD(learning_rate=0.1) # On choisit la descente de gradient stochastique, avec un taux d'apprentissage de 0.1\n",
"\n",
"# On définit ici, pour le modèle introduit plus tôt, l'optimiseur choisi, la fonction de perte (ici\n",
"# l'entropie croisée binaire pour un problème de classification binaire) et les métriques que l'on veut observer pendant\n",
"# l'entraînement. La précision (accuracy) est un indicateur plus simple à interpréter que l'entropie croisée.\n",
"model.compile(optimizer=sgd,\n",
" loss='binary_crossentropy',\n",
" metrics=['accuracy'])\n",
"\n",
"# Entraînement du modèle avec des mini-batchs de taille 20, sur 15 epochs. \n",
"# Le paramètre validation_split signifie qu'on tire aléatoirement une partie des données\n",
"# (ici 20%) pour servir d'ensemble de validation\n",
"history = model.fit(x_train, y_train, validation_split=0.2, epochs=15, batch_size=20)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fD8fXJj0WJID"
},
"source": [
"La cellule suivante introduit un code permettant de visualiser la frontière de décision du modèle appris. "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "8iSYRgNaL6F-",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"import numpy as np\n",
"def print_decision_boundaries(model, x, y):\n",
" dx, dy = 0.1, 0.1\n",
" y_grid, x_grid = np.mgrid[slice(np.min(x[:,1]), np.max(x[:,1]) + dy, dy),\n",
" slice(np.min(x[:,0]), np.max(x[:,0]) + dx, dx)]\n",
"\n",
"\n",
" x_gen = np.concatenate((np.expand_dims(np.reshape(x_grid, (-1)),1),np.expand_dims(np.reshape(y_grid, (-1)),1)), axis=1)\n",
" z_gen = model.predict(x_gen).reshape(x_grid.shape)\n",
"\n",
" z_min, z_max = 0, 1\n",
"\n",
" c = plt.pcolor(x_grid, y_grid, z_gen, cmap='RdBu', vmin=z_min, vmax=z_max)\n",
" plt.colorbar(c)\n",
" plt.plot(x[y==0,0], x[y==0,1], 'r.')\n",
" plt.plot(x[y==1,0], x[y==1,1], 'b.')\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "ltsUweGrMPor",
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWIAAAD8CAYAAABNR679AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAABl4UlEQVR4nO29eXwjaXnv+32rJFmLd7s3t917T88wC7P27AsDYTsk8yEhCYTkhISbyXK4OSfbvSfhXsiFcz5JblhOOEAOQyAsZw6BEAgDgRByCQwwW8++T09Pr+5ud7fdbVu2JUuqeu8fJaneKqkkWS5Zsv3+Ph9/WlKV3npltx8/9Xt+z+8RUko0NDQ0NNoHo90b0NDQ0Fjv0IFYQ0NDo83QgVhDQ0OjzdCBWENDQ6PN0IFYQ0NDo83QgVhDQ0OjzdCBWENDQ2MJEEJ8VghxVgjxbMBxIYT4mBDikBDiaSHE1fXW1IFYQ0NDY2n4HPDGGsffBOwtft0N/HW9BXUg1tDQ0FgCpJT3A+drnHIX8AXp4CGgXwixpdaakTA32Cj6IhG5OdYFgLTdzj61yU99HUAGPK54n3K09nnB6wW9p+JYjWdBR2r2MTbY5NjKXshW91nqPk6NpWKS3KSUcsNy1jB6RyWFbN3zZGbqOUA98R4p5T1LvNxW4ITyfLz42umgN7QlEG+OdfHXey4FIJ8tlF+3c1b5cX7R8rwnpwTmnC9Iq88tqZ7nva73mPqe4PPUY5YvKqvL+49511v6ebVQa41m1mtm7WbR6J52sshFZDlInCN0tXRPGp2NT3Hs2LIXsRaJXvLWuqflHv+brJTy2mVfb4loSyDW0KiFnSzye5zBRGIh+CibdDDWWDaEYa7UpU4CY8rz0eJrgdAcsUbH4SKymEhMwERyEfVvKTU0akMgDLPuV0i4D/j3RfXEDcCMlDKQloA2ZcTCMOjqdTIcI+Z++ELGpSmMmJeaMJVjsYKXcwiiLUzhpzCU9YSo+h7wUwnVH/ufq3/RfIwIKjOqXrfyNr36fbv/PHUN73nSd17V0wKuXXvtaus3g6A9qfs5SBwLAcWM+CDxZV9XY51DiNACrRDiS8AdwLAQYhx4PxAFkFL+D+DbwJuBQ8AC8Gv11tTUhEbH4QhdfJRNmiPWCA1CCMxoLJS1pJTvqHNcAv9hKWu2JRAbEUFyOAFAbj5ffr0Qd7NetYgHYEaVzNl/TCnyqdlyraKeN3P27s+S1bPlWhlxrcw5OFv2Z5hB2WhjBT5/Nlsrg1U/c6MFtFZmy/6fwXG6OF4KwFpqoRECVpAjXjJ0RqyhobH2ESI10QroQKyhobHmIXBqU52K9lATUYPEcNLZQCJXfj0359IUER/9kIu655kx7zfUUqpwKm2hUhYAZt59rt4KV1IT1Y9V6pLrP/a/T/2b7P9v4aUL1EVqVN1q3rc3RiU0U9Tzw1uEDJdLCGN/ax1ad10POiPW0NBoIbTuugFoakJDY31hpbNTVXdNUXetA7EPQmCEpJpoBdqkmoiQ3DAAQC4+724m7gr3VTWFcyyiHMt5jqn6Y5W2qFRXuH8Ro6ou2Q7WJTdKYaj0Q6Xut/raFdSBcktvKLf6lddtHW3RqBa5FiXQai2y91pB1wn1Mg2jHdmp1l3Xh8MR64xYQ2NdoB3ZqdZdNwBNTWhorB+0Kzs9QpcOwDUhMHQg9sKImCQ29jsbSLn/UXPphfLj6HzG8578vEtbRBLebefnXKpCbQRRKQvwKSoUCsP0nRdVeAaVtqjVINKouqKZ1uoKL84A2qLytr8Z2qK1SotG27NV7JSL7CXLy8Q5IhoLNkE0Uqv5207NTte9qkJoakJDo2nslIv8rsK5fkxuajgYV6y1Qvxtp2WnWlUBAoER0cU6D4xYlNTIRgDyc26xLqpmx7MLnveYcfd5JO7LlgMKeWrLtHPMLQCaitmQ2j4N3qxa1SJHfamuabnPVYMhv9lQUFHPn7UFFfxqt1YrTyqyzWY0xsF7b7QIF6bZ0D6ymNLlXPc2ybmaAi6WXv72YrIcF11rXousVRVojlhDYzl42ce5vrwMzjXMtVYTtKoCdEOHhsYycFTE+ZjctGSOuBqOiK7Q1molwuZzO5W3XlEIEKYOxB6ISBRzw1bncWLa3Uxytvw4mkp43hMNKOoBRJRCnjnrUh35+UXveQmVwqju+gbe9mq14FfhCKe6vimjnWoX9YLd3ILahJttrW7UI9mLxuiNsD2Sa9EDx404x4tZnP9Xaam65COikr/tJC1yq/jcTuOtVxpirWfEQog4cD/QVVzvq1LK9y93XQ2N9QjN57YI64AjXgTulFLOCSGiwI+FEN8pjpHW0GgLmpG8dQI0n9s6mJHOZWKXvbOiG/1c8Wm0+FXzpk5EYkQ2OtSEneotv26nL5QfGwplAV69sZ+2yKXVNmlXopKvpUWOLyqve9upVQoiH3WPqWOdIHi0k1+XHGRW76cwVNoiaKyT/32Nqyu8xxo1qw8e7RSuWX2t9vFaqEaJ7JBZ3hOC5K0drm+az20NhBAIo5auvr0I5U+EEMIEHgP2AJ+QUj4cxroaGs1gr+/2vlnJW7uw3vncVkHUkF+2G6E4JUspLSnllThjo/cLIS7znyOEuFsI8agQ4tFzF6bDuKyGRlWUZGoWrCuZmkZtGIao+9UuhEqaSCmnhRD/BrwReNZ37B7gHoBrrrhUyv7NAJiJnvI5Ip50HyuUBYBIT5cfm6lZz7FIasZ9nAxWV0QVasJDZyS849ojSlNINF5daQHBM/b8DSJBZvX+adQZT4OI+3otN7eg+Xr+9zVKW/j/MgfTFp3n+lbCURHn47gyteMi7lFbtHIadaN71FhhCNY2NSGE2ADki0E4AfwU8BfL3pmGxjJwVMQ5uk4z4XXvK1EFjg3mGg7EwBbg80We2AC+IqX8Vs13mFHsnk0AyNhc+WUj5hbhIr6M2E66mbM9N+05JuIpd2klW471znnOW5xOK+u7v6R5Xzt1PuUW+bxmQ14f5KDRTupYJwj2SLZ8AmEjwCO5sqhH1WOVRT33ca2MuDM8ktsz1sm/fhhopy5Z+0oEQAhMcw3PrJNSPg1cFcJeNDQ0lgmtQw7GWs+INTQ0OgRah1wdQtDWYlw9tCUQ28IgE3HohK6oS0fImFKsi8173mN0ufRDJW2haJEV2sJQdMkAZsJdvzDn0ha5lPc/q+r8Vsv1LWi0kxn13gIFeSRXjHIK8Ej2u74F0Ra1KIyak6U7wiO5fdOom/FIbgbNaqWXAq1DDoboXGZCZ8QaGmsNWodcHZ2sI9aBWEOjjdAKh5WBEAIz0rkpcVsCccGG88Vb75RyG5+IucqIWNTbxixz7nMRTXqOGVFVbaGoKxSlBYClaJENRV0RSU57zoumGjOrj6QURYXq+hb3jXJSW6YDxjqBt9Va1SJbectzXhBtUWsadW3lRfXXIVin3KxZfTBtEa4uedTOsEdmOSTiHBPxUM3qnfcsn1fYg2/6CE4rttYhtwa6WKexpjFiZRizMhwzEpw0E/Xf0GJsl1l+0y4GOCn4pLGJY6LzilarvRV7VUF45ZedBh2INZaFESvDOxZPlbO6e2MjHDfaG/R2+0Yi7ZHZjgzE63ViSDuwHho6lgzLlpzPOLfbi8p9WKLgPk762oQTXX3lx1E/baE0ghge5UXKc14kwOlNbQgBECml8SPpnlfh+qbSETVbq5UGkYD5euCdo6e2T1eoK5TvTdB8PQBTMav3z9ELmrFXW11R+frOghP0DEAi2SWzTBju96mSwgiiLWrTD2N2ll0ywyHinkBfTRnyivAGuCNGItB037+Gd6/BOwqDtgiaGKLbp1uBdeC+prF+MW4msBBIJDaCEy2gJsbsLO8unMZE8hoEn2Zzzaz7mIjzKXMzu2WWV4oc8Upgh8yylywvycY9kKtNDFktWFWFRq0jrkTetjk95/gB9yhZYJ+SLWYtb4UzoVQ8E5Go51gqMVh+bCqFPOHLnNUWalWXbPh1yUpRz1aMiAwlUwaI9rgFv9y0q0uO+nXJSoasjnIy4z6zoQCPZH9RT82WDeU825cRG2q23CKP5
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"print_decision_boundaries(model, x_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mSaSWEnoNxqG"
},
"source": [
"## Exemple de classification plus \"complexe\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "W60rDDAzWTpq"
},
"source": [
"Pour manipuler un peu la librairie, voici un second problème légèrement plus complexe. A vous de réutiliser les cellules précédentes pour mettre en place un réseau permettant de résoudre ce problème."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "KvhN3uQaN5ji",
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAgmklEQVR4nO2df4wdV3XHv2d3vU4KVJWcqFCSrSsVVbQENWKV6qUVbPOLqKJQCkhFFQ4xtWWJqERqReuk6b7UYNNGQkYkEus0QbEUQakCAqFUIT9YEbQbyjqEAgnQgIQJghJcpRBhO9nd0z9mxzs7Oz/uzNw7c+/M9yM9vX0/dubMvJnvnDn3nHNFVUEIISRcJro2gBBCSDMo5IQQEjgUckIICRwKOSGEBA6FnBBCAmeqi5VecMEFunv37i5WTQghwXLixImfqeqF6fc7EfLdu3djZWWli1UTQkiwiMgPst5naIUQQgKHQk4IIYFDISeEkMChkBNCSOBQyAkhJHAo5ENlPO7aAkKIJSjkQ+XWW7u2gBBiCQo5IYQEDoV8SIzHgEj0ADb/ZpiFkKCRLiaWmJ2dVVZ2dowIwElFCAkKETmhqrPp9+mRE0JI4FDIh8r8fNcWEEIsQSEfKoyLE9IbKOSEEBI4FHJCCAkcCjkhhAQOhZyQvrC8DBw5Ej2TQdHJDEGEEMssLwNXXgm88AIwPQ08/DAwGnVtFWkJeuSE9IHFxUjE19ai58XFri0iLUIhJ8OhaughpFDF3FzkiU9ORs9zc11bRFqEoRUyDKqGHkILVYxGkY2Li5GI+2wrsQ49chImVb3lqqGHEEMVoxFw8CBFfIA09shF5GIAxwH8OgAFcExVP9J0uYTkUsdbjkMP8f+UhR6qfp+QDrERWlkF8Deq+riIvAzACRF5UFWftLBsQraT5S2XCXnV0ANDFSQgGgu5qv4YwI83/v6FiDwF4JUAKOTEDXW95dGomiBX/X6ILC/zYtUDrA52ishuAJcC+ErGZ/sB7AeAmZkZm6sltujipK6zTnrLdghtQJfkYk3IReSlAO4DcKOq/jz9uaoeA3AMiCaWqLUSeg/u6OKkbrLOIXjLrqkToiJeYiVrRUR2IBLxe1X10zaWuY34pL/llug5hNzekOgiSyPEzJA+wdzz3tBYyEVEANwF4ClV/XBzk3LgSe+WLk5q2+sMqYDHB+IQ1aFDDKsEjo3Qyh8CeBeAb4jIExvv3aSq91tY9iZMB3NLF3Fnm+tkvLceDFH1AhtZK18GIBZsKYYDXO7p4qS2tU7Ge8mACauyk5VrJA/f4r2mU+kVhYMYKiKGiGq9BJImzM7O6srKSuvr7T1Dz+rxaftFgLJzqygcxFARyUBETqjqbPp9Ns3qCzzxw4v3FoWDGCoiFQgrtELyYVZP94zHkScuG0NG8d95YZaicJCrUBHDNb2EoZW+4JtH7lOYowtMQitA8X6yvQ99O0ZIZRha6Ts+ZfVQMMwpCgfZDhUxXNNbKOR9wpcYcVuC4bPXPz/ftQXbYS1Gb6GQE/u0IRhJr39yEti7F9izxx9BN00/bBOf7tqIVSjkxD5tCEbS619bAxYWgHvuYRinDF/u2ohVKOTEDa4FI/b6z5yJBhVV+xv39TmERLyAQk7CJPb6jx8HPv5xYHW1n3HfpgPHvAgMAgo5CZfY69+zp79i1WTgmNlDg4FCTsKnz3HfJgPHTDccDBRy4j9DDg80GThmuuFgoJATv2F4oP4dB9MNBwOFnPhNWXhgyN66CX0OO5FzUMiJ3xSFB+itEwKA3Q8J4HdHvKJ5JdnxkRAA9MhJCF5tXniAg3mEAKBHTnz3aovuFnyfBd7nOx3SK+iRDx1XXq2NQUiTuwUfBvPG4+1NskK40yG9gR5527jy0uou14VXG4vYLbdEz3W31fe7hZhbb93+Xii2k15Aj7xNXHlpTZfr4wQGy8vAyZNRi1qgvRi4rXRGxu9Ji9AjbxNXXppv3l9yvsnJyUiQq3jl8YXpzjujKdP27Su/ONno/13lTqJsfk7f4/ekV1DI28TVhLqulluXWMT27YvE7c47q4VYkhem1VVgZqZcCLPCG1WpckEcjzfb5wKbfycvKKMRcPAgRZw4h6GVNnFVMm1zubZCC6NRtJzV1eohlq7CEnNzWJuaBtZfAKamMVlnvaw0JV2gqq0/Xve61ynxkKUl1fPPV52cjJ6Xluot4/Dh6Nl0ecn/KXovzfx87AdvfczPV7d7Y5VvmF7Sm+SwvmF6yXzz4/XF2zsxoTo1pbqwUMsOp5js17bx0SZPAbCiGZpKISebHD4ciS4QPR8+XO3/s4S77CS1cfFQ3TiUm9F08/Xw4UjE4wvKjh1+iZOtfd13mzwmT8gZIyebNI2152WrFMWJPRqobTzUMDcHTCROqbW15ttjM13Vo319Dh9tChArMXIRuRvAmwD8VFVfY2OZpAOaxtrrxLZtxcPn5+v9X4LGQw2jEXDHHcANN0DX1rA6uRPf3jWHS+oaZDtd1ceUSB9tChDReNS9yUJEXg/geQDHTYR8dnZWV1ZWGq+XeEJygA+oroQ9GyD8xrFl/Pt7F/HI+hwe3zmqr79HjkSpkGtr0W3CoUPR3U0TfNzXPtrkKSJyQlVn0+9b8chV9UsistvGskhgZHmNVcXGhzJ7i3z+1AiHdYS1dWCyyQxrLrxVH/e1jzYFRmsxchHZLyIrIrLy7LPPtrVa4pqGMc4+9pUyibUbbTeLioghreWRq+oxAMeAKLTS1nq9oo+3kA28Rp/6Stn8acpi7ZW2O/4gvkDmfbGPxxYxhgVBbeGTatmkwghhWmt8meS9yk9jqpdF0YLS7U6POZQZ19djixhDIW8Ln1RrcRHYtQs4dcqeC1qyjCyt8SVhIS86lBZsW3pZuN3plVx3Xflx48uxRTrDVvrhJwDMAbhARJ4BMK+qd9lYdm/wQbVikTh7Flhfj3Ked+5sxYPL0pqDB/2Y5D390+zalS3YtvSy8CYmvRKg/Ljx4dginWIra+WdNpbTa6omKbuIecYisb4evV5fb82Dy9OaMmc+azfk7Zq6uyz90+QJtk29zN3u9Er27IkeRRvmqocPsY+rsYysck/XD5bol+CqbDnZCwSInsuWb7EPRtVF5VX8Z+0am7usaFlF22BtV7H3SD+xcJAip0SfMXIfcRXzTHpuJjFyy4NoVdOF82LXWbvG5i5LO7hAlCoY76qs5VrdVUUrodcdLg7HMijkPuIy5llFTR0eeFnTXKbJ2w3p9+LJhKamNj9vusvi3WQq0M7HG0PJTOHFJh+H5zWF3AfSB78vMc8KB17V8/fWW8uFPG83pL3lWN8mJ6O5LPbssT+sUCbQzscbQ8hMCeVi0xUOz2sKedfkHfw+lC0bHHjLy8Dx48Ddd0caMz0NHD3qNrMx+d6RI5v6BphNJlSFubnI019fj57zBNr5tTeEzJQQLjZd4+i8ppB3je8Hf8GBF1+DzpzZnPHs7Fngve+NXqedsvF464xs8XSX8/P1p9w00TeTME4RydncinB67fXlLq2IEC42PYVC3jUBH/zxNSgWOJEovLG2lp3ZmBRUkXJhNMFE38rCOEU1UouL0faobrYX70xDfbhLKyKEi01PoZB3TcAHf/IaNDUFXH89cOmlwI03tntdaqJvZTVSptdZb8b4XBhSZZm+X2z6SlZOousH88jDxHRqTdPpNl3aZjKd59KS6jXXbJ2dLWuat7Zmq2uMC0NcLdM0T5459VsA5+wk56hxcngjVhmU2ZY1nWe6NkrEvEYqTeO5Pm3hwhDby6xyIPl80HVEnpBzzs6hEccSbrklejZsBN6w7bgT4p7ex49Xty3ZrWBiArj6amBhAfjAB6pnzTWe69MWLgypskyTEeUqB5KPB52nMEY+NGpmyfg2JpvM2pyainQGyLYtazrP9PaMx/YqQV2GiAvD1S4MqbJMk+KAKgeSbwedx1iZs7MqnLOzQxoUbbQxoGe6jvR0lvv2RTnkfZ4q1Pt6G9NUpCo7PrQfyTF5c3YyRj5EPB1ASodEFxaKG1QNLXzqTSw+icmoMrEGcmLk9MiJNyS97
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"x, y = datasets.make_gaussian_quantiles(n_samples=250, n_features=2, n_classes=2, random_state=1)\n",
"# Partitionnement des données en apprentissage et test\n",
"x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state=1)\n",
"\n",
"# Affichage des données d'apprentissage\n",
"plt.plot(x_train[y_train==0,0], x_train[y_train==0,1], 'b.')\n",
"plt.plot(x_train[y_train==1,0], x_train[y_train==1,1], 'r.')\n",
"\n",
"# Affichage des données de test\n",
"plt.plot(x_test[y_test==0,0], x_test[y_test==0,1], 'b+')\n",
"plt.plot(x_test[y_test==1,0], x_test[y_test==1,1], 'r+')\n",
"\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential_1\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" dense_1 (Dense) (None, 10) 30 \n",
" \n",
" dense_2 (Dense) (None, 10) 110 \n",
" \n",
" dense_3 (Dense) (None, 1) 11 \n",
" \n",
"=================================================================\n",
"Total params: 151\n",
"Trainable params: 151\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n",
"Epoch 1/20\n",
"9/9 [==============================] - 0s 24ms/step - loss: 0.8085 - accuracy: 0.4444 - val_loss: 0.7453 - val_accuracy: 0.2667\n",
"Epoch 2/20\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.7223 - accuracy: 0.3500 - val_loss: 0.6925 - val_accuracy: 0.4889\n",
"Epoch 3/20\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.6862 - accuracy: 0.5333 - val_loss: 0.6650 - val_accuracy: 0.6000\n",
"Epoch 4/20\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.6658 - accuracy: 0.6000 - val_loss: 0.6499 - val_accuracy: 0.6444\n",
"Epoch 5/20\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.6495 - accuracy: 0.6500 - val_loss: 0.6355 - val_accuracy: 0.6889\n",
"Epoch 6/20\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.6341 - accuracy: 0.6944 - val_loss: 0.6214 - val_accuracy: 0.7333\n",
"Epoch 7/20\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.6157 - accuracy: 0.7833 - val_loss: 0.6037 - val_accuracy: 0.8222\n",
"Epoch 8/20\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.5970 - accuracy: 0.8222 - val_loss: 0.5851 - val_accuracy: 0.8444\n",
"Epoch 9/20\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.5773 - accuracy: 0.8556 - val_loss: 0.5629 - val_accuracy: 0.8444\n",
"Epoch 10/20\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.5555 - accuracy: 0.8722 - val_loss: 0.5383 - val_accuracy: 0.8444\n",
"Epoch 11/20\n",
"9/9 [==============================] - 0s 4ms/step - loss: 0.5328 - accuracy: 0.8778 - val_loss: 0.5116 - val_accuracy: 0.8667\n",
"Epoch 12/20\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.5093 - accuracy: 0.8611 - val_loss: 0.4857 - val_accuracy: 0.8889\n",
"Epoch 13/20\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.4856 - accuracy: 0.8944 - val_loss: 0.4567 - val_accuracy: 0.9111\n",
"Epoch 14/20\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.4606 - accuracy: 0.8944 - val_loss: 0.4294 - val_accuracy: 0.9111\n",
"Epoch 15/20\n",
"9/9 [==============================] - 0s 6ms/step - loss: 0.4375 - accuracy: 0.8889 - val_loss: 0.4041 - val_accuracy: 0.9111\n",
"Epoch 16/20\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.4195 - accuracy: 0.8944 - val_loss: 0.3805 - val_accuracy: 0.9333\n",
"Epoch 17/20\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.3965 - accuracy: 0.9056 - val_loss: 0.3602 - val_accuracy: 0.9333\n",
"Epoch 18/20\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.3770 - accuracy: 0.9000 - val_loss: 0.3413 - val_accuracy: 0.9333\n",
"Epoch 19/20\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.3600 - accuracy: 0.9222 - val_loss: 0.3242 - val_accuracy: 0.9333\n",
"Epoch 20/20\n",
"9/9 [==============================] - 0s 5ms/step - loss: 0.3447 - accuracy: 0.9167 - val_loss: 0.3099 - val_accuracy: 0.9333\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWIAAAD8CAYAAABNR679AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAABpBklEQVR4nO29eXgcWXnv/zlV3a19lyVZlrzv+ybL23hg2IYlTCBAWJMACTcESG4u2bkJWW4u4ULCDwIhzMAQSIYdBgwzYTIEGI9nbFnexh7b432TLVvWvqu7qs7vj2q1uhapW1JL3ZLO53n02FV96tTpRa9Of8/7fo+QUqJQKBSK9KGlewAKhUIx11GBWKFQKNKMCsQKhUKRZlQgVigUijSjArFCoVCkGRWIFQqFIs2oQKxQKBTjQAjxqBCiRQjx4iiPCyHE54QQl4QQp4QQWxP1qQKxQqFQjI9/Ax4c4/HXAiuiPx8AvpioQxWIFQqFYhxIKQ8A7WM0eQj4urQ5DBQLIeaP1WcglQNMFhHIliKrYDrulESTxG18WyRx3UTuncyZCY1los9zukjF6zkarupR31rSiVSYJnVNEvdOxVgAT5Wsbz8pqqSdxopcOdDWKqWcN5k+tMIaiTGYzL3OAPENH5ZSPjzO2y0AbsYdN0XPNY92QXoCcVYBwdW/Ork+NH362ujeNlqC6yZ6b6FpCdsk14/znN9zSKaf6WIq7y0tc8xjAGl6z02kn8TXWOO+Jtk2ZiScsM1E+55Im1QROfnV65PuxBwiuOZNCZuFj395UEq5fdL3GydpCcQKxUwi0rEAo30RgZKrBIqb0j0cxQSZxonGLaA27rgmem5UlEasUIxBpGMBPUfexcCF++lpfA9GZ026h6SYEAKh6Ql/UsR+4Dei2RM7gS4p5aiyBKgZsUIxJkb7IrB0QANLEmlfpGbFMxEhUhZohRDfBF4GlAshmoCPA0EAKeW/Ak8CrwMuAf3AexP1OWMCcTo14UR6sF+/06n/asHQhPpxP6+p+uoWaa8mfK+a0LzbhMrvjtou3Fo5ZrtkxpeMLurXxhqlTajiNgOXTLAkaBahilux19utK/v16x5zMuNzf07sNpbrePo02tmAEAI9id+TZJBSviPB4xL40Hj6nDGBWDEzCbdV0X7gV8DUQTcpu//HvkE23FpJ2zOJ2003wdLbFO3+DpG2WgKl1wiW3E73kBQTJJ2L0YlQgVgxIcJtVYTvLSA07xahsjujt7u3wA6uaGBhz3j9AvG96qTapYNg6W2CpbfVLHQmk0JpYipQgVgxbuxZ7kOx2Wvpvh+NGoxD826BboIFaCahef4zytC820m1UygmgsBf8skUJh2IhRC1wNeBSuxs8YellJ+dbL+pYKpyhP369ui2geDExpOE/utpk0wecYraAPR3LHHMXo2OJeQu6PW0Awgs6CXw6p8xdLeKrMo7ZFX0AfnedrV96K52Qivy7XOs8VlG2NPGreX6tdGS0JFN13VSm0BesU++clJ5zknkH7s/y1M5g09G+84sZv+M2AA+KqU8LoQoAI4JIZ6WUp5NQd+KDCSr8g5Ct5AWCM0iq3J0aQIgq6KVrIrWxP0m2U6hGDezXZqI5sc1R//fI4Q4h13OpwLxLCVr3j3KX/nUyOx13r10DyljiHRUY7QtJFB2Qy3sZRJCJJVdlC5SqhELIRYDW4AGn8c+gO1EBKF8rN6lyN5ViPzzaPlXUjmMWUekvZpIWy3BspsES6fmlzvcWmEvvlXcIWteS8L2WfPuqQDsItJRTfehd9h5x5pJ4a5vqmCcIdga8SyeEQ8jhMgHvg/8Tyllt/vxqHHGwwAiu0YaFz8KUgdhEljxjyoYj0KkvZqu598W++Uu2v2dlGcThFsraP3FG8DUELpF2QP/mVQwVjgx2hY6ij+MtoUqEGcKs12aABBCBLGD8GNSyh8kvMDKtoMwOkiQvasgLhCnqjDD97oUFGuAdzEumYUv91ejZBbHwjcWx/1yg9mxmEB195jX2ONz3ctnUXL4ur6OxWBpgIa0wOhcTP7iIU97AE1LwsUtmTYTcFtL5t6m6V0kNQ3nQpfvgp5rscmK+CzouV5Tdz9ZlXfpv2iCCegWWZV3Pde4F/iSvTdEfM6NzUQLYGYnIqnf9XSRiqwJAXwFOCel/KekLtIGQZh2joUwEfnnJzuMWUuwvMlO6zIB3bSPU0x2VQvd2sjiW/b88UsOg3dKGWieR878e+RUd6R8jDOBUNkdSvb+gEhrDcHyJkJld0gi2UExHYjZL03sAd4DnBZCnIye+wsp5ZOjXSD0IfQV/6g04iTw++WG1C46ZFW0UvHgzxlqmU/2/HtkV47lee1l8E4pzT+5D2lpdGoW1W88SHbV+PqYLYTK7oxZ4KJIDwLh+XaSSaQia+IgE/AU1/KvOOQIxehMxy93VkUrOdUeaT8pBprnIS0NpC1tDNwun7OBWJGhzAWNeLYxkeKMiRRi6D5/oT3abhL6r9+9PG189FVNd1Ya6QFv5ZFb79V8tN3iZZ10HpdI00LokuJlXeTkZ3najUVy2rP3nHR99TfCPoUYprORaXhfU8vdxue9SaTl+poJuYtAfNoY4QHPOTdJGUbNGb13IqhArJjl5FZ3svith+i7WUZebRu5C7rSPaQY/bdL6L1RSm5NKznzndr1wJ1SBm6Vk7Oglax5qpBkViOS26UmXahAPI1MRz5wusit7iS3ujN6lNbd72L03y7h+vd3IU2B0C1q3/R8LBgP3Cnl9v49yGjK3vzXP6vklFmMUDNiBQwn+zvzgXVVzjul9DeVIU1ha9cm9DeVjwTiW+VIU4s9NtA8TwXi2YzSiBOTMuf8JDbMTMYMPVGOsN85t07r1mjNjiXOfODOpeg1fZ5+9aycMfvxu7efthsIutro3jZawK0Re2ey7uvc14C/bhxPUnnFfhp2MnnDrhxhIzKik5as7OJeQ5x2vbyT7LwgpiEpXtZJx7GRx4qWdJKdG4rrx/ueu7VmGcpypO0Fy71FMG6N2C9H2G0mlMkBYyajBzIi3PmSuSObZbjzgUPzxtxLMOX03y6hv6mM3Jo2cudInm9+bRcr39NI15US8he2kVfTGXssd0EnS952OKZrZ1eO/zVxp+1VvvaXZFW2pfAZKFKFECKpCUG6UIF4mgiWNlO670cuM/XpyWscaC7h5uPDWqlk0a8dmlPBeDTJIXdBJ7kLOgGwjPH37U7bG7xToQJxBjORas7pQgXiaSRdyf79TeVxWqllz4znSCCeSnLm36MzviKxSvlzZDLJSF3pYsYE4qnSzZLSf/10WrdnhatNIOTUev3auPVg+zpn/q1b6/U756fbxuu9xcs7aWsc0UNLVnaRUxDy9ON7r5ArfzU6q+i6VkDnpWKKl3cihaDjUhEly7soWtKT1FdA3dXGfZwsA4POqWy8RjxybnQdeRjLozX7aMSuNsElPQR/7VBM8smq6AWc75/b+yLc730O7nWLdGcDzzzT9yQQya1VpIsZE4gVEyevppOV72mk53opBYvaya+dXJ5v17UCXviXjViGhtAsEAJpCrSAxZYPnaZ4mf9uHbORnOqOmLfGROQNxfRg22CqQKxIM/m1XZMOwMN0XirGMjSQAmkJkAIQWKY9M55LgVgxQxDCN3MoU1CBWDFuipd3ogUsLFMghAQB0pJouqRkeeZU1SVD780ieq6VklfbRn5cVoVi9qFmxIpZRdHiHjb93qmYRqy5NOJMqaxLRO/NIs5/bXs0m2QpK97dmFQw7r9dTN/NMnKqW+es5edMQwi1WOfLWItvEzWGT8b03WvW42d0k7gwxL0Y51mI81mscy/OBUJeM6FglvMtCWV536JgVhKLbO6FuFDioo+ckLefgmzXeALRNhUSdthBSNeA+s5oi/ykFt48i3VJpBaFTa+5b++g0zC938f0Z8B1bmjIFnPbmudFq+sE0rQYap7HvNW2AY+7eAPsRb7em0Vc++5w8F7J8nceieUnW6b0XDM04BxfJJnP5Bgm/nE9edp4Fv18doyey/iZRmUKGTw0hWJqKVrehRawQFhoAUnh0s6E1
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"model = Sequential()\n",
"model.add(Dense(10, activation='relu', input_dim=2))\n",
"model.add(Dense(10, activation='relu', input_dim=10))\n",
"model.add(Dense(1, activation='sigmoid', input_dim=10))\n",
"model.summary()\n",
"\n",
"sgd = optimizers.SGD(learning_rate=0.1)\n",
"model.compile(optimizer=sgd, loss='binary_crossentropy', metrics=['accuracy'])\n",
"history = model.fit(x_train, y_train, validation_split=0.2, epochs=20, batch_size=20)\n",
"\n",
"print_decision_boundaries(model, x_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XMMppWbnG3dN"
},
"source": [
"# Classification d'images de chiens et de chats\n",
"\n",
"Dans la suite du TP, on s'intéresse au problème simple (en apparence) de reconnaître des chiens et des chats dans des images.\n",
"\n",
"<center> <img src=\"https://drive.google.com/uc?id=11W1SmzrBhL8vyzPCjSkZfHWnxb7kByi5\" style=\"width:1000;height:550px;\"></center>\n",
"<caption><center><b> Figure 1 : Quelques images de la base de données </b></center></caption>\n",
"\n",
"Pour cela nous allons utiliser une base de données de 4000 images, réparties en 2000 images d'apprentissage, 1000 images de validation, et 1000 images de test. Compte-tenu de la variabilité possible des représentations de chiens et chats, cette base de données est d'une taille assez réduite et le problème est complexe. Il correspond bien aux problèmes que nous pouvons rencontrer dans la réalité, lorsque les données sont souvent difficiles à obtenir.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m7K-oLcaXkcY"
},
"source": [
"Il faut définir une résolution commune à toutes les images, qui sera donc la dimension passée en entrée au réseau de neurones. Pour commencer et simplifier le problème, vous pouvez d'abord considérer des images de taille $64 \\times 64$ ; plus tard, lorsque vos réseaux fonctionneront bien, nous pourrons envisager d'augmenter cette résolution pour améliorer les performances. "
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "8th8b32kV2kh",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"IMAGE_SIZE = 128\n",
"CLASSES = ['cat', 'dog']"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z3mdNJJXc6Wy"
},
"source": [
"## Chargement des données\n",
"La base de données est à télécharger depuis Git. Ne passez pas trop de temps à regarder les cellules suivantes (mais exécutez les !)."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "n_OkpjrpFXXG",
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"fatal: le chemin de destination 'iam' existe déjà et n'est pas un répertoire vide.\n"
]
}
],
"source": [
"!git clone https://github.com/axelcarlier/iam.git\n",
"path = \"./iam/tp3/\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KoSVj5OGXa-4"
},
"source": [
"Chargement des données dans des tenseurs $x$ et $y$ de dimensions respectives $(N, 64, 64, 3)$ et $(N, 1)$, où $N$ désigne le nombre d'éléments de l'ensemble considéré (apprentissage, validation, ou test)."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "VcNp4xl0QfOT",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"import glob\n",
"import PIL\n",
"from PIL import Image\n",
"import numpy as np\n",
"\n",
"def load_data(path, classes, image_size=128):\n",
"\n",
" # Liste les fichiers présents dans le dossier path\n",
" file_path = glob.glob(path)\n",
" \n",
" # Initialise les structures de données\n",
" x = np.zeros((len(file_path), image_size, image_size, 3))\n",
" y = np.zeros((len(file_path), 1))\n",
"\n",
" for i in range(len(file_path)):\n",
" # Lecture de l'image\n",
" img = Image.open(file_path[i])\n",
" # Mise à l'échelle de l'image\n",
" img = img.resize((image_size,image_size), Image.ANTIALIAS)\n",
" # Remplissage de la variable x\n",
" x[i] = np.asarray(img)\n",
"\n",
" img_path_split = file_path[i].split('/')\n",
" img_name_split = img_path_split[-1].split('.')\n",
" class_label = classes.index(img_name_split[-3])\n",
" \n",
" y[i] = class_label\n",
"\n",
" return x, y\n",
"\n",
"x_train, y_train = load_data('./iam/tp3/train/*', CLASSES, image_size=IMAGE_SIZE)\n",
"x_val, y_val = load_data('./iam/tp3/validation/*', CLASSES, image_size=IMAGE_SIZE)\n",
"x_test, y_test = load_data('./iam/tp3/test/*', CLASSES, image_size=IMAGE_SIZE)\n",
"\n",
"# Normalisation des entrées via une division par 255 des valeurs de pixel.\n",
"x_train = x_train/255\n",
"x_val = x_val/255\n",
"x_test = x_test/255"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vwngS1p9V1VN"
},
"source": [
"### Visualisation des images"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "YXUxcuIPOS5W",
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1EAAANYCAYAAAAlimdiAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9W6wtSbfnB/1GRGTmvKy1b1X1nct3Tl9oNzJGtjAPthpkCyFz8YPllhHG3CQjGUsIxAtgGgTClhC2/GYJ8dBCCLAQbQGWZcBSC2PgBTCGfkEYtfuAzunu4/NdqmrX3nutecmMiMHDiIjMOddca1+qvvqqDnN836o5d87MyMjIiBHjP66iqlzpSle60pWudKUrXelKV7rSlT6M3K+7A1e60pWudKUrXelKV7rSla70Y6IriLrSla50pStd6UpXutKVrnSlj6AriLrSla50pStd6UpXutKVrnSlj6AriLrSla50pStd6UpXutKVrnSlj6AriLrSla50pStd6UpXutKVrnSlj6AriLrSla50pStd6UpXutKVrnSlj6AriLrSla50pStd6UpXutKV3kMi8j8Wkf/ur7sfV/ph0BVEXenXRiLy+yLy9/26+3GlK13pjz9d+c2VrnSlK13pu6QriLrSla50pStd6UpXutKVrnSlj6AriLrSd0Ii8rsi8i+KyC9F5CsR+e+LyJ8RkX+t/PtLEfmficiLcv4/D/wJ4H8tInci8k/8Wh/gSle60o+GrvzmSle60vdBIvJ3ishfEZF3IvIvAKvFb/95Efk9EflaRP5lEfntxW//YRH5qyLyRkT+ByLyfxKRf+zX8hBX+pXRFURd6VuTiHjgfwP8AfCngJ8CfwkQ4J8Gfhv4dwG/C/yTAKr6nwX+OvAPqOqNqv6z33vHr3SlK/3o6MpvrnSlK30fJCI98C8B/zzwCvhfAP+x8tt/EOM3/zDwWxg/+kvlt8+B/yXw3wA+A/4q8O/7fnt/pe+DRFV/3X240o+cROTPAf8y8FuqGp84788D/x1V/TvLv38f+MdU9V/9Pvp5pStd6cdPV35zpStd6fsgEfl7MWD0Uy3Csoj8n4F/DQNOX6nqP1GO3wCvgT8L/L3Af0FV/1z5TTAlzj+lqv/D7/1BrvQro/Dr7sCV/ljQ7wJ/cC7QiMhvAP8c8PcAt5jl8/X3370rXelKf4zoym+udKUrfR/028Af6qm14Q8Wv/2VelBV70TkK8wy/tvA31j8piLyN7+H/l7pe6arO9+Vvgv6G8CfEJFzUP7fAxT421X1GfCfwVxuKl3NoFe60pU+lq785kpXutL3QX8E/LRYkir9ifL5bwN/sh4UkS3muveH5brfWfwmy39f6Y8PXUHUlb4L+r9hTOOfEZGtiKxE5N+PaYPvgDci8lPgv3Z23c+Bf8f329UrXelKP3K68psrXelK3wf9X4AI/JdFpBORfwj4u8pv/3PgPyci/x4RGTAlzr+uqr8P/G+Bv11E/nxR9vwXgd/8/rt/pV81XUHUlb41qWoC/gHgb8H8fv8m8J8A/ing3wu8wZjKv3h26T8N/LdE5BsR+a9+fz2+0pWu9GOlK7+50pWu9H2Qqo7APwT8o8DXGJ/5F8tv/yrw3wb+V5hS588A/0j57UvgPw78s8BXwN8G/N+B4/f6AFf6ldM1scSVrnSlK13pSle60pWu9CsgEXGYsuc/rar/h193f6703dHVEnWlK13pSle60pWudKUrfUckIv8REXlRXP3+m1h85v/119ytK33HdAVRV7rSla50pStd6UpXutJ3R38O+P8AX2Lux39eVfe/3i5d6bumXxmIEpH/aKnW/Hsi8hd+Vfe50pWudKUrv7nSla70fdGV31zpfaSq/6Sqfqaqt6r6d6vqv/7r7tOVvnv6lcRElYry/xbwH8L8QP8N4D+pqv/md36zK13pSv9/TVd+c6UrXen7oiu/udKVrlTpV1Vs9+8Cfk9V/78AIvKXgH8QuMhknHPqnBnFajZ+VUWzIgIiYn/lOCjOOWRx/kk1EGth/nqS4v/BiY9edk7yxL/m6/WJRi6VLBHEif0mdnnOmZyVlBOqi9bKGNhVetbi+3v3oFeyPK4n50j9ryrK8h04RCCrnV0xeM6PD5y1cOF3KffR97yTx5/gtLHzs0vnnBOcczgRnMiT7XyYSuF0rC72ROc+6dmoXmyjvGcBm69q/Rex+aGqqCo55/k4smj7W5Ke/1NPfnrsDS10MF+q6hffTWc+mj6K3wCsV4Pe3GzbuKpC1kwq42tj7BDnWA0DPgRySqgq4/HYfgclp4SUuVU/Nef2rk5GsPCiXAdO5MIblJlllXVGe+eUOUzrty6uW36cXN+OnzS8+HbKL53IzB/aTU7ncX2Ek/6Xe9R5vFTSXVihp7+dtPehPOGPOz21+n4I9DH9+5B98cPocDwyTdOva2A+mt+suqCboSfnPB8ssg0s1/K81tpqq/t+++mRcRRItZ2cUSB4P8tOi2uNh8i8By/o0n6laFvLqmqyl7M96LSEEk1msK+5tCTtWS71WyrPkwc/FNkDgvOIc3jnGyurY6aq5AUnCyGYjHgiNy76sbxL2UuX8pAuxmGWv2B5QBa8ru3NbSOvz1P4vWvMFFEIyc5zOq8flZPmW1+rKDGLSWUfQHAKTsGnjCRFxwliJqj9TrLxnzyod6R1hwbH1HsyStKMKEg53y3eTy5jkEsbwTkcgq9yYdnfEmXvRO15nZDLHHDe40QIPhS5RpF6n7bH2Gcu+27KimomTgnNmZTsPjnZ+C7X0PLdyeKt5rLvijiTk8q/63A759qaW9L5v+s7Bnj79puLMs6vCkT9lEW1Zkxb83efdk7+ceAfL9+5vb3Be1cmB8Q4MR2PBO8IwTN0HZ33xDiimtiuB4L31HfjQ5msFCGoTAFzWCyLSTzL2TgvrAWZvFIAmpTVpO2cBvZwUF5OFZxQJcUImiHnEyYphUPIcuGXlkI3mEAWAjEm7vZ7docjb97dMSVlShnnPVKYhwh0ZJworqwwmywgesrYZCHIJM1kVWN+TlBnD5y0LgZbTQoFcDhyVFJM9P1A13V0YcA5zxgTKSs5Q05wPI72RFI9RJWsSs4JY2+JOvWdGDjG2cR3au/kqR2xCrYn722x4L339owpld8gpUyMmfVq4Ga7Zjt0DL0vCPBMiGwMbMk+T7f99s60jNPJipOZuWqdN6fzLCOz4Fmfpf4V5iEihNCRcmYcR7quYxgGpmkixsh+v2eaJrqua+92OTbn43I+dm0Ozj+e9Gd5DczAwjnXNiWgAQRjbJBiq+L+66D38hs45TnbzZq//+/7eziOEzEljuPE/njg7f094xQZY2S13jAMa/7Mn/2zvHr5kuNhz/F45K///u8zjSOd96Qpsr+7wzvHugsMXcdm6Dnsdhz3B+I4klMCcaiA+IAiHHO26eE9GSEhjamH4OhCgGygTacIOdMHjxdhHWyuT4eRpMqUsHXnfFM6IfO8qhuHiAk+CAuBo25cRQAQJQRP13mcdzgnxCmScyJHRbPxBgWyClkhlXWRxfiQK+ATJ+RsyjAnbeIjVQgDBFuvxooEVTuaP1CwlrPN80PpU1QP328m20tQsvL0s1Ol8hX9iAf7FPXLiVh32tpjcr3Me14TYhbn12eZ56y080659Cn9W3/t9z66998hfTS/2fQdf//f8e8kpUTWTCzrTbG1mLE1mFXtnKaAga4LUARizZkUE1IGyPZ8yN6hzrGPkTElDocDOSsvbp8RQiA4gaykacI7Rx86nHN4X+WFGThUGQpxTVzKqkwxoppIccI7T98FvPd4H4rSe7Fqy/ue4gSqBQAJvshPqvO5IQRC8KUth0oZGW98ZBg6uhB4dfucoe+5vbltCquYEzElxhwZNZFUyQIvP/+c9WZN1wUEIY0TokqgAAERKksKzuNDQItMlAp/HDUXyUVRAe+cjVHKiNizeBG8OMZxZBxHpimSUiJ0BuLwjuwdh7U3fp8y/Zh49nokRKU7ZsQ5xHnUC9k7shfUCS7ZOGZn7+DoBS2yXi+OrQTWo3K7SwzfHBhe70l/9
"text/plain": [
"<Figure size 864x864 with 9 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"# Randomisation des indices et affichage de 9 images alétoires de la base d'apprentissage\n",
"indices = np.arange(x_train.shape[0])\n",
"np.random.shuffle(indices)\n",
"plt.figure(figsize=(12, 12))\n",
"for i in range(0, 9):\n",
" plt.subplot(3, 3, i+1)\n",
" plt.title(CLASSES[int(y_train[i])])\n",
" plt.imshow(x_train[i])\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tV5s1T3yWJB6"
},
"source": [
"## Première approche : réseau convolutif de base"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "00T5cUGlif9z"
},
"source": [
"Les images ont toutes été redimensionnées en $64 \\times 64$. \n",
"Vous devez définir un réseau de neurones convolutif en suivant ce schéma por la base convolutive : \n",
"\n",
"<center> <img src=\"https://drive.google.com/uc?id=1bwXaIgO-pKJGs6fVaX0IrLbFbUAlTvNM\" style=\"width:800;height:400px;\"></center>\n",
"<caption><center><b> Figure 2: Vue de l'architecture à implémenter </b></center></caption>\n",
"\n",
"Ce réseau alterne dans une première phase les couches de convolution et de Max Pooling (afin de diviser à chaque fois la dimension des tenseurs par 2). \n",
"\n",
"La première couche comptera 32 filtres de convolution, la seconde 64, la troisième 96 et la 4e 128. Enfin, avant la couche de sortie, vous ajouterez une couche dense comptant 512 neurones. Vous aurez donc construit un réseau à 6 couches, sorte de version simplifiée d'AlexNet.\n",
"\n",
"Pour construire ce réseau, vous pouvez utiliser les fonctions Conv2D, Maxpooling2D, et Flatten de Keras."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential_2\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" conv2d (Conv2D) (None, 126, 126, 32) 896 \n",
" \n",
" max_pooling2d (MaxPooling2D (None, 63, 63, 32) 0 \n",
" ) \n",
" \n",
" conv2d_1 (Conv2D) (None, 61, 61, 64) 18496 \n",
" \n",
" max_pooling2d_1 (MaxPooling (None, 30, 30, 64) 0 \n",
" 2D) \n",
" \n",
" conv2d_2 (Conv2D) (None, 28, 28, 96) 55392 \n",
" \n",
" max_pooling2d_2 (MaxPooling (None, 14, 14, 96) 0 \n",
" 2D) \n",
" \n",
" conv2d_3 (Conv2D) (None, 12, 12, 128) 110720 \n",
" \n",
" max_pooling2d_3 (MaxPooling (None, 6, 6, 128) 0 \n",
" 2D) \n",
" \n",
" flatten (Flatten) (None, 4608) 0 \n",
" \n",
" dense_4 (Dense) (None, 512) 2359808 \n",
" \n",
" dense_5 (Dense) (None, 1) 513 \n",
" \n",
"=================================================================\n",
"Total params: 2,545,825\n",
"Trainable params: 2,545,825\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"from tensorflow.keras.layers import Conv2D, MaxPooling2D\n",
"from tensorflow.keras.layers import Dense, Flatten\n",
"\n",
"model = Sequential()\n",
"model.add(Conv2D(32, 3, activation=\"relu\", input_shape=x_train.shape[1:]))\n",
"model.add(MaxPooling2D(pool_size=(2, 2)))\n",
"model.add(Conv2D(64, 3, activation=\"relu\"))\n",
"model.add(MaxPooling2D(pool_size=(2, 2)))\n",
"model.add(Conv2D(96, 3, activation=\"relu\"))\n",
"model.add(MaxPooling2D(pool_size=(2, 2)))\n",
"model.add(Conv2D(128, 3, activation=\"relu\"))\n",
"model.add(MaxPooling2D(pool_size=(2, 2)))\n",
"model.add(Flatten())\n",
"model.add(Dense(512, activation=\"relu\"))\n",
"model.add(Dense(1, activation=\"sigmoid\"))\n",
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yWqVtzWZIsOY"
},
"source": [
"### Entrainement"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6Q9IQIETzLI-"
},
"source": [
"Pour l'entraînement, vous pouvez utiliser directement les hyperparamètres suivants."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "IJsJ7mMIjCGm",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"adam = optimizers.Adam(learning_rate=3e-4)\n",
"model.compile(loss='binary_crossentropy', optimizer=adam, metrics=['accuracy'])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LNbGxTZt4cck"
},
"source": [
"... puis lancer l'entraînement. **Attention : si jamais vous voulez relancer l'entraînement, il faut réinitialiser les poids du réseau. Pour cela il faut re-exécuter les cellules précédentes à partir de la définition du réseau !** Sinon vous risquez de repartir d'un entraînement précédent (qui s'est éventuellement bien, ou mal déroulé) et mal interpréter votre nouvel entraînement."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"id": "fjetcQRljJC8",
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"200/200 [==============================] - 16s 75ms/step - loss: 0.6965 - accuracy: 0.5125 - val_loss: 0.6850 - val_accuracy: 0.5610\n",
"Epoch 2/10\n",
"200/200 [==============================] - 13s 65ms/step - loss: 0.6547 - accuracy: 0.6320 - val_loss: 0.6452 - val_accuracy: 0.6310\n",
"Epoch 3/10\n",
"200/200 [==============================] - 13s 64ms/step - loss: 0.5868 - accuracy: 0.7025 - val_loss: 0.6075 - val_accuracy: 0.6700\n",
"Epoch 4/10\n",
"200/200 [==============================] - 13s 64ms/step - loss: 0.5296 - accuracy: 0.7360 - val_loss: 0.6445 - val_accuracy: 0.6800\n",
"Epoch 5/10\n",
"200/200 [==============================] - 13s 64ms/step - loss: 0.4848 - accuracy: 0.7745 - val_loss: 0.6199 - val_accuracy: 0.7170\n",
"Epoch 6/10\n",
"200/200 [==============================] - 13s 64ms/step - loss: 0.4319 - accuracy: 0.7950 - val_loss: 0.5907 - val_accuracy: 0.7200\n",
"Epoch 7/10\n",
"200/200 [==============================] - 13s 63ms/step - loss: 0.3770 - accuracy: 0.8250 - val_loss: 0.6193 - val_accuracy: 0.7200\n",
"Epoch 8/10\n",
"200/200 [==============================] - 12s 62ms/step - loss: 0.3217 - accuracy: 0.8540 - val_loss: 0.6377 - val_accuracy: 0.7030\n",
"Epoch 9/10\n",
"200/200 [==============================] - 13s 64ms/step - loss: 0.2562 - accuracy: 0.8920 - val_loss: 0.7223 - val_accuracy: 0.7300\n",
"Epoch 10/10\n",
"200/200 [==============================] - 12s 62ms/step - loss: 0.1914 - accuracy: 0.9230 - val_loss: 0.7326 - val_accuracy: 0.7200\n"
]
}
],
"source": [
"history = model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=10, batch_size=10)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iBPk-patWSYX"
},
"source": [
"### Analyse des résultats du modèle"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "casoAuXmzWYb"
},
"source": [
"Les quelques lignes suivantes permettent d'afficher l'évolution des métriques au cours de l'entraînement, sur les ensembles d'apprentissage et de validation."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"id": "fExCbSI3V6Ur",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"def plot_training_analysis():\n",
" acc = history.history['accuracy']\n",
" val_acc = history.history['val_accuracy']\n",
" loss = history.history['loss']\n",
" val_loss = history.history['val_loss']\n",
"\n",
" epochs = range(len(acc))\n",
"\n",
" plt.plot(epochs, acc, 'b', linestyle=\"--\",label='Training acc')\n",
" plt.plot(epochs, val_acc, 'g', label='Validation acc')\n",
" plt.title('Training and validation accuracy')\n",
" plt.legend()\n",
"\n",
" plt.figure()\n",
"\n",
" plt.plot(epochs, loss, 'b', linestyle=\"--\",label='Training loss')\n",
" plt.plot(epochs, val_loss,'g', label='Validation loss')\n",
" plt.title('Training and validation loss')\n",
" plt.legend()\n",
"\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"id": "ex3AjPOPu2UN",
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEICAYAAABPgw/pAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAzQklEQVR4nO3dd3xUVfrH8c9DKKEEJAGk9yBFSIAIUkRQULCAqCggQkBFsbCwIoKiFGVXEBu76i7dRRAVpaioK+0HKygdFATpEAQEQoeQdn5/nEmYhJQhTHJnJs/79corM3fu3HkygW/OnHvuOWKMQSmllP8r4HQBSimlvEMDXSmlAoQGulJKBQgNdKWUChAa6EopFSA00JVSKkBooAcwEflWRPp4e18nicg+EWmfC8c1IlLbdftfIvKKJ/vm4HUeEZH/5rROpbIiOg7dt4jIObe7xYBLQJLr/pPGmFl5X5XvEJF9wOPGmMVePq4Bwo0xu7y1r4hUB/YChYwxiV4pVKksFHS6AJWWMaZEyu2swktECmpIKF+h/x59g3a5+AkRaSsiMSLyoogcAaaLSGkR+VpEjonISdftym7PWS4ij7tuR4vI/0RkgmvfvSLSKYf71hCRFSJyVkQWi8j7IvJxJnV7UuNrIvKj63j/FZEybo8/KiL7ReSEiLycxfvTXESOiEiQ27auIrLFdbuZiKwWkVMiclhE/ikihTM51gwRed3t/guu5/whIv3S7Xu3iGwUkTMiclBERrk9vML1/ZSInBORFinvrdvzW4rIWhE57fre0tP35irf51ARme76GU6KyHy3x7qIyCbXz7BbRDq6tqfp3hKRUSm/ZxGp7up6ekxEDgBLXds/d/0eTrv+jTRwe35REXnL9fs87fo3VlREvhGR59L9PFtEpGtGP6vKnAa6fykPhALVgP7Y39901/2qwEXgn1k8vzmwAygDjAemiojkYN/ZwBogDBgFPJrFa3pSY0+gL1AOKAwMARCR+sCHruNXdL1eZTJgjPkZOA/clu64s123k4DBrp+nBXA78HQWdeOqoaOrng5AOJC+//480Bu4DrgbGCAi97kea+P6fp0xpoQxZnW6Y4cC3wATXT/b28A3IhKW7me44r3JQHbv80xsF14D17HecdXQDPgP8ILrZ2gD7MvkNTJyK1APuNN1/1vs+1QO2AC4dxFOAJoCLbH/jocCycBHQK+UnUQkAqiEfW/U1TDG6JePfmH/Y7V33W4LxAPBWewfCZx0u78c22UDEA3scnusGGCA8lezLzYsEoFibo9/DHzs4c+UUY0j3O4/DXznuv0qMMftseKu96B9Jsd+HZjmuh2CDdtqmew7CJjndt8AtV23ZwCvu25PA95w26+O+74ZHPdd4B3X7equfQu6PR4N/M91+1FgTbrnrwais3tvruZ9Bipgg7N0Bvv9O6XerP79ue6PSvk9u/1sNbOo4TrXPqWwf3AuAhEZ7BcMnMSelwAb/B/kxv+pQP/SFrp/OWaMiUu5IyLFROTfro+wZ7Af8a9z73ZI50jKDWPMBdfNEle5b0Ug1m0bwMHMCvawxiNuty+41VTR/djGmPPAicxeC9sav19EigD3AxuMMftdddRxdUMccdXxN2xrPTtpagD2p/v5movIMldXx2ngKQ+Pm3Ls/em27ce2TlNk9t6kkc37XAX7OzuZwVOrALs9rDcjqe+NiASJyBuubpszXG7pl3F9BWf0Wq5/058CvUSkANAD+4lCXSUNdP+SfkjS88ANQHNjTEkuf8TPrBvFGw4DoSJSzG1blSz2v5YaD7sf2/WaYZntbIzZhg3ETqTtbgHbdbMd2wosCbyUkxqwn1DczQYWAlWMMaWAf7kdN7shZH9gu0jcVQUOeVBXelm9zwexv7PrMnjeQaBWJsc8j/10lqJ8Bvu4/4w9gS7YbqlS2FZ8Sg3HgbgsXusj4BFsV9gFk657SnlGA92/hWA/xp5y9ceOzO0XdLV41wGjRKSwiLQA7s2lGucC94hIa9cJzDFk/292NvAXbKB9nq6OM8A5EakLDPCwhs+AaBGp7/qDkr7+EGzrN87VH93T7bFj2K6OmpkcexFQR0R6ikhBEXkYqA987WFt6evI8H02xhzG9m1/4Dp5WkhEUgJ/KtBXRG4XkQIiUsn1/gBsArq79o8CHvSghkvYT1HFsJ+CUmpIxnZfvS0iFV2t+RauT1O4AjwZeAttneeYBrp/excoim39/AR8l0ev+wj2xOIJbL/1p9j/yBl5lxzWaIzZCjyDDenD2H7WmGye9gn2RN1SY8xxt+1DsGF7FpjsqtmTGr51/QxLgV2u7+6eBsaIyFlsn/9nbs+9AIwFfhQ7uubmdMc+AdyDbV2fwJ4kvCdd3Z56l6zf50eBBOynlD+x5xAwxqzBnnR9BzgN/B+XPzW8gm1RnwRGk/YTT0b+g/2EdAjY5qrD3RDgF2AtEAuMI20G/QdoiD0no3JALyxS10xEPgW2G2Ny/ROCClwi0hvob4xp7XQt/kpb6OqqichNIlLL9RG9I7bfdL7DZSk/5urOehqY5HQt/kwDXeVEeeyQunPYMdQDjDEbHa1I+S0RuRN7vuEo2XfrqCxol4tSSgUIbaErpVSAcGxyrjJlypjq1as79fJKKeWX1q9ff9wYUzajxxwL9OrVq7Nu3TqnXl4ppfySiKS/ujiVdrkopVSA0EBXSqkAoYGulFIBwqdWLEpISCAmJoa4uLjsd1aOCA4OpnLlyhQqVMjpUpRS6fhUoMfExBASEkL16tXJfN0F5RRjDCdOnCAmJoYaNWo4XY5SKh2f6nKJi4sjLCxMw9xHiQhhYWH6CUopH+VTgQ5omPs4/f0o5bt8LtCVUipQHTkCv/ySe8fXQHdz4sQJIiMjiYyMpHz58lSqVCn1fnx8fJbPXbduHQMHDsz2NVq2bJntPkqpwJIyZdbXX8MAT5dWyQGfOinqtLCwMDZt2gTAqFGjKFGiBEOGXF5kPTExkYIFM37LoqKiiIqKyvY1Vq1a5ZValVK+LTnZBvg770DXrjBwIDzyCNx6a+69prbQsxEdHc1TTz1F8+bNGTp0KGvWrKFFixY0btyYli1bsmPHDgCWL1/OPffcA9g/Bv369aNt27bUrFmTiRMnph6vRIkSqfu3bduWBx98kLp16/LII4+krIDOokWLqFu3Lk2bNmXgwIGpx3W3b98+brnlFpo0aUKTJk3S/KEYN24cDRs2JCIigmHDhgGwa9cu2rdvT0REBE2aNGH37mtZF1gplZlz5+Cf/4QbboAuXWD3bggJsY8VLQrh4bn32j7dQm/b9sptDz0ETz8NFy7AXXdd+Xh0tP06fhweTLcC4vLlOasjJiaGVatWERQUxJkzZ1i5ciUFCxZk8eLFvPTSS3zxxRdXPGf79u0sW7aMs2fPcsMNNzBgwIArxm5v3LiRrVu3UrFiRVq1asWPP/5IVFQUTz75JCtWrKBGjRr06NEjw5rKlSvHDz/8QHBwMDt37qRHjx6sW7eOb7/9lgULFvDzzz9TrFgxYmNjAXjkkUcYNmwYXbt2JS4ujuTk5Jy9GUqpLHXvDt98A82bw+uvw/33Q15dtuHTge4runXrRlBQEACnT5+mT58+7Ny5ExEhISEhw+fcfffdFClShCJFilCuXDmOHj1K5cqV0+zTrFmz1G2RkZHs27ePEiVKULNmzdRx3j169GDSpCsXcUlISODZZ59l06ZNBAUF8fvvvwOwePFi+vbtS7FidrH20NBQzp49y6FDh+jatStgLw5SSnnHzz/DxInw1ltQvjy88gq8/DK0aJH3tfh0oGfVoi5WLOvHy5TJeYs8veLFi6fefuWVV2jXrh3z5s1j3759tM3oYwRQpEiR1NtBQUEkJibmaJ/MvPPOO1x//fVs3ryZ5ORkDWml8lBiIsybZ/vHV6+GkiWhd28b6M2bO1eX9qFfpdOnT1OpUiUAZsyY4fXj33DDDezZs4d9+/YB8OmnGS9Of/r0aSpUqECBAgWYOXMmSUlJAHTo0IHp06dz4cIFAGJjYwkJCaFy5crMnz8fgEuXLqU+rpS6Ohcv2v7xh
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAEICAYAAABPgw/pAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAA2QklEQVR4nO3dZ3hVVfr38e9NKAFCkQBSEjoEEDEkARSko7RIQBmkOIioFEXsjl3GMo+j/hVRBCPgiAMDWOjNoSM6SgKhhyI1NAGBJBBIW8+LdRIDhBDS9snJ/bkuLs4uZ587O/DLytprry3GGJRSShV+xZwuQCmlVN7QQFdKKQ+hga6UUh5CA10ppTyEBrpSSnkIDXSllPIQGugqUyKyREQezOt9nSQiB0Skaz4c14hIA9frSSLyWnb2zcHnDBaRH3JaZxbH7SgiMXl9XFXwijtdgMo7IhKfYbEMcAlIcS2PMMZMz+6xjDE98mNfT2eMGZkXxxGROsB+oIQxJtl17OlAtr+HqujRQPcgxhiftNcicgB4xBiz/Mr9RKR4WkgopTyHdrkUAWm/UovI30TkOPCliNwkIgtF5KSInHG99svwntUi8ojr9VAR+VFEPnDtu19EeuRw37oislZE4kRkuYhMEJF/X6Pu7NT4loisdx3vBxGpnGH7X0XkoIicFpFXsjg/rUXkuIh4ZVjXV0S2uF63EpGfReSsiBwTkU9FpOQ1jvUvEXk7w/LzrvccFZFhV+zbS0Q2iUisiBwWkbEZNq91/X1WROJF5I60c5vh/W1EZIOInHP93Sa75yYrItLE9f6zIrJdRHpn2NZTRHa4jnlERJ5zra/s+v6cFZE/RGSdiGi+FDA94UVHNaASUBsYjv3ef+largUkAJ9m8f7WwC6gMvAeMEVEJAf7zgB+BXyBscBfs/jM7NQ4CHgIqAqUBNICpikw0XX8Gq7P8yMTxphfgPNA5yuOO8P1OgV42vX13AF0AR7Lom5cNXR31XMX0BC4sv/+PDAEqAj0AkaJSB/XtvauvysaY3yMMT9fcexKwCJgvOtr+xBYJCK+V3wNV52b69RcAlgA/OB63xPAdBEJcO0yBdt9Vw5oBqx0rX8WiAGqADcDLwM6r0gB00AvOlKBN4wxl4wxCcaY08aY74wxF4wxccA7QIcs3n/QGPOFMSYF+Aqojv2Pm+19RaQW0BJ43RiTaIz5EZh/rQ/MZo1fGmN2G2MSgNlAoGt9P2ChMWatMeYS8JrrHFzLf4CBACJSDujpWocxJtIY8z9jTLIx5gDweSZ1ZKa/q75txpjz2B9gGb++1caYrcaYVGPMFtfnZee4YH8A7DHGfO2q6z9ANHBPhn2udW6ycjvgA7zr+h6tBBbiOjdAEtBURMobY84YYzZmWF8dqG2MSTLGrDM6UVSB00AvOk4aYy6mLYhIGRH53NUlEYv9Fb9ixm6HKxxPe2GMueB66XOD+9YA/siwDuDwtQrOZo3HM7y+kKGmGhmP7QrU09f6LGxr/F4RKQXcC2w0xhx01dHI1Z1w3FXHP7Ct9eu5rAbg4BVfX2sRWeXqUjoHjMzmcdOOffCKdQeBmhmWr3VurluzMSbjD7+Mx70P+8PuoIisEZE7XOvfB/YCP4jIPhF5MXtfhspLGuhFx5WtpWeBAKC1MaY8f/6Kf61ulLxwDKgkImUyrPPPYv/c1Hgs47Fdn+l7rZ2NMTuwwdWDy7tbwHbdRAMNXXW8nJMasN1GGc3A/obib4ypAEzKcNzrtW6PYruiMqoFHMlGXdc7rv8V/d/pxzXGbDDGhGG7Y+ZiW/4YY+KMMc8aY+oBvYFnRKRLLmtRN0gDvegqh+2TPuvqj30jvz/Q1eKNAMaKSElX6+6eLN6Smxq/BUJF5E7XBcw3uf6/9xnAk9gfHN9cUUcsEC8ijYFR2axhNjBURJq6fqBcWX857G8sF0WkFfYHSZqT2C6ietc49mKgkYgMEpHiInI/0BTbPZIbv2Bb8y+ISAkR6Yj9Hs10fc8Gi0gFY0wS9pykAohIqIg0cF0rOYe97pBVF5fKBxroRdc4oDRwCvgfsLSAPncw9sLiaeBtYBZ2vHxmxpHDGo0x24HHsSF9DDiDvWiXlbQ+7JXGmFMZ1j+HDds44AtXzdmpYYnra1iJ7Y5YecUujwFvikgc8Dqu1q7rvRew1wzWu0aO3H7FsU8DodjfYk4DLwChV9R9w4wxidgA74E9758BQ4wx0a5d/goccHU9jcR+P8Fe9F0OxAM/A58ZY1blphZ140SvWygnicgsINoYk++/ISjl6bSFrgqUiLQUkfoiUsw1rC8M2xerlMolvVNUFbRqwPfYC5QxwChjzCZnS1LKM2iXi1JKeQjtclFKKQ/hWJdL5cqVTZ06dZz6eKWUKpQiIyNPGWOqZLbNsUCvU6cOERERTn28UkoVSiJy5R3C6bTLRSmlPIQGulJKeQgNdKWU8hAa6Eop5SE00JVSykNooCullIfQQFdKKQ+hc7kopVQ+ir0US0xsDIfPHSYmNoaY2Bh6NepFSI2QPP8sDXSllMqhzML6cOzhy/6OvRR71ftu9rlZA10ppQrKlWGdFtDXC+tqPtXwK+9HgG8AXep2wb+8P37l/fAr74d/BX9qlKtBSa+S+VKzBrpSqsjJSVgLws0+N2ca1v4V7N/5GdbZoYGulPJ4F5Iu8PTSp/nx8I+FNqyzQwNdKeXRziScoffM3qw/tJ7QRqGFNqyzQwNdKeWxjsYdpfu/uxN9KpqZ/WbS/5b+TpeUrzTQlVIeac/pPdz977s5deEUiwcvpmu9rk6XlO800JVSHifyaCQ9pvfAYFj14Kp8GSLojvROUaWUR1m5fyUdv+pImRJlWD9sfZEJc8hmoItIdxHZJSJ7ReTFTLZ/JCJRrj+7ReRsnleqlFLX8e2Ob+kxvQe1K9Rm/bD1NPJt5HRJBeq6XS4i4gVMAO4CYoANIjLfGLMjbR9jzNMZ9n8CaJEPtSql1DVNipjEY4se4w7/O1gwcAGVSldyuqQCl50WeitgrzFmnzEmEZgJhGWx/0DgP3lRnFJKXY8xhrfWvMWoRaPo2bAn//3rf4tkmEP2Ar0mcDjDcoxr3VVEpDZQF1h5je3DRSRCRCJOnjx5o7WmO3kSUlNz/HallIdINak8seQJXl/9OkNuG8Kc++dQpkQZp8tyTF5fFB0AfGuMSclsozEm3BgTYowJqVKlSo4/5P77oWlTCA+HhIQcH0YpVYglpiQy6LtBTNgwgWfveJYvw76khFcJp8tyVHYC/Qjgn2HZz7UuMwPI5+4WY+DRR6FsWRgxAmrXhr//3bbalVJFQ3xiPKEzQpm1fRbvdX2PD+7+gGKig/aycwY2AA1FpK6IlMSG9vwrdxKRxsBNwM95W+KVnwMDB0JEBKxaBa1bw9ix8PXX+fmpSil3cerCKTp/1ZkV+1cwtfdUnm/7vNMluY3rBroxJhkYDSwDdgKzjTHbReRNEemdYdcBwExjjMmfUq0zCWfYf2Y/ItCxIyxYADt2wCOP2O3TpkHv3rB2rW3NK6U8x6Fzh7hz6p1s/X0rc+6fw0MtHnK6JLeSrd9RjDGLjTGNjDH1jTHvuNa9boyZn2GfscaYq8ao57WJERNp8EkD+s7qy6r9qzDG0KQJlC9vt1+6BD/9BB06QKtWMGsWJCfnd1VKqfy2/ffttJnShuPxx/nhgR/oHdD7+m8qYgpdp9OQ24bwYtsXWXdwHZ2ndab5pOZ8EfkFF5IuALZ//dAhmDgRzp2DAQOgTx9na1ZK5c7Ph3+m3ZftSDEprH1oLe1qt3O6JLck+dxDck0hISEmIiIix+9PSErgP9v+w/hfxrP5xGYqla7Eo0GP8ljLx6hVoRZghzYuWADe3tCtmw34d9+Fxx8HP7+8+kqUUvlpyZ4l3Df7PmqWr8kPD/xA3ZvqOl2So0Qk0hiT6XwGhTbQ0xhjWHdoHeN/Gc+c6DkA9G3clzGtx9CuVjtEJH3f+fOhb18oVsy23J99FgIDc12CUiqf/HvLv3lo3kPcWvVWlgxews0+NztdkuM8OtAzOnj2IJ9t+IwvNn7BmYtnC
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_training_analysis()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ochTgkyqwHIe"
},
"source": [
"### Correction du surapprentissage"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zXb2ZxKv4gpi"
},
"source": [
"Vous devriez reconnaître le problème auquel vous avez affaire : **le surapprentissage**. Ce problème est classique dès lors que l'on travaille sur des bases de données de taille réduite en apprentissage profond.\n",
" En effet, le réseau que vous avez créé compte normalement (si vous avez suivi les indications) plus de trois millions de paramètres. Le problème que vous essayez de résoudre pendant l'entraînement consiste à établir 450 000 paramètres avec seulement 2000 exemples : c'est trop peu !\n",
"\n",
"Afin de limiter ce surapprentissage, nous pouvons appliquer les techniques de régularisation vues pendant le 2nd cours. En traitement d'image, une des techniques les plus couramment utilisées est **l'augmentation de la base de données**.\n",
"\n",
"Nous allons introduire un objet *ImageDataGenerator* pour appliquer des transformations supplémentaires aux images de notre base de données. A vous de chercher dans la documentation à quoi correspondent les différents paramètres présentés ci-dessous."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"id": "90Wlyt6Gwm6v",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"from tensorflow.keras.preprocessing.image import ImageDataGenerator\n",
"\n",
"train_datagen = ImageDataGenerator(\n",
" rotation_range=40,\n",
" width_shift_range=0.2,\n",
" height_shift_range=0.2,\n",
" shear_range=0.2,\n",
" zoom_range=0.2,\n",
" horizontal_flip=True)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TkZEqml-4ccl"
},
"source": [
"La cellule suivante vous permet de visualiser des images passées à travers notre boucle d'augmentation de données. Observez comment les valeurs manquantes (par exemple, dans le cas d'une rotation) sont comblées."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"id": "nqDBaNs94ccm",
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAAEICAYAAABf40E1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9W6ht7dbnB/3ac+i9jzHmnOvw7ne/tasqlUTjpReKxIsoBA8gGs2NRI2GCJG6EgyaowRU8BBvogFB+SBClECpJBDFgKghSG40VhRECyWEJKas5Dvs/a615hxj9P4cmhetPX2Mufa7vvpSX+2qFb7V957vXGuuMceh9/60p7X//9/+TVSVb8e349vxR/cIf7XfwLfj2/Ht+Kt7fAsC345vxx/x41sQ+HZ8O/6IH9+CwLfj2/FH/PgWBL4d344/4se3IPDt+Hb8ET++BYFvx7fjj/jxLQh8O/6yHSLyr4jIf+Sv9vv4dvxbO74FgW/Ht+OP+PEtCHw7fvIQkb9GRP4pEfkdEfk9Efkfici/U0T+Wf/774rIPyEib/3x/3PgTwH/GxF5FpG/76/qB/h2/IEP+SYb/nZ8fohIBP5F4J8F/iGgAf8+4N8A/nrg/wQ8Af8k8C+q6t/tv/evAP8lVf0//JV/19+Ov9Qj/dV+A9+Or/L4G4E/Dvy9qlr9Z/+8f/+X/PvviMg/Avw3/kq/uW/HX97jWxD4dvzU8dcA/+pdAABARH4A/lHgPwg8YuXkr/7Kv71vx1/O4xsm8O34qeP/C/wpEfl8k/jvAgr8u1X1CfgvAHL3799qy38bHt+CwLfjp47/C/AXgH9YRE4isojI34Tt/s/ABxH5E8Df+9nv/ZvAv+Ov7Fv9dvxhj29B4Nvxa4eqNuA/CfwNwL8G/OvAfwb4bwH/XuAD8L8F/qnPfvW/B/xDIvKjiPw9f+Xe8bfjD3N8Ywe+Hd+OP+LHt0zg2/Ht+CN+fAsC345vxx/x4zcWBETkPyYi/28R+ZdE5B/4Tb3Ot+Pb8e34wx2/EUzAFWf/H+A/ioFK/wLwn1PV/9df9hf7dnw7vh1/qOM3JRb6G4F/SVX/ZQAR+TPA3wr8ZBAIcdKYFoJ0gigC9l0UcRZaAO1qRLQqFrsGRW3f/V8Z/6SqdIXmv6cq2LMLlgQJguy/J3L/Xe3/Ov7sz70HzfHeBBEQEX/fAvZ/eu+oKp8HWhF/fQHxZEz9/e6vOT7E/lh7v33/3P5vIqDNP3W3fxGF3hBRgnRiEEIUgsj+PseJfUXy6+0Pur9PEAl07f552H/26vcUJNhrxJgIQZAQuL2UvXYI4s8r+7m9fV67rvt7ErtCyOfn/vbGf/1T3J7n9fn2izJe5+5z3v3n7vv9s332WkH8PMtnj/v193G7h8Yllf1c8Ore0tfn4/4d+TXX/c+6/4OqouM+653un6krtNZpre1rSJXfVdXvP/98v6kg8Ccwwck4/nXg33//ABH508CfBgjxwPtf/AeYwoU5FXKoHKeNFDtL7kTpRFG0NrR3Wun0Br2B+mLufpI60EVpKEU7pSsva2WtsFbYaqRpBI4gCSRji0yQ0BDphNgQbUCDVqA3et3sZPeG9gbqASsIKQg5BWIMzDnZQgjCtq1s60pvtnjEF72EADIjISJhAgmoBGqttNboDVBBZAKyvcd4opO41oBKRGOGFC2WbZ8QKklWUqgkKYT6I0kKb06V0zHzeJpY5okYwx4Mgtz+bO9Raa1bsFUl5myLXSLXdePlfKX5DTlNC4LYdagd7XA8Hpnnmbfv3jLNE3maEF/0KSXylFmWhRgjQYRaC701ai1o62jv0DuoBdMQAilGOhWl07T4vXNbgCEEC2wi+8Jv7RZ8QwiEEAkhIhJQD2atNvpYNNo9INyCKcK+mMB+HIJdu5SzfY+J2wYkvxY+VJXeO2Urdl8SCHmCkDwINKgr9Aq90stG10Ztla5KR2ka6SrUbu+l97FJdNq20WphPb+wrRtlXbmUxlY7H66Fy+XK88sLc4rEIHx6Wf/Vn1qsf9Vkw6r6W8BvAaTlZ1rkDVuJlLaS40ZXIcdK74UchSl15jlYdnAt9DqCQUNVCGoLKRDoKMF36hxhOgXW0jiXxvNlY2tQe0c1obogsiBktEVUAr0HQozE0JAUQBsxBHqraKtoE7Q3Wm9IU3pXWldCaPSuhCAEhN6hk+jSEYEQEr1B2RQQVALQbWcJUFukdSHGjISEhJlOQkk0JjoRDQFCsACgFVpDKAQaMXQiG9JXIitTapyOieMhsSwTyRcM2A6bYvTA1v2GtZvWduxICAkFrmth3SqlNWJISIjEEO1z12prRgOHw4HlsHA4HAgxIGFkEkKeMzllYorgi6O1aov/swAAECQQQ0Ci0Fuju4L5fvMdi398JvVdd3yJBETCHgDAsrPeO623/fEWAGyZ+sm57bp691ohEGLYg8HvFwDuX6vWCjEhOVsACBHaZkFAG3S7r3pvlnGpvxsV/xqL//Z+6UprhVY26nWllkKpja3DtSnPzy+01khBmHMip8inl/Un1+JvKgj8eUx/Po4/6T/7wiFIPkJIVL3QdUVLILdK18AUG7U3ujZiwHaSEAhxBIK+72RdFUGIYDusp2T4DhhEKM2yg9o7tSmqja7Zd94ARLQJtQshWOobothNFQJdAtI9S/CdpHV/bel+Awuop/HBMoGmga6BTqBrQnuwm4OISEb9vmohIQSaZJSIEuka7WZTXyjaQAuijSCNKI0cOtQVbWcOh84yBQ5LZvLsZKygkQGol1UjAIy0XEKwGxWhd+W6rtTWASGmTIzRdtumaFOmPJFyZl5mpnlCYgAPADFGYowWAGK07EFt0fex+MdOrErwVDnGAKK0XnyB3gLA/eL/UgAA9izAMgBb0WNhqu6/5M8/Vrveno/71xsBIHr24WXc7TdfHZZV2aYgMdl1DtHChnZf/PY1sksLAJ2OeukaLADoeB17v7YZNeq2UbdCqZXSLOu9bpVrqdRaEZQUAzklpvTlpf6bCgL/AvDvEpG/Hlv8/1ngb//io0UgzEjI9CL0HtDeqL0g9D0aCp0UYZkiAdtx7eL69w6iDcEWTFQrEyIdkUBMlh2UrtRWCNVuwqod1WoXVm1RqgbU6/tAh+gpqAhBQUUQVS8NdF9QtSohKMqowcN+E/ceUCKQ6BpQDUBCSIhM9GAprUryOz6hYr+j6umu2qKREQioCI2ABQPVgvYrUxaWWZhyJnn6PTbREAwLuX/f++IZnzEES4e7sm11LMN9YdVqNzgKMSXm2QJAzvnVAg3RFk+M0bIDMWyn94Zqt69+W7ji5y2EQKfStd2BFT8dAOB1EODuve54gNfavdvr2d4wlvAANfy5fuL+tLIm7N8/x6Puj/E69qUQEoRogdVT+Vsg6Hs2NrIAC0CWBdxt/DcooHd6q7RS7Kt1aleqwloa21Y9CwikFEkxkuKXicDfSBBQ1Soi/2XgfwdE4H+qqv/PL/+CQVpBomUEaaHXma4bH7dPTHVjihulbeRYqa2SYmdKQp4DeVHqWuhVqaV7VAUIRAKBSEfp2pEgTKLkR2GrynltnLeNrW40bSiJrgsiE4GMNqGJWokQOiEkJEdEOxKivZaXCWj36B0QJtSBx9o62oUgC0EykieiRsMxYgSv8XsfUd8CkNXj/qWKaEdrIyYlT0qvBW0bUVeCFnS7MMWVaYKHY2aZI9OUiPcLB8MCgNdlQLM6OMTkwF7i+XxhK40OxBCJyfCT3pRWO0JgmmYOxwPH49Hq/RR3HACBaZpIOZFSApRWG61VK+OaBwFVRC0ARC8jOpb5dW2Il3Y3oPLXs4B+91liDF4KxP0G673ti1JV9131VRkwwoLe4o6VE14eegZ6KwN+PWCMAGBZgC1mSRlCtOvYt70EoFW0FXqvNG23MoCwb3xtxw11x6Ra2ahlo1xXSq1stbF2ZW3w6eVM2TZigClHTstMCuGn0xU/fmOYgKr+M8A/8wd6LAbm+W1gN3+Y0C50qVQCdCE0pasQxL5b6t0d7A9IUqKqAWvN0nO7wLZ7gxA9K54kILHTsngar6zVARnPy
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"example_x, example_y = train_datagen.flow(x_train, y_train, batch_size=1).next()\n",
"for i in range(0,1):\n",
" plt.imshow(example_x[i])\n",
" plt.title(CLASSES[int(example_y[i])])\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KdqsPbKm5TkR"
},
"source": [
"Nous pouvons maintenant recréer notre modèle et relancer l'entraînement."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"id": "VJLABQ67yLms",
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential_3\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" conv2d_4 (Conv2D) (None, 126, 126, 32) 896 \n",
" \n",
" max_pooling2d_4 (MaxPooling (None, 63, 63, 32) 0 \n",
" 2D) \n",
" \n",
" conv2d_5 (Conv2D) (None, 61, 61, 64) 18496 \n",
" \n",
" max_pooling2d_5 (MaxPooling (None, 30, 30, 64) 0 \n",
" 2D) \n",
" \n",
" conv2d_6 (Conv2D) (None, 28, 28, 96) 55392 \n",
" \n",
" max_pooling2d_6 (MaxPooling (None, 14, 14, 96) 0 \n",
" 2D) \n",
" \n",
" conv2d_7 (Conv2D) (None, 12, 12, 128) 110720 \n",
" \n",
" max_pooling2d_7 (MaxPooling (None, 6, 6, 128) 0 \n",
" 2D) \n",
" \n",
" flatten_1 (Flatten) (None, 4608) 0 \n",
" \n",
" dense_6 (Dense) (None, 512) 2359808 \n",
" \n",
" dense_7 (Dense) (None, 1) 513 \n",
" \n",
"=================================================================\n",
"Total params: 2,545,825\n",
"Trainable params: 2,545,825\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n",
"Epoch 1/50\n",
"200/200 [==============================] - 16s 78ms/step - loss: 0.6924 - accuracy: 0.5285 - val_loss: 0.6877 - val_accuracy: 0.5130\n",
"Epoch 2/50\n",
"200/200 [==============================] - 14s 72ms/step - loss: 0.6874 - accuracy: 0.5385 - val_loss: 0.6829 - val_accuracy: 0.5470\n",
"Epoch 3/50\n",
"200/200 [==============================] - 14s 71ms/step - loss: 0.6789 - accuracy: 0.5710 - val_loss: 0.6462 - val_accuracy: 0.6600\n",
"Epoch 4/50\n",
"200/200 [==============================] - 14s 72ms/step - loss: 0.6515 - accuracy: 0.6275 - val_loss: 0.6392 - val_accuracy: 0.6290\n",
"Epoch 5/50\n",
"200/200 [==============================] - 14s 71ms/step - loss: 0.6236 - accuracy: 0.6670 - val_loss: 0.6177 - val_accuracy: 0.6560\n",
"Epoch 6/50\n",
"200/200 [==============================] - 14s 71ms/step - loss: 0.5949 - accuracy: 0.6830 - val_loss: 0.6098 - val_accuracy: 0.6720\n",
"Epoch 7/50\n",
"200/200 [==============================] - 14s 70ms/step - loss: 0.5918 - accuracy: 0.6880 - val_loss: 0.5917 - val_accuracy: 0.6870\n",
"Epoch 8/50\n",
"200/200 [==============================] - 14s 70ms/step - loss: 0.5752 - accuracy: 0.7030 - val_loss: 0.5890 - val_accuracy: 0.6820\n",
"Epoch 9/50\n",
"200/200 [==============================] - 14s 70ms/step - loss: 0.5696 - accuracy: 0.7115 - val_loss: 0.5727 - val_accuracy: 0.7010\n",
"Epoch 10/50\n",
"200/200 [==============================] - 14s 70ms/step - loss: 0.5652 - accuracy: 0.7115 - val_loss: 0.5878 - val_accuracy: 0.6940\n",
"Epoch 11/50\n",
"200/200 [==============================] - 14s 71ms/step - loss: 0.5535 - accuracy: 0.7175 - val_loss: 0.5721 - val_accuracy: 0.7160\n",
"Epoch 12/50\n",
"200/200 [==============================] - 14s 71ms/step - loss: 0.5465 - accuracy: 0.7165 - val_loss: 0.5863 - val_accuracy: 0.7080\n",
"Epoch 13/50\n",
"200/200 [==============================] - 14s 71ms/step - loss: 0.5432 - accuracy: 0.7300 - val_loss: 0.5491 - val_accuracy: 0.7330\n",
"Epoch 14/50\n",
"200/200 [==============================] - 14s 68ms/step - loss: 0.5356 - accuracy: 0.7255 - val_loss: 0.5415 - val_accuracy: 0.7170\n",
"Epoch 15/50\n",
"200/200 [==============================] - 14s 69ms/step - loss: 0.5385 - accuracy: 0.7225 - val_loss: 0.5958 - val_accuracy: 0.7020\n",
"Epoch 16/50\n",
"200/200 [==============================] - 14s 69ms/step - loss: 0.5150 - accuracy: 0.7455 - val_loss: 0.5447 - val_accuracy: 0.7210\n",
"Epoch 17/50\n",
"200/200 [==============================] - 14s 71ms/step - loss: 0.5168 - accuracy: 0.7405 - val_loss: 0.5414 - val_accuracy: 0.7290\n",
"Epoch 18/50\n",
"200/200 [==============================] - 14s 69ms/step - loss: 0.5131 - accuracy: 0.7465 - val_loss: 0.5588 - val_accuracy: 0.7350\n",
"Epoch 19/50\n",
"200/200 [==============================] - 14s 70ms/step - loss: 0.5034 - accuracy: 0.7560 - val_loss: 0.5406 - val_accuracy: 0.7500\n",
"Epoch 20/50\n",
"200/200 [==============================] - 14s 71ms/step - loss: 0.4934 - accuracy: 0.7605 - val_loss: 0.5458 - val_accuracy: 0.7630\n",
"Epoch 21/50\n",
"200/200 [==============================] - 13s 67ms/step - loss: 0.4879 - accuracy: 0.7605 - val_loss: 0.5965 - val_accuracy: 0.7190\n",
"Epoch 22/50\n",
"200/200 [==============================] - 14s 71ms/step - loss: 0.4927 - accuracy: 0.7530 - val_loss: 0.5636 - val_accuracy: 0.7440\n",
"Epoch 23/50\n",
"200/200 [==============================] - 14s 71ms/step - loss: 0.4876 - accuracy: 0.7760 - val_loss: 0.5400 - val_accuracy: 0.7260\n",
"Epoch 24/50\n",
"200/200 [==============================] - 14s 70ms/step - loss: 0.4752 - accuracy: 0.7800 - val_loss: 0.5035 - val_accuracy: 0.7510\n",
"Epoch 25/50\n",
"200/200 [==============================] - 14s 71ms/step - loss: 0.4677 - accuracy: 0.7880 - val_loss: 0.5696 - val_accuracy: 0.7060\n",
"Epoch 26/50\n",
"200/200 [==============================] - 14s 71ms/step - loss: 0.4773 - accuracy: 0.7675 - val_loss: 0.4967 - val_accuracy: 0.7710\n",
"Epoch 27/50\n",
"200/200 [==============================] - 14s 71ms/step - loss: 0.4535 - accuracy: 0.7915 - val_loss: 0.4772 - val_accuracy: 0.7650\n",
"Epoch 28/50\n",
"200/200 [==============================] - 14s 69ms/step - loss: 0.4499 - accuracy: 0.7885 - val_loss: 0.5210 - val_accuracy: 0.7610\n",
"Epoch 29/50\n",
"200/200 [==============================] - 14s 71ms/step - loss: 0.4423 - accuracy: 0.7965 - val_loss: 0.4957 - val_accuracy: 0.7740\n",
"Epoch 30/50\n",
"200/200 [==============================] - 14s 71ms/step - loss: 0.4485 - accuracy: 0.7820 - val_loss: 0.4817 - val_accuracy: 0.7820\n",
"Epoch 31/50\n",
"200/200 [==============================] - 14s 71ms/step - loss: 0.4370 - accuracy: 0.8020 - val_loss: 0.4877 - val_accuracy: 0.7730\n",
"Epoch 32/50\n",
"200/200 [==============================] - 14s 72ms/step - loss: 0.4311 - accuracy: 0.8090 - val_loss: 0.4957 - val_accuracy: 0.7600\n",
"Epoch 33/50\n",
"200/200 [==============================] - 15s 72ms/step - loss: 0.4318 - accuracy: 0.7990 - val_loss: 0.5394 - val_accuracy: 0.7420\n",
"Epoch 34/50\n",
"200/200 [==============================] - 14s 69ms/step - loss: 0.4159 - accuracy: 0.8105 - val_loss: 0.5039 - val_accuracy: 0.7820\n",
"Epoch 35/50\n",
"200/200 [==============================] - 14s 69ms/step - loss: 0.4236 - accuracy: 0.7985 - val_loss: 0.4754 - val_accuracy: 0.8000\n",
"Epoch 36/50\n",
"200/200 [==============================] - 14s 68ms/step - loss: 0.4060 - accuracy: 0.8255 - val_loss: 0.4881 - val_accuracy: 0.7840\n",
"Epoch 37/50\n",
"200/200 [==============================] - 14s 69ms/step - loss: 0.4006 - accuracy: 0.8165 - val_loss: 0.4649 - val_accuracy: 0.7770\n",
"Epoch 38/50\n",
"200/200 [==============================] - 14s 72ms/step - loss: 0.3998 - accuracy: 0.8245 - val_loss: 0.4597 - val_accuracy: 0.7870\n",
"Epoch 39/50\n",
"200/200 [==============================] - 14s 70ms/step - loss: 0.4053 - accuracy: 0.8120 - val_loss: 0.4789 - val_accuracy: 0.7800\n",
"Epoch 40/50\n",
"200/200 [==============================] - 14s 70ms/step - loss: 0.4107 - accuracy: 0.8020 - val_loss: 0.5138 - val_accuracy: 0.7660\n",
"Epoch 41/50\n",
"200/200 [==============================] - 14s 70ms/step - loss: 0.3770 - accuracy: 0.8270 - val_loss: 0.4722 - val_accuracy: 0.7890\n",
"Epoch 42/50\n",
"200/200 [==============================] - 14s 69ms/step - loss: 0.3939 - accuracy: 0.8210 - val_loss: 0.4114 - val_accuracy: 0.8150\n",
"Epoch 43/50\n",
"200/200 [==============================] - 14s 69ms/step - loss: 0.3666 - accuracy: 0.8410 - val_loss: 0.4861 - val_accuracy: 0.7800\n",
"Epoch 44/50\n",
"200/200 [==============================] - 14s 69ms/step - loss: 0.3897 - accuracy: 0.8235 - val_loss: 0.4669 - val_accuracy: 0.7810\n",
"Epoch 45/50\n",
"200/200 [==============================] - 14s 68ms/step - loss: 0.3590 - accuracy: 0.8490 - val_loss: 0.4122 - val_accuracy: 0.8120\n",
"Epoch 46/50\n",
"200/200 [==============================] - 14s 71ms/step - loss: 0.3679 - accuracy: 0.8285 - val_loss: 0.4504 - val_accuracy: 0.7910\n",
"Epoch 47/50\n",
"200/200 [==============================] - 14s 70ms/step - loss: 0.3470 - accuracy: 0.8430 - val_loss: 0.4709 - val_accuracy: 0.8110\n",
"Epoch 48/50\n",
"200/200 [==============================] - 14s 70ms/step - loss: 0.3471 - accuracy: 0.8495 - val_loss: 0.4283 - val_accuracy: 0.8070\n",
"Epoch 49/50\n",
"200/200 [==============================] - 14s 70ms/step - loss: 0.3461 - accuracy: 0.8595 - val_loss: 0.4441 - val_accuracy: 0.7980\n",
"Epoch 50/50\n",
"200/200 [==============================] - 14s 69ms/step - loss: 0.3244 - accuracy: 0.8595 - val_loss: 0.4269 - val_accuracy: 0.8070\n"
]
}
],
"source": [
"model = Sequential()\n",
"model.add(Conv2D(32, 3, activation=\"relu\", input_shape=x_train.shape[1:]))\n",
"model.add(MaxPooling2D(pool_size=(2, 2)))\n",
"model.add(Conv2D(64, 3, activation=\"relu\"))\n",
"model.add(MaxPooling2D(pool_size=(2, 2)))\n",
"model.add(Conv2D(96, 3, activation=\"relu\"))\n",
"model.add(MaxPooling2D(pool_size=(2, 2)))\n",
"model.add(Conv2D(128, 3, activation=\"relu\"))\n",
"model.add(MaxPooling2D(pool_size=(2, 2)))\n",
"model.add(Flatten())\n",
"model.add(Dense(512, activation=\"relu\"))\n",
"model.add(Dense(1, activation=\"sigmoid\"))\n",
"model.summary()\n",
"\n",
"adam = optimizers.Adam(learning_rate=3e-4)\n",
"model.compile(loss='binary_crossentropy', optimizer=adam, metrics=['accuracy'])\n",
"\n",
"history = model.fit(train_datagen.flow(x_train, y_train, batch_size=10), validation_data=(x_val, y_val), epochs=50)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vM2CLNX8wbfv"
},
"source": [
"### Analyse des résultats"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"id": "9smZiILLyt8g",
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAEICAYAAABRSj9aAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAABOuklEQVR4nO3deVhU1RvA8e8Lihvkirkvlbvmhlu5ZpaamVmZmKlZaZqVtvurFNc2K8ultNRWNXMvNZfUXFITtwyX3FBRVNwFFYF5f3/cAQcYYNgEhvN5nnmYuffce98Lwztnzj33HFFVDMMwDPflkdUBGIZhGJnLJHrDMAw3ZxK9YRiGmzOJ3jAMw82ZRG8YhuHmTKI3DMNwcybR50IiskxEemd02awkIsEicn8m7FdF5C77869E5D1XyqbhOE+JyIq0xmkYyRHTjz5nEJFwh5cFgUggxv66v6r+dOujyj5EJBh4TlVXZfB+FaiiqgczqqyIVAKOAHlVNTpDAjWMZOTJ6gAM16iqd+zz5JKaiOQxycPILsz7MXswTTc5nIi0FpEQEXlLRE4BM0SkqIj8JiJhInLB/rycwzZrReQ5+/M+IrJBRMbZyx4RkQ5pLFtZRNaJyBURWSUik0TkxyTidiXGUSKy0b6/FSJSwmH90yJyVETOicg7yfx+mojIKRHxdFj2qIj8Y3/eWEQ2ichFEQkVkYki4pXEvr4VkdEOr9+wb3NSRPomKPuQiOwQkcsiclxEAhxWr7P/vCgi4SLSLPZ367D9PSKyVUQu2X/e4+rvJpW/52IiMsN+DhdEZKHDukdEZKf9HA6JSHv78njNZCISEPt3FpFK9iasZ0XkGLDavvwX+9/hkv09Usth+wIi8on973nJ/h4rICJLROSlBOfzj4g86uxcjaSZRO8eSgHFgIpAP6y/6wz76wrANWBiMts3AfYDJYCPgGkiImkoOxP4GygOBABPJ3NMV2LsATwDlAS8gNcBRKQm8KV9/2XsxyuHE6q6BYgA7kuw35n25zHAEPv5NAPaAgOTiRt7DO3t8bQDqgAJrw9EAL2AIsBDwAAR6WJf19L+s4iqeqvqpgT7LgYsAb6wn9unwBIRKZ7gHBL9bpxI6ff8A1ZTYC37vj6zx9AY+B54w34OLYHgJI7hTCugBvCg/fUyrN9TSWA74NjUOA5oCNyD9T5+E7AB3wE9YwuJSF2gLNbvxkgNVTWPHPbA+oe73/68NXADyJ9M+XrABYfXa7GafgD6AAcd1hUEFCiVmrJYSSQaKOiw/kfgRxfPyVmM7zq8Hgj8bn8+DJjtsK6Q/XdwfxL7Hg1Mtz/3wUrCFZMoOxhY4PBagbvsz78FRtufTwc+cChX1bGsk/2OBz6zP69kL5vHYX0fYIP9+dPA3wm23wT0Sel3k5rfM1AaK6EWdVJuSmy8yb3/7K8DYv/ODud2RzIxFLGXKYz1QXQNqOukXH7gAtZ1D7A+ECZnxv+Uuz9Mjd49hKnq9dgXIlJQRKbYvwpfxmoqKOLYfJHAqdgnqnrV/tQ7lWXLAOcdlgEcTypgF2M85fD8qkNMZRz3raoRwLmkjoVVe+8qIvmArsB2VT1qj6OqvTnjlD2OsVi1+5TEiwE4muD8mojIGnuTySXgBRf3G7vvowmWHcWqzcZK6ncTTwq/5/JYf7MLTjYtDxxyMV5n4n43IuIpIh/Ym38uc/ObQQn7I7+zY9nf0z8DPUXEA/DH+gZipJJJ9O4hYdep14BqQBNVvY2bTQVJNcdkhFCgmIgUdFhWPpny6Ykx1HHf9mMWT6qwqu7BSpQdiN9sA1YT0D6sWuNtwP/SEgPWNxpHM4HFQHlVLQx85bDflLq6ncRqanFUATjhQlwJJfd7Po71NyviZLvjwJ1J7DMC69tcrFJOyjieYw/gEazmrcJYtf7YGM4C15M51nfAU1hNalc1QTOX4RqT6N2TD9bX4Yv29t7hmX1Aew05EAgQES8RaQY8nEkxzgU6iUhz+4XTkaT8Xp4JvIKV6H5JEMdlIFxEqgMDXIxhDtBHRGraP2gSxu+DVVu+bm/v7uGwLgyryeSOJPa9FKgqIj1EJI+IPAnUBH5zMbaEcTj9PatqKFbb+WT7Rdu8IhL7QTANeEZE2oqIh4iUtf9+AHYC3e3l/YDHXYghEutbV0Gsb02xMdiwmsE+FZEy9tp/M/u3L+yJ3QZ8gqnNp5lJ9O5pPFAAq7a0Gfj9Fh33KawLmuew2sV/xvoHd2Y8aYxRVYOAF7GSdyhWO25ICpvNwrpAuFpVzzosfx0rCV8BvrbH7EoMy+znsBo4aP/paCAwUkSuYF1TmOOw7VVgDLBRrN4+TRPs+xzQCas2fg7r4mSnBHG7ajzJ/56fBqKwvtWcwbpGgar+jXWx9zPgEvAnN79lvIdVA78AjCD+NyRnvsf6RnUC2GOPw9HrwG5gK3Ae+JD4uel7oA7WNR8jDcwNU0amEZGfgX2qmunfKAz3JSK9gH6q2jyrY8mpTI3eyDAi0khE7rR/1W+P1S67MIvDMnIwe7PYQGBqVseSk5lEb2SkUlhd/8Kx+oAPUNUdWRqRkWOJyINY1zNOk3LzkJEM03RjGIbh5kyN3jAMw81lu0HNSpQooZUqVcrqMAzDMHKUbdu2nVVVX2frsl2ir1SpEoGBgVkdhmEYRo4iIgnvpo5jmm4MwzDcnEn0hmEYbs4kesMwDDeX7dronYmKiiIkJITr16+nXNjIEvnz56dcuXLkzZs3q0MxDCOBHJHoQ0JC8PHxoVKlSiQ9H4aRVVSVc+fOERISQuXKlbM6HMMwEsgRTTfXr1+nePHiJslnUyJC8eLFzTcuw8imckSiB0ySz+bM38cwsq8c0XRjGIbhDlRh1izYty/+8tat4b77nG6SIUyid8G5c+do27YtAKdOncLT0xNfX+sGtL///hsvL68ktw0MDOT777/niy++SPYY99xzD3/99VfGBW0YRrZz8iQMGACXL4Pjl2APD5Pos1zx4sXZuXMnAAEBAXh7e/P666/HrY+OjiZPHue/Sj8/P/z8/FI8hknyhuG+YmLA0xPKloXAQLjzTiu53yo5po0+u+nTpw8vvPACTZo04c033+Tvv/+mWbNm1K9fn3vuuYf9+/cDsHbtWjp16gRYHxJ9+/aldevW3HHHHfFq+d7e3nHlW7duzeOPP0716tV56qmniB1hdOnSpVSvXp2GDRvy8ssvx+3XUXBwMC1atKBBgwY0aNAg3gfIhx9+SJ06dahbty5vv/02AAcPHuT++++nbt26NGjQgEOH0jMftGEYCV29Cg8/DB9/bL2uUuXWJnlwsUZvn0Tic8AT+EZVP0iwvgLWJL5F7GXeVtWlIlIJ2AvstxfdrKovpDfo1q0TL+vWDQYOtH6pHTsmXt+nj/U4exYeTzDD5dq1aYsjJCSEv/76C09PTy5fvsz69evJkycPq1at4n//+x/z5s1LtM2+fftYs2YNV65coVq1agwYMCBR3/MdO3YQFBREmTJluPfee9m4cSN+fn7079+fdevWUblyZfz9/Z3GVLJkSVauXEn+/Pk5cOAA/v7+BAYGsmzZMhYtWsSWLVsoWLAg58+fB+Cpp57i7bff5tFHH+X69evYbLa0/TIMw0jk/HkryW/eDF26ZF0cKSZ6EfEEJgHtsObl3Coii1V1j0Oxd4E5qvqliNTEmty4kn3dIVWtl6FRZxNPPPEEnp6eAFy6dInevXtz4MABRISoqCin2zz00EPky5ePfPnyUbJkSU6fPk25cuXilWncuHHcsnr16hEcHIy3tzd33HFHXD91f39/pk5NPOlOVFQUgwYNYufOnXh6evLff/8BsGrVKp555hkKFiwIQLFixbhy5QonTpzg0UcfBaybngzDyBgnTkD79vDffzBnDjz2WNbF4kqNvjFwUFUPA4jIbKwp4hwTvQK32Z8XBk5mZJAJJVcDL1gw+fUlSqS9Bp9QoUKF4p6/9957tGnThgULFhAcHExrZ187gHz58sU99/T0JDo6Ok1lkvLZZ59x++23s2vXLmw2m0nehpEKa9ZA3rzQPJ2z00ZHQ9u2VrJftixzL
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAEICAYAAABRSj9aAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAABTsElEQVR4nO3dd3gVRffA8e9JgdAhJPTQS+gJhABBiqBUpQiigAKvPxCwYAX1VRHF3hsWELHyUgVBQZogCEpHaQkloYROIHRSz++PvTek56belPk8Tx7unZ3dPZuEcyezszOiqhiGYRiFl4uzAzAMwzByl0n0hmEYhZxJ9IZhGIWcSfSGYRiFnEn0hmEYhZxJ9IZhGIWcSfRGpojIMhEZkdN1nUlEDovIbblwXBWR+rbXX4jIi47UzcJ5honIiqzGmc5xu4hIeE4f18h7bs4OwMh9InIl0duSQBQQZ3s/RlV/dPRYqtorN+oWdqo6NieOIyK1gTDAXVVjbcf+EXD4Z2gUPSbRFwGqWtr+WkQOA6NUdVXyeiLiZk8ehmEUHqbrpgiz/2kuIs+IyClgpohUEJFfROSsiFywva6RaJ+1IjLK9nqkiPwpIu/a6oaJSK8s1q0jIutE5LKIrBKRqSLyQxpxOxLjFBHZYDveChHxSrT9fhE5IiIRIvJ8Ot+ftiJySkRcE5UNEJF/ba8DReQvEYkUkZMi8qmIFEvjWN+IyKuJ3k+w7XNCRB5IVrePiOwQkUsickxEJifavM72b6SIXBGR9vbvbaL9g0Rki4hctP0b5Oj3Jj0i0ti2f6SI7BGRvom29RaRvbZjHheRp23lXrafT6SInBeR9SJi8k4eM99wowrgCdQCHsT6nZhpe18TuA58ms7+bYEQwAt4G5ghIpKFurOAzUBFYDJwfzrndCTGocB/gEpAMcCeeJoAn9uOX812vhqkQlU3AVeBrsmOO8v2Og54wnY97YFuwEPpxI0thp62eG4HGgDJ7w9cBYYD5YE+wDgR6W/b1sn2b3lVLa2qfyU7tifwK/Cx7dreB34VkYrJriHF9yaDmN2BJcAK236PAj+KSCNblRlY3YBlgGbA77byp4BwwBuoDPwXMPOu5DGT6I144CVVjVLV66oaoaoLVPWaql4GXgM6p7P/EVWdrqpxwLdAVaz/0A7XFZGaQBtgkqpGq+qfwOK0TuhgjDNVdb+qXgfmAn628kHAL6q6TlWjgBdt34O0/A8YAiAiZYDetjJUdZuq/q2qsap6GPgylThSM9gW325VvYr1wZb4+taq6i5VjVfVf23nc+S4YH0wHFDV721x/Q8IBu5MVCet70162gGlgTdtP6PfgV+wfW+AGKCJiJRV1Ququj1ReVWglqrGqOp6NRNs5TmT6I2zqnrD/kZESorIl7aujUtYXQXlE3dfJHPK/kJVr9lels5k3WrA+URlAMfSCtjBGE8len0tUUzVEh/blmgj0joXVuv9LhEpDtwFbFfVI7Y4Gtq6JU7Z4ngdq3WfkSQxAEeSXV9bEVlj65q6CIx18Lj2Yx9JVnYEqJ7ofVrfmwxjVtXEH4qJjzsQ60PwiIj8ISLtbeXvAAeBFSISKiLPOnYZRk4yid5I3rp6CmgEtFXVstzsKkirOyYnnAQ8RaRkojKfdOpnJ8aTiY9tO2fFtCqr6l6shNaLpN02YHUBBQMNbHH8NysxYHU/JTYL6y8aH1UtB3yR6LgZtYZPYHVpJVYTOO5AXBkd1ydZ/3rCcVV1i6r2w+rWWYT1lwKqellVn1LVukBf4EkR6ZbNWIxMMoneSK4MVp93pK2/96XcPqGthbwVmCwixWytwTvT2SU7Mc4H7hCRW2w3Tl8h4/8Hs4DHsD5Q5iWL4xJwRUR8gXEOxjAXGCkiTWwfNMnjL4P1F84NEQnE+oCxO4vV1VQ3jWMvBRqKyFARcRORe4AmWN0s2bEJq/U/UUTcRaQL1s9otu1nNkxEyqlqDNb3JB5ARO4Qkfq2ezEXse5rpNdVZuQCk+iN5D4ESgDngL+B3/LovMOwbmhGAK8Cc7DG+6fmQ7IYo6ruAR7GSt4ngQtYNwvTY+8j/11VzyUqfxorCV8GpttidiSGZbZr+B2rW+P3ZFUeAl4RkcvAJGytY9u+17DuSWywjWRpl+zYEcAdWH/1RAATgTuSxZ1pqhqNldh7YX3fPwOGq2qwrcr9wGFbF9ZYrJ8nWDebVwFXgL+Az1R1TXZiMTJPzH0RIz8SkTlAsKrm+l8UhlHYmRa9kS+ISBsRqSciLrbhh/2w+noNw8gm82SskV9UAX7CujEaDoxT1R3ODckwCgfTdWMYhlHIma4bwzCMQi7fdd14eXlp7dq1nR2GYRhGgbJt27Zzquqd2rZ8l+hr167N1q1bnR2GYRhGgSIiyZ+ITmC6bgzDMAo5hxK9iPQUkRAROZjaXBUi8oGI7LR97ReRyETbRojIAdtXvl9tyDAMo7DJsOvGNlHUVKwpVcOBLSKy2DYHCACq+kSi+o8C/rbX9sfTA7Dm6Nhm2/dCjl6FYRiGkSZH+ugDgYOqGgogIrOxHmbZm0b9Idycu6MHsFJVz9v2XQn0xDbNq2EY+UNMTAzh4eHcuHEj48qGU3l4eFCjRg3c3d0d3seRRF+dpFOqhmMtIJGCiNQC6nBz7o7U9q2eyn4PYi16Qc2aySfyMwwjt4WHh1OmTBlq165N2uvGGM6mqkRERBAeHk6dOnUc3i+nb8beC8y3LSzhMFWdpqoBqhrg7Z3q6CDDMHLRjRs3qFixokny+ZyIULFixUz/5eVIoj9O0rmza5D23Nb3krRbJjP7GobhRCbJFwxZ+Tk5kui3AA3EWry5GFYyT7HMm20+7gpYU5HaLQe6i7WYcwWgu60sV1y8CJs25dbRDcMwCqYME72qxgKPYCXofcBcVd0jIq8kXgUe6wNgduL1IG03YadgfVhsAV6x35jNDV9/De3aQfv2MGcOxMTk1pkMw8hJERER+Pn54efnR5UqVahevXrC++jo6HT33bp1K+PHj8/wHEFBQTkS69q1a7njjjty5Fh5xaEnY1V1KdbKNYnLJiV7PzmNfb8Gvs5ifJly9/0XcXcvx0cfwb33Qo0a8MgjMGECuJhHwwwj36pYsSI7d+4EYPLkyZQuXZqnn346YXtsbCxubqmnq4CAAAICAjI8x8aNG3Mk1oKo0KS/89fP03h6DVZ69eOTpSv4eXE8jRrBihUmyRtGQTRy5EjGjh1L27ZtmThxIps3b6Z9+/b4+/sTFBRESEgIkLSFPXnyZB544AG6dOlC3bp1+fjjjxOOV7p06YT6Xbp0YdCgQfj6+jJs2DDsHRFLly7F19eX1q1bM378+Axb7ufPn6d///60aNGCdu3a8e+//wLwxx9/JPxF4u/vz+XLlzl58iSdOnXCz8+PZs2asX79+hz/nqUl3811kx3jA8czfft0FocspoFnA8ZNGcc9viOBCoSFwUMPweefg5kzzTDS16VLyrLBg63/Q9euQe/eKbePHGl9nTsHgwYl3bZ2bdbiCA8PZ+PGjbi6unLp0iXWr1+Pm5sbq1at4r///S8LFixIsU9wcDBr1qzh8uXLNGrUiHHjxqUYc75jxw727NlDtWrV6NChAxs2bCAgIIAxY8awbt066tSpw5AhQzKM76WXXsLf359Fixbx+++/M3z4cHbu3Mm7777L1KlT6dChA1euXMHDw4Np06bRo0cPnn/+eeLi4rh27VrWvilZUGjaup4lPHmt22sce+IYPwz4Ae9S3jy54knqf1adUYtHsWH3MTZuBD8/mD/f2dEahuGIu+++G1dXVwAuXrzI3XffTbNmzXjiiSfYs2dPqvv06dOH4sWL4+XlRaVKlTh9+nSKOoGBgdSoUQMXFxf8/Pw4fPgwwcHB1K1bN2F8uiOJ/s8//+T+++8HoGvXrkRERHDp0iU6dOjAk08+yccff0xkZCRubm60adOGmTNnMnnyZHbt2kWZMmWy+m3JtELVogco7lacYS2GMazFMHac3MFnWz7jh10/cKrOKXbs+IUhQ+Duu2HMGPjgAyhRwtkRG0b+k14LvGTJ9Ld7e
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_training_analysis()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T_JbNoF46le7"
},
"source": [
"On voit clairement sur les courbes que l'on a limité le sur-apprentissage. Notez aussi d'ailleurs, et c'est important, que l'apprentissage est plus lent : le modèle met plus de temps à prédire correctement l'ensemble d'apprentissage. C'est normal, car on a en quelque sorte \"complexifié le problème\" en introduisant toutes ces déformations de nos images.\n",
"Cette forme de régularisation \"par les données\" s'ajoute aux autres méthodes que nous avons vues précédemment comme la régularisation L1/L2 des poids du réseau et le Dropout. \n",
"\n",
"Vous devriez maintenant atteindre des performances autour de 80% de précision sur l'ensemble de validation, ce qui est bien mais pas complètement satisfaisant : il faudrait pour continuer à s'améliorer probablement s'entraîner plus longtemps, mais également disposer de plus de données.\n",
"\n",
"Une autre solution est d'utiliser le **Transfer Learning**."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aVFqfXs9GrKe"
},
"source": [
"## Transfer learning\n",
"\n",
"L'une des raisons qui peut expliquer le fait que nos résultats soient décevants est que les premières couches de notre réseau convolutif, sensées détecter des caractéristiques utiles pour discriminer chiens et chats, n'ont pas appris de filtres suffisamment généraux à partir des 2000 images d'entraînement. Ainsi, même si ces filtres sont pertinents pour les 2000 images d'entraînement, il y a en fait assez peu de chances que ces filtres puissent bien fonctionner pour la généralisation sur de nouvelles données.\n",
"\n",
"C'est la raison pour laquelle nous avons envie de réutiliser un réseau pré-entrainé sur une large base de données, permettant donc de détecter des caractéristiques qui généraliseront mieux à de nouvelles données.\n",
"\n",
"Dans cette partie, nous allons réutiliser un réseau célèbre, et d'ores et déjà entraîné sur la base de données ImageNet : le réseau VGG-16.\n",
"\n",
"Commençons par récupérer les couches de convolution de ce réseau, et s'en remémorer la composition."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"id": "zRWY8mEQuF9O",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"from tensorflow.keras.applications import VGG16\n",
"\n",
"conv_base = VGG16(weights='imagenet', # On utilise les poids du réseau déjà pré-entrainé sur la base de données ImageNet\n",
" include_top=False, # On ne conserve pas la partie Dense du réseau originel\n",
" input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3))"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"id": "xv_jCMwkuHY4",
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"vgg16\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" input_1 (InputLayer) [(None, 128, 128, 3)] 0 \n",
" \n",
" block1_conv1 (Conv2D) (None, 128, 128, 64) 1792 \n",
" \n",
" block1_conv2 (Conv2D) (None, 128, 128, 64) 36928 \n",
" \n",
" block1_pool (MaxPooling2D) (None, 64, 64, 64) 0 \n",
" \n",
" block2_conv1 (Conv2D) (None, 64, 64, 128) 73856 \n",
" \n",
" block2_conv2 (Conv2D) (None, 64, 64, 128) 147584 \n",
" \n",
" block2_pool (MaxPooling2D) (None, 32, 32, 128) 0 \n",
" \n",
" block3_conv1 (Conv2D) (None, 32, 32, 256) 295168 \n",
" \n",
" block3_conv2 (Conv2D) (None, 32, 32, 256) 590080 \n",
" \n",
" block3_conv3 (Conv2D) (None, 32, 32, 256) 590080 \n",
" \n",
" block3_pool (MaxPooling2D) (None, 16, 16, 256) 0 \n",
" \n",
" block4_conv1 (Conv2D) (None, 16, 16, 512) 1180160 \n",
" \n",
" block4_conv2 (Conv2D) (None, 16, 16, 512) 2359808 \n",
" \n",
" block4_conv3 (Conv2D) (None, 16, 16, 512) 2359808 \n",
" \n",
" block4_pool (MaxPooling2D) (None, 8, 8, 512) 0 \n",
" \n",
" block5_conv1 (Conv2D) (None, 8, 8, 512) 2359808 \n",
" \n",
" block5_conv2 (Conv2D) (None, 8, 8, 512) 2359808 \n",
" \n",
" block5_conv3 (Conv2D) (None, 8, 8, 512) 2359808 \n",
" \n",
" block5_pool (MaxPooling2D) (None, 4, 4, 512) 0 \n",
" \n",
"=================================================================\n",
"Total params: 14,714,688\n",
"Trainable params: 14,714,688\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"conv_base.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bKyLfZcwOYwH"
},
"source": [
"Nous pouvons extraire les caractéristiques, apprises par le réseau de neurones sur ImageNet, de notre base de données d'image de chiens et de chat. L'intérêt, par rapport à la première partie, est qu'il aurait été presque impossible de déduire ces caractéristiques \"générales\" (trouvées sur une immense base de données) depuis notre base de données trop réduite de 2000 images. En revanche, ces caractéristiques générales devraient se révéler utiles pour notre classifieur.\n",
"\n",
"On peut lire sur la structure du réseau VGG résumée grâce à la fonction *summary* ci-dessus que le tenseur de sortie est de dimension $2 \\times 2 \\times 512$, autrement dit que le réseau prédit des caractéristiques de dimension $2 \\times 2 \\times 512$ à partir d'une image de taille $64 \\times 64$.\n",
"\n",
"On va redimensionner cette sortie dans un vecteur de dimension $2048 = 2 \\times 2 \\times 512$. "
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"id": "op4vvD9_ugWL",
"vscode": {
"languageId": "python"
}
},
"outputs": [
{
"ename": "ValueError",
"evalue": "cannot reshape array of size 16384000 into shape (2000,4096)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/home/laurent/Documents/Cours/ENSEEIHT/S8 - Réseau Profond/TP3.ipynb Cell 49'\u001b[0m in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m <a href='vscode-notebook-cell:/home/laurent/Documents/Cours/ENSEEIHT/S8%20-%20R%C3%A9seau%20Profond/TP3.ipynb#ch0000048?line=0'>1</a>\u001b[0m train_features \u001b[39m=\u001b[39m conv_base\u001b[39m.\u001b[39mpredict(x_train)\n\u001b[0;32m----> <a href='vscode-notebook-cell:/home/laurent/Documents/Cours/ENSEEIHT/S8%20-%20R%C3%A9seau%20Profond/TP3.ipynb#ch0000048?line=1'>2</a>\u001b[0m train_features \u001b[39m=\u001b[39m np\u001b[39m.\u001b[39;49mreshape(train_features,(train_features\u001b[39m.\u001b[39;49mshape[\u001b[39m0\u001b[39;49m],\u001b[39m2\u001b[39;49m\u001b[39m*\u001b[39;49m\u001b[39m2\u001b[39;49m\u001b[39m*\u001b[39;49m\u001b[39m1024\u001b[39;49m))\n\u001b[1;32m <a href='vscode-notebook-cell:/home/laurent/Documents/Cours/ENSEEIHT/S8%20-%20R%C3%A9seau%20Profond/TP3.ipynb#ch0000048?line=3'>4</a>\u001b[0m val_features \u001b[39m=\u001b[39m conv_base\u001b[39m.\u001b[39mpredict(x_val)\n\u001b[1;32m <a href='vscode-notebook-cell:/home/laurent/Documents/Cours/ENSEEIHT/S8%20-%20R%C3%A9seau%20Profond/TP3.ipynb#ch0000048?line=4'>5</a>\u001b[0m val_features \u001b[39m=\u001b[39m np\u001b[39m.\u001b[39mreshape(val_features,(val_features\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m],\u001b[39m2\u001b[39m\u001b[39m*\u001b[39m\u001b[39m2\u001b[39m\u001b[39m*\u001b[39m\u001b[39m512\u001b[39m))\n",
"File \u001b[0;32m<__array_function__ internals>:180\u001b[0m, in \u001b[0;36mreshape\u001b[0;34m(*args, **kwargs)\u001b[0m\n",
"File \u001b[0;32m/tmp/deepl/.env/lib/python3.8/site-packages/numpy/core/fromnumeric.py:298\u001b[0m, in \u001b[0;36mreshape\u001b[0;34m(a, newshape, order)\u001b[0m\n\u001b[1;32m <a href='file:///tmp/deepl/.env/lib/python3.8/site-packages/numpy/core/fromnumeric.py?line=197'>198</a>\u001b[0m \u001b[39m@array_function_dispatch\u001b[39m(_reshape_dispatcher)\n\u001b[1;32m <a href='file:///tmp/deepl/.env/lib/python3.8/site-packages/numpy/core/fromnumeric.py?line=198'>199</a>\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mreshape\u001b[39m(a, newshape, order\u001b[39m=\u001b[39m\u001b[39m'\u001b[39m\u001b[39mC\u001b[39m\u001b[39m'\u001b[39m):\n\u001b[1;32m <a href='file:///tmp/deepl/.env/lib/python3.8/site-packages/numpy/core/fromnumeric.py?line=199'>200</a>\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m <a href='file:///tmp/deepl/.env/lib/python3.8/site-packages/numpy/core/fromnumeric.py?line=200'>201</a>\u001b[0m \u001b[39m Gives a new shape to an array without changing its data.\u001b[39;00m\n\u001b[1;32m <a href='file:///tmp/deepl/.env/lib/python3.8/site-packages/numpy/core/fromnumeric.py?line=201'>202</a>\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m <a href='file:///tmp/deepl/.env/lib/python3.8/site-packages/numpy/core/fromnumeric.py?line=295'>296</a>\u001b[0m \u001b[39m [5, 6]])\u001b[39;00m\n\u001b[1;32m <a href='file:///tmp/deepl/.env/lib/python3.8/site-packages/numpy/core/fromnumeric.py?line=296'>297</a>\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m--> <a href='file:///tmp/deepl/.env/lib/python3.8/site-packages/numpy/core/fromnumeric.py?line=297'>298</a>\u001b[0m \u001b[39mreturn\u001b[39;00m _wrapfunc(a, \u001b[39m'\u001b[39;49m\u001b[39mreshape\u001b[39;49m\u001b[39m'\u001b[39;49m, newshape, order\u001b[39m=\u001b[39;49morder)\n",
"File \u001b[0;32m/tmp/deepl/.env/lib/python3.8/site-packages/numpy/core/fromnumeric.py:57\u001b[0m, in \u001b[0;36m_wrapfunc\u001b[0;34m(obj, method, *args, **kwds)\u001b[0m\n\u001b[1;32m <a href='file:///tmp/deepl/.env/lib/python3.8/site-packages/numpy/core/fromnumeric.py?line=53'>54</a>\u001b[0m \u001b[39mreturn\u001b[39;00m _wrapit(obj, method, \u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwds)\n\u001b[1;32m <a href='file:///tmp/deepl/.env/lib/python3.8/site-packages/numpy/core/fromnumeric.py?line=55'>56</a>\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m---> <a href='file:///tmp/deepl/.env/lib/python3.8/site-packages/numpy/core/fromnumeric.py?line=56'>57</a>\u001b[0m \u001b[39mreturn\u001b[39;00m bound(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwds)\n\u001b[1;32m <a href='file:///tmp/deepl/.env/lib/python3.8/site-packages/numpy/core/fromnumeric.py?line=57'>58</a>\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mTypeError\u001b[39;00m:\n\u001b[1;32m <a href='file:///tmp/deepl/.env/lib/python3.8/site-packages/numpy/core/fromnumeric.py?line=58'>59</a>\u001b[0m \u001b[39m# A TypeError occurs if the object does have such a method in its\u001b[39;00m\n\u001b[1;32m <a href='file:///tmp/deepl/.env/lib/python3.8/site-packages/numpy/core/fromnumeric.py?line=59'>60</a>\u001b[0m \u001b[39m# class, but its signature is not identical to that of NumPy's. This\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m <a href='file:///tmp/deepl/.env/lib/python3.8/site-packages/numpy/core/fromnumeric.py?line=63'>64</a>\u001b[0m \u001b[39m# Call _wrapit from within the except clause to ensure a potential\u001b[39;00m\n\u001b[1;32m <a href='file:///tmp/deepl/.env/lib/python3.8/site-packages/numpy/core/fromnumeric.py?line=64'>65</a>\u001b[0m \u001b[39m# exception has a traceback chain.\u001b[39;00m\n\u001b[1;32m <a href='file:///tmp/deepl/.env/lib/python3.8/site-packages/numpy/core/fromnumeric.py?line=65'>66</a>\u001b[0m \u001b[39mreturn\u001b[39;00m _wrapit(obj, method, \u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwds)\n",
"\u001b[0;31mValueError\u001b[0m: cannot reshape array of size 16384000 into shape (2000,4096)"
]
}
],
"source": [
"train_features = conv_base.predict(x_train)\n",
"train_features = np.reshape(train_features,(train_features.shape[0],2*2*512))\n",
"\n",
"val_features = conv_base.predict(x_val)\n",
"val_features = np.reshape(val_features,(val_features.shape[0],2*2*512))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0A_ayR0dvvwe"
},
"source": [
"Nous pouvons maintenant définir un réseau de neurones simple (par exemple, de 2 couches denses, avec 256 neurones sur la couche cachée) qui va travailler directement sur les caractéristiques prédites par VGG."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lmmBYYmtvUUF",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"# A COMPLETER\n",
"model = Sequential()\n",
"model.add(Dense(256, activation=\"relu\", input_dim=train_features.shape[1]))\n",
"model.add(Dense(1, activation=\"sigmoid\"))\n",
"model.summary()\n",
"\n",
"# AJOUTER EGALEMENT LA FONCTION DE COUT\n",
"model.compile(optimizer=optimizers.Adam(lr=3e-4),\n",
" loss='binary_crossentropy',\n",
" metrics=['accuracy'])\n",
"\n",
"# COMPLETER AVEC LES TENSEURS SUR LESQUELS EFFECTUER L'APPRENTISSAGE\n",
"history = model.fit(train_features, y_train,\n",
" epochs=50,\n",
" batch_size=16,\n",
" validation_data=(val_features, y_val))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_5xpyZCS4cco",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"plot_training_analysis()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_CFd7e-dJ-cK"
},
"source": [
"On observe à nouveau beaucoup de sur-apprentissage. Il faudrait trouver un moyen d'intégrer de l'augmentation de données. \n",
"\n",
"Pour cela, on peut connecter notre petit réseau de neurones à l'extrémité de la base convolutionnelle de VGG. L'idée est qu'en réutilisant notre générateur de données augmentées, nous pourrons calculer les caractéristiques de VGG sur les données augmentées, et ainsi classifier ces caractéristiques plutôt que les caractéristiques de notre base de données uniquement."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8W85VMorXPtP"
},
"source": [
"## Intégration de l'augmentation de données"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fCb_itsuXenK"
},
"source": [
"### Définition du nouveau modèle et entrainement\n",
"\n",
"On commence par créer un nouveau modèle qui va s'appuyer sur la base convolutive de VGG, à laquelle on adjoint une couche dense et notre couche de sortie."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jyZZS-GSKyPZ",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"from tensorflow.keras import layers\n",
"\n",
"model = Sequential()\n",
"model.add(conv_base)\n",
"model.add(layers.Flatten())\n",
"model.add(layers.Dense(256, activation='relu'))\n",
"model.add(layers.Dense(1, activation='sigmoid'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iZr_u4s7K4Fi",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dRiiu2EbBNAv"
},
"source": [
"**Attention** : il est important de ne pas commander l'entraînement de la base convolutionnelle de VGG ! Nous ne voulons en aucun cas écraser les bonnes caractéristiques de VGG que nous cherchons justement à réutiliser ! Le réseau aurait en outre un grand nombre de paramètres, ce qui est justement ce que l'on veut éviter ! \n",
"\n",
"Pour cela nous pouvons utiliser l'attribut *trainable* : en le positionnant à *false*, nous pouvons geler les poids et en empêcher la mise à jour pendant l'entraînement."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9h8Fx8P0PId5",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"conv_base.trainable = False\n",
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "t7tkMrA14ccp"
},
"source": [
"Observez le décompte des poids : le nombre de poids entraînable est maintenant de 500 000, contre 16 millions précédemment ; on ne va entrainer ici que les poids de notre couche dense et de la couche de sortie."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "go7Uld7sLRdG",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"model.compile(loss='binary_crossentropy',\n",
" optimizer=optimizers.Adam(learning_rate=3e-4),\n",
" metrics=['accuracy'])\n",
"\n",
"history = model.fit(train_datagen.flow(x_train, y_train, batch_size=10), \n",
" validation_data=(x_val, y_val),\n",
" epochs=10,\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "N-RNeMcAXu8h"
},
"source": [
"### Analyse des résultats du nouveau modèle"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tJWBzO-KCCSh"
},
"source": [
"L'entraînement est beaucoup plus lent ! Il faut en effet générer les données augmentées, et leur faire traverser les couches de VGG à chaque itération de gradient. Ceci prend du temps !"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_DHOFSauLyJa",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"plot_training_analysis()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ac4pNlvJCtkY"
},
"source": [
"En revanche, on observe que l'on a bien limité le sur-apprentissage, ce qui était le but recherché. Les résultats sont un peu meilleurs mais pas complètement satisfaisants."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UmykKP9_M1GW"
},
"source": [
"### Fine-tuning\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r4pV9QlhFc_8"
},
"source": [
"Nous pouvons maintenant tester la dernière technique vue en cours : le **fine-tuning**. Pour cela, nous allons repartir du réseau que nous venons d'entraîner, mais nous allons débloquer l'entraînement des poids de l'ensemble du réseau. **ATTENTION : il est important de choisir un taux d'apprentissage très faible afin de ne pas réduire à néant les bénéfices des entraînements précédents.** L'objectif est simplement de faire évoluer les paramètres du réseau \"à la marge\", et ceci ne peut être fait qu'après la première étape de *transfer learning* précédente. Sans cela, les dernières couches ajoutées à la suite de la base convolutive, après leur initialisation aléatoire, auraient engendré de forts gradients qui auraient complètement détruit les filtres généraux de VGG.\n",
"\n",
"\n",
"\n",
"On commence par réactiver l'entraînement des paramètres de la base convolutive de VGG : "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZeZA3eVYEJDY",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"conv_base.trainable = True\n",
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "p9SvUKcWEPVd",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"model.compile(loss='binary_crossentropy',\n",
" optimizer=optimizers.Adam(learning_rate=1e-5), # Taux d'apprentissage réduit pour ne pas tout casser, ni risquer le sur-apprentissage !\n",
" metrics=['accuracy'])\n",
"\n",
"history = model.fit(train_datagen.flow(x_train, y_train, batch_size=10), \n",
" validation_data=(x_val, y_val),\n",
" epochs=10,\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "MCK-hm_IN0P4",
"vscode": {
"languageId": "python"
}
},
"outputs": [],
"source": [
"plot_training_analysis()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8hCLsD6tpGhm"
},
"source": [
"On atteint un bon résultat, proche des 90% de précision sur l'ensemble de validation, bien au-dessus des performances obtenues sans *transfer learning* ! Vous comprenez maintenant pourquoi en traitement d'image, cette technique est incontournable."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VptFMmkArNqi"
},
"source": [
"**S'il vous reste du temps** :\n",
"\n",
"Vous pouvez maintenant reprendre le travail depuis le début en augmentant la résolution des images (par exemple $128 \\times 128$). A l'issue du *transfer learning* et du *fine-tuning*, vous devriez dépasser les 95\\% de précision sur l'ensemble de validation. \n",
"\n"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"machine_shape": "hm",
"name": "TP3_Classification_de_chiens_et_chats_Sujet.ipynb",
"provenance": [],
"toc_visible": true
},
"interpreter": {
"hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a"
},
"kernelspec": {
"display_name": ".env",
"language": "python",
"name": ".env"
}
},
"nbformat": 4,
"nbformat_minor": 0
}