diff --git a/src/SGAN.ipynb b/src/SGAN.ipynb new file mode 100644 index 0000000..eecb18e --- /dev/null +++ b/src/SGAN.ipynb @@ -0,0 +1 @@ +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"widgets":{"application/vnd.jupyter.widget-state+json":{"8449be8723ae49a998eef74ac99830ae":{"model_module":"@jupyter-widgets/output","model_name":"OutputModel","model_module_version":"1.0.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/output","_model_module_version":"1.0.0","_model_name":"OutputModel","_view_count":null,"_view_module":"@jupyter-widgets/output","_view_module_version":"1.0.0","_view_name":"OutputView","layout":"IPY_MODEL_70abdf1919c642108f09f8b05e0cbbd2","msg_id":"","outputs":[{"output_type":"display_data","data":{"text/plain":"Labs... \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[35m100%\u001b[0m \u001b[36m0:00:00\u001b[0m\n","text/html":"
Labs... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00\n
\n"},"metadata":{}}]}},"70abdf1919c642108f09f8b05e0cbbd2":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"140350eb93d342db996b0ef188a99a01":{"model_module":"@jupyter-widgets/output","model_name":"OutputModel","model_module_version":"1.0.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/output","_model_module_version":"1.0.0","_model_name":"OutputModel","_view_count":null,"_view_module":"@jupyter-widgets/output","_view_module_version":"1.0.0","_view_name":"OutputView","layout":"IPY_MODEL_bb5381d4f3f54262899685bd44ad4db6","msg_id":"","outputs":[{"output_type":"display_data","data":{"text/plain":"Unlabs... \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[35m100%\u001b[0m \u001b[36m0:00:00\u001b[0m\n","text/html":"
Unlabs... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00\n
\n"},"metadata":{}}]}},"bb5381d4f3f54262899685bd44ad4db6":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}},"32f04016312a486e8920415c14df928f":{"model_module":"@jupyter-widgets/output","model_name":"OutputModel","model_module_version":"1.0.0","state":{"_dom_classes":[],"_model_module":"@jupyter-widgets/output","_model_module_version":"1.0.0","_model_name":"OutputModel","_view_count":null,"_view_module":"@jupyter-widgets/output","_view_module_version":"1.0.0","_view_name":"OutputView","layout":"IPY_MODEL_00b216d0f53045c5b1533d1d4266f756","msg_id":"","outputs":[{"output_type":"display_data","data":{"text/plain":"Tests... \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[35m100%\u001b[0m \u001b[36m0:00:00\u001b[0m\n","text/html":"
Tests... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00\n
\n"},"metadata":{}}]}},"00b216d0f53045c5b1533d1d4266f756":{"model_module":"@jupyter-widgets/base","model_name":"LayoutModel","model_module_version":"1.2.0","state":{"_model_module":"@jupyter-widgets/base","_model_module_version":"1.2.0","_model_name":"LayoutModel","_view_count":null,"_view_module":"@jupyter-widgets/base","_view_module_version":"1.2.0","_view_name":"LayoutView","align_content":null,"align_items":null,"align_self":null,"border":null,"bottom":null,"display":null,"flex":null,"flex_flow":null,"grid_area":null,"grid_auto_columns":null,"grid_auto_flow":null,"grid_auto_rows":null,"grid_column":null,"grid_gap":null,"grid_row":null,"grid_template_areas":null,"grid_template_columns":null,"grid_template_rows":null,"height":null,"justify_content":null,"justify_items":null,"left":null,"margin":null,"max_height":null,"max_width":null,"min_height":null,"min_width":null,"object_fit":null,"object_position":null,"order":null,"overflow":null,"overflow_x":null,"overflow_y":null,"padding":null,"right":null,"top":null,"visibility":null,"width":null}}}},"accelerator":"GPU","gpuClass":"standard"},"cells":[{"cell_type":"code","source":["!git clone https://github.com/axelcarlier/projsemisup\n","!pip install rich"],"metadata":{"id":"oip8sNddZXqQ","executionInfo":{"status":"ok","timestamp":1674483788419,"user_tz":-60,"elapsed":115397,"user":{"displayName":"Damien Guillotin","userId":"16184557771337807994"}},"colab":{"base_uri":"https://localhost:8080/"},"outputId":"7b061793-d20a-4ba7-d5f0-84c2d72edce4"},"execution_count":1,"outputs":[{"output_type":"stream","name":"stdout","text":["Cloning into 'projsemisup'...\n","remote: Enumerating objects: 48161, done.\u001b[K\n","remote: Total 48161 (delta 0), reused 0 (delta 0), pack-reused 48161\u001b[K\n","Receiving objects: 100% (48161/48161), 2.96 GiB | 28.92 MiB/s, done.\n","Resolving deltas: 100% (44/44), done.\n","Updating files: 100% (22857/22857), done.\n","Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Collecting rich\n"," Downloading rich-13.2.0-py3-none-any.whl (238 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m238.9/238.9 KB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: pygments<3.0.0,>=2.6.0 in /usr/local/lib/python3.8/dist-packages (from rich) (2.6.1)\n","Collecting markdown-it-py<3.0.0,>=2.1.0\n"," Downloading markdown_it_py-2.1.0-py3-none-any.whl (84 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.5/84.5 KB\u001b[0m \u001b[31m12.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: typing-extensions<5.0,>=4.0.0 in /usr/local/lib/python3.8/dist-packages (from rich) (4.4.0)\n","Collecting mdurl~=0.1\n"," Downloading mdurl-0.1.2-py3-none-any.whl (10.0 kB)\n","Installing collected packages: mdurl, markdown-it-py, rich\n","Successfully installed markdown-it-py-2.1.0 mdurl-0.1.2 rich-13.2.0\n"]}]},{"cell_type":"code","execution_count":2,"metadata":{"id":"qwlA5PJlI7NM","executionInfo":{"status":"ok","timestamp":1674483795463,"user_tz":-60,"elapsed":7072,"user":{"displayName":"Damien Guillotin","userId":"16184557771337807994"}}},"outputs":[],"source":["import os\n","\n","from PIL.Image import open\n","from PIL.Image import ANTIALIAS\n","\n","import numpy as np\n","from numpy import expand_dims\n","from numpy import zeros\n","from numpy import ones\n","from numpy import asarray\n","from numpy.random import randn\n","from numpy.random import randint\n","\n","from keras import backend\n","from keras.models import Model\n","from keras.models import load_model\n","from keras.layers import Input\n","from keras.layers import Dense\n","from keras.layers import Reshape\n","from keras.layers import Flatten\n","from keras.layers import Conv2D\n","from keras.layers import Conv2DTranspose\n","from keras.layers import LeakyReLU\n","from keras.layers import Dropout\n","from keras.layers import Lambda\n","from keras.layers import Activation\n","from keras.layers import Concatenate\n","from keras.layers import BatchNormalization\n","from keras.metrics import SparseTopKCategoricalAccuracy\n","from keras.metrics import SparseCategoricalAccuracy\n","from keras.metrics import BinaryAccuracy\n","from keras.optimizers import Adam\n","from keras.applications import MobileNet\n","\n","import tensorflow as tf\n","import tensorflow_datasets.public_api as tfds\n","\n","import imgaug.augmenters as iaa\n","\n","from matplotlib import pyplot\n","\n","from rich.progress import track"]},{"cell_type":"code","source":["IMAGE_SIZE = 64\n","LATENT_DIM = 512\n","BATCH_SIZE = 128\n","LEARNING_RATE = 3e-5\n","\n","PATH = '/content/projsemisup/'\n","CLASSES = os.listdir(PATH + 'Lab/')\n","NB_CLASSES = len(CLASSES)\n","LAB_COUNT = len(os.listdir(PATH + 'Lab/' + CLASSES[0] + '/'))\n","TEST_COUNT = len(os.listdir(PATH + 'Test/' + CLASSES[0] + '/'))"],"metadata":{"id":"86O4uJQYY8Vd","executionInfo":{"status":"ok","timestamp":1674483944644,"user_tz":-60,"elapsed":8,"user":{"displayName":"Damien Guillotin","userId":"16184557771337807994"}}},"execution_count":16,"outputs":[]},{"cell_type":"code","source":["def load_semisup_data():\n","\n"," classes = os.listdir(PATH + 'Lab/')\n","\n"," # Initialise les structures de données\n"," x_lab = np.zeros((LAB_COUNT * NB_CLASSES, IMAGE_SIZE, IMAGE_SIZE, 3), dtype=np.uint8)\n"," y_lab = np.zeros((LAB_COUNT * NB_CLASSES, 1))\n"," i = 0\n"," for c in track(classes, description='Labs...'):\n","\n"," class_label = classes.index(c)\n"," list_images = os.listdir(PATH + 'Lab/' + c + '/')\n","\n"," for img_name in list_images:\n"," # Lecture de l'image\n"," img = open(PATH + 'Lab/' + c + '/' + img_name)\n","\n"," # Mise à l'échelle de l'image\n"," img = img.resize((IMAGE_SIZE, IMAGE_SIZE), ANTIALIAS)\n"," img = img.convert('RGB')\n","\n"," # Remplissage de la variable x\n"," x_lab[i] = np.asarray(img, dtype=np.uint8)\n"," y_lab[i] = class_label\n"," i = i + 1\n","\n"," list_images = os.listdir(PATH + 'Unlab/')\n"," nb_unlab = len(list_images)\n"," x_unlab = np.zeros((nb_unlab, IMAGE_SIZE, IMAGE_SIZE, 3), dtype=np.uint8)\n"," i = 0\n"," for img_name in track(list_images, description='Unlabs...'):\n"," # Lecture de l'image\n"," img = open(PATH + 'Unlab/' + img_name)\n","\n"," # Mise à l'échelle de l'image\n"," img = img.resize((IMAGE_SIZE, IMAGE_SIZE), ANTIALIAS)\n"," img = img.convert('RGB')\n","\n"," # Remplissage de la variable x\n"," x_unlab[i] = np.asarray(img, dtype=np.uint8)\n"," i = i + 1\n","\n"," file_PATH_test = os.listdir(PATH + 'Test/')\n","\n"," # Initialise les structures de données\n"," x_test = np.zeros((TEST_COUNT * NB_CLASSES, IMAGE_SIZE, IMAGE_SIZE, 3), dtype=np.uint8)\n"," y_test = np.zeros((TEST_COUNT * NB_CLASSES, 1))\n"," i = 0\n"," for c in track(file_PATH_test, description='Tests...'):\n","\n"," class_label = classes.index(c)\n"," list_images = os.listdir(PATH + 'Test/' + c + '/')\n","\n"," for img_name in list_images:\n"," # Lecture de l'image\n"," img = open(PATH + 'Test/' + c + '/' + img_name, )\n","\n"," # Mise à l'échelle de l'image\n"," img = img.resize((IMAGE_SIZE, IMAGE_SIZE), ANTIALIAS)\n"," img = img.convert('RGB')\n","\n"," # Remplissage de la variable x\n"," x_test[i] = np.asarray(img, dtype=np.uint8)\n"," y_test[i] = class_label\n"," i = i + 1\n","\n"," return (\n"," x_lab, y_lab,\n"," x_unlab,\n"," x_test, y_test\n"," )"],"metadata":{"id":"UwadDSV1Y_Sc","executionInfo":{"status":"ok","timestamp":1674483795466,"user_tz":-60,"elapsed":26,"user":{"displayName":"Damien Guillotin","userId":"16184557771337807994"}}},"execution_count":4,"outputs":[]},{"cell_type":"code","source":["def define_discriminator():\n","\n"," in_image = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))\n"," fe = Conv2D(128, (3, 3), strides=(2, 2), padding='same')(in_image)\n"," fe = LeakyReLU(alpha=0.2)(fe)\n"," fe = Conv2D(128, (3, 3), strides=(2, 2), padding='same')(fe)\n"," fe = LeakyReLU(alpha=0.2)(fe)\n"," fe = Conv2D(128, (3, 3), strides=(2, 2), padding='same')(fe)\n"," fe = LeakyReLU(alpha=0.2)(fe)\n"," fe = Conv2D(128, (3, 3), strides=(2, 2), padding='same')(fe)\n"," fe = LeakyReLU(alpha=0.2)(fe)\n"," fe = Flatten()(fe)\n"," fe = Dropout(0.4)(fe)\n"," fe = Dense(128)(fe)\n","\n"," # mbn = MobileNet(\n"," # input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),\n"," # classes=NB_CLASSES,\n"," # include_top=False,\n"," # weights=None\n"," # )\n"," # fe = Flatten()(mbn.output)\n","\n"," # Unsupervised output\n"," d_out_layer = Dense(1)(fe)\n"," d_out_layer = Activation('sigmoid')(d_out_layer)\n"," d_model = Model(in_image, d_out_layer)\n"," # d_model = Model(mbn.input, d_out_layer)\n"," opt = Adam(learning_rate=LEARNING_RATE)\n"," d_model.compile(loss='binary_crossentropy',\n"," optimizer=opt, metrics=[BinaryAccuracy()])\n","\n"," # Supervised output\n"," c_out_layer = Dense(NB_CLASSES)(fe)\n"," c_out_layer = Activation('softmax')(c_out_layer)\n"," c_model = Model(in_image, c_out_layer)\n"," # c_model = Model(mbn.input, c_out_layer)\n"," opt = Adam(learning_rate=LEARNING_RATE)\n"," c_model.compile(loss='sparse_categorical_crossentropy',\n"," optimizer=opt, metrics=[SparseCategoricalAccuracy(), SparseTopKCategoricalAccuracy(k=3)])\n","\n"," return d_model, c_model"],"metadata":{"id":"mIfaYQKnZE2T","executionInfo":{"status":"ok","timestamp":1674483795467,"user_tz":-60,"elapsed":26,"user":{"displayName":"Damien Guillotin","userId":"16184557771337807994"}}},"execution_count":5,"outputs":[]},{"cell_type":"code","source":["def define_generator():\n","\n"," in_lat = Input(shape=(LATENT_DIM,))\n"," gen = Dense(4*4*256)(in_lat)\n"," gen = BatchNormalization(momentum=0.8)(gen)\n"," gen = LeakyReLU(alpha=0.2)(gen)\n"," gen = Reshape((4, 4, 256))(gen)\n"," gen = Conv2DTranspose(128, 4, strides=2, padding=\"same\")(gen)\n"," gen = BatchNormalization(momentum=0.8)(gen)\n"," gen = LeakyReLU(alpha=0.2)(gen)\n"," gen = Conv2DTranspose(128, 4, strides=2, padding=\"same\")(gen)\n"," gen = BatchNormalization(momentum=0.8)(gen)\n"," gen = LeakyReLU(alpha=0.2)(gen)\n"," gen = Conv2DTranspose(128, 4, strides=2, padding=\"same\")(gen)\n"," gen = BatchNormalization(momentum=0.8)(gen)\n"," gen = LeakyReLU(alpha=0.2)(gen)\n"," gen = Conv2DTranspose(128, 4, strides=2, padding=\"same\")(gen)\n"," gen = BatchNormalization(momentum=0.8)(gen)\n"," gen = LeakyReLU(alpha=0.2)(gen)\n"," out_layer = Conv2D(3, (3, 3), padding=\"same\", activation=\"tanh\")(gen)\n","\n"," model = Model(in_lat, out_layer)\n","\n"," return model"],"metadata":{"id":"5Bcd3eRCZHkG","executionInfo":{"status":"ok","timestamp":1674483795468,"user_tz":-60,"elapsed":26,"user":{"displayName":"Damien Guillotin","userId":"16184557771337807994"}}},"execution_count":6,"outputs":[]},{"cell_type":"code","source":["def define_gan(g_model, d_model):\n","\n"," d_model.trainable = False\n"," gan_output = d_model(g_model.output)\n","\n"," model = Model(g_model.input, gan_output)\n","\n"," opt = Adam(learning_rate=2*LEARNING_RATE)\n"," model.compile(loss='binary_crossentropy',\n"," optimizer=opt, metrics=[BinaryAccuracy()])\n","\n"," return model"],"metadata":{"id":"0cAFqWRCZJLW","executionInfo":{"status":"ok","timestamp":1674483795469,"user_tz":-60,"elapsed":26,"user":{"displayName":"Damien Guillotin","userId":"16184557771337807994"}}},"execution_count":7,"outputs":[]},{"cell_type":"code","source":["def select_supervised_samples(dataset, n_samples=100):\n"," X, y = dataset\n"," X_list, y_list = list(), list()\n"," n_per_class = int(n_samples / NB_CLASSES)\n"," for i in range(NB_CLASSES):\n"," # get all images for this class\n"," X_with_class = X[y == i]\n"," # choose random instances\n"," ix = randint(0, len(X_with_class), n_per_class)\n"," # add to list\n"," [X_list.append(X_with_class[j]) for j in ix]\n"," [y_list.append(i) for j in ix]\n"," return asarray(X_list), asarray(y_list)"],"metadata":{"id":"pSDIItbJZLaR","executionInfo":{"status":"ok","timestamp":1674483795470,"user_tz":-60,"elapsed":26,"user":{"displayName":"Damien Guillotin","userId":"16184557771337807994"}}},"execution_count":8,"outputs":[]},{"cell_type":"code","source":["def generate_real_samples(dataset, n_samples, aug=None):\n"," # split into images and labels\n"," images, labels = dataset\n"," # choose random instances\n"," ix = randint(0, images.shape[0], n_samples)\n"," # select images and labels\n"," X = images[ix]\n"," if labels is not None:\n"," labels = labels[ix]\n"," # generate class labels\n"," y = ones((n_samples, 1))\n"," # apply augmentation\n"," if aug is not None:\n"," X = aug(images=X)\n"," # normalize batch\n"," X = X.astype(\"double\") / 127.5 - 1.0\n"," return X, labels, y"],"metadata":{"id":"x_ZiI1KdZNVT","executionInfo":{"status":"ok","timestamp":1674483795794,"user_tz":-60,"elapsed":350,"user":{"displayName":"Damien Guillotin","userId":"16184557771337807994"}}},"execution_count":9,"outputs":[]},{"cell_type":"code","source":["def generate_latent_points(n_samples):\n"," # generate points in the latent space\n"," z_input = randn(n_samples * LATENT_DIM)\n"," # reshape into a batch of inputs for the network\n"," z_input = z_input.reshape(n_samples, LATENT_DIM)\n"," return z_input"],"metadata":{"id":"lkUFZ6xOZPJS","executionInfo":{"status":"ok","timestamp":1674483795797,"user_tz":-60,"elapsed":30,"user":{"displayName":"Damien Guillotin","userId":"16184557771337807994"}}},"execution_count":10,"outputs":[]},{"cell_type":"code","source":["def generate_fake_samples(generator, n_samples):\n"," # generate points in latent space\n"," z_input = generate_latent_points(n_samples)\n"," # predict outputs\n"," images = generator.predict(z_input, verbose=0)\n"," # create class labels\n"," y = zeros((n_samples, 1))\n"," return images, y"],"metadata":{"id":"XvOWaVWsZQla","executionInfo":{"status":"ok","timestamp":1674483795799,"user_tz":-60,"elapsed":28,"user":{"displayName":"Damien Guillotin","userId":"16184557771337807994"}}},"execution_count":11,"outputs":[]},{"cell_type":"code","source":["def summarize_performance(step, d_model, g_model, c_model, dataset, n_samples=100):\n"," # prepare fake examples\n"," X, _ = generate_fake_samples(g_model, n_samples)\n"," # scale from [-1,1] to [0,1]\n"," X = (X + 1) / 2.0\n"," # create figure\n"," pyplot.figure(figsize=(10, 4))\n"," # plot images\n"," for i in range(10):\n"," # define subplot\n"," pyplot.subplot(2, 5, 1 + i)\n"," # turn off axis\n"," pyplot.axis('off')\n"," # plot raw pixel data\n"," pyplot.imshow(X[i, :, :, :])\n"," pyplot.show()\n"," \n"," # get test images\n"," _, _, _, x_test, y_test = dataset\n"," # normalize images\n"," x_test = x_test.astype(\"double\") / 127.5 - 1.0\n"," # evaluate the classifier model\n"," _, acc, acc3 = c_model.evaluate(x_test, y_test, verbose=0)\n"," print(f\"Test accuracy:\\n top 1: {acc*100:.3f}%\\n top 3: {acc3*100:.3f}%\")\n"," # save the discriminator model\n"," filename1 = f\"d_model_{step+1:04d}.h5\"\n"," d_model.save_weights(filename1)\n"," # save the generator model\n"," filename2 = f\"g_model_{step+1:04d}.h5\"\n"," g_model.save_weights(filename2)\n"," # save the classifier model\n"," filename3 = f\"c_model_{step+1:04d}.h5\"\n"," c_model.save_weights(filename3)\n"," print(f\"Saved: {filename1}, {filename2}, {filename3}\\n\")\n","\n"," return acc, acc3"],"metadata":{"id":"6xybda-3ZTFc","executionInfo":{"status":"ok","timestamp":1674483795801,"user_tz":-60,"elapsed":29,"user":{"displayName":"Damien Guillotin","userId":"16184557771337807994"}}},"execution_count":12,"outputs":[]},{"cell_type":"code","source":["def train(g_model, d_model, c_model, gan_model, dataset, aug=None, nb_epochs=100):\n"," # select supervised dataset\n"," x_lab, y_lab, x_unlab, _, _ = dataset\n"," # calculate the number of batches per training epoch\n"," bat_per_epo = int(x_unlab.shape[0] / BATCH_SIZE)\n"," # store accuracies\n"," accuracies = {'train': {'top1': [], 'top3': []}, 'val': {'top1': [], 'top3': []}}\n","\n"," print(f\"nb_epochs={nb_epochs}, batch_size={BATCH_SIZE}, b/e={bat_per_epo}\\n\")\n"," for epoch in range(nb_epochs):\n","\n"," print(f\"epoch {epoch+1:3}/{nb_epochs}:\")\n"," for step in range(bat_per_epo):\n"," # update supervised discriminator (c)\n"," X, Y, _ = generate_real_samples([x_lab, y_lab], BATCH_SIZE, aug)\n"," c_loss, c_acc, c_acc3 = c_model.train_on_batch(X, Y)\n"," # update unsupervised discriminator (d)\n"," X_real, _, Y_real = generate_real_samples([x_unlab, None], BATCH_SIZE//2)\n"," X_fake, Y_fake = generate_fake_samples(g_model, BATCH_SIZE//2)\n"," X = np.concatenate([X_real, X_fake], axis=0)\n"," Y = np.concatenate([Y_real, Y_fake], axis=0)\n"," d_loss, d_acc = d_model.train_on_batch(X, Y)\n"," # update generator (g)\n"," X_gan, y_gan = generate_latent_points(BATCH_SIZE), ones((BATCH_SIZE, 1))\n"," g_loss, g_acc = gan_model.train_on_batch(X_gan, y_gan)\n"," # show losses and accuracies\n"," print(f\"\\rstep {step+1:5}/{bat_per_epo}:\" +\n"," f\" c[{c_loss:7.3f}, {c_acc*100:3.0f}%, {c_acc3*100:3.0f}%]\" +\n"," f\" d[{d_loss:7.3f}, {d_acc*100:3.0f}%]\" +\n"," f\" g[{g_loss:7.3f}, {g_acc*100:3.0f}%]\",\n"," end='')\n"," # evaluate the model\n"," acc, acc3 = summarize_performance(epoch, d_model, g_model, c_model, dataset)\n"," accuracies['train']['top1'].append(c_acc)\n"," accuracies['train']['top3'].append(c_acc3)\n"," accuracies['val']['top1'].append(acc)\n"," accuracies['val']['top3'].append(acc3)\n"," print(accuracies)"],"metadata":{"id":"2MoC8sysY0Lc","executionInfo":{"status":"ok","timestamp":1674483795803,"user_tz":-60,"elapsed":29,"user":{"displayName":"Damien Guillotin","userId":"16184557771337807994"}}},"execution_count":13,"outputs":[]},{"cell_type":"code","source":["# load image data\n","dataset = load_semisup_data()"],"metadata":{"id":"_GU7lrzCaVzF","colab":{"base_uri":"https://localhost:8080/","height":65,"referenced_widgets":["8449be8723ae49a998eef74ac99830ae","70abdf1919c642108f09f8b05e0cbbd2","140350eb93d342db996b0ef188a99a01","bb5381d4f3f54262899685bd44ad4db6","32f04016312a486e8920415c14df928f","00b216d0f53045c5b1533d1d4266f756"]},"executionInfo":{"status":"ok","timestamp":1674483943912,"user_tz":-60,"elapsed":148137,"user":{"displayName":"Damien Guillotin","userId":"16184557771337807994"}},"outputId":"5934c561-0c6f-43c7-a9a7-dad5f29f8490"},"execution_count":14,"outputs":[{"output_type":"display_data","data":{"text/plain":["Output()"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"8449be8723ae49a998eef74ac99830ae"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":[],"text/html":["
\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["\n"],"text/html":["
\n","
\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["Output()"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"140350eb93d342db996b0ef188a99a01"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":[],"text/html":["
\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["\n"],"text/html":["
\n","
\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["Output()"],"application/vnd.jupyter.widget-view+json":{"version_major":2,"version_minor":0,"model_id":"32f04016312a486e8920415c14df928f"}},"metadata":{}},{"output_type":"display_data","data":{"text/plain":[],"text/html":["
\n"]},"metadata":{}},{"output_type":"display_data","data":{"text/plain":["\n"],"text/html":["
\n","
\n"]},"metadata":{}}]},{"cell_type":"code","source":["# augmentation\n","aug = iaa.RandAugment(n=2, m=9)"],"metadata":{"id":"vlC87HYvWUY8","executionInfo":{"status":"ok","timestamp":1674483943913,"user_tz":-60,"elapsed":15,"user":{"displayName":"Damien Guillotin","userId":"16184557771337807994"}}},"execution_count":15,"outputs":[]},{"cell_type":"code","source":["# create the discriminator models\n","d_model, c_model = define_discriminator()\n","# create the generator\n","g_model = define_generator()\n","# create the gan\n","gan_model = define_gan(g_model, d_model)\n","# train model\n","train(g_model, d_model, c_model, gan_model, dataset, aug)"],"metadata":{"id":"zuWf9tXqZVE-"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["backup_name = '0040'\n","\n","# create the discriminator models\n","d_model, c_model = define_discriminator()\n","# create the generator\n","g_model = define_generator()\n","\n","d_model.load_weights(f\"d_model_{backup_name}.h5\")\n","c_model.load_weights(f\"c_model_{backup_name}.h5\")\n","g_model.load_weights(f\"g_model_{backup_name}.h5\")\n","\n","# create the gan\n","gan_model = define_gan(g_model, d_model)\n","\n","# create the gan\n","gan_model = define_gan(g_model, d_model)\n","# train model\n","train(g_model, d_model, c_model, gan_model, dataset, aug)"],"metadata":{"id":"7btFJZgciMoR","colab":{"base_uri":"https://localhost:8080/","height":1000,"output_embedded_package_id":"12w46u3QlSkjLBeQWCKr_T3su-v3efvre"},"outputId":"17cc1f09-88dc-4080-ec60-fb07b874e32d","executionInfo":{"status":"ok","timestamp":1674493900342,"user_tz":-60,"elapsed":2138671,"user":{"displayName":"Damien Guillotin","userId":"16184557771337807994"}}},"execution_count":18,"outputs":[{"output_type":"display_data","data":{"text/plain":"Output hidden; open in https://colab.research.google.com to view."},"metadata":{}}]}]} \ No newline at end of file