working on off-policy test. other parts of dqn_replay is runnable, but performance not tested.
This commit is contained in:
parent
24d75fd1aa
commit
e68dcd3c64
@ -1,6 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import
|
||||
|
||||
import tensorflow as tf
|
||||
import gym
|
||||
import numpy as np
|
||||
@ -10,11 +8,9 @@ import time
|
||||
import sys
|
||||
sys.path.append('..')
|
||||
from tianshou.core import losses
|
||||
# from tianshou.data.batch import Batch
|
||||
import tianshou.data.advantage_estimation as advantage_estimation
|
||||
import tianshou.core.policy.dqn as policy # TODO: fix imports as zhusuan so that only need to import to policy
|
||||
import tianshou.core.policy.dqn as policy
|
||||
import tianshou.core.value_function.action_value as value_function
|
||||
import sys
|
||||
|
||||
from tianshou.data.replay_buffer.vanilla import VanillaReplayBuffer
|
||||
from tianshou.data.data_collector import DataCollector
|
||||
@ -25,14 +21,6 @@ if __name__ == '__main__':
|
||||
observation_dim = env.observation_space.shape
|
||||
action_dim = env.action_space.n
|
||||
|
||||
clip_param = 0.2
|
||||
num_batches = 10
|
||||
batch_size = 512
|
||||
|
||||
seed = 0
|
||||
np.random.seed(seed)
|
||||
tf.set_random_seed(seed)
|
||||
|
||||
### 1. build network with pure tf
|
||||
observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim)
|
||||
|
||||
@ -45,7 +33,7 @@ if __name__ == '__main__':
|
||||
return None, action_values # no policy head
|
||||
|
||||
### 2. build policy, loss, optimizer
|
||||
dqn = value_function.DQN(my_network, observation_placeholder=observation_ph, weight_update=100)
|
||||
dqn = value_function.DQN(my_network, observation_placeholder=observation_ph, weight_update=200)
|
||||
pi = policy.DQN(dqn)
|
||||
|
||||
dqn_loss = losses.qlearning(dqn)
|
||||
@ -69,6 +57,17 @@ if __name__ == '__main__':
|
||||
)
|
||||
|
||||
### 4. start training
|
||||
# hyper-parameters
|
||||
batch_size = 256
|
||||
replay_buffer_warmup = 1000
|
||||
epsilon_decay_interval = 200
|
||||
epsilon = 0.3
|
||||
test_interval = 1000
|
||||
|
||||
seed = 0
|
||||
np.random.seed(seed)
|
||||
tf.set_random_seed(seed)
|
||||
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
with tf.Session(config=config) as sess:
|
||||
@ -78,12 +77,11 @@ if __name__ == '__main__':
|
||||
pi.sync_weights() # TODO: automate this for policies with target network
|
||||
|
||||
start_time = time.time()
|
||||
epsilon = 0.5
|
||||
pi.set_epsilon_train(epsilon)
|
||||
data_collector.collect(num_timesteps=int(1e3)) # warm-up
|
||||
data_collector.collect(num_timesteps=replay_buffer_warmup) # warm-up
|
||||
for i in range(int(1e8)): # number of training steps
|
||||
# anneal epsilon step-wise
|
||||
if (i + 1) % 1e4 == 0 and epsilon > 0.1:
|
||||
if (i + 1) % epsilon_decay_interval == 0 and epsilon > 0.1:
|
||||
epsilon -= 0.1
|
||||
pi.set_epsilon_train(epsilon)
|
||||
|
||||
@ -91,15 +89,13 @@ if __name__ == '__main__':
|
||||
data_collector.collect()
|
||||
|
||||
# update network
|
||||
for _ in range(num_batches):
|
||||
feed_dict = data_collector.next_batch(batch_size)
|
||||
sess.run(train_op, feed_dict=feed_dict)
|
||||
|
||||
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))
|
||||
feed_dict = data_collector.next_batch(batch_size)
|
||||
sess.run(train_op, feed_dict=feed_dict)
|
||||
|
||||
# test every 1000 training steps
|
||||
# tester could share some code with batch!
|
||||
if i % 1000 == 0:
|
||||
if i % test_interval == 0:
|
||||
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))
|
||||
# epsilon 0.05 as in nature paper
|
||||
pi.set_epsilon_test(0.05)
|
||||
#test(env, pi) # go for act_test of pi, not act
|
||||
|
@ -69,14 +69,3 @@ def qlearning(action_value_function):
|
||||
|
||||
q_value = action_value_function.value_tensor
|
||||
return tf.losses.mean_squared_error(target_value_ph, q_value)
|
||||
|
||||
|
||||
def deterministic_policy_gradient(sampled_state, critic):
|
||||
"""
|
||||
deterministic policy gradient:
|
||||
|
||||
:param sampled_action: placeholder of sampled actions during the interaction with the environment
|
||||
:param critic: current `value` function
|
||||
:return:
|
||||
"""
|
||||
return tf.reduce_mean(critic.get_value(sampled_state))
|
@ -30,8 +30,10 @@ class DQN(PolicyBase):
|
||||
feed_dict = {self.action_value._observation_placeholder: observation[None]}
|
||||
feed_dict.update(my_feed_dict)
|
||||
action = sess.run(self._argmax_action, feed_dict=feed_dict)
|
||||
|
||||
# epsilon_greedy
|
||||
if np.random.rand() < self.epsilon_train:
|
||||
pass
|
||||
action = np.random.randint(self.action_value.num_actions)
|
||||
|
||||
if self.weight_update > 0:
|
||||
self.interaction_count += 1
|
||||
@ -39,7 +41,23 @@ class DQN(PolicyBase):
|
||||
return np.squeeze(action)
|
||||
|
||||
def act_test(self, observation, my_feed_dict={}):
|
||||
pass
|
||||
sess = tf.get_default_session()
|
||||
if self.weight_update > 1:
|
||||
if self.interaction_count % self.weight_update == 0:
|
||||
self.update_weights()
|
||||
|
||||
feed_dict = {self.action_value._observation_placeholder: observation[None]}
|
||||
feed_dict.update(my_feed_dict)
|
||||
action = sess.run(self._argmax_action, feed_dict=feed_dict)
|
||||
|
||||
# epsilon_greedy
|
||||
if np.random.rand() < self.epsilon_test:
|
||||
action = np.random.randint(self.action_value.num_actions)
|
||||
|
||||
if self.weight_update > 0:
|
||||
self.interaction_count += 1
|
||||
|
||||
return np.squeeze(action)
|
||||
|
||||
@property
|
||||
def q_net(self):
|
||||
|
@ -114,6 +114,8 @@ class DQN(ValueFunctionBase):
|
||||
|
||||
self._value_tensor_all_actions = value_tensor
|
||||
|
||||
self.num_actions = value_tensor.shape.as_list()[-1]
|
||||
|
||||
batch_size = tf.shape(value_tensor)[0]
|
||||
batch_dim_index = tf.range(batch_size)
|
||||
indices = tf.stack([batch_dim_index, action_placeholder], axis=1)
|
||||
|
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
STATE = 0
|
||||
@ -105,12 +104,12 @@ class nstep_q_return:
|
||||
"""
|
||||
compute the n-step return for Q-learning targets
|
||||
"""
|
||||
def __init__(self, n, action_value, use_target_network=True):
|
||||
def __init__(self, n, action_value, use_target_network=True, discount_factor=0.99):
|
||||
self.n = n
|
||||
self.action_value = action_value
|
||||
self.use_target_network = use_target_network
|
||||
self.discount_factor = discount_factor
|
||||
|
||||
# TODO : we should transfer the tf -> numpy/python -> tf into a monolithic compute graph in tf
|
||||
def __call__(self, buffer, indexes=None):
|
||||
"""
|
||||
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
||||
@ -118,41 +117,39 @@ class nstep_q_return:
|
||||
each episode.
|
||||
:return: dict with key 'return' and value the computed returns corresponding to `index`.
|
||||
"""
|
||||
qvalue = self.action_value._value_tensor_all_actions
|
||||
indexes = indexes or buffer.index
|
||||
episodes = buffer.data
|
||||
discount_factor = 0.99
|
||||
returns = []
|
||||
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
with tf.Session(config=config) as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
for episode_index in range(len(indexes)):
|
||||
index = indexes[episode_index]
|
||||
if index:
|
||||
episode = episodes[episode_index]
|
||||
episode_q = []
|
||||
for episode_index in range(len(indexes)):
|
||||
index = indexes[episode_index]
|
||||
if index:
|
||||
episode = episodes[episode_index]
|
||||
episode_q = []
|
||||
|
||||
for i in index:
|
||||
current_discount_factor = 1
|
||||
last_frame_index = i
|
||||
target_q = episode[i][REWARD]
|
||||
for lfi in range(i, min(len(episode), i + self.n + 1)):
|
||||
if episode[lfi][DONE]:
|
||||
break
|
||||
target_q += current_discount_factor * episode[lfi][REWARD]
|
||||
current_discount_factor *= discount_factor
|
||||
last_frame_index = lfi
|
||||
if last_frame_index > i:
|
||||
state = episode[last_frame_index][STATE]
|
||||
# the shape of qpredict is [batch_size, action_dimension]
|
||||
qpredict = sess.run(qvalue, feed_dict={self.action_value.managed_placeholders['observation']:
|
||||
state.reshape(1, state.shape[0])})
|
||||
target_q += current_discount_factor * max(qpredict[0])
|
||||
episode_q.append(target_q)
|
||||
for i in index:
|
||||
current_discount_factor = 1
|
||||
last_frame_index = i
|
||||
target_q = episode[i][REWARD]
|
||||
for lfi in range(i, min(len(episode), i + self.n + 1)):
|
||||
if episode[lfi][DONE]:
|
||||
break
|
||||
target_q += current_discount_factor * episode[lfi][REWARD]
|
||||
current_discount_factor *= self.discount_factor
|
||||
last_frame_index = lfi
|
||||
if last_frame_index > i:
|
||||
state = episode[last_frame_index][STATE]
|
||||
|
||||
if self.use_target_network:
|
||||
# [None] adds one dimension to the beginning
|
||||
qpredict = self.action_value.eval_value_all_actions_old(state[None])
|
||||
else:
|
||||
qpredict = self.action_value.eval_value_all_actions(state[None])
|
||||
target_q += current_discount_factor * max(qpredict[0])
|
||||
episode_q.append(target_q)
|
||||
|
||||
returns.append(episode_q)
|
||||
else:
|
||||
returns.append([])
|
||||
|
||||
returns.append(episode_q)
|
||||
else:
|
||||
returns.append([])
|
||||
return {'return': returns}
|
||||
|
@ -36,14 +36,20 @@ class DataCollector(object):
|
||||
"One and only one collection number specification permitted!"
|
||||
|
||||
if num_timesteps > 0:
|
||||
for _ in range(num_timesteps):
|
||||
num_timesteps_ = int(num_timesteps)
|
||||
for _ in range(num_timesteps_):
|
||||
action = self.policy.act(self.current_observation, my_feed_dict=my_feed_dict)
|
||||
next_observation, reward, done, _ = self.env.step(action)
|
||||
self.data_buffer.add((self.current_observation, action, reward, done))
|
||||
self.current_observation = next_observation
|
||||
|
||||
if done:
|
||||
self.current_observation = self.env.reset()
|
||||
else:
|
||||
self.current_observation = next_observation
|
||||
|
||||
if num_episodes > 0:
|
||||
for _ in range(num_episodes):
|
||||
num_episodes_ = int(num_episodes)
|
||||
for _ in range(num_episodes_):
|
||||
observation = self.env.reset()
|
||||
done = False
|
||||
while not done:
|
||||
@ -56,7 +62,7 @@ class DataCollector(object):
|
||||
for processor in self.process_functions:
|
||||
self.data.update(processor(self.data_buffer))
|
||||
|
||||
def next_batch(self, batch_size, standardize_advantage=True):
|
||||
def next_batch(self, batch_size, standardize_advantage=None):
|
||||
sampled_index = self.data_buffer.sample(batch_size)
|
||||
if self.process_mode == 'sample':
|
||||
for processor in self.process_functions:
|
||||
@ -87,7 +93,8 @@ class DataCollector(object):
|
||||
else:
|
||||
raise TypeError('Placeholder {} has no value to feed!'.format(str(placeholder.name)))
|
||||
|
||||
if standardize_advantage:
|
||||
auto_standardize = (standardize_advantage is None) and self.require_advantage
|
||||
if standardize_advantage or auto_standardize:
|
||||
if self.require_advantage:
|
||||
advantage_value = feed_dict[self.required_placeholders['advantage']]
|
||||
advantage_mean = np.mean(advantage_value)
|
||||
|
8
tianshou/data/tester.py
Normal file
8
tianshou/data/tester.py
Normal file
@ -0,0 +1,8 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
|
||||
def test_policy_in_env(policy, env):
|
||||
# make another env as the original is for training data collection
|
||||
env_ = env
|
||||
|
||||
pass
|
Loading…
x
Reference in New Issue
Block a user