From d599506dc9f24e9ce7656037e25fcf126279cb98 Mon Sep 17 00:00:00 2001 From: haoshengzou Date: Mon, 15 Jan 2018 16:32:30 +0800 Subject: [PATCH] fixed the bugs on Jan 14, which gives inferior or even no improvement. mistook group_ndims. policy will soon need refactoring. --- examples/ppo_cartpole_gym.py | 6 +++--- tianshou/core/policy/base.py | 20 ++++++++++++++++++++ tianshou/core/policy/stochastic.py | 10 +++++----- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/examples/ppo_cartpole_gym.py b/examples/ppo_cartpole_gym.py index 42d1c13..2710c98 100755 --- a/examples/ppo_cartpole_gym.py +++ b/examples/ppo_cartpole_gym.py @@ -24,7 +24,7 @@ if __name__ == '__main__': num_batches = 10 batch_size = 512 - seed = 5 + seed = 0 np.random.seed(seed) tf.set_random_seed(seed) @@ -32,8 +32,8 @@ if __name__ == '__main__': observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim) def my_policy(): - net = tf.layers.dense(observation_ph, 64, activation=tf.nn.tanh) - net = tf.layers.dense(net, 64, activation=tf.nn.tanh) + net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh) + net = tf.layers.dense(net, 32, activation=tf.nn.tanh) action_logits = tf.layers.dense(net, action_dim, activation=None) diff --git a/tianshou/core/policy/base.py b/tianshou/core/policy/base.py index 5657940..23cd45d 100644 --- a/tianshou/core/policy/base.py +++ b/tianshou/core/policy/base.py @@ -174,6 +174,20 @@ class StochasticPolicy(PolicyBase): log_p = self._log_prob(sampled_action) return tf.reduce_sum(log_p, tf.range(-self._group_ndims, 0)) + def log_prob_old(self, sampled_action): + """ + log_prob(sampled_action) + + Compute log probability density (mass) function at `given` value. + + :param given: A Tensor. The value at which to evaluate log probability + density (mass) function. Must be able to broadcast to have a shape + of ``(... + )batch_shape + value_shape``. + :return: A Tensor of shape ``(... + )batch_shape[:-group_ndims]``. + """ + log_p = self._log_prob_old(sampled_action) + return tf.reduce_sum(log_p, tf.range(-self._group_ndims, 0)) + # @add_name_scope def prob(self, sampled_action): """ @@ -195,6 +209,12 @@ class StochasticPolicy(PolicyBase): """ raise NotImplementedError() + def _log_prob_old(self, sampled_action): + """ + Private method for subclasses to rewrite the :meth:`log_prob` method. + """ + raise NotImplementedError() + def _prob(self, sampled_action): """ Private method for subclasses to rewrite the :meth:`prob` method. diff --git a/tianshou/core/policy/stochastic.py b/tianshou/core/policy/stochastic.py index 33ee36a..294c21f 100644 --- a/tianshou/core/policy/stochastic.py +++ b/tianshou/core/policy/stochastic.py @@ -38,7 +38,7 @@ class OnehotCategorical(StochasticPolicy): policy_callable, observation_placeholder, weight_update=1, - group_ndims=1, + group_ndims=0, **kwargs): self.managed_placeholders = {'observation': observation_placeholder} self.weight_update = weight_update @@ -47,7 +47,7 @@ class OnehotCategorical(StochasticPolicy): with tf.variable_scope('network'): logits, value_head = policy_callable() self._logits = tf.convert_to_tensor(logits, dtype=tf.float32) - self._action = tf.multinomial(self.logits, num_samples=1) + self._action = tf.multinomial(self._logits, num_samples=1) # TODO: self._action should be exactly the action tensor to run that directly gives action_dim if value_head is not None: @@ -129,7 +129,7 @@ class OnehotCategorical(StochasticPolicy): def _prob(self, sampled_action): return tf.exp(self._log_prob(sampled_action)) - def log_prob_old(self, sampled_action): + def _log_prob_old(self, sampled_action): return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=sampled_action, logits=self._logits_old) def update_weights(self): @@ -242,7 +242,7 @@ class Normal(StochasticPolicy): @property def action_shape(self): - return tuple(self._mean.shape.as_list[1:]) + return tuple(self._mean.shape.as_list()[1:]) def _act(self, observation, my_feed_dict): # TODO: getting session like this maybe ugly. also maybe huge problem when parallel @@ -265,7 +265,7 @@ class Normal(StochasticPolicy): def _prob(self, sampled_action): return tf.exp(self._log_prob(sampled_action)) - def log_prob_old(self, sampled_action): + def _log_prob_old(self, sampled_action): """ return the log_prob of the old policy when constructing tf graphs. Raises error when there's no old policy. :param sampled_action: the placeholder for sampled actions during interaction with the environment.