457 lines
26 KiB
Plaintext
457 lines
26 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "Ls4hgfTEHgGR"
|
||
},
|
||
"source": [
|
||
"# Réseaux Génératifs Antagonistes\n",
|
||
"\n",
|
||
"Dans ce TP nous allons mettre en place l'entraînement d'un réseau de neurone génératif, entraîné de manière antagoniste à l'aide d'un réseau discriminateur. \n",
|
||
"\n",
|
||
"<center> <img src=\"https://drive.google.com/uc?id=1_ADmA-Js37z6R-0o476dzX4jMG5WHLtr\" width=600></center>\n",
|
||
"<caption><center> Schéma global de fonctionnement d'un GAN ([Goodfellow 2014]) </center></caption>\n",
|
||
"\n",
|
||
"Dans un premier temps, nous allons illustrer le fonctionnement du GAN sur l'exemple simple, canonique, de la base de données MNIST. \n",
|
||
"Votre objectif sera par la suite d'adapter cet exemple à la base de données *Labelled Faces in the Wild*, et éventuellement d'implémenter quelques astuces permettant d'améliorer l'entrainement.\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {
|
||
"id": "TRziuDJMInpM"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import tensorflow as tf\n",
|
||
"from tensorflow import keras\n",
|
||
"from keras import layers\n",
|
||
"import numpy as np\n",
|
||
"import os\n",
|
||
"import matplotlib.pyplot as plt"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "IFNzLxouIfwy"
|
||
},
|
||
"source": [
|
||
"On commence par définir les réseaux discriminateur et générateur, en suivant les recommandations de DCGAN (activation *LeakyReLU*, *stride*, *Batch Normalization*, activation de sortie *tanh* pour le générateur)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {
|
||
"id": "IfPhxKGLHfD-"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Model: \"discriminator\"\n",
|
||
"_________________________________________________________________\n",
|
||
" Layer (type) Output Shape Param # \n",
|
||
"=================================================================\n",
|
||
" conv2d_3 (Conv2D) (None, 14, 14, 64) 640 \n",
|
||
" \n",
|
||
" batch_normalization_5 (Batc (None, 14, 14, 64) 256 \n",
|
||
" hNormalization) \n",
|
||
" \n",
|
||
" leaky_re_lu_5 (LeakyReLU) (None, 14, 14, 64) 0 \n",
|
||
" \n",
|
||
" conv2d_4 (Conv2D) (None, 7, 7, 128) 73856 \n",
|
||
" \n",
|
||
" batch_normalization_6 (Batc (None, 7, 7, 128) 512 \n",
|
||
" hNormalization) \n",
|
||
" \n",
|
||
" leaky_re_lu_6 (LeakyReLU) (None, 7, 7, 128) 0 \n",
|
||
" \n",
|
||
" global_max_pooling2d_1 (Glo (None, 128) 0 \n",
|
||
" balMaxPooling2D) \n",
|
||
" \n",
|
||
" dense_2 (Dense) (None, 1) 129 \n",
|
||
" \n",
|
||
"=================================================================\n",
|
||
"Total params: 75,393\n",
|
||
"Trainable params: 75,009\n",
|
||
"Non-trainable params: 384\n",
|
||
"_________________________________________________________________\n",
|
||
"Model: \"generator\"\n",
|
||
"_________________________________________________________________\n",
|
||
" Layer (type) Output Shape Param # \n",
|
||
"=================================================================\n",
|
||
" dense_3 (Dense) (None, 6272) 809088 \n",
|
||
" \n",
|
||
" batch_normalization_7 (Batc (None, 6272) 25088 \n",
|
||
" hNormalization) \n",
|
||
" \n",
|
||
" leaky_re_lu_7 (LeakyReLU) (None, 6272) 0 \n",
|
||
" \n",
|
||
" reshape_1 (Reshape) (None, 7, 7, 128) 0 \n",
|
||
" \n",
|
||
" conv2d_transpose_2 (Conv2DT (None, 14, 14, 128) 262272 \n",
|
||
" ranspose) \n",
|
||
" \n",
|
||
" batch_normalization_8 (Batc (None, 14, 14, 128) 512 \n",
|
||
" hNormalization) \n",
|
||
" \n",
|
||
" leaky_re_lu_8 (LeakyReLU) (None, 14, 14, 128) 0 \n",
|
||
" \n",
|
||
" conv2d_transpose_3 (Conv2DT (None, 28, 28, 128) 262272 \n",
|
||
" ranspose) \n",
|
||
" \n",
|
||
" batch_normalization_9 (Batc (None, 28, 28, 128) 512 \n",
|
||
" hNormalization) \n",
|
||
" \n",
|
||
" leaky_re_lu_9 (LeakyReLU) (None, 28, 28, 128) 0 \n",
|
||
" \n",
|
||
" conv2d_5 (Conv2D) (None, 28, 28, 1) 6273 \n",
|
||
" \n",
|
||
"=================================================================\n",
|
||
"Total params: 1,366,017\n",
|
||
"Trainable params: 1,352,961\n",
|
||
"Non-trainable params: 13,056\n",
|
||
"_________________________________________________________________\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"latent_dim = 128\n",
|
||
"discriminator = keras.Sequential(\n",
|
||
" [\n",
|
||
" keras.Input(shape=(28, 28, 1)),\n",
|
||
" layers.Conv2D(64, (3, 3), strides=(2, 2), padding=\"same\"),\n",
|
||
" layers.BatchNormalization(momentum = 0.8),\n",
|
||
" layers.LeakyReLU(alpha=0.2),\n",
|
||
" layers.Conv2D(128, (3, 3), strides=(2, 2), padding=\"same\"),\n",
|
||
" layers.BatchNormalization(momentum = 0.8),\n",
|
||
" layers.LeakyReLU(alpha=0.2),\n",
|
||
" layers.GlobalMaxPooling2D(),\n",
|
||
" layers.Dense(1, activation=\"sigmoid\"),\n",
|
||
" ],\n",
|
||
" name=\"discriminator\",\n",
|
||
")\n",
|
||
"discriminator.summary()\n",
|
||
"\n",
|
||
"generator = keras.Sequential(\n",
|
||
" [\n",
|
||
" keras.Input(shape=(latent_dim,)),\n",
|
||
" layers.Dense(7 * 7 * 128), \n",
|
||
" layers.BatchNormalization(momentum = 0.8),\n",
|
||
" layers.LeakyReLU(alpha=0.2),\n",
|
||
" layers.Reshape((7, 7, 128)),\n",
|
||
" layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding=\"same\"),\n",
|
||
" layers.BatchNormalization(momentum = 0.8),\n",
|
||
" layers.LeakyReLU(alpha=0.2),\n",
|
||
" layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding=\"same\"),\n",
|
||
" layers.BatchNormalization(momentum = 0.8),\n",
|
||
" layers.LeakyReLU(alpha=0.2),\n",
|
||
" layers.Conv2D(1, (7, 7), padding=\"same\", activation=\"tanh\"),\n",
|
||
" ],\n",
|
||
" name=\"generator\",\n",
|
||
")\n",
|
||
"generator.summary()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "kZ0FTcu6yl56"
|
||
},
|
||
"source": [
|
||
"Le code suivant décrit ce qui se passe à chaque itération de l'algorithme, ce qui est également résumé dans le cours sur le slide suivant : \n",
|
||
"\n",
|
||
"<center> <img src=\"https://drive.google.com/uc?id=1I6KesJZeSN_p_mx5nkAsVUeMmUKfIYB_\" width=600></center>\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {
|
||
"id": "_RnxhJX_KJxF"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Instanciation de deux optimiseurs, l'un pour le discrimnateur et l'autre pour le générateur\n",
|
||
"d_optimizer = keras.optimizers.Adam(learning_rate=0.0008)\n",
|
||
"g_optimizer = keras.optimizers.Adam(learning_rate=0.0004)\n",
|
||
"\n",
|
||
"# Instanciation d'une fonction de coût entropie croisée\n",
|
||
"loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)\n",
|
||
"\n",
|
||
"\n",
|
||
"# La fonction prend en entrée un mini-batch d'images réelles\n",
|
||
"@tf.function\n",
|
||
"def train_step(real_images):\n",
|
||
" batch_size = tf.shape(real_images)[0]\n",
|
||
"\n",
|
||
" # ENTRAINEMENT DU DISCRIMINATEUR\n",
|
||
" # Échantillonnage d’un mini-batch de bruit\n",
|
||
" random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim,))\n",
|
||
" # Création d'un mini-batch d'images générées à partir du bruit\n",
|
||
" generated_images = generator(random_latent_vectors)\n",
|
||
" # Échantillonnage d’un mini-batch de données combinant images générées et réelles\n",
|
||
" combined_images = tf.concat([generated_images, real_images], axis=0)\n",
|
||
"\n",
|
||
" # Création des labels associés au mini-batch de données créé précédemment\n",
|
||
" # Pour l'entraînement du discriminateur :\n",
|
||
" # - les données générées sont labellisées \"0\" \n",
|
||
" # - les données réelles sont labellisées \"1\" \n",
|
||
" labels = tf.concat([tf.zeros((batch_size, 1)), tf.ones((batch_size, 1))], axis=0)\n",
|
||
"\n",
|
||
" # Entraînement du discriminateur\n",
|
||
" with tf.GradientTape() as tape:\n",
|
||
" # L'appel d'un modèle (ici discriminator) à l'intérieur de Tf.GradientTape\n",
|
||
" # permet de récupérer les gradients pour faire la mise à jour\n",
|
||
"\n",
|
||
" # Prédiction du discriminateur sur notre batch d'images réelles et générées\n",
|
||
" predictions = discriminator(combined_images)\n",
|
||
" # Calcul de la fonction de coût\n",
|
||
" d_loss = loss_fn(labels, predictions)\n",
|
||
"\n",
|
||
" # Récupération des gradients de la fonction de coût par rapport aux paramètres du discriminateur\n",
|
||
" grads = tape.gradient(d_loss, discriminator.trainable_weights)\n",
|
||
" # Mise à jour des paramètres par l'optimiseur grâce aux gradients de la fonction de coût\n",
|
||
" d_optimizer.apply_gradients(zip(grads, discriminator.trainable_weights))\n",
|
||
" ### NOTE : ON N'ENTRAINE PAS LE GENERATEUR A CE MOMENT !\n",
|
||
"\n",
|
||
" # ENTRAINEMENT DU GENERATEUR\n",
|
||
" # Échantillonnage d’un mini-batch de bruit\n",
|
||
" random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim,))\n",
|
||
" # Création des labels associés au mini-batch de données créé précédemment\n",
|
||
" # Pour l'entraînement du générateur :\n",
|
||
" # - les données générées sont labellisées ici \"1\" \n",
|
||
" misleading_labels = tf.ones((batch_size, 1))\n",
|
||
"\n",
|
||
" # Entraînement du générateur sans toucher aux paramètres du discriminateur !\n",
|
||
" with tf.GradientTape() as tape:\n",
|
||
" predictions = discriminator(generator(random_latent_vectors))\n",
|
||
" g_loss = loss_fn(misleading_labels, predictions)\n",
|
||
" \n",
|
||
" # Récupération des gradients de la fonction de coût par rapport aux paramètres du générateur\n",
|
||
" grads = tape.gradient(g_loss, generator.trainable_weights)\n",
|
||
" # Mise à jour des paramètres par l'optimiseur grâce aux gradients de la fonction de coût\n",
|
||
" g_optimizer.apply_gradients(zip(grads, generator.trainable_weights))\n",
|
||
"\n",
|
||
" return d_loss, g_loss, generated_images"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "all1LAF92h1u"
|
||
},
|
||
"source": [
|
||
"Il reste à écrire l'algorithme final qui va faire appel au code d'itération écrit précédemment"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {
|
||
"id": "lQJWoazN2pwd"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n",
|
||
"Start epoch 0\n"
|
||
]
|
||
},
|
||
{
|
||
"ename": "TypeError",
|
||
"evalue": "in user code:\n\n File \"/tmp/ipykernel_11002/1607979120.py\", line 26, in train_step *\n labels = tf.concat(tf.zeros((batch_size, 1)), tf.ones((batch_size, 1)), axis=0)\n\n TypeError: Got multiple values for argument 'axis'\n",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
|
||
"\u001b[1;32m/home/laurent/Documents/Cours/ENSEEIHT/S9 - IAM/IAM2022_TP_GAN_Sujet.ipynb Cell 8\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m <a href='vscode-notebook-cell:/home/laurent/Documents/Cours/ENSEEIHT/S9%20-%20IAM/IAM2022_TP_GAN_Sujet.ipynb#X10sZmlsZQ%3D%3D?line=12'>13</a>\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m\"\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39mStart epoch\u001b[39m\u001b[39m\"\u001b[39m, epoch)\n\u001b[1;32m <a href='vscode-notebook-cell:/home/laurent/Documents/Cours/ENSEEIHT/S9%20-%20IAM/IAM2022_TP_GAN_Sujet.ipynb#X10sZmlsZQ%3D%3D?line=14'>15</a>\u001b[0m \u001b[39mfor\u001b[39;00m step, real_images \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(dataset):\n\u001b[1;32m <a href='vscode-notebook-cell:/home/laurent/Documents/Cours/ENSEEIHT/S9%20-%20IAM/IAM2022_TP_GAN_Sujet.ipynb#X10sZmlsZQ%3D%3D?line=15'>16</a>\u001b[0m \u001b[39m# Descente de gradient simultanée du discrimnateur et du générateur\u001b[39;00m\n\u001b[0;32m---> <a href='vscode-notebook-cell:/home/laurent/Documents/Cours/ENSEEIHT/S9%20-%20IAM/IAM2022_TP_GAN_Sujet.ipynb#X10sZmlsZQ%3D%3D?line=16'>17</a>\u001b[0m d_loss, g_loss, generated_images \u001b[39m=\u001b[39m train_step(real_images)\n\u001b[1;32m <a href='vscode-notebook-cell:/home/laurent/Documents/Cours/ENSEEIHT/S9%20-%20IAM/IAM2022_TP_GAN_Sujet.ipynb#X10sZmlsZQ%3D%3D?line=18'>19</a>\u001b[0m \u001b[39m# Affichage régulier d'images générées.\u001b[39;00m\n\u001b[1;32m <a href='vscode-notebook-cell:/home/laurent/Documents/Cours/ENSEEIHT/S9%20-%20IAM/IAM2022_TP_GAN_Sujet.ipynb#X10sZmlsZQ%3D%3D?line=19'>20</a>\u001b[0m \u001b[39mif\u001b[39;00m step \u001b[39m%\u001b[39m \u001b[39m200\u001b[39m \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m:\n\u001b[1;32m <a href='vscode-notebook-cell:/home/laurent/Documents/Cours/ENSEEIHT/S9%20-%20IAM/IAM2022_TP_GAN_Sujet.ipynb#X10sZmlsZQ%3D%3D?line=20'>21</a>\u001b[0m \u001b[39m# Métriques\u001b[39;00m\n",
|
||
"File \u001b[0;32m~/.local/lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py:153\u001b[0m, in \u001b[0;36mfilter_traceback.<locals>.error_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 151\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mException\u001b[39;00m \u001b[39mas\u001b[39;00m e:\n\u001b[1;32m 152\u001b[0m filtered_tb \u001b[39m=\u001b[39m _process_traceback_frames(e\u001b[39m.\u001b[39m__traceback__)\n\u001b[0;32m--> 153\u001b[0m \u001b[39mraise\u001b[39;00m e\u001b[39m.\u001b[39mwith_traceback(filtered_tb) \u001b[39mfrom\u001b[39;00m \u001b[39mNone\u001b[39m\n\u001b[1;32m 154\u001b[0m \u001b[39mfinally\u001b[39;00m:\n\u001b[1;32m 155\u001b[0m \u001b[39mdel\u001b[39;00m filtered_tb\n",
|
||
"File \u001b[0;32m~/.local/lib/python3.10/site-packages/tensorflow/python/framework/func_graph.py:1147\u001b[0m, in \u001b[0;36mfunc_graph_from_py_func.<locals>.autograph_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1145\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mException\u001b[39;00m \u001b[39mas\u001b[39;00m e: \u001b[39m# pylint:disable=broad-except\u001b[39;00m\n\u001b[1;32m 1146\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mhasattr\u001b[39m(e, \u001b[39m\"\u001b[39m\u001b[39mag_error_metadata\u001b[39m\u001b[39m\"\u001b[39m):\n\u001b[0;32m-> 1147\u001b[0m \u001b[39mraise\u001b[39;00m e\u001b[39m.\u001b[39mag_error_metadata\u001b[39m.\u001b[39mto_exception(e)\n\u001b[1;32m 1148\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 1149\u001b[0m \u001b[39mraise\u001b[39;00m\n",
|
||
"\u001b[0;31mTypeError\u001b[0m: in user code:\n\n File \"/tmp/ipykernel_11002/1607979120.py\", line 26, in train_step *\n labels = tf.concat(tf.zeros((batch_size, 1)), tf.ones((batch_size, 1)), axis=0)\n\n TypeError: Got multiple values for argument 'axis'\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# Préparation de la base de données : on utilise toutes les images (entraînement + test) de MNIST\n",
|
||
"batch_size = 32\n",
|
||
"(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()\n",
|
||
"all_digits = np.concatenate([x_train, x_test])\n",
|
||
"all_digits = (all_digits.astype(\"float32\")-127.5) / 127.5 # Images normalisées\n",
|
||
"all_digits = np.reshape(all_digits, (-1, 28, 28, 1))\n",
|
||
"dataset = tf.data.Dataset.from_tensor_slices(all_digits)\n",
|
||
"dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)\n",
|
||
"\n",
|
||
"epochs = 20 # Une 20aine d'epochs est nécessaire pour voir des chiffres qui semblent réalistes\n",
|
||
"\n",
|
||
"for epoch in range(epochs):\n",
|
||
" print(\"\\nStart epoch\", epoch)\n",
|
||
"\n",
|
||
" for step, real_images in enumerate(dataset):\n",
|
||
" # Descente de gradient simultanée du discrimnateur et du générateur\n",
|
||
" d_loss, g_loss, generated_images = train_step(real_images)\n",
|
||
"\n",
|
||
" # Affichage régulier d'images générées.\n",
|
||
" if step % 200 == 0:\n",
|
||
" # Métriques\n",
|
||
" print(\"Perte du discriminateur à l'étape %d: %.2f\" % (step, d_loss))\n",
|
||
" print(\"Perte du générateur à l'étape %d: %.2f\" % (step, g_loss))\n",
|
||
"\n",
|
||
" plt.figure(figsize=(20, 4))\n",
|
||
" for i in range(10):\n",
|
||
" plt.subplot(1,10, i+1)\n",
|
||
" plt.imshow(generated_images[i, :, :, 0]*128+128, cmap='gray')\n",
|
||
" \n",
|
||
" plt.show()\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "kwIc9354oNIV"
|
||
},
|
||
"source": [
|
||
"# Travail à faire :\n",
|
||
"\n",
|
||
"Prenez le temps de lire, de comprendre et de compléter le code qui vous est fourni. Observez attentivement l'évolution des métriques ainsi que les images générées au cours de l'entraînement. L'objectif de ce TP est d'abord de vous fournir un exemple de code implémentant les GANs, mais surtout de vous faire sentir la difficulté d'entraîner ces modèles.\n",
|
||
"\n",
|
||
"Dans la suite du TP, nous vous fournissons ci-dessous un code de chargement de la base de données de visages *Labelled Faces in the Wild*. Votre objectif est donc d'adapter le code précédent pour générer non plus des chiffres mais des visages.\n",
|
||
"\n",
|
||
"Quelques précisions importantes, et indications : \n",
|
||
"\n",
|
||
"\n",
|
||
"* MNIST est une base de données d'images noir et blanc de dimension 28 $\\times$ 28, LFW est une base de données d'images couleur de dimension 32 $\\times$ 32 $\\times$ 3\n",
|
||
"* La diversité des visages est bien plus grande que celle des chiffres ; votre générateur doit donc être un peu plus complexe que celui utilisé ici (plus de couches, et/ou plus de filtres par exemple) \n",
|
||
"* Pour faire fonctionner ce second exemple, il pourrait être nécessaire de modifier quelques hyperparamètres (dimension de l'espace latent, taux d'apprentissage des générateur et discriminateur, etc.)\n",
|
||
"\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "ohexDvCYrahC"
|
||
},
|
||
"source": [
|
||
"Le code suivant télécharge et prépare les données de la base LFW."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "Ot-zkfDBQUkl"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import pandas as pd\n",
|
||
"import tarfile, tqdm, cv2, os\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"# Télécharger les données de la base de données \"Labelled Faces in the Wild\"\n",
|
||
"!wget http://www.cs.columbia.edu/CAVE/databases/pubfig/download/lfw_attributes.txt\n",
|
||
"!wget http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz\n",
|
||
"!wget http://vis-www.cs.umass.edu/lfw/lfw.tgz\n",
|
||
" \n",
|
||
"ATTRS_NAME = \"lfw_attributes.txt\"\n",
|
||
"IMAGES_NAME = \"lfw-deepfunneled.tgz\"\n",
|
||
"RAW_IMAGES_NAME = \"lfw.tgz\"\n",
|
||
"\n",
|
||
"def decode_image_from_raw_bytes(raw_bytes):\n",
|
||
" img = cv2.imdecode(np.asarray(bytearray(raw_bytes), dtype=np.uint8), 1)\n",
|
||
" img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
|
||
" return img\n",
|
||
"\n",
|
||
"def load_lfw_dataset(\n",
|
||
" use_raw=False,\n",
|
||
" dx=80, dy=80,\n",
|
||
" dimx=45, dimy=45):\n",
|
||
"\n",
|
||
" # Read attrs\n",
|
||
" df_attrs = pd.read_csv(ATTRS_NAME, sep='\\t', skiprows=1)\n",
|
||
" df_attrs = pd.DataFrame(df_attrs.iloc[:, :-1].values, columns=df_attrs.columns[1:])\n",
|
||
" imgs_with_attrs = set(map(tuple, df_attrs[[\"person\", \"imagenum\"]].values))\n",
|
||
"\n",
|
||
" # Read photos\n",
|
||
" all_photos = []\n",
|
||
" photo_ids = []\n",
|
||
"\n",
|
||
" # tqdm in used to show progress bar while reading the data in a notebook here, you can change\n",
|
||
" # tqdm_notebook to use it outside a notebook\n",
|
||
" with tarfile.open(RAW_IMAGES_NAME if use_raw else IMAGES_NAME) as f:\n",
|
||
" for m in tqdm.tqdm_notebook(f.getmembers()):\n",
|
||
" # Only process image files from the compressed data\n",
|
||
" if m.isfile() and m.name.endswith(\".jpg\"):\n",
|
||
" # Prepare image\n",
|
||
" img = decode_image_from_raw_bytes(f.extractfile(m).read())\n",
|
||
"\n",
|
||
" # Crop only faces and resize it\n",
|
||
" img = img[dy:-dy, dx:-dx]\n",
|
||
" img = cv2.resize(img, (dimx, dimy))\n",
|
||
"\n",
|
||
" # Parse person and append it to the collected data\n",
|
||
" fname = os.path.split(m.name)[-1]\n",
|
||
" fname_splitted = fname[:-4].replace('_', ' ').split()\n",
|
||
" person_id = ' '.join(fname_splitted[:-1])\n",
|
||
" photo_number = int(fname_splitted[-1])\n",
|
||
" if (person_id, photo_number) in imgs_with_attrs:\n",
|
||
" all_photos.append(img)\n",
|
||
" photo_ids.append({'person': person_id, 'imagenum': photo_number})\n",
|
||
"\n",
|
||
" photo_ids = pd.DataFrame(photo_ids)\n",
|
||
" all_photos = np.stack(all_photos).astype('uint8')\n",
|
||
"\n",
|
||
" # Preserve photo_ids order!\n",
|
||
" all_attrs = photo_ids.merge(df_attrs, on=('person', 'imagenum')).drop([\"person\", \"imagenum\"], axis=1)\n",
|
||
"\n",
|
||
" return all_photos, all_attrs\n",
|
||
"\n",
|
||
"# Prépare le dataset et le charge dans la variable X\n",
|
||
"X, attr = load_lfw_dataset(use_raw=True, dimx=32, dimy=32)\n",
|
||
"# Normalise les images\n",
|
||
"X = (X.astype(\"float32\")-127.5)/127.5\n"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"accelerator": "GPU",
|
||
"colab": {
|
||
"collapsed_sections": [],
|
||
"machine_shape": "hm",
|
||
"provenance": []
|
||
},
|
||
"kernelspec": {
|
||
"display_name": "Python 3.10.8 64-bit",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.10.8"
|
||
},
|
||
"vscode": {
|
||
"interpreter": {
|
||
"hash": "767d51c1340bd893661ea55ea3124f6de3c7a262a8b4abca0554b478b1e2ff90"
|
||
}
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 0
|
||
}
|