diff --git a/examples/dqn_example.py b/examples/dqn_example.py index f822e59..56e84bf 100644 --- a/examples/dqn_example.py +++ b/examples/dqn_example.py @@ -13,6 +13,7 @@ from tianshou.core import losses from tianshou.data.batch import Batch import tianshou.data.advantage_estimation as advantage_estimation import tianshou.core.policy.dqn as policy # TODO: fix imports as zhusuan so that only need to import to policy +import tianshou.core.value_function.action_value as value_function if __name__ == '__main__': @@ -31,31 +32,26 @@ if __name__ == '__main__': ### 1. build network with pure tf observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim) - def my_policy(): + def my_network(): net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh) net = tf.layers.dense(net, 32, activation=tf.nn.tanh) action_values = tf.layers.dense(net, action_dim, activation=None) - return action_values, None # None value head - - # TODO: current implementation of passing function or overriding function has to return a value head - # to allow network sharing between policy and value networks. This makes 'policy' and 'value_function' - # imbalanced semantically (though they are naturally imbalanced since 'policy' is required to interact - # with the environment and 'value_function' is not). I have an idea to solve this imbalance, which is - # not based on passing function or overriding function. + return None, action_values # no policy head ### 2. build policy, loss, optimizer - pi = policy.DQN(my_policy, observation_placeholder=observation_ph, weight_update=10) + dqn = value_function.DQN(my_network, observation_placeholder=observation_ph, weight_update=100) + pi = policy.DQN(dqn) - dqn_loss = losses.qlearning(pi) + dqn_loss = losses.qlearning(dqn) total_loss = dqn_loss optimizer = tf.train.AdamOptimizer(1e-4) - train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables) + train_op = optimizer.minimize(total_loss, var_list=dqn.trainable_variables) ### 3. define data collection - data_collector = Batch(env, pi, [advantage_estimation.nstep_q_return(1, pi.target_network)], [pi]) + data_collector = Batch(env, pi, [advantage_estimation.nstep_q_return(1, dqn)], [dqn]) ### 4. start training config = tf.ConfigProto() diff --git a/tianshou/core/losses.py b/tianshou/core/losses.py index 69ac79e..f1e77c7 100644 --- a/tianshou/core/losses.py +++ b/tianshou/core/losses.py @@ -51,24 +51,25 @@ def state_value_mse(state_value_function): :param state_value_function: instance of StateValue :return: tensor of the mse loss """ - state_value_ph = tf.placeholder(tf.float32, shape=(None,), name='state_value_mse/state_value_placeholder') - state_value_function.managed_placeholders['return'] = state_value_ph + target_value_ph = tf.placeholder(tf.float32, shape=(None,), name='state_value_mse/state_value_placeholder') + state_value_function.managed_placeholders['return'] = target_value_ph state_value = state_value_function.value_tensor - return tf.losses.mean_squared_error(state_value_ph, state_value) + return tf.losses.mean_squared_error(target_value_ph, state_value) -def dqn_loss(sampled_action, sampled_target, policy): +def qlearning(action_value_function): """ deep q-network - - :param sampled_action: placeholder of sampled actions during the interaction with the environment - :param sampled_target: estimated Q(s,a) - :param policy: current `policy` to be optimized + :param action_value_function: current `action_value` to be optimized :return: """ - sampled_q = policy.q_net.value_tensor - return tf.reduce_mean(tf.square(sampled_target - sampled_q)) + target_value_ph = tf.placeholder(tf.float32, shape=(None,), name='qlearning/action_value_placeholder') + action_value_function.managed_placeholders['return'] = target_value_ph + + q_value = action_value_function.value_tensor + return tf.losses.mean_squared_error(target_value_ph, q_value) + def deterministic_policy_gradient(sampled_state, critic): """ diff --git a/tianshou/core/policy/dqn.py b/tianshou/core/policy/dqn.py index 2f6db5a..bc5db67 100644 --- a/tianshou/core/policy/dqn.py +++ b/tianshou/core/policy/dqn.py @@ -2,88 +2,53 @@ from __future__ import absolute_import from .base import PolicyBase import tensorflow as tf -from ..value_function.action_value import DQN +import numpy as np -class DQNRefactor(PolicyBase): +class DQN(PolicyBase): """ use DQN from value_function as a member """ - def __init__(self, value_tensor, observation_placeholder, action_placeholder): - self._q_net = DQN(value_tensor, observation_placeholder, action_placeholder) - self._argmax_action = tf.argmax(value_tensor, axis=1) - - super(DQNRefactor, self).__init__(observation_placeholder=observation_placeholder) + def __init__(self, dqn): + self.action_value = dqn + self._argmax_action = tf.argmax(dqn.value_tensor_all_actions, axis=1) + self.weight_update = dqn.weight_update + if self.weight_update > 1: + self.interaction_count = 0 + else: + self.interaction_count = -1 def act(self, observation, exploration=None): sess = tf.get_default_session() - if not exploration: # no exploration - action = sess.run(self._argmax_action, feed_dict={self._observation_placeholder: observation}) + if self.weight_update > 1: + if self.interaction_count % self.weight_update == 0: + self.update_weights() + feed_dict = {self.action_value._observation_placeholder: observation[None]} + action = sess.run(self._argmax_action, feed_dict=feed_dict) - return action + if self.weight_update > 0: + self.interaction_count += 1 + + if not exploration: + return np.squeeze(action) @property def q_net(self): - return self._q_net + return self.action_value + def sync_weights(self): + """ + sync the weights of network_old. Direct copy the weights of network. + :return: + """ + if self.action_value.sync_weights_ops is not None: + self.action_value.sync_weights() -class DQNOld(QValuePolicy): - """ - The policy as in DQN - """ - - def __init__(self, logits, observation_placeholder, dtype=None, **kwargs): - # TODO: this version only support non-continuous action space, extend it to support continuous action space - self._logits = tf.convert_to_tensor(logits) - if dtype is None: - dtype = tf.int32 - self._n_categories = self._logits.get_shape()[-1].value - - super(DQN, self).__init__(observation_placeholder) - - # TODO: put the net definition outside of the class - net = tf.layers.conv2d(self._observation_placeholder, 16, 8, 4, 'valid', activation=tf.nn.relu) - net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu) - net = tf.layers.flatten(net) - net = tf.layers.dense(net, 256, activation=tf.nn.relu, use_bias=True) - self._value = tf.layers.dense(net, self._n_categories) - - def _act(self, observation, exploration=None): # first implement no exploration + def update_weights(self): """ - return the action (int) to be executed. - no exploration when exploration=None. + updates the weights of policy_old. + :return: """ - # TODO: ensure thread safety, tf.multinomial to init - sess = tf.get_default_session() - sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1), - feed_dict={self._observation_placeholder: observation[None]}) - return sampled_action - - @property - def logits(self): - """ - :return: action values - """ - return self._logits - - @property - def n_categories(self): - """ - :return: dimension of action space if not continuous - """ - return self._n_categories - - def values(self, observation): - """ - returns the Q(s, a) values (float) for all actions a at observation s - """ - sess = tf.get_default_session() - value = sess.run(self._value, feed_dict={self._observation_placeholder: observation[None]}) - return value - - def values_tensor(self): - """ - returns the tensor of the values for all actions a at observation s - """ - return self._value + if self.action_value.weight_update_ops is not None: + self.action_value.update_weights() \ No newline at end of file diff --git a/tianshou/core/policy/stochastic.py b/tianshou/core/policy/stochastic.py index 3ac82f0..c867035 100644 --- a/tianshou/core/policy/stochastic.py +++ b/tianshou/core/policy/stochastic.py @@ -65,7 +65,7 @@ class OnehotCategorical(StochasticPolicy): logits, value_head = policy_callable() self._logits_old = tf.convert_to_tensor(logits, dtype=tf.float32) - if value_head is not None: # useful in DDPG + if value_head is not None: pass network_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='network') @@ -201,7 +201,7 @@ class Normal(StochasticPolicy): self._mean_old = tf.convert_to_tensor(mean, dtype=tf.float32) self._logstd_old = tf.convert_to_tensor(logstd, dtype=tf.float32) - if value_head is not None: # useful in DDPG + if value_head is not None: pass network_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='network') @@ -215,7 +215,7 @@ class Normal(StochasticPolicy): if weight_update == 0: self.weight_update_ops = self.sync_weights_ops - elif 0 < weight_update < 1: + elif 0 < weight_update < 1: # useful in DDPG pass else: self.interaction_count = 0 diff --git a/tianshou/core/value_function/action_value.py b/tianshou/core/value_function/action_value.py index c62dae6..e6145ec 100644 --- a/tianshou/core/value_function/action_value.py +++ b/tianshou/core/value_function/action_value.py @@ -28,27 +28,70 @@ class ActionValue(ValueFunctionBase): {self._observation_placeholder: observation, self._action_placeholder: action}) -class DQN(ActionValue): +class DQN(ValueFunctionBase): """ class of the very DQN architecture. Instead of feeding s and a to the network to get a value, DQN feed s to the network and the last layer is Q(s, *) for all actions """ - def __init__(self, value_tensor, observation_placeholder, action_placeholder): + def __init__(self, network_callable, observation_placeholder, weight_update=1): """ :param value_tensor: of shape (batchsize, num_actions) :param observation_placeholder: of shape (batchsize, observation_dim) :param action_placeholder: of shape (batchsize, ) """ - self._value_tensor_all_actions = value_tensor + self._observation_placeholder = observation_placeholder + self.action_placeholder = action_placeholder = tf.placeholder(tf.int32, shape=(None,), name='action_value.DQN/action_placeholder') + self.managed_placeholders = {'observation': observation_placeholder, 'action': action_placeholder} + self.weight_update = weight_update + self.interaction_count = -1 # defaults to -1. only useful if weight_update > 1. - batch_size = tf.shape(value_tensor)[0] - batch_dim_index = tf.range(batch_size) - indices = tf.stack([batch_dim_index, action_placeholder], axis=1) - canonical_value_tensor = tf.gather_nd(value_tensor, indices) + with tf.variable_scope('network', reuse=tf.AUTO_REUSE): + value_tensor = network_callable()[-1] + + self._value_tensor_all_actions = value_tensor + + batch_size = tf.shape(value_tensor)[0] + batch_dim_index = tf.range(batch_size) + indices = tf.stack([batch_dim_index, action_placeholder], axis=1) + canonical_value_tensor = tf.gather_nd(value_tensor, indices) + + self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network') + + super(DQN, self).__init__(canonical_value_tensor, observation_placeholder=observation_placeholder) + + # deal with target network + if self.weight_update == 1: + self.weight_update_ops = None + self.sync_weights_ops = None + else: # then we need to build another tf graph as target network + with tf.variable_scope('net_old', reuse=tf.AUTO_REUSE): + value_tensor = network_callable()[-1] + self.value_tensor_all_actions_old = value_tensor + + batch_size = tf.shape(value_tensor)[0] + batch_dim_index = tf.range(batch_size) + indices = tf.stack([batch_dim_index, action_placeholder], axis=1) + canonical_value_tensor = tf.gather_nd(value_tensor, indices) + + self.value_tensor_old = canonical_value_tensor + + network_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='network') + network_old_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net_old') + + assert len(network_weights) == len(network_old_weights) + self.sync_weights_ops = [tf.assign(variable_old, variable) + for (variable_old, variable) in zip(network_old_weights, network_weights)] + + if weight_update == 0: + self.weight_update_ops = self.sync_weights_ops + elif 0 < weight_update < 1: # useful in DDPG + pass + else: + self.interaction_count = 0 + import math + self.weight_update = math.ceil(weight_update) + self.weight_update_ops = self.sync_weights_ops - super(DQN, self).__init__(value_tensor=canonical_value_tensor, - observation_placeholder=observation_placeholder, - action_placeholder=action_placeholder) def eval_value_all_actions(self, observation): """ @@ -60,4 +103,41 @@ class DQN(ActionValue): @property def value_tensor_all_actions(self): - return self._value_tensor_all_actions \ No newline at end of file + return self._value_tensor_all_actions + + def eval_value_old(self, observation, action): + """ + eval value using target network + :param observation: numpy array of obs + :param action: numpy array of action + :return: numpy array of action value + """ + sess = tf.get_default_session() + feed_dict = {self._observation_placeholder: observation, self.action_placeholder: action} + return sess.run(self.value_tensor_old, feed_dict=feed_dict) + + def eval_value_all_actions_old(self, observation): + """ + :param observation: + :return: numpy array of Q(s, *) given s, of shape (batchsize, num_actions) + """ + sess = tf.get_default_session() + return sess.run(self.value_tensor_all_actions_old, feed_dict={self._observation_placeholder: observation}) + + def sync_weights(self): + """ + sync the weights of network_old. Direct copy the weights of network. + :return: + """ + if self.sync_weights_ops is not None: + sess = tf.get_default_session() + sess.run(self.sync_weights_ops) + + def update_weights(self): + """ + updates the weights of policy_old. + :return: + """ + if self.weight_update_ops is not None: + sess = tf.get_default_session() + sess.run(self.weight_update_ops) \ No newline at end of file diff --git a/tianshou/data/advantage_estimation.py b/tianshou/data/advantage_estimation.py index a1f1978..78467b9 100644 --- a/tianshou/data/advantage_estimation.py +++ b/tianshou/data/advantage_estimation.py @@ -46,8 +46,14 @@ class gae_lambda: def __call__(self, raw_data): reward = raw_data['reward'] + observation = raw_data['observation'] - return {'advantage': reward} + state_value = self.value_function.eval_value(observation) + + # wrong version of advantage just to run + advantage = reward + state_value + + return {'advantage': advantage} class nstep_return: @@ -60,8 +66,41 @@ class nstep_return: def __call__(self, raw_data): reward = raw_data['reward'] + observation = raw_data['observation'] - return {'return': reward} + state_value = self.value_function.eval_value(observation) + + # wrong version of return just to run + return_ = reward + state_value + + return {'return': return_} + + +class nstep_q_return: + """ + compute the n-step return for Q-learning targets + """ + def __init__(self, n, action_value, use_target_network=True): + self.n = n + self.action_value = action_value + self.use_target_network = use_target_network + + def __call__(self, raw_data): + # raw_data should contain 'next_observation' from replay memory...? + # maybe the main difference between Batch and Replay is the stored data format? + reward = raw_data['reward'] + observation = raw_data['observation'] + + if self.use_target_network: + action_value_all_actions = self.action_value.eval_value_all_actions_old(observation) + else: + action_value_all_actions = self.action_value.eval_value_all_actions(observation) + + action_value_max = np.max(action_value_all_actions, axis=1) + + return_ = reward + action_value_max + + return {'return': return_} class QLearningTarget: