diff --git a/analysis.ipynb b/analysis.ipynb index 29dd725..bad0f33 100644 --- a/analysis.ipynb +++ b/analysis.ipynb @@ -1,5 +1,84 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/xiayunsun/code/venvs/py367/lib/python3.6/site-packages/gym/envs/registration.py:14: PkgResourcesDeprecationWarning: Parameters to load are deprecated. Call .resolve and .require separately.\n", + " result = entry_point.load(False)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(210, 160, 3)\n", + "0.78125\n", + "(210, 160)\n", + "(128, 128)\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAADqCAYAAACssY5nAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAH1RJREFUeJzt3XuQXGW57/Hvr7tnBjJJCCSawiRK9t4RCxWjpiJ1vFQ0W3bgeIi34oAC4VbREs7ebpSb2zpg7bIKjwKHLW4kSA6himuBHCKFXBVl1xYkhBSXQDTcJDkhIZN7JsncnvNHr247c+2Z7p7uXvl9qlK91tvr8vTknWfe9a6316uIwMzM0itT7wDMzKy2nOjNzFLOid7MLOWc6M3MUs6J3sws5ZzozcxSrmaJXtIiSeskrZd0Wa3OY2Zmw1MtxtFLygJ/Aj4PbACeAU6PiLVVP5mZmQ2rVi36+cD6iHgtIrqAO4HFNTqXmZkNI1ej484A3ipZ3wB8YqiNJfnruVZrWyPiXfUOwqweapXoRyRpKbC0Xue3Q86b9Q7ArF5qleg3ArNK1mcmZUURsQxYBm7Rm5nVUq366J8B5kiaLakVOA1YWaNzmZnZMGrSoo+IHkkXAg8DWWB5RLxUi3OZmdnwajK8ctRBNEjXzTXXXDOq7S+66KIx71/NfStVz3MPF0eVz/VsRMyr5gHNmoW/GWtmlnJ1G3XTDCppdffffzyvFipRyxa7mdWHW/RmZinnFr0dZKQrB7f4zZqPE70Nm7zHq8vIzGrHXTdmZinnFv0wKm3NVrL/eLak3Wo3Sze36M3MUs5fmLJDhb8wZYcst+jNzFKuIfroZ86c6WF7VlOuX3Yoc4vezCzlnOjNzFLOid7MLOWc6M3MUm7MiV7SLEm/lbRW0kuS/ikpv1LSRklrkn8nVy9cs/qStEjSOknrJV1W73jMylHJqJse4DsRsVrSJOBZSY8m710bET+pPDyzxiEpC/wM+DywAXhG0sqIWFvfyMyGN+YWfURsiojVyfJu4GVgRrUCM2tA84H1EfFaRHQBdwKL6xyT2YiqMo5e0jHAR4GngU8CF0o6C1hFvtW/fZB9lgJLAY488shqhGFWazOAt0rWNwCf6L9Rad3OZDIfP+yww8YnuhGU+y14SWPed7D9K9m3UpV85lrGUa3zHThwgJ6enhEPVnGilzQRuBf4dkTsknQD8K9AJK9XA+f23y8ilgHLAGbNmuVHIFhqlNbt9vb2+MAHPlDniPL27NlDJpMhIoZNxm1tbQC0tLQQEfT29rJv376D9ildLuwrib6+PiZMmEAmkylus3//fnp6egbsV7q/JCKCiRMnDohnLArn2bNnT/HYwyXXXC6fCtva2io+d399fX10dnaSyeQ7UAqfsxrneeWVV8rarqJRN5JayCf52yLilwARsTkieiOiD7iJ/OWuWRpsBGaVrM9Myswa2phb9Mr/ebwZeDkirikpPzoiNiWrXwJerCxEs4bxDDBH0mzyCf404Gv1DWlofX19tLa2FluSTz75JFu2bKGlpYW+vr6Dti20didOnMj1118PwAMPPEAul2P+/Pn84Ac/YMeOHQBkMpmD9i8cv7u7m+nTp3PGGWewY8cOtm3bBsAll1zCwoUL6e7uJiIGtGSz2SxdXV0cccQRnHjiiQAcccQRHDhwoHjs0XzmtrY2du/eDcBjjz3Gjh07aG1tpbe3d8D2kmhpaeHXv/41ANdff31VWtuS6O3tJZvN8r73vY+bbrqJTZvyaXHChAksWrSIadOm0d3dXdF5ylVJ180ngTOBFyStScq+B5wuaS75rps3gG9UFKFZg4iIHkkXAg8DWWB5RLxU57DKVuieaG9v5+qrr+bwww8HYPv27Zx//vlMnjyZlpaWAV0chWRYSLq9vb389Kc/5cgjj2T37t1cfPHFxfLCOUplMhlyuRyS+NOf/sSvfvWrYhxtbW1861vfYuvWrYPuW6mWlhYAJk2axI033khXVxfAQV05mUyGN998s7h9tbtuJJHL5Yo/v1p8zpGM+YwR8R/AYJ1eD449HLPGFhEP0qR1vKWlhWw2S3t7O0888QSTJ08GYOPGjZx11llks9liQi7VP9F3d3fz2GOPMWPGDN555x2+//3vA9DZ2TnoH4pCos9kMmzbto3HHnsMgGnTptHe3s7FF1/M9u3byeVyVU+ymUyGTCZDa2srr7/+ejHR9/T0kM1mkYQktm/PjxepRRIuJPrCH51sNlt8r6+vb9RXLWPREE+vHImfPGgj8SxZ5ZNEd3d3MemVdh+UOxqkq6vroGOUq7e3t3i+7u5uuru7i8m2miNfCl1LkyZN4sCBA/T19XHttdcWzyGJffv20dvbS1tbG/fffz8AP/7xj5k0aVJNWvWly+M9D0hTJHozq1whmWYyGXbs2FFMNtu2baO3t3fYZFv6XqFlPmnSJLZv315skQ63b2m3T6HvPpfLFfuxh9t/NDKZDAcOHCiO3pk+fTq5XK7Yci7EsXXrVm699VYWLlxIRAy4Z1GpwRJ5rYdwDsfPujEzSzm36M0OES0tLUyYMIFdu3bxyiuvHDT+fc+ePcV+60ILu/T9XC5H4Ytf7e3tbNq0qdhKLozGmTBhQrEfuq+vr7h/4Zh9fX2cdNJJfOUrXykePyLo6Ojg8MMPr0offWHUzc6dOwHYuXMnHR0dtLa2HnRsSXR2dtLb23tQS7807mpqbW0t9v8XfkbjyYnerMbqMcqioLQ/+IwzzqCnp6f4xab+2xW6WApdK4UvPq1du5azzjrroH0GG14ZEeRyOXbu3Ek2my1+4/22227jvvvuK27TP5GWfomrcIMY8slxrApdN1/+8peL3VL9lXZH7d27F4ApU6ZU/eboli1b+PrXv37Ql8aOOuoooLLPOBpO9GY1dMQRR3DyyY3xANfSlmT/RF+a3ArjzQvbFFr0pQZL9AUf/vCHD0rm2Wz2oG2GO/dHPvKR8j5MmQrfGeg/9r//eQvvDTbWvlKD/fyq9TnfeuutkTfCid6spqZOncqZZ55Z7zAspe65556ytnOiN6shSQeNmzarpnJH8njUjZlZyjnRm5mlnLtuzGpsvL8FadafW/RmZinnRG9mlnJO9GZmKedEb2aWctWYM/YNYDfQC/RExDxJRwF3AceQn3zk1MEmCDdrRJJmAbcC08lPoLMsIq6rZr0uzEBUi29iWroUnhVUyU39ao26+WxEbC1Zvwx4PCKuknRZsn5plc5lVms9wHciYrWkScCzkh4FzqZK9TqTyfDOO++wZcuW4rpZqcJjGd797nczffr0ihoFtRpeuRhYkCyvAJ7Aid6aRDLn8aZkebekl4EZVFCvC99gLLTK2traeOSRR1ixYgUAhx9+eF0mpLDGU6gH+/btA2DJkiWcf/75dHZ2Ft+H0Q3brUaiD+ARSQHcGBHLgOklE4S/Tf4S+CCSlgJLgeJT7swajaRjgI8CT1NGvU72Kdbt97znPcMde8CkHfWcnMIax2B1oxLVSPSfioiNkt4NPCrpldI3IyKSPwL0K18GLAOYNWuWmzHWcCRNBO4Fvh0Ru0p/4Yaq18l7xbp9/PHHl1W33XVjtVRx7YqIjcnrFuA+YD6wWdLRAMnrlkrPYzaeJLWQT/K3RcQvk+Ka1mu35g1qUw8qSvSS2pObVUhqB04EXgRWAkuSzZYA91dyHrPxpPxv2s3AyxFROuu467U1pUq7bqYD9yV/gXLA7RHxkKRngLslnQe8CZxa4XnMxtMngTOBFyStScq+B1yF67U1oYoSfUS8BgyYKiUiOoCFlRzbrF4i4j+Aoa6fXa+t6TTF0yufWrSo3iFYg/vPegdg1sB8q9/MLOWc6M3MUs6J3sws5ZzozcxSzonezCzlmmLUTd/f7ap3CGZmTaspEr1ZGpU+uKp/uR2aSp9I2WgPNTOzMShN9NX8pbbmVfoU02rWCffRm5mlnFv0ZnXS3d1NV1cXkG/RFyaccKv+0FX4/4+IYt3o7u6u+LhO9GbjoP9sQPv37+eUU05hwYIF9QnImsbkyZPZv39/cX0ss5A1RaLfNrmz3iGYVZUkuru72bt3L+CJR2ygwpyxhWkmK9EUid4srQqtM88Va/1Vs0440ZvVQUTQ2trKxIkTAbfobaBCi761tbXipD/m2iXpWElrSv7tkvRtSVdK2lhSfnJFEZrViaSspOckPZCsz5b0tKT1ku6S1FrvGM3KMeYWfUSsA+ZC/hcC2Eh+zthzgGsj4idVidCsfv4JeBmYnKz/iHzdvlPSz4HzgBvGcuCIoKWlpdiiNxtKS0tLxS36anXdLARejYg3azE0bNsHuqp+TEuZrdU9nKSZwH8FfghclMwj+znga8kmK4ArGWOiH+R81TiMpUg1++ir1TF4GnBHyfqFkp6XtFzSkYPtIGmppFWSVhVGHpg1kP8NXAL0JetTgR0R0ZOsbwBmDLZjad3u6OiofaRmI6i4RZ/0U54CXJ4U3QD8KxDJ69XAuf33i4hlwDKAWbNmeciBNQxJXwC2RMSzkhaMdv/Sun388ccPWbczmYwffWBDKrToq3GjvhpdNycBqyNiM0DhFUDSTcADVTiH2Xj6JHBKMpDgMPJ99NcBUyTlklb9TPL3pcbMQyqtHNWoJ9Xoujmdkm4bSUeXvPcl4MUqnMNs3ETE5RExMyKOId8t+ZuI+DrwW+CryWZLgPvrFKLZqFSU6CW1A58HfllS/L8kvSDpeeCzwD9Xcg6zBnIp+Ruz68n32d9cycEkFf+Z9VfN+lFR101E7CVf4UvLzqwookHc3vfeah/SUubEGh03Ip4AnkiWXwPm1+I8TvbWn78Za9bkIoIJEyYwZcqUg8qc8K1/Pejq6qKnp2eYPUbmRG9WR/2fdeMbtAbV/6PvRG9WR07sNpRG/MKUmZk1KLfozepAEp2dncVZhMyG0traSltbW0Ut/KZI9F13XlnvEKzRnfif9Y5gVCTR1dVFZ2dncd2sVGliP+yww9Kf6M3Syn30NhQPrzRrchFBLpejpaUFcIveBiok+lwu1zCPKTazUchms/zlL39h3bp1xXWzUr29vQAce+yxHHfccRWNpfeoGzOzlHOL3qxOstksuVz+V7DwalZQ6M6rxtVeU9Su3zx0Qr1DsAb3hROvqXcIo9bX11e8PHcfvfVXqBuFScIr0RSJ3iyNuru7i8Mr3aK3/gp98t3d3RUfy7XLbBwUWuyF0ROtra2sXr2aFStWADBhwoS6xWaNqdAIWLJkCfPmzSsm/v51qRxlJXpJy4HC9GofSsqOAu4CjgHeAE6NiO3JJMrXAScDncDZEbG67IjMGoCkKcAvgA+RnxbzXGAdg9T5sZ6jtI++MK1gNS7TrbkV6kGhblSjj77cUTe3AIv6lV0GPB4Rc4DHk3XITy04J/m3lPwcsmbN5jrgoYj4APAR4GWGrvNmDa2sFn1E/F7SMf2KFwMLkuUV5CdnuDQpvzXy1xVPSZoi6eiI2FSNgM1qTdIRwGeAswEiogvokjRUna9YJpNBUlUmgrbmVot6UMnRppck77eB6cnyDOCtku02JGVmzWI28A7wfyQ9J+kXybSZQ9V5s4ZWlT8bSet9VN/RlbRU0ipJq/bu3VuNMMyqJQd8DLghIj4K7KVfN81wdb60bnd0dNQ8WLORVJLoN0s6GiB53ZKUbwRmlWw3Myk7SEQsi4h5ETGvvb29gjDMqm4DsCEink7W7yGf+Ieq8wcprdtTp04dbBOzcVVJol8JLEmWlwD3l5SfpbwTgJ3un7dmEhFvA29JOjYpWgisZeg6b9bQyh1eeQf5m1DTJG0ArgCuAu6WdB7wJnBqsvmD5IdWric/vPKcKsdsNh7+B3CbpFbgNfL1OMPgdd6soZU76ub0Id5aOMi2AVxQSVBm9RYRa4B5g7w1oM6bNTqP5TIzSzknejOzlHOiNzNLOSd6M7OUc6I3M0s5J3ozs5RzojczSzknejOzlHOiNzNLOSd6M7OUc6I3M0s5J3ozs5RzojczSzknejOzlHOiNzNLuRETvaTlkrZIerGk7MeSXpH0vKT7JE1Jyo+RtE/SmuTfz2sZvFmtSPpnSS9JelHSHZIOkzRb0tOS1ku6K5mUxKzhldOivwVY1K/sUeBDEXE88Cfg8pL3Xo2Iucm/b1YnTLPxI2kG8I/AvIj4EJAFTgN+BFwbEX8HbAfOq1+UZuUbMdFHxO+Bbf3KHomInmT1KfITgJulSQ44XFIOmABsAj5HfqJwgBXAF+sUm9moVKOP/lzg1yXrsyU9J+l3kj491E6SlkpaJWnV3r17qxCGWXVExEbgJ8BfyCf4ncCzwI6SBs4GYEZ9IjQbnYoSvaR/AXqA25KiTcB7I+KjwEXA7ZImD7ZvRCyLiHkRMa+9vb2SMMyqStKRwGJgNvAeoJ2B3ZfD7V9sxHR0dNQoSrPyjTnRSzob+ALw9WRCcCLiQER0JMvPAq8C769CnGbj6e+B1yPinYjoBn4JfBKYknTlQL67cuNgO5c2YqZOnTo+EZsNY0yJXtIi4BLglIjoLCl/l6Rssvw3wBzgtWoEajaO/gKcIGmCJAELgbXAb4GvJtssAe6vU3xmo1LO8Mo7gD8Ax0raIOk84HpgEvBov2GUnwGel7SG/E2rb0bEtkEPbNagIuJp8vV3NfAC+d+TZcClwEWS1gNTgZvrFqTZKORG2iAiTh+keNAKHhH3AvdWGpRZvUXEFcAV/YpfA+bXIRyzivibsWZmKedEb2aWck70ZmYp50RvZpZyTvRmZinnRG9mlnJO9GZmKedEb2aWck70ZmYp50RvZpZyTvRmZinnRG9mlnJO9GZmKedEb2aWck70ZmYpV87EI8slbZH0YknZlZI2JpOOrJF0csl7l0taL2mdpH+oVeBmlRqibh8l6VFJf05ej0zKJenfkrr9vKSP1S9ys9Epp0V/C4NPjHxtRMxN/j0IIOk44DTgg8k+/16YWtCsAd3CwLp9GfB4RMwBHk/WAU4iPzXmHGApcMM4xWhWsRETfUT8Hih3OsDFwJ3JJOGvA+vxjDzWoIao24uBFcnyCuCLJeW3Rt5T5CcKP3p8IjWrTCV99Bcml7DLC5e3wAzgrZJtNiRlA0haKmmVpFV79+6tIAyzqpoeEZuS5beB6cnymOp2R0dH7SI1K9NYE/0NwN8Cc4FNwNWjPUBELIuIeRExr729fYxhmNVORAQQY9ivWLenTp1ag8jMRmdMiT4iNkdEb0T0ATfx1+6ZjcCskk1nJmVmzWJzoUsmed2SlLtuW9MaU6Lv1zf5JaAwamElcJqkNkmzyd+4+mNlIZqNq5XAkmR5CXB/SflZyeibE4CdJV08Zg0tN9IGku4AFgDTJG0ArgAWSJpL/rL2DeAbABHxkqS7gbVAD3BBRPTWJnSzygxRt68C7pZ0HvAmcGqy+YPAyeQHGHQC54x7wGZjNGKij4jTBym+eZjtfwj8sJKgzMbDEHUbYOEg2wZwQW0jMqsNfzPWzCzlnOjNzFLOid7MLOWc6M3MUs6J3sws5ZzozcxSzonezCzlnOjNzFJuxC9MmVnl8t+3MqvcWOqSW/Rm40ASkuodhqXAWOqSW/RmNZbJ/LU9VfgFzWQyTvxV1tfXx65duw76eY9FRDBlypSGuQrLZDLkcjlyubGnayd6sxrq6enhnXfeKa5HBJJoa2ujs7Oz4qR0qOvr6wMgl8tx1FFHcfvtt9PR0VFMiuUm68If3e7ubiZMmMDXvvY1pkyZQldXF8C4/j8VYi58hs2bN/Pcc89x4MCBAdvu27evrGO6lpmZpZxb9GY1VtoaLLTos9msu26qrLe3l46ODnbs2EE2mx3VvqUt+gMHDhzUxVYvhZZ9Npsll8vR2zvwie/l1iEnerMa6//L6Buz1VNIxBHBvn37WLlyJfv37x/1z7ewfV9fH7lcjpaWlmK3UL0UYurt7aWnp4eenp4B25TbNVXOxCPLgS8AWyLiQ0nZXcCxySZTgB0RMVfSMcDLwLrkvaci4ptlRWI2zoao2z8G/hvQBbwKnBMRO5L3LgfOA3qBf4yIh8s4x6A30VpbW0fd6rShRQTd3d387ne/q0qCbm1trXuiL9i/fz/btm0btI9+sFb+YMpp0d8CXA/cWiiIiP9eWJZ0NbCzZPtXI2JuWWc3q69b6Fe3gUeByyOiR9KPgMuBSyUdB5wGfBB4D/CYpPePNIPahg0b+O53vzugXBJ79uyhtbW1WNYoozyaVUQM2uptVoW68Yc//IEnn3xy0G6kt99+u6xjlTPD1O+TlvoAyl9bnAp8rqyzmTWQwep2RDxSsvoU8NVkeTFwZ0QcAF6XtB6YD/xhuHP09PSwffv2AeWS6O3t9agbG1FheGVFx6gwhk8DmyPizyVlsyU9J+l3kj491I6SlkpaJWnV3r17KwzDrCbOBX6dLM8A3ip5b0NSNkBp3U5TC9OaV6U3Y08H7ihZ3wS8NyI6JH0c+L+SPhgRu/rvGBHLgGUAs2bN8jWrNRRJ/0J+gvvbRrtvad1ub28PJ3ur1FBXfjUfdSMpB3wZ+HihLLmsPZAsPyvpVeD9wKqxnsdsvEk6m/xN2oXx147zjcCsks1mJmVmDa+Srpu/B16JiA2FAknvkpRNlv8GmAO8VlmIZuNH0iLgEuCUiOgseWslcJqkNkmzydftP9YjRrPRGjHRS7qD/A2nYyVtkHRe8tZpHNxtA/AZ4HlJa4B7gG9GxLZqBmxWLUPU7euBScCjktZI+jlARLwE3A2sBR4CLhhpxI1Zoyhn1M3pQ5SfPUjZvcC9lYdlVntD1O2bh9n+h8APaxeRWW14bJeZWco50ZuZpZwTvZlZyvmhZmY11NnZuXX16tV7ga31jmWcTMOfdTy9r5yNnOjNaigi3iVpVUTMq3cs48GftTG568bMLOWc6M3MUq4hum52Zvt4YMqeeodxSHpq0aIx73vCQw9VMZLK/JdHHhl5o/pZVu8AxpE/awNyi96sxpKHnB0S/FkbkxO9mVnKOdGbmaVcQ/TRW/00Uj972iRPwrwOyAK/iIir6hxSVUl6A9hNfg7dnoiYJ+ko4C7gGOAN4NSIGDjFVhMYYk7hQT9fMtvedcDJQCdwdkSsrkfcg3GL3qwGksd1/ww4CTgOOD2ZdzZtPhsRc0vGk18GPB4Rc4DHk/VmdQvQf7TCUJ/vJPKPrp4DLAVuGKcYy+IWvaVCA16ZzAfWR8RrAJLuJD/v7Nq6RlV7i4EFyfIK4Ang0noFU4kh5sse6vMtBm5NJqp5StIUSUdHxKbxiXZ4IyZ6SbOAW4HpQADLIuK6al7C7H7z//Gbc/9nJZ/DrNEMNsfsJ+oUS60E8IikAG5MRqFML0lub5PPG2ky1Ocbak7hhkj05XTd9ADfiYjjgBOAC5JL0Ka8hDGzqvlURHyM/O/8BZI+U/pm0rpN7XzQzfT5Rkz0EbGp0CKPiN3Ay+T/Ui0mf+lC8vrFZLl4CRMRTwFTJB1d9cjNGlvq55iNiI3J6xbgPvLdVZsLv+/J65b6RVgTQ32+hv7/HtXN2KS/6qPA04z+EsbsUPIMMEfSbEmt5KfeXFnnmKpGUrukSYVl4ETgRfKfcUmy2RLg/vpEWDNDfb6VwFnKOwHY2Sj98zCKm7GSJpKfJvDbEbEr3xWfFxGR9NOVTdJS8l07ZqkTET2SLgQeJj+8cnky72xaTAfuS/JADrg9Ih6S9AxwdzL/7pvAqXWMsSLJnMILgGmSNgBXAFcx+Od7kPx9yfXk702eM+4BD0P5bqYRNpJagAeAhyPimqRsHbAgIjYllzBPRMSxkm5Mlu/ov90wx2+Kfi5ras82yyNlzaptxK6bZBTNzcDLhSSfaMpLGDOzQ82ILXpJnwKeBF4A+pLi75Hvp78beC/JJUxEbEv+MFxP/osGncA5EbFqhHO4RW+15ha9HbLK6rqpeRBO9FZ7TvR2yPIjEMzMUs6J3sws5ZzozcxSzonezCzlGuXplVuBvclrM5pG88YOzR1/ubG/r9aBmDWqhhh1AyBpVbOOimjm2KG542/m2M3Gi7tuzMxSzonezCzlGinRL6t3ABVo5tihueNv5tjNxkXD9NGbmVltNFKL3szMaqDuiV7SIknrJK2X1BQzxkt6Q9ILktZIWpWUHSXpUUl/Tl6PrHecBZKWS9oi6cWSskHjTZ46+m/J/8fzkj5Wv8iHjP1KSRuTn/8aSSeXvHd5Evs6Sf9Qn6jNGktdE72kLPAz8nNOHgecnsxH2ww+GxFzS4b2DTWHbiO4hfzTREs1y5y/tzAwdoBrk5//3Ih4ECCpO6cBH0z2+fekjpkd0urdop8PrI+I1yKiC7iT/JyzzWioOXTrLiJ+D2zrV9wUc/4OEftQFgN3RsSBiHid/Gw/82sWnFmTqHeib9b5ZQN4RNKzyZSIMPQcuo2q2ef8vTDpWlpe0k3WLLGbjat6J/pm9amI+Bj5bo4LJH2m9M3ID2VqmuFMzRYv+e6kvwXmApuAq+sbjlljq3ei3wjMKlmfmZQ1tIjYmLxuAe4j3z2wudDFkbxuqV+EZRkq3ob/P4mIzRHRGxF9wE38tXum4WM3q4d6J/pngDmSZktqJX8jbWWdYxqWpHZJkwrLwInAiww9h26jato5f/vdM/gS+Z8/5GM/TVKbpNnkbyj/cbzjM2s0dX16ZUT0SLoQeBjIAssj4qV6xlSG6cB9+alxyQG3R8RDkp4B7pZ0HskcunWM8SCS7gAWANMkbQCuAK5i8HgfBE4mfyOzEzhn3AMuMUTsCyTNJd/d9AbwDYCIeEnS3cBaoAe4ICJ66xG3WSPxN2PNzFKu3l03ZmZWY070ZmYp50RvZpZyTvRmZinnRG9mlnJO9GZmKedEb2aWck70ZmYp9/8Bm9qAMTRCeMMAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# visualise DQN preprocess\n", + "import gym\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "import numpy as np\n", + "import scipy.ndimage\n", + "\n", + "env = gym.make('BreakoutNoFrameskip-v4')\n", + "env.reset()\n", + "# accumulate 5 frames\n", + "frames = []\n", + "for i in range(5):\n", + " frame, reward, is_done, info = env.step(env.action_space.sample())\n", + " frames.append(frame)\n", + "print(frames[0].shape)\n", + "\n", + "plt.figure(1)\n", + "plt.subplot(121)\n", + "plt.imshow(frames[0])\n", + "plt.subplot(122)\n", + "\n", + "frame = np.maximum(frames[0], frames[1])\n", + "frame = np.divide(frame, 256)\n", + "print(np.amax(frame))\n", + "# greyscale\n", + "frame = np.dot(frame[...,:3], [0.299, 0.587, 0.114])\n", + "print(frame.shape)\n", + "# rescale\n", + "frame = scipy.ndimage.interpolation.zoom(frame, zoom = np.divide(128, frame.shape))\n", + "print(frame.shape)\n", + "\n", + "plt.imshow(frame, cmap = plt.get_cmap('gray'))\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": 12, @@ -47,13 +126,6 @@ "ax.legend(loc=4)\n", "plt.show()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/common/preprocess.py b/common/preprocess.py new file mode 100644 index 0000000..4bbe9fc --- /dev/null +++ b/common/preprocess.py @@ -0,0 +1,34 @@ +import numpy as np +import scipy.ndimage + + +def dqn_preprocess(frames, frames_to_stack=4, cropped_size=128.): + ''' + preprocess following dqn nature paper + 1. remove flickering: max(pixel) over previous frame + 2. Convert RGB to greyscale + 3. rescale to 84x84 + 4. do 1-4 for 4 frames and stack them + :param frames: last 5 frames of shape (210, 160, 3), we need 5 to perform the remove flickering step + :return: (84,84,4) + ''' + RGB = 256 + assert len(frames) == frames_to_stack + 1 + + processed_frames = [] + for i in range(1, frames_to_stack + 1): + current_frame = frames[i] + prev_frame = frames[i - 1] + # step 1 + frame = np.maximum(current_frame, prev_frame) + # step 2 + # first normalise + assert np.amax(frame) <= RGB + frame = np.divide(frame, RGB) + frame = np.dot(frame[..., :3], [0.299, 0.587, 0.114]) # (210, 160) + # step 3 + frame = scipy.ndimage.interpolation.zoom(frame, zoom=np.divide(cropped_size, frame.shape)) + + processed_frames.append(frame) + + return np.dstack(processed_frames) diff --git a/tests/test_proprocess.py b/tests/test_proprocess.py new file mode 100644 index 0000000..4a419e9 --- /dev/null +++ b/tests/test_proprocess.py @@ -0,0 +1,17 @@ +import unittest + +import gym + +from common.preprocess import dqn_preprocess + + +class TestPreprocess(unittest.TestCase): + def test_dqn_preprocess(self): + env = gym.make('BreakoutNoFrameskip-v0') + env.reset() + frames = [] + for i in range(5): + observation, _, _, _ = env.step(env.action_space.sample()) + frames.append(observation) + result = dqn_preprocess(frames) + assert result.shape == (128, 128, 4)