434 lines
14 KiB
Plaintext
434 lines
14 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# 405 DQN Reinforcement Learning\n",
|
|
"\n",
|
|
"View more, visit my tutorial page: https://morvanzhou.github.io/tutorials/\n",
|
|
"My Youtube Channel: https://www.youtube.com/user/MorvanZhou\n",
|
|
"More about Reinforcement learning: https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/\n",
|
|
"\n",
|
|
"Dependencies:\n",
|
|
"* torch: 0.1.11\n",
|
|
"* gym: 0.8.1\n",
|
|
"* numpy"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"from torch.autograd import Variable\n",
|
|
"import torch.nn.functional as F\n",
|
|
"import numpy as np\n",
|
|
"import gym"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[2017-06-20 22:23:40,418] Making new env: CartPole-v0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Hyper Parameters\n",
|
|
"BATCH_SIZE = 32\n",
|
|
"LR = 0.01 # learning rate\n",
|
|
"EPSILON = 0.9 # greedy policy\n",
|
|
"GAMMA = 0.9 # reward discount\n",
|
|
"TARGET_REPLACE_ITER = 100 # target update frequency\n",
|
|
"MEMORY_CAPACITY = 2000\n",
|
|
"env = gym.make('CartPole-v0')\n",
|
|
"env = env.unwrapped\n",
|
|
"N_ACTIONS = env.action_space.n\n",
|
|
"N_STATES = env.observation_space.shape[0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Net(nn.Module):\n",
|
|
" def __init__(self, ):\n",
|
|
" super(Net, self).__init__()\n",
|
|
" self.fc1 = nn.Linear(N_STATES, 10)\n",
|
|
" self.fc1.weight.data.normal_(0, 0.1) # initialization\n",
|
|
" self.out = nn.Linear(10, N_ACTIONS)\n",
|
|
" self.out.weight.data.normal_(0, 0.1) # initialization\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" x = self.fc1(x)\n",
|
|
" x = F.relu(x)\n",
|
|
" actions_value = self.out(x)\n",
|
|
" return actions_value"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class DQN(object):\n",
|
|
" def __init__(self):\n",
|
|
" self.eval_net, self.target_net = Net(), Net()\n",
|
|
"\n",
|
|
" self.learn_step_counter = 0 # for target updating\n",
|
|
" self.memory_counter = 0 # for storing memory\n",
|
|
" self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2)) # initialize memory\n",
|
|
" self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)\n",
|
|
" self.loss_func = nn.MSELoss()\n",
|
|
"\n",
|
|
" def choose_action(self, x):\n",
|
|
" x = Variable(torch.unsqueeze(torch.FloatTensor(x), 0))\n",
|
|
" # input only one sample\n",
|
|
" if np.random.uniform() < EPSILON: # greedy\n",
|
|
" actions_value = self.eval_net.forward(x)\n",
|
|
" action = torch.max(actions_value, 1)[1].data.numpy()[0, 0] # return the argmax\n",
|
|
" else: # random\n",
|
|
" action = np.random.randint(0, N_ACTIONS)\n",
|
|
" return action\n",
|
|
"\n",
|
|
" def store_transition(self, s, a, r, s_):\n",
|
|
" transition = np.hstack((s, [a, r], s_))\n",
|
|
" # replace the old memory with new memory\n",
|
|
" index = self.memory_counter % MEMORY_CAPACITY\n",
|
|
" self.memory[index, :] = transition\n",
|
|
" self.memory_counter += 1\n",
|
|
"\n",
|
|
" def learn(self):\n",
|
|
" # target parameter update\n",
|
|
" if self.learn_step_counter % TARGET_REPLACE_ITER == 0:\n",
|
|
" self.target_net.load_state_dict(self.eval_net.state_dict())\n",
|
|
" self.learn_step_counter += 1\n",
|
|
"\n",
|
|
" # sample batch transitions\n",
|
|
" sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)\n",
|
|
" b_memory = self.memory[sample_index, :]\n",
|
|
" b_s = Variable(torch.FloatTensor(b_memory[:, :N_STATES]))\n",
|
|
" b_a = Variable(torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int)))\n",
|
|
" b_r = Variable(torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2]))\n",
|
|
" b_s_ = Variable(torch.FloatTensor(b_memory[:, -N_STATES:]))\n",
|
|
"\n",
|
|
" # q_eval w.r.t the action in experience\n",
|
|
" q_eval = self.eval_net(b_s).gather(1, b_a) # shape (batch, 1)\n",
|
|
" q_next = self.target_net(b_s_).detach() # detach from graph, don't backpropagate\n",
|
|
" q_target = b_r + GAMMA * q_next.max(1)[0] # shape (batch, 1)\n",
|
|
" loss = self.loss_func(q_eval, q_target)\n",
|
|
"\n",
|
|
" self.optimizer.zero_grad()\n",
|
|
" loss.backward()\n",
|
|
" self.optimizer.step()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"dqn = DQN()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"Collecting experience...\n",
|
|
"Ep: 201 | Ep_r: 1.59\n",
|
|
"Ep: 202 | Ep_r: 4.18\n",
|
|
"Ep: 203 | Ep_r: 2.73\n",
|
|
"Ep: 204 | Ep_r: 1.97\n",
|
|
"Ep: 205 | Ep_r: 1.18\n",
|
|
"Ep: 206 | Ep_r: 0.86\n",
|
|
"Ep: 207 | Ep_r: 2.88\n",
|
|
"Ep: 208 | Ep_r: 1.63\n",
|
|
"Ep: 209 | Ep_r: 3.91\n",
|
|
"Ep: 210 | Ep_r: 3.6\n",
|
|
"Ep: 211 | Ep_r: 0.98\n",
|
|
"Ep: 212 | Ep_r: 3.85\n",
|
|
"Ep: 213 | Ep_r: 1.81\n",
|
|
"Ep: 214 | Ep_r: 2.32\n",
|
|
"Ep: 215 | Ep_r: 3.75\n",
|
|
"Ep: 216 | Ep_r: 3.53\n",
|
|
"Ep: 217 | Ep_r: 4.75\n",
|
|
"Ep: 218 | Ep_r: 2.4\n",
|
|
"Ep: 219 | Ep_r: 0.64\n",
|
|
"Ep: 220 | Ep_r: 1.15\n",
|
|
"Ep: 221 | Ep_r: 2.3\n",
|
|
"Ep: 222 | Ep_r: 7.37\n",
|
|
"Ep: 223 | Ep_r: 1.25\n",
|
|
"Ep: 224 | Ep_r: 5.02\n",
|
|
"Ep: 225 | Ep_r: 10.29\n",
|
|
"Ep: 226 | Ep_r: 17.54\n",
|
|
"Ep: 227 | Ep_r: 36.2\n",
|
|
"Ep: 228 | Ep_r: 6.61\n",
|
|
"Ep: 229 | Ep_r: 10.04\n",
|
|
"Ep: 230 | Ep_r: 55.19\n",
|
|
"Ep: 231 | Ep_r: 10.03\n",
|
|
"Ep: 232 | Ep_r: 13.25\n",
|
|
"Ep: 233 | Ep_r: 8.75\n",
|
|
"Ep: 234 | Ep_r: 3.83\n",
|
|
"Ep: 235 | Ep_r: -0.92\n",
|
|
"Ep: 236 | Ep_r: 5.12\n",
|
|
"Ep: 237 | Ep_r: 3.56\n",
|
|
"Ep: 238 | Ep_r: 5.69\n",
|
|
"Ep: 239 | Ep_r: 8.43\n",
|
|
"Ep: 240 | Ep_r: 29.27\n",
|
|
"Ep: 241 | Ep_r: 17.95\n",
|
|
"Ep: 242 | Ep_r: 44.77\n",
|
|
"Ep: 243 | Ep_r: 98.0\n",
|
|
"Ep: 244 | Ep_r: 38.78\n",
|
|
"Ep: 245 | Ep_r: 45.02\n",
|
|
"Ep: 246 | Ep_r: 27.73\n",
|
|
"Ep: 247 | Ep_r: 36.96\n",
|
|
"Ep: 248 | Ep_r: 48.98\n",
|
|
"Ep: 249 | Ep_r: 111.36\n",
|
|
"Ep: 250 | Ep_r: 95.61\n",
|
|
"Ep: 251 | Ep_r: 149.77\n",
|
|
"Ep: 252 | Ep_r: 29.96\n",
|
|
"Ep: 253 | Ep_r: 2.79\n",
|
|
"Ep: 254 | Ep_r: 20.1\n",
|
|
"Ep: 255 | Ep_r: 24.25\n",
|
|
"Ep: 256 | Ep_r: 3074.75\n",
|
|
"Ep: 257 | Ep_r: 1258.49\n",
|
|
"Ep: 258 | Ep_r: 127.39\n",
|
|
"Ep: 259 | Ep_r: 283.46\n",
|
|
"Ep: 260 | Ep_r: 166.96\n",
|
|
"Ep: 261 | Ep_r: 101.71\n",
|
|
"Ep: 262 | Ep_r: 63.45\n",
|
|
"Ep: 263 | Ep_r: 288.94\n",
|
|
"Ep: 264 | Ep_r: 130.49\n",
|
|
"Ep: 265 | Ep_r: 207.05\n",
|
|
"Ep: 266 | Ep_r: 183.71\n",
|
|
"Ep: 267 | Ep_r: 142.75\n",
|
|
"Ep: 268 | Ep_r: 126.53\n",
|
|
"Ep: 269 | Ep_r: 310.79\n",
|
|
"Ep: 270 | Ep_r: 863.2\n",
|
|
"Ep: 271 | Ep_r: 365.12\n",
|
|
"Ep: 272 | Ep_r: 659.52\n",
|
|
"Ep: 273 | Ep_r: 103.98\n",
|
|
"Ep: 274 | Ep_r: 554.83\n",
|
|
"Ep: 275 | Ep_r: 246.01\n",
|
|
"Ep: 276 | Ep_r: 332.23\n",
|
|
"Ep: 277 | Ep_r: 323.35\n",
|
|
"Ep: 278 | Ep_r: 278.71\n",
|
|
"Ep: 279 | Ep_r: 613.6\n",
|
|
"Ep: 280 | Ep_r: 152.21\n",
|
|
"Ep: 281 | Ep_r: 402.02\n",
|
|
"Ep: 282 | Ep_r: 351.4\n",
|
|
"Ep: 283 | Ep_r: 115.87\n",
|
|
"Ep: 284 | Ep_r: 163.26\n",
|
|
"Ep: 285 | Ep_r: 631.0\n",
|
|
"Ep: 286 | Ep_r: 263.47\n",
|
|
"Ep: 287 | Ep_r: 511.21\n",
|
|
"Ep: 288 | Ep_r: 337.18\n",
|
|
"Ep: 289 | Ep_r: 819.76\n",
|
|
"Ep: 290 | Ep_r: 190.83\n",
|
|
"Ep: 291 | Ep_r: 442.98\n",
|
|
"Ep: 292 | Ep_r: 537.24\n",
|
|
"Ep: 293 | Ep_r: 1101.12\n",
|
|
"Ep: 294 | Ep_r: 178.42\n",
|
|
"Ep: 295 | Ep_r: 225.61\n",
|
|
"Ep: 296 | Ep_r: 252.62\n",
|
|
"Ep: 297 | Ep_r: 617.5\n",
|
|
"Ep: 298 | Ep_r: 617.8\n",
|
|
"Ep: 299 | Ep_r: 244.01\n",
|
|
"Ep: 300 | Ep_r: 687.91\n",
|
|
"Ep: 301 | Ep_r: 618.51\n",
|
|
"Ep: 302 | Ep_r: 1405.07\n",
|
|
"Ep: 303 | Ep_r: 456.95\n",
|
|
"Ep: 304 | Ep_r: 340.33\n",
|
|
"Ep: 305 | Ep_r: 502.91\n",
|
|
"Ep: 306 | Ep_r: 441.21\n",
|
|
"Ep: 307 | Ep_r: 255.81\n",
|
|
"Ep: 308 | Ep_r: 403.03\n",
|
|
"Ep: 309 | Ep_r: 229.1\n",
|
|
"Ep: 310 | Ep_r: 308.49\n",
|
|
"Ep: 311 | Ep_r: 165.37\n",
|
|
"Ep: 312 | Ep_r: 153.76\n",
|
|
"Ep: 313 | Ep_r: 442.05\n",
|
|
"Ep: 314 | Ep_r: 229.23\n",
|
|
"Ep: 315 | Ep_r: 128.52\n",
|
|
"Ep: 316 | Ep_r: 358.18\n",
|
|
"Ep: 317 | Ep_r: 319.03\n",
|
|
"Ep: 318 | Ep_r: 381.76\n",
|
|
"Ep: 319 | Ep_r: 199.19\n",
|
|
"Ep: 320 | Ep_r: 418.63\n",
|
|
"Ep: 321 | Ep_r: 223.95\n",
|
|
"Ep: 322 | Ep_r: 222.37\n",
|
|
"Ep: 323 | Ep_r: 405.4\n",
|
|
"Ep: 324 | Ep_r: 311.32\n",
|
|
"Ep: 325 | Ep_r: 184.85\n",
|
|
"Ep: 326 | Ep_r: 1026.71\n",
|
|
"Ep: 327 | Ep_r: 252.41\n",
|
|
"Ep: 328 | Ep_r: 224.93\n",
|
|
"Ep: 329 | Ep_r: 620.02\n",
|
|
"Ep: 330 | Ep_r: 174.54\n",
|
|
"Ep: 331 | Ep_r: 782.45\n",
|
|
"Ep: 332 | Ep_r: 263.79\n",
|
|
"Ep: 333 | Ep_r: 178.63\n",
|
|
"Ep: 334 | Ep_r: 242.84\n",
|
|
"Ep: 335 | Ep_r: 635.43\n",
|
|
"Ep: 336 | Ep_r: 668.89\n",
|
|
"Ep: 337 | Ep_r: 265.42\n",
|
|
"Ep: 338 | Ep_r: 207.81\n",
|
|
"Ep: 339 | Ep_r: 293.09\n",
|
|
"Ep: 340 | Ep_r: 530.23\n",
|
|
"Ep: 341 | Ep_r: 479.26\n",
|
|
"Ep: 342 | Ep_r: 559.77\n",
|
|
"Ep: 343 | Ep_r: 241.39\n",
|
|
"Ep: 344 | Ep_r: 158.83\n",
|
|
"Ep: 345 | Ep_r: 1510.69\n",
|
|
"Ep: 346 | Ep_r: 425.17\n",
|
|
"Ep: 347 | Ep_r: 266.94\n",
|
|
"Ep: 348 | Ep_r: 166.08\n",
|
|
"Ep: 349 | Ep_r: 630.52\n",
|
|
"Ep: 350 | Ep_r: 250.95\n",
|
|
"Ep: 351 | Ep_r: 625.88\n",
|
|
"Ep: 352 | Ep_r: 417.7\n",
|
|
"Ep: 353 | Ep_r: 867.81\n",
|
|
"Ep: 354 | Ep_r: 150.62\n",
|
|
"Ep: 355 | Ep_r: 230.89\n",
|
|
"Ep: 356 | Ep_r: 1017.52\n",
|
|
"Ep: 357 | Ep_r: 190.28\n",
|
|
"Ep: 358 | Ep_r: 396.91\n",
|
|
"Ep: 359 | Ep_r: 305.53\n",
|
|
"Ep: 360 | Ep_r: 131.61\n",
|
|
"Ep: 361 | Ep_r: 387.54\n",
|
|
"Ep: 362 | Ep_r: 298.82\n",
|
|
"Ep: 363 | Ep_r: 207.56\n",
|
|
"Ep: 364 | Ep_r: 248.56\n",
|
|
"Ep: 365 | Ep_r: 589.12\n",
|
|
"Ep: 366 | Ep_r: 179.52\n",
|
|
"Ep: 367 | Ep_r: 130.19\n",
|
|
"Ep: 368 | Ep_r: 1220.84\n",
|
|
"Ep: 369 | Ep_r: 126.35\n",
|
|
"Ep: 370 | Ep_r: 133.31\n",
|
|
"Ep: 371 | Ep_r: 485.81\n",
|
|
"Ep: 372 | Ep_r: 823.4\n",
|
|
"Ep: 373 | Ep_r: 253.26\n",
|
|
"Ep: 374 | Ep_r: 466.06\n",
|
|
"Ep: 375 | Ep_r: 203.27\n",
|
|
"Ep: 376 | Ep_r: 386.5\n",
|
|
"Ep: 377 | Ep_r: 491.02\n",
|
|
"Ep: 378 | Ep_r: 239.45\n",
|
|
"Ep: 379 | Ep_r: 276.93\n",
|
|
"Ep: 380 | Ep_r: 331.98\n",
|
|
"Ep: 381 | Ep_r: 764.79\n",
|
|
"Ep: 382 | Ep_r: 198.29\n",
|
|
"Ep: 383 | Ep_r: 717.18\n",
|
|
"Ep: 384 | Ep_r: 562.15\n",
|
|
"Ep: 385 | Ep_r: 29.44\n",
|
|
"Ep: 386 | Ep_r: 344.95\n",
|
|
"Ep: 387 | Ep_r: 671.87\n",
|
|
"Ep: 388 | Ep_r: 299.81\n",
|
|
"Ep: 389 | Ep_r: 899.76\n",
|
|
"Ep: 390 | Ep_r: 319.04\n",
|
|
"Ep: 391 | Ep_r: 252.11\n",
|
|
"Ep: 392 | Ep_r: 865.62\n",
|
|
"Ep: 393 | Ep_r: 255.64\n",
|
|
"Ep: 394 | Ep_r: 81.74\n",
|
|
"Ep: 395 | Ep_r: 213.13\n",
|
|
"Ep: 396 | Ep_r: 422.33\n",
|
|
"Ep: 397 | Ep_r: 167.47\n",
|
|
"Ep: 398 | Ep_r: 507.34\n",
|
|
"Ep: 399 | Ep_r: 614.0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"\n",
|
|
"print('\\nCollecting experience...')\n",
|
|
"for i_episode in range(400):\n",
|
|
" s = env.reset()\n",
|
|
" ep_r = 0\n",
|
|
" while True:\n",
|
|
" env.render()\n",
|
|
" a = dqn.choose_action(s)\n",
|
|
"\n",
|
|
" # take action\n",
|
|
" s_, r, done, info = env.step(a)\n",
|
|
"\n",
|
|
" # modify the reward\n",
|
|
" x, x_dot, theta, theta_dot = s_\n",
|
|
" r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8\n",
|
|
" r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5\n",
|
|
" r = r1 + r2\n",
|
|
"\n",
|
|
" dqn.store_transition(s, a, r, s_)\n",
|
|
"\n",
|
|
" ep_r += r\n",
|
|
" if dqn.memory_counter > MEMORY_CAPACITY:\n",
|
|
" dqn.learn()\n",
|
|
" if done:\n",
|
|
" print('Ep: ', i_episode,\n",
|
|
" '| Ep_r: ', round(ep_r, 2))\n",
|
|
"\n",
|
|
" if done:\n",
|
|
" break\n",
|
|
" s = s_"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.5.2"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|