From 88ecaa332d01bdd8c142abefb82ef3cf591a36e2 Mon Sep 17 00:00:00 2001 From: haosheng Date: Mon, 11 Dec 2017 13:25:22 +0800 Subject: [PATCH] minor fix in core/policy --- tianshou/core/policy/base.py | 26 +++++++++++++------------- tianshou/core/policy/stochastic.py | 1 - 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/tianshou/core/policy/base.py b/tianshou/core/policy/base.py index a61661c..9b149d7 100644 --- a/tianshou/core/policy/base.py +++ b/tianshou/core/policy/base.py @@ -148,19 +148,19 @@ class StochasticPolicy(object): """ return self._act(observation) - if n_samples is None: - samples = self._sample(n_samples=1) - return tf.squeeze(samples, axis=0) - elif isinstance(n_samples, int): - return self._sample(n_samples) - else: - n_samples = tf.convert_to_tensor(n_samples, dtype=tf.int32) - _assert_rank_op = tf.assert_rank( - n_samples, 0, - message="n_samples should be a scalar (0-D Tensor).") - with tf.control_dependencies([_assert_rank_op]): - samples = self._sample(n_samples) - return samples + # if n_samples is None: + # samples = self._sample(n_samples=1) + # return tf.squeeze(samples, axis=0) + # elif isinstance(n_samples, int): + # return self._sample(n_samples) + # else: + # n_samples = tf.convert_to_tensor(n_samples, dtype=tf.int32) + # _assert_rank_op = tf.assert_rank( + # n_samples, 0, + # message="n_samples should be a scalar (0-D Tensor).") + # with tf.control_dependencies([_assert_rank_op]): + # samples = self._sample(n_samples) + # return samples def _act(self, observation): """ diff --git a/tianshou/core/policy/stochastic.py b/tianshou/core/policy/stochastic.py index 822600a..37eb1be 100644 --- a/tianshou/core/policy/stochastic.py +++ b/tianshou/core/policy/stochastic.py @@ -76,7 +76,6 @@ class OnehotCategorical(StochasticPolicy): return sampled_action def _log_prob(self, sampled_action): - sampled_action_onehot = tf.one_hot(sampled_action, self.n_categories, dtype=self.act_dtype) return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=sampled_action, logits=self.logits) # given = tf.cast(given, self.param_dtype)