working on off-policy test. other parts of dqn_replay is runnable, but performance not tested.
This commit is contained in:
parent
24d75fd1aa
commit
e68dcd3c64
@ -1,6 +1,4 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
from __future__ import absolute_import
|
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -10,11 +8,9 @@ import time
|
|||||||
import sys
|
import sys
|
||||||
sys.path.append('..')
|
sys.path.append('..')
|
||||||
from tianshou.core import losses
|
from tianshou.core import losses
|
||||||
# from tianshou.data.batch import Batch
|
|
||||||
import tianshou.data.advantage_estimation as advantage_estimation
|
import tianshou.data.advantage_estimation as advantage_estimation
|
||||||
import tianshou.core.policy.dqn as policy # TODO: fix imports as zhusuan so that only need to import to policy
|
import tianshou.core.policy.dqn as policy
|
||||||
import tianshou.core.value_function.action_value as value_function
|
import tianshou.core.value_function.action_value as value_function
|
||||||
import sys
|
|
||||||
|
|
||||||
from tianshou.data.replay_buffer.vanilla import VanillaReplayBuffer
|
from tianshou.data.replay_buffer.vanilla import VanillaReplayBuffer
|
||||||
from tianshou.data.data_collector import DataCollector
|
from tianshou.data.data_collector import DataCollector
|
||||||
@ -25,14 +21,6 @@ if __name__ == '__main__':
|
|||||||
observation_dim = env.observation_space.shape
|
observation_dim = env.observation_space.shape
|
||||||
action_dim = env.action_space.n
|
action_dim = env.action_space.n
|
||||||
|
|
||||||
clip_param = 0.2
|
|
||||||
num_batches = 10
|
|
||||||
batch_size = 512
|
|
||||||
|
|
||||||
seed = 0
|
|
||||||
np.random.seed(seed)
|
|
||||||
tf.set_random_seed(seed)
|
|
||||||
|
|
||||||
### 1. build network with pure tf
|
### 1. build network with pure tf
|
||||||
observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim)
|
observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim)
|
||||||
|
|
||||||
@ -45,7 +33,7 @@ if __name__ == '__main__':
|
|||||||
return None, action_values # no policy head
|
return None, action_values # no policy head
|
||||||
|
|
||||||
### 2. build policy, loss, optimizer
|
### 2. build policy, loss, optimizer
|
||||||
dqn = value_function.DQN(my_network, observation_placeholder=observation_ph, weight_update=100)
|
dqn = value_function.DQN(my_network, observation_placeholder=observation_ph, weight_update=200)
|
||||||
pi = policy.DQN(dqn)
|
pi = policy.DQN(dqn)
|
||||||
|
|
||||||
dqn_loss = losses.qlearning(dqn)
|
dqn_loss = losses.qlearning(dqn)
|
||||||
@ -69,6 +57,17 @@ if __name__ == '__main__':
|
|||||||
)
|
)
|
||||||
|
|
||||||
### 4. start training
|
### 4. start training
|
||||||
|
# hyper-parameters
|
||||||
|
batch_size = 256
|
||||||
|
replay_buffer_warmup = 1000
|
||||||
|
epsilon_decay_interval = 200
|
||||||
|
epsilon = 0.3
|
||||||
|
test_interval = 1000
|
||||||
|
|
||||||
|
seed = 0
|
||||||
|
np.random.seed(seed)
|
||||||
|
tf.set_random_seed(seed)
|
||||||
|
|
||||||
config = tf.ConfigProto()
|
config = tf.ConfigProto()
|
||||||
config.gpu_options.allow_growth = True
|
config.gpu_options.allow_growth = True
|
||||||
with tf.Session(config=config) as sess:
|
with tf.Session(config=config) as sess:
|
||||||
@ -78,12 +77,11 @@ if __name__ == '__main__':
|
|||||||
pi.sync_weights() # TODO: automate this for policies with target network
|
pi.sync_weights() # TODO: automate this for policies with target network
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
epsilon = 0.5
|
|
||||||
pi.set_epsilon_train(epsilon)
|
pi.set_epsilon_train(epsilon)
|
||||||
data_collector.collect(num_timesteps=int(1e3)) # warm-up
|
data_collector.collect(num_timesteps=replay_buffer_warmup) # warm-up
|
||||||
for i in range(int(1e8)): # number of training steps
|
for i in range(int(1e8)): # number of training steps
|
||||||
# anneal epsilon step-wise
|
# anneal epsilon step-wise
|
||||||
if (i + 1) % 1e4 == 0 and epsilon > 0.1:
|
if (i + 1) % epsilon_decay_interval == 0 and epsilon > 0.1:
|
||||||
epsilon -= 0.1
|
epsilon -= 0.1
|
||||||
pi.set_epsilon_train(epsilon)
|
pi.set_epsilon_train(epsilon)
|
||||||
|
|
||||||
@ -91,15 +89,13 @@ if __name__ == '__main__':
|
|||||||
data_collector.collect()
|
data_collector.collect()
|
||||||
|
|
||||||
# update network
|
# update network
|
||||||
for _ in range(num_batches):
|
feed_dict = data_collector.next_batch(batch_size)
|
||||||
feed_dict = data_collector.next_batch(batch_size)
|
sess.run(train_op, feed_dict=feed_dict)
|
||||||
sess.run(train_op, feed_dict=feed_dict)
|
|
||||||
|
|
||||||
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))
|
|
||||||
|
|
||||||
# test every 1000 training steps
|
# test every 1000 training steps
|
||||||
# tester could share some code with batch!
|
# tester could share some code with batch!
|
||||||
if i % 1000 == 0:
|
if i % test_interval == 0:
|
||||||
|
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))
|
||||||
# epsilon 0.05 as in nature paper
|
# epsilon 0.05 as in nature paper
|
||||||
pi.set_epsilon_test(0.05)
|
pi.set_epsilon_test(0.05)
|
||||||
#test(env, pi) # go for act_test of pi, not act
|
#test(env, pi) # go for act_test of pi, not act
|
||||||
|
@ -69,14 +69,3 @@ def qlearning(action_value_function):
|
|||||||
|
|
||||||
q_value = action_value_function.value_tensor
|
q_value = action_value_function.value_tensor
|
||||||
return tf.losses.mean_squared_error(target_value_ph, q_value)
|
return tf.losses.mean_squared_error(target_value_ph, q_value)
|
||||||
|
|
||||||
|
|
||||||
def deterministic_policy_gradient(sampled_state, critic):
|
|
||||||
"""
|
|
||||||
deterministic policy gradient:
|
|
||||||
|
|
||||||
:param sampled_action: placeholder of sampled actions during the interaction with the environment
|
|
||||||
:param critic: current `value` function
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return tf.reduce_mean(critic.get_value(sampled_state))
|
|
@ -30,8 +30,10 @@ class DQN(PolicyBase):
|
|||||||
feed_dict = {self.action_value._observation_placeholder: observation[None]}
|
feed_dict = {self.action_value._observation_placeholder: observation[None]}
|
||||||
feed_dict.update(my_feed_dict)
|
feed_dict.update(my_feed_dict)
|
||||||
action = sess.run(self._argmax_action, feed_dict=feed_dict)
|
action = sess.run(self._argmax_action, feed_dict=feed_dict)
|
||||||
|
|
||||||
|
# epsilon_greedy
|
||||||
if np.random.rand() < self.epsilon_train:
|
if np.random.rand() < self.epsilon_train:
|
||||||
pass
|
action = np.random.randint(self.action_value.num_actions)
|
||||||
|
|
||||||
if self.weight_update > 0:
|
if self.weight_update > 0:
|
||||||
self.interaction_count += 1
|
self.interaction_count += 1
|
||||||
@ -39,7 +41,23 @@ class DQN(PolicyBase):
|
|||||||
return np.squeeze(action)
|
return np.squeeze(action)
|
||||||
|
|
||||||
def act_test(self, observation, my_feed_dict={}):
|
def act_test(self, observation, my_feed_dict={}):
|
||||||
pass
|
sess = tf.get_default_session()
|
||||||
|
if self.weight_update > 1:
|
||||||
|
if self.interaction_count % self.weight_update == 0:
|
||||||
|
self.update_weights()
|
||||||
|
|
||||||
|
feed_dict = {self.action_value._observation_placeholder: observation[None]}
|
||||||
|
feed_dict.update(my_feed_dict)
|
||||||
|
action = sess.run(self._argmax_action, feed_dict=feed_dict)
|
||||||
|
|
||||||
|
# epsilon_greedy
|
||||||
|
if np.random.rand() < self.epsilon_test:
|
||||||
|
action = np.random.randint(self.action_value.num_actions)
|
||||||
|
|
||||||
|
if self.weight_update > 0:
|
||||||
|
self.interaction_count += 1
|
||||||
|
|
||||||
|
return np.squeeze(action)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def q_net(self):
|
def q_net(self):
|
||||||
|
@ -114,6 +114,8 @@ class DQN(ValueFunctionBase):
|
|||||||
|
|
||||||
self._value_tensor_all_actions = value_tensor
|
self._value_tensor_all_actions = value_tensor
|
||||||
|
|
||||||
|
self.num_actions = value_tensor.shape.as_list()[-1]
|
||||||
|
|
||||||
batch_size = tf.shape(value_tensor)[0]
|
batch_size = tf.shape(value_tensor)[0]
|
||||||
batch_dim_index = tf.range(batch_size)
|
batch_dim_index = tf.range(batch_size)
|
||||||
indices = tf.stack([batch_dim_index, action_placeholder], axis=1)
|
indices = tf.stack([batch_dim_index, action_placeholder], axis=1)
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
import tensorflow as tf
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
STATE = 0
|
STATE = 0
|
||||||
@ -105,12 +104,12 @@ class nstep_q_return:
|
|||||||
"""
|
"""
|
||||||
compute the n-step return for Q-learning targets
|
compute the n-step return for Q-learning targets
|
||||||
"""
|
"""
|
||||||
def __init__(self, n, action_value, use_target_network=True):
|
def __init__(self, n, action_value, use_target_network=True, discount_factor=0.99):
|
||||||
self.n = n
|
self.n = n
|
||||||
self.action_value = action_value
|
self.action_value = action_value
|
||||||
self.use_target_network = use_target_network
|
self.use_target_network = use_target_network
|
||||||
|
self.discount_factor = discount_factor
|
||||||
|
|
||||||
# TODO : we should transfer the tf -> numpy/python -> tf into a monolithic compute graph in tf
|
|
||||||
def __call__(self, buffer, indexes=None):
|
def __call__(self, buffer, indexes=None):
|
||||||
"""
|
"""
|
||||||
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
||||||
@ -118,41 +117,39 @@ class nstep_q_return:
|
|||||||
each episode.
|
each episode.
|
||||||
:return: dict with key 'return' and value the computed returns corresponding to `index`.
|
:return: dict with key 'return' and value the computed returns corresponding to `index`.
|
||||||
"""
|
"""
|
||||||
qvalue = self.action_value._value_tensor_all_actions
|
|
||||||
indexes = indexes or buffer.index
|
indexes = indexes or buffer.index
|
||||||
episodes = buffer.data
|
episodes = buffer.data
|
||||||
discount_factor = 0.99
|
|
||||||
returns = []
|
returns = []
|
||||||
|
|
||||||
config = tf.ConfigProto()
|
for episode_index in range(len(indexes)):
|
||||||
config.gpu_options.allow_growth = True
|
index = indexes[episode_index]
|
||||||
with tf.Session(config=config) as sess:
|
if index:
|
||||||
sess.run(tf.global_variables_initializer())
|
episode = episodes[episode_index]
|
||||||
for episode_index in range(len(indexes)):
|
episode_q = []
|
||||||
index = indexes[episode_index]
|
|
||||||
if index:
|
|
||||||
episode = episodes[episode_index]
|
|
||||||
episode_q = []
|
|
||||||
|
|
||||||
for i in index:
|
for i in index:
|
||||||
current_discount_factor = 1
|
current_discount_factor = 1
|
||||||
last_frame_index = i
|
last_frame_index = i
|
||||||
target_q = episode[i][REWARD]
|
target_q = episode[i][REWARD]
|
||||||
for lfi in range(i, min(len(episode), i + self.n + 1)):
|
for lfi in range(i, min(len(episode), i + self.n + 1)):
|
||||||
if episode[lfi][DONE]:
|
if episode[lfi][DONE]:
|
||||||
break
|
break
|
||||||
target_q += current_discount_factor * episode[lfi][REWARD]
|
target_q += current_discount_factor * episode[lfi][REWARD]
|
||||||
current_discount_factor *= discount_factor
|
current_discount_factor *= self.discount_factor
|
||||||
last_frame_index = lfi
|
last_frame_index = lfi
|
||||||
if last_frame_index > i:
|
if last_frame_index > i:
|
||||||
state = episode[last_frame_index][STATE]
|
state = episode[last_frame_index][STATE]
|
||||||
# the shape of qpredict is [batch_size, action_dimension]
|
|
||||||
qpredict = sess.run(qvalue, feed_dict={self.action_value.managed_placeholders['observation']:
|
if self.use_target_network:
|
||||||
state.reshape(1, state.shape[0])})
|
# [None] adds one dimension to the beginning
|
||||||
target_q += current_discount_factor * max(qpredict[0])
|
qpredict = self.action_value.eval_value_all_actions_old(state[None])
|
||||||
episode_q.append(target_q)
|
else:
|
||||||
|
qpredict = self.action_value.eval_value_all_actions(state[None])
|
||||||
|
target_q += current_discount_factor * max(qpredict[0])
|
||||||
|
episode_q.append(target_q)
|
||||||
|
|
||||||
|
returns.append(episode_q)
|
||||||
|
else:
|
||||||
|
returns.append([])
|
||||||
|
|
||||||
returns.append(episode_q)
|
|
||||||
else:
|
|
||||||
returns.append([])
|
|
||||||
return {'return': returns}
|
return {'return': returns}
|
||||||
|
@ -36,14 +36,20 @@ class DataCollector(object):
|
|||||||
"One and only one collection number specification permitted!"
|
"One and only one collection number specification permitted!"
|
||||||
|
|
||||||
if num_timesteps > 0:
|
if num_timesteps > 0:
|
||||||
for _ in range(num_timesteps):
|
num_timesteps_ = int(num_timesteps)
|
||||||
|
for _ in range(num_timesteps_):
|
||||||
action = self.policy.act(self.current_observation, my_feed_dict=my_feed_dict)
|
action = self.policy.act(self.current_observation, my_feed_dict=my_feed_dict)
|
||||||
next_observation, reward, done, _ = self.env.step(action)
|
next_observation, reward, done, _ = self.env.step(action)
|
||||||
self.data_buffer.add((self.current_observation, action, reward, done))
|
self.data_buffer.add((self.current_observation, action, reward, done))
|
||||||
self.current_observation = next_observation
|
|
||||||
|
if done:
|
||||||
|
self.current_observation = self.env.reset()
|
||||||
|
else:
|
||||||
|
self.current_observation = next_observation
|
||||||
|
|
||||||
if num_episodes > 0:
|
if num_episodes > 0:
|
||||||
for _ in range(num_episodes):
|
num_episodes_ = int(num_episodes)
|
||||||
|
for _ in range(num_episodes_):
|
||||||
observation = self.env.reset()
|
observation = self.env.reset()
|
||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
@ -56,7 +62,7 @@ class DataCollector(object):
|
|||||||
for processor in self.process_functions:
|
for processor in self.process_functions:
|
||||||
self.data.update(processor(self.data_buffer))
|
self.data.update(processor(self.data_buffer))
|
||||||
|
|
||||||
def next_batch(self, batch_size, standardize_advantage=True):
|
def next_batch(self, batch_size, standardize_advantage=None):
|
||||||
sampled_index = self.data_buffer.sample(batch_size)
|
sampled_index = self.data_buffer.sample(batch_size)
|
||||||
if self.process_mode == 'sample':
|
if self.process_mode == 'sample':
|
||||||
for processor in self.process_functions:
|
for processor in self.process_functions:
|
||||||
@ -87,7 +93,8 @@ class DataCollector(object):
|
|||||||
else:
|
else:
|
||||||
raise TypeError('Placeholder {} has no value to feed!'.format(str(placeholder.name)))
|
raise TypeError('Placeholder {} has no value to feed!'.format(str(placeholder.name)))
|
||||||
|
|
||||||
if standardize_advantage:
|
auto_standardize = (standardize_advantage is None) and self.require_advantage
|
||||||
|
if standardize_advantage or auto_standardize:
|
||||||
if self.require_advantage:
|
if self.require_advantage:
|
||||||
advantage_value = feed_dict[self.required_placeholders['advantage']]
|
advantage_value = feed_dict[self.required_placeholders['advantage']]
|
||||||
advantage_mean = np.mean(advantage_value)
|
advantage_mean = np.mean(advantage_value)
|
||||||
|
8
tianshou/data/tester.py
Normal file
8
tianshou/data/tester.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
|
||||||
|
def test_policy_in_env(policy, env):
|
||||||
|
# make another env as the original is for training data collection
|
||||||
|
env_ = env
|
||||||
|
|
||||||
|
pass
|
Loading…
x
Reference in New Issue
Block a user