diff --git a/analysis.ipynb b/analysis.ipynb index 29dd725..bad0f33 100644 --- a/analysis.ipynb +++ b/analysis.ipynb @@ -1,5 +1,84 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/xiayunsun/code/venvs/py367/lib/python3.6/site-packages/gym/envs/registration.py:14: PkgResourcesDeprecationWarning: Parameters to load are deprecated. Call .resolve and .require separately.\n", + " result = entry_point.load(False)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(210, 160, 3)\n", + "0.78125\n", + "(210, 160)\n", + "(128, 128)\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# visualise DQN preprocess\n", + "import gym\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "import numpy as np\n", + "import scipy.ndimage\n", + "\n", + "env = gym.make('BreakoutNoFrameskip-v4')\n", + "env.reset()\n", + "# accumulate 5 frames\n", + "frames = []\n", + "for i in range(5):\n", + " frame, reward, is_done, info = env.step(env.action_space.sample())\n", + " frames.append(frame)\n", + "print(frames[0].shape)\n", + "\n", + "plt.figure(1)\n", + "plt.subplot(121)\n", + "plt.imshow(frames[0])\n", + "plt.subplot(122)\n", + "\n", + "frame = np.maximum(frames[0], frames[1])\n", + "frame = np.divide(frame, 256)\n", + "print(np.amax(frame))\n", + "# greyscale\n", + "frame = np.dot(frame[...,:3], [0.299, 0.587, 0.114])\n", + "print(frame.shape)\n", + "# rescale\n", + "frame = scipy.ndimage.interpolation.zoom(frame, zoom = np.divide(128, frame.shape))\n", + "print(frame.shape)\n", + "\n", + "plt.imshow(frame, cmap = plt.get_cmap('gray'))\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": 12, @@ -47,13 +126,6 @@ "ax.legend(loc=4)\n", "plt.show()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/common/preprocess.py b/common/preprocess.py new file mode 100644 index 0000000..4bbe9fc --- /dev/null +++ b/common/preprocess.py @@ -0,0 +1,34 @@ +import numpy as np +import scipy.ndimage + + +def dqn_preprocess(frames, frames_to_stack=4, cropped_size=128.): + ''' + preprocess following dqn nature paper + 1. remove flickering: max(pixel) over previous frame + 2. Convert RGB to greyscale + 3. rescale to 84x84 + 4. do 1-4 for 4 frames and stack them + :param frames: last 5 frames of shape (210, 160, 3), we need 5 to perform the remove flickering step + :return: (84,84,4) + ''' + RGB = 256 + assert len(frames) == frames_to_stack + 1 + + processed_frames = [] + for i in range(1, frames_to_stack + 1): + current_frame = frames[i] + prev_frame = frames[i - 1] + # step 1 + frame = np.maximum(current_frame, prev_frame) + # step 2 + # first normalise + assert np.amax(frame) <= RGB + frame = np.divide(frame, RGB) + frame = np.dot(frame[..., :3], [0.299, 0.587, 0.114]) # (210, 160) + # step 3 + frame = scipy.ndimage.interpolation.zoom(frame, zoom=np.divide(cropped_size, frame.shape)) + + processed_frames.append(frame) + + return np.dstack(processed_frames) diff --git a/tests/test_proprocess.py b/tests/test_proprocess.py new file mode 100644 index 0000000..4a419e9 --- /dev/null +++ b/tests/test_proprocess.py @@ -0,0 +1,17 @@ +import unittest + +import gym + +from common.preprocess import dqn_preprocess + + +class TestPreprocess(unittest.TestCase): + def test_dqn_preprocess(self): + env = gym.make('BreakoutNoFrameskip-v0') + env.reset() + frames = [] + for i in range(5): + observation, _, _, _ = env.step(env.action_space.sample()) + frames.append(observation) + result = dqn_preprocess(frames) + assert result.shape == (128, 128, 4)