fixed the bugs on Jan 14, which gives inferior or even no improvement. mistook group_ndims. policy will soon need refactoring.
This commit is contained in:
parent
983cd36074
commit
d599506dc9
@ -24,7 +24,7 @@ if __name__ == '__main__':
|
|||||||
num_batches = 10
|
num_batches = 10
|
||||||
batch_size = 512
|
batch_size = 512
|
||||||
|
|
||||||
seed = 5
|
seed = 0
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
tf.set_random_seed(seed)
|
tf.set_random_seed(seed)
|
||||||
|
|
||||||
@ -32,8 +32,8 @@ if __name__ == '__main__':
|
|||||||
observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim)
|
observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim)
|
||||||
|
|
||||||
def my_policy():
|
def my_policy():
|
||||||
net = tf.layers.dense(observation_ph, 64, activation=tf.nn.tanh)
|
net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
|
||||||
net = tf.layers.dense(net, 64, activation=tf.nn.tanh)
|
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
|
||||||
|
|
||||||
action_logits = tf.layers.dense(net, action_dim, activation=None)
|
action_logits = tf.layers.dense(net, action_dim, activation=None)
|
||||||
|
|
||||||
|
@ -174,6 +174,20 @@ class StochasticPolicy(PolicyBase):
|
|||||||
log_p = self._log_prob(sampled_action)
|
log_p = self._log_prob(sampled_action)
|
||||||
return tf.reduce_sum(log_p, tf.range(-self._group_ndims, 0))
|
return tf.reduce_sum(log_p, tf.range(-self._group_ndims, 0))
|
||||||
|
|
||||||
|
def log_prob_old(self, sampled_action):
|
||||||
|
"""
|
||||||
|
log_prob(sampled_action)
|
||||||
|
|
||||||
|
Compute log probability density (mass) function at `given` value.
|
||||||
|
|
||||||
|
:param given: A Tensor. The value at which to evaluate log probability
|
||||||
|
density (mass) function. Must be able to broadcast to have a shape
|
||||||
|
of ``(... + )batch_shape + value_shape``.
|
||||||
|
:return: A Tensor of shape ``(... + )batch_shape[:-group_ndims]``.
|
||||||
|
"""
|
||||||
|
log_p = self._log_prob_old(sampled_action)
|
||||||
|
return tf.reduce_sum(log_p, tf.range(-self._group_ndims, 0))
|
||||||
|
|
||||||
# @add_name_scope
|
# @add_name_scope
|
||||||
def prob(self, sampled_action):
|
def prob(self, sampled_action):
|
||||||
"""
|
"""
|
||||||
@ -195,6 +209,12 @@ class StochasticPolicy(PolicyBase):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _log_prob_old(self, sampled_action):
|
||||||
|
"""
|
||||||
|
Private method for subclasses to rewrite the :meth:`log_prob` method.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def _prob(self, sampled_action):
|
def _prob(self, sampled_action):
|
||||||
"""
|
"""
|
||||||
Private method for subclasses to rewrite the :meth:`prob` method.
|
Private method for subclasses to rewrite the :meth:`prob` method.
|
||||||
|
@ -38,7 +38,7 @@ class OnehotCategorical(StochasticPolicy):
|
|||||||
policy_callable,
|
policy_callable,
|
||||||
observation_placeholder,
|
observation_placeholder,
|
||||||
weight_update=1,
|
weight_update=1,
|
||||||
group_ndims=1,
|
group_ndims=0,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
self.managed_placeholders = {'observation': observation_placeholder}
|
self.managed_placeholders = {'observation': observation_placeholder}
|
||||||
self.weight_update = weight_update
|
self.weight_update = weight_update
|
||||||
@ -47,7 +47,7 @@ class OnehotCategorical(StochasticPolicy):
|
|||||||
with tf.variable_scope('network'):
|
with tf.variable_scope('network'):
|
||||||
logits, value_head = policy_callable()
|
logits, value_head = policy_callable()
|
||||||
self._logits = tf.convert_to_tensor(logits, dtype=tf.float32)
|
self._logits = tf.convert_to_tensor(logits, dtype=tf.float32)
|
||||||
self._action = tf.multinomial(self.logits, num_samples=1)
|
self._action = tf.multinomial(self._logits, num_samples=1)
|
||||||
# TODO: self._action should be exactly the action tensor to run that directly gives action_dim
|
# TODO: self._action should be exactly the action tensor to run that directly gives action_dim
|
||||||
|
|
||||||
if value_head is not None:
|
if value_head is not None:
|
||||||
@ -129,7 +129,7 @@ class OnehotCategorical(StochasticPolicy):
|
|||||||
def _prob(self, sampled_action):
|
def _prob(self, sampled_action):
|
||||||
return tf.exp(self._log_prob(sampled_action))
|
return tf.exp(self._log_prob(sampled_action))
|
||||||
|
|
||||||
def log_prob_old(self, sampled_action):
|
def _log_prob_old(self, sampled_action):
|
||||||
return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=sampled_action, logits=self._logits_old)
|
return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=sampled_action, logits=self._logits_old)
|
||||||
|
|
||||||
def update_weights(self):
|
def update_weights(self):
|
||||||
@ -242,7 +242,7 @@ class Normal(StochasticPolicy):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def action_shape(self):
|
def action_shape(self):
|
||||||
return tuple(self._mean.shape.as_list[1:])
|
return tuple(self._mean.shape.as_list()[1:])
|
||||||
|
|
||||||
def _act(self, observation, my_feed_dict):
|
def _act(self, observation, my_feed_dict):
|
||||||
# TODO: getting session like this maybe ugly. also maybe huge problem when parallel
|
# TODO: getting session like this maybe ugly. also maybe huge problem when parallel
|
||||||
@ -265,7 +265,7 @@ class Normal(StochasticPolicy):
|
|||||||
def _prob(self, sampled_action):
|
def _prob(self, sampled_action):
|
||||||
return tf.exp(self._log_prob(sampled_action))
|
return tf.exp(self._log_prob(sampled_action))
|
||||||
|
|
||||||
def log_prob_old(self, sampled_action):
|
def _log_prob_old(self, sampled_action):
|
||||||
"""
|
"""
|
||||||
return the log_prob of the old policy when constructing tf graphs. Raises error when there's no old policy.
|
return the log_prob of the old policy when constructing tf graphs. Raises error when there's no old policy.
|
||||||
:param sampled_action: the placeholder for sampled actions during interaction with the environment.
|
:param sampled_action: the placeholder for sampled actions during interaction with the environment.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user