TP-reinforcement-learning/TP2/TP2.ipynb

674 lines
442 KiB
Plaintext
Raw Permalink Normal View History

2023-06-23 18:10:32 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# TP2 : Bandits\n",
"\n",
"Laurent Fainsin"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.37454012 0.95071431 0.73199394 0.59865848 0.15601864 0.15599452\n",
" 0.05808361 0.86617615 0.60111501 0.70807258]\n"
]
}
],
"source": [
"K = 10\n",
"T = 1000\n",
"EPOCHS = 5000\n",
"EPSILON = 0.15\n",
"np.random.seed(42)\n",
"PROB = np.random.rand(K)\n",
"BEST = np.argmax(PROB)\n",
"print(PROB)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# $\\epsilon$-greedy"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_2937633/564047854.py:12: RuntimeWarning: invalid value encountered in divide\n",
" a = np.argmax(rewards / counts)\n"
]
}
],
"source": [
"def epsilon_greedy(epochs=EPOCHS, t_max=T, epsilon=EPSILON):\n",
" A = -np.ones((epochs, t_max), dtype=np.int32)\n",
" R = np.zeros((epochs, t_max), dtype=np.float32)\n",
"\n",
" for epoch in range(epochs):\n",
" rewards = np.zeros(K)\n",
" counts = np.zeros(K)\n",
" for t in range(t_max):\n",
" if np.random.rand() < epsilon:\n",
" a = np.random.randint(K)\n",
" else:\n",
" a = np.argmax(rewards / counts)\n",
"\n",
" reward = np.random.rand() < PROB[a]\n",
" rewards[a] += reward\n",
" counts[a] += 1\n",
" \n",
" A[epoch, t] = a\n",
" R[epoch, t] = reward\n",
" \n",
" return A, R\n",
"\n",
"A_eps, R_eps = epsilon_greedy()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Decaying $\\epsilon$-greedy"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_2937633/3008607987.py:13: RuntimeWarning: invalid value encountered in divide\n",
" a = np.argmax(rewards / counts)\n"
]
}
],
"source": [
"def decay_epsilon_greedy(epochs=EPOCHS, t_max=T):\n",
" A = -np.ones((epochs, t_max), dtype=np.int32)\n",
" R = np.zeros((epochs, t_max), dtype=np.float32)\n",
"\n",
" for epoch in range(epochs):\n",
" rewards = np.zeros(K)\n",
" counts = np.zeros(K)\n",
" for t in range(t_max):\n",
" epsilon = 1.0 / (0.1 * t + 1)\n",
" if np.random.rand() < epsilon:\n",
" a = np.random.randint(K)\n",
" else:\n",
" a = np.argmax(rewards / counts)\n",
"\n",
" reward = np.random.rand() < PROB[a]\n",
" rewards[a] += reward\n",
" counts[a] += 1\n",
" \n",
" A[epoch, t] = a\n",
" R[epoch, t] = reward\n",
" \n",
" return A, R\n",
"\n",
"A_deps, R_deps = decay_epsilon_greedy()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Upper confidence interval"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_2937633/3010206541.py:12: RuntimeWarning: invalid value encountered in divide\n",
" a = np.argmax(rewards / counts + epsilons)\n"
]
}
],
"source": [
"def UCB(epochs=EPOCHS, t_max=T, c=2):\n",
" A = -np.ones((epochs, t_max), dtype=np.int32)\n",
" R = np.zeros((epochs, t_max), dtype=np.float32)\n",
"\n",
" for epoch in range(epochs):\n",
" rewards = np.zeros(K)\n",
" counts = np.zeros(K)\n",
" for t in range(t_max):\n",
" \n",
" epsilons = np.sqrt(c * np.log(t + 1) / (counts + 1e-5))\n",
" \n",
" a = np.argmax(rewards / counts + epsilons)\n",
"\n",
" reward = np.random.rand() < PROB[a]\n",
" rewards[a] += reward\n",
" counts[a] += 1\n",
" \n",
" A[epoch, t] = a\n",
" R[epoch, t] = reward\n",
" \n",
" return A, R\n",
"\n",
"A_ucb, R_ucb = UCB()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Optimistic greedy"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def optimistic_greedy(epochs=EPOCHS, t_max=T, Q0=100):\n",
" A = -np.ones((epochs, t_max), dtype=np.int32)\n",
" R = np.zeros((epochs, t_max), dtype=np.float32)\n",
"\n",
" for epoch in range(epochs):\n",
" value = Q0 * np.ones(K)\n",
" counts = np.zeros(K)\n",
" for t in range(t_max):\n",
" a = np.argmax(value)\n",
"\n",
" reward = np.random.rand() < PROB[a]\n",
" counts[a] += 1\n",
" value[a] += (reward - value[a]) / counts[a]\n",
" \n",
" A[epoch, t] = a\n",
" R[epoch, t] = reward\n",
" \n",
" return A, R\n",
"\n",
"A_opt, R_opt = optimistic_greedy()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Gradient method"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def gradient(epochs=EPOCHS, t_max=T, alpha=0.1):\n",
" A = -np.ones((epochs, t_max), dtype=np.int32)\n",
" R = np.zeros((epochs, t_max), dtype=np.float32)\n",
"\n",
" for epoch in range(epochs):\n",
" H = np.zeros(K)\n",
" rewards = np.zeros(K)\n",
" counts = np.zeros(K)\n",
"\n",
" for t in range(t_max):\n",
" # softmax(H) -> probs -> action\n",
" probs = np.exp(H) / np.sum(np.exp(H))\n",
" a = np.random.choice(K, p=probs)\n",
"\n",
" # update rewards\n",
" reward = np.random.rand() < PROB[a]\n",
" rewards[a] += reward\n",
" counts[a] += 1\n",
"\n",
" # one hot vector for action\n",
" one_hot = np.zeros(K)\n",
" one_hot[a] = 1\n",
"\n",
" # update H\n",
" H += alpha * (reward - rewards / (counts + 1e-5)) * (one_hot - probs)\n",
" \n",
" # update A and R\n",
" A[epoch, t] = a\n",
" R[epoch, t] = reward\n",
" \n",
" return A, R\n",
"\n",
"A_grad, R_grad = gradient()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Mean rewards"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean reward Epsilon-greedy: 871.75\n",
"Mean reward Optimistic-greedy: 920.57\n",
"Mean reward Decaying Epsilon-greedy: 913.69\n",
"Mean reward UCB: 809.02\n",
"Mean reward Gradient: 807.94\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAHHCAYAAABDUnkqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAD8JUlEQVR4nOydd3QUVRuHn9me3ntCQiD00HvvSC8qzQKiIioWsBcEFcGKKIqCBQsogqDSEZAmvfdOElJI722zZb4/JtnNZjchQZp+85zDIXvn3jt3J5ud37z3LYIoiiIyMjIyMjIyMv8RFLd7ATIyMjIyMjIyNxJZ3MjIyMjIyMj8p5DFjYyMjIyMjMx/ClncyMjIyMjIyPynkMWNjIyMjIyMzH8KWdzIyMjIyMjI/KeQxY2MjIyMjIzMfwpZ3MjIyMjIyMj8p5DFjYyMjIyMjMx/ClncyMjIyFSTiIgIxo8ff7uXISMjcw1kcSMj8w+YP38+giDQrl27270UGZkasXv3bmbMmEF2dvbtXoqMzA1HFjcyMv+AJUuWEBERwf79+7l48eLtXo6MTLXZvXs3b775pixuZP6TyOJGRuY6iYmJYffu3cyZMwc/Pz+WLFlyy9dgNpspLi6+5ee9Hv4Nay0oKLjdS6g2/6a1ysjcamRxIyNznSxZsgQvLy8GDhzIPffcYyNuDAYD3t7ePPTQQ3bjcnNz0el0PP/885Y2vV7P9OnTqVu3LlqtlrCwMF588UX0er3NWEEQmDx5MkuWLKFx48ZotVo2bNgAwIcffkjHjh3x8fHBycmJVq1a8euvv9qdv6ioiKeffhpfX1/c3NwYMmQIiYmJCILAjBkzbPomJiYyYcIEAgIC0Gq1NG7cmG+//bZa16eqtV5rXlEU8fX1ZerUqZY2s9mMp6cnSqXSxtrw3nvvoVKpyM/PB+D48eOMHz+eyMhIdDodgYGBTJgwgYyMDJv1zZgxA0EQOH36NGPHjsXLy4vOnTtbzj9z5kxCQ0NxdnamR48enDp1qlrvOzY2FkEQ+PDDD/n4448JDw/HycmJbt26cfLkSbv+Z8+e5Z577sHb2xudTkfr1q1ZtWqVTZ/vvvsOQRDYvn07TzzxBP7+/oSGhla5jnnz5tG4cWOcnZ3x8vKidevW/PTTT5b3/sILLwBQu3ZtBEFAEARiY2Mt4xcvXkyrVq1wcnLC29ub0aNHEx8fb3OO7t2706RJEw4dOkTHjh1xcnKidu3afPnll9W6VjIyNwvV7V6AjMy/lSVLljBixAg0Gg1jxozhiy++4MCBA7Rp0wa1Ws3w4cNZuXIlCxYsQKPRWMb9/vvv6PV6Ro8eDUg37SFDhvD3338zceJEGjZsyIkTJ/j44485f/48v//+u815//rrL5YtW8bkyZPx9fUlIiICgE8++YQhQ4Zw3333UVJSwtKlS7n33ntZs2YNAwcOtIwfP348y5Yt44EHHqB9+/Zs377d5ngZKSkptG/f3iJS/Pz8WL9+PQ8//DC5ubk8++yz17xGjtZanXkFQaBTp07s2LHDMtfx48fJyclBoVCwa9cuy5p37txJixYtcHV1BWDTpk1cvnyZhx56iMDAQE6dOsXChQs5deoUe/fuRRAEmzXee++9REVFMWvWLERRBOCNN95g5syZDBgwgAEDBnD48GH69u1LSUnJNd9zGT/88AN5eXk8+eSTFBcX88knn9CzZ09OnDhBQEAAAKdOnaJTp06EhITw8ssv4+LiwrJlyxg2bBgrVqxg+PDhNnM+8cQT+Pn58cYbb1Rpufnqq694+umnueeee3jmmWcoLi7m+PHj7Nu3j7FjxzJixAjOnz/Pzz//zMcff4yvry8Afn5+ALzzzjtMmzaNkSNH8sgjj5CWlsa8efPo2rUrR44cwdPT03KurKwsBgwYwMiRIxkzZgzLli3j8ccfR6PRMGHChGpfLxmZG4ooIyNTYw4ePCgC4qZNm0RRFEWz2SyGhoaKzzzzjKXPxo0bRUBcvXq1zdgBAwaIkZGRltc//vijqFAoxJ07d9r0+/LLL0VA3LVrl6UNEBUKhXjq1Cm7NRUWFtq8LikpEZs0aSL27NnT0nbo0CEREJ999lmbvuPHjxcBcfr06Za2hx9+WAwKChLT09Nt+o4ePVr08PCwO19FKltrdef94IMPRKVSKebm5oqiKIqffvqpGB4eLrZt21Z86aWXRFEURZPJJHp6eopTpkyp9DqIoij+/PPPIiDu2LHD0jZ9+nQREMeMGWPTNzU1VdRoNOLAgQNFs9lsaX/11VdFQBw3blyV7zsmJkYERCcnJzEhIcHSvm/fPhGwWWuvXr3E6Ohosbi42NJmNpvFjh07ilFRUZa2RYsWiYDYuXNn0Wg0Vnl+URTFoUOHio0bN66yzwcffCACYkxMjE17bGysqFQqxXfeecem/cSJE6JKpbJp79atmwiIH330kaVNr9eLzZs3F/39/cWSkpJrrlVG5mYgb0vJyFwHS5YsISAggB49egDSFsyoUaNYunQpJpMJgJ49e+Lr68svv/xiGZeVlcWmTZsYNWqUpW358uU0bNiQBg0akJ6ebvnXs2dPALZu3Wpz7m7dutGoUSO7NTk5OdmcJycnhy5dunD48GFLe9m20BNPPGEz9qmnnrJ5LYoiK1asYPDgwYiiaLOufv36kZOTYzNvZVRca03m7dKlCyaTid27dwOShaZLly506dKFnTt3AnDy5Emys7Pp0qWLw+tQXFxMeno67du3B3C45kmTJtm83rx5MyUlJTz11FM2Vp7qWKrKM2zYMEJCQiyv27ZtS7t27Vi3bh0AmZmZ/PXXX4wcOZK8vDzLdcjIyKBfv35cuHCBxMREmzkfffRRlErlNc/t6elJQkICBw4cqNGaAVauXInZbGbkyJE2v5/AwECioqLsPo8qlYrHHnvM8lqj0fDYY4+RmprKoUOHanx+GZkbgSxuZGRqiMlkYunSpfTo0YOYmBguXrzIxYsXadeuHSkpKWzZsgWQvvTvvvtu/vjjD4vvzMqVKzEYDDbi5sKFC5w6dQo/Pz+bf/Xq1QMgNTXV5vy1a9d2uK41a9bQvn17dDod3t7e+Pn58cUXX5CTk2PpExcXh0KhsJujbt26Nq/T0tLIzs5m4cKFdusq8yOquC5HVDxPTeZt2bIlzs7OFiFTJm66du3KwYMHKS4uthwr85UBSTQ888wzBAQE4OTkhJ+fn2Ud5a9FZWuMi4sDICoqyqbdz88PLy+va77nMiqOB6hXr57Fr+XixYuIosi0adPsrsX06dNtrkVla62Ml156CVdXV9q2bUtUVBRPPvkku3btqtbYCxcuIIoiUVFRdus6c+aM3ZqCg4NxcXGxe5+AjQ+PjMytRPa5kZGpIX/99RdXr15l6dKlLF261O74kiVL6Nu3LwCjR49mwYIFrF+/nmHDhrFs2TIaNGhAs2bNLP3NZjPR0dHMmTPH4fnCwsJsXpe3TJSxc+dOhgwZQteuXZk/fz5BQUGo1WoWLVpkcSKtCWazGYD777+fcePGOezTtGnTa85Tca01mVetVtOuXTt27NjBxYsXSU5OpkuXLgQEBGAwGNi3bx87d+6kQYMGFl8RgJEjR7J7925eeOEFmjdvjqurK2azmbvuusty/qrWeKsoW8vzzz9Pv379HPapKDqru9aGDRty7tw51qxZw4YNG1ixYgXz58/njTfe4M0337zmugRBYP369Q6tRGW+TTIydzKyuJGRqSFLlizB39+fzz//3O7YypUr+e233/jyyy9xcnKia9euBAUF8csvv9C5c2f++usvXnvtNZsxderU4dixY/Tq1cvO2bW6rFixAp1Ox8aNG9FqtZb2RYsW2fQLDw/HbDYTExNjY1momKPHz88PNzc3TCYTvXv3vq41OaKm83bp0oX33nuPzZs34+vrS4MGDRAEgcaNG7Nz50527tzJoEGDLP2zsrLYsmULb775Jm+88Yal/cKFC9VeY3h4uGVMZGSkpT0tLY2srKxqz+PonOfPn7c4gJfNrVarb+g1LsPFxYVRo0YxatQoSkpKGDFiBO+88w6vvPIKOp2u0s9anTp
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"mean_reward_eps = np.mean(np.sum(R_eps, axis=1))\n",
"mean_reward_opt = np.mean(np.sum(R_opt, axis=1))\n",
"mean_reward_deps = np.mean(np.sum(R_deps, axis=1))\n",
"mean_reward_ucb = np.mean(np.sum(R_ucb, axis=1))\n",
"mean_reward_grad = np.mean(np.sum(R_grad, axis=1))\n",
"\n",
"asymptotic_reward_eps = EPSILON * (np.sum(PROB) - np.max(PROB)) / (K - 1) + (1 - EPSILON) * np.max(PROB)\n",
"\n",
"print(f\"Mean reward Epsilon-greedy: {mean_reward_eps:.02f}\")\n",
"print(f\"Mean reward Optimistic-greedy: {mean_reward_opt:.02f}\")\n",
"print(f\"Mean reward Decaying Epsilon-greedy: {mean_reward_deps:.02f}\")\n",
"print(f\"Mean reward UCB: {mean_reward_ucb:.02f}\")\n",
"print(f\"Mean reward Gradient: {mean_reward_grad:.02f}\")\n",
"\n",
"plt.plot([0, T], [asymptotic_reward_eps]*2, 'b--', label=\"Asymptotic Epsilon-greedy\")\n",
"plt.plot(np.mean(R_eps, axis=0), label=\"Epsilon-greedy\")\n",
"plt.plot(np.mean(R_opt, axis=0), label=\"Optimistic-greedy\")\n",
"plt.plot(np.mean(R_deps, axis=0), label=\"Decaying Epsilon-greedy\")\n",
"plt.plot(np.mean(R_ucb, axis=0), label=\"UCB\")\n",
"plt.plot(np.mean(R_grad, axis=0), label=\"Gradient\")\n",
"\n",
"plt.title(\"Average reward per step.\")\n",
"plt.xlabel(\"Steps\")\n",
"# plt.ylim(0.4, 0.7)\n",
"plt.ylabel(\"Average reward\")\n",
"plt.legend()\n",
"plt.grid()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Cumulative regret"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjsAAAHHCAYAAABZbpmkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAADbbUlEQVR4nOzdd3QUVRvA4d+WZNMT0guBJBASem/SpaMooiBNelGagBQRpTcpooiIAoYiCPqpWOggvffeSQHSe0+2zPdHyOqSAAkk2ZT7nLMHdmZ25p2bzebdW2WSJEkIgiAIgiCUUnJjByAIgiAIglCYRLIjCIIgCEKpJpIdQRAEQRBKNZHsCIIgCIJQqolkRxAEQRCEUk0kO4IgCIIglGoi2REEQRAEoVQTyY4gCIIgCKWaSHYEQRAEQSjVRLIjCGXMwYMHkclkHDx40NihCIIgFAmR7AhCEZDJZHl65CUBmT9/Ptu2bSv0mEua48ePM3PmTOLj440dSpHbsWMHM2fONHYYglBsKY0dgCCUBRs3bjR4vmHDBvbu3Ztje9WqVZ97rvnz5/POO+/QrVu3ggyxxDt+/DizZs1i4MCB2NnZGTucIrVjxw6++eYbkfAIwlOIZEcQikC/fv0Mnp88eZK9e/fm2F7cpaSkYGlpWSTXSk9Px9TUFLm89FdAF2W5CkJZVPo/RQShhEhJSeGjjz7C09MTlUqFn58fS5YsQZIk/TEymYyUlBTWr1+vb/oaOHAgAMHBwYwcORI/Pz/Mzc1xcHCgR48eBAUFvVA8M2fORCaTcf36dfr06UO5cuVo3ry5fv+PP/5I/fr1MTc3x97enl69evHgwYMc5/nmm2/w8fHB3NycRo0aceTIEVq3bk3r1q31x2T3I9qyZQuffvopHh4eWFhYkJiYCMCpU6fo1KkTtra2WFhY0KpVK44dO2YQ66RJkwDw9vbWl83z7v2XX37R34OjoyP9+vXj0aNH+v1LlixBJpMRHByc47VTp07F1NSUuLg4/bbnxZmXcn2SWq1m1qxZ+Pr6YmZmhoODA82bN2fv3r0ADBw4kG+++QYwbC7NptPp+PLLL6levTpmZma4uLgwYsQIg7gBvLy8eP3119mzZw916tTBzMyMatWq8dtvv+UrHkEojkTNjiAUA5Ik8cYbb3DgwAGGDBlCnTp12L17N5MmTeLRo0csW7YMyGoOGzp0KI0aNWL48OEAVKpUCYAzZ85w/PhxevXqRfny5QkKCuLbb7+ldevWXL9+HQsLixeKrUePHvj6+jJ//nx94jVv3jw+++wzevbsydChQ4mKiuLrr7+mZcuWXLhwQd+M9O233zJ69GhatGjB+PHjCQoKolu3bpQrV47y5cvnuNacOXMwNTVl4sSJZGRkYGpqyj///EPnzp2pX78+M2bMQC6XExAQwKuvvsqRI0do1KgR3bt35/bt2/z0008sW7YMR0dHAJycnJ56X+vWrWPQoEE0bNiQBQsWEBERwVdffcWxY8f099CzZ08mT57Mzz//rE+msv3888906NCBcuXKAeQpzueVa25mzpzJggUL9D/3xMREzp49y/nz52nfvj0jRowgNDQ012ZRgBEjRujvdezYsQQGBrJixQouXLjAsWPHMDEx0R97584d3n33Xd5//30GDBhAQEAAPXr0YNeuXbRv3z5P8QhCsSQJglDkRo0aJf3312/btm0SIM2dO9fguHfeeUeSyWTS3bt39dssLS2lAQMG5Dhnampqjm0nTpyQAGnDhg36bQcOHJAA6cCBA8+MccaMGRIg9e7d22B7UFCQpFAopHnz5hlsv3LliqRUKvXbMzIyJAcHB6lhw4aSWq3WH7du3ToJkFq1apUjJh8fH4P70Ol0kq+vr9SxY0dJp9MZ3Ku3t7fUvn17/bbFixdLgBQYGPjM+5IkScrMzJScnZ2lGjVqSGlpafrtf//9twRI06dP129r2rSpVL9+fYPXnz592qBc8xPn08r1aWrXri299tprzzzmyfdTtiNHjkiAtGnTJoPtu3btyrG9YsWKEiD9+uuv+m0JCQmSm5ubVLdu3XzFIwjFjWjGEoRiYMeOHSgUCsaOHWuw/aOPPkKSJHbu3Pncc5ibm+v/r1ariYmJoXLlytjZ2XH+/PkXju399983eP7bb7+h0+no2bMn0dHR+oerqyu+vr4cOHAAgLNnzxITE8OwYcNQKv+tRO7bt6++NuRJAwYMMLiPixcvcufOHfr06UNMTIz+WikpKbRt25bDhw+j0+nyfU9nz54lMjKSkSNHYmZmpt/+2muv4e/vz/bt2/Xb3n33Xc6dO8e9e/f027Zu3YpKpeLNN9984TifLNensbOz49q1a9y5cyff9/nLL79ga2tL+/btDX5W9evXx8rKSv+zyubu7s5bb72lf25jY0P//v25cOEC4eHhLx2PIBiLaMYShGIgODgYd3d3rK2tDbZnj87Krc/Ik9LS0liwYAEBAQE8evTIoGkkISHhhWPz9vY2eH7nzh0kScLX1zfX47ObRbJjrly5ssF+pVKJl5dXnq8FWUnQ0yQkJDw1eXqa7Nj8/Pxy7PP39+fo0aP65z169GDChAls3bqVTz75BEmS+OWXX+jcuTM2NjYvHOeT9/o0s2fP5s0336RKlSrUqFGDTp068d5771GrVq3nvvbOnTskJCTg7Oyc6/7IyEiD55UrVzbo7wNQpUoVAIKCgnB1dX2peATBWESyIwilxJgxYwgICGDcuHE0bdoUW1tbZDIZvXr1eqHaj2z/rWmBrA6vMpmMnTt3olAochxvZWVVoNcCWLx4MXXq1Mn1NS9zvbxwd3enRYsW/Pzzz3zyySecPHmSkJAQPv/885eK88l7fZqWLVty7949/vjjD/bs2cOaNWtYtmwZq1atYujQoc98rU6nw9nZmU2bNuW6/1l9mgojHkEwFpHsCEIxULFiRfbt20dSUpJB7c7Nmzf1+7M9+c072//+9z8GDBjA0qVL9dvS09MLfJK9SpUqIUkS3t7e+m/9ucmO+e7du7Rp00a/XaPREBQUlKeagOzO1zY2NrRr1+6Zxz6tXJ4V261bt3j11VcN9t26dcugvCGrKWvkyJHcunWLrVu3YmFhQdeuXV8ozhdhb2/PoEGDGDRoEMnJybRs2ZKZM2fqk4un3XulSpXYt28fzZo1y1NydffuXSRJMjjf7du3AQxq454XjyAUN6LPjiAUA126dEGr1bJixQqD7cuWLUMmk9G5c2f9NktLy1wTGIVCkWNUz9dff41Wqy3QWLt3745CoWDWrFk5ridJEjExMQA0aNAABwcHVq9ejUaj0R+zadOmHMOen6Z+/fpUqlSJJUuWkJycnGN/VFSU/v/Z89TkJblr0KABzs7OrFq1ioyMDP32nTt3cuPGDV577TWD499++20UCgU//fQTv/zyC6+//rrBvDj5iTO/ssszm5WVFZUrVzaI+2n33rNnT7RaLXPmzMlxXo1Gk+P40NBQfv/9d/3zxMRENmzYQJ06dXB1dc1zPAkJCdy8efOlmk8FoSCJmh1BKAa6du1KmzZtmDZtGkFBQdSuXZs9e/bwxx9/MG7cOH3NAWT9Yd23bx9ffPEF7u7ueHt707hxY15//XU2btyIra0t1apV48SJE+zbtw8HB4cCjbVSpUrMnTuXqVOn6oeSW1tbExgYyO+//87w4cOZOHEipqamzJw5kzFjxvDqq6/Ss2dPgoKCWLduHZUqVcpTTYxcLmfNmjV07tyZ6tWrM2jQIDw8PHj06BEHDhzAxsaGv/76S18uANOmTaNXr16YmJjQtWvXXCfrMzEx4fPPP2fQoEG0atWK3r1764eee3l5MX78eIPjnZ2dadOmDV988QVJSUm8++67LxxnflWrVo3WrVtTv3597O3tOXv2LP/73/8YPXq0/pjsex87diwdO3ZEoVDQq1cvWrVqxYgRI1iwYAEXL16kQ4cOmJiYcOfOHX755Re++uor3nnnHf15qlSpwpAhQzhz5gwuLi788MMPREREEBAQkK94fv/9dwYNGkRAQIB+HihBMCqjjQMThDIst6HCSUlJ0vjx4yV3d3f
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"mean_cum_reward_eps = np.arange(1, T+1) - np.cumsum(np.mean(R_eps, axis=0))\n",
"mean_cum_reward_opt = np.arange(1, T+1) - np.cumsum(np.mean(R_opt, axis=0))\n",
"mean_cum_reward_deps = np.arange(1, T+1) - np.cumsum(np.mean(R_deps, axis=0))\n",
"mean_cum_reward_ucb = np.arange(1, T+1) - np.cumsum(np.mean(R_ucb, axis=0))\n",
"mean_cum_reward_grad = np.arange(1, T+1) - np.cumsum(np.mean(R_grad, axis=0))\n",
"\n",
"plt.plot(mean_cum_reward_eps, label=\"Epsilon-greedy\")\n",
"plt.plot(mean_cum_reward_opt, label=\"Optimistic-greedy\")\n",
"plt.plot(mean_cum_reward_deps, label=\"Decaying Epsilon-greedy\")\n",
"plt.plot(mean_cum_reward_ucb, label=\"UCB\")\n",
"plt.plot(mean_cum_reward_grad, label=\"Gradient\")\n",
"\n",
"plt.xlabel(\"Steps\")\n",
"plt.ylabel(\"Total regret\")\n",
"plt.title(\"Total regret over steps.\")\n",
"plt.legend()\n",
"plt.grid()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Verification convergence"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA04AAAHDCAYAAAATEUquAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAABikUlEQVR4nO3deXhM5///8ddIZCGILbJYEltC7YKiLS0V2lLaomgllI+1dkWVWEpQVO2tJVSppYuqWqrWUq1dtY2dRhGhCJFKSM7vD7/MtyPLZCIxKc/Hdc3VzH2fc5/3HJMmr5z73GMyDMMQAAAAACBNuexdAAAAAADkdAQnAAAAALCC4AQAAAAAVhCcAAAAAMAKghMAAAAAWEFwAgAAAAArCE4AAAAAYAXBCQAAAACsIDgBAAAAgBUEJwDIQRo2bKiGDRvauwwL27Ztk8lk0rZt2+xdSrqWLFmigIAA5c6dW+7u7jbvf/bsWZlMJk2ePDnri3uMhYSEyNfX195lAMADIzgBQAYsWrRIJpMpzcfPP/+c4bH++OMPjRo1SmfPns2+gjNh9uzZWrRokb3LyJSjR48qJCREZcqU0bx58/TJJ5+kue26des0atSoh1fcfeLi4jRq1KgcH0QBAJYc7V0AAPyXjBkzRn5+finay5Ytm+Ex/vjjD40ePVoNGzZM8Zf477///kFLzLTZs2erSJEiCgkJsWh/5pln9M8//8jJyck+hWXAtm3blJSUpI8++sjqv8W6des0a9Ysu4WnuLg4jR49WpJy3NVFAEDaCE4AYINmzZopMDAw28bPieEkV65ccnFxsXcZ6YqOjpakTE3Ry+lu3bqlvHnz2rsMAHjsMVUPALLY8uXLVbNmTeXLl0/58+dX5cqV9dFHH0m6N+WvdevWkqRnn33WPNUvedrW/fc4Jd9ftHLlSo0ePVo+Pj7Kly+fXnvtNcXExCg+Pl79+vWTh4eH3Nzc1KlTJ8XHx1vUEx4erueee04eHh5ydnZWxYoVNWfOHIttfH199fvvv2v79u3mmpLrSOsep1WrVqlmzZpydXVVkSJF9MYbb+j8+fMW24SEhMjNzU3nz59Xy5Yt5ebmpqJFi2rQoEFKTEzM0PmcPXu2nnjiCTk7O8vb21u9evXS9evXLWoPDQ2VJBUtWlQmkynNq0khISGaNWuWJFlMtbzfJ598ojJlysjZ2Vm1atXS3r17U2xz9OhRvfbaaypUqJBcXFwUGBioNWvWpPtazp49q6JFi0qSRo8ebT5+cr3J5+vUqVN64YUXlC9fPnXo0MH8Ou+/Giilfl9cfHy8QkNDVbZsWTk7O6tEiRJ65513Urw37te7d2+5ubkpLi4uRV+7du3k6elp/nf75ptv9OKLL8rb21vOzs4qU6aMxo4da/XfNa33U/I9ZvdPF83MeQaA7MAVJwCwQUxMjK5cuWLRZjKZVLhwYUnSpk2b1K5dOzVq1EgTJ06UJEVERGjXrl3q27evnnnmGfXp00fTp0/Xu+++qwoVKkiS+b9pCQsLk6urq4YOHaqTJ09qxowZyp07t3LlyqVr165p1KhR+vnnn7Vo0SL5+flp5MiR5n3nzJmjJ554Qi1atJCjo6O+/fZb9ezZU0lJSerVq5ckadq0aXr77bfl5uam4cOHS5KKFSuWZj2LFi1Sp06dVKtWLYWFhenSpUv66KOPtGvXLh08eNDiyk9iYqKCgoJUp04dTZ48WT/88IOmTJmiMmXKqEePHum+7lGjRmn06NFq3LixevTooWPHjmnOnDnau3evdu3apdy5c2vatGn69NNP9fXXX2vOnDlyc3NTlSpVUh2vW7duunDhgjZt2qQlS5akus2yZct08+ZNdevWTSaTSZMmTdIrr7yi06dPK3fu3JKk33//XfXr15ePj4+GDh2qvHnzauXKlWrZsqW+/PJLtWrVKtWxixYtqjlz5qhHjx5q1aqVXnnlFUmyqPfu3bsKCgrSU089pcmTJytPnjzpnqP7JSUlqUWLFtq5c6f+97//qUKFCjpy5Ig+/PBDHT9+XKtXr05z37Zt22rWrFn67rvvzAFfuje98Ntvv1VISIgcHBwk3XsPuLm5acCAAXJzc9OWLVs0cuRI3bhxQx988IFNNacls+cZALKFAQCwKjw83JCU6sPZ2dm8Xd++fY38+fMbd+/eTXOsVatWGZKMrVu3puhr0KCB0aBBA/PzrVu3GpKMSpUqGQkJCeb2du3aGSaTyWjWrJnF/nXr1jVKlSpl0RYXF5fiOEFBQUbp0qUt2p544gmLY99fQ3K9CQkJhoeHh1GpUiXjn3/+MW+3du1aQ5IxcuRIc1twcLAhyRgzZozFmNWrVzdq1qyZ4lj/Fh0dbTg5ORlNmjQxEhMTze0zZ840JBkLFy40t4WGhhqSjMuXL6c7pmEYRq9evYzUfvydOXPGkGQULlzYuHr1qrn9m2++MSQZ3377rbmtUaNGRuXKlY3bt2+b25KSkox69eoZ5cqVS/f4ly9fNiQZoaGhKfqSz9fQoUNT9JUqVcoIDg5O0X7/e2bJkiVGrly5jB9//NFiu7lz5xqSjF27dqVZW1JSkuHj42O8+uqrFu0rV640JBk7duwwt6X2vurWrZuRJ08ei/MSHBxs8Z68//2ULPn8h4eHm9se5DwDQFZjqh4A2GDWrFnatGmTxWP9+vXmfnd3d926dUubNm3K0uN27NjRfLVDkurUqSPDMNS5c2eL7erUqaNz587p7t275jZXV1fz18lXzBo0aKDTp08rJibG5lr27dun6Oho9ezZ0+LepxdffFEBAQH67rvvUuzTvXt3i+dPP/20Tp8+ne5xfvjhByUkJKhfv37Klev/flx17dpV+fPnT/U4WaFt27YqWLCgRa2SzPVevXpVW7ZsUZs2bXTz5k1duXJFV65c0d9//62goCCdOHEixZRFW1m7EpeeVatWqUKFCgoICDDXduXKFT333HOSpK1bt6a5r8lkUuvWrbVu3TrFxsaa21esWCEfHx899dRT5rZ/v6+Sz8PTTz+tuLg4HT16NNP1J3sY5xkAbMFUPQCwQe3atdNdHKJnz55auXKlmjVrJh8fHzVp0kRt2rRR06ZNH+i4JUuWtHheoEABSVKJEiVStCclJSkmJsY8fXDXrl0KDQ3V7t27U9y7EhMTYx4ro/78809Jkr+/f4q+gIAA7dy506LNxcXFfF9PsoIFC+ratWuZOo6Tk5NKly5t7s9q95/r5BCVXO/JkydlGIZGjBihESNGpDpGdHS0fHx8MnV8R0dHFS9ePFP7StKJEycUERGR4pz/u7b0tG3bVtOmTdOaNWvUvn17xcbGat26deapi8l+//13vffee9qyZYtu3LhhMUZmAvn9svs8A4CtCE4AkIU8PDx06NAhbdy4UevXr9f69esVHh6ujh07avHixZkeN/m+koy2G4YhSTp16pQaNWqkgIAATZ06VSVKlJCTk5PWrVunDz/8UElJSZmuKaPSqjGnsnZOk8/ZoEGDFBQUlOq2tixPfz9nZ2eLK2zJUlvEQrp3D9m/a05KSlLlypU1derUVLe/P2zf78knn5Svr69Wrlyp9u3b69tvv9U///yjtm3bmre5fv26GjRooPz582vMmDEqU6aMXFxcdODAAQ0ZMiTd91V6r+Pfsvs8A4CtCE4AkMWcnJzUvHlzNW/eXElJSerZs6c+/vhjjRgxQmXLlk3zF8fs8O233yo+Pl5r1qyxuJKS2nStjNZVqlQpSdKxY8fM07+SHTt2zNz/oP59nNKlS5vbExISdObMGTVu3DhT4z7o+U+uJXfu3JmqIbPHL1iwoMVqgsn+/PNPi/NTpkwZHT58WI0aNcr0sdq0aaOPPvpIN27c0IoVK+Tr66snn3zS3L9t2zb9/fff+uqrr/TMM8+Y28+cOZOh1yEpxWu5/wrig55nAMhq3OMEAFno77//tnieK1cu84ppyUtBJ38mT2q/BGe15CsRyVdLpHvTqMLDw1Nsmzdv3gzVFBgYKA8PD82dO9dieev169crIiJCL7744oMXLqlx48ZycnLS9OnTLepfsGCBYmJiMn2cBz3/Hh4
"text/plain": [
"<Figure size 1000x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10, 5))\n",
"\n",
"plt.bar(np.arange(1, K+1)*2, PROB, 0.2)\n",
"plt.bar(np.arange(1, K+1)*2+0.2, [np.sum(R_eps[A_eps == k]) / np.sum(A_eps == k) for k in range(K)], 0.2, label=\"Epsilon-greedy\")\n",
"plt.bar(np.arange(1, K+1)*2+0.4, [np.sum(R_opt[A_opt == k]) / np.sum(A_opt == k) for k in range(K)], 0.2, label=\"Optimistic-greedy\")\n",
"plt.bar(np.arange(1, K+1)*2+0.6, [np.sum(R_deps[A_deps == k]) / np.sum(A_deps == k) for k in range(K)], 0.2, label=\"Decaying Epsilon-greedy\")\n",
"plt.bar(np.arange(1, K+1)*2+0.8, [np.sum(R_ucb[A_ucb == k]) / np.sum(A_ucb == k) for k in range(K)], 0.2, label=\"UCB\")\n",
"plt.bar(np.arange(1, K+1)*2+1.0, [np.sum(R_grad[A_grad == k]) / np.sum(A_grad == k) for k in range(K)], 0.2, label=\"Gradient\")\n",
"\n",
"plt.xticks(np.arange(1, K+1)*2, np.arange(1, K+1))\n",
"plt.title('Estimation of the true value.')\n",
"plt.ylabel('Average reward')\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Best arm selection"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlEAAAHHCAYAAACfqw0dAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAADprElEQVR4nOzdd3xT5f7A8U+SZnXvSRdt2XvIBpWloIK4UO+9oPeqP/e8V70u3Osq7oHz6nXvhQLiQvbeo9CWVbr3StLk+f2RNjR00JZ0wffdV185OefJOU9OkpNvnqlRSimEEEIIIUSLaDs6A0IIIYQQXZEEUUIIIYQQrSBBlBBCCCFEK0gQJYQQQgjRChJECSGEEEK0ggRRQgghhBCtIEGUEEIIIUQrSBAlhBBCCNEKEkQJIYQQQrRCpwyifvvtNzQaDZ9//nlHZ0V0UqmpqUyZMoWAgAA0Gg1ff/11i/cxd+5cEhISPJ63jpCRkYFGo+E///nPSXEc0TnNmzcPjUbT0dlolYSEBObOndvR2WjUu+++i0ajISMjo6Oz0mV1ROzQbkGURqNp1v9vv/3WXlk6JVRUVDBv3ryT7rzOmTOHrVu38uijj/L+++8zbNiwBtNlZmYyb948Nm3a1L4ZbCMLFy5k3rx5HZ2NdneqPu9TzanwOj/22GOt+tEnOiev9jrQ+++/73b/vffeY8mSJfXW9+7dm507d7ZXtk56FRUVPPjggwCcfvrpHZsZD6msrGTlypXcc8893HDDDU2mzczM5MEHHyQhIYFBgwa5bXvjjTdwOBxtmFPPW7hwIS+//PJJ/0VzrFP1eZ9qToXX+bHHHuPCCy9k5syZbuv/+te/Mnv2bIxGY8dkTLRKuwVRf/nLX9zur1q1iiVLltRbD5x0QVR5eTk+Pj4dnY2TRm5uLgCBgYEntB+9Xu+B3IiTTUVFBd7e3h2dDXGK0el06HS6js6GaKFO2SaqlsPh4NFHH6Vbt26YTCYmTpzI3r1766VbvXo1Z511FgEBAXh7ezNhwgSWL19+3P3X1p9+8skn/Pvf/yYyMhIfHx/OO+88Dh482Krj1LYZ2LFjB5dddhlBQUGMHTvWtf1///sfp512Gt7e3gQFBTF+/HgWL17sto8ff/yRcePG4ePjg5+fH9OnT2f79u1uaebOnYuvry+HDx9m5syZ+Pr6EhYWxh133IHdbgec7VfCwsIAePDBB11VprW/8rZs2cLcuXPp3r07JpOJyMhIrrzySvLz8xs8V8OGDcNkMpGUlMTrr7/eaPuI//3vfwwdOhSz2UxwcDCzZ89u8Hw2ZOPGjZx99tn4+/vj6+vLxIkTWbVqldv5jY+PB+Cf//wnGo2m0XZNv/32G8OHDwfgiiuucD3/d99913UO6z62bnufl19+me7du+Pt7c2UKVM4ePAgSikefvhhunXrhtlsZsaMGRQUFNQ7bnNev6ysLK644gq6deuG0WgkKiqKGTNmNNkeYu7cubz88suAe/X4sRYsWEBSUhJGo5Hhw4ezdu3aeml27drFhRdeSHBwMCaTiWHDhvHtt982euyGzJ8/n/j4eMxmMxMmTGDbtm2tOo7NZuPBBx8kJSUFk8lESEgIY8eOZcmSJS163nV98803TJ8+nejoaIxGI0lJSTz88MOuz0at008/nX79+rF+/XrGjx+Pt7c3//73vz32Xqjr22+/RaPRsGXLFte6L774Ao1Gw6xZs9zS9u7dm0suucR1/5133uHMM88kPDwco9FInz59ePXVV+sdY926dUydOpXQ0FDMZjOJiYlceeWVTearVnPet41p7md+9erVTJs2jaCgIHx8fBgwYADPP/88cPzX2eFw8Nxzz9G3b19MJhMRERFcc801FBYWuh1DKcUjjzxCt27d8Pb25owzzmj28wD4z3/+w+jRowkJCcFsNjN06NAG29gsWbKEsWPHEhgYiK+vLz179uTf//53k/vWaDSUl5fz3//+1/X8attpNdQmKiEhgXPOOcd1/TWbzfTv39/VPOPLL7+kf//+mEwmhg4dysaNG+sd80Q+6/n5+fz1r3/F39+fwMBA5syZw+bNm92uo7V++eUX1/snMDCQGTNmNFggcvjwYa688koiIiIwGo307duXt99+u166F198kb59+7q+K4cNG8aHH37YrHzb7fYmv9MfeOAB9Hq96wd5XVdffTWBgYFUVVU161ioDnL99derxg7/66+/KkANHjxYDR06VM2fP1/NmzdPeXt7q9NOO80t7dKlS5XBYFCjRo1SzzzzjJo/f74aMGCAMhgMavXq1U3mofY4/fv3VwMGDFDPPvusuuuuu5TJZFI9evRQFRUVLT7OAw88oADVp08fNWPGDPXKK6+ol19+WSml1Lx58xSgRo8erZ5++mn1/PPPq8suu0zdeeedrse/9957SqPRqLPOOku9+OKL6sknn1QJCQkqMDBQpaenu9LNmTNHmUwm1bdvX3XllVeqV199VV1wwQUKUK+88opSSqmysjL16quvKkCdf/756v3331fvv/++2rx5s1JKqf/85z9q3Lhx6qGHHlILFixQN998szKbzeq0005TDofDdawNGzYoo9GoEhIS1BNPPKEeffRRFR0drQYOHFjvNXzkkUeURqNRl1xyiXrllVfUgw8+qEJDQ1VCQoIqLCxs8vXYtm2b8vHxUVFRUerhhx9WTzzxhEpMTFRGo1GtWrVKKaXU5s2b1fz58xWgLr30UvX++++rr776qsH9ZWVlqYceekgB6uqrr3Y9/3379rnOYXx8vCt9enq6AtSgQYNUnz591LPPPqvuvfdeZTAY1MiRI9W///1vNXr0aPXCCy+om266SWk0GnXFFVe4HbO5r9/o0aNVQECAuvfee9Wbb76pHnvsMXXGGWeo33//vdHzs2LFCjV58mQFuJ7L+++/75b3wYMHq+TkZPXkk0+qp556SoWGhqpu3bopq9Xqdp4DAgJUnz591JNPPqleeuklNX78eKXRaNSXX37Z5GtUe5z+/furhIQE9eSTT6oHH3xQBQcHq7CwMJWVldXi4/z73/9WGo1GXXXVVeqNN95QzzzzjLr00kvVE088cdzn3ZiZM2eqiy++WD399NPq1VdfVRdddJEC1B133OGWbsKECSoyMlKFhYWpG2+8Ub3++uvq66+/9sh74Vj5+flKo9GoF1980bXu5ptvVlqtVoWFhbnW5eTkKEC99NJLrnXDhw9Xc+fOVfPnz1cvvviimjJlSr002dnZKigoSPXo0UM9/fTT6o033lD33HOP6t27d5P5Uqr579va61tdzf3ML168WBkMBhUfH68eeOAB9eqrr6qbbrpJTZo0SSl1/Nf5H//4h/Ly8lJXXXWVeu2119Sdd96pfHx81PDhw93e3/fee68C1LRp09RLL72krrzyShUdHa1CQ0PVnDlzjnsuunXrpq677jr10ksvqWeffVaddtppClDff/+9K822bduUwWBQw4YNU88//7x67bXX1B133KHGjx/f5L7ff/99ZTQa1bhx41zPb8WKFUoppd555x0FuJ3v+Ph41bNnTxUVFaXmzZun5s+fr2JiYpSvr6/63//+p+Li4tQTTzyhnnjiCRUQEKCSk5OV3W53y2drP+t2u12NGjVK6XQ6dcMNN6iXXnpJTZ482XXdf+edd1xplyxZory8vFSPHj3UU0895XoPBAUFuT2frKws1a1bNxUbG6seeugh9eqrr6rzzjtPAWr+/PmudAsWLFCAuvDCC9Xrr7+unn/+efX3v/9d3XTTTU3mubnf6ampqQpw+ywqpZTFYlFBQUHqyiuvbPI4dXXqIKp3797KYrG41j///PMKUFu3blVKKeVwOFRKSoqaOnWq25d+RUWFSkxMVJMnT24yD7XHiYmJUSUlJa71n376qQLU888/3+Lj1F5kLr30UrdjpaamKq1Wq84//3y3N3nt/pVSqrS0VAUGBqqrrrrKbXtWVpYKCAhwWz9nzhwFqIceesgtbW3gWSs3N1cB6oEHHqj3/OsGibU++ugjBag//vjDte7cc89V3t7e6vDhw27Px8vLy+01zMjIUDqdTj3
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"selected_best_cum_eps = np.mean(np.cumsum(A_eps == BEST, axis=1), axis=0) / np.arange(1, T+1)\n",
"selected_best_cum_opt = np.mean(np.cumsum(A_opt == BEST, axis=1), axis=0) / np.arange(1, T+1)\n",
"selected_best_cum_deps = np.mean(np.cumsum(A_deps == BEST, axis=1), axis=0) / np.arange(1, T+1)\n",
"selected_best_cum_ucb = np.mean(np.cumsum(A_ucb == BEST, axis=1), axis=0) / np.arange(1, T+1)\n",
"selected_best_cum_grad = np.mean(np.cumsum(A_grad == BEST, axis=1), axis=0) / np.arange(1, T+1)\n",
"\n",
"asymptote = (1 - EPSILON) + EPSILON * 1/K\n",
"\n",
"plt.plot([0, T], [100*asymptote]*2, 'b--', label=\"Asymptotic Epsilon-greedy\")\n",
"plt.plot(100*selected_best_cum_eps, label=\"Epsilon-greedy\")\n",
"plt.plot(100*selected_best_cum_opt, label=\"Optimistic-greedy\")\n",
"plt.plot(100*selected_best_cum_deps, label=\"Decaying Epsilon-greedy\")\n",
"plt.plot(100*selected_best_cum_ucb, label=\"UCB\")\n",
"plt.plot(100*selected_best_cum_grad, label=\"Gradient\")\n",
"\n",
"plt.title(\"The percentage of times the best arm was elected as time goes by\")\n",
"plt.xlabel(\"Steps\")\n",
"plt.ylabel(\"Optimal action (%)\")\n",
"plt.legend()\n",
"plt.grid()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# MeanReward($\\epsilon$)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_2937633/564047854.py:12: RuntimeWarning: invalid value encountered in divide\n",
" a = np.argmax(rewards / counts)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAHHCAYAAABDUnkqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd1iT19vA8W8SCHuDIIiAW0RFrQMVt+LCUa2z7tm696hVHNU662qte29rrQMHTrQo7roVFYoioKDsTZ73D37klbpAgSCez3Xlkjzr3DlEcuc8Z8gkSZIQBEEQBEEoJOSaDkAQBEEQBCE3ieRGEARBEIRCRSQ3giAIgiAUKiK5EQRBEAShUBHJjSAIgiAIhYpIbgRBEARBKFREciMIgiAIQqEikhtBEARBEAoVkdwIgiAIglCoiORGEAQcHR3p3bu3psP4ZEeOHMHV1RVdXV1kMhlRUVGaDumtZDIZXl5emg7jvS5dukTt2rUxMDBAJpNx/fp1TYeURYMGDWjQoIH6eVBQEDKZjA0bNmgsJqHgEMmNkC82bNiATCZDJpNx7ty5N/ZLkoS9vT0ymYzWrVtrIELhcxcZGUmnTp3Q09Pj119/ZfPmzRgYGGgsHm9v7wKfwLxLamoq33zzDS9fvuSXX35h8+bNODg4aDosQcg2LU0HIHxZdHV12bZtG3Xr1s2y/cyZMzx9+hQdHR0NRSZ87i5dukRsbCwzZ86kSZMmmg4Hb29vfv3117cmOImJiWhpFdw/v48ePeLff/9l9erV9O/fX9PhvNWxY8c0HYJQgImWGyFftWzZkt27d5OWlpZl+7Zt26hWrRo2NjYaiuzjqVQqkpKSNB3Ge8XHx2s6hDz3/PlzAExNTTUbSDbo6uoW6OTmc6hLpVKJUqnUdBhCASWSGyFfde3alcjISHx8fNTbUlJS2LNnD926dXvrOSqVisWLF1OhQgV0dXWxtrZm0KBBvHr1Kstxf/31F61atcLW1hYdHR1KlizJzJkzSU9Pz3JcgwYNcHFx4c6dOzRs2BB9fX3s7OyYN29etl6DTCZj6NChbN26lQoVKqCjo8ORI0cACAkJoW/fvlhbW6Ojo0OFChVYt26d+lxJkrC0tGT06NFZXp+pqSkKhSJLH5G5c+eipaVFXFwcADdu3KB3796UKFECXV1dbGxs6Nu3L5GRkVni8/LyQiaTcefOHbp164aZmZm6pUySJGbNmkWxYsXQ19enYcOG3L59O1uvG2DBggXUrl0bCwsL9PT0qFatGnv27HnjOB8fH+rWrYupqSmGhoaULVuWyZMnf/D669evp1GjRhQpUgQdHR2cnZ1ZsWLFB89r0KABvXr1AqB69erIZDJ1H6J39Sf6b5+N06dPI5PJ2LVrFz/99BPFihVDV1eXxo0b8/DhwzfO9/f3p2XLlpiZmWFgYEClSpVYsmQJAL179+bXX38FUN+Olclk6nPf1ufm2rVrtGjRAmNjYwwNDWncuDEXLlzIckzm7d2///6b0aNHY2VlhYGBAe3bt+fFixcfrCeAkydP4u7ujoGBAaamprRt25a7d++q9/fu3Zv69esD8M033yCTybLU09tERUUxcuRI7O3t0dHRoVSpUsydOxeVSqU+JrNPzIIFC/jll19wcHBAT0+P+vXrc+vWrSzXCwsLo0+fPhQrVgwdHR2KFi1K27ZtCQoKUh/z39/fx75e+P//Mw8fPqR3796YmppiYmJCnz59SEhI+GAZQsFTcL86CIWSo6Mjbm5ubN++nRYtWgBw+PBhoqOj6dKlC0uXLn3jnEGDBrFhwwb69OnD8OHDCQwMZPny5Vy7do2///4bbW1tIOMPv6GhIaNHj8bQ0JCTJ08ydepUYmJimD9/fpZrvnr1iubNm/P111/TqVMn9uzZw4QJE6hYsaI6rvc5efIku3btYujQoVhaWuLo6Eh4eDi1atVSJz9WVlYcPnyYfv36ERMTw8iRI5HJZNSpUwdfX1/1tW7cuEF0dDRyuZy///6bVq1aAXD27FmqVKmCoaEhkJEwPH78mD59+mBjY8Pt27dZtWoVt2/f5sKFC1k+PCHjg6l06dLMnj0bSZIAmDp1KrNmzaJly5a0bNmSq1ev0qxZM1JSUrL1+1uyZAlt2rShe/fupKSksGPHDr755hsOHjyojvv27du0bt2aSpUqMWPGDHR0dHj48CF///33B6+/YsUKKlSoQJs2bdDS0uLAgQN8//33qFQqhgwZ8s7zfvjhB8qWLcuqVauYMWMGTk5OlCxZMluv6b9+/vln5HI5Y8eOJTo6mnnz5tG9e3f8/f3Vx/j4+NC6dWuKFi3KiBEjsLGx4e7duxw8eJARI0YwaNAgnj17ho+PD5s3b/5gmbdv38bd3R1jY2PGjx+PtrY2K1eupEGDBpw5c4aaNWtmOX7YsGGYmZkxbdo0goKCWLx4MUOHDmXnzp3vLef48eO0aNGCEiVK4OXlRWJiIsuWLaNOnTpcvXoVR0dHBg0ahJ2dHbNnz2b48OFUr14da2vrd14zISGB+vXrExISwqBBgyhevDh+fn5MmjSJ0NBQFi9enOX4TZs2ERsby5AhQ0hKSmLJkiU0atSImzdvqsvp0KEDt2/fZtiwYTg6OvL8+XN8fHwIDg7G0dHxg/WZk9f7uk6dOuHk5MScOXO4evUqa9asoUiRIsydOzfbZQoFhCQI+WD9+vUSIF26dElavny5ZGRkJCUkJEiSJEnffPON1LBhQ0mSJMnBwUFq1aqV+ryzZ89KgLR169Ys1zty5Mgb2zOv97pBgwZJ+vr6UlJSknpb/fr1JUDatGmTeltycrJkY2MjdejQ4YOvBZDkcrl0+/btLNv79esnFS1aVIqIiMiyvUuXLpKJiYk6vvnz50sKhUKKiYmRJEmSli5dKjk4OEg1atSQJkyYIEmSJKWnp0umpqbSqFGj3vv6tm/fLgGSr6+vetu0adMkQOratWuWY58/fy4plUqpVatWkkqlUm+fPHmyBEi9evX64Gv/bwwpKSmSi4uL1KhRI/W2X375RQKkFy9efPB6H7q+JEmSh4eHVKJEiQ+e+/p77HUODg5vfW3169eX6tevr35+6tQpCZDKly8vJScnq7cvWbJEAqSbN29KkiRJaWlpkpOTk+Tg4CC9evUqyzVfr9chQ4ZI7/oTC0jTpk1TP2/Xrp2kVCqlR48eqbc9e/ZMMjIykurVq/fGa2zSpEmWskaNGiUpFAopKirqreVlcnV1lYoUKSJFRkaqt/3zzz+SXC6Xevbs+UZd7N69+73XkyRJmjlzpmRgYCA9ePAgy/aJEydKCoVCCg4OliRJkgIDAyVA0tPTk54+fao+zt/fXwLU7/VXr15JgDR//vz3lvvf31/m9devX5/j15v5f6Zv375Zymjfvr1kYWHxwToQCh5xW0rId506dSIxMZGDBw8SGxvLwYMH33lLavfu3ZiYmNC0aVMiIiLUj2rVqmFoaMipU6fUx+rp6al/jo2NJSIiAnd3dxISErh3716W6xoaGvLtt9+qnyuVSmrUqMHjx4+z9Rrq16+Ps7Oz+rkkSfzxxx94enoiSVKWWD08PIiOjubq1asAuLu7k56ejp+fH5DRQuPu7o67uztnz54F4NatW0RFReHu7v7W15eUlERERAS1atUCUF/7dYMHD87y/Pjx46SkpDBs2LAsrTwjR47M1mv+bwyvXr0iOjoad3f3LOVn9tP466+/styWyOn1o6OjiYiIoH79+jx+/Jjo6OgcXetj9enTJ0tfjszfQeZ749q1awQGBjJy5Mg3+qT8t/UsO9LT0zl27Bjt2rWjRIkS6u1FixalW7dunDt3jpiYmCznDBw4MEtZme+pf//9953lhIaGcv36dXr37o25ubl6e6VKlWjatCne3t45jh0y/o+6u7tjZmaW5X3fpEkT0tPTs7RSArRr1w47Ozv18xo1alCzZk11+Xp6eiiVSk6fPv3Greec+JjX+9//M+7u7kRGRr5R/0LBJ5IbId9ZWVnRpEkTtm3bxt69e0lPT6djx45vPTYgIIDo6GiKFCmClZVVlkdcXJy64yNkNO23b98eExM
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"epsilons = np.linspace(0.0, 1.0, 20)\n",
"for t_max in [10, 30, 50, 70, 100, 1000]:\n",
" mean_rewards = []\n",
" for epsilon in epsilons:\n",
" A, R = epsilon_greedy(epochs=100, t_max=t_max, epsilon=epsilon)\n",
" mean_rewards.append(np.mean(np.sum(R, axis=1)) / t_max)\n",
" plt.plot(epsilons, mean_rewards, label=f\"t_max={t_max}\")\n",
"plt.legend()\n",
"plt.title(\"Mean reward as a function of epsilon.\")\n",
"plt.xlabel(\"Epsilon\")\n",
"plt.ylabel(\"Mean reward\")\n",
"plt.grid()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Questions\n",
"\n",
"2. With ε-greedy, what is the asymptotic probability of taking the optimal action?\n",
"\n",
"Si on imagine que l'on a la appris la politique optimal, alors il vient simplement que la probabilité de choisir le bon bandit est:\n",
"\n",
"$\\displaystyle(1 - \\epsilon) + \\epsilon * \\frac1K$\n",
"\n",
"3. Which ε is better for a relatively small of T ? and for large T\n",
"\n",
"Une valeur relativement faible de ε est meilleure pour une valeur relativement faible de T, car elle permet à l'algorithme d'explorer et d'apprendre davantage sur les bras en un temps plus court. Une valeur plus grande de ε est meilleure pour une grande valeur de T, car elle permet à l'algorithme de continuer à explorer et à apprendre sur les bras sur une plus longue période de temps.\n",
"\n",
"4. Do you observe some spikes in the plot of average rewards? if yes, please provide an explanation.\n",
"\n",
"On observer des instabilité lors des premières steps, cela est très probablement dû au fait que l'on divise par des nombres très petits au lorsque l'on a pas encore exploré toutes les machines. On observe aussi des pics lorsque l'algorithme explore un bras avec une faible valeur attendue et reçoit une récompense élevée, ce qui provoque un pic temporaire dans la récompense moyenne.\n",
"\n",
"6. What are your conclusions in terms of methods? Give some intuition.\n",
"\n",
"La méthode ε-greedy est une méthode simple et largement utilisée pour équilibrer l'exploration et l'exploitation dans l'apprentissage par renforcement. Sa version améliorée decaying epsilon greedy permet toute fois d'obtenir de meilleurs performances."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Parameter study by Learning curve"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"parameters = [1/128, 1/64, 1/32, 1/16, 1/8, 1/4, 1/2, 1, 2, 4]\n",
"\n",
"epsilon_greedy_rewards = []\n",
"gradient_rewards = []\n",
"ucb_rewards = []\n",
"optimistic_greedy_rewards = []\n",
"\n",
"for p in parameters:\n",
"\n",
" A, R = epsilon_greedy(epsilon=p)\n",
" epsilon_greedy_rewards.append(np.mean(np.sum(R, axis=1)))\n",
"\n",
" A, R = gradient(alpha=p)\n",
" gradient_rewards.append(np.mean(np.sum(R, axis=1)))\n",
"\n",
" A, R = UCB(c=p)\n",
" ucb_rewards.append(np.mean(np.sum(R, axis=1)))\n",
"\n",
" A, R = optimistic_greedy(Q0=p)\n",
" optimistic_greedy_rewards.append(np.mean(np.sum(R, axis=1)))\n"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjsAAAHLCAYAAAAurFnfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAADCMklEQVR4nOzddXxTVxvA8d+NNHWlCkUKhaJDhhR31+EybBtsQ8YYsDEFpmzvYMgMtsGGbdhwd5fhTotTSkupe5rc94/QQGmBFtretD3frR+Sm3vPfU6TJk/OPSLJsiwjCIIgCIJQSKmUDkAQBEEQBCEviWRHEARBEIRCTSQ7giAIgiAUaiLZEQRBEAShUBPJjiAIgiAIhZpIdgRBEARBKNREsiMIgiAIQqEmkh1BEARBEAo1kewIgiAIglCoiWRHECxc6dKlGTx4sNJhvLBNmzZRvXp1rK2tkSSJ6OhopUPKkiRJTJo0Sekwnuro0aPUr18fOzs7JEni5MmTOS6jdOnSdOzYMfeDEwQLJJIdgfnz5yNJEpIksW/fvkyPy7KMr68vkiSJN0fhudy/f59evXphY2PDjz/+yIIFC7Czs1Msng0bNlh8QvMker2enj17EhkZyfTp01mwYAGlSpXKct/z588zadIkrl+/nr9BCmaLFy/mhx9+UDqMIk+jdACC5bC2tmbx4sU0bNgww/bdu3dz+/ZtdDqdQpEJBd3Ro0eJi4vj888/p2XLlkqHw4YNG/jxxx+zTHiSkpLQaCz3rfHKlSvcuHGDuXPn8vrrrz913/PnzzN58mSaNm1K6dKl8ydAIYPFixdz9uxZxowZo3QoRZpo2RHM2rdvz7Jly0hLS8uwffHixdSqVQsvLy+FInt+RqOR5ORkpcN4qoSEBKVDyHPh4eEAODs7KxtINlhbW1t0slOQfpf5JTk5GaPRqHQY+aYgvK9ZGpHsCGZ9+/bl/v37bN261bwtNTWV5cuX069fvyyPMRqN/PDDD1SuXBlra2s8PT0ZPnw4UVFRGfZbvXo1HTp0wMfHB51OR9myZfn8888xGAwZ9mvatClVqlTh/PnzNGvWDFtbW4oXL863336brTpIksTIkSNZtGgRlStXRqfTsWnTJgBCQkIYOnQonp6e6HQ6KleuzB9//GE+VpZlihUrxtixYzPUz9nZGbVanaGPydSpU9FoNMTHxwNw+vRpBg8ejJ+fH9bW1nh5eTF06FDu37+fIb5JkyYhSRLnz5+nX79+uLi4mFvSZFnmiy++oESJEtja2tKsWTPOnTuXrXoD/O9//6N+/fq4ublhY2NDrVq1WL58eab9tm7dSsOGDXF2dsbe3p4KFSrw4YcfPrP8efPm0bx5czw8PNDpdFSqVImff/75mcc1bdqUQYMGAVC7dm0kSTL3QXpSf6SmTZvStGlT8/1du3YhSRJLly7lyy+/pESJElhbW9OiRQuCg4MzHX/48GHat2+Pi4sLdnZ2VKtWjRkzZgAwePBgfvzxRwDz5VtJkszHZtVn58SJE7Rr1w5HR0fs7e1p0aIFhw4dyrBP+uXg/fv3M3bsWNzd3bGzs6Nbt27cu3fvmb8ngB07dtCoUSPs7OxwdnamS5cuXLhwwfz44MGDadKkCQA9e/ZEkqQMv6fH4+nZsycAzZo1M9dz165dGfbbt28fderUwdraGj8/P/76669MZUVHRzNmzBh8fX3R6XSUK1eOqVOnZivBSO8btGXLFnOfrUqVKrFy5coM+0VGRjJu3DiqVq2Kvb09jo6OtGvXjlOnTmXYL/218Pfff/Pxxx9TvHhxbG1tiY2NzXEZS5cuZfLkyRQvXhwHBwd69OhBTEwMKSkpjBkzBg8PD+zt7RkyZAgpKSmZ6rZw4UJq1aqFjY0Nrq6u9OnTh1u3bpkfb9q0KevXr+fGjRvm3/+jLWwpKSl89tlnlCtXDp1Oh6+vLxMmTMh0rqe9rwnZY7lfX4R8V7p0aQIDA1myZAnt2rUDYOPGjcTExNCnTx9mzpyZ6Zjhw4czf/58hgwZwujRo7l27RqzZ8/mxIkT7N+/H61WC5jeeO3t7Rk7diz29vbs2LGDTz/9lNjYWL777rsMZUZFRdG2bVteeeUVevXqxfLly3n//fepWrWqOa6n2bFjB0uXLmXkyJEUK1aM0qVLExYWRr169cxvGu7u7mzcuJHXXnuN2NhYxowZgyRJNGjQgD179pjLOn36NDExMahUKvbv30+HDh0A2Lt3LzVq1MDe3h4wJRBXr15lyJAheHl5ce7cOebMmcO5c+c4dOhQhg9TMH1Q+fv789VXXyHLMgCffvopX3zxBe3bt6d9+/YcP36c1q1bk5qamq3nb8aMGXTu3Jn+/fuTmprK33//Tc+ePVm3bp057nPnztGxY0eqVavGlClT0Ol0BAcHs3///meW//PPP1O5cmU6d+6MRqNh7dq1vP322xiNRkaMGPHE4z766CMqVKjAnDlzmDJlCmXKlKFs2bLZqtPjvvnmG1QqFePGjSMmJoZvv/2W/v37c/jwYfM+W7dupWPHjnh7e/POO+/g5eXFhQsXWLduHe+88w7Dhw/nzp07bN26lQULFjzznOfOnaNRo0Y4OjoyYcIEtFotv/76K02bNmX37t3UrVs3w/6jRo3CxcWFzz77jOvXr/PDDz8wcuRI/vnnn6eeZ9u2bbRr1w4/Pz8mTZpEUlISs2bNokGDBhw/fpzSpUszfPhwihcvzldffcXo0aOpXbs2np6eWZbXuHFjRo8ezcyZM/nwww+pWLEigPlfgODgYHr06MFrr73GoEGD+OOPPxg8eDC1atWicuXKACQmJtKkSRNCQkIYPnw4JUuW5MCBA0ycOJHQ0NBs9UcJCgqid+/evPnmmwwaNIh58+bRs2dPNm3aRKtWrQC4evUqq1atomfPnpQpU4awsDB+/fVXmjRpwvnz5/Hx8clQ5ueff46VlRXjxo0jJSUFKysrzp8/n6Myvv76a2xsbPjggw8IDg5m1qxZaLVaVCoVUVFRTJo0iUOHDjF//nzKlCnDp59+aj72yy+/5JNPPqFXr168/vrr3Lt3j1mzZtG4cWNOnDiBs7MzH330ETExMdy+fZvp06cDmN8zjEYjnTt3Zt++fQwbNoyKFSty5swZpk+fzuXLl1m1alWGWLN6XxNyQBaKvHnz5smAfPToUXn27Nmyg4ODnJiYKMuyLPfs2VNu1qyZLMuyXKpUKblDhw7m4/bu3SsD8qJFizKUt2nTpkzb08t71PDhw2VbW1s5OTnZvK1JkyYyIP/111/mbSkpKbKXl5fcvXv3Z9YFkFUqlXzu3LkM21977TXZ29tbjoiIyLC9T58+spOTkzm+7777Tlar1XJsbKwsy7I8c+ZMuVSpUnKdOnXk999/X5ZlWTYYDLKzs7P87rvvPrV+S5YskQF5z5495m2fffaZDMh9+/bNsG94eLhsZWUld+jQQTYajebtH374oQzIgwYNembdH48hNTVVrlKlity8eXPztunTp8uAfO/evWeW96zyZVmW27RpI/v5+T3z2EdfY48qVapUlnVr0qSJ3KRJE/P9nTt3yoBcsWJFOSUlxbx9xowZMiCfOXNGlmVZTktLk8uUKSOXKlVKjoqKylDmo7/XESNGyE96+wPkzz77zHy/a9euspWVlXzlyhXztjt37sgODg5y48aNM9WxZcuWGc717rvvymq1Wo6Ojs7yfOmqV68ue3h4yPfv3zdvO3XqlKxSqeSBAwdm+l0sW7bsqeXJsiwvW7ZMBuSdO3dmeqxUqVKZXp/h4eGyTqeT33vvPfO2zz//XLazs5MvX76c4fgPPvhAVqvV8s2bN58aQ/p5VqxYYd4WExMje3t7yzVq1DBvS05Olg0GQ4Zjr127Jut0OnnKlCnmben19/Pzy/SazGkZVapUkVNTU83b+/btK0uSJLdr1y5DGYGBgXKpUqXM969fvy6r1Wr5yy+/zLDfmTNnZI1Gk2F7hw4dMhybbsGCBbJKpZL37t2bYfsvv/wiA/L+/fvN2570viZkn7iMJWTQq1cvkpKSWLduHXFxcaxbt+6Jl7CWLVuGk5MTrVq
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(epsilon_greedy_rewards, label=\"Epsilon-greedy\")\n",
"plt.plot(gradient_rewards, label=\"Gradient\")\n",
"plt.plot(ucb_rewards, label=\"UCB\")\n",
"plt.plot(optimistic_greedy_rewards, label=\"Optimistic-greedy\")\n",
"plt.title(\"Mean reward as a function of the parameter.\")\n",
"plt.xlabel(\"Parameter\")\n",
"plt.ylabel(\"Mean reward\")\n",
"plt.xticks(range(10), [rf\"$2^{{ {i} }}$\" for i in range(-7, 3)])\n",
"plt.legend()\n",
"plt.grid()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. What is the best algorithm for this example ?\n",
"\n",
"UCB\n",
"\n",
"2. Can you comment on the shape of the curves ? what is the optimal parameter in each of the algorithms ?\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best parameter for epsilon-greedy: 0.0078125\n",
"Best parameter for gradient: 0.5\n",
"Best parameter for UCB: 0.0625\n",
"Best parameter for optimistic-greedy: 0.5\n"
]
}
],
"source": [
"print(\"Best parameter for epsilon-greedy:\", parameters[np.argmax(epsilon_greedy_rewards)])\n",
"print(\"Best parameter for gradient:\", parameters[np.argmax(gradient_rewards)])\n",
"print(\"Best parameter for UCB:\", parameters[np.argmax(ucb_rewards)])\n",
"print(\"Best parameter for optimistic-greedy:\", parameters[np.argmax(optimistic_greedy_rewards)])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"vscode": {
"interpreter": {
"hash": "2e7007663510ed9db58bd00c4b6768d5d37222230e04b44bb35e77da9185a2df"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}