fixed the bugs on Jan 14, which gives inferior or even no improvement. mistook group_ndims. policy will soon need refactoring.

This commit is contained in:
haoshengzou 2018-01-15 16:32:30 +08:00
parent 983cd36074
commit d599506dc9
3 changed files with 28 additions and 8 deletions

View File

@ -24,7 +24,7 @@ if __name__ == '__main__':
num_batches = 10 num_batches = 10
batch_size = 512 batch_size = 512
seed = 5 seed = 0
np.random.seed(seed) np.random.seed(seed)
tf.set_random_seed(seed) tf.set_random_seed(seed)
@ -32,8 +32,8 @@ if __name__ == '__main__':
observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim) observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim)
def my_policy(): def my_policy():
net = tf.layers.dense(observation_ph, 64, activation=tf.nn.tanh) net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
net = tf.layers.dense(net, 64, activation=tf.nn.tanh) net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
action_logits = tf.layers.dense(net, action_dim, activation=None) action_logits = tf.layers.dense(net, action_dim, activation=None)

View File

@ -174,6 +174,20 @@ class StochasticPolicy(PolicyBase):
log_p = self._log_prob(sampled_action) log_p = self._log_prob(sampled_action)
return tf.reduce_sum(log_p, tf.range(-self._group_ndims, 0)) return tf.reduce_sum(log_p, tf.range(-self._group_ndims, 0))
def log_prob_old(self, sampled_action):
"""
log_prob(sampled_action)
Compute log probability density (mass) function at `given` value.
:param given: A Tensor. The value at which to evaluate log probability
density (mass) function. Must be able to broadcast to have a shape
of ``(... + )batch_shape + value_shape``.
:return: A Tensor of shape ``(... + )batch_shape[:-group_ndims]``.
"""
log_p = self._log_prob_old(sampled_action)
return tf.reduce_sum(log_p, tf.range(-self._group_ndims, 0))
# @add_name_scope # @add_name_scope
def prob(self, sampled_action): def prob(self, sampled_action):
""" """
@ -195,6 +209,12 @@ class StochasticPolicy(PolicyBase):
""" """
raise NotImplementedError() raise NotImplementedError()
def _log_prob_old(self, sampled_action):
"""
Private method for subclasses to rewrite the :meth:`log_prob` method.
"""
raise NotImplementedError()
def _prob(self, sampled_action): def _prob(self, sampled_action):
""" """
Private method for subclasses to rewrite the :meth:`prob` method. Private method for subclasses to rewrite the :meth:`prob` method.

View File

@ -38,7 +38,7 @@ class OnehotCategorical(StochasticPolicy):
policy_callable, policy_callable,
observation_placeholder, observation_placeholder,
weight_update=1, weight_update=1,
group_ndims=1, group_ndims=0,
**kwargs): **kwargs):
self.managed_placeholders = {'observation': observation_placeholder} self.managed_placeholders = {'observation': observation_placeholder}
self.weight_update = weight_update self.weight_update = weight_update
@ -47,7 +47,7 @@ class OnehotCategorical(StochasticPolicy):
with tf.variable_scope('network'): with tf.variable_scope('network'):
logits, value_head = policy_callable() logits, value_head = policy_callable()
self._logits = tf.convert_to_tensor(logits, dtype=tf.float32) self._logits = tf.convert_to_tensor(logits, dtype=tf.float32)
self._action = tf.multinomial(self.logits, num_samples=1) self._action = tf.multinomial(self._logits, num_samples=1)
# TODO: self._action should be exactly the action tensor to run that directly gives action_dim # TODO: self._action should be exactly the action tensor to run that directly gives action_dim
if value_head is not None: if value_head is not None:
@ -129,7 +129,7 @@ class OnehotCategorical(StochasticPolicy):
def _prob(self, sampled_action): def _prob(self, sampled_action):
return tf.exp(self._log_prob(sampled_action)) return tf.exp(self._log_prob(sampled_action))
def log_prob_old(self, sampled_action): def _log_prob_old(self, sampled_action):
return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=sampled_action, logits=self._logits_old) return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=sampled_action, logits=self._logits_old)
def update_weights(self): def update_weights(self):
@ -242,7 +242,7 @@ class Normal(StochasticPolicy):
@property @property
def action_shape(self): def action_shape(self):
return tuple(self._mean.shape.as_list[1:]) return tuple(self._mean.shape.as_list()[1:])
def _act(self, observation, my_feed_dict): def _act(self, observation, my_feed_dict):
# TODO: getting session like this maybe ugly. also maybe huge problem when parallel # TODO: getting session like this maybe ugly. also maybe huge problem when parallel
@ -265,7 +265,7 @@ class Normal(StochasticPolicy):
def _prob(self, sampled_action): def _prob(self, sampled_action):
return tf.exp(self._log_prob(sampled_action)) return tf.exp(self._log_prob(sampled_action))
def log_prob_old(self, sampled_action): def _log_prob_old(self, sampled_action):
""" """
return the log_prob of the old policy when constructing tf graphs. Raises error when there's no old policy. return the log_prob of the old policy when constructing tf graphs. Raises error when there's no old policy.
:param sampled_action: the placeholder for sampled actions during interaction with the environment. :param sampled_action: the placeholder for sampled actions during interaction with the environment.