TP-intelligence-artificiell.../IAM2022_Apprentissage_Semi_Supervise_Sujet.ipynb

2177 lines
343 KiB
Plaintext
Raw Normal View History

2023-06-23 17:39:56 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "XMMppWbnG3dN"
},
"source": [
"# Apprentissage Semi-Supervisé\n",
"\n",
"On se propose dans ce TP d'illustrer certains techniques d'apprentissage semi-supervisé vues en cours.\n",
"\n",
"Dans tout ce qui suit, on considère que l'on dispose d'un ensemble de données $x_{lab}$ labellisées et d'un ensemble de donnés $x_{unlab}$ non labellisées."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2onzaW7mJrgG"
},
"source": [
"## Datasets"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4BejOODdKZ70"
},
"source": [
"Commencez par exécuter ces codes qui vos permettront de charger les datasets que nous allons utiliser et de les partager en données labellisées et non labellisées, ainsi qu'en données de test."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "V2nYQ2X5JW2k"
},
"source": [
"### Dataset des deux clusters"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "Pkv-k9qIJyXH"
},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn import datasets\n",
"import matplotlib.pyplot as plt \n",
"\n",
"def generate_2clusters_dataset(num_lab = 10, num_unlab=740, num_test=250):\n",
" num_samples = num_lab + num_unlab + num_test\n",
" # Génération de 1000 données du dataset des 2 lunes\n",
" x, y = datasets.make_blobs(n_samples=[round(num_samples/2), round(num_samples/2)], n_features=2, center_box=(- 3, 3), random_state=1)\n",
"\n",
" x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=num_test/num_samples, random_state=1)\n",
" x_train_lab, x_train_unlab, y_train_lab, y_train_unlab = train_test_split(x_train, y_train, test_size=num_unlab/(num_unlab+num_lab), random_state=6)\n",
"\n",
" return x_train_lab, y_train_lab, x_train_unlab, y_train_unlab, x_test, y_test\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "OBwkuDKFLKdH"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(10, 2) (740, 2) (250, 2)\n",
"(10,) (740,) (250,)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAA+GElEQVR4nO2df5AdV3Xnv/e9N/M0EgQS2WAU/1QEXsx6sZFRZWQbDyPviLUF1pY2m02R2MkYi0nZRs4ST0X+UTNCZhzDVqw/WKSnwk7ZCUUqtU7Cks3GwY4HsXpDQCZYJCEh3kAcIFkcVUgUI3mkeWf/uHPUt++7t/t2v36/z6eqa2bedN++fbvf954+99xzFRFBEARB6F9K3a6AIAiC0Boi5IIgCH2OCLkgCEKfI0IuCILQ54iQC4Ig9DmVbpz0vPPOo0svvbQbpxYEQehbnn/++X8kovPtz7si5JdeeimOHTvWjVMLgiD0LUqpv3V9Lq4VQRCEPkeEXBAEoc8RIRcEQehzRMgFQRD6HBFyQRCEPkeEXBAEoc8RIRf6iqUl4OGH9U9BEDRdiSMXhDwsLQHbtgHLy8DoKPDss8D4eLdrJQjdRyxyoW9YXNQivrKify4udrtGgtAbiJALfcPEhLbEy2X9c2Ki2zUShN5AXCtC3zA+rt0pi4taxMWtIggaEXKhrxgfFwEXBBtxrQiCIPQ5IuSCIAh9jgi5IAhCnyNCLgiC0OeIkAuCIPQ5IuSCIAh9jgi5IAhCnyNCLgiC0OeIkAuCIPQ5IuSCIAh9jgi5IAhCnyNCLgiC0OeIkAuCIPQ5IuSCIAh9jgi5IAhCnyNCLgiC0OeIkAuCIPQ5IuSCIAh9jgi5IAhCn1OYkCulykqpP1VK/X5RZQqCIAjpFGmR7wHwjQLLEwRBEAIoRMiVUhcCuBnAp4ooTxAEQQinKIv8AIBZAI2CyhMEQRACaVnIlVI7AHyfiJ5P2W+3UuqYUurYyy+/3OppBUEQhFWKsMivBfA+pdS3AfwWgEml1G/aOxHRYSK6hoiuOf/88ws4rSBkZ2kJePhh/VMQBoVKqwUQ0V4AewFAKTUB4JeJ6GdbLVcQiubwYeCuu4CVFaBaBZ59Fhgf73atBKF1JI5cGAqWloA77wTOnAEaDeDVV4HFxW7XqnvIm8lg0bJFbkJEiwAWiyxTEIpgcVELOFMuAxMT3apNd7n9duAznwGWl4HRUXkzGQTEIheGgokJ7U4plYBKBfjEJ4ZXvB5/XIv4yor+OcxvJoNCoRa5IPQq4+Pa8lxc1KI+rCLOjI5GFvmwvpkMEmKRC0PD+Diwd+9wivj8PKCU3gDg1Cltkf/Mz2Rrj/n5dtROaBVFRB0/6TXXXEPHjh3r+HkFQdBinvdr38qxQusopZ4nomvsz8UiF4QeIy2iRCJOBBsRckHoICEivW0b8OCD+qe9X9r/Q5iby7a/7Zbh38XN0juIkAtChwgR4cXF5IiStP+HkFWA5+e1O4VdKvy7qxwR9+4gQi7EGKYvYpEuipCyQkR4YkJHkpTL7oiStP93m337ul2D4UTCD4UY+/YNh5izdVzEpJjQsiYmgJERbc2WSsC3vgUcOABs2aL3Vyo9TLLbYZRZ3TJCZxCLXBhKinBRZCnrzBnghReAH/sxPSEJAD71KeCXfxm48UbgJ34COHRI75cWJtnJMEq7U/e5U8SH3l1EyIWh/CLmcVH43CdpZf3rvwKTk8CHPwx873ta7M+c0Zb5yoqO6f7Wt/T/t23T+/tIuidFuIrs8kNcJVl86EJ7kDhyIUa/xwkvLYW7HbLum+Q+8ZV15owW8a98RSfqSqNa1a6WZ5/Vbhgb3/2x63fgAHDihP/afPW1y8/6PPT789Pr+OLIQUQd3zZv3kxCbwJ0uwb5qdeJxsaIymX9s14vruyFBV0uoH8uLIQdd/Ag0dq1bKOGbWvXEh065C7Pd3/M+pVKRCMj+u9qlWhmJt4WSe0EEM3Nues1N5d+vSH7CPkBcIwcmiquFSFGPw9m5fF7h7oj2H1SKmmrc/369HIXFoD9+4Ef/jDwAlb54Q+Bj30ssmxDXF+me6dU0m2wsqLfAmq1eLij3U4PPBAvn90p/CykuUrMNhR3SpdwqXu7N7HIhXaQ1SLPun+tpi3dUil5fy63VMpmiZvbunVER482l530xlSva8u8VtPnVyoqz3yLSLPIQ8+XVpZQPPBY5BJ+KAwMWUPzXBZ80jEnTuic5o1G8v5cbqOFpcjPntV+9a1bw48ZH4/qc+WVwJNP6pS1KyvxQdgs7ZT2hpa1DYX2IEIuDBSmmKXB7ojQdK6h+/N+p0/nH/hbXgZOnmz+nIXVHKx8+ulmlwa3w623ugXb107T09pNwvsnuVMWF7WLSVLidh+JWhEGlpColCyRK1n2X1oCPvpR4I/+SEeuZKVaBR55BPinf2oWUztC5dSpYiJFQic2ZY2QCWF+XvzrIUjUijBU9ILv9uhR7evO4yNfs0Yf7/JR2xE0RUUahUbm5I3gSaKfo6U6CSRqRRgm0iJY2m39seX+utflOz4p5nxiIoowWVnRP7NM4so7sSnrfkIHcal7uzexyIV2k2aRt9MC5HObUSNFbGaMNkeo1Ovp12IfZ7dLo6Gt/0cfJbrjDqKbb9aRL41G8jXy+fPSSrz6sAKPRS4+cmFgOXwYeOopYNcuYPfu+P/aOQPx4Yd1qtqVFX2e170O+MEPspWxZg3wx3+so1bS6pl2Leb/zbqVSsB73wscPw58//val3/mjJ5ROjICvOENwOwscPvt7lmmRSIzQsPw+cglakXoObIOQLqOA4B77tFulS9+UYfjPf10PHcIuyfm5tJdEmaURtrA3sSEdjs0Glqc/vmfw+q/di3wlrcAt9wCbN8efu2uSJakuo2OatcNkW6T06fj+ywv6+1b3wI+9CHg4x8H9uzROWDe/W4JL+xJXGZ6uzdxrQg+8g5S2sfNzCQPyLE7IuQ13p7goxRRpaLdD779q9Vml4FrglC1qgdEN27U0/KXl+Nl+dwpadc/Pe13Wxw5QnTJJUSjo9ndO5WKPr5V7GsRd0oY8LhWRMiFniJvRIR93MxMmI88xFdulm1uIyPN5c7NNe/Pwu8Sxu3bIz91ErWaLoNnlU5PJ1//woL/GvPkfzG3Sy4hOnmy+bpD6YWIon7FJ+QStSL0FHkjIuzjbr1Vx0Hv3++Oh86SU8bMs2KystIcDbNvX7wu1SrwwQ8CR45EUghEP+fmdBlf+pL//EtLwF136dmejYZ2izz+ePL1c7vZLiMinccla/4Xk+99D7jppnh8fJaVgbLmxJHFpgNwqXu7N7HIhSTyRkSEHpcnWoItYtNNkmTp++riOzdANDvrPpata9O9YVrZvP+OHUSTk8nX1kpsu7mtXavrGxo5Y5LFIhfrPQ7EtSL0EkX4RIsIgUsTID6H6XMvlYimpuL+3SwdA3/uE3NbuFgok/zWXLerrnLvc8MN+pyPPprPN55lCx13CLl37Zh81M+IkAs9RRYLzkVRllpSPcxzVKtaAH3nY2EKvS6fOG/a5BYuIBI+PpbPafrjFxaSO5Y77iheuIucXepqJ7HII0TIhfaRwzT2ffHbaam5LMUk69E1gGrWzU4by/uGNMPCgnvCkGmR+wZI+fOxMaKtW937mFErTL2uB2iLFHGldD1avZ9JFFHGoCBCLrSHDCZTmgsipKgs+9qYVixbtkkCkXQO8392hEpapzI3Fz+ehZnDGe16sQXO5+IOoFSKIllYoBcWdGdgdhKmr72VHOkuEb/xRl2+q0MMuUeDItKdCp8UIRfaQ04npsuCCynKtjCziABbseWydpNUq5HI1Grusnzn8IUkVqvNgm93HHwNO3Ykd2zmuXwuDRZm+6e5mQtKrFlTnJBXq0QHDvjbO+1+DpLbpF2upebztEnIAVwE4DkAfwHgzwHsSTtGhHyAyPltdD34IUVl/cIkRYmYli2vcRl6Ca58Kkpp94vrekwfO1vf5rW6rst8+6hUml0xV1wRdUY+l0mlEr+eWo3orW8tZsDTt4qR6/pd7WoKvd12/cYgCPmbALxj9
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"x_train_lab, y_train_lab, x_train_unlab, y_train_unlab, x_test, y_test = generate_2clusters_dataset(num_lab = 10, num_unlab=740, num_test=250)\n",
"\n",
"print(x_train_lab.shape, x_train_unlab.shape, x_test.shape)\n",
"print(y_train_lab.shape, y_train_unlab.shape, y_test.shape)\n",
"\n",
"# Affichage des données\n",
"plt.plot(x_train_unlab[y_train_unlab==0,0], x_train_unlab[y_train_unlab==0,1], 'b.')\n",
"plt.plot(x_train_unlab[y_train_unlab==1,0], x_train_unlab[y_train_unlab==1,1], 'r.')\n",
"\n",
"plt.plot(x_test[y_test==0,0], x_test[y_test==0,1], 'b+')\n",
"plt.plot(x_test[y_test==1,0], x_test[y_test==1,1], 'r+')\n",
"\n",
"plt.plot(x_train_lab[y_train_lab==0,0], x_train_lab[y_train_lab==0,1], 'b.', markersize=30)\n",
"plt.plot(x_train_lab[y_train_lab==1,0], x_train_lab[y_train_lab==1,1], 'r.', markersize=30)\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sKR9vNgsLp_J"
},
"source": [
"### Dataset des 2 lunes"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_AFhsTUQwIxt"
},
"source": [
"\n",
"<img src=\"https://drive.google.com/uc?id=1xb_gasBJ6sEmbyvCWTnVEAsbspyDCyFL\">\n",
"<caption><center> Figure 1: Comparaison de différents algorithmes semi-supervisés sur le dataset des 2 lunes</center></caption>"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "tCw5v2JDLwau"
},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn import datasets\n",
"import matplotlib.pyplot as plt \n",
"\n",
"def generate_2moons_dataset(num_lab = 10, num_unlab=740, num_test=250):\n",
" num_samples = num_lab + num_unlab + num_test\n",
" # Génération de 1000 données du dataset des 2 lunes\n",
" x, y = datasets.make_moons(n_samples=num_samples, noise=0.1, random_state=1)\n",
"\n",
" x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=num_test/num_samples, random_state=1)\n",
" x_train_lab, x_train_unlab, y_train_lab, y_train_unlab = train_test_split(x_train, y_train, test_size=num_unlab/(num_unlab+num_lab), random_state=6)\n",
"\n",
" return x_train_lab, y_train_lab, x_train_unlab, y_train_unlab, x_test, y_test\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "FkQ1L5I1MBkH"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(10, 2) (740, 2) (250, 2)\n",
"(10,) (740,) (250,)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAABGzUlEQVR4nO2de5DlVXXvv/uc7j4zjC9oLIHIOLQoEeXKY+zQExjagRoCEukUJtFYGWAQaAMGkiq7RMFuGDOTkCrpG+EyPQqGyet6E28lkApOwqPFzGmFwdf4KBMwCVfLVyZXHS/DNDNn3z9WL3/rt8/ev8d5/s4561P1q/P4vfZv/36/tddee621jbUWiqIoSv9T6nYBFEVRlM6gAl9RFGVAUIGvKIoyIKjAVxRFGRBU4CuKogwIQ90uQIjjjz/erlu3rtvFUBRF6Smefvrp/7TWvtK3rrACf926ddi3b1+3i6EoitJTGGP+I7ROTTqKoigDggp8RVGUAUEFvqIoyoCgAl9RFGVAUIGvKIoyIKjA7zPm5rpdAkVRiooK/D7j9tu7XQJFUYqKCnxFUZQBQQV+HzA3BxhDCxB9V/OOoigSU9QJUNavX2810jY/xgAFvaWKonQAY8zT1tr1vnWq4SuKogwIKvD7jNnZbpdAUZSiogK/z1C7vaIoIVTgK4qiDAgq8BVFUQYEFfiKoigDggp8pedZWgJ27KBPRVHCFHbGK0VZWgIWF4HJSWBiIrzNhRcCy8vAyAjw6KPhbRVl0FGBr3SELMLb3T4kyOWxFhdpm6NH6XNxkbbznS9vGRSl31CBr7SdNC18bq7enTRJkMtjzc/TJ/+enPSfD+hcT8B3PYpSBNSGr7Qdn/CW+DJ8Tk6SYC6XI0HuO9aBAyS8t22LhLjvfGllaCWasVQpKqrhK22HhbfUwtOYmCABvnt3+rEmJuLaeuh8ecugKP2GavhK22HhLbXwrBk+H3gA+PjHyRyztOQ/lu988/O0z/x81CCk7dcMmrFU6QU0W6bSdUIZPnfsAG67jcww5TIJ61tuST9etz13NGOp0k00W6bSk4Ts+Glktder/74yaKgNX8lFO1wbQxk+XTv+/v3Zzp1lzKCdvQDNWKoUFRX4SmYaEZJZGgjXzi33AciOf/gwUKsBpRJQqSSfm234n/40cMUV/u1Cbp+tQO32SlFRga9kRgrJw4cjf/NWRsG6+1x5JX2v1Wh9rZYuoJeWgJtvpu0+9zngjDPq/f4vvli9dpTBQ234ys9JsmkvLQHPPUf29FKJBO8jj0TeMz7y+r5bS9r8Cy/QPi+8AHz/+8DwMJ0ToM80Ab24SA0SN0w+v/92e+0oShFRDV8BkJ7K4K1vpXVDQ8D69cC+fX5tW5pjsvrfv/gicN99wJ13koBnDxdrgT17gOOOA97yFtLKf/xjOs6ePWEhPToa7xGMjvqvN8nUlGesQiNrlZ7BWlvI5ZxzzrFK59i+3dpy2VqAPrdvj9ZNT9P/vJx+urUjI7Td6tXWVqu0XbVKv+X/1Sodi7dxOXjQ2vPOs/aYY+LncJdjjrH2/PNpe2vpv6RrKZVom1KJfs/O+o9bKtH1yfL5riOJpLIoSqcBsM8G5KqadBQA+Vwgv/EN8jW/9tp4TyA0EMpJzlzTz4svApdcAjz1FPD888nle/554MkngUsvpf2AsPlpcpIGdstl+pycJA2cxbykVgN27oybptzr2L07m/umunkqRUdNOgqAyKbtM2Ns2QLcfz8JP+bIEWDt2vSUBktLwMaNJGhdU9F99wFf/CLZ2bNw+DANwo6M0O8PfpA+t24FTj45Mqv40jK4nj8ubgPF1zE0RNd+9Gi8/HNz8Zw5MsJ21SodF1AKSkj17/aiJp3GmJ1tzz7VKpk+KhUygwwPW7uwEF+/fTv9J00427dH5hNj6BjWWlurWXvKKclmnNBy7LHRdzY/uWYVaZYZGaFys4kGoN/GRGYd13TD1zM9HZm6SiVrN2+uN/EAySYxRekkSDDptEQ4A7gfwA8BfC2w3gD4EwDPAPgqgLPTjqkCPz9sp85Lnn0WFqwdGooLySSbd7UaF9blMh1j715r16xpTOAPDUXH4vO51yAFMAt233LZZf4xhtlZa7dutXZqio6T1DgA+e3+itIukgR+q0w6fwrgbgC7A+svAfC6leWXANy78qm0kHal5ZXmkAMHSFRKDx2g3na/Z4+/PEePAjfcQAvb4vNy5AjwS78ErFkDPPYYsGED/c9mldnZuJ99uUzrjhyh/w4dSs9145adk6Hxde/eHdXJ7GyySUxRikJLBL619gljzLqETS4HsHul9fm8MeYVxpgTrbXfa8X5B52QPZlD/H0ug0n7yO1dd80LLyQBWqvRZyj1MNu5+djDw5GAr9WAr3+9cYEPAJddRgL8scf8610BDETfuYHIg7V0DXzdn/xk1IDwBCtummZFKRwh1T/vAmAdwiadvwdwnvj9KID1nu2uA7APwL61a9e2sdPTP4TcDdkuz/Zln4lBbhNCmkbY1XF4mEwclUp03IUFsm/77PpA3BRUqVj7pjc1Zs5Juk73upLqLE9dSlOUa9d37fWNjKEoSitBu234tkUCXy5qw88PCyfpF+/aut3t5ScjfefZNs3C3hWCvJ3P/57/GxqK/puepkHUZgX93r3+a5HfZ2fT4wDS6nJqKl6nXD8he30jYyiK0kqSBH6n/PC/C+Bk8fvVK/8pDeJLOAaQ+WTjRvpk08XRo2S3vvVW/7Fkdkc24dx2G30CwLveFUWuSowhE8mttyZPKWht5PK4di391yyu6cSXofL22+PXktU/Xm63Z090PvazB8iMc+21lOtHUXqFTgn8BwFsMcS5AH5i1X7/cxoJ2HEHFRcXo4k3rAW2bweqVVpXLgOrVwMf+Yh/Zqbbb48akN2747lsdu8mf/lqlY4BUDDT9DTwxBMkCB97rD5oKxTI9cwz9FuWIQ/HHAO87W3+fd3rAqJryZLLh++D9N8/dCg63oYNwIc+RI3H/v3RbFzcwOpsV0rhCan+eRYAfwXgewBeBPAdANcAmAYwvbLeALgHwLMA9iPFnGMHyKST1Z3PNU0k+Z3L4yTZ8H3HcM0t5bK14+PWzsxEKRZ8piGf6cT3H28LWHvVVfnMOJUKpVdYXg7X0aZN/n3ZrBRC1l+lEqWO4P1dP/vNm+NunxxfoCYdpdugEzb8Vi+DIvCzBOywMAr5k/NAoU/AuoOIcpsk33XfYkyUe2br1uSyhHDHDQBrTzghXdi7uXRc5KCwbFTc8YhQ4+feh+npeNCY26AuLMQbRx7AVoGvdBsV+AUmi4bvaxQaESzuubZu9a/3DdC6gt/tRSSR5P0il+OPj/8eHqbgrLExa3fuTNbsh4fj+7p1lFbP7iCzr3xbt8YbjOnpqBHm+3LBBdnuQyMDyYqShSSBr7l0ukyWgJ3JScrpUqvRZ6OTdbhJwU49NVyWH/8YeOgh4JvfrD+Ota2fJQqgiUre+U7gwQeBk04C3vhGYHwcOPfcZHv/4mLcp5/raHY2Chp77jl/cJibf8e9DzwuYgyNZUi2bCE7vow/4Pw+Ibo9wboy2KjALwBZAnasjT737wc2bQJ27aLI11BDwcKMP0P56X2531koDQ0BZ59N637603jAEe+fNoerzBcvBWi1SgOhpRIN7r7zncB119GSBy4HH+fuu/3XwYOpHDC2YUN8YDVr4JSsr9NOA37jN7JH1+adWrEdcwgrA0xI9e/2MigmHUmoq+8GP0lbtS+3C8OmiKRz+HLjbN8emSqMiY8r+AKssuAGSMnf8vyumSmpXtxrdcvl1hvnxEmzt/P5QuMUQ0PxRGwhk5VrAuJjZ825o/l5lEaA2vCLj7Sfu5koFxbigssdvA0N9voEvrWRsF1YiNvrebKQmZn48Wdm4mVkTxY5cUjWwVp3u02b4uMTvgZKZr3kcyZ55HCj4drlk8Ym5MB3yNvJ2nq7vVte/p2WUC6LDV8zcCqNkCTwdQKULiN9vw8fJjv9iy9ScrGlJVre974oWKlWi0wiQP0cr66fPRD3C19aIr/7XbuAG2+MB1RxI
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"x_train_lab, y_train_lab, x_train_unlab, y_train_unlab, x_test, y_test = generate_2moons_dataset(num_lab = 10, num_unlab=740, num_test=250)\n",
"\n",
"print(x_train_lab.shape, x_train_unlab.shape, x_test.shape)\n",
"print(y_train_lab.shape, y_train_unlab.shape, y_test.shape)\n",
"\n",
"# Affichage des données\n",
"plt.plot(x_train_unlab[y_train_unlab==0,0], x_train_unlab[y_train_unlab==0,1], 'b.')\n",
"plt.plot(x_train_unlab[y_train_unlab==1,0], x_train_unlab[y_train_unlab==1,1], 'r.')\n",
"\n",
"plt.plot(x_test[y_test==0,0], x_test[y_test==0,1], 'b+')\n",
"plt.plot(x_test[y_test==1,0], x_test[y_test==1,1], 'r+')\n",
"\n",
"plt.plot(x_train_lab[y_train_lab==0,0], x_train_lab[y_train_lab==0,1], 'b.', markersize=30)\n",
"plt.plot(x_train_lab[y_train_lab==1,0], x_train_lab[y_train_lab==1,1], 'r.', markersize=30)\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NIGZe-yAQq-A"
},
"source": [
"## Modèles"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jaPezVmtK5tC"
},
"source": [
"Nous allons dès maintenant préparer les modèles que nous utiliserons dans la suite.\n",
"\n",
"**Travail à faire** Complétez les modèles ci-dessous : "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mfw8EKUuUpt6"
},
"source": [
"Pour le dataset des 2 clusters, un simple perceptron monocouche suffira :"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"id": "BryV3CDKQytA"
},
"outputs": [],
"source": [
"from keras.layers import Input, Dense\n",
"from keras.models import Model\n",
"\n",
"# Ici, écrire un simple perceptron monocouche\n",
"def create_model_2clusters():\n",
"\n",
" inputs = Input(shape=(2,))\n",
"\n",
" outputs = Dense(1, activation='sigmoid')(inputs)\n",
"\n",
" model = Model(inputs=inputs, outputs=outputs) \n",
"\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NrE8ZQCpUuxg"
},
"source": [
"Pour le dataset des 2 lunes, implémentez un perceptron multi-couches à une couche cachée, par exemple de 20 neurones."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"id": "o1jcG_4pyGlx"
},
"outputs": [],
"source": [
"# Ici, écrire un perceptron multi-couches à une seule couche cachée comprenant 20 neurones\n",
"def create_model_2moons():\n",
"\n",
" inputs = Input(shape=(2,))\n",
"\n",
" inter = Dense(20, activation=\"relu\")(inputs) \n",
"\n",
" outputs = Dense(1, activation=\"sigmoid\")(inter)\n",
" \n",
" model = Model(inputs=inputs, outputs=outputs) \n",
"\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JMaTgZJcQbIh"
},
"source": [
"## Apprentissage supervisé"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hGTfv5YfMAXY"
},
"source": [
"Commencez par bien lire le code ci-dessous, qui vous permet de mettre en place un apprentissage supervisé en détaillant la boucle d'apprentissage. Cela nous permettra d'avoir plus de contrôle dans la suite pour implémenter les algorithmes semi-supervisés. Cela vous fournira également une base contre laquelle comparer les algorithmes semi-supervisés.\n",
"\n",
"En quelques mots, le code est organisé autour d'une double boucle : une sur les *epochs*, et la 2nde sur les *mini-batches*.\n",
"\n",
"Pour chaque nouveau batch de données, on réalise la succession d'étapes suivantes dans un bloc **GradientTape** qui permet le calcul automatique des gradients : \n",
"\n",
"\n",
"1. Prédiction de la sortie du modèle sur les données du batch\n",
"2. Calcul de la fonction de perte entre sortie du réseau et labels réels associés aux élements du batch\n",
"3. Calcul des gradients de la perte par rapport aux paramètres du réseau (par différentiation automatique)\n",
"4. Mise à jour des paramètres grâce aux gradients calculés. \n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fbmhai8PVXVd"
},
"source": [
"### Dataset des 2 clusters"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"id": "XP5XgJRQQm5_"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0 : Loss : 1.5559, Acc : 0.1000, Test Acc : 0.0560\n",
"Epoch 1 : Loss : 1.5317, Acc : 0.1000, Test Acc : 0.0520\n",
"Epoch 2 : Loss : 1.5078, Acc : 0.1000, Test Acc : 0.0520\n",
"Epoch 3 : Loss : 1.4840, Acc : 0.1000, Test Acc : 0.0480\n",
"Epoch 4 : Loss : 1.4604, Acc : 0.1000, Test Acc : 0.0440\n",
"Epoch 5 : Loss : 1.4371, Acc : 0.0000, Test Acc : 0.0400\n",
"Epoch 6 : Loss : 1.4139, Acc : 0.0000, Test Acc : 0.0400\n",
"Epoch 7 : Loss : 1.3910, Acc : 0.0000, Test Acc : 0.0400\n",
"Epoch 8 : Loss : 1.3683, Acc : 0.0000, Test Acc : 0.0440\n",
"Epoch 9 : Loss : 1.3459, Acc : 0.0000, Test Acc : 0.0440\n",
"Epoch 10 : Loss : 1.3237, Acc : 0.0000, Test Acc : 0.0480\n",
"Epoch 11 : Loss : 1.3018, Acc : 0.0000, Test Acc : 0.0480\n",
"Epoch 12 : Loss : 1.2802, Acc : 0.0000, Test Acc : 0.0480\n",
"Epoch 13 : Loss : 1.2588, Acc : 0.0000, Test Acc : 0.0480\n",
"Epoch 14 : Loss : 1.2377, Acc : 0.0000, Test Acc : 0.0480\n",
"Epoch 15 : Loss : 1.2169, Acc : 0.0000, Test Acc : 0.0480\n",
"Epoch 16 : Loss : 1.1964, Acc : 0.0000, Test Acc : 0.0440\n",
"Epoch 17 : Loss : 1.1762, Acc : 0.0000, Test Acc : 0.0400\n",
"Epoch 18 : Loss : 1.1563, Acc : 0.0000, Test Acc : 0.0440\n",
"Epoch 19 : Loss : 1.1367, Acc : 0.0000, Test Acc : 0.0440\n",
"Epoch 20 : Loss : 1.1175, Acc : 0.0000, Test Acc : 0.0440\n",
"Epoch 21 : Loss : 1.0985, Acc : 0.0000, Test Acc : 0.0480\n",
"Epoch 22 : Loss : 1.0798, Acc : 0.0000, Test Acc : 0.0480\n",
"Epoch 23 : Loss : 1.0614, Acc : 0.0000, Test Acc : 0.0480\n",
"Epoch 24 : Loss : 1.0433, Acc : 0.0000, Test Acc : 0.0520\n",
"Epoch 25 : Loss : 1.0256, Acc : 0.0000, Test Acc : 0.0560\n",
"Epoch 26 : Loss : 1.0081, Acc : 0.0000, Test Acc : 0.0600\n",
"Epoch 27 : Loss : 0.9908, Acc : 0.0000, Test Acc : 0.0560\n",
"Epoch 28 : Loss : 0.9739, Acc : 0.0000, Test Acc : 0.0680\n",
"Epoch 29 : Loss : 0.9572, Acc : 0.0000, Test Acc : 0.0760\n",
"Epoch 30 : Loss : 0.9408, Acc : 0.1000, Test Acc : 0.0800\n",
"Epoch 31 : Loss : 0.9247, Acc : 0.1000, Test Acc : 0.0840\n",
"Epoch 32 : Loss : 0.9088, Acc : 0.1000, Test Acc : 0.0880\n",
"Epoch 33 : Loss : 0.8932, Acc : 0.1000, Test Acc : 0.0880\n",
"Epoch 34 : Loss : 0.8778, Acc : 0.1000, Test Acc : 0.1080\n",
"Epoch 35 : Loss : 0.8627, Acc : 0.1000, Test Acc : 0.1040\n",
"Epoch 36 : Loss : 0.8479, Acc : 0.1000, Test Acc : 0.1200\n",
"Epoch 37 : Loss : 0.8333, Acc : 0.1000, Test Acc : 0.1360\n",
"Epoch 38 : Loss : 0.8190, Acc : 0.1000, Test Acc : 0.1600\n",
"Epoch 39 : Loss : 0.8049, Acc : 0.3000, Test Acc : 0.1760\n",
"Epoch 40 : Loss : 0.7910, Acc : 0.3000, Test Acc : 0.2080\n",
"Epoch 41 : Loss : 0.7775, Acc : 0.4000, Test Acc : 0.2360\n",
"Epoch 42 : Loss : 0.7641, Acc : 0.4000, Test Acc : 0.3080\n",
"Epoch 43 : Loss : 0.7510, Acc : 0.4000, Test Acc : 0.3680\n",
"Epoch 44 : Loss : 0.7382, Acc : 0.4000, Test Acc : 0.4280\n",
"Epoch 45 : Loss : 0.7256, Acc : 0.5000, Test Acc : 0.4680\n",
"Epoch 46 : Loss : 0.7133, Acc : 0.5000, Test Acc : 0.5120\n",
"Epoch 47 : Loss : 0.7012, Acc : 0.5000, Test Acc : 0.5240\n",
"Epoch 48 : Loss : 0.6893, Acc : 0.5000, Test Acc : 0.5400\n",
"Epoch 49 : Loss : 0.6777, Acc : 0.5000, Test Acc : 0.5480\n",
"Epoch 50 : Loss : 0.6663, Acc : 0.5000, Test Acc : 0.5800\n",
"Epoch 51 : Loss : 0.6552, Acc : 0.5000, Test Acc : 0.6120\n",
"Epoch 52 : Loss : 0.6443, Acc : 0.5000, Test Acc : 0.6520\n",
"Epoch 53 : Loss : 0.6336, Acc : 0.7000, Test Acc : 0.6960\n",
"Epoch 54 : Loss : 0.6232, Acc : 0.7000, Test Acc : 0.7560\n",
"Epoch 55 : Loss : 0.6130, Acc : 0.7000, Test Acc : 0.7920\n",
"Epoch 56 : Loss : 0.6030, Acc : 0.8000, Test Acc : 0.8120\n",
"Epoch 57 : Loss : 0.5932, Acc : 1.0000, Test Acc : 0.8280\n",
"Epoch 58 : Loss : 0.5836, Acc : 1.0000, Test Acc : 0.8520\n",
"Epoch 59 : Loss : 0.5743, Acc : 1.0000, Test Acc : 0.8720\n",
"Epoch 60 : Loss : 0.5651, Acc : 1.0000, Test Acc : 0.8840\n",
"Epoch 61 : Loss : 0.5562, Acc : 1.0000, Test Acc : 0.8920\n",
"Epoch 62 : Loss : 0.5474, Acc : 1.0000, Test Acc : 0.9000\n",
"Epoch 63 : Loss : 0.5389, Acc : 1.0000, Test Acc : 0.9040\n",
"Epoch 64 : Loss : 0.5306, Acc : 1.0000, Test Acc : 0.9080\n",
"Epoch 65 : Loss : 0.5224, Acc : 1.0000, Test Acc : 0.9120\n",
"Epoch 66 : Loss : 0.5144, Acc : 1.0000, Test Acc : 0.9160\n",
"Epoch 67 : Loss : 0.5067, Acc : 1.0000, Test Acc : 0.9120\n",
"Epoch 68 : Loss : 0.4991, Acc : 1.0000, Test Acc : 0.9160\n",
"Epoch 69 : Loss : 0.4916, Acc : 1.0000, Test Acc : 0.9200\n",
"Epoch 70 : Loss : 0.4844, Acc : 1.0000, Test Acc : 0.9200\n",
"Epoch 71 : Loss : 0.4773, Acc : 1.0000, Test Acc : 0.9240\n",
"Epoch 72 : Loss : 0.4704, Acc : 1.0000, Test Acc : 0.9280\n",
"Epoch 73 : Loss : 0.4636, Acc : 1.0000, Test Acc : 0.9360\n",
"Epoch 74 : Loss : 0.4570, Acc : 1.0000, Test Acc : 0.9400\n",
"Epoch 75 : Loss : 0.4505, Acc : 1.0000, Test Acc : 0.9400\n",
"Epoch 76 : Loss : 0.4442, Acc : 1.0000, Test Acc : 0.9440\n",
"Epoch 77 : Loss : 0.4380, Acc : 1.0000, Test Acc : 0.9440\n",
"Epoch 78 : Loss : 0.4320, Acc : 1.0000, Test Acc : 0.9440\n",
"Epoch 79 : Loss : 0.4261, Acc : 1.0000, Test Acc : 0.9440\n",
"Epoch 80 : Loss : 0.4204, Acc : 1.0000, Test Acc : 0.9440\n",
"Epoch 81 : Loss : 0.4147, Acc : 1.0000, Test Acc : 0.9440\n",
"Epoch 82 : Loss : 0.4092, Acc : 1.0000, Test Acc : 0.9440\n",
"Epoch 83 : Loss : 0.4039, Acc : 1.0000, Test Acc : 0.9440\n",
"Epoch 84 : Loss : 0.3986, Acc : 1.0000, Test Acc : 0.9440\n",
"Epoch 85 : Loss : 0.3935, Acc : 1.0000, Test Acc : 0.9480\n",
"Epoch 86 : Loss : 0.3885, Acc : 1.0000, Test Acc : 0.9520\n",
"Epoch 87 : Loss : 0.3836, Acc : 1.0000, Test Acc : 0.9520\n",
"Epoch 88 : Loss : 0.3788, Acc : 1.0000, Test Acc : 0.9520\n",
"Epoch 89 : Loss : 0.3741, Acc : 1.0000, Test Acc : 0.9520\n",
"Epoch 90 : Loss : 0.3695, Acc : 1.0000, Test Acc : 0.9520\n",
"Epoch 91 : Loss : 0.3650, Acc : 1.0000, Test Acc : 0.9520\n",
"Epoch 92 : Loss : 0.3606, Acc : 1.0000, Test Acc : 0.9520\n",
"Epoch 93 : Loss : 0.3563, Acc : 1.0000, Test Acc : 0.9520\n",
"Epoch 94 : Loss : 0.3521, Acc : 1.0000, Test Acc : 0.9520\n",
"Epoch 95 : Loss : 0.3480, Acc : 1.0000, Test Acc : 0.9520\n",
"Epoch 96 : Loss : 0.3440, Acc : 1.0000, Test Acc : 0.9560\n",
"Epoch 97 : Loss : 0.3400, Acc : 1.0000, Test Acc : 0.9560\n",
"Epoch 98 : Loss : 0.3362, Acc : 1.0000, Test Acc : 0.9560\n",
"Epoch 99 : Loss : 0.3324, Acc : 1.0000, Test Acc : 0.9560\n",
"Epoch 100 : Loss : 0.3287, Acc : 1.0000, Test Acc : 0.9560\n",
"Epoch 101 : Loss : 0.3251, Acc : 1.0000, Test Acc : 0.9560\n",
"Epoch 102 : Loss : 0.3215, Acc : 1.0000, Test Acc : 0.9560\n",
"Epoch 103 : Loss : 0.3181, Acc : 1.0000, Test Acc : 0.9560\n",
"Epoch 104 : Loss : 0.3147, Acc : 1.0000, Test Acc : 0.9560\n",
"Epoch 105 : Loss : 0.3113, Acc : 1.0000, Test Acc : 0.9560\n",
"Epoch 106 : Loss : 0.3081, Acc : 1.0000, Test Acc : 0.9560\n",
"Epoch 107 : Loss : 0.3049, Acc : 1.0000, Test Acc : 0.9560\n",
"Epoch 108 : Loss : 0.3017, Acc : 1.0000, Test Acc : 0.9560\n",
"Epoch 109 : Loss : 0.2987, Acc : 1.0000, Test Acc : 0.9560\n",
"Epoch 110 : Loss : 0.2956, Acc : 1.0000, Test Acc : 0.9560\n",
"Epoch 111 : Loss : 0.2927, Acc : 1.0000, Test Acc : 0.9560\n",
"Epoch 112 : Loss : 0.2898, Acc : 1.0000, Test Acc : 0.9560\n",
"Epoch 113 : Loss : 0.2870, Acc : 1.0000, Test Acc : 0.9560\n",
"Epoch 114 : Loss : 0.2842, Acc : 1.0000, Test Acc : 0.9560\n",
"Epoch 115 : Loss : 0.2814, Acc : 1.0000, Test Acc : 0.9560\n",
"Epoch 116 : Loss : 0.2787, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 117 : Loss : 0.2761, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 118 : Loss : 0.2735, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 119 : Loss : 0.2710, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 120 : Loss : 0.2685, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 121 : Loss : 0.2661, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 122 : Loss : 0.2637, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 123 : Loss : 0.2613, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 124 : Loss : 0.2590, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 125 : Loss : 0.2567, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 126 : Loss : 0.2545, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 127 : Loss : 0.2523, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 128 : Loss : 0.2501, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 129 : Loss : 0.2480, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 130 : Loss : 0.2459, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 131 : Loss : 0.2439, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 132 : Loss : 0.2418, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 133 : Loss : 0.2399, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 134 : Loss : 0.2379, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 135 : Loss : 0.2360, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 136 : Loss : 0.2341, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 137 : Loss : 0.2323, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 138 : Loss : 0.2304, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 139 : Loss : 0.2286, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 140 : Loss : 0.2269, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 141 : Loss : 0.2251, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 142 : Loss : 0.2234, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 143 : Loss : 0.2217, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 144 : Loss : 0.2201, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 145 : Loss : 0.2184, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 146 : Loss : 0.2168, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 147 : Loss : 0.2152, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 148 : Loss : 0.2137, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 149 : Loss : 0.2121, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 150 : Loss : 0.2106, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 151 : Loss : 0.2091, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 152 : Loss : 0.2076, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 153 : Loss : 0.2062, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 154 : Loss : 0.2048, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 155 : Loss : 0.2034, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 156 : Loss : 0.2020, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 157 : Loss : 0.2006, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 158 : Loss : 0.1992, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 159 : Loss : 0.1979, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 160 : Loss : 0.1966, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 161 : Loss : 0.1953, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 162 : Loss : 0.1940, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 163 : Loss : 0.1928, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 164 : Loss : 0.1915, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 165 : Loss : 0.1903, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 166 : Loss : 0.1891, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 167 : Loss : 0.1879, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 168 : Loss : 0.1867, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 169 : Loss : 0.1855, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 170 : Loss : 0.1844, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 171 : Loss : 0.1833, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 172 : Loss : 0.1821, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 173 : Loss : 0.1810, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 174 : Loss : 0.1799, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 175 : Loss : 0.1789, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 176 : Loss : 0.1778, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 177 : Loss : 0.1767, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 178 : Loss : 0.1757, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 179 : Loss : 0.1747, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 180 : Loss : 0.1737, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 181 : Loss : 0.1727, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 182 : Loss : 0.1717, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 183 : Loss : 0.1707, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 184 : Loss : 0.1697, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 185 : Loss : 0.1688, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 186 : Loss : 0.1678, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 187 : Loss : 0.1669, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 188 : Loss : 0.1660, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 189 : Loss : 0.1651, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 190 : Loss : 0.1642, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 191 : Loss : 0.1633, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 192 : Loss : 0.1624, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 193 : Loss : 0.1615, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 194 : Loss : 0.1607, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 195 : Loss : 0.1598, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 196 : Loss : 0.1590, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 197 : Loss : 0.1581, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 198 : Loss : 0.1573, Acc : 1.0000, Test Acc : 0.9600\n",
"Epoch 199 : Loss : 0.1565, Acc : 1.0000, Test Acc : 0.9600\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"import math\n",
"\n",
"# Données et modèle du problème des 2 clusters\n",
"x_train_lab, y_train_lab, x_train_unlab, y_train_unlab, x_test, y_test = generate_2clusters_dataset(num_lab = 10, num_unlab=740, num_test=250)\n",
"model = create_model_2clusters()\n",
"\n",
"# Hyperparamètres de l'apprentissage\n",
"epochs = 200\n",
"batch_size = 32\n",
"if batch_size < x_train_lab.shape[0]:\n",
" steps_per_epoch = math.floor(x_train_lab.shape[0]/batch_size)\n",
"else:\n",
" steps_per_epoch = 1\n",
" batch_size = x_train_lab.shape[0]\n",
"\n",
"# Instanciation d'un optimiseur et d'une fonction de coût.\n",
"optimizer = keras.optimizers.Adam(learning_rate=1e-2)\n",
"loss_fn = keras.losses.BinaryCrossentropy()\n",
"\n",
"# Préparation des métriques pour le suivi de la performance du modèle.\n",
"train_acc_metric = keras.metrics.BinaryAccuracy()\n",
"test_acc_metric = keras.metrics.BinaryAccuracy()\n",
"\n",
"# Indices de l'ensemble labellisé\n",
"indices = np.arange(x_train_lab.shape[0])\n",
"\n",
"# Boucle sur les epochs\n",
"for epoch in range(epochs):\n",
"\n",
" # A chaque nouvelle epoch, on randomise les indices de l'ensemble labellisé\n",
" np.random.shuffle(indices) \n",
"\n",
" # Et on recommence à cumuler la loss\n",
" cum_loss_value = 0\n",
"\n",
" for step in range(steps_per_epoch):\n",
"\n",
" # Sélection des données du prochain batch\n",
" x_batch = x_train_lab[indices[step*batch_size: (step+1)*batch_size]]\n",
" y_batch = y_train_lab[indices[step*batch_size: (step+1)*batch_size]]\n",
"\n",
" # Etape nécessaire pour comparer y_batch à la sortie du réseau\n",
" y_batch = np.expand_dims(y_batch, 1)\n",
"\n",
" # Les opérations effectuées par le modèle dans ce bloc sont suivies et permettront\n",
" # la différentiation automatique.\n",
" with tf.GradientTape() as tape:\n",
"\n",
" # Application du réseau aux données d'entrée\n",
" y_pred = model(x_batch, training=True) # Logits for this minibatch\n",
"\n",
" # Calcul de la fonction de perte sur ce batch\n",
" loss_value = loss_fn(y_batch, y_pred)\n",
"\n",
" # Calcul des gradients par différentiation automatique\n",
" grads = tape.gradient(loss_value, model.trainable_weights)\n",
"\n",
" # Réalisation d'une itération de la descente de gradient (mise à jour des paramètres du réseau)\n",
" optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
"\n",
" # Mise à jour de la métrique\n",
" train_acc_metric.update_state(y_batch, y_pred)\n",
"\n",
" cum_loss_value = cum_loss_value + loss_value\n",
"\n",
" # Calcul de la précision à la fin de l'epoch\n",
" train_acc = train_acc_metric.result()\n",
"\n",
" # Calcul de la précision sur l'ensemble de test à la fin de l'epoch\n",
" test_logits = model(x_test, training=False)\n",
" test_acc_metric.update_state(np.expand_dims(y_test, 1), test_logits)\n",
" test_acc = test_acc_metric.result()\n",
"\n",
" print(\"Epoch %4d : Loss : %.4f, Acc : %.4f, Test Acc : %.4f\" % (epoch, float(cum_loss_value/steps_per_epoch), float(train_acc), float(test_acc)))\n",
"\n",
" # Remise à zéro des métriques pour la prochaine epoch\n",
" train_acc_metric.reset_states()\n",
" test_acc_metric.reset_states() "
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"id": "FcnTF5WWVacl"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAA+GElEQVR4nO2df5AdV3Xnv/e9N/M0EgQS2WAU/1QEXsx6sZFRZWQbDyPviLUF1pY2m02R2MkYi0nZRs4ST0X+UTNCZhzDVqw/WKSnwk7ZCUUqtU7Cks3GwY4HsXpDQCZYJCEh3kAcIFkcVUgUI3mkeWf/uHPUt++7t/t2v36/z6eqa2bedN++fbvf954+99xzFRFBEARB6F9K3a6AIAiC0Boi5IIgCH2OCLkgCEKfI0IuCILQ54iQC4Ig9DmVbpz0vPPOo0svvbQbpxYEQehbnn/++X8kovPtz7si5JdeeimOHTvWjVMLgiD0LUqpv3V9Lq4VQRCEPkeEXBAEoc8RIRcEQehzRMgFQRD6HBFyQRCEPkeEXBAEoc8RIRf6iqUl4OGH9U9BEDRdiSMXhDwsLQHbtgHLy8DoKPDss8D4eLdrJQjdRyxyoW9YXNQivrKify4udrtGgtAbiJALfcPEhLbEy2X9c2Ki2zUShN5AXCtC3zA+rt0pi4taxMWtIggaEXKhrxgfFwEXBBtxrQiCIPQ5IuSCIAh9jgi5IAhCnyNCLgiC0OeIkAuCIPQ5IuSCIAh9jgi5IAhCnyNCLgiC0OeIkAuCIPQ5IuSCIAh9jgi5IAhCnyNCLgiC0OeIkAuCIPQ5IuSCIAh9jgi5IAhCnyNCLgiC0OeIkAuCIPQ5IuSCIAh9jgi5IAhCn1OYkCulykqpP1VK/X5RZQqCIAjpFGmR7wHwjQLLEwRBEAIoRMiVUhcCuBnAp4ooTxAEQQinKIv8AIBZAI2CyhMEQRACaVnIlVI7AHyfiJ5P2W+3UuqYUurYyy+/3OppBUEQhFWKsMivBfA+pdS3AfwWgEml1G/aOxHRYSK6hoiuOf/88ws4rSBkZ2kJePhh/VMQBoVKqwUQ0V4AewFAKTUB4JeJ6GdbLVcQiubwYeCuu4CVFaBaBZ59Fhgf73atBKF1JI5cGAqWloA77wTOnAEaDeDVV4HFxW7XqnvIm8lg0bJFbkJEiwAWiyxTEIpgcVELOFMuAxMT3apNd7n9duAznwGWl4HRUXkzGQTEIheGgokJ7U4plYBKBfjEJ4ZXvB5/XIv4yor+OcxvJoNCoRa5IPQq4+Pa8lxc1KI+rCLOjI5GFvmwvpkMEmKRC0PD+Diwd+9wivj8PKCU3gDg1Cltkf/Mz2Rrj/n5dtROaBVFRB0/6TXXXEPHjh3r+HkFQdBinvdr38qxQusopZ4nomvsz8UiF4QeIy2iRCJOBBsRckHoICEivW0b8OCD+qe9X9r/Q5iby7a/7Zbh38XN0juIkAtChwgR4cXF5IiStP+HkFWA5+e1O4VdKvy7qxwR9+4gQi7EGKYvYpEuipCyQkR4YkJHkpTL7oiStP93m337ul2D4UTCD4UY+/YNh5izdVzEpJjQsiYmgJERbc2WSsC3vgUcOABs2aL3Vyo9TLLbYZRZ3TJCZxCLXBhKinBRZCnrzBnghReAH/sxPSEJAD71KeCXfxm48UbgJ34COHRI75cWJtnJMEq7U/e5U8SH3l1EyIWh/CLmcVH43CdpZf3rvwKTk8CHPwx873ta7M+c0Zb5yoqO6f7Wt/T/t23T+/tIuidFuIrs8kNcJVl86EJ7kDhyIUa/xwkvLYW7HbLum+Q+8ZV15owW8a98RSfqSqNa1a6WZ5/Vbhgb3/2x63fgAHDihP/afPW1y8/6PPT789Pr+OLIQUQd3zZv3kxCbwJ0uwb5qdeJxsaIymX9s14vruyFBV0uoH8uLIQdd/Ag0dq1bKOGbWvXEh065C7Pd3/M+pVKRCMj+u9qlWhmJt4WSe0EEM3Nues1N5d+vSH7CPkBcIwcmiquFSFGPw9m5fF7h7oj2H1SKmmrc/369HIXFoD9+4Ef/jDwAlb54Q+Bj30ssmxDXF+me6dU0m2wsqLfAmq1eLij3U4PPBAvn90p/CykuUrMNhR3SpdwqXu7N7HIhXaQ1SLPun+tpi3dUil5fy63VMpmiZvbunVER482l530xlSva8u8VtPnVyoqz3yLSLPIQ8+XVpZQPPBY5BJ+KAwMWUPzXBZ80jEnTuic5o1G8v5cbqOFpcjPntV+9a1bw48ZH4/qc+WVwJNP6pS1KyvxQdgs7ZT2hpa1DYX2IEIuDBSmmKXB7ojQdK6h+/N+p0/nH/hbXgZOnmz+nIXVHKx8+ulmlwa3w623ugXb107T09pNwvsnuVMWF7WLSVLidh+JWhEGlpColCyRK1n2X1oCPvpR4I/+SEeuZKVaBR55BPinf2oWUztC5dSpYiJFQic2ZY2QCWF+XvzrIUjUijBU9ILv9uhR7evO4yNfs0Yf7/JR2xE0RUUahUbm5I3gSaKfo6U6CSRqRRgm0iJY2m39seX+utflOz4p5nxiIoowWVnRP7NM4so7sSnrfkIHcal7uzexyIV2k2aRt9MC5HObUSNFbGaMNkeo1Ovp12IfZ7dLo6Gt/0cfJbrjDqKbb9aRL41G8jXy+fPSSrz6sAKPRS4+cmFgOXwYeOopYNcuYPfu+P/aOQPx4Yd1qtqVFX2e170O+MEPspWxZg3wx3+so1bS6pl2Leb/zbqVSsB73wscPw58//val3/mjJ5ROjICvOENwOwscPvt7lmmRSIzQsPw+cglakXoObIOQLqOA4B77tFulS9+UYfjPf10PHcIuyfm5tJdEmaURtrA3sSEdjs0Glqc/vmfw+q/di3wlrcAt9wCbN8efu2uSJakuo2OatcNkW6T06fj+ywv6+1b3wI+9CHg4x8H9uzROWDe/W4JL+xJXGZ6uzdxrQg+8g5S2sfNzCQPyLE7IuQ13p7goxRRpaLdD779q9Vml4FrglC1qgdEN27U0/KXl+Nl+dwpadc/Pe13Wxw5QnTJJUSjo9ndO5WKPr5V7GsRd0oY8LhWRMiFniJvRIR93MxMmI88xFdulm1uIyPN5c7NNe/Pwu8Sxu3bIz91ErWaLoNnlU5PJ1//woL/GvPkfzG3Sy4hOnmy+bpD6YWIon7FJ+QStSL0FHkjIuzjbr1Vx0Hv3++Oh86SU8bMs2KystIcDbNvX7wu1SrwwQ8CR45EUghEP+fmdBlf+pL//EtLwF136dmejYZ2izz+ePL1c7vZLiMinccla/4Xk+99D7jppnh8fJaVgbLmxJHFpgNwqXu7N7HIhSTyRkSEHpcnWoItYtNNkmTp++riOzdANDvrPpata9O9YVrZvP+OHUSTk8nX1kpsu7mtXavrGxo5Y5LFIhfrPQ7EtSL0EkX4RIsIgUsTID6H6XMvlYimpuL+3SwdA3/uE3NbuFgok/zWXLerrnLvc8MN+pyPPprPN55lCx13CLl37Zh81M+IkAs9RRYLzkVRllpSPcxzVKtaAH3nY2EKvS6fOG/a5BYuIBI+PpbPafrjFxaSO5Y77iheuIucXepqJ7HII0TIhfaRwzT2ffHbaam5LMUk69E1gGrWzU4by/uGNMPCgnvCkGmR+wZI+fOxMaKtW937mFErTL2uB2iLFHGldD1avZ9JFFHGoCBCLrSHDCZTmgsipKgs+9qYVixbtkkCkXQO8392hEpapzI3Fz+ehZnDGe16sQXO5+IOoFSKIllYoBcWdGdgdhKmr72VHOkuEb/xRl2+q0MMuUeDItKdCp8UIRfaQ04npsuCCynKtjCziABbseWydpNUq5HI1Grusnzn8IUkVqvNgm93HHwNO3Ykd2zmuXwuDRZm+6e5mQtKrFlTnJBXq0QHDvjbO+1+DpLbpF2upebztEnIAVwE4DkAfwHgzwHsSTtGhHyAyPltdD34IUVl/cIkRYmYli2vcRl6Ca58Kkpp94vrekwfO1vf5rW6rst8+6hUml0xV1wRdUY+l0mlEr+eWo3orW8tZsDTt4qR6/pd7WoKvd12/cYgCPmbALxj9
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAACIlUlEQVR4nO2dd5gUVdaH31td3T05JxiGHAUBUWRXRcFMUERUjBjWLCb4FrPrrlkXDCQDq4iCEgwoqGAAE0ZQQYIgcRiY0JN7Qoeq+/1R0z3dk4dpmAHqfZ59Vrq7bt2q7jn31Lnn/I6QUmJiYmJicviitPYETExMTExahmnITUxMTA5zTENuYmJicphjGnITExOTwxzTkJuYmJgc5qitctatK0OSKrNw5U/kRvek35BhoRjOxMTEpM0SF2HluI7xoq73RKukH66ZHpKTSil568s/+GynxuBL7iIsIjIUw5qYmJi0OdJiwzhvQPs6DflhHVoRQnDVGcfyzNiurJ/3EDt++7a1p2RiYmJyyAmJIRdCxAkhlgghtgghNgsh/h6KcZtKWmIMr9x2Or0Kv+WbuU9QWV52KE9vYmJi0qqEyiN/AfhUStkbGABsDtG4TUYIwTVnHcvTF3Tm93kPsfP3NYd6CiYmJiatQotj5EKIWOA3oKts6mB1xMh1BGWWBDQ1DKgzDNRkJJLvt+xjawF0Pv50VKutReOBEY+v1KDYoyBbOD8TExOT5tJQjDwUWStdgDzgdSHEAGAtcKeUMii+IYS4EbgR4OUp47lxzMlBg5RZErBGxRElNEQI7OTIQZ0YUlzO/G/eI673KaR26d2yAaUkTHqhrJIij6XlEzQxMTEJEaHwyE8AfgBOllL+KIR4ASiRUj5U70F1eORF9nRiw60hMeKBSAmf/7qLPwoEPU4+D6vN3qLB3BVOsitNQ25iYnJoOdhZK3uBvVLKH6v+vQQY1PxhRMiNOIAQcNagzkwYnMSOL94kZ9efLRpMHIxJmpiYmLSAFhtyKWU2kCmE6FX10hnAppaOG2oSYyO5fdQA0op/Z+OqJXjcrtaekomJiUlICFXWyu3AfCHEemAg8ESIxg0pQsDZg7pw9eAk/vpsHjm7a3vnv3z7Jf847xSuHfl3Fs6Z3gqzNDExMWkeITHkUsrfpJQnSCn7SykvkFIWhmLcg0VibCR3njeQtMLf2LTqXTwewzvXNI2Zj9/PY7Pm88rSr1j9yQfs3t6CUIyJiYnJIaB1tFZayIlXPoCjuKLW60mx4fz01uNNGkMIOPv4rhxXVMb8lfNI7HcqjsJS2nXsTLuMTgCcNmIM369aQaduvRoZzcTExKT1OCwNuaO4gr43PVfr9Y0v393ssZLjDO98xbrf+OLLDSSlpPnfS0ptx5/rf23RXE1MTEwONoe11kqoEALOPb4rp3aLpjhrG7l7trX2lExMTEyazGHpkR8s+nRtT6xVklywjk07NpC7by+JqWmNH2hiYmLSipiGPIDB/Xqwbc8+eqVGcGxGGP99/A1u/9czrT0tkxDx5MTLcDpLa70eFRXNfTPeboUZmZiEBtOQB6CqFmY8cBPn3PAImq5zyyVn0T+ykM1fvU+Pk0ajWq2tPUWTFuB0ltL1+toppTvm3N4KszExCR2HpSFPig2vc2MzKTa8xWOPPO0ERp52QtBrJxQ6WbDyDZIGDCMu0Qy1mJiYtC0OS0Pe1BTDUJESH8Wd5w3g419+4a8ddqw9hmG1t0CzxcTExCSEmFkrTUQIGDW4G6P7RPPTnPvI3GymJZqYmLQNTEPeTBJjInn9jtNJ3bWMNQummZotJiYmrc5hGVppbYQQ3HbeIHbsy+ffr95LpzOuJqP3wNaelkkjREVF17mxGRUV3QqzMTEJHS3WIz8g6tQj70BcRNtfV4rKvcS59vr/res6s5b9ytriaAaPu7VleucmJiYm9XCw9ciPahRFYeL5x/Pg6Qn88Mo9ZG75rbWnZGJicpRhGvIArnvgBVJOuYp+509s9rHd0pOYe8cZJG1fypq3nzNj5yYmJocM05AHcM3YM/j0lUcO+HhFUbhjzAk8MDyeH165h6yt60M3ORMTE5N6OKwNuaOwhHET/0N+UUlIxjv1hH4kxEa1eJzuHZKZe8cZxG99j+/feR6vxx2C2ZmYmJjUzWFtyOe9t4LCrL94490VrT2VWiiKwp0XDOa+YXGseWkKWVs3tPaUTExMjlAOW0PuKCxh2WermH1hKss+WxUyrzzU9OiQzBt3nkHsn0v4/p0XTO/cxMQk5By2hnzeeysY3U3QKzWM0d1Em/TKfSiKwt1jT+Te02JY8/I9ZP31R2tPycTE5AjisDTkPm98wvExAEw4PqZNe+U+emakMPf204nZvJgfFk3H6/W09pRMTEyOAA5LQ+7zxpOijAKipCg1JF75Zf/3LH+/bAp/7sqiw/Br+d+7K0Mx3SAsFoVJY09kyilRrHlpCvu2bwz5OUxMTI4uDsvKzvNvepB9+7Nrvd6+XRofvvxYy+fXADUrO1uCpuk8v/QXNrmSOWHsTaiqqXduYmJSNw1Vdh6Whrw1CaUh97F5dy5PvL+eHiNvpF3XPiEd28TE5MjALNFv4/TpZMTOIza8w4+LZ5ixcxMTk2bRhgy5pDUeDpqDMb+DM0mLReH/xp3IpJMi+O6le9i/c8tBOY+JicmRR5sx5BZvJS5pabPGXEpwSQsWb+VBPc8xnVJ54/bhhP2+gB+XzELzeg/q+UxMTA5/2kyMXEdQZklAU8OAOsNArYzE4q0kUitAOUheeU027srhyQ/+oMeom2jfpdchOaeJiUnb5LDY7DSpG69XY+oHv/CXlsbxY27Eorb9DWETE5PQc0g2O4UQFiHEr0KIZaEa0wRU1cI9Fw3hzhPtfDP7n+zb+WdrT8nExKSNEcoY+Z3A5hCOZxJAvy5pvHH7cGy/vsWP775kxs5NTEz8hMSQCyE6AKOAOaEYz6RuVNXCfZf8jTsHW/nmpX+yf5fpnZuYmITOI38emALoIRrPpAGO7ZrGGxOHY137Jj+997LpnZuYHOW0eOdMCDEayJVSrhVCDGvgczcCNwK8PGU8N445uaWnPqpRVQv3jf8763fs5+nZU+h93s2kde7Z2tM6Ynly4mU4naW1Xo+Kiua+GW+3woxMTKppcdaKEOJJ4CrAC4QBMcB7Usor6z3IzFoJKR6vxrNLfmKnpSMnnP8PFIultad0RPHkxMvI3LWdtEuDdXwsFgvlK57j8blHx/6+uZi1Lg1lrbTYI5dS3gfcB1Dlkf9fg0bcJORYVQv3X/p3ft++n2dm/ZPe599CWqcerT2tIwansxRrVAL2pI5Br7sce1ppRq2D01lK1+un13p9x5zbW2E2JoG0mcpOk5YzoFs75t4+DOXnN/j5/VfRNa21p2RiYnIICKkhl1KullKODuWYJs3Dqlp44NK/c+txgq9n/5OczB2tPSUTE5ODjOmRH6EM7N6OuROHIX/4H78s/Z/pnZuYHMGY9d5HMFbVwkOXncSv2/bx39lTOOaCiaR06NLa0zossYRFsG/uXUGveZwFZHTu1joTMjEJwDTkRwHH9WjP3ImpPLV4Dr/YuzJo9LUoivkw1lSioqLBWQphwX8uUUnd2ky2RmMZJaHIOImKiq5zYzMqKrr5EzYJKaYhP0rweefrtu5j6qx/0veCiSSb3nmTaAvGujFD3FhGSSgyTtrCfTCpG9OQH2UM6tme17uk8NTiOawN785xo642vfPDADP1z6QhTEN+FGKzqjx8+Un8snUfz82eYnjn6Z1be1qHDaEsjDGLbExCgWnIj2JOqPLOn1j0MmsjenHcqAmmd94EQukdm562SSgwDflRjs2q8sgVp/DTliyen/VP+o293fTOQ0xDXveBkJ25A60qnbTQkcsD14ym0JHLxjmT6Xv91GaNVZzv4IFrapd+mE8EhxemITcB4MTe6bzetco7j+rFcSNN7zxUhNrr1jTNLxdgjUqg6/XTyc7cQdb8+2qN6Vss6ss4kbq3SXMzQ0BtG9OQm/ix26z8+8pT+HHzXp6fNYX+F95OUvtOrT2tQ0JbN
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from mlxtend.plotting import plot_decision_regions\n",
"\n",
"# Affichage des données\n",
"plt.plot(x_train_unlab[y_train_unlab==0,0], x_train_unlab[y_train_unlab==0,1], 'b.')\n",
"plt.plot(x_train_unlab[y_train_unlab==1,0], x_train_unlab[y_train_unlab==1,1], 'r.')\n",
"\n",
"plt.plot(x_test[y_test==0,0], x_test[y_test==0,1], 'b+')\n",
"plt.plot(x_test[y_test==1,0], x_test[y_test==1,1], 'r+')\n",
"\n",
"plt.plot(x_train_lab[y_train_lab==0,0], x_train_lab[y_train_lab==0,1], 'b.', markersize=30)\n",
"plt.plot(x_train_lab[y_train_lab==1,0], x_train_lab[y_train_lab==1,1], 'r.', markersize=30)\n",
"\n",
"plt.show()\n",
"\n",
"#Affichage de la frontière de décision\n",
"plot_decision_regions(x_train_unlab, y_train_unlab, clf=model, legend=2)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "caF0geTEx5Zv"
},
"source": [
"### Dataset des 2 lunes"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GdFbVzKbMcYE"
},
"source": [
"**Travail à faire** : Mettez en place le même apprentissage pour le dataset des 2 lunes. "
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0 : Loss : 0.6756, Acc : 0.7000, Test Acc : 0.7200\n",
"Epoch 1 : Loss : 0.6550, Acc : 0.9000, Test Acc : 0.7560\n",
"Epoch 2 : Loss : 0.6354, Acc : 0.9000, Test Acc : 0.7600\n",
"Epoch 3 : Loss : 0.6172, Acc : 0.8000, Test Acc : 0.8000\n",
"Epoch 4 : Loss : 0.6000, Acc : 0.8000, Test Acc : 0.8120\n",
"Epoch 5 : Loss : 0.5837, Acc : 0.8000, Test Acc : 0.8200\n",
"Epoch 6 : Loss : 0.5677, Acc : 0.8000, Test Acc : 0.8280\n",
"Epoch 7 : Loss : 0.5520, Acc : 0.8000, Test Acc : 0.8360\n",
"Epoch 8 : Loss : 0.5363, Acc : 0.8000, Test Acc : 0.8360\n",
"Epoch 9 : Loss : 0.5205, Acc : 0.8000, Test Acc : 0.8360\n",
"Epoch 10 : Loss : 0.5050, Acc : 0.8000, Test Acc : 0.8400\n",
"Epoch 11 : Loss : 0.4898, Acc : 0.8000, Test Acc : 0.8400\n",
"Epoch 12 : Loss : 0.4749, Acc : 0.8000, Test Acc : 0.8440\n",
"Epoch 13 : Loss : 0.4607, Acc : 0.8000, Test Acc : 0.8440\n",
"Epoch 14 : Loss : 0.4475, Acc : 0.8000, Test Acc : 0.8400\n",
"Epoch 15 : Loss : 0.4347, Acc : 0.8000, Test Acc : 0.8360\n",
"Epoch 16 : Loss : 0.4223, Acc : 0.8000, Test Acc : 0.8360\n",
"Epoch 17 : Loss : 0.4104, Acc : 0.8000, Test Acc : 0.8360\n",
"Epoch 18 : Loss : 0.3991, Acc : 0.8000, Test Acc : 0.8360\n",
"Epoch 19 : Loss : 0.3886, Acc : 0.8000, Test Acc : 0.8360\n",
"Epoch 20 : Loss : 0.3787, Acc : 0.8000, Test Acc : 0.8440\n",
"Epoch 21 : Loss : 0.3694, Acc : 0.8000, Test Acc : 0.8440\n",
"Epoch 22 : Loss : 0.3606, Acc : 0.8000, Test Acc : 0.8440\n",
"Epoch 23 : Loss : 0.3523, Acc : 0.8000, Test Acc : 0.8440\n",
"Epoch 24 : Loss : 0.3446, Acc : 0.8000, Test Acc : 0.8440\n",
"Epoch 25 : Loss : 0.3374, Acc : 0.8000, Test Acc : 0.8440\n",
"Epoch 26 : Loss : 0.3306, Acc : 0.8000, Test Acc : 0.8440\n",
"Epoch 27 : Loss : 0.3242, Acc : 0.8000, Test Acc : 0.8440\n",
"Epoch 28 : Loss : 0.3181, Acc : 0.8000, Test Acc : 0.8440\n",
"Epoch 29 : Loss : 0.3125, Acc : 0.8000, Test Acc : 0.8440\n",
"Epoch 30 : Loss : 0.3072, Acc : 0.8000, Test Acc : 0.8440\n",
"Epoch 31 : Loss : 0.3021, Acc : 0.8000, Test Acc : 0.8440\n",
"Epoch 32 : Loss : 0.2972, Acc : 0.8000, Test Acc : 0.8520\n",
"Epoch 33 : Loss : 0.2924, Acc : 0.8000, Test Acc : 0.8520\n",
"Epoch 34 : Loss : 0.2878, Acc : 0.8000, Test Acc : 0.8560\n",
"Epoch 35 : Loss : 0.2834, Acc : 0.8000, Test Acc : 0.8560\n",
"Epoch 36 : Loss : 0.2790, Acc : 0.8000, Test Acc : 0.8560\n",
"Epoch 37 : Loss : 0.2747, Acc : 0.8000, Test Acc : 0.8560\n",
"Epoch 38 : Loss : 0.2705, Acc : 0.8000, Test Acc : 0.8560\n",
"Epoch 39 : Loss : 0.2664, Acc : 0.8000, Test Acc : 0.8560\n",
"Epoch 40 : Loss : 0.2623, Acc : 0.8000, Test Acc : 0.8600\n",
"Epoch 41 : Loss : 0.2583, Acc : 0.8000, Test Acc : 0.8640\n",
"Epoch 42 : Loss : 0.2544, Acc : 0.8000, Test Acc : 0.8680\n",
"Epoch 43 : Loss : 0.2506, Acc : 0.8000, Test Acc : 0.8800\n",
"Epoch 44 : Loss : 0.2468, Acc : 0.8000, Test Acc : 0.8800\n",
"Epoch 45 : Loss : 0.2431, Acc : 0.8000, Test Acc : 0.8880\n",
"Epoch 46 : Loss : 0.2394, Acc : 0.8000, Test Acc : 0.8880\n",
"Epoch 47 : Loss : 0.2362, Acc : 0.8000, Test Acc : 0.8880\n",
"Epoch 48 : Loss : 0.2334, Acc : 0.8000, Test Acc : 0.8880\n",
"Epoch 49 : Loss : 0.2306, Acc : 0.8000, Test Acc : 0.8880\n",
"Epoch 50 : Loss : 0.2281, Acc : 0.8000, Test Acc : 0.8920\n",
"Epoch 51 : Loss : 0.2258, Acc : 0.8000, Test Acc : 0.8960\n",
"Epoch 52 : Loss : 0.2237, Acc : 0.8000, Test Acc : 0.9080\n",
"Epoch 53 : Loss : 0.2215, Acc : 0.8000, Test Acc : 0.9040\n",
"Epoch 54 : Loss : 0.2194, Acc : 0.8000, Test Acc : 0.9040\n",
"Epoch 55 : Loss : 0.2174, Acc : 0.8000, Test Acc : 0.9080\n",
"Epoch 56 : Loss : 0.2154, Acc : 0.8000, Test Acc : 0.9080\n",
"Epoch 57 : Loss : 0.2134, Acc : 0.8000, Test Acc : 0.9080\n",
"Epoch 58 : Loss : 0.2115, Acc : 0.8000, Test Acc : 0.9120\n",
"Epoch 59 : Loss : 0.2096, Acc : 0.8000, Test Acc : 0.9120\n",
"Epoch 60 : Loss : 0.2078, Acc : 0.8000, Test Acc : 0.9120\n",
"Epoch 61 : Loss : 0.2061, Acc : 0.8000, Test Acc : 0.9120\n",
"Epoch 62 : Loss : 0.2044, Acc : 0.8000, Test Acc : 0.9120\n",
"Epoch 63 : Loss : 0.2028, Acc : 0.8000, Test Acc : 0.9120\n",
"Epoch 64 : Loss : 0.2012, Acc : 0.8000, Test Acc : 0.9120\n",
"Epoch 65 : Loss : 0.1997, Acc : 0.9000, Test Acc : 0.9120\n",
"Epoch 66 : Loss : 0.1982, Acc : 0.9000, Test Acc : 0.9120\n",
"Epoch 67 : Loss : 0.1968, Acc : 0.9000, Test Acc : 0.9120\n",
"Epoch 68 : Loss : 0.1954, Acc : 0.9000, Test Acc : 0.9120\n",
"Epoch 69 : Loss : 0.1941, Acc : 0.9000, Test Acc : 0.9120\n",
"Epoch 70 : Loss : 0.1928, Acc : 0.9000, Test Acc : 0.9120\n",
"Epoch 71 : Loss : 0.1916, Acc : 0.9000, Test Acc : 0.9080\n",
"Epoch 72 : Loss : 0.1904, Acc : 0.9000, Test Acc : 0.9080\n",
"Epoch 73 : Loss : 0.1893, Acc : 0.9000, Test Acc : 0.9080\n",
"Epoch 74 : Loss : 0.1882, Acc : 0.9000, Test Acc : 0.9080\n",
"Epoch 75 : Loss : 0.1872, Acc : 0.9000, Test Acc : 0.9000\n",
"Epoch 76 : Loss : 0.1862, Acc : 0.9000, Test Acc : 0.9000\n",
"Epoch 77 : Loss : 0.1854, Acc : 0.9000, Test Acc : 0.8960\n",
"Epoch 78 : Loss : 0.1844, Acc : 0.9000, Test Acc : 0.8960\n",
"Epoch 79 : Loss : 0.1835, Acc : 0.9000, Test Acc : 0.8960\n",
"Epoch 80 : Loss : 0.1826, Acc : 0.9000, Test Acc : 0.8960\n",
"Epoch 81 : Loss : 0.1818, Acc : 0.9000, Test Acc : 0.8960\n",
"Epoch 82 : Loss : 0.1812, Acc : 0.9000, Test Acc : 0.8960\n",
"Epoch 83 : Loss : 0.1805, Acc : 0.9000, Test Acc : 0.9000\n",
"Epoch 84 : Loss : 0.1797, Acc : 0.9000, Test Acc : 0.9000\n",
"Epoch 85 : Loss : 0.1789, Acc : 0.9000, Test Acc : 0.9000\n",
"Epoch 86 : Loss : 0.1782, Acc : 0.9000, Test Acc : 0.8960\n",
"Epoch 87 : Loss : 0.1776, Acc : 0.9000, Test Acc : 0.8960\n",
"Epoch 88 : Loss : 0.1769, Acc : 0.9000, Test Acc : 0.8920\n",
"Epoch 89 : Loss : 0.1763, Acc : 0.9000, Test Acc : 0.8920\n",
"Epoch 90 : Loss : 0.1757, Acc : 0.9000, Test Acc : 0.8800\n",
"Epoch 91 : Loss : 0.1750, Acc : 0.9000, Test Acc : 0.8800\n",
"Epoch 92 : Loss : 0.1744, Acc : 0.9000, Test Acc : 0.8800\n",
"Epoch 93 : Loss : 0.1739, Acc : 0.9000, Test Acc : 0.8800\n",
"Epoch 94 : Loss : 0.1733, Acc : 0.9000, Test Acc : 0.8800\n",
"Epoch 95 : Loss : 0.1727, Acc : 0.9000, Test Acc : 0.8760\n",
"Epoch 96 : Loss : 0.1722, Acc : 0.9000, Test Acc : 0.8760\n",
"Epoch 97 : Loss : 0.1716, Acc : 0.9000, Test Acc : 0.8760\n",
"Epoch 98 : Loss : 0.1711, Acc : 0.9000, Test Acc : 0.8760\n",
"Epoch 99 : Loss : 0.1706, Acc : 0.9000, Test Acc : 0.8760\n",
"Epoch 100 : Loss : 0.1701, Acc : 0.9000, Test Acc : 0.8760\n",
"Epoch 101 : Loss : 0.1696, Acc : 0.9000, Test Acc : 0.8760\n",
"Epoch 102 : Loss : 0.1691, Acc : 0.9000, Test Acc : 0.8760\n",
"Epoch 103 : Loss : 0.1687, Acc : 0.9000, Test Acc : 0.8760\n",
"Epoch 104 : Loss : 0.1682, Acc : 0.9000, Test Acc : 0.8760\n",
"Epoch 105 : Loss : 0.1677, Acc : 0.9000, Test Acc : 0.8760\n",
"Epoch 106 : Loss : 0.1673, Acc : 0.9000, Test Acc : 0.8760\n",
"Epoch 107 : Loss : 0.1669, Acc : 0.9000, Test Acc : 0.8760\n",
"Epoch 108 : Loss : 0.1664, Acc : 0.9000, Test Acc : 0.8760\n",
"Epoch 109 : Loss : 0.1660, Acc : 0.9000, Test Acc : 0.8760\n",
"Epoch 110 : Loss : 0.1656, Acc : 0.9000, Test Acc : 0.8760\n",
"Epoch 111 : Loss : 0.1652, Acc : 0.9000, Test Acc : 0.8760\n",
"Epoch 112 : Loss : 0.1648, Acc : 0.9000, Test Acc : 0.8760\n",
"Epoch 113 : Loss : 0.1644, Acc : 0.9000, Test Acc : 0.8760\n",
"Epoch 114 : Loss : 0.1640, Acc : 0.9000, Test Acc : 0.8760\n",
"Epoch 115 : Loss : 0.1636, Acc : 0.9000, Test Acc : 0.8760\n",
"Epoch 116 : Loss : 0.1632, Acc : 0.9000, Test Acc : 0.8760\n",
"Epoch 117 : Loss : 0.1628, Acc : 0.9000, Test Acc : 0.8720\n",
"Epoch 118 : Loss : 0.1624, Acc : 0.9000, Test Acc : 0.8720\n",
"Epoch 119 : Loss : 0.1620, Acc : 0.9000, Test Acc : 0.8720\n",
"Epoch 120 : Loss : 0.1617, Acc : 0.9000, Test Acc : 0.8760\n",
"Epoch 121 : Loss : 0.1613, Acc : 0.9000, Test Acc : 0.8720\n",
"Epoch 122 : Loss : 0.1610, Acc : 0.9000, Test Acc : 0.8720\n",
"Epoch 123 : Loss : 0.1606, Acc : 0.9000, Test Acc : 0.8720\n",
"Epoch 124 : Loss : 0.1602, Acc : 0.9000, Test Acc : 0.8720\n",
"Epoch 125 : Loss : 0.1599, Acc : 0.9000, Test Acc : 0.8720\n",
"Epoch 126 : Loss : 0.1595, Acc : 0.9000, Test Acc : 0.8720\n",
"Epoch 127 : Loss : 0.1591, Acc : 0.9000, Test Acc : 0.8720\n",
"Epoch 128 : Loss : 0.1588, Acc : 0.9000, Test Acc : 0.8720\n",
"Epoch 129 : Loss : 0.1585, Acc : 0.9000, Test Acc : 0.8720\n",
"Epoch 130 : Loss : 0.1581, Acc : 0.9000, Test Acc : 0.8720\n",
"Epoch 131 : Loss : 0.1578, Acc : 0.9000, Test Acc : 0.8720\n",
"Epoch 132 : Loss : 0.1574, Acc : 0.9000, Test Acc : 0.8720\n",
"Epoch 133 : Loss : 0.1571, Acc : 0.9000, Test Acc : 0.8720\n",
"Epoch 134 : Loss : 0.1567, Acc : 0.9000, Test Acc : 0.8680\n",
"Epoch 135 : Loss : 0.1564, Acc : 0.9000, Test Acc : 0.8680\n",
"Epoch 136 : Loss : 0.1561, Acc : 0.9000, Test Acc : 0.8640\n",
"Epoch 137 : Loss : 0.1558, Acc : 0.9000, Test Acc : 0.8640\n",
"Epoch 138 : Loss : 0.1554, Acc : 0.9000, Test Acc : 0.8640\n",
"Epoch 139 : Loss : 0.1551, Acc : 0.9000, Test Acc : 0.8640\n",
"Epoch 140 : Loss : 0.1548, Acc : 0.9000, Test Acc : 0.8640\n",
"Epoch 141 : Loss : 0.1544, Acc : 0.9000, Test Acc : 0.8640\n",
"Epoch 142 : Loss : 0.1541, Acc : 0.9000, Test Acc : 0.8640\n",
"Epoch 143 : Loss : 0.1538, Acc : 0.9000, Test Acc : 0.8640\n",
"Epoch 144 : Loss : 0.1535, Acc : 0.9000, Test Acc : 0.8640\n",
"Epoch 145 : Loss : 0.1531, Acc : 0.9000, Test Acc : 0.8640\n",
"Epoch 146 : Loss : 0.1528, Acc : 0.9000, Test Acc : 0.8640\n",
"Epoch 147 : Loss : 0.1525, Acc : 0.9000, Test Acc : 0.8640\n",
"Epoch 148 : Loss : 0.1522, Acc : 0.9000, Test Acc : 0.8640\n",
"Epoch 149 : Loss : 0.1519, Acc : 0.9000, Test Acc : 0.8640\n",
"Epoch 150 : Loss : 0.1515, Acc : 0.9000, Test Acc : 0.8640\n",
"Epoch 151 : Loss : 0.1512, Acc : 0.9000, Test Acc : 0.8640\n",
"Epoch 152 : Loss : 0.1509, Acc : 0.9000, Test Acc : 0.8640\n",
"Epoch 153 : Loss : 0.1506, Acc : 0.9000, Test Acc : 0.8640\n",
"Epoch 154 : Loss : 0.1503, Acc : 0.9000, Test Acc : 0.8600\n",
"Epoch 155 : Loss : 0.1499, Acc : 0.9000, Test Acc : 0.8560\n",
"Epoch 156 : Loss : 0.1496, Acc : 0.9000, Test Acc : 0.8560\n",
"Epoch 157 : Loss : 0.1493, Acc : 0.9000, Test Acc : 0.8560\n",
"Epoch 158 : Loss : 0.1490, Acc : 0.9000, Test Acc : 0.8560\n",
"Epoch 159 : Loss : 0.1487, Acc : 0.9000, Test Acc : 0.8560\n",
"Epoch 160 : Loss : 0.1484, Acc : 0.9000, Test Acc : 0.8560\n",
"Epoch 161 : Loss : 0.1481, Acc : 0.9000, Test Acc : 0.8560\n",
"Epoch 162 : Loss : 0.1477, Acc : 0.9000, Test Acc : 0.8560\n",
"Epoch 163 : Loss : 0.1474, Acc : 0.9000, Test Acc : 0.8560\n",
"Epoch 164 : Loss : 0.1471, Acc : 0.9000, Test Acc : 0.8560\n",
"Epoch 165 : Loss : 0.1468, Acc : 0.9000, Test Acc : 0.8560\n",
"Epoch 166 : Loss : 0.1465, Acc : 0.9000, Test Acc : 0.8560\n",
"Epoch 167 : Loss : 0.1462, Acc : 0.9000, Test Acc : 0.8560\n",
"Epoch 168 : Loss : 0.1458, Acc : 0.9000, Test Acc : 0.8560\n",
"Epoch 169 : Loss : 0.1455, Acc : 0.9000, Test Acc : 0.8560\n",
"Epoch 170 : Loss : 0.1452, Acc : 0.9000, Test Acc : 0.8560\n",
"Epoch 171 : Loss : 0.1449, Acc : 0.9000, Test Acc : 0.8560\n",
"Epoch 172 : Loss : 0.1446, Acc : 0.9000, Test Acc : 0.8560\n",
"Epoch 173 : Loss : 0.1443, Acc : 0.9000, Test Acc : 0.8560\n",
"Epoch 174 : Loss : 0.1440, Acc : 0.9000, Test Acc : 0.8520\n",
"Epoch 175 : Loss : 0.1437, Acc : 0.9000, Test Acc : 0.8520\n",
"Epoch 176 : Loss : 0.1433, Acc : 0.9000, Test Acc : 0.8520\n",
"Epoch 177 : Loss : 0.1430, Acc : 0.9000, Test Acc : 0.8520\n",
"Epoch 178 : Loss : 0.1427, Acc : 0.9000, Test Acc : 0.8520\n",
"Epoch 179 : Loss : 0.1424, Acc : 0.9000, Test Acc : 0.8520\n",
"Epoch 180 : Loss : 0.1421, Acc : 0.9000, Test Acc : 0.8520\n",
"Epoch 181 : Loss : 0.1417, Acc : 0.9000, Test Acc : 0.8520\n",
"Epoch 182 : Loss : 0.1414, Acc : 0.9000, Test Acc : 0.8520\n",
"Epoch 183 : Loss : 0.1411, Acc : 0.9000, Test Acc : 0.8520\n",
"Epoch 184 : Loss : 0.1408, Acc : 0.9000, Test Acc : 0.8520\n",
"Epoch 185 : Loss : 0.1405, Acc : 0.9000, Test Acc : 0.8520\n",
"Epoch 186 : Loss : 0.1402, Acc : 0.9000, Test Acc : 0.8520\n",
"Epoch 187 : Loss : 0.1398, Acc : 0.9000, Test Acc : 0.8520\n",
"Epoch 188 : Loss : 0.1395, Acc : 0.9000, Test Acc : 0.8520\n",
"Epoch 189 : Loss : 0.1392, Acc : 0.9000, Test Acc : 0.8520\n",
"Epoch 190 : Loss : 0.1389, Acc : 0.9000, Test Acc : 0.8520\n",
"Epoch 191 : Loss : 0.1386, Acc : 0.9000, Test Acc : 0.8520\n",
"Epoch 192 : Loss : 0.1382, Acc : 0.9000, Test Acc : 0.8480\n",
"Epoch 193 : Loss : 0.1379, Acc : 0.9000, Test Acc : 0.8480\n",
"Epoch 194 : Loss : 0.1376, Acc : 0.9000, Test Acc : 0.8480\n",
"Epoch 195 : Loss : 0.1373, Acc : 0.9000, Test Acc : 0.8480\n",
"Epoch 196 : Loss : 0.1369, Acc : 0.9000, Test Acc : 0.8480\n",
"Epoch 197 : Loss : 0.1366, Acc : 0.9000, Test Acc : 0.8480\n",
"Epoch 198 : Loss : 0.1363, Acc : 0.9000, Test Acc : 0.8480\n",
"Epoch 199 : Loss : 0.1360, Acc : 0.9000, Test Acc : 0.8480\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"import math\n",
"\n",
"# Données et modèle du problème des 2 clusters\n",
"x_train_lab, y_train_lab, x_train_unlab, y_train_unlab, x_test, y_test = generate_2moons_dataset(num_lab = 10, num_unlab=740, num_test=250)\n",
"model = create_model_2moons()\n",
"\n",
"# Hyperparamètres de l'apprentissage\n",
"epochs = 200\n",
"batch_size = 32\n",
"if batch_size < x_train_lab.shape[0]:\n",
" steps_per_epoch = math.floor(x_train_lab.shape[0]/batch_size)\n",
"else:\n",
" steps_per_epoch = 1\n",
" batch_size = x_train_lab.shape[0]\n",
"\n",
"# Instanciation d'un optimiseur et d'une fonction de coût.\n",
"optimizer = keras.optimizers.Adam(learning_rate=1e-2)\n",
"loss_fn = keras.losses.BinaryCrossentropy()\n",
"\n",
"# Préparation des métriques pour le suivi de la performance du modèle.\n",
"train_acc_metric = keras.metrics.BinaryAccuracy()\n",
"test_acc_metric = keras.metrics.BinaryAccuracy()\n",
"\n",
"# Indices de l'ensemble labellisé\n",
"indices = np.arange(x_train_lab.shape[0])\n",
"\n",
"# Boucle sur les epochs\n",
"for epoch in range(epochs):\n",
"\n",
" # A chaque nouvelle epoch, on randomise les indices de l'ensemble labellisé\n",
" np.random.shuffle(indices) \n",
"\n",
" # Et on recommence à cumuler la loss\n",
" cum_loss_value = 0\n",
"\n",
" for step in range(steps_per_epoch):\n",
"\n",
" # Sélection des données du prochain batch\n",
" x_batch = x_train_lab[indices[step*batch_size: (step+1)*batch_size]]\n",
" y_batch = y_train_lab[indices[step*batch_size: (step+1)*batch_size]]\n",
"\n",
" # Etape nécessaire pour comparer y_batch à la sortie du réseau\n",
" y_batch = np.expand_dims(y_batch, 1)\n",
"\n",
" # Les opérations effectuées par le modèle dans ce bloc sont suivies et permettront\n",
" # la différentiation automatique.\n",
" with tf.GradientTape() as tape:\n",
"\n",
" # Application du réseau aux données d'entrée\n",
" y_pred = model(x_batch, training=True) # Logits for this minibatch\n",
"\n",
" # Calcul de la fonction de perte sur ce batch\n",
" loss_value = loss_fn(y_batch, y_pred)\n",
"\n",
" # Calcul des gradients par différentiation automatique\n",
" grads = tape.gradient(loss_value, model.trainable_weights)\n",
"\n",
" # Réalisation d'une itération de la descente de gradient (mise à jour des paramètres du réseau)\n",
" optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
"\n",
" # Mise à jour de la métrique\n",
" train_acc_metric.update_state(y_batch, y_pred)\n",
"\n",
" cum_loss_value = cum_loss_value + loss_value\n",
"\n",
" # Calcul de la précision à la fin de l'epoch\n",
" train_acc = train_acc_metric.result()\n",
"\n",
" # Calcul de la précision sur l'ensemble de test à la fin de l'epoch\n",
" test_logits = model(x_test, training=False)\n",
" test_acc_metric.update_state(np.expand_dims(y_test, 1), test_logits)\n",
" test_acc = test_acc_metric.result()\n",
"\n",
" print(\"Epoch %4d : Loss : %.4f, Acc : %.4f, Test Acc : %.4f\" % (epoch, float(cum_loss_value/steps_per_epoch), float(train_acc), float(test_acc)))\n",
"\n",
" # Remise à zéro des métriques pour la prochaine epoch\n",
" train_acc_metric.reset_states()\n",
" test_acc_metric.reset_states() "
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAABGzUlEQVR4nO2de5DlVXXvv/uc7j4zjC9oLIHIOLQoEeXKY+zQExjagRoCEukUJtFYGWAQaAMGkiq7RMFuGDOTkCrpG+EyPQqGyet6E28lkApOwqPFzGmFwdf4KBMwCVfLVyZXHS/DNDNn3z9WL3/rt8/ev8d5/s4561P1q/P4vfZv/36/tddee621jbUWiqIoSv9T6nYBFEVRlM6gAl9RFGVAUIGvKIoyIKjAVxRFGRBU4CuKogwIQ90uQIjjjz/erlu3rtvFUBRF6Smefvrp/7TWvtK3rrACf926ddi3b1+3i6EoitJTGGP+I7ROTTqKoigDggp8RVGUAUEFvqIoyoCgAl9RFGVAUIGvKIoyIKjA7zPm5rpdAkVRiooK/D7j9tu7XQJFUYqKCnxFUZQBQQV+HzA3BxhDCxB9V/OOoigSU9QJUNavX2810jY/xgAFvaWKonQAY8zT1tr1vnWq4SuKogwIKvD7jNnZbpdAUZSiogK/z1C7vaIoIVTgK4qiDAgq8BVFUQYEFfiKoigDggp8pedZWgJ27KBPRVHCFHbGK0VZWgIWF4HJSWBiIrzNhRcCy8vAyAjw6KPhbRVl0FGBr3SELMLb3T4kyOWxFhdpm6NH6XNxkbbznS9vGRSl31CBr7SdNC18bq7enTRJkMtjzc/TJ/+enPSfD+hcT8B3PYpSBNSGr7Qdn/CW+DJ8Tk6SYC6XI0HuO9aBAyS8t22LhLjvfGllaCWasVQpKqrhK22HhbfUwtOYmCABvnt3+rEmJuLaeuh8ecugKP2GavhK22HhLbXwrBk+H3gA+PjHyRyztOQ/lu988/O0z/x81CCk7dcMmrFU6QU0W6bSdUIZPnfsAG67jcww5TIJ61tuST9etz13NGOp0k00W6bSk4Ts+Glktder/74yaKgNX8lFO1wbQxk+XTv+/v3Zzp1lzKCdvQDNWKoUFRX4SmYaEZJZGgjXzi33AciOf/gwUKsBpRJQqSSfm234n/40cMUV/u1Cbp+tQO32SlFRga9kRgrJw4cjf/NWRsG6+1x5JX2v1Wh9rZYuoJeWgJtvpu0+9zngjDPq/f4vvli9dpTBQ234ys9JsmkvLQHPPUf29FKJBO8jj0TeMz7y+r5bS9r8Cy/QPi+8AHz/+8DwMJ0ToM80Ab24SA0SN0w+v/92e+0oShFRDV8BkJ7K4K1vpXVDQ8D69cC+fX5tW5pjsvrfv/gicN99wJ13koBnDxdrgT17gOOOA97yFtLKf/xjOs6ePWEhPToa7xGMjvqvN8nUlGesQiNrlZ7BWlvI5ZxzzrFK59i+3dpy2VqAPrdvj9ZNT9P/vJx+urUjI7Td6tXWVqu0XbVKv+X/1Sodi7dxOXjQ2vPOs/aYY+LncJdjjrH2/PNpe2vpv6RrKZVom1KJfs/O+o9bKtH1yfL5riOJpLIoSqcBsM8G5KqadBQA+Vwgv/EN8jW/9tp4TyA0EMpJzlzTz4svApdcAjz1FPD888nle/554MkngUsvpf2AsPlpcpIGdstl+pycJA2cxbykVgN27oybptzr2L07m/umunkqRUdNOgqAyKbtM2Ns2QLcfz8JP+bIEWDt2vSUBktLwMaNJGhdU9F99wFf/CLZ2bNw+DANwo6M0O8PfpA+t24FTj45Mqv40jK4nj8ubgPF1zE0RNd+9Gi8/HNz8Zw5MsJ21SodF1AKSkj17/aiJp3GmJ1tzz7VKpk+KhUygwwPW7uwEF+/fTv9J00427dH5hNj6BjWWlurWXvKKclmnNBy7LHRdzY/uWYVaZYZGaFys4kGoN/GRGYd13TD1zM9HZm6SiVrN2+uN/EAySYxRekkSDDptEQ4A7gfwA8BfC2w3gD4EwDPAPgqgLPTjqkCPz9sp85Lnn0WFqwdGooLySSbd7UaF9blMh1j715r16xpTOAPDUXH4vO51yAFMAt233LZZf4xhtlZa7dutXZqio6T1DgA+e3+itIukgR+q0w6fwrgbgC7A+svAfC6leWXANy78qm0kHal5ZXmkAMHSFRKDx2g3na/Z4+/PEePAjfcQAvb4vNy5AjwS78ErFkDPPYYsGED/c9mldnZuJ99uUzrjhyh/w4dSs9145adk6Hxde/eHdXJ7GyySUxRikJLBL619gljzLqETS4HsHul9fm8MeYVxpgTrbXfa8X5B52QPZlD/H0ug0n7yO1dd80LLyQBWqvRZyj1MNu5+djDw5GAr9WAr3+9cYEPAJddRgL8scf8610BDETfuYHIg7V0DXzdn/xk1IDwBCtummZFKRwh1T/vAmAdwiadvwdwnvj9KID1nu2uA7APwL61a9e2sdPTP4TcDdkuz/Zln4lBbhNCmkbY1XF4mEwclUp03IUFsm/77PpA3BRUqVj7pjc1Zs5Juk73upLqLE9dSlOUa9d37fWNjKEoSitBu234tkUCXy5qw88PCyfpF+/aut3t5ScjfefZNs3C3hWCvJ3P/57/GxqK/puepkHUZgX93r3+a5HfZ2fT4wDS6nJqKl6nXD8he30jYyiK0kqSBH6n/PC/C+Bk8fvVK/8pDeJLOAaQ+WTjRvpk08XRo2S3vvVW/7Fkdkc24dx2G30CwLveFUWuSowhE8mttyZPKWht5PK4di391yyu6cSXofL22+PXktU/Xm63Z090PvazB8iMc+21lOtHUXqFTgn8BwFsMcS5AH5i1X7/cxoJ2HEHFRcXo4k3rAW2bweqVVpXLgOrVwMf+Yh/Zqbbb48akN2747lsdu8mf/lqlY4BUDDT9DTwxBMkCB97rD5oKxTI9cwz9FuWIQ/HHAO87W3+fd3rAqJryZLLh++D9N8/dCg63oYNwIc+RI3H/v3RbFzcwOpsV0rhCan+eRYAfwXgewBeBPAdANcAmAYwvbLeALgHwLMA9iPFnGMHyKST1Z3PNU0k+Z3L4yTZ8H3HcM0t5bK14+PWzsxEKRZ8piGf6cT3H28LWHvVVfnMOJUKpVdYXg7X0aZN/n3ZrBRC1l+lEqWO4P1dP/vNm+NunxxfoCYdpdugEzb8Vi+DIvCzBOywMAr5k/NAoU/AuoOIcpsk33XfYkyUe2br1uSyhHDHDQBrTzghXdi7uXRc5KCwbFTc8YhQ4+feh+npeNCY26AuLMQbRx7AVoGvdBsV+AUmi4bvaxQaESzuubZu9a/3DdC6gt/tRSSR5P0il+OPj/8eHqbgrLExa3fuTNbsh4fj+7p1lFbP7iCzr3xbt8YbjOnpqBHm+3LBBdnuQyMDyYqShSSBr7l0ukyWgJ3JScrpUqvRZ6OTdbhJwU49NVyWH/8YeOgh4JvfrD+Ota2fJQqgiUre+U7gwQeBk04C3vhGYHwcOPfcZHv/4mLcp5/raHY2Chp77jl/cJibf8e9DzwuYgyNZUi2bCE7vow/4Pw+Ibo9wboy2KjALwBZAnasjT737wc2bQJ27aLI11BDwcKMP0P56X2531koDQ0BZ59N637603jAEe+fNoerzBcvBWi1SgOhpRIN7r7zncB119GSBy4HH+fuu/3XwYOpHDC2YUN8YDVr4JSsr9NOA37jN7JH1+adWrEdcwgrA0xI9e/2MigmHUmoq+8GP0lbtS+3C8OmiKRz+HLjbN8emSqMiY8r+AKssuAGSMnf8vyumSmpXtxrdcvl1hvnxEmzt/P5QuMUQ0PxRGwhk5VrAuJjZ825o/l5lEaA2vCLj7Sfu5koFxbigssdvA0N9voEvrWRsF1YiNvrebKQmZn48Wdm4mVkTxY5cUjWwVp3u02b4uMTvgZKZr3kcyZ55HCj4drlk8Ym5MB3yNvJ2nq7vVte/p2WUC6LDV8zcCqNkCTwdQKULiN9vw8fJjv9iy9ScrGlJVre974oWKlWi0wiQP0cr66fPRD3C19aIr/7XbuAG2+MB1RxI
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAABuOUlEQVR4nO2dd5wU5f3H38/MbLveC0cvIoJdMZZYYokgir1gN2qCwVhITKwxiYn+TCBRsASNBUWK2BDEDjZUQEQUEaTDcb3u3taZeX5/zO7eHRyK3nGH3PN+ve51u7OzM8/M7n6e5/k+3yKklCgUCoVi70fr6gYoFAqFonNQgq9QKBTdBCX4CoVC0U1Qgq9QKBTdBCX4CoVC0U0wuroB38WCbyqVC5FCoVD8AE7Yt0Ds7LU9WvDXVga6ugkKhULxk+KEfQt2+poy6SgUCkU3QQm+QqFQdBOU4CsUCkU3YY+24beFQJLpsvHqIMRO1ya6DCklYQsaYhqSPa99CoWi+/KTE/xMl01WqhdbGLAHCj5S4pUmNIWpj+ld3RqFQqFI8pMz6Xh19lyxBxACWxh4ldYrFIo9jHYLvhCilxBigRDiayHESiHEDW3sI4QQDwoh1gohVgghDmnH+fZcsU8gxB5pblIoFN2bjjDpmMB4KeUyIUQ68JkQ4i0p5dct9hkBDIr/HQE8Ev+vUCgUik6i3SN8KWWZlHJZ/LEfWAWUbLfbaGCqdPgEyBJCFLf33F3F0g/f5VenH8OVI49k5uOTuro5CoVCsUt0qA1fCNEXOBj4dLuXSoAtLZ5vZcdOIXGMa4UQS4UQS9+fM70jm9chWJbFQ3+/jXsensaUV95j4fyX2bRudVc3S6FQKL6XDvPSEUKkAS8AN0opG3/scaSUU4ApAI+9v75duXRuuOwsGhp3bEpmRgYPTH3pRx1z9ZefU9y7L8W9+gBw3IjRfLzgDfoMGNyepioUCsVup0MEXwjhwhH7aVLKF9vYpRTo1eJ5z/i23UpDYyODrp28w/Zvp4z70cesqSwnv6h5cpJXWMzqFZ//6OMpFApFZ9ERXjoC+B+wSko5cSe7zQEui3vr/AxokFKWtffcCoVCodh1OmKEfzRwKfClEGJ5fNttQG8AKeWjwGvASGAtEASu7IDzdgm5BUVUlTdPTqorysgtLOrCFikUCsWu0W7Bl1J+CN+dQ0BKKYHftvdcewKDhx3Etk0bKN+6mdzCIt6b/wp//L+Hu7pZCoVC8b385FIrdDW6YXDdbf/g9t9chG1ZnHLWhfQdqBZsFQrFns9eLfiZGRltLtBmZmS067jDjz2R4cee2K5jKBQKRWezVwv+j3W9VCgUir2Rn1zyNIVCoVD8OJTgKxQKRTdBCb5CoVB0E5TgKxQKRTdBCb5CoVB0E5Tg/wgm3nkTFxw3jF+fdXxXN0WhUCh2GSX4P4KTR5/PPY8819XNUCgUih9EtxD8hroa/v67S2isr+2Q4+1/2JGkZ2Z3yLEUCoWis+gWgv/uy9Owt33BOy8929VNUSgUii5jrxf8hroaPn9rNv85uyefvzW7w0b5CoVC8VNjrxf8d1+exukDYVChj9MHokb5CoWi27JXC35idD/m0EwAxhyaqUb5CoWi27JXC35idJ+b5gKc/x0xyr/3lrHcdMkotm5cxyUnHsLrLyqPHYVCseezV2fL/HLxB3xQFmb6iq2ttmdVfcBZV/7uRx/31vsfaW/TFAqFotPZqwX/rkee7+omKBQKxR7DXm3SUSgUCkUzHSL4QognhBCVQoivdvL68UKIBiHE8vjfXT/2XFJKkPLHN7YzkNJpp0KhUOxBdJRJ5ylgMjD1O/b5QEo5qr0nClvglSY2BojvrJ3eNUiJJk3CVlc3RKFQKFrTIYIvpXxfCNG3I471fTTENGgK49VB7IGCL6UkbMXbqVAoFHsQnbloe6QQ4gtgG/B7KeXKtnYSQlwLXAtwyfh7OPaMi1q9LhHUx3SI7e7mKhQKxd5FZwn+MqCPlDIghBgJvAwMamtHKeUUYArAY++vV4ZwhUKh6CA6xe4gpWyUUgbij18DXEKIvM44t0KhUCgcOkXwhRBFIm5wF0IMj5+3pjPOrVAoFAqHDjHpCCGmA8cDeUKIrcCfAReAlPJR4FxgrBDCBELAhVL5LSoUCkWn0lFeOhd9z+uTcdw2FQqFQtFFKN9BhUKh6CYowVcoFIpughJ8hUKh6CYowVcoFIpughJ8hUKh6CYowVcoFIpughJ8hUKh6CYowVcoFIpuwl5d4lChuHfcRQQC/h22p6Wlc+vk6V3QIoWi61CCr9hj2RWx3n6f+qpypNDQhEZmbh511ZW40nLQvSkAWOEgADVlq7hu1HCA5L7bH1uh2NtQgq/YYwkE/PS/etIO29c/fv1O91kxeSw9rvgPkerNlPQdROnGb/Hk9WbbUzcC0OOK/wCw+fHf0vvqhwCS+25/bIVib0MJvuInz8rHxydH7lF/LaHKzUjbpHzL+i5umUKxZ6EEX/GTxwoHkyP3LU/egCuvF9KMYjVWtNovGqgjVLk5+TwWjTjvN81Oa6tC0ZUowVd0CR29mGpGI9hSIqVE2jZSSiwzlnwNAClx5fUCQGg6wnAnt5du/BaAuupKbr9iVLvaolDsqSjBV3QJgYCflF/ehGVZrbZvmXEH945zsm3XVVeyYvLYVq/r3hR88cf3jruIuupKhDcdu2V5hURte81ACEG0tpRYoBbbjFD21A0AWE11lD19Y3x/nV5XTATAlZaTXBNQ9nzF3oYSfMVuo61RfH1VOUIzsKVN/Qv/QOjOV1Bzp1By8b240nKS7ym5+N4dOoTyGXeQ33cA4HQa3ryehKq2YjY45huhG5Q9+TtAYIca0HwZyLCfXn0HsGndanJH3ugcqEUHUTHrruSibsKbR6HYG1GCr9httOVl8/kD15B72s2O4OrNX7/KmXew5YnfYfqriQWc7Z642Ou6TlGv/gBE8gpamVmGXj2BFZPH4soqQkoovnQC0erNuPN6U/b0jRScPp7I2w9y6+Tp/GbkobgL+rXRUknUH6+46a9h6b3nA2DHIow7/WdJl01QZh7FTxsl+IpORUqJ0AykbeLK6YnQnGBvPTWHossnUvrwFbjScpxtGYVIIFpb2srGfu+4i1qJru5NoWzqzYAzwjcbqxGahrRtKl7+P2TY79jlW4zqYzVbkdKOPxPoqTkgQLhTKLroH0jbZutDl5F7zp3ohiv5vi0z7uD2K0aRlpYOoIK6FD8plOArughHfBOljaVtOSJs20T9tQCUTb8VGYsCMmn6kRK2bnDEv76qnKX3XRgXctnq2IUX/gNXbk+kGcVurKCk7yBq/n4udiwSP46NK6cnAHpqFkWX/xshRLNdvwWV8x5AxsIA2JZFbXWVs3YgbQ659fkd9le2f8WeSkcVMX8CGAVUSimHtfG6AB4ARgJB4Aop5bKOOLfip0PCL17aJkiQVqzFqxIjsxA9NZv8M26het4EsCyKLv1X/HVnJdas20rFzLv47WmHYUmJ7ktzXnV5yB3hLMhWTL8N24wQrdwAgNvjSR5CCJH0zknMLpzTxz18pMSOdzIARnYJWCbFVzqmqfJnf4+MdxpWsD65qKx7Uxh69YQOu1cKxe6go0b4T+EUKZ+6k9dHAIPif0cAj8T/K7oRlmUhNL2VaaUlZkMlIJ0OQTOwGisxa0tb2frBEWrLdswxVlO90xdYJjVzHU8bbBMhNHoPHLJ9AyibejNCN7ACtehx05EwPMnjCiHQXG7sWLS5Q4h3FAAyFqHo0gkIIYjVlpJS5KwtJBZ9FYo9mQ4RfCnl+0KIvt+xy2hgqnTm758IIbKEEMVSyrKOOL+i89je86ahphpb2ghpk5VflNyelpZOWlo66x+/PpnfxrYtsEwqZt2F0DS01GxHSCVo3jRammUKL/wH5dP+AEIkbf12LILZUIGWmk1e3NvGyO4BQPkz4ylOBF9NuqTNthtuL6a/BmG4kbadnBG03QHJVg/lTjopheKnRGfZ8EuALS2eb41v20HwhRDXAtcCXDL+Ho4946JOaWB35YcGQG3vedMyV
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from mlxtend.plotting import plot_decision_regions\n",
"\n",
"# Affichage des données\n",
"plt.plot(x_train_unlab[y_train_unlab==0,0], x_train_unlab[y_train_unlab==0,1], 'b.')\n",
"plt.plot(x_train_unlab[y_train_unlab==1,0], x_train_unlab[y_train_unlab==1,1], 'r.')\n",
"\n",
"plt.plot(x_test[y_test==0,0], x_test[y_test==0,1], 'b+')\n",
"plt.plot(x_test[y_test==1,0], x_test[y_test==1,1], 'r+')\n",
"\n",
"plt.plot(x_train_lab[y_train_lab==0,0], x_train_lab[y_train_lab==0,1], 'b.', markersize=30)\n",
"plt.plot(x_train_lab[y_train_lab==1,0], x_train_lab[y_train_lab==1,1], 'r.', markersize=30)\n",
"\n",
"plt.show()\n",
"\n",
"#Affichage de la frontière de décision\n",
"plot_decision_regions(x_train_unlab, y_train_unlab, clf=model, legend=2)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YPiuBS36V8EG"
},
"source": [
"# Minimisation de l'entropie"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UlBFMsFLMtEp"
},
"source": [
"Nous allons dans cette partie mettre en place le mécanisme de minimisation d'entropie, conjointement à la minimisation de l'entropie croisée.\n",
"\n",
"Pour commencer, implémentez la fonction de coût qui calcule l'entropie $H$ des prédictions du réseau $\\hat{y}$ :\n",
"$$ H(\\hat{y}) = - ∑_{i=1}^N \\hat{y}_i log(\\hat{y}_i) $$\n",
"\n",
"Pour les exemples simples des datasets des 2 clusters et des 2 lunes, il faut implémenter une entropie binaire ! (plus tard, sur MNIST, il faudra implémenter une version multi-classe de l'entropie)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"id": "1gEt2x_sXFin"
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"\n",
"# Calcul de l'entropie de y_pred\n",
"def binary_entropy_loss(y_pred):\n",
" return -tf.reduce_sum(y_pred * tf.math.log(y_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-L1Li1YtNN87"
},
"source": [
"**Travail à faire** : Reprenez maintenant la boucle d'apprentissage supervisé et introduisez la minimisation d'entropie pour régulariser l'apprentissage.\n",
"\n",
"La difficulté principale va être l'introduction des données non labellisées dans la boucle. Ainsi, un batch devra maintenant être composé de données labellisées et non labellisées. Je vous suggère de conserver le même nombre de données labellisées par batch que précédemment (i.e. 16) et de prendre un plus grand nombre de données non labellisées, par exemple 90.\n",
"\n",
"N'oubliez pas également d'introduire un hyperparamètre λ pour contrôler l'équilibre entre perte supervisée et non supervisée. Utilisez un λ constant dans un premier temps, et testez ensuite des variantes qui consisteraient à augmenter progressivement sa valeur au fil des epochs. \n",
"\n",
"La fonction objectif à minimiser aura donc la forme : \n",
"$$ J = \\sum_{(x,y) \\in \\mathcal{L}} CE(y, \\hat{y}) + \\lambda \\sum_{x \\in \\mathcal{U}} H(\\hat{y})\t$$"
]
},
{
"cell_type": "code",
"execution_count": 188,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0 : Loss : 1.5490, Acc : 0.7000, Test Acc : 0.7800\n",
"Epoch 1 : Loss : 1.5385, Acc : 0.7000, Test Acc : 0.8000\n",
"Epoch 2 : Loss : 1.5271, Acc : 0.7000, Test Acc : 0.8600\n",
"Epoch 3 : Loss : 1.5151, Acc : 0.7000, Test Acc : 0.8800\n",
"Epoch 4 : Loss : 1.5028, Acc : 0.7000, Test Acc : 0.8800\n",
"Epoch 5 : Loss : 1.4902, Acc : 0.7000, Test Acc : 0.8800\n",
"Epoch 6 : Loss : 1.4767, Acc : 0.7000, Test Acc : 0.8800\n",
"Epoch 7 : Loss : 1.4626, Acc : 0.7000, Test Acc : 0.8800\n",
"Epoch 8 : Loss : 1.4474, Acc : 0.7000, Test Acc : 0.8800\n",
"Epoch 9 : Loss : 1.4315, Acc : 0.7000, Test Acc : 0.8800\n",
"Epoch 10 : Loss : 1.4145, Acc : 0.7000, Test Acc : 0.8800\n",
"Epoch 11 : Loss : 1.3964, Acc : 0.7000, Test Acc : 0.8800\n",
"Epoch 12 : Loss : 1.3772, Acc : 0.7000, Test Acc : 0.8800\n",
"Epoch 13 : Loss : 1.3567, Acc : 0.7000, Test Acc : 0.8800\n",
"Epoch 14 : Loss : 1.3348, Acc : 0.7000, Test Acc : 0.8800\n",
"Epoch 15 : Loss : 1.3121, Acc : 0.7000, Test Acc : 0.8800\n",
"Epoch 16 : Loss : 1.2887, Acc : 0.7000, Test Acc : 0.8800\n",
"Epoch 17 : Loss : 1.2647, Acc : 0.7000, Test Acc : 0.8800\n",
"Epoch 18 : Loss : 1.2405, Acc : 0.7000, Test Acc : 0.8800\n",
"Epoch 19 : Loss : 1.2163, Acc : 0.7000, Test Acc : 0.8800\n",
"Epoch 20 : Loss : 1.1924, Acc : 0.7000, Test Acc : 0.8600\n",
"Epoch 21 : Loss : 1.1691, Acc : 0.7000, Test Acc : 0.8600\n",
"Epoch 22 : Loss : 1.1466, Acc : 0.7000, Test Acc : 0.8600\n",
"Epoch 23 : Loss : 1.1253, Acc : 0.7000, Test Acc : 0.8600\n",
"Epoch 24 : Loss : 1.1053, Acc : 0.7000, Test Acc : 0.8600\n",
"Epoch 25 : Loss : 1.0868, Acc : 0.7000, Test Acc : 0.8600\n",
"Epoch 26 : Loss : 1.0698, Acc : 0.7000, Test Acc : 0.8800\n",
"Epoch 27 : Loss : 1.0543, Acc : 0.7000, Test Acc : 0.8800\n",
"Epoch 28 : Loss : 1.0402, Acc : 0.7000, Test Acc : 0.8800\n",
"Epoch 29 : Loss : 1.0273, Acc : 0.7000, Test Acc : 0.8800\n",
"Epoch 30 : Loss : 1.0154, Acc : 0.7000, Test Acc : 0.8800\n",
"Epoch 31 : Loss : 1.0045, Acc : 0.7000, Test Acc : 0.8800\n",
"Epoch 32 : Loss : 0.9943, Acc : 0.7000, Test Acc : 0.8800\n",
"Epoch 33 : Loss : 0.9845, Acc : 0.8000, Test Acc : 0.8800\n",
"Epoch 34 : Loss : 0.9751, Acc : 0.8000, Test Acc : 0.8800\n",
"Epoch 35 : Loss : 0.9660, Acc : 0.8000, Test Acc : 0.8800\n",
"Epoch 36 : Loss : 0.9570, Acc : 0.8000, Test Acc : 0.8800\n",
"Epoch 37 : Loss : 0.9477, Acc : 0.8000, Test Acc : 0.8800\n",
"Epoch 38 : Loss : 0.9383, Acc : 0.8000, Test Acc : 0.9000\n",
"Epoch 39 : Loss : 0.9286, Acc : 0.8000, Test Acc : 0.9200\n",
"Epoch 40 : Loss : 0.9187, Acc : 0.8000, Test Acc : 0.9200\n",
"Epoch 41 : Loss : 0.9087, Acc : 0.8000, Test Acc : 0.9200\n",
"Epoch 42 : Loss : 0.8986, Acc : 0.8000, Test Acc : 0.9000\n",
"Epoch 43 : Loss : 0.8883, Acc : 0.8000, Test Acc : 0.9200\n",
"Epoch 44 : Loss : 0.8776, Acc : 0.8000, Test Acc : 0.9000\n",
"Epoch 45 : Loss : 0.8667, Acc : 0.8000, Test Acc : 0.9000\n",
"Epoch 46 : Loss : 0.8560, Acc : 0.8000, Test Acc : 0.9000\n",
"Epoch 47 : Loss : 0.8452, Acc : 0.8000, Test Acc : 0.9000\n",
"Epoch 48 : Loss : 0.8344, Acc : 0.8000, Test Acc : 0.9000\n",
"Epoch 49 : Loss : 0.8235, Acc : 0.8000, Test Acc : 0.8800\n",
"Epoch 50 : Loss : 0.8123, Acc : 0.8000, Test Acc : 0.8800\n",
"Epoch 51 : Loss : 0.8011, Acc : 0.8000, Test Acc : 0.8400\n",
"Epoch 52 : Loss : 0.7900, Acc : 0.8000, Test Acc : 0.8400\n",
"Epoch 53 : Loss : 0.7789, Acc : 0.8000, Test Acc : 0.8400\n",
"Epoch 54 : Loss : 0.7677, Acc : 0.8000, Test Acc : 0.8400\n",
"Epoch 55 : Loss : 0.7566, Acc : 0.8000, Test Acc : 0.8400\n",
"Epoch 56 : Loss : 0.7455, Acc : 0.8000, Test Acc : 0.8400\n",
"Epoch 57 : Loss : 0.7345, Acc : 0.8000, Test Acc : 0.8200\n",
"Epoch 58 : Loss : 0.7237, Acc : 0.8000, Test Acc : 0.8200\n",
"Epoch 59 : Loss : 0.7133, Acc : 0.8000, Test Acc : 0.8200\n",
"Epoch 60 : Loss : 0.7028, Acc : 0.8000, Test Acc : 0.8200\n",
"Epoch 61 : Loss : 0.6923, Acc : 0.8000, Test Acc : 0.8200\n",
"Epoch 62 : Loss : 0.6819, Acc : 0.8000, Test Acc : 0.8200\n",
"Epoch 63 : Loss : 0.6715, Acc : 0.8000, Test Acc : 0.8000\n",
"Epoch 64 : Loss : 0.6611, Acc : 0.8000, Test Acc : 0.8000\n",
"Epoch 65 : Loss : 0.6509, Acc : 0.8000, Test Acc : 0.8000\n",
"Epoch 66 : Loss : 0.6408, Acc : 0.8000, Test Acc : 0.8000\n",
"Epoch 67 : Loss : 0.6307, Acc : 0.8000, Test Acc : 0.8000\n",
"Epoch 68 : Loss : 0.6207, Acc : 0.8000, Test Acc : 0.8000\n",
"Epoch 69 : Loss : 0.6108, Acc : 0.8000, Test Acc : 0.8000\n",
"Epoch 70 : Loss : 0.6011, Acc : 0.8000, Test Acc : 0.8000\n",
"Epoch 71 : Loss : 0.5914, Acc : 0.8000, Test Acc : 0.8000\n",
"Epoch 72 : Loss : 0.5816, Acc : 0.8000, Test Acc : 0.8000\n",
"Epoch 73 : Loss : 0.5716, Acc : 0.9000, Test Acc : 0.8000\n",
"Epoch 74 : Loss : 0.5616, Acc : 0.9000, Test Acc : 0.8000\n",
"Epoch 75 : Loss : 0.5515, Acc : 0.9000, Test Acc : 0.8000\n",
"Epoch 76 : Loss : 0.5415, Acc : 0.9000, Test Acc : 0.8000\n",
"Epoch 77 : Loss : 0.5313, Acc : 0.9000, Test Acc : 0.7800\n",
"Epoch 78 : Loss : 0.5210, Acc : 0.9000, Test Acc : 0.7800\n",
"Epoch 79 : Loss : 0.5108, Acc : 0.9000, Test Acc : 0.8000\n",
"Epoch 80 : Loss : 0.5006, Acc : 0.9000, Test Acc : 0.8000\n",
"Epoch 81 : Loss : 0.4904, Acc : 0.9000, Test Acc : 0.8000\n",
"Epoch 82 : Loss : 0.4802, Acc : 0.9000, Test Acc : 0.8000\n",
"Epoch 83 : Loss : 0.4701, Acc : 0.9000, Test Acc : 0.8000\n",
"Epoch 84 : Loss : 0.4600, Acc : 0.9000, Test Acc : 0.8000\n",
"Epoch 85 : Loss : 0.4501, Acc : 0.9000, Test Acc : 0.8000\n",
"Epoch 86 : Loss : 0.4402, Acc : 0.9000, Test Acc : 0.8000\n",
"Epoch 87 : Loss : 0.4305, Acc : 0.9000, Test Acc : 0.8000\n",
"Epoch 88 : Loss : 0.4210, Acc : 0.9000, Test Acc : 0.8000\n",
"Epoch 89 : Loss : 0.4115, Acc : 0.9000, Test Acc : 0.8000\n",
"Epoch 90 : Loss : 0.4020, Acc : 0.9000, Test Acc : 0.7800\n",
"Epoch 91 : Loss : 0.3926, Acc : 0.9000, Test Acc : 0.7800\n",
"Epoch 92 : Loss : 0.3833, Acc : 0.9000, Test Acc : 0.7800\n",
"Epoch 93 : Loss : 0.3741, Acc : 0.9000, Test Acc : 0.7800\n",
"Epoch 94 : Loss : 0.3650, Acc : 0.9000, Test Acc : 0.7800\n",
"Epoch 95 : Loss : 0.3560, Acc : 0.9000, Test Acc : 0.8000\n",
"Epoch 96 : Loss : 0.3471, Acc : 0.9000, Test Acc : 0.8000\n",
"Epoch 97 : Loss : 0.3381, Acc : 0.9000, Test Acc : 0.8000\n",
"Epoch 98 : Loss : 0.3292, Acc : 0.9000, Test Acc : 0.7800\n",
"Epoch 99 : Loss : 0.3204, Acc : 0.9000, Test Acc : 0.7800\n",
"Epoch 100 : Loss : 0.3116, Acc : 0.9000, Test Acc : 0.7600\n",
"Epoch 101 : Loss : 0.3028, Acc : 0.9000, Test Acc : 0.7600\n",
"Epoch 102 : Loss : 0.2940, Acc : 0.9000, Test Acc : 0.7600\n",
"Epoch 103 : Loss : 0.2853, Acc : 0.9000, Test Acc : 0.7600\n",
"Epoch 104 : Loss : 0.2766, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 105 : Loss : 0.2681, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 106 : Loss : 0.2597, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 107 : Loss : 0.2513, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 108 : Loss : 0.2429, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 109 : Loss : 0.2348, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 110 : Loss : 0.2279, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 111 : Loss : 0.2209, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 112 : Loss : 0.2134, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 113 : Loss : 0.2061, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 114 : Loss : 0.1991, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 115 : Loss : 0.1921, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 116 : Loss : 0.1852, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 117 : Loss : 0.1783, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 118 : Loss : 0.1716, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 119 : Loss : 0.1650, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 120 : Loss : 0.1587, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 121 : Loss : 0.1527, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 122 : Loss : 0.1468, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 123 : Loss : 0.1412, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 124 : Loss : 0.1360, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 125 : Loss : 0.1309, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 126 : Loss : 0.1258, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 127 : Loss : 0.1212, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 128 : Loss : 0.1171, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 129 : Loss : 0.1130, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 130 : Loss : 0.1089, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 131 : Loss : 0.1050, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 132 : Loss : 0.1012, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 133 : Loss : 0.0976, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 134 : Loss : 0.0942, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 135 : Loss : 0.0910, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 136 : Loss : 0.0878, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 137 : Loss : 0.0850, Acc : 1.0000, Test Acc : 0.7400\n",
"Epoch 138 : Loss : 0.0823, Acc : 1.0000, Test Acc : 0.7400\n",
"Epoch 139 : Loss : 0.0796, Acc : 1.0000, Test Acc : 0.7400\n",
"Epoch 140 : Loss : 0.0768, Acc : 1.0000, Test Acc : 0.7400\n",
"Epoch 141 : Loss : 0.0743, Acc : 1.0000, Test Acc : 0.7400\n",
"Epoch 142 : Loss : 0.0719, Acc : 1.0000, Test Acc : 0.7400\n",
"Epoch 143 : Loss : 0.0696, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 144 : Loss : 0.0675, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 145 : Loss : 0.0655, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 146 : Loss : 0.0635, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 147 : Loss : 0.0615, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 148 : Loss : 0.0597, Acc : 1.0000, Test Acc : 0.7600\n",
"Epoch 149 : Loss : 0.0580, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 150 : Loss : 0.0563, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 151 : Loss : 0.0547, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 152 : Loss : 0.0532, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 153 : Loss : 0.0518, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 154 : Loss : 0.0504, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 155 : Loss : 0.0491, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 156 : Loss : 0.0478, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 157 : Loss : 0.0465, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 158 : Loss : 0.0454, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 159 : Loss : 0.0443, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 160 : Loss : 0.0433, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 161 : Loss : 0.0423, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 162 : Loss : 0.0412, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 163 : Loss : 0.0402, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 164 : Loss : 0.0394, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 165 : Loss : 0.0385, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 166 : Loss : 0.0377, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 167 : Loss : 0.0368, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 168 : Loss : 0.0360, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 169 : Loss : 0.0353, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 170 : Loss : 0.0345, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 171 : Loss : 0.0338, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 172 : Loss : 0.0332, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 173 : Loss : 0.0325, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 174 : Loss : 0.0318, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 175 : Loss : 0.0312, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 176 : Loss : 0.0306, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 177 : Loss : 0.0301, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 178 : Loss : 0.0295, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 179 : Loss : 0.0290, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 180 : Loss : 0.0285, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 181 : Loss : 0.0280, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 182 : Loss : 0.0275, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 183 : Loss : 0.0270, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 184 : Loss : 0.0265, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 185 : Loss : 0.0261, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 186 : Loss : 0.0256, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 187 : Loss : 0.0252, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 188 : Loss : 0.0248, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 189 : Loss : 0.0244, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 190 : Loss : 0.0240, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 191 : Loss : 0.0236, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 192 : Loss : 0.0232, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 193 : Loss : 0.0229, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 194 : Loss : 0.0225, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 195 : Loss : 0.0222, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 196 : Loss : 0.0218, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 197 : Loss : 0.0215, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 198 : Loss : 0.0212, Acc : 1.0000, Test Acc : 0.7800\n",
"Epoch 199 : Loss : 0.0209, Acc : 1.0000, Test Acc : 0.7800\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"import math\n",
"\n",
"# Données et modèle du problème des 2 clusters\n",
"x_train_lab, y_train_lab, x_train_unlab, y_train_unlab, x_test, y_test = generate_2moons_dataset(num_lab=10, num_unlab=90, num_test=50)\n",
"model = create_model_2moons()\n",
"\n",
"# Hyperparamètres de l'apprentissage\n",
"lambdaa = 0.25\n",
"epochs = 200\n",
"batch_size = 16\n",
"if batch_size < x_train_lab.shape[0]:\n",
" steps_per_epoch = math.floor(x_train_lab.shape[0]/batch_size)\n",
"else:\n",
" steps_per_epoch = 1\n",
" batch_size = x_train_lab.shape[0]\n",
"\n",
"# Instanciation d'un optimiseur et d'une fonction de coût.\n",
"optimizer = keras.optimizers.Adam(learning_rate=1e-2)\n",
"loss_fn = keras.losses.BinaryCrossentropy()\n",
"\n",
"# Préparation des métriques pour le suivi de la performance du modèle.\n",
"train_acc_metric = keras.metrics.BinaryAccuracy()\n",
"test_acc_metric = keras.metrics.BinaryAccuracy()\n",
"\n",
"# Indices de l'ensemble labellisé\n",
"indices = np.arange(x_train_lab.shape[0])\n",
"indices_unlab = np.arange(x_train_unlab.shape[0])\n",
"\n",
"# Boucle sur les epochs\n",
"for epoch in range(epochs):\n",
"\n",
" # A chaque nouvelle epoch, on randomise les indices de l'ensemble labellisé\n",
" np.random.shuffle(indices)\n",
" np.random.shuffle(indices_unlab)\n",
"\n",
" # Et on recommence à cumuler la loss\n",
" cum_loss_value = 0\n",
"\n",
" for step in range(steps_per_epoch):\n",
"\n",
" # Sélection des données du prochain batch\n",
" x_batch = x_train_lab[indices[step*batch_size: (step+1)*batch_size]]\n",
" x_batch_unlab = x_train_unlab[indices_unlab[step*batch_size: (step+1)*batch_size]]\n",
" y_batch = y_train_lab[indices[step*batch_size: (step+1)*batch_size]]\n",
"\n",
" # Etape nécessaire pour comparer y_batch à la sortie du réseau\n",
" y_batch = np.expand_dims(y_batch, 1)\n",
"\n",
" # Les opérations effectuées par le modèle dans ce bloc sont suivies et permettront\n",
" # la différentiation automatique.\n",
" with tf.GradientTape() as tape:\n",
"\n",
" # Application du réseau aux données d'entrée\n",
" y_pred = model(x_batch, training=True) # Logits for this minibatch\n",
" y_pred_unlab = model(x_batch_unlab, training=True)\n",
"\n",
" # Calcul de la fonction de perte sur ce batch\n",
" loss_value = loss_fn(y_batch, y_pred) + lambdaa * binary_entropy_loss(y_pred)\n",
"\n",
" # Calcul des gradients par différentiation automatique\n",
" grads = tape.gradient(loss_value, model.trainable_weights)\n",
"\n",
" # Réalisation d'une itération de la descente de gradient (mise à jour des paramètres du réseau)\n",
" optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
"\n",
" # Mise à jour de la métrique\n",
" train_acc_metric.update_state(y_batch, y_pred)\n",
"\n",
" cum_loss_value = cum_loss_value + loss_value\n",
"\n",
" # Calcul de la précision à la fin de l'epoch\n",
" train_acc = train_acc_metric.result()\n",
"\n",
" # Calcul de la précision sur l'ensemble de test à la fin de l'epoch\n",
" test_logits = model(x_test, training=False)\n",
" test_acc_metric.update_state(np.expand_dims(y_test, 1), test_logits)\n",
" test_acc = test_acc_metric.result()\n",
"\n",
" print(\"Epoch %4d : Loss : %.4f, Acc : %.4f, Test Acc : %.4f\" % (epoch, float(cum_loss_value/steps_per_epoch), float(train_acc), float(test_acc)))\n",
"\n",
" # Remise à zéro des métriques pour la prochaine epoch\n",
" train_acc_metric.reset_states()\n",
" test_acc_metric.reset_states()"
]
},
{
"cell_type": "code",
"execution_count": 189,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD4CAYAAADhNOGaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAnQUlEQVR4nO3df5Ac9Xnn8fej1UqwghgQAhN+C3PBzqXKtmQFCYhlwBzoD/BVfAl25cCHEsmJfbYrqiMQjLWgMxBSJSspO1Z0SGdIOcYJuXKUWBQ2AsVV0RIkKH4YbAcBSYyChQw+DiHQrqTn/vh2W63R/Oie7pnpnv68qqZmpqd757uzs/3099fzNXdHRETqa9qgCyAiIoOlQCAiUnMKBCIiNadAICJScwoEIiI1N33QBejGiSee6GedddagiyEiUimPPfbYT919TuP2SgaCs846i+3btw+6GCIilWJm/9psu5qGRERqToFARKTmFAhERGpOgUBEpOYUCEREak6BQESk5hQIpDYmJuD228O9iBxSyXkEIllNTMAll8DkJMyYAZs3w8KFgy6VSDmoRiC1sGVLCAIHDoT7LVsGXSKR8lAgkNyKbnIZHy/m5yQtXhxqAiMj4X7x4uLfQ6SqrIorlM2fP9+VYqIcetHkYgZ5v5YTE+Gqf/HiQ+Vptk2kTszsMXef37hdfQSSy5YtsG8fHDwY7rdsGfxJtlVwim8icjg1DdVYEU0ws2eHIADhfvbs7stiFm5w6HE3ZczaH+AOW7fCmjWwalW437o1f61EpCpUIxgyWZo/brklfzB49VWYNi0EgWnTwvNujI8fKkvepqG4PyCuEbTqD5iagvXr4c474ZVXwvOpKRgdDbeTToLrr4elS8NzkWGlQDBEBjFEcvFimDmz80m3nxYuDL97u4C4Zw9ccQU8/jjs3Xv4a5OT4fbii7BiBfzlX8KmTXDMMf0ovUj/qWloiKRpEimyCQYOnXRXrWodeLL+7JUruytLY7luvLF5eaamQhDYtu3IINBo71549FFYsiQcJzKMNGpoiGStERQxOieNfr1PsnmpnbVrw5V+pyCQNDYGq1fD8uXdlk5k8FqNGlIgGDJZ+giGLRCkeR93OOec0OyT1dy5sGPHodqUSNW0CgRqGhoy7ZpEGhXRBNNK0U1QRZmYCB3D3di1S3mKZDgVEgjMbIOZvWJm32/xupnZn5rZDjN7yszen3jtWjN7LrpdW0R5JJ2iT8rJGcbj4+HqO75Cjx8X/Z5ZA86jj3bf1r9/f+hXEBk2RY0a+hrwZeCeFq9fAZwb3X4V+Crwq2Z2ArASmA848JiZbXT3nxVUrqFVtlmycf/Evn0wfTr87u+GOQXHHhted+9Nk0rWYadvvNF9IJicDMeLDJtCAoG7f8/Mzmqzy1XAPR46JB4xs+PM7BRgMfBdd38NwMy+C1wOfKOIcg2rMmbS3LwZ3n47nIgnJ+HLXw5zC+Ix+eecU44x+cceG95/cjL7sTNmHApsIsOkX30EpwI/Tjx/KdrWavsRzGyZmW03s+27d+/uWUGroGyZNPfsgfvuO/xq/MCBQ0FhaurQmPxLLgn792JtgDR9HgsWdB+Ipk+HD3ygu2NFyqwyE8rcfR2wDsKooQEXZ6DSzpzth3hM/g9/2HnfeEz+hRfCj34Ujs1ao2nXJJam/2HhQnjHO+DNN9O9X9LJJw++5iXSC/2qEewETk88Py3a1mq7tJFmEle/rF8fZufu25du/3374Omnw33WGk3cJHbzzeE+a21ifDykwfj3f892HIR5BNdfr6GjMpz6FQg2AtdEo4fOB15395eBB4DLzOx4MzseuCzaJh1kGSbaTBGjd9xDnp4sE7PgUJK6adOy1WjyNonFI5m+971sx82cCfPmwXXXZTtOpCqKGj76DWAC+CUze8nMlprZJ83sk9Eum4AXgB3A/wJ+DyDqJF4FbItut8Ydx9Jbt9yS/2fkGZM/cyb8zu9kq9EUtbjMRReF+zPPhKOOar/v2FjoV9i0SYnnZHgVNWroYx1ed+BTLV7bAGwoohzSX3nG5L/9Nvz5n8M735k+EKRJJpfWypVw002wYUOo1ezaFeYJxP0u06eHPoHrrw81AQUBGWaV6SyW/MbHD68JxO3dK1d211SUZ0y+Gdx6K3z+89mOK2pxmfj3Xb4cli0LtZtt28LvdOyxoRZw/vnqE5B6UCCokSJz/sPwjMk3g0WLwk2kjpRrSLqmMfkiw0GBoKaKyvl/0kndHasx+SLloUAwhNK097faJ8uMX7PQmTo2lqV0GpMvUjZaj2AIddv+300Oo6kpuPji0NGaZlLZzJmhSWnz5nKPxClbUj+RImg9Aumomwlbo6Nw//3h5N6pZlCVMfl5ZzCLVI0CQYW0a8655JL8C8F0O2HrmGPCFf7q1WEVr1mzwpW/WbifNStsX7067Ff2ReDLltRPpNfUNFQhzZp8mjXnLFrU/dDQvE0i7tUfk1/GNN8iRWjVNKR5BBVX9NVr3glbwzAmv8gZzCJVoECQEF/NPvro4VezCxcO7mq202zgZimpe7kWcV0UNYNZpArUNEQY+bJ+fcg588or4fnU1KHVtU46abCra8UB6oILQlqGxgClES7N6XMROVyrpqHaB4I9e8LCKo8/3j6d8thYSEW8aVP/OjsbA9Sbb4YTf1kCVJmpnV/kSBo+2kS8uta2bZ1z6seray1Z0j7RWlFLMO7ZE8bnr1gRlnmMV9SKl398880jl3+UQzTyRyS9WgeCblbXeuyxkLq4maLGn/ciQNVNUWsXZNWLtZhFeq22gaDb1bX27g3HNWtRK+oqtOgAVUd5lvPs9mSuiWhSVbUNBHlW19q1q/k/edar0GYnnF4EqDyqfIXbzXKeeU7mao6Sqqrt8NE8q2vt3x+abRrHymcZf96qM7OIAFXUGP46drg2O5knf+d2I5GaDeXtdIxIGdQ2EORZXWtyMhzfTNrx561OOL0IUN3qdFIcRq1O5tA5MDa7EKhjMJXqqW3TULy6VjeKWF2rVTNSrwJUN0tRDqrDdZDa9S2kafppbI5Sc5FUQSE1AjO7HPgTYAS4y93vaHj9S8CHoqdjwEnuflz02gHg6ei1f3P3K4soUyfx6lrdLLNYxOparZqRerX84y23ZA8GdU210KpW16620Eo3x4j0W+4JZWY2Avwz8GHgJWAb8DF3f7bF/v8deJ+7Xxc93+PumaZoFTGhzB3OOSeMxc9q7lzYsaM3aSe2boXLLjs0byCLWbPgO99p3jRUxBrF0l17v/oIpCx6OaFsAbDD3V9w90ngXuCqNvt/DPhGAe+bS1lX1ypy+cfx8fypqeVw3YxE6uYYkX4qIhCcCvw48fylaNsRzOxM4GzgocTmo8xsu5k9YmYfafUmZrYs2m/77t27Cyh2SM3w/veHnPlpzJwZ0kxcd10hb99UkQFqfDzUAuKaQPxYgUBEkvrdWXw1cJ+7H0hsOzOqqnwcWGNm5zQ70N3Xuft8d58/Z86cQgpT1tW1yhig5HAKpjJMiggEO4HTE89Pi7Y1czUNzULuvjO6fwHYAryvgDKllmV1rQcfhKeegjVrwqiSNWtCm37Rbe9FBajkZDClpi5WMjX4IFV5wp+URxGjhrYB55rZ2YQAcDXh6v4wZnYecDwwkdh2PLDX3feZ2YnABcCdBZQpk9FRWL4cli1rvrrWvHkhfcN55/UvTXUcoDZsCDOGd+0K8wTi0SfTp4c+geuvDzWBZkGgcfy6DJd16+BTn4KDB8NFi+YoSNfcPfcNWEIYOfQ8cFO07VbgysQ+48AdDcctIgwdfTK6X5rm/ebNm+f98sYb7hde6D42FrewN7+NjblfdFHYP7ZyZTFlOHjQ/R//0X3NGvdVq8L91q1heyu33eY+MhLKNjISnks+K1c2/9sX9XfOYutW99HRQ2WYNk1/Y+kM2O5Nzqm1X4+gnampkAp627Z0CeBmzgw1iM2bwxX6IIdsakZrbw16OO7tt8PnPx9qAxBqiN/7nv7G0p7WI+hClbOA5sm+KeW3eHG48Jg2LVx0fOUr+htL91QjaCHPhLNm4jWGZ
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAuoElEQVR4nO3dd3xUVfr48c+THkhCCx2kCEoHUREsCLhLM8KyKgKi2EARdHf1Z2Gx7ndXVnfFVUA6AqsiuoogoIhKtwDSOwgICTWQhDRCkjm/P5KJkzBJJpnJlNzn/XrxInPn5p4TynPPnPPc54gxBqWUUpVfkK87oJRSyjs04CullEVowFdKKYvQgK+UUhahAV8ppSwixNcdKMmqfWc0hUgppcqgZ6s6Utx7fh3wD51J83UXlFIqoPRsVafY93RKRymlLEIDvlJKWYTbAV9EGovIKhHZIyK7ReRPTs4REXlHRA6JyA4R6exuu0oppcrGE3P4OcDTxpgtIhIN/CwiK40xexzO6Qe0zP91AzA1//cyEwzVQm1EBINIsWsTPmOM4WIupGQHYfC//imlrMvtgG+MOQmczP86VUT2Ag0Bx4A/EJhv8gr3/Cgi1UWkfv73lkm1UBvVq0ZgkxDww4CPMUSYHEi/SHJ2sK97o5RSBTw6hy8iTYFrgJ+KvNUQOO7wOj7/WJlFBOO/wR5ABJuEEKGxXinlZzwW8EUkCvgU+LMx5oIb1xklIptFZPPaJQucve+/wd5OxC+nm5RS1uaRPHwRCSUv2H9gjPnMySkJQGOH143yj13GGDMDmAEwc+1hffBKKaU8xBNZOgLMBvYaYyYWc9oS4P78bJ2uQEp55u/9xeb13/HwHTfzYP9uLJw1ydfdUUopl3hihH8TcB+wU0S25R/7K3AFgDFmGrAc6A8cAjKABz3Qrk/k5uYy5R9/5bUZC4mtV58nh/Sja8/eNLnyal93TSmlSuSJLJ31UHL+YX52zhh32yqrP90/iJQLly8nVIuJ4e35i8p1zf07t1L/iqbUb9wEgFv7DeSHVSs04Cul/J5f19JxV8qFC7QcNfmy4wdnjC33Nc+dOUXter8lGMXWrc/+HVvLfT2llPIWLa2glFIWoQG/jGrVqcfZU78lGCWePkmtuvV82COllHKNBvwyurpdJ078eoRT8cfIzr7Emi8X07VHH193SymlSlWp5/ArQnBICI//9TXGPzYUW24uvQcNoWkLXbBVSvm/Sh3wq8XEOF2grRYT49Z1u3S/jS7db3PrGkop5W2VOuCXN/VSKaUqI53DV0opi9CAr5RSFqEBXymlLEIDvlJKWYQGfKWUsggN+OUw8cW/cM+t7Xh0UA9fd0UppVymAb8cfj9wMH+f+qGvu6GUUmViiYCfknSOfzw5nAvJ5z1yvfbXdSO6Wg2PXEsppbzFEgH/u88/wHZiO98uet/XXVFKKZ+p9AE/JekcW1f+j//8sRFbV/7PY6N8pZQKNJU+4H/3+Qfc0QJa1o3kjhboKF8pZVmVOuDbR/fDrq0GwLBrq+koXyllWZU64NtH97WiQoG83z0xyp/w7Gj+MjyO+KO/MPy2znz1mWbsKKX8X6Wulrlz4zrWnbzIgh3xhY5XP7uOQQ8+We7rjntjqrtdU0opr/NIwBeROUAccMYY087J+z2AxcCR/EOfGWP+5om2S/LS1E8qugmllAoYnhrhzwUmA/NLOGedMSbOQ+0ppZQqI4/M4Rtj1gJeWQk1xoAx3miq/IzJ66dSSvkRby7adhOR7SLypYi0Le4kERklIptFZPPaJQsue/9iLgSZHP8N+sYQZHK4mOvrjiilVGHeWrTdAjQxxqSJSH/gc6ClsxONMTOAGQAz1x6+LKqnZAdB+kUigkFEKrDL5WOM4WJufj+VUsqPeCXgG2MuOHy9XETeFZFYY0xima+FkJwdDNme7aNSSlV2XhmGikg9yR+Oi0iX/HbPeaNtpZRSeTyVlrkA6AHEikg88DIQCmCMmQbcBYwWkRwgExhidFVTKaW8yiMB3xgztJT3J5OXtqmUUspHdGVRKaUsQgO+UkpZhAZ8pZSyCA34SillERrwlVLKIjTgK6WURWjAV0opi9CAr5RSFqEBXymlLEIDvlJKWYQGfKWUsggN+EopZREa8JVSyiK8teOVUj41YexQ0tJSLzseFRXNuMmXb6WpVGWkAV95nS+Cb1paKs0fmXTZ8cOznqiQ9pTyRxrwldelpaVSpc9fyM0tvNP78Y9eYMLYoR4L+o43lqTEMyQcPQhAcHAw9Ro390gbSgUSDfjKJZ4elefm5hIee0WhY6FRNZ22UV6Oo/odk0cXtJeVeMxjbSgVSDTgK5folIhSgU+zdJRSyiJ0hK98JudSFo472efkZJOUeN6j8/h2wRFVODH3zwBkp50nK7YOkDclpZRVeCTgi8gcIA44Y4xp5+R9Ad4G+gMZwAPGmC2eaFtVjKJz9vZFT08seEZFRXP8oxeQiGgk+Ld/gsGRMUhIqEfn8e3aPvJmwdeHZz3BP+Yu9XgbSvk7T43w5wKTgfnFvN8PaJn/6wZgav7vyk8VnbO3L3p6YsFz3OQFTBg7lONHfyEkquZvb2RnEhxRBbgEuL9QHBUVfdkaQ8q5RIwth/EPxJXrmkoFMo8EfGPMWhFpWsIpA4H5xhgD/Cgi1UWkvjHmpCfaVxXPPiXiOB0C5Z8SGTd5AeMfiCtxIdiVheKy3hRKa1Opysxbc/gNgeMOr+Pzj10W8EVkFDAKYPjTf6f7gKFe6aAqmX1KxN+mQzR7SCnX+d2irTFmBjADYObaw6aU05VSSrnIWwE/AWjs8LpR/jFlYc7m2O3HlVKe562AvwQYKyIfkbdYm6Lz9/7NnWDs6ry6LpIq5V2eSstcAPQAYkUkHngZCAUwxkwDlpOXknmIvLTMBz3Rrqo49mBcNHinpaUy/oG4ErNaPDWvXhGfAIq7Zur5s5dl7tjP1xuTqiw8laVT4spqfnbOGE+0pbzLl4ui5U29tB8vyzU1e0dZgd8t2ipVFjr6Vsp1GvArId3sI4/+OShVmAb8Skhz0/N4+s9BbyAq0GnAtxhvBC3HefWUc4nYjA0AMbaChVF/CpITxg4ttEGKXXBwcKHXeiNVgU4DvsWUNWiVJ1PGMZAHwmJoWloqEbGNOLd0YsGxSxfOYoCQ4NCCm1RS4hl2z3q64Knj3bOeJvdiBtlp5wtl+PjTzUwpRxrwVYmsErgcq2lCXrG4WnFP0bBpy4JjCUcPFrop5F7MoMED/yEr8Vih8/zpZqaUI90ARSmlLEJH+JVQSdMwFVFr3l9p6QalCtOAXwmVNA3j7GnSsgikTBWP75oVHEx22vmCm0h22nmyEo9dtrjriicH3ECO7fLagCFBwjtLfnK7r0o5owHfYtwd9Xpj0dfbnPUxO+38ZYG8XuPmZMTWKSgPPf6BuEJz92WRYzM0GXv5fkG/Tr6/XNdTyhUa8C3G26Nwfxv1O+OsjxPGDiVtxVscLnLc8UYVCDczpRxpwFfKCVduVMXeKPILzDnyxykvZT0a8JXyIH04S/kzTctUSimL0BG+KhOdt/aMkCBxukAbEiQ+6I2yCg34AcQfUiLL2056agpnE34l9dQR0k4dhaxUwslGgEzCqNepF806dEXEGgFPUy+VL2jADyCBND+cfSmL/T99Q8rBTVQLyqReTCjXNahGy+bVaXpjY6KqhBecm5tr4/PvN/DF9EVUb9Od1jf1I6gcue2BwJWbtj/c2FXlpAFfedS5Uwns/24BMVmnGXpTM27ocV2po/bg4CDuvKU1d97SmjXbj/DBnOcJbtiB9rcNJjQ8vMTv9TelTXm5ctMOpBu7v0hNPs9H/3qGoc/+m6hqNXzdHb+lAV+5zWazcXDzahK3f0u7OsG8Nag91aNbl+tat3Zsxq0dm7H7yEmmf/gCGVFNaNf3PqpGV/Nwr4vnzghbR+C+senLhYSc3snG5R/Ra+hoX3fHb2nAV+WWlpLEnm8/JvjcIQZd34jej3Xz2Bx83RrRJPyyj1dHXsUHSyZwylaDVn3uo2adBh65fkl0hB1YUpPPs3/tI
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from mlxtend.plotting import plot_decision_regions\n",
"\n",
"# Affichage des données\n",
"plt.plot(x_train_unlab[y_train_unlab==0,0], x_train_unlab[y_train_unlab==0,1], 'b.')\n",
"plt.plot(x_train_unlab[y_train_unlab==1,0], x_train_unlab[y_train_unlab==1,1], 'r.')\n",
"\n",
"plt.plot(x_test[y_test==0,0], x_test[y_test==0,1], 'b+')\n",
"plt.plot(x_test[y_test==1,0], x_test[y_test==1,1], 'r+')\n",
"\n",
"plt.plot(x_train_lab[y_train_lab==0,0], x_train_lab[y_train_lab==0,1], 'b.', markersize=30)\n",
"plt.plot(x_train_lab[y_train_lab==1,0], x_train_lab[y_train_lab==1,1], 'r.', markersize=30)\n",
"\n",
"plt.show()\n",
"\n",
"#Affichage de la frontière de décision\n",
"plot_decision_regions(x_train_unlab, y_train_unlab, clf=model, legend=2)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 113,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0 : Loss : 0.4852, Acc : 0.7000, Test Acc : 0.6111\n",
"Epoch 1 : Loss : 0.4787, Acc : 0.7000, Test Acc : 0.6444\n",
"Epoch 2 : Loss : 0.4723, Acc : 0.7000, Test Acc : 0.6444\n",
"Epoch 3 : Loss : 0.4660, Acc : 0.7000, Test Acc : 0.6444\n",
"Epoch 4 : Loss : 0.4598, Acc : 0.7000, Test Acc : 0.6556\n",
"Epoch 5 : Loss : 0.4537, Acc : 0.7000, Test Acc : 0.6667\n",
"Epoch 6 : Loss : 0.4478, Acc : 0.7000, Test Acc : 0.6667\n",
"Epoch 7 : Loss : 0.4420, Acc : 0.7000, Test Acc : 0.6667\n",
"Epoch 8 : Loss : 0.4363, Acc : 0.7000, Test Acc : 0.6667\n",
"Epoch 9 : Loss : 0.4307, Acc : 0.7000, Test Acc : 0.6667\n",
"Epoch 10 : Loss : 0.4252, Acc : 0.7000, Test Acc : 0.6667\n",
"Epoch 11 : Loss : 0.4199, Acc : 0.7000, Test Acc : 0.6778\n",
"Epoch 12 : Loss : 0.4147, Acc : 0.7000, Test Acc : 0.6889\n",
"Epoch 13 : Loss : 0.4095, Acc : 0.7000, Test Acc : 0.6889\n",
"Epoch 14 : Loss : 0.4045, Acc : 0.7000, Test Acc : 0.6889\n",
"Epoch 15 : Loss : 0.3996, Acc : 0.7000, Test Acc : 0.6889\n",
"Epoch 16 : Loss : 0.3948, Acc : 0.7000, Test Acc : 0.7111\n",
"Epoch 17 : Loss : 0.3900, Acc : 0.7000, Test Acc : 0.7222\n",
"Epoch 18 : Loss : 0.3854, Acc : 0.7000, Test Acc : 0.7222\n",
"Epoch 19 : Loss : 0.3808, Acc : 0.7000, Test Acc : 0.7333\n",
"Epoch 20 : Loss : 0.3764, Acc : 0.7000, Test Acc : 0.7333\n",
"Epoch 21 : Loss : 0.3720, Acc : 0.7000, Test Acc : 0.7333\n",
"Epoch 22 : Loss : 0.3677, Acc : 0.7000, Test Acc : 0.7444\n",
"Epoch 23 : Loss : 0.3635, Acc : 0.7000, Test Acc : 0.7444\n",
"Epoch 24 : Loss : 0.3593, Acc : 0.7000, Test Acc : 0.7444\n",
"Epoch 25 : Loss : 0.3552, Acc : 0.7000, Test Acc : 0.7444\n",
"Epoch 26 : Loss : 0.3512, Acc : 0.7000, Test Acc : 0.7556\n",
"Epoch 27 : Loss : 0.3473, Acc : 0.7000, Test Acc : 0.7556\n",
"Epoch 28 : Loss : 0.3434, Acc : 0.7000, Test Acc : 0.7556\n",
"Epoch 29 : Loss : 0.3396, Acc : 0.7000, Test Acc : 0.7667\n",
"Epoch 30 : Loss : 0.3358, Acc : 0.7000, Test Acc : 0.7667\n",
"Epoch 31 : Loss : 0.3321, Acc : 0.7000, Test Acc : 0.7667\n",
"Epoch 32 : Loss : 0.3285, Acc : 0.7000, Test Acc : 0.7667\n",
"Epoch 33 : Loss : 0.3249, Acc : 0.7000, Test Acc : 0.7667\n",
"Epoch 34 : Loss : 0.3214, Acc : 0.7000, Test Acc : 0.7667\n",
"Epoch 35 : Loss : 0.3179, Acc : 0.7000, Test Acc : 0.7778\n",
"Epoch 36 : Loss : 0.3144, Acc : 0.7000, Test Acc : 0.7778\n",
"Epoch 37 : Loss : 0.3110, Acc : 0.7000, Test Acc : 0.7778\n",
"Epoch 38 : Loss : 0.3077, Acc : 0.7000, Test Acc : 0.7778\n",
"Epoch 39 : Loss : 0.3044, Acc : 0.7000, Test Acc : 0.7778\n",
"Epoch 40 : Loss : 0.3012, Acc : 0.7000, Test Acc : 0.7778\n",
"Epoch 41 : Loss : 0.2980, Acc : 0.8000, Test Acc : 0.7778\n",
"Epoch 42 : Loss : 0.2948, Acc : 0.8000, Test Acc : 0.7778\n",
"Epoch 43 : Loss : 0.2917, Acc : 0.8000, Test Acc : 0.7889\n",
"Epoch 44 : Loss : 0.2886, Acc : 0.8000, Test Acc : 0.7889\n",
"Epoch 45 : Loss : 0.2856, Acc : 0.8000, Test Acc : 0.7889\n",
"Epoch 46 : Loss : 0.2826, Acc : 0.8000, Test Acc : 0.7889\n",
"Epoch 47 : Loss : 0.2796, Acc : 0.8000, Test Acc : 0.7889\n",
"Epoch 48 : Loss : 0.2767, Acc : 0.8000, Test Acc : 0.7889\n",
"Epoch 49 : Loss : 0.2738, Acc : 0.8000, Test Acc : 0.7889\n",
"Epoch 50 : Loss : 0.2710, Acc : 0.8000, Test Acc : 0.8000\n",
"Epoch 51 : Loss : 0.6722, Acc : 0.8000, Test Acc : 0.8111\n",
"Epoch 52 : Loss : 0.6680, Acc : 0.8000, Test Acc : 0.8111\n",
"Epoch 53 : Loss : 0.6631, Acc : 0.8000, Test Acc : 0.8111\n",
"Epoch 54 : Loss : 0.6577, Acc : 0.8000, Test Acc : 0.8111\n",
"Epoch 55 : Loss : 0.6520, Acc : 0.8000, Test Acc : 0.8222\n",
"Epoch 56 : Loss : 0.6462, Acc : 0.8000, Test Acc : 0.8222\n",
"Epoch 57 : Loss : 0.6403, Acc : 0.8000, Test Acc : 0.8222\n",
"Epoch 58 : Loss : 0.6344, Acc : 0.8000, Test Acc : 0.8222\n",
"Epoch 59 : Loss : 0.6285, Acc : 0.8000, Test Acc : 0.8222\n",
"Epoch 60 : Loss : 0.6227, Acc : 0.8000, Test Acc : 0.8222\n",
"Epoch 61 : Loss : 0.6171, Acc : 0.8000, Test Acc : 0.8222\n",
"Epoch 62 : Loss : 0.6116, Acc : 0.8000, Test Acc : 0.8222\n",
"Epoch 63 : Loss : 0.6062, Acc : 0.8000, Test Acc : 0.8222\n",
"Epoch 64 : Loss : 0.6011, Acc : 0.8000, Test Acc : 0.8222\n",
"Epoch 65 : Loss : 0.5961, Acc : 0.8000, Test Acc : 0.8222\n",
"Epoch 66 : Loss : 0.5913, Acc : 0.8000, Test Acc : 0.8222\n",
"Epoch 67 : Loss : 0.5866, Acc : 0.8000, Test Acc : 0.8333\n",
"Epoch 68 : Loss : 0.5822, Acc : 0.8000, Test Acc : 0.8333\n",
"Epoch 69 : Loss : 0.5779, Acc : 0.8000, Test Acc : 0.8333\n",
"Epoch 70 : Loss : 0.5738, Acc : 0.8000, Test Acc : 0.8333\n",
"Epoch 71 : Loss : 0.5698, Acc : 0.8000, Test Acc : 0.8333\n",
"Epoch 72 : Loss : 0.5660, Acc : 0.8000, Test Acc : 0.8333\n",
"Epoch 73 : Loss : 0.5624, Acc : 0.8000, Test Acc : 0.8333\n",
"Epoch 74 : Loss : 0.5589, Acc : 0.8000, Test Acc : 0.8333\n",
"Epoch 75 : Loss : 0.5555, Acc : 0.8000, Test Acc : 0.8333\n",
"Epoch 76 : Loss : 0.5523, Acc : 0.8000, Test Acc : 0.8333\n",
"Epoch 77 : Loss : 0.5491, Acc : 0.8000, Test Acc : 0.8333\n",
"Epoch 78 : Loss : 0.5461, Acc : 0.8000, Test Acc : 0.8333\n",
"Epoch 79 : Loss : 0.5432, Acc : 0.8000, Test Acc : 0.8444\n",
"Epoch 80 : Loss : 0.5404, Acc : 0.8000, Test Acc : 0.8444\n",
"Epoch 81 : Loss : 0.5377, Acc : 0.8000, Test Acc : 0.8444\n",
"Epoch 82 : Loss : 0.5350, Acc : 0.8000, Test Acc : 0.8556\n",
"Epoch 83 : Loss : 0.5324, Acc : 0.8000, Test Acc : 0.8556\n",
"Epoch 84 : Loss : 0.5299, Acc : 0.9000, Test Acc : 0.8444\n",
"Epoch 85 : Loss : 0.5275, Acc : 0.9000, Test Acc : 0.8444\n",
"Epoch 86 : Loss : 0.5252, Acc : 0.9000, Test Acc : 0.8444\n",
"Epoch 87 : Loss : 0.5228, Acc : 0.9000, Test Acc : 0.8444\n",
"Epoch 88 : Loss : 0.5206, Acc : 0.9000, Test Acc : 0.8556\n",
"Epoch 89 : Loss : 0.5184, Acc : 0.9000, Test Acc : 0.8556\n",
"Epoch 90 : Loss : 0.5162, Acc : 0.9000, Test Acc : 0.8667\n",
"Epoch 91 : Loss : 0.5141, Acc : 0.9000, Test Acc : 0.8667\n",
"Epoch 92 : Loss : 0.5120, Acc : 0.9000, Test Acc : 0.8667\n",
"Epoch 93 : Loss : 0.5100, Acc : 0.9000, Test Acc : 0.8667\n",
"Epoch 94 : Loss : 0.5080, Acc : 0.9000, Test Acc : 0.8667\n",
"Epoch 95 : Loss : 0.5060, Acc : 0.9000, Test Acc : 0.8667\n",
"Epoch 96 : Loss : 0.5041, Acc : 0.9000, Test Acc : 0.8667\n",
"Epoch 97 : Loss : 0.5022, Acc : 0.9000, Test Acc : 0.8667\n",
"Epoch 98 : Loss : 0.5003, Acc : 0.9000, Test Acc : 0.8667\n",
"Epoch 99 : Loss : 0.4984, Acc : 0.9000, Test Acc : 0.8667\n",
"Epoch 100 : Loss : 0.4965, Acc : 0.9000, Test Acc : 0.8667\n",
"Epoch 101 : Loss : 0.4947, Acc : 0.9000, Test Acc : 0.8778\n",
"Epoch 102 : Loss : 0.4929, Acc : 0.9000, Test Acc : 0.8778\n",
"Epoch 103 : Loss : 0.4911, Acc : 1.0000, Test Acc : 0.8778\n",
"Epoch 104 : Loss : 0.4893, Acc : 1.0000, Test Acc : 0.8778\n",
"Epoch 105 : Loss : 0.4875, Acc : 1.0000, Test Acc : 0.8778\n",
"Epoch 106 : Loss : 0.4858, Acc : 1.0000, Test Acc : 0.8778\n",
"Epoch 107 : Loss : 0.4840, Acc : 1.0000, Test Acc : 0.8778\n",
"Epoch 108 : Loss : 0.4822, Acc : 1.0000, Test Acc : 0.8778\n",
"Epoch 109 : Loss : 0.4805, Acc : 1.0000, Test Acc : 0.8778\n",
"Epoch 110 : Loss : 0.4788, Acc : 1.0000, Test Acc : 0.8778\n",
"Epoch 111 : Loss : 0.4770, Acc : 1.0000, Test Acc : 0.8889\n",
"Epoch 112 : Loss : 0.4753, Acc : 1.0000, Test Acc : 0.8889\n",
"Epoch 113 : Loss : 0.4736, Acc : 1.0000, Test Acc : 0.8889\n",
"Epoch 114 : Loss : 0.4719, Acc : 1.0000, Test Acc : 0.8889\n",
"Epoch 115 : Loss : 0.4702, Acc : 1.0000, Test Acc : 0.8889\n",
"Epoch 116 : Loss : 0.4685, Acc : 1.0000, Test Acc : 0.9000\n",
"Epoch 117 : Loss : 0.4668, Acc : 1.0000, Test Acc : 0.9000\n",
"Epoch 118 : Loss : 0.4650, Acc : 1.0000, Test Acc : 0.9000\n",
"Epoch 119 : Loss : 0.4633, Acc : 1.0000, Test Acc : 0.9111\n",
"Epoch 120 : Loss : 0.4616, Acc : 1.0000, Test Acc : 0.9111\n",
"Epoch 121 : Loss : 0.4599, Acc : 1.0000, Test Acc : 0.9111\n",
"Epoch 122 : Loss : 0.4582, Acc : 1.0000, Test Acc : 0.9111\n",
"Epoch 123 : Loss : 0.4565, Acc : 1.0000, Test Acc : 0.9111\n",
"Epoch 124 : Loss : 0.4548, Acc : 1.0000, Test Acc : 0.9111\n",
"Epoch 125 : Loss : 0.4531, Acc : 1.0000, Test Acc : 0.9111\n",
"Epoch 126 : Loss : 0.4513, Acc : 1.0000, Test Acc : 0.9111\n",
"Epoch 127 : Loss : 0.4496, Acc : 1.0000, Test Acc : 0.9111\n",
"Epoch 128 : Loss : 0.4479, Acc : 1.0000, Test Acc : 0.9111\n",
"Epoch 129 : Loss : 0.4461, Acc : 1.0000, Test Acc : 0.9111\n",
"Epoch 130 : Loss : 0.4444, Acc : 1.0000, Test Acc : 0.9111\n",
"Epoch 131 : Loss : 0.4427, Acc : 1.0000, Test Acc : 0.9111\n",
"Epoch 132 : Loss : 0.4409, Acc : 1.0000, Test Acc : 0.9111\n",
"Epoch 133 : Loss : 0.4392, Acc : 1.0000, Test Acc : 0.9111\n",
"Epoch 134 : Loss : 0.4374, Acc : 1.0000, Test Acc : 0.9111\n",
"Epoch 135 : Loss : 0.4356, Acc : 1.0000, Test Acc : 0.9111\n",
"Epoch 136 : Loss : 0.4339, Acc : 1.0000, Test Acc : 0.9111\n",
"Epoch 137 : Loss : 0.4321, Acc : 1.0000, Test Acc : 0.9222\n",
"Epoch 138 : Loss : 0.4303, Acc : 1.0000, Test Acc : 0.9222\n",
"Epoch 139 : Loss : 0.4285, Acc : 1.0000, Test Acc : 0.9222\n",
"Epoch 140 : Loss : 0.4267, Acc : 1.0000, Test Acc : 0.9222\n",
"Epoch 141 : Loss : 0.4249, Acc : 1.0000, Test Acc : 0.9222\n",
"Epoch 142 : Loss : 0.4231, Acc : 1.0000, Test Acc : 0.9222\n",
"Epoch 143 : Loss : 0.4213, Acc : 1.0000, Test Acc : 0.9222\n",
"Epoch 144 : Loss : 0.4194, Acc : 1.0000, Test Acc : 0.9222\n",
"Epoch 145 : Loss : 0.4176, Acc : 1.0000, Test Acc : 0.9222\n",
"Epoch 146 : Loss : 0.4158, Acc : 1.0000, Test Acc : 0.9222\n",
"Epoch 147 : Loss : 0.4139, Acc : 1.0000, Test Acc : 0.9222\n",
"Epoch 148 : Loss : 0.4120, Acc : 1.0000, Test Acc : 0.9222\n",
"Epoch 149 : Loss : 0.4102, Acc : 1.0000, Test Acc : 0.9222\n",
"Epoch 150 : Loss : 0.4083, Acc : 1.0000, Test Acc : 0.9222\n",
"Epoch 151 : Loss : 0.4064, Acc : 1.0000, Test Acc : 0.9222\n",
"Epoch 152 : Loss : 0.4045, Acc : 1.0000, Test Acc : 0.9222\n",
"Epoch 153 : Loss : 0.4027, Acc : 1.0000, Test Acc : 0.9222\n",
"Epoch 154 : Loss : 0.4008, Acc : 1.0000, Test Acc : 0.9222\n",
"Epoch 155 : Loss : 0.3988, Acc : 1.0000, Test Acc : 0.9222\n",
"Epoch 156 : Loss : 0.3969, Acc : 1.0000, Test Acc : 0.9333\n",
"Epoch 157 : Loss : 0.3950, Acc : 1.0000, Test Acc : 0.9333\n",
"Epoch 158 : Loss : 0.3931, Acc : 1.0000, Test Acc : 0.9444\n",
"Epoch 159 : Loss : 0.3912, Acc : 1.0000, Test Acc : 0.9556\n",
"Epoch 160 : Loss : 0.3892, Acc : 1.0000, Test Acc : 0.9556\n",
"Epoch 161 : Loss : 0.3873, Acc : 1.0000, Test Acc : 0.9556\n",
"Epoch 162 : Loss : 0.3853, Acc : 1.0000, Test Acc : 0.9556\n",
"Epoch 163 : Loss : 0.3834, Acc : 1.0000, Test Acc : 0.9556\n",
"Epoch 164 : Loss : 0.3814, Acc : 1.0000, Test Acc : 0.9556\n",
"Epoch 165 : Loss : 0.3794, Acc : 1.0000, Test Acc : 0.9556\n",
"Epoch 166 : Loss : 0.3775, Acc : 1.0000, Test Acc : 0.9556\n",
"Epoch 167 : Loss : 0.3755, Acc : 1.0000, Test Acc : 0.9556\n",
"Epoch 168 : Loss : 0.3735, Acc : 1.0000, Test Acc : 0.9556\n",
"Epoch 169 : Loss : 0.3715, Acc : 1.0000, Test Acc : 0.9556\n",
"Epoch 170 : Loss : 0.3695, Acc : 1.0000, Test Acc : 0.9556\n",
"Epoch 171 : Loss : 0.3675, Acc : 1.0000, Test Acc : 0.9556\n",
"Epoch 172 : Loss : 0.3656, Acc : 1.0000, Test Acc : 0.9667\n",
"Epoch 173 : Loss : 0.3636, Acc : 1.0000, Test Acc : 0.9667\n",
"Epoch 174 : Loss : 0.3616, Acc : 1.0000, Test Acc : 0.9667\n",
"Epoch 175 : Loss : 0.3596, Acc : 1.0000, Test Acc : 0.9667\n",
"Epoch 176 : Loss : 0.3576, Acc : 1.0000, Test Acc : 0.9667\n",
"Epoch 177 : Loss : 0.3556, Acc : 1.0000, Test Acc : 0.9667\n",
"Epoch 178 : Loss : 0.3536, Acc : 1.0000, Test Acc : 0.9667\n",
"Epoch 179 : Loss : 0.3516, Acc : 1.0000, Test Acc : 0.9667\n",
"Epoch 180 : Loss : 0.3496, Acc : 1.0000, Test Acc : 0.9667\n",
"Epoch 181 : Loss : 0.3476, Acc : 1.0000, Test Acc : 0.9778\n",
"Epoch 182 : Loss : 0.3456, Acc : 1.0000, Test Acc : 0.9778\n",
"Epoch 183 : Loss : 0.3436, Acc : 1.0000, Test Acc : 0.9778\n",
"Epoch 184 : Loss : 0.3416, Acc : 1.0000, Test Acc : 0.9778\n",
"Epoch 185 : Loss : 0.3396, Acc : 1.0000, Test Acc : 0.9778\n",
"Epoch 186 : Loss : 0.3376, Acc : 1.0000, Test Acc : 0.9778\n",
"Epoch 187 : Loss : 0.3356, Acc : 1.0000, Test Acc : 0.9778\n",
"Epoch 188 : Loss : 0.3336, Acc : 1.0000, Test Acc : 0.9778\n",
"Epoch 189 : Loss : 0.3316, Acc : 1.0000, Test Acc : 0.9778\n",
"Epoch 190 : Loss : 0.3296, Acc : 1.0000, Test Acc : 0.9778\n",
"Epoch 191 : Loss : 0.3277, Acc : 1.0000, Test Acc : 0.9778\n",
"Epoch 192 : Loss : 0.3257, Acc : 1.0000, Test Acc : 0.9778\n",
"Epoch 193 : Loss : 0.3237, Acc : 1.0000, Test Acc : 0.9778\n",
"Epoch 194 : Loss : 0.3218, Acc : 1.0000, Test Acc : 0.9778\n",
"Epoch 195 : Loss : 0.3198, Acc : 1.0000, Test Acc : 0.9778\n",
"Epoch 196 : Loss : 0.3179, Acc : 1.0000, Test Acc : 0.9778\n",
"Epoch 197 : Loss : 0.3159, Acc : 1.0000, Test Acc : 0.9778\n",
"Epoch 198 : Loss : 0.3140, Acc : 1.0000, Test Acc : 0.9778\n",
"Epoch 199 : Loss : 0.3121, Acc : 1.0000, Test Acc : 0.9778\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"import math\n",
"\n",
"# Données et modèle du problème des 2 clusters\n",
"x_train_lab, y_train_lab, x_train_unlab, y_train_unlab, x_test, y_test = generate_2clusters_dataset(num_lab=10, num_unlab=90, num_test=90)\n",
"model = create_model_2clusters()\n",
"\n",
"# Hyperparamètres de l'apprentissage\n",
"epochs = 200\n",
"batch_size = 16\n",
"if batch_size < x_train_lab.shape[0]:\n",
" steps_per_epoch = math.floor(x_train_lab.shape[0]/batch_size)\n",
"else:\n",
" steps_per_epoch = 1\n",
" batch_size = x_train_lab.shape[0]\n",
"\n",
"# Instanciation d'un optimiseur et d'une fonction de coût.\n",
"optimizer = keras.optimizers.Adam(learning_rate=1e-2)\n",
"loss_fn = keras.losses.BinaryCrossentropy()\n",
"\n",
"# Préparation des métriques pour le suivi de la performance du modèle.\n",
"train_acc_metric = keras.metrics.BinaryAccuracy()\n",
"test_acc_metric = keras.metrics.BinaryAccuracy()\n",
"\n",
"# Indices de l'ensemble labellisé\n",
"indices = np.arange(x_train_lab.shape[0])\n",
"indices_unlab = np.arange(x_train_unlab.shape[0])\n",
"\n",
"# Boucle sur les epochs\n",
"for epoch in range(epochs):\n",
"\n",
" if epoch > 50:\n",
" lambdaa = 0.25\n",
" else:\n",
" lambdaa = 0\n",
"\n",
" # A chaque nouvelle epoch, on randomise les indices de l'ensemble labellisé\n",
" np.random.shuffle(indices)\n",
" np.random.shuffle(indices_unlab)\n",
"\n",
" # Et on recommence à cumuler la loss\n",
" cum_loss_value = 0\n",
"\n",
" for step in range(steps_per_epoch):\n",
"\n",
" # Sélection des données du prochain batch\n",
" x_batch = x_train_lab[indices[step*batch_size: (step+1)*batch_size]]\n",
" x_batch_unlab = x_train_unlab[indices_unlab[step*batch_size: (step+1)*batch_size]]\n",
" y_batch = y_train_lab[indices[step*batch_size: (step+1)*batch_size]]\n",
"\n",
" # Etape nécessaire pour comparer y_batch à la sortie du réseau\n",
" y_batch = np.expand_dims(y_batch, 1)\n",
"\n",
" # Les opérations effectuées par le modèle dans ce bloc sont suivies et permettront\n",
" # la différentiation automatique.\n",
" with tf.GradientTape() as tape:\n",
"\n",
" # Application du réseau aux données d'entrée\n",
" y_pred = model(x_batch, training=True) # Logits for this minibatch\n",
" y_pred_unlab = model(x_batch_unlab, training=True)\n",
"\n",
" # Calcul de la fonction de perte sur ce batch\n",
" loss_value = loss_fn(y_batch, y_pred) + lambdaa * binary_entropy_loss(y_pred)\n",
"\n",
" # Calcul des gradients par différentiation automatique\n",
" grads = tape.gradient(loss_value, model.trainable_weights)\n",
"\n",
" # Réalisation d'une itération de la descente de gradient (mise à jour des paramètres du réseau)\n",
" optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
"\n",
" # Mise à jour de la métrique\n",
" train_acc_metric.update_state(y_batch, y_pred)\n",
"\n",
" cum_loss_value = cum_loss_value + loss_value\n",
"\n",
" # Calcul de la précision à la fin de l'epoch\n",
" train_acc = train_acc_metric.result()\n",
"\n",
" # Calcul de la précision sur l'ensemble de test à la fin de l'epoch\n",
" test_logits = model(x_test, training=False)\n",
" test_acc_metric.update_state(np.expand_dims(y_test, 1), test_logits)\n",
" test_acc = test_acc_metric.result()\n",
"\n",
" print(\"Epoch %4d : Loss : %.4f, Acc : %.4f, Test Acc : %.4f\" % (epoch, float(cum_loss_value/steps_per_epoch), float(train_acc), float(test_acc)))\n",
"\n",
" # Remise à zéro des métriques pour la prochaine epoch\n",
" train_acc_metric.reset_states()\n",
" test_acc_metric.reset_states()"
]
},
{
"cell_type": "code",
"execution_count": 114,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAlf0lEQVR4nO2de4xc9ZXnv8fu7sLVtsQOGM8QYO0mYSTIsgx+yO0FhZhJjK1VUCaKtGS0I8YejKMhCpIjZhk0uD0WDjgTx5mdkR1n7ZlhxTxWgllWkS0g2RiStKG7jQghgZm1xgOBYe0eR5v4Ibra+Owfp2+6uri36j5+9/Gr+n6kUvWtuo9Tt6vOPff7O79zRFVBCCHEX+aVbQAhhJBs0JETQojn0JETQojn0JETQojn0JETQojn9JVx0Msvv1yXLl1axqEJIcRbjh079q+qurj19VIc+dKlSzExMVHGoQkhxFtE5M2w1ymtEEKI59CRE0KI59CRE0KI59CRE0KI59CRE0KI59CRE0KI59CRE0JK5+hR4MtftmeSnFLyyAkhJODoUeD224FGAxgYAL7zHWB4uGyr/IIROSGkVI4cMSf+/vv2fORI2Rb5Bx05IT1O2bLGbbdZJD5/vj3fdls5dvgMpRVCepgqyBrDw3bcI0fMiVNWSQ4dOSE9TJisUYYjHR6mA88CpRVCehifZI2yJaAqw4ickB7GF1mjChJQlaEjJ6TH8UHWqIoEVFUorRBCKo9PElAZMCInhFQeXySgsqAjJ4R4gQ8SUFlQWiGEEM/J7MhF5BIRGRORH4rIj0VkuwvDCOlGmEJH8sCFtDIFYK2qnhWRfgDfF5HDqvqig30T0jUwhY7kReaIXI2zM4v9Mw/Nul9Cug0WhyJ54UQjF5H5IvIKgFMAnlPVl0LW2SwiEyIyMTk56eKwhHgFU+hIXjhx5Kr6vqreBOAqAKtE5KMh6+xX1RWqumLx4sUuDkuIVwQpdDt2UFbpxMhI2Rb4hai6VUFE5GEA51X1T6LWWbFihU5MTDg9LiGkexABHLumrkBEjqnqitbXXWStLBaRS2f+XgDgEwDeyLpfQggh8XAhrfwagO+KyKsAxmEa+bcc7JcQ0kOMjFgkLmLLwd+UWTrjXFqJA6UVQkg7KK2Ek5u0QgghpFzoyAkhlWPbtrIt8As6ckIKxnfNt4gyA76fo6KhRk685+hRv8qb+qz/ssxAuURp5CxjS7ym2x1L1S5S7NRTTSitEK8ps35Jktv/NKl1wUXqj/7InqtQMZFlBqoJHTnxmjIdy/YEBZtHRkxOCSSV4O92jjzPi1RanZtlBqoJpRXiNd3cAiy4SAWykauLVFY5ip16qgcjcuI9w8PAgw8W41xczD6Mm1qXV/TLcrrdB7NWCEmJr9kn3T5A3M0wa4UQAqC75ahWVO3CNTYGnDkDLFoErFplnzm4q+oG6MgJSYnPsw+L1rmLTqOcngYOHAB27QJOnbLl6Wmgv98eV1wBPPAAsGmTLfsOpRVCEjAy4nbWYdXyxPOgaCnn7Flg/Xrg5ZeB8+ej16vXgeXLgUOHgIUL87PHJSyaRYgDkqQcdiJLnriLi0lYCmIe0++LHFydnjYnPj7e3okD9v7YGLBhg23nM5RWCCmJLLMkt2/P5szDomQgn8g5rzTKMA4csEh8aire+lNTwLFjwMGDwL335mdX3jAiJ6QDeTU8KHMyU9hFJK/IuTmN8q678pNVVE0T7xSJt3L+vG3nYwZSAB05IR1IMyszDknzxF1eUFovIpddBrz1li3ncWEJcv0PHnS3z1aOHrWBzTScPFmNEghpySytiMjVAB4HsASAAtivql/Pul9CeoEge2RkJJ4jD5x21hz25hTEyy4D7r/fovC+PuCee4Df+R3/Bl/HxtJr3RcumK6+Zo1bm4rCRUR+AcBWVb0ewGoAvy8i1zvYLyGVI6+UQ5eDqHEJouTTp00rfv99c4TXXOPWiRfVi/PMmfSOvNGw7X0lsyNX1XdV9eWZv88AeB3Ah7Lul5AqUpWGBy4vKJddBly8aH9fvGjLrWT53HlJU60sWpQ+J3xgwLb3FacauYgsBfAbAF4KeW+ziEyIyMTk5KTLwxLiJVki1bROcNMmSy/cv382zfD0aWDejCeYN8+WWynjjiEpq1ald+R9fcDKlW7tKRJn6YcishDAkwDuV9VftL6vqvsB7AdsQpCr4xLiK64077iTio4etcHGefMs8p43D6jVgD177LmI9MA8Z8MOD9uMzRMnkm+7ZIl/YwLNOInIRaQf5sSfUNWnXOyTkKpQRI/KNBw9Cnz+88DHPx5vUlGQTtgsozQaFoGHZc/koW27nhnbjIhNu6/Xk21Xr9t2XtdeUdVMDwACy1rZE3eb5cuXKyE+MDqqumCB6vz59jw6ms9xtm1Ltn5gl0igOJuNO3eG73tWmZ77iPuZgGT2FbWvVhoN1VtuUa3Voj9z86NWU731VtvOBwBMaIhPdRGR/wcA/xnAWhF5ZeaxwcF+CSmdoqaXJ41SA7sCOUYkWhZpHWzcuRP4xjfsuaolbNNG7f39wOHDppd3iszrdVvv0KEuKJwV5t3zfjAiJ75QVESelGa7BgZUt2xJF1mPjloU32nbpHcMYduHRcRR+233XhwaDdV9+1SHhlQHBy3yFrHnwUF7fd8+fyLxAERE5Kx+SEgH8qpQmHW/abZv1qjLajARZ2A30KuzuqegHvn4+Nx65KtX+6mJs7EEISnJo3a3Cyeaxq5mySJL0a48GBn5YJqjiGW6pJVaXnwReP757i4TDNCREzKH1ig3r2i8Ck60yKqEzUSlIAbOutWZB8tJnXkvtbSjIydkhtYf/p49szVIwhxBllS6spxoM2W1fGt3zoJz6kJaqcLFsihY/ZCQGVp/+E8+2T5jJctsx6SVD/MiqLeSVqPPml8ftQ8XE4fKLBNcNIzICZmhNUr+zGeA732vfdScJSovum+mS1zIFu320Twgm/aOoZeaTDMiJ2SG1ih58+YPRs2tsx23b8+nkl/RJLX/8ceB997Lll/fKUd/3br0rfACstxxeEVYTmLeD+aRk24gyIXuBuJ8jiCve3TUctebZ0emya/vlKO/c6e9127Waq+BHGd2EtJTNA/GAfnV1253/DIIxgSOHLEoGrDP/bu/my7i7TRO0Esad1Y4IYiQlAR5z0X/hIIJNVlTI8PytoHovO3m48bVx8uY9NTNRE0IoiMnJANFlJ8NO+boaHxnGjjfsbG5sxuHh+dWNgz7HFHOfuNG4MMfbm97L+VxF0WUI6dG7gNZC12Q3Ej7r0law6VdBcMo/bjRUN27V3XZMqsvMjBg9UYGBmx52TJ7v9GIp5EnHQ+gxu0eUCP3GB/as/QoafXqpFUVWysYjo4CCxbM1Y8DW0ZGgLNngbVrga1brdHCuXOz1RIbDVs+ccLev/12y+wAOueGJ8kdp8ZdHJRWfCBry3RSObLIDlEaefC6CHDLLVYoamqq8/5qNZNaHnnEUv6ibBKxi0cSm6lxuyVKWmFEXlWKaj1OSiHNzE5Vi8TXrbPtXnoJ+NjHrJJf69fi5ZfjOXHA1jt2DPjqVzvfJSStzd4zedwlw4jcBxiR9zTT08CBA8CuXcCpU7Y8PW3NEFTtbxdceSXws5/Z/oKI+5lnwpW9vj7ghRfooIuGETkhHtJJ62514s89l/5YP/858PWvz94lPPNMuDYP0IlXDVfNlw+KyCkRec3F/kgLebYeJ5VlehpYv9607vPn423ziU+kP96FC8BTT81KIc2ReCDdrFkz+0ylLzl5NfJ2VTTrLwH8GawJM3ENfy09yYEDybRuwOSWtFJLo2FReBjNGTGuuvekxdcB1Dzz6p04clV9QUSWutgXIVWnCEeiapp43Eg8IItePjBgF43W8gNA/C49eZ8bnycZ5VkfnRo5IQkIHEmWinxxj3PqVD77jiIs8h8dtYtKa8MH4IOJVEWcm6T591Uiz7z6why5iGwWkQkRmZicnCzqsIQ4xZUj6aSVjo25y0aJy9CQPc+fP/ta8+drHfgM/g4ceRFO1udJRnk2EymssYSq7gewH
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAqP0lEQVR4nO3dd3SU1dbA4d/JpJFiAgSIEIoUlSIoBBAhVBULyEURwUK/Qdq1flwhoKjYu4CYARS52LBwUWzoVYqiCDZAkCKidAyQQEACJOf7IwkGSJ15Z877zuxnLdYyk8w5OyPsObNPU1prhBBCOFeI6QCEEEJ4RxK5EEI4nCRyIYRwOEnkQgjhcJLIhRDC4UKN9LpxkSyVKUFubi5ps//HpQPuwuVymQ5HCGET8VFhXFSnsirue8rI8sPlUySRl2LNlj08/OGvpAy5l7DwCNPhCCFsIDEukp4tahabyKW0YkMX1K/BI33OZ6k7jaNHsk2HI4SwOUnkNlUvsQpTBibz9awJHMrcbzocIYSNSSK3sWqVY0lPTeHHuZPYv2eH6XCEEDZlZrKzGHkoDruqkBsaCRRbBjJM4zpxlOjc/YTgvxJ/bHQkM0Z25bYZT3D8ihHUqHue3/oWQjiDbRL5YVcVwmLiiVG5KBvmca0hR0dyOBtic/f5te+I8DCm3dqNsS/PJKfNDdRpkuzX/oUQ9mab0kpuaCQRNk3iAEpBhMot+MTgfy5XCE8O7UTo2nfZvPJ/RmIQQtiTbRI5KNsm8UL58ZkLUinFfTe2J+nPL1n7xXxjcQgh7MVGiVyU15hrkkl2beS792ebDkUIYQOWJXKllEsp9YNSaqFVbZrw8bLvOO+qETTsnsqjM942HU6JbunajJ41s/j6jeeQM+WFCG5WjshvA9Zb2J7f5ebmMmpyOh+l38e696fx+odLWbf5D9Nhlejqtg0Z1jKCZbMfIi8313Q4QghDLFm1opRKAq4GHgLutKLN0rS5OY2MrL/OeDwhrhLfzn3I43a/XbOJhnXOpn7tRAD6XZnCgs9X0KRhHY/b9LVLmtQmLiqcyTMmkjLkPtnSL0QQsmr54bPAWCC2pB9QSqUCqQDpY28gtVd7jzvLyPqLpsOfOePxn9Pv8LhNgB179lE7MeHk10mJCaxYvcGrNv2hab0aPNYnjP97cTzth06iUnSJ/xuEEAHI69KKUqoHsFdr/V1pP6e1dmutk7XWyd4kcVG8OolVmDakDV/PmsjBAxmmwxFC+JEVNfL2wDVKqa3AG0BXpdRcC9r1u1o1qrJt999JcPvuDGpVr2owoopJiI9hxq0dWfPqg+zbtc10OEIIP/E6kWutx2mtk7TW9YB+wOda65u9jsyA1s0asen3nfy2fTfHjh3njY+WcU2XtqbDqpCYqAjco7qy9b2n2PXbL6bDEUL4gawjLyI01MXUtOF0/+ckGvccRd/uHWjayL4TnSUJDwvlhRHdOLh0FlvXfms6HCGEj1l61orWejGw2Mo2i5MQV6nYic2EuEpet31Vp2Su6uT8s0xCQkJ4YkgnJr+xgE2Hs2jU9jLTIQkhfMQ2NwRlRiQRH2WbM7xKlHnkBPE5202HUSHT3v+ejeGNadatj+lQhBAekhuCgtyoni1pG76F795/yXQoQggfkEQeJG7s0pReSYdY/vozsqVfiABj/1qGsMyVrRsSH7OdqS9NpuOg8YS4XKZDEgHgkdH9yc4+dMbjMTGxjJv6uoGIgo8k8iDTrnEScVFhPOieQMqQSYRFyJZ+4Z3s7EPUHzbljMe3zBxjIJrgJIk8CDWpW4MnbgjnrvRxtB96v2zpDyAyOg5OksiLGJL2HAuXrKJ6lTjWvjfVdDg+lVS9Mi8MvZgxsybQ6uaJnFUloewnCduT0XFwksnOIgb17sbH7kmmw/CbqnHRuG/txNrXHyRjl32P6xVClM7RI/KMAwcZPvFZ3JNvp2r8WV631zG5GVt37LEgMueIiYogfWRXbpvxDMcvTeXs+o1NhyQCkJR8fMvRiXzOu59wYMdmXnnnE+4cer3pcBwrPCyUabd25Z7ZL/H74euoe8HFpkMSDhITE1ts6SYm5u+5Fyn5+JZjE3nGgYMs/PQLpl9bgxEffMHA67pbMioPViEhITw2uBMPv7mQDdlZnNeuu+mQhEPIiNo8xybyOe9+Qo8GivNqRNKjwREZlVtAKUVav3ZMX7iS1Z8dpPml8no6TXlGxyLwODKRF47G5/WNA2BAq7PoO09G5VYZ0eMi3liyji/+O5PkfwwzHY6oABkdBydHrlopHI0nxOS/DyXEhNKjgeKVdz7xqt3+dz9Bu/5j2bB1B0ldBjPrnUVWhOtI/To14dpz/mL5a0/Lln4hbM6RI/LF3/7Ezl05vLZm1ymP18z4yavyyutP/p+3oQWU7q3qEx+1nednPUjKoPG4Qh3510XYgJR8fMvrY2yVUpHAUiCC/DeGt7XW95X6JDnG1lHW/76X+xdsoOPQ+2VLfylkiZ3wpdKOsbUic+YAXbXW2UqpMOBLpdRHWutvLGhb2EDjutV5qn84d7nH027wJKJkFFUsWWInTLHizk6ttc4u+DKs4I8Hw3yN3Uux+fHZPEgfqVUtnheGtuXblyaSte9P0+EIIYqwZLJTKeVSSv0I7AU+1VqvqGgbrhNHydEu2yZzrSFHu3CdOGo6FGOqnBWNe0RHfn5jMhk7fzcdjhCigCVFaa11LnChUioemK+Uaqa1Xlv0Z5RSqUAqQPrYG0jt1f6UNqJz93M4G46GRgLFloEM07hOHCI6d7/pQIyKrhSBe1Q3bpvxLMe6DqNmg6amQxIi6Fl9+XKmUuoL4Apg7WnfcwNuoNjJzhA0sbn7INfKiIQvhIW6mDq8K+PnvMLWw9dSr7ls6RfCJK8TuVKqGnC8IIlXAi4DHvM6MmFrISEhPDKwI4++9QG/HM7k/HZXmA7JOFliJ0yxYvlhc+AVwEV+zX2e1vqBUp9UzIhcONf0D35gfUgjLrjsBtOhCBGwSlt+aMWqldVa64u01s211s3KTOIi4Iy4+iJSorexcv4M06EIEZQcuUVf2M/1HRvTp/5ffDn3SdnSL4Sf2X8rpXCMy1s1ID56B8/OvJ+OgyfIln6bk52ogUP+pQlLtTm/FvdFh3Nvehodh91PeESk6ZBECWQnauCQRC4sd17tajxzYzh3phds6Y+17mjhQBhFBsLvIOxFErnwiZrV4pg+rB1jZk6kxU1pxFetbkm7gTCKDITfQdiLJHLhM5XPisI9sjOj3Q/TsNftVKtVz3RIAUlG+EISufCpqMhw0kd25Y6Zz3O802BqNrrAdEgBR0b4QhK58LmwUBfPp3Yhbc5cth7pRb0Wl5gOSSA7UQOJJHLhFyEhITw8MIXH3/mYX7IzOb/9VaZDCnpSdgkcksiF3yil+Heftrg/+pHVi7Jofnn/CrcRCKPIQPgdhL1IIhd+l3plC97+8hc+eTed1r1TUar8xxYHwijSib+DTKjamyRyYUSfDucT/8Nv/GfuE1xy092EhATWaRH+THwxMbF8/8j1aHXqa6jzchl1dTLx1RK9jkEmVO1NErkw5tKLziE+eidPz3qAlMFphIaGmQ7JMv5MfOOmvk7aoB5n9Ldj6yb2LXz6jMcl+QaewBoGCcdJPrcm9/esx5L0CeQc/ct0OEI4kozIhXGNkqrxzE3h3OEez8WDJxEdG2c6JEfbvW0Lubm55J44wbFD+1k9dQQArsgomg57ynB0whdkRC5s4eyEOF785yWsmn0vmRl7TIfjaLm5uUQk1CGsSi1cMZWpOehZag56ltyjR0yHJnzEiqveagNzgBqABtxa6+e8bVcEn/jYKGaM7Mzo9Iepf81tVE+qbzokR/rzg+fQeSdAa3IPH2Dby7cBkHsow+M2ZcmkvVlRWjkB3KW1/l4pFQt8p5T6VGu9zoK2RZCpFBFO+shu3DFzKsc7DqLWuc1Nh+QRfye+ov0dz9pNYr9H0GiUUoTG1QBg+4tD2DJzjEcxyBJDe/M6kWutdwG7Cv77kFJqPVALkEQuPBIa6uL54V2ZMOdVthw5SP0LO5gOq
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from mlxtend.plotting import plot_decision_regions\n",
"\n",
"# Affichage des données\n",
"plt.plot(x_train_unlab[y_train_unlab==0,0], x_train_unlab[y_train_unlab==0,1], 'b.')\n",
"plt.plot(x_train_unlab[y_train_unlab==1,0], x_train_unlab[y_train_unlab==1,1], 'r.')\n",
"\n",
"plt.plot(x_test[y_test==0,0], x_test[y_test==0,1], 'b+')\n",
"plt.plot(x_test[y_test==1,0], x_test[y_test==1,1], 'r+')\n",
"\n",
"plt.plot(x_train_lab[y_train_lab==0,0], x_train_lab[y_train_lab==0,1], 'b.', markersize=30)\n",
"plt.plot(x_train_lab[y_train_lab==1,0], x_train_lab[y_train_lab==1,1], 'r.', markersize=30)\n",
"\n",
"plt.show()\n",
"\n",
"#Affichage de la frontière de décision\n",
"plot_decision_regions(x_train_unlab, y_train_unlab, clf=model, legend=2)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "y25fa_IIZeuH"
},
"source": [
"Une fois cette étape réalisée, vous pouvez tester l'algorithme sur le dataset des 2 lunes ; comme annoncé en cours, vous devriez avoir beaucoup de mal à faire fonctionner l'algorithme sur ces données.\n",
"\n",
"S'il vous reste du temps, vous pouvez également tester votre algorithme sur les données MNIST, cela vous avancera pour la prochaine séance."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T3qJ5s-NUPnT"
},
"source": [
"# MNIST"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kH_eMruIMMVF"
},
"source": [
"## Chargement des données\n"
]
},
{
"cell_type": "code",
"execution_count": 166,
"metadata": {
"id": "ebv6WLB1MOkU"
},
"outputs": [],
"source": [
"from keras.datasets import mnist\n",
"import numpy as np\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"def generate_mnist_dataset(num_lab = 10, seed=10):\n",
"\n",
" # Chargement et normalisation (entre 0 et 1) des données de la base de données MNIST\n",
" (x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
"\n",
" x_train = np.expand_dims(x_train.astype('float32') / 255., 3)\n",
" x_test = np.expand_dims(x_test.astype('float32') / 255., 3)\n",
"\n",
" x_train_lab, x_train_unlab, y_train_lab, y_train_unlab = train_test_split(x_train, y_train, test_size=(x_train.shape[0]-num_lab)/x_train.shape[0], random_state=seed)\n",
"\n",
" return x_train_lab, y_train_lab, x_train_unlab, y_train_unlab, x_test, y_test\n"
]
},
{
"cell_type": "code",
"execution_count": 184,
"metadata": {
"id": "ZTsEZ2pzMpiU"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[7 5 9 4 4 1 8 1 0 7]\n",
"[7 2 8 3 1 4 6 9 5 0]\n",
"[0 5 6 4 5 5 8 2 1 4]\n"
]
}
],
"source": [
"# for i in range(1000, 10000):\n",
"# x_train_lab, y_train_lab, x_train_unlab, y_train_unlab, x_test, y_test = generate_mnist_dataset(num_lab = 10, seed=i)\n",
"# if len(set(y_train_lab)) >= 10:\n",
"# print(i)\n",
"# break\n",
"\n",
"# print(y_train_lab)\n",
"# print(len(y_train_lab))\n",
"# print(set(y_train_lab))\n",
"# test = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n",
"# print(set(test))\n",
"# print(len(set(y_train_lab)))\n",
"\n",
"# print(x_train_lab.shape, x_train_unlab.shape, x_test.shape)\n",
"# print(y_train_lab.shape, y_train_unlab.shape, y_test.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cl7GCcHvUa-l"
},
"source": [
"## Définition du modèle"
]
},
{
"cell_type": "code",
"execution_count": 142,
"metadata": {
"id": "RiYrQU0NUcKs"
},
"outputs": [],
"source": [
"from keras.layers import Conv2D, MaxPooling2D, Flatten\n",
"\n",
"# Ici, on implémentera le modèle LeNet-5 :\n",
"def create_model_mnist():\n",
"\n",
" inputs = keras.Input(shape=(28, 28, 1,))\n",
"\n",
" # 1 couche de convolution 5x5 à 6 filtres suivie d'un max pooling\n",
" x = Conv2D(6, 5, activation = 'relu')(inputs)\n",
" x = MaxPooling2D(pool_size=(2, 2))(x)\n",
"\n",
" # puis 1 couche de convolution 5x5 à 16 filtres suivie d'un max pooling\n",
" x = Conv2D(16, 5, activation = 'relu')(x)\n",
" x = MaxPooling2D(pool_size=(2, 2))(x)\n",
"\n",
" # et d'un Flatten\n",
" x = Flatten()(x)\n",
"\n",
" # Enfin 2 couches denses de 120 et 84 neurones\n",
" x = Dense(120, activation='relu')(x)\n",
" x = Dense(84, activation='relu')(x)\n",
"\n",
" # avant la couche de sortie à 10 neurones.\n",
" outputs = Dense(10, activation='softmax')(x)\n",
"\n",
" model = keras.Model(inputs=inputs, outputs=outputs) \n",
"\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 190,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0 : Loss : 2.3014, Acc : 0.1000, Test Acc : 0.1074\n",
"Epoch 1 : Loss : 2.2547, Acc : 0.3000, Test Acc : 0.1335\n",
"Epoch 2 : Loss : 2.2146, Acc : 0.4000, Test Acc : 0.1968\n",
"Epoch 3 : Loss : 2.1732, Acc : 0.6000, Test Acc : 0.2695\n",
"Epoch 4 : Loss : 2.1275, Acc : 0.7000, Test Acc : 0.3152\n",
"Epoch 5 : Loss : 2.0814, Acc : 0.8000, Test Acc : 0.3496\n",
"Epoch 6 : Loss : 2.0271, Acc : 0.8000, Test Acc : 0.3671\n",
"Epoch 7 : Loss : 1.9674, Acc : 0.8000, Test Acc : 0.3805\n",
"Epoch 8 : Loss : 1.9000, Acc : 0.8000, Test Acc : 0.3940\n",
"Epoch 9 : Loss : 1.8275, Acc : 0.8000, Test Acc : 0.4053\n",
"Epoch 10 : Loss : 1.7498, Acc : 0.9000, Test Acc : 0.4102\n",
"Epoch 11 : Loss : 1.6640, Acc : 0.9000, Test Acc : 0.4171\n",
"Epoch 12 : Loss : 1.5701, Acc : 0.9000, Test Acc : 0.4208\n",
"Epoch 13 : Loss : 1.4696, Acc : 0.9000, Test Acc : 0.4260\n",
"Epoch 14 : Loss : 1.3669, Acc : 0.9000, Test Acc : 0.4301\n",
"Epoch 15 : Loss : 1.2611, Acc : 0.9000, Test Acc : 0.4359\n",
"Epoch 16 : Loss : 1.1529, Acc : 0.9000, Test Acc : 0.4425\n",
"Epoch 17 : Loss : 1.0442, Acc : 0.9000, Test Acc : 0.4470\n",
"Epoch 18 : Loss : 0.9338, Acc : 1.0000, Test Acc : 0.4501\n",
"Epoch 19 : Loss : 0.8274, Acc : 1.0000, Test Acc : 0.4498\n",
"Epoch 20 : Loss : 0.7254, Acc : 1.0000, Test Acc : 0.4487\n",
"Epoch 21 : Loss : 0.6296, Acc : 1.0000, Test Acc : 0.4457\n",
"Epoch 22 : Loss : 0.5413, Acc : 1.0000, Test Acc : 0.4438\n",
"Epoch 23 : Loss : 0.4580, Acc : 1.0000, Test Acc : 0.4425\n",
"Epoch 24 : Loss : 0.3813, Acc : 1.0000, Test Acc : 0.4404\n",
"Epoch 25 : Loss : 0.3130, Acc : 1.0000, Test Acc : 0.4395\n",
"Epoch 26 : Loss : 1.8660, Acc : 1.0000, Test Acc : 0.4432\n",
"Epoch 27 : Loss : 1.5501, Acc : 1.0000, Test Acc : 0.4450\n",
"Epoch 28 : Loss : 1.2730, Acc : 1.0000, Test Acc : 0.4461\n",
"Epoch 29 : Loss : 1.0235, Acc : 1.0000, Test Acc : 0.4431\n",
"Epoch 30 : Loss : 0.7895, Acc : 1.0000, Test Acc : 0.4384\n",
"Epoch 31 : Loss : 0.5879, Acc : 1.0000, Test Acc : 0.4381\n",
"Epoch 32 : Loss : 0.4273, Acc : 1.0000, Test Acc : 0.4399\n",
"Epoch 33 : Loss : 0.3016, Acc : 1.0000, Test Acc : 0.4425\n",
"Epoch 34 : Loss : 0.2109, Acc : 1.0000, Test Acc : 0.4444\n",
"Epoch 35 : Loss : 0.1455, Acc : 1.0000, Test Acc : 0.4435\n",
"Epoch 36 : Loss : 0.0974, Acc : 1.0000, Test Acc : 0.4459\n",
"Epoch 37 : Loss : 0.0641, Acc : 1.0000, Test Acc : 0.4455\n",
"Epoch 38 : Loss : 0.0427, Acc : 1.0000, Test Acc : 0.4473\n",
"Epoch 39 : Loss : 0.0290, Acc : 1.0000, Test Acc : 0.4475\n",
"Epoch 40 : Loss : 0.0200, Acc : 1.0000, Test Acc : 0.4470\n",
"Epoch 41 : Loss : 0.0138, Acc : 1.0000, Test Acc : 0.4471\n",
"Epoch 42 : Loss : 0.0097, Acc : 1.0000, Test Acc : 0.4473\n",
"Epoch 43 : Loss : 0.0069, Acc : 1.0000, Test Acc : 0.4474\n",
"Epoch 44 : Loss : 0.0050, Acc : 1.0000, Test Acc : 0.4475\n",
"Epoch 45 : Loss : 0.0037, Acc : 1.0000, Test Acc : 0.4478\n",
"Epoch 46 : Loss : 0.0029, Acc : 1.0000, Test Acc : 0.4492\n",
"Epoch 47 : Loss : 0.0022, Acc : 1.0000, Test Acc : 0.4495\n",
"Epoch 48 : Loss : 0.0018, Acc : 1.0000, Test Acc : 0.4498\n",
"Epoch 49 : Loss : 0.0014, Acc : 1.0000, Test Acc : 0.4497\n",
"Epoch 50 : Loss : 0.0011, Acc : 1.0000, Test Acc : 0.4508\n",
"Epoch 51 : Loss : 0.0009, Acc : 1.0000, Test Acc : 0.4520\n",
"Epoch 52 : Loss : 0.0008, Acc : 1.0000, Test Acc : 0.4524\n",
"Epoch 53 : Loss : 0.0007, Acc : 1.0000, Test Acc : 0.4528\n",
"Epoch 54 : Loss : 0.0006, Acc : 1.0000, Test Acc : 0.4534\n",
"Epoch 55 : Loss : 0.0005, Acc : 1.0000, Test Acc : 0.4537\n",
"Epoch 56 : Loss : 0.0004, Acc : 1.0000, Test Acc : 0.4540\n",
"Epoch 57 : Loss : 0.0004, Acc : 1.0000, Test Acc : 0.4550\n",
"Epoch 58 : Loss : 0.0003, Acc : 1.0000, Test Acc : 0.4559\n",
"Epoch 59 : Loss : 0.0003, Acc : 1.0000, Test Acc : 0.4560\n",
"Epoch 60 : Loss : 0.0003, Acc : 1.0000, Test Acc : 0.4559\n",
"Epoch 61 : Loss : 0.0002, Acc : 1.0000, Test Acc : 0.4559\n",
"Epoch 62 : Loss : 0.0002, Acc : 1.0000, Test Acc : 0.4560\n",
"Epoch 63 : Loss : 0.0002, Acc : 1.0000, Test Acc : 0.4563\n",
"Epoch 64 : Loss : 0.0002, Acc : 1.0000, Test Acc : 0.4562\n",
"Epoch 65 : Loss : 0.0002, Acc : 1.0000, Test Acc : 0.4566\n",
"Epoch 66 : Loss : 0.0002, Acc : 1.0000, Test Acc : 0.4572\n",
"Epoch 67 : Loss : 0.0002, Acc : 1.0000, Test Acc : 0.4579\n",
"Epoch 68 : Loss : 0.0002, Acc : 1.0000, Test Acc : 0.4580\n",
"Epoch 69 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4582\n",
"Epoch 70 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4582\n",
"Epoch 71 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4585\n",
"Epoch 72 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4585\n",
"Epoch 73 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4585\n",
"Epoch 74 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4586\n",
"Epoch 75 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4583\n",
"Epoch 76 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4583\n",
"Epoch 77 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4580\n",
"Epoch 78 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4577\n",
"Epoch 79 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4573\n",
"Epoch 80 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4577\n",
"Epoch 81 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4576\n",
"Epoch 82 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4576\n",
"Epoch 83 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4574\n",
"Epoch 84 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4573\n",
"Epoch 85 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4572\n",
"Epoch 86 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4574\n",
"Epoch 87 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4573\n",
"Epoch 88 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4571\n",
"Epoch 89 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4570\n",
"Epoch 90 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4569\n",
"Epoch 91 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4567\n",
"Epoch 92 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4567\n",
"Epoch 93 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4567\n",
"Epoch 94 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4567\n",
"Epoch 95 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4569\n",
"Epoch 96 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4567\n",
"Epoch 97 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4568\n",
"Epoch 98 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4568\n",
"Epoch 99 : Loss : 0.0001, Acc : 1.0000, Test Acc : 0.4571\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"import math\n",
"\n",
"# Données et modèle du problème des 2 clusters\n",
"x_train_lab, y_train_lab, x_train_unlab, y_train_unlab, x_test, y_test = generate_mnist_dataset(num_lab=10, seed=1164)\n",
"model = create_model_mnist()\n",
"\n",
"# Hyperparamètres de l'apprentissage\n",
"epochs = 100\n",
"batch_size = 64\n",
"if batch_size < x_train_lab.shape[0]:\n",
" steps_per_epoch = math.floor(x_train_lab.shape[0]/batch_size)\n",
"else:\n",
" steps_per_epoch = 1\n",
" batch_size = x_train_lab.shape[0]\n",
"\n",
"# Instanciation d'un optimiseur et d'une fonction de coût.\n",
"optimizer = keras.optimizers.Adam(learning_rate=1e-3)\n",
"loss_fn = keras.losses.SparseCategoricalCrossentropy()\n",
"\n",
"# Préparation des métriques pour le suivi de la performance du modèle.\n",
"train_acc_metric = keras.metrics.SparseCategoricalAccuracy()\n",
"test_acc_metric = keras.metrics.SparseCategoricalAccuracy()\n",
"\n",
"# Indices de l'ensemble labellisé\n",
"indices = np.arange(x_train_lab.shape[0])\n",
"indices_unlab = np.arange(x_train_unlab.shape[0])\n",
"\n",
"# Boucle sur les epochs\n",
"for epoch in range(epochs):\n",
"\n",
" if epoch > 25:\n",
" lambdaa = 0.2\n",
" else:\n",
" lambdaa = 0\n",
"\n",
" # A chaque nouvelle epoch, on randomise les indices de l'ensemble labellisé\n",
" np.random.shuffle(indices)\n",
" np.random.shuffle(indices_unlab)\n",
"\n",
" # Et on recommence à cumuler la loss\n",
" cum_loss_value = 0\n",
"\n",
" for step in range(steps_per_epoch):\n",
"\n",
" # Sélection des données du prochain batch\n",
" x_batch = x_train_lab[indices[step*batch_size: (step+1)*batch_size]]\n",
" x_batch_unlab = x_train_unlab[indices_unlab[step*batch_size: (step+1)*batch_size]]\n",
" y_batch = y_train_lab[indices[step*batch_size: (step+1)*batch_size]]\n",
"\n",
" # Etape nécessaire pour comparer y_batch à la sortie du réseau\n",
" y_batch = np.expand_dims(y_batch, 1)\n",
"\n",
" # Les opérations effectuées par le modèle dans ce bloc sont suivies et permettront\n",
" # la différentiation automatique.\n",
" with tf.GradientTape() as tape:\n",
"\n",
" # Application du réseau aux données d'entrée\n",
" y_pred = model(x_batch, training=True) # Logits for this minibatch\n",
" y_pred_unlab = model(x_batch_unlab, training=True)\n",
"\n",
" # Calcul de la fonction de perte sur ce batch\n",
" # print(y_batch)\n",
" # print(y_pred)\n",
" loss_value = loss_fn(y_batch, y_pred) + lambdaa * binary_entropy_loss(y_pred)\n",
"\n",
" # Calcul des gradients par différentiation automatique\n",
" grads = tape.gradient(loss_value, model.trainable_weights)\n",
"\n",
" # Réalisation d'une itération de la descente de gradient (mise à jour des paramètres du réseau)\n",
" optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
"\n",
" # Mise à jour de la métrique\n",
" train_acc_metric.update_state(y_batch, y_pred)\n",
"\n",
" cum_loss_value = cum_loss_value + loss_value\n",
"\n",
" # Calcul de la précision à la fin de l'epoch\n",
" train_acc = train_acc_metric.result()\n",
"\n",
" # Calcul de la précision sur l'ensemble de test à la fin de l'epoch\n",
" test_logits = model(x_test, training=False)\n",
" test_acc_metric.update_state(np.expand_dims(y_test, 1), test_logits)\n",
" test_acc = test_acc_metric.result()\n",
"\n",
" print(\"Epoch %4d : Loss : %.4f, Acc : %.4f, Test Acc : %.4f\" % (epoch, float(cum_loss_value/steps_per_epoch), float(train_acc), float(test_acc)))\n",
"\n",
" # Remise à zéro des métriques pour la prochaine epoch\n",
" train_acc_metric.reset_states()\n",
" test_acc_metric.reset_states()"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [
"fbmhai8PVXVd",
"kH_eMruIMMVF"
],
"machine_shape": "hm",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3.10.7 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.7"
},
"vscode": {
"interpreter": {
"hash": "767d51c1340bd893661ea55ea3124f6de3c7a262a8b4abca0554b478b1e2ff90"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}