1019 lines
34 KiB
Plaintext
1019 lines
34 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "BwpJ5IffzRG6"
|
||
},
|
||
"source": [
|
||
"Dans ce TP, nous allons entraîner et tester un RNN pour la génération de texte. Plus précisément, le modèle que nous allons construire sera capable, étant donnée une séquence de caractères, de prédire le prochain caractère le plus probable.\n",
|
||
"\n",
|
||
"![model](https://drive.google.com/uc?id=1syE1phix6Pu-b8y9ktol0thCdC2lzmlV\n",
|
||
")\n",
|
||
"\n",
|
||
"Il sera alors possible, partant d'une chaîne de caractères, de réaliser plus inférences successivement du modèle pour générer la suite de la phrase.\n",
|
||
"\n",
|
||
"![inference](https://drive.google.com/uc?id=1T6J3UgFV4Q2JhJm3984HhJIkH7ukWkb7\n",
|
||
")\n",
|
||
"\n",
|
||
"Les phrases suivantes ont été obtenues à l'issue d'un entraînement du modèle sur une base de données regroupant les tweets de Donald Trump (à partir respectivement des débuts de phrase 'China', 'Obama', et 'Mo'). Même si les phrases ne sont pas complètement correctes, le modèle arrive à générer des mots existants (pour la plupart) et à les enchaîner d'une manière tout de même relativement crédible !\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "HcygKkEVZBaa"
|
||
},
|
||
"source": [
|
||
"<pre>\n",
|
||
"China on dollars are sources to other things!The Fake News State approvement is smart, restlected & unfair \n",
|
||
"\n",
|
||
"Obama BEAT!Not too late. This is the only requirement. Also, the Fake News is running a big democrats want \n",
|
||
"\n",
|
||
"More system. See you really weak!Thank you. You and others just doesn’t exist.\n",
|
||
"</pre>"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "_bGsCP9DZFQ5"
|
||
},
|
||
"source": [
|
||
"Le reste du notebook a été adapté du tutorial https://www.tensorflow.org/text/tutorials/text_generation. Il n'y aura pas de code à compléter, l'objectif du TP est de découvrir comment préparer la base de données, implémenter et entraîner le modèle, et réaliser l'inférence.\n",
|
||
"\n",
|
||
"A vous de vous emparer du code ci-dessous pour essayer d'améliorer les performances du modèle et de générer les phrases les plus crédibles possibles ! "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"# Téléchargement des données"
|
||
],
|
||
"metadata": {
|
||
"id": "YDjEFzJ-ecuJ"
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:08.261025Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:08.260828Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:10.284556Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:10.283846Z"
|
||
},
|
||
"id": "yG_n40gFzf9s"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import tensorflow as tf\n",
|
||
"\n",
|
||
"import numpy as np\n",
|
||
"import os\n",
|
||
"import time"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "EHDoRoc5PKWz"
|
||
},
|
||
"source": [
|
||
"Téléchargement des données"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:10.288588Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:10.288339Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:10.512538Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:10.511842Z"
|
||
},
|
||
"id": "pD_55cOxLkAb"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"path_to_file = tf.keras.utils.get_file('realdonaltrump.csv', 'https://drive.google.com/uc?export=download&id=1s1isv9TQjGiEr2gG__8bOdBFvQlmepRt')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "UHjdCjDuSvX_"
|
||
},
|
||
"source": [
|
||
"On commence par extraire les tweets du CSV (notez qu'il y a d'autres métadonnées dans le fichier, comme le nombre de retweets par exemple, qui pourraient être utilisées pour d'autres tâches)."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"source": [
|
||
"import csv\n",
|
||
"tweets = []\n",
|
||
"text = ''\n",
|
||
"with open(path_to_file, newline='') as csvfile:\n",
|
||
" reader = csv.DictReader(csvfile)\n",
|
||
" for row in reader:\n",
|
||
" tweets.append(row['content'])\n",
|
||
" text += row['content']\n",
|
||
"\n",
|
||
"# Affichage des 10 premiers tweets\n",
|
||
"print(tweets[:10])"
|
||
],
|
||
"metadata": {
|
||
"id": "BPfeRNpwC-Ev"
|
||
},
|
||
"execution_count": null,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:10.515842Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:10.515602Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:10.521336Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:10.520758Z"
|
||
},
|
||
"id": "aavnuByVymwK"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Nombre total de caractères du dataset\n",
|
||
"print(f'Longueur totale du texte: {len(text)} caractères')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:10.530172Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:10.529967Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:10.544918Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:10.544310Z"
|
||
},
|
||
"id": "IlCgQBRVymwR"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Extraction des caractères uniques du texte\n",
|
||
"vocab = sorted(set(text))\n",
|
||
"print(f'{len(vocab)} unique caractères')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"source": [
|
||
"# Affichage du vocabulaire\n",
|
||
"print(vocab)"
|
||
],
|
||
"metadata": {
|
||
"id": "hSAriIDsEa-J"
|
||
},
|
||
"execution_count": null,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "rNnrKn_lL-IJ"
|
||
},
|
||
"source": [
|
||
"# Préparation des données"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "LFjSVAlWzf-N"
|
||
},
|
||
"source": [
|
||
"Il est nécessaire de convertir les caractères dans une représentation admissible par le modèle. \n",
|
||
"\n",
|
||
"La fonction `tf.keras.layers.StringLookup` convertit les chaînes de caractères en nombre, en reprenant l'indice de chaque caractère dans le vocabulaire établi précédemment.\n",
|
||
"\n",
|
||
"Il faut cependant commencer par séparer le texte en caractères, comme présenté sur l'exemple ci-dessous."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:10.547650Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:10.547458Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:12.216225Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:12.215486Z"
|
||
},
|
||
"id": "a86OoYtO01go"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"example_texts = ['abcdefg', 'xyz']\n",
|
||
"\n",
|
||
"chars = tf.strings.unicode_split(example_texts, input_encoding='UTF-8')\n",
|
||
"chars"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "1s4f1q3iqY8f"
|
||
},
|
||
"source": [
|
||
"On peut ensuire appliquer la fonction `tf.keras.layers.StringLookup` :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:12.219925Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:12.219236Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:12.230858Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:12.230260Z"
|
||
},
|
||
"id": "6GMlCe3qzaL9"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"ids_from_chars = tf.keras.layers.StringLookup(\n",
|
||
" vocabulary=list(vocab), mask_token=None)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:12.234414Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:12.233898Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:12.241111Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:12.240543Z"
|
||
},
|
||
"id": "WLv5Q_2TC2pc"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"ids = ids_from_chars(chars)\n",
|
||
"ids"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "tZfqhkYCymwX"
|
||
},
|
||
"source": [
|
||
"Pour retrouver un texte à partir de sa représentation numérique (ce sera utile lors de l'étape finale de génération) il faut être capable d'inverser le processus, ce que l'on peut faire avec `tf.keras.layers.StringLookup(..., invert=True)`. "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:12.244473Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:12.244010Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:12.251817Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:12.251224Z"
|
||
},
|
||
"id": "Wd2m3mqkDjRj"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"chars_from_ids = tf.keras.layers.StringLookup(\n",
|
||
" vocabulary=ids_from_chars.get_vocabulary(), invert=True, mask_token=None)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:12.255016Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:12.254548Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:12.259845Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:12.259298Z"
|
||
},
|
||
"id": "c2GCh0ySD44s"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"chars = chars_from_ids(ids)\n",
|
||
"chars"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "-FeW5gqutT3o"
|
||
},
|
||
"source": [
|
||
"Enfin, on peut recréer une chaîne de caractères :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:12.263027Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:12.262587Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:12.273779Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:12.273214Z"
|
||
},
|
||
"id": "zxYI-PeltqKP"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"tf.strings.reduce_join(chars, axis=-1).numpy()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:12.276992Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:12.276505Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:12.280017Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:12.279421Z"
|
||
},
|
||
"id": "w5apvBDn9Ind"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"def text_from_ids(ids):\n",
|
||
" return tf.strings.reduce_join(chars_from_ids(ids), axis=-1)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "hgsVvVxnymwf"
|
||
},
|
||
"source": [
|
||
"Il faut maintenant créer les exemples d'apprentissage, ainsi que leurs labels associés. Pour cela, nous allons diviser le texte en séquences, chacune composée de `seq_length` caractères.\n",
|
||
"\n",
|
||
"Pour chaque séquence constituant un ensemble d'apprentissage, le label à prédire correspondant est une séquence de même longueur dont tous les caractères ont été décalés d'un cran. \n",
|
||
"\n",
|
||
"Une manière simple de constituer notre base est donc de diviser le texte en séquences de longueur `seq_length+1`, et d'utiliser les `seq_length` premiers caractères comme donnée, et les `seq_length` derniers caractères comme label.\n",
|
||
"\n",
|
||
"\n",
|
||
"N.B. Cette manière de faire n'est clairement pas optimale ! Certaines séquences vont recouvrir deux tweets successifs, qui n'auront potentiellement aucun lien entre eux !"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:12.283397Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:12.282940Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:12.693441Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:12.692846Z"
|
||
},
|
||
"id": "UopbsKi88tm5"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"all_ids = ids_from_chars(tf.strings.unicode_split(text, 'UTF-8'))\n",
|
||
"all_ids"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:12.696712Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:12.696216Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:12.700189Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:12.699616Z"
|
||
},
|
||
"id": "qmxrYDCTy-eL"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"ids_dataset = tf.data.Dataset.from_tensor_slices(all_ids)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:12.703062Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:12.702570Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:12.721321Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:12.720696Z"
|
||
},
|
||
"id": "cjH5v45-yqqH"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"for ids in ids_dataset.take(10):\n",
|
||
" print(chars_from_ids(ids).numpy().decode('utf-8'))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:12.724378Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:12.723876Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:12.726907Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:12.726324Z"
|
||
},
|
||
"id": "C-G2oaTxy6km"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"seq_length = 50\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "-ZSYAcQV8OGP"
|
||
},
|
||
"source": [
|
||
"La méthode `batch` permet de regrouper les caractères du texte en séquences de la longueur voulue."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:12.729823Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:12.729461Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:12.740383Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:12.739819Z"
|
||
},
|
||
"id": "BpdjRO2CzOfZ"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"sequences = ids_dataset.batch(seq_length+1, drop_remainder=True)\n",
|
||
"\n",
|
||
"for seq in sequences.take(1):\n",
|
||
" print(chars_from_ids(seq))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "5PHW902-4oZt"
|
||
},
|
||
"source": [
|
||
"Voici par exemple les premières séquences extraites du dataset :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:12.743335Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:12.742989Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:12.754659Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:12.754086Z"
|
||
},
|
||
"id": "QO32cMWu4a06"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"for seq in sequences.take(5):\n",
|
||
" print(text_from_ids(seq).numpy())"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "UbLcIPBj_mWZ"
|
||
},
|
||
"source": [
|
||
"Nous allons maintenant générer les couples (données, labels) à partir des séquences extraites :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:12.757751Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:12.757363Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:12.760852Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:12.760178Z"
|
||
},
|
||
"id": "9NGu-FkO_kYU"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"def split_input_target(sequence):\n",
|
||
" input_text = sequence[:-1]\n",
|
||
" target_text = sequence[1:]\n",
|
||
" return input_text, target_text"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:12.763663Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:12.763249Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:12.767978Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:12.767379Z"
|
||
},
|
||
"id": "WxbDTJTw5u_P"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"split_input_target(list(\"Tensorflow\"))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:12.770840Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:12.770439Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:12.816930Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:12.816374Z"
|
||
},
|
||
"id": "B9iKPXkw5xwa"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"dataset = sequences.map(split_input_target)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:12.820172Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:12.819702Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:12.842329Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:12.841738Z"
|
||
},
|
||
"id": "GNbw-iR0ymwj"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"for input_example, target_example in dataset.take(1):\n",
|
||
" print(\"Input :\", text_from_ids(input_example).numpy())\n",
|
||
" print(\"Target:\", text_from_ids(target_example).numpy())"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "MJdfPmdqzf-R"
|
||
},
|
||
"source": [
|
||
"Avant de pouvoir fournir les données au modèle, il est important de les ranger dans un ordre aléatoire et de les regrouper en batches.\n",
|
||
"\n",
|
||
"Le paramètre `prefetch` permet d'organiser le chargement du prochain batch de données pendant que le modèle est en train de traiter le précédent."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:12.845656Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:12.845278Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:12.852723Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:12.852173Z"
|
||
},
|
||
"id": "p2pGotuNzf-S"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Batch size\n",
|
||
"BATCH_SIZE = 64\n",
|
||
"\n",
|
||
"# Buffer size to shuffle the dataset\n",
|
||
"# (TF data is designed to work with possibly infinite sequences,\n",
|
||
"# so it doesn't attempt to shuffle the entire sequence in memory. Instead,\n",
|
||
"# it maintains a buffer in which it shuffles elements).\n",
|
||
"BUFFER_SIZE = 10000\n",
|
||
"\n",
|
||
"dataset = (\n",
|
||
" dataset\n",
|
||
" .shuffle(BUFFER_SIZE)\n",
|
||
" .batch(BATCH_SIZE, drop_remainder=True)\n",
|
||
" .prefetch(tf.data.experimental.AUTOTUNE))\n",
|
||
"\n",
|
||
"dataset"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "r6oUuElIMgVx"
|
||
},
|
||
"source": [
|
||
"# Construction du modèle"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "m8gPwEjRzf-Z"
|
||
},
|
||
"source": [
|
||
"Le modèle sera composé de 3 couches seulement :\n",
|
||
"\n",
|
||
"* `tf.keras.layers.Embedding`: La couche d'entrée, qui permet d'apprendre un descripteur de dimension`embedding_dim` à associer à chacun des caractères passés en entrée;\n",
|
||
"* `tf.keras.layers.GRU`: Une cellule récurrente, avec `rnn_units` neurones (que l'on pourrait tout à fait remplacer par un LSTM)\n",
|
||
"* `tf.keras.layers.Dense`: La couche de sortie, avec `vocab_size` neurones. Notez qu'on ne spécifie pas la fonction d'activation (`softmax`) car elle est intégrée directement dans la fonction de coût."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:12.856129Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:12.855554Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:12.860086Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:12.859494Z"
|
||
},
|
||
"id": "zHT8cLh7EAsg"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Taille du vocabulaire\n",
|
||
"vocab_size = len(ids_from_chars.get_vocabulary())\n",
|
||
"\n",
|
||
"# Dimension des descripteurs de caractères\n",
|
||
"embedding_dim = 256\n",
|
||
"\n",
|
||
"# Nombre de neurones du GRU\n",
|
||
"rnn_units = 512"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:12.862758Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:12.862561Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:12.868444Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:12.867942Z"
|
||
},
|
||
"id": "wj8HQ2w8z4iO"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"class MyModel(tf.keras.Model):\n",
|
||
" def __init__(self, vocab_size, embedding_dim, rnn_units):\n",
|
||
" super().__init__(self)\n",
|
||
" self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
|
||
" self.gru = tf.keras.layers.GRU(rnn_units,\n",
|
||
" return_sequences=True,\n",
|
||
" return_state=True)\n",
|
||
" self.dense = tf.keras.layers.Dense(vocab_size)\n",
|
||
"\n",
|
||
" def call(self, inputs, states=None, return_state=False, training=False):\n",
|
||
" x = inputs\n",
|
||
" x = self.embedding(x, training=training)\n",
|
||
" if states is None:\n",
|
||
" states = self.gru.get_initial_state(x)\n",
|
||
" x, states = self.gru(x, initial_state=states, training=training)\n",
|
||
" x = self.dense(x, training=training)\n",
|
||
"\n",
|
||
" if return_state:\n",
|
||
" return x, states\n",
|
||
" else:\n",
|
||
" return x"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"N.B. Cette manière inhabituelle de définir le modèle (qui ressemble d'ailleurs beaucoup au formalisme Pytorch) est utile pour l'inférence. Nous aurions pu utiliser un modèle classique (construit avec `keras.Sequential`) mais cela ne nous aurait pas donné d'accès simple aux états internes du GRU. Il sera important de pouvoir manipuler cet état lorsque nous enchaînerons plusieurs prédictions successives, qui nécessiteront chaque fois de repartir de l'état obtenu lors de la prédiction précédente. Cette opération n'est pas possible avec le modèle Sequentiel que nous utilisons d'habitude."
|
||
],
|
||
"metadata": {
|
||
"id": "Pjcb4E4jkmi5"
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:12.871482Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:12.870964Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:12.884188Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:12.883662Z"
|
||
},
|
||
"id": "IX58Xj9z47Aw"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"model = MyModel(\n",
|
||
" vocab_size=vocab_size,\n",
|
||
" embedding_dim=embedding_dim,\n",
|
||
" rnn_units=rnn_units)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "RkA5upJIJ7W7"
|
||
},
|
||
"source": [
|
||
"Pour chaque caractère de la séquence, le modèle produit le descripteur associé, applique un pas de temps du GRU et enfin applique la couche dense pour obtenir la prédiction du réseau :\n",
|
||
"\n",
|
||
"![A drawing of the data passing through the model](https://drive.google.com/uc?id=1GYD8U9aF-MTC1XpJ3VKpY1b0clJuO4wb)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"Nous pouvons tester notre modèle sur le premier exemple d'apprentissage, pour vérifier les dimensions :"
|
||
],
|
||
"metadata": {
|
||
"id": "skr44qgVmr3m"
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"source": [
|
||
"for input_example_batch, target_example_batch in dataset.take(1):\n",
|
||
" example_batch_predictions = model(input_example_batch)\n",
|
||
" print(example_batch_predictions.shape, \"# (batch_size, sequence_length, vocab_size)\")"
|
||
],
|
||
"metadata": {
|
||
"id": "BqUVnDFTmpUD"
|
||
},
|
||
"execution_count": null,
|
||
"outputs": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:19.114557Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:19.114071Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:19.127462Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:19.126830Z"
|
||
},
|
||
"id": "vPGmAAXmVLGC"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"model.summary()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "LJL0Q0YPY6Ee"
|
||
},
|
||
"source": [
|
||
"# Entraînement du modèle"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "YCbHQHiaa4Ic"
|
||
},
|
||
"source": [
|
||
"Le problème que nous cherchons à résoudre est celui d'une classification à `vocab_size` classes."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "UAjbjY03eiQ4"
|
||
},
|
||
"source": [
|
||
"On utilise la fonction de coût `tf.keras.losses.sparse_categorical_crossentropy` car nos labels sont sous forme d'indices (et pas de *one-hot vectors*). Le flag `from_logits` positionné à `True` indique qu'il faut au préalable appliquer la fonction softmax à la sortie du réseau."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:19.186765Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:19.186240Z",
|
||
"iopub.status.idle": "2022-05-03T11:14:19.197413Z",
|
||
"shell.execute_reply": "2022-05-03T11:14:19.196878Z"
|
||
},
|
||
"id": "DDl1_Een6rL0"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"model.compile(optimizer='adam', loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:14:19.214135Z",
|
||
"iopub.status.busy": "2022-05-03T11:14:19.213597Z",
|
||
"iopub.status.idle": "2022-05-03T11:16:14.167240Z",
|
||
"shell.execute_reply": "2022-05-03T11:16:14.166580Z"
|
||
},
|
||
"id": "UK-hmKjYVoll"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"history = model.fit(dataset, epochs=5)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "kKkD5M6eoSiN"
|
||
},
|
||
"source": [
|
||
"# Génération de texte"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "oIdQ8c8NvMzV"
|
||
},
|
||
"source": [
|
||
"Pour générer un texte, il suffit de partir d'une séquence initiale, d'effectuer une prédiction, et de conserver l'état interne du modèle pour pouvoir le restaurer lors de la prochaine inférence, qui prendra en entrée la séquence initiale augmentée du caractère prédit précédemment.\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "DjGz1tDkzf-u"
|
||
},
|
||
"source": [
|
||
"La classe suivante permet de réaliser une prédiction :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:16:14.171458Z",
|
||
"iopub.status.busy": "2022-05-03T11:16:14.170910Z",
|
||
"iopub.status.idle": "2022-05-03T11:16:14.180188Z",
|
||
"shell.execute_reply": "2022-05-03T11:16:14.179593Z"
|
||
},
|
||
"id": "iSBU1tHmlUSs"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"class OneStep(tf.keras.Model):\n",
|
||
" def __init__(self, model, chars_from_ids, ids_from_chars, temperature=1):\n",
|
||
" super().__init__()\n",
|
||
" self.temperature = temperature\n",
|
||
" self.model = model\n",
|
||
" self.chars_from_ids = chars_from_ids\n",
|
||
" self.ids_from_chars = ids_from_chars\n",
|
||
"\n",
|
||
"\n",
|
||
" @tf.function\n",
|
||
" def generate_one_step(self, inputs, states=None):\n",
|
||
" # Convert strings to token IDs.\n",
|
||
" input_chars = tf.strings.unicode_split(inputs, 'UTF-8')\n",
|
||
" input_ids = self.ids_from_chars(input_chars).to_tensor()\n",
|
||
"\n",
|
||
" # Run the model.\n",
|
||
" # predicted_logits.shape is [batch, char, next_char_logits]\n",
|
||
" predicted_logits, states = self.model(inputs=input_ids, states=states,\n",
|
||
" return_state=True)\n",
|
||
" # Only use the last prediction.\n",
|
||
" predicted_logits = predicted_logits[:, -1, :]\n",
|
||
" predicted_logits = predicted_logits/self.temperature\n",
|
||
"\n",
|
||
" # Sample the output logits to generate token IDs.\n",
|
||
" predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)\n",
|
||
" predicted_ids = tf.squeeze(predicted_ids, axis=-1)\n",
|
||
"\n",
|
||
" # Convert from token ids to characters\n",
|
||
" predicted_chars = self.chars_from_ids(predicted_ids)\n",
|
||
"\n",
|
||
" # Return the characters and model state.\n",
|
||
" return predicted_chars, states"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:16:14.183384Z",
|
||
"iopub.status.busy": "2022-05-03T11:16:14.182965Z",
|
||
"iopub.status.idle": "2022-05-03T11:16:14.193938Z",
|
||
"shell.execute_reply": "2022-05-03T11:16:14.193266Z"
|
||
},
|
||
"id": "fqMOuDutnOxK"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"one_step_model = OneStep(model, chars_from_ids, ids_from_chars)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "p9yDoa0G3IgQ"
|
||
},
|
||
"source": [
|
||
"Il reste à appeler cette fonction dans une boucle opour générer un texte complet :"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"execution": {
|
||
"iopub.execute_input": "2022-05-03T11:16:14.197481Z",
|
||
"iopub.status.busy": "2022-05-03T11:16:14.197261Z",
|
||
"iopub.status.idle": "2022-05-03T11:16:16.999406Z",
|
||
"shell.execute_reply": "2022-05-03T11:16:16.998550Z"
|
||
},
|
||
"id": "ST7PSyk9t1mT"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"start = time.time()\n",
|
||
"states = None\n",
|
||
"next_char = tf.constant(['Obama '])\n",
|
||
"result = [next_char]\n",
|
||
"\n",
|
||
"for n in range(100):\n",
|
||
" next_char, states = one_step_model.generate_one_step(next_char, states=states)\n",
|
||
" result.append(next_char)\n",
|
||
"\n",
|
||
"result = tf.strings.join(result)\n",
|
||
"end = time.time()\n",
|
||
"print(result[0].numpy().decode('utf-8'), '\\n\\n' + '_'*80)\n",
|
||
"print('\\nRun time:', end - start)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"A vous de jouer pour améliorer les résultats. Vous pouvez par exemple : \n",
|
||
"- Jouer avec le paramètre de température dans la classe `OneStep` pour accentuer ou diminuer le caractère aléatoire des prédictions.\n",
|
||
"- Modifier le réseau en rajoutant des couches supplémentaires, ou en modifiant le nombre de neurones de la couche GRU.\n",
|
||
"- Modifier la préparation des données pour éviter le problème des séquences chevauchant plusieurs tweets\n",
|
||
"- Entraîner le modèle plus longtemps devrait également aider !"
|
||
],
|
||
"metadata": {
|
||
"id": "emMSmEg2qk_c"
|
||
}
|
||
}
|
||
],
|
||
"metadata": {
|
||
"accelerator": "GPU",
|
||
"colab": {
|
||
"toc_visible": true,
|
||
"provenance": []
|
||
},
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.7.5"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 0
|
||
} |