fixed the bugs on Jan 14, which gives inferior or even no improvement. mistook group_ndims. policy will soon need refactoring.
This commit is contained in:
parent
983cd36074
commit
d599506dc9
@ -24,7 +24,7 @@ if __name__ == '__main__':
|
||||
num_batches = 10
|
||||
batch_size = 512
|
||||
|
||||
seed = 5
|
||||
seed = 0
|
||||
np.random.seed(seed)
|
||||
tf.set_random_seed(seed)
|
||||
|
||||
@ -32,8 +32,8 @@ if __name__ == '__main__':
|
||||
observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim)
|
||||
|
||||
def my_policy():
|
||||
net = tf.layers.dense(observation_ph, 64, activation=tf.nn.tanh)
|
||||
net = tf.layers.dense(net, 64, activation=tf.nn.tanh)
|
||||
net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
|
||||
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
|
||||
|
||||
action_logits = tf.layers.dense(net, action_dim, activation=None)
|
||||
|
||||
|
@ -174,6 +174,20 @@ class StochasticPolicy(PolicyBase):
|
||||
log_p = self._log_prob(sampled_action)
|
||||
return tf.reduce_sum(log_p, tf.range(-self._group_ndims, 0))
|
||||
|
||||
def log_prob_old(self, sampled_action):
|
||||
"""
|
||||
log_prob(sampled_action)
|
||||
|
||||
Compute log probability density (mass) function at `given` value.
|
||||
|
||||
:param given: A Tensor. The value at which to evaluate log probability
|
||||
density (mass) function. Must be able to broadcast to have a shape
|
||||
of ``(... + )batch_shape + value_shape``.
|
||||
:return: A Tensor of shape ``(... + )batch_shape[:-group_ndims]``.
|
||||
"""
|
||||
log_p = self._log_prob_old(sampled_action)
|
||||
return tf.reduce_sum(log_p, tf.range(-self._group_ndims, 0))
|
||||
|
||||
# @add_name_scope
|
||||
def prob(self, sampled_action):
|
||||
"""
|
||||
@ -195,6 +209,12 @@ class StochasticPolicy(PolicyBase):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _log_prob_old(self, sampled_action):
|
||||
"""
|
||||
Private method for subclasses to rewrite the :meth:`log_prob` method.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _prob(self, sampled_action):
|
||||
"""
|
||||
Private method for subclasses to rewrite the :meth:`prob` method.
|
||||
|
@ -38,7 +38,7 @@ class OnehotCategorical(StochasticPolicy):
|
||||
policy_callable,
|
||||
observation_placeholder,
|
||||
weight_update=1,
|
||||
group_ndims=1,
|
||||
group_ndims=0,
|
||||
**kwargs):
|
||||
self.managed_placeholders = {'observation': observation_placeholder}
|
||||
self.weight_update = weight_update
|
||||
@ -47,7 +47,7 @@ class OnehotCategorical(StochasticPolicy):
|
||||
with tf.variable_scope('network'):
|
||||
logits, value_head = policy_callable()
|
||||
self._logits = tf.convert_to_tensor(logits, dtype=tf.float32)
|
||||
self._action = tf.multinomial(self.logits, num_samples=1)
|
||||
self._action = tf.multinomial(self._logits, num_samples=1)
|
||||
# TODO: self._action should be exactly the action tensor to run that directly gives action_dim
|
||||
|
||||
if value_head is not None:
|
||||
@ -129,7 +129,7 @@ class OnehotCategorical(StochasticPolicy):
|
||||
def _prob(self, sampled_action):
|
||||
return tf.exp(self._log_prob(sampled_action))
|
||||
|
||||
def log_prob_old(self, sampled_action):
|
||||
def _log_prob_old(self, sampled_action):
|
||||
return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=sampled_action, logits=self._logits_old)
|
||||
|
||||
def update_weights(self):
|
||||
@ -242,7 +242,7 @@ class Normal(StochasticPolicy):
|
||||
|
||||
@property
|
||||
def action_shape(self):
|
||||
return tuple(self._mean.shape.as_list[1:])
|
||||
return tuple(self._mean.shape.as_list()[1:])
|
||||
|
||||
def _act(self, observation, my_feed_dict):
|
||||
# TODO: getting session like this maybe ugly. also maybe huge problem when parallel
|
||||
@ -265,7 +265,7 @@ class Normal(StochasticPolicy):
|
||||
def _prob(self, sampled_action):
|
||||
return tf.exp(self._log_prob(sampled_action))
|
||||
|
||||
def log_prob_old(self, sampled_action):
|
||||
def _log_prob_old(self, sampled_action):
|
||||
"""
|
||||
return the log_prob of the old policy when constructing tf graphs. Raises error when there's no old policy.
|
||||
:param sampled_action: the placeholder for sampled actions during interaction with the environment.
|
||||
|
Loading…
x
Reference in New Issue
Block a user