update to torch 0.4
This commit is contained in:
@ -4,13 +4,12 @@ My Youtube Channel: https://www.youtube.com/user/MorvanZhou
|
||||
More about Reinforcement learning: https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/
|
||||
|
||||
Dependencies:
|
||||
torch: 0.3
|
||||
torch: 0.4
|
||||
gym: 0.8.1
|
||||
numpy
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import gym
|
||||
@ -55,7 +54,7 @@ class DQN(object):
|
||||
self.loss_func = nn.MSELoss()
|
||||
|
||||
def choose_action(self, x):
|
||||
x = Variable(torch.unsqueeze(torch.FloatTensor(x), 0))
|
||||
x = torch.unsqueeze(torch.FloatTensor(x), 0)
|
||||
# input only one sample
|
||||
if np.random.uniform() < EPSILON: # greedy
|
||||
actions_value = self.eval_net.forward(x)
|
||||
@ -82,10 +81,10 @@ class DQN(object):
|
||||
# sample batch transitions
|
||||
sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)
|
||||
b_memory = self.memory[sample_index, :]
|
||||
b_s = Variable(torch.FloatTensor(b_memory[:, :N_STATES]))
|
||||
b_a = Variable(torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int)))
|
||||
b_r = Variable(torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2]))
|
||||
b_s_ = Variable(torch.FloatTensor(b_memory[:, -N_STATES:]))
|
||||
b_s = torch.FloatTensor(b_memory[:, :N_STATES])
|
||||
b_a = torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int))
|
||||
b_r = torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2])
|
||||
b_s_ = torch.FloatTensor(b_memory[:, -N_STATES:])
|
||||
|
||||
# q_eval w.r.t the action in experience
|
||||
q_eval = self.eval_net(b_s).gather(1, b_a) # shape (batch, 1)
|
||||
|
||||
Reference in New Issue
Block a user