minor fix in core/policy
This commit is contained in:
parent
e3c0478fa0
commit
88ecaa332d
@ -148,19 +148,19 @@ class StochasticPolicy(object):
|
||||
"""
|
||||
return self._act(observation)
|
||||
|
||||
if n_samples is None:
|
||||
samples = self._sample(n_samples=1)
|
||||
return tf.squeeze(samples, axis=0)
|
||||
elif isinstance(n_samples, int):
|
||||
return self._sample(n_samples)
|
||||
else:
|
||||
n_samples = tf.convert_to_tensor(n_samples, dtype=tf.int32)
|
||||
_assert_rank_op = tf.assert_rank(
|
||||
n_samples, 0,
|
||||
message="n_samples should be a scalar (0-D Tensor).")
|
||||
with tf.control_dependencies([_assert_rank_op]):
|
||||
samples = self._sample(n_samples)
|
||||
return samples
|
||||
# if n_samples is None:
|
||||
# samples = self._sample(n_samples=1)
|
||||
# return tf.squeeze(samples, axis=0)
|
||||
# elif isinstance(n_samples, int):
|
||||
# return self._sample(n_samples)
|
||||
# else:
|
||||
# n_samples = tf.convert_to_tensor(n_samples, dtype=tf.int32)
|
||||
# _assert_rank_op = tf.assert_rank(
|
||||
# n_samples, 0,
|
||||
# message="n_samples should be a scalar (0-D Tensor).")
|
||||
# with tf.control_dependencies([_assert_rank_op]):
|
||||
# samples = self._sample(n_samples)
|
||||
# return samples
|
||||
|
||||
def _act(self, observation):
|
||||
"""
|
||||
|
@ -76,7 +76,6 @@ class OnehotCategorical(StochasticPolicy):
|
||||
return sampled_action
|
||||
|
||||
def _log_prob(self, sampled_action):
|
||||
sampled_action_onehot = tf.one_hot(sampled_action, self.n_categories, dtype=self.act_dtype)
|
||||
return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=sampled_action, logits=self.logits)
|
||||
|
||||
# given = tf.cast(given, self.param_dtype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user