update to torch 0.4

This commit is contained in:
Morvan Zhou
2018-05-30 01:39:53 +08:00
parent 7e7c9bb383
commit 921b69a582
15 changed files with 82 additions and 104 deletions

View File

@ -4,13 +4,12 @@ My Youtube Channel: https://www.youtube.com/user/MorvanZhou
More about Reinforcement learning: https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/
Dependencies:
torch: 0.3
torch: 0.4
gym: 0.8.1
numpy
"""
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import gym
@ -55,7 +54,7 @@ class DQN(object):
self.loss_func = nn.MSELoss()
def choose_action(self, x):
x = Variable(torch.unsqueeze(torch.FloatTensor(x), 0))
x = torch.unsqueeze(torch.FloatTensor(x), 0)
# input only one sample
if np.random.uniform() < EPSILON: # greedy
actions_value = self.eval_net.forward(x)
@ -82,10 +81,10 @@ class DQN(object):
# sample batch transitions
sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)
b_memory = self.memory[sample_index, :]
b_s = Variable(torch.FloatTensor(b_memory[:, :N_STATES]))
b_a = Variable(torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int)))
b_r = Variable(torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2]))
b_s_ = Variable(torch.FloatTensor(b_memory[:, -N_STATES:]))
b_s = torch.FloatTensor(b_memory[:, :N_STATES])
b_a = torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int))
b_r = torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2])
b_s_ = torch.FloatTensor(b_memory[:, -N_STATES:])
# q_eval w.r.t the action in experience
q_eval = self.eval_net(b_s).gather(1, b_a) # shape (batch, 1)