From fed3bf2a1271f6a104070563efbab09e72e00df9 Mon Sep 17 00:00:00 2001 From: haoshengzou Date: Sun, 14 Jan 2018 20:58:28 +0800 Subject: [PATCH] auto target network. ppo_cartpole.py run ok. but results is different from previous version even with the same random seed, still needs debugging. --- examples/ppo_cartpole.py | 75 +++++++++---------- tianshou/core/losses.py | 14 ++-- tianshou/core/policy/base.py | 43 ++--------- tianshou/core/policy/stochastic.py | 115 ++++++++++++++++++++++++----- tianshou/data/batch.py | 7 +- 5 files changed, 151 insertions(+), 103 deletions(-) diff --git a/examples/ppo_cartpole.py b/examples/ppo_cartpole.py index 418cc52..ae7de41 100755 --- a/examples/ppo_cartpole.py +++ b/examples/ppo_cartpole.py @@ -17,24 +17,9 @@ from rllab.envs.box2d.cartpole_env import CartpoleEnv from rllab.envs.normalized_env import normalize -def policy_net(observation, action_dim, scope=None): - """ - Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf +# for tutorial purpose, placeholders are explicitly appended with '_ph' suffix - :param observation: Placeholder for the observation. A tensor of shape (bs, x, y, channels) - :param action_dim: int. The number of actions. - :param scope: str. Specifying the scope of the variables. - """ - # with tf.variable_scope(scope): - net = tf.layers.dense(observation, 32, activation=tf.nn.tanh) - net = tf.layers.dense(net, 32, activation=tf.nn.tanh) - - act_mean = tf.layers.dense(net, action_dim, activation=None) - - return act_mean - - -if __name__ == '__main__': # a clean version with only policy net, no value net +if __name__ == '__main__': env = normalize(CartpoleEnv()) observation_dim = env.observation_space.shape action_dim = env.action_space.flat_dim @@ -47,44 +32,51 @@ if __name__ == '__main__': # a clean version with only policy net, no value net np.random.seed(seed) tf.set_random_seed(seed) - # 1. build network with pure tf - observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input + ### 1. build network with pure tf + observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input - with tf.variable_scope('pi'): - action_mean = policy_net(observation, action_dim, 'pi') - action_logstd = tf.get_variable('action_logstd', shape=(action_dim,)) - train_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES - with tf.variable_scope('pi_old'): - action_mean_old = policy_net(observation, action_dim, 'pi_old') - action_logstd_old = tf.get_variable('action_logstd_old', shape=(action_dim,)) - pi_old_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'pi_old') + def my_policy(): + net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh) + net = tf.layers.dense(net, 32, activation=tf.nn.tanh) + action_mean = tf.layers.dense(net, action_dim, activation=None) + action_logstd = tf.get_variable('action_logstd', shape=(action_dim, )) - # 2. build losses, optimizers - pi = policy.Normal(action_mean, action_logstd, observation_placeholder=observation) # YongRen: policy.Gaussian (could reference the policy in TRPO paper, my code is adapted from zhusuan.distributions) policy.DQN etc. - # for continuous action space, you may need to change an environment to run - pi_old = policy.Normal(action_mean_old, action_logstd_old, observation_placeholder=observation) + # value = tf.layers.dense(net, 1, activation=None) - action = tf.placeholder(dtype=tf.float32, shape=(None, action_dim)) # batch of integer actions - advantage = tf.placeholder(dtype=tf.float32, shape=(None,)) # advantage values used in the Gradients + return action_mean, action_logstd, None # None value head - ppo_loss_clip = losses.ppo_clip(action, advantage, clip_param, pi, pi_old) # TongzhengRen: losses.vpg ... management of placeholders and feed_dict + # TODO: current implementation of passing function or overriding function has to return a value head + # to allow network sharing between policy and value networks. This makes 'policy' and 'value_function' + # imbalanced semantically (though they are naturally imbalanced since 'policy' is required to interact + # with the environment and 'value_function' is not). I have an idea to solve this imbalance, which is + # not based on passing function or overriding function. + + ### 2. build policy, losses, optimizers + pi = policy.Normal(my_policy, observation_placeholder=observation_ph, weight_update=0) + + # action = tf.placeholder(dtype=tf.float32, shape=(None, action_dim)) # batch of integer actions + # advantage = tf.placeholder(dtype=tf.float32, shape=(None,)) # advantage values used in the Gradients + + ppo_loss_clip = losses.ppo_clip(pi, clip_param) # TongzhengRen: losses.vpg ... management of placeholders and feed_dict total_loss = ppo_loss_clip optimizer = tf.train.AdamOptimizer(1e-4) - train_op = optimizer.minimize(total_loss, var_list=train_var_list) + train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables) - # 3. define data collection + ### 3. define data collection training_data = Batch(env, pi, advantage_estimation.full_return) # YouQiaoben: finish and polish Batch, advantage_estimation.gae_lambda as in PPO paper # ShihongSong: Replay(), see dqn_example.py # maybe a dict to manage the elements to be collected - # 4. start training + ### 4. start training + # init = tf.global_variables_initializer() config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) + # sync pi and pi_old - sess.run([tf.assign(theta_old, theta) for (theta_old, theta) in zip(pi_old_var_list, train_var_list)]) + pi.sync_weights() # TODO: automate this for policies with target network start_time = time.time() for i in range(100): # until some stopping criterion met... @@ -97,11 +89,10 @@ if __name__ == '__main__': # a clean version with only policy net, no value net # update network for _ in range(num_batches): - data = training_data.next_batch(batch_size) # YouQiaoben, ShihongSong - # TODO: auto managing of the placeholders? or add this to params of data.Batch - sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], advantage: data['returns']}) + feed_dict = training_data.next_batch(batch_size) # YouQiaoben, ShihongSong + sess.run(train_op, feed_dict=feed_dict) # assigning pi to pi_old - sess.run([tf.assign(theta_old, theta) for (theta_old, theta) in zip(pi_old_var_list, train_var_list)]) + pi.update_weights() print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) \ No newline at end of file diff --git a/tianshou/core/losses.py b/tianshou/core/losses.py index 5d5d2f3..e4f56ce 100644 --- a/tianshou/core/losses.py +++ b/tianshou/core/losses.py @@ -1,22 +1,26 @@ import tensorflow as tf -def ppo_clip(sampled_action, advantage, clip_param, pi, pi_old): +def ppo_clip(policy, clip_param): """ the clip loss in ppo paper :param sampled_action: placeholder of sampled actions during interaction with the environment :param advantage: placeholder of estimated advantage values. :param clip param: float or Tensor of type float. - :param pi: current `policy` to be optimized + :param policy: current `policy` to be optimized :param pi_old: old `policy` for computing the ppo loss as in Eqn. (7) in the paper """ + action_ph = tf.placeholder(policy.act_dtype, shape=(None, policy.action_dim), name='ppo_clip_loss/action_placeholder') + advantage_ph = tf.placeholder(tf.float32, shape=(None,), name='ppo_clip_loss/advantage_placeholder') + policy.managed_placeholders['action'] = action_ph + policy.managed_placeholders['processed_reward'] = advantage_ph - log_pi_act = pi.log_prob(sampled_action) - log_pi_old_act = pi_old.log_prob(sampled_action) + log_pi_act = policy.log_prob(action_ph) + log_pi_old_act = policy.log_prob_old(action_ph) ratio = tf.exp(log_pi_act - log_pi_old_act) clipped_ratio = tf.clip_by_value(ratio, 1. - clip_param, 1. + clip_param) - ppo_clip_loss = -tf.reduce_mean(tf.minimum(ratio * advantage, clipped_ratio * advantage)) + ppo_clip_loss = -tf.reduce_mean(tf.minimum(ratio * advantage_ph, clipped_ratio * advantage_ph)) return ppo_clip_loss diff --git a/tianshou/core/policy/base.py b/tianshou/core/policy/base.py index 8d4b2a1..5657940 100644 --- a/tianshou/core/policy/base.py +++ b/tianshou/core/policy/base.py @@ -16,44 +16,11 @@ class PolicyBase(object): """ base class for policy. only provides `act` method with exploration """ - def __init__(self, observation_placeholder): - self._observation_placeholder = observation_placeholder - - def act(self, observation, exploration): + def act(self, observation): raise NotImplementedError() -class QValuePolicy(object): - """ - The policy as in DQN - """ - def __init__(self, observation_placeholder): - self._observation_placeholder = observation_placeholder - - def act(self, observation, exploration=None): # first implement no exploration - """ - return the action (int) to be executed. - no exploration when exploration=None. - """ - self._act(observation, exploration) - - def _act(self, observation, exploration=None): - raise NotImplementedError() - - def values(self, observation): - """ - returns the Q(s, a) values (float) for all actions a at observation s - """ - pass - - def values_tensor(self): - """ - returns the tensor of the values for all actions a at observation s - """ - pass - - -class StochasticPolicy(object): +class StochasticPolicy(PolicyBase): """ The :class:`StochasticPolicy` class is the base class for various probabilistic distributions which support batch inputs, generating batches of samples and @@ -170,7 +137,7 @@ class StochasticPolicy(object): return self._group_ndims # @add_name_scope - def act(self, observation): + def act(self, observation, my_feed_dict={}): """ sample(n_samples=None) @@ -184,9 +151,9 @@ class StochasticPolicy(object): samples to draw from the distribution. :return: A Tensor of samples. """ - return self._act(observation) + return self._act(observation, my_feed_dict) - def _act(self, observation): + def _act(self, observation, my_feed_dict): """ Private method for subclasses to rewrite the :meth:`sample` method. """ diff --git a/tianshou/core/policy/stochastic.py b/tianshou/core/policy/stochastic.py index cda204d..d3ab8e7 100644 --- a/tianshou/core/policy/stochastic.py +++ b/tianshou/core/policy/stochastic.py @@ -94,24 +94,64 @@ class Normal(StochasticPolicy): :param observation_placeholder """ def __init__(self, - mean = 0., - logstd = 1., - group_ndims = 1, - observation_placeholder = None, + policy_callable, + observation_placeholder, + weight_update=1, + group_ndims=1, **kwargs): + self.managed_placeholders = {'observation': observation_placeholder} + self.weight_update = weight_update + self.interaction_count = -1 # defaults to -1. only useful if weight_update > 1. + + with tf.variable_scope('network'): + mean, logstd, value_head = policy_callable() + self._mean = tf.convert_to_tensor(mean, dtype = tf.float32) + self._logstd = tf.convert_to_tensor(logstd, dtype = tf.float32) + self._std = tf.exp(self._logstd) + + shape = tf.broadcast_dynamic_shape(tf.shape(self._mean), tf.shape(self._std)) + self._action = tf.random_normal(tf.concat([[1], shape], 0), dtype = tf.float32) * self._std + self._mean + # TODO: self._action should be exactly the action tensor to run, without [0, 0] in self._act + + if value_head is not None: + pass + + self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network') + + if self.weight_update == 1: + self.weight_update_ops = None + self.sync_weights_ops = None + else: # then we need to build another tf graph as target network + with tf.variable_scope('net_old'): + mean, logstd, value_head = policy_callable() + self._mean_old = tf.convert_to_tensor(mean, dtype=tf.float32) + self._logstd_old = tf.convert_to_tensor(logstd, dtype=tf.float32) + + if value_head is not None: # useful in DDPG + pass + + network_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='network') + self.network_old_weights = network_old_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net_old') + assert len(network_weights) == len(network_old_weights) + self.sync_weights_ops = [tf.assign(variable_old, variable) + for (variable_old, variable) in zip(network_old_weights, network_weights)] + + if weight_update == 0: + self.weight_update_ops = self.sync_weights_ops + elif 0 < weight_update < 1: + pass + else: + self.interaction_count = 0 + import math + self.weight_update = math.ceil(weight_update) - self._mean = tf.convert_to_tensor(mean, dtype = tf.float32) - self._logstd = tf.convert_to_tensor(logstd, dtype = tf.float32) - self._std = tf.exp(self._logstd) - shape = tf.broadcast_dynamic_shape(tf.shape(self._mean), tf.shape(self._std)) - self._action = tf.random_normal(tf.concat([[1], shape], 0), dtype = tf.float32) * self._std + self._mean super(Normal, self).__init__( - act_dtype = tf.float32, - param_dtype = tf.float32, - is_continuous = True, - observation_placeholder = observation_placeholder, + act_dtype=tf.float32, + param_dtype=tf.float32, + is_continuous=True, + observation_placeholder=observation_placeholder, group_ndims = group_ndims, **kwargs) @@ -127,13 +167,22 @@ class Normal(StochasticPolicy): def logstd(self): return self._logstd - def _act(self, observation): + @property + def action(self): + return self._action + + @property + def action_dim(self): + return self.mean.shape.as_list()[-1] + + def _act(self, observation, my_feed_dict): # TODO: getting session like this maybe ugly. also maybe huge problem when parallel sess = tf.get_default_session() # observation[None] adds one dimension at the beginning - sampled_action = sess.run(self._action, - feed_dict={self._observation_placeholder: observation[None]}) + feed_dict = {self._observation_placeholder: observation[None]} + feed_dict.update(my_feed_dict) + sampled_action = sess.run(self._action, feed_dict=feed_dict) sampled_action = sampled_action[0, 0] return sampled_action @@ -145,4 +194,36 @@ class Normal(StochasticPolicy): return c - logstd - 0.5 * precision * tf.square(sampled_action - mean) def _prob(self, sampled_action): - return tf.exp(self._log_prob(sampled_action)) \ No newline at end of file + return tf.exp(self._log_prob(sampled_action)) + + def log_prob_old(self, sampled_action): + """ + return the log_prob of the old policy when constructing tf graphs. Raises error when there's no old policy. + :param sampled_action: the placeholder for sampled actions during interaction with the environment. + :return: tensor of the log_prob of the old policy + """ + if self.weight_update == 1: + raise AttributeError('Policy has no policy_old since it\'s initialized with weight_update=1!') + + mean, logstd = self._mean_old, self._logstd_old + c = -0.5 * np.log(2 * np.pi) + precision = tf.exp(-2 * logstd) + return c - logstd - 0.5 * precision * tf.square(sampled_action - mean) + + def update_weights(self): + """ + updates the weights of policy_old. + :return: + """ + if self.weight_update_ops is not None: + sess = tf.get_default_session() + sess.run(self.weight_update_ops) + + def sync_weights(self): + """ + sync the weights of network_old. Direct copy the weights of network. + :return: + """ + if self.sync_weights_ops is not None: + sess = tf.get_default_session() + sess.run(self.sync_weights_ops) \ No newline at end of file diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index a2a8dde..2e372ea 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -135,7 +135,12 @@ class Batch(object): advantage_std = np.std(current_batch['returns']) current_batch['returns'] = (current_batch['returns'] - advantage_mean) / advantage_std - return current_batch + feed_dict = {} + feed_dict[self._pi.managed_placeholders['observation']] = current_batch['observations'] + feed_dict[self._pi.managed_placeholders['action']] = current_batch['actions'] + feed_dict[self._pi.managed_placeholders['processed_reward']] = current_batch['returns'] + + return feed_dict # TODO: this will definitely be refactored with a proper logger def statistics(self):