106 lines
3.8 KiB
Python
106 lines
3.8 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
from .base import StochasticPolicy
|
|
|
|
|
|
__all__ = [
|
|
'OnehotCategorical',
|
|
'OnehotDiscrete',
|
|
]
|
|
|
|
|
|
class OnehotCategorical(StochasticPolicy):
|
|
"""
|
|
The class of one-hot Categorical distribution.
|
|
See :class:`~zhusuan.distributions.base.Distribution` for details.
|
|
|
|
:param logits: A N-D (N >= 1) `float` Tensor of shape (...,
|
|
n_categories). Each slice `[i, j, ..., k, :]` represents the
|
|
un-normalized log probabilities for all categories.
|
|
|
|
.. math:: \\mathrm{logits} \\propto \\log p
|
|
|
|
:param dtype: The value type of samples from the distribution.
|
|
:param group_ndims: A 0-D `int32` Tensor representing the number of
|
|
dimensions in `batch_shape` (counted from the end) that are grouped
|
|
into a single event, so that their probabilities are calculated
|
|
together. Default is 0, which means a single value is an event.
|
|
See :class:`~zhusuan.distributions.base.Distribution` for more detailed
|
|
explanation.
|
|
|
|
A single sample is a N-D Tensor with the same shape as logits. Each slice
|
|
`[i, j, ..., k, :]` is a one-hot vector of the selected category.
|
|
"""
|
|
|
|
def __init__(self, logits, observation_placeholder, dtype=None, group_ndims=0, **kwargs):
|
|
self._logits = tf.convert_to_tensor(logits)
|
|
|
|
if dtype is None:
|
|
dtype = tf.int32
|
|
# assert_same_float_and_int_dtype([], dtype)
|
|
|
|
tf.assert_rank(self._logits, rank=2) # TODO: flexible policy output rank?
|
|
self._n_categories = self._logits.get_shape()[-1].value
|
|
|
|
super(OnehotCategorical, self).__init__(
|
|
act_dtype=dtype,
|
|
param_dtype=self._logits.dtype,
|
|
is_continuous=False,
|
|
observation_placeholder=observation_placeholder,
|
|
group_ndims=group_ndims,
|
|
**kwargs)
|
|
|
|
@property
|
|
def logits(self):
|
|
"""The un-normalized log probabilities."""
|
|
return self._logits
|
|
|
|
@property
|
|
def n_categories(self):
|
|
"""The number of categories in the distribution."""
|
|
return self._n_categories
|
|
|
|
def _act(self, observation):
|
|
sess = tf.get_default_session() # TODO: this may be ugly. also maybe huge problem when parallel
|
|
sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1), feed_dict={self._observation_placeholder: observation[None]})
|
|
# observation[None] adds one dimension at the beginning
|
|
|
|
sampled_action = sampled_action[0, 0]
|
|
|
|
return sampled_action
|
|
|
|
def _log_prob(self, sampled_action):
|
|
return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=sampled_action, logits=self.logits)
|
|
|
|
# given = tf.cast(given, self.param_dtype)
|
|
# given, logits = maybe_explicit_broadcast(
|
|
# given, self.logits, 'given', 'logits')
|
|
# if (given.get_shape().ndims == 2) or (logits.get_shape().ndims == 2):
|
|
# given_flat = given
|
|
# logits_flat = logits
|
|
# else:
|
|
# given_flat = tf.reshape(given, [-1, self.n_categories])
|
|
# logits_flat = tf.reshape(logits, [-1, self.n_categories])
|
|
# log_p_flat = -tf.nn.softmax_cross_entropy_with_logits(
|
|
# labels=given_flat, logits=logits_flat)
|
|
# if (given.get_shape().ndims == 2) or (logits.get_shape().ndims == 2):
|
|
# log_p = log_p_flat
|
|
# else:
|
|
# log_p = tf.reshape(log_p_flat, tf.shape(logits)[:-1])
|
|
# if given.get_shape() and logits.get_shape():
|
|
# log_p.set_shape(tf.broadcast_static_shape(
|
|
# given.get_shape(), logits.get_shape())[:-1])
|
|
# return log_p
|
|
|
|
def _prob(self, sampled_action):
|
|
return tf.exp(self._log_prob(sampled_action))
|
|
|
|
|
|
OnehotDiscrete = OnehotCategorical |