203 lines
7.8 KiB
Python
Raw Normal View History

#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
import tensorflow as tf
# from zhusuan.utils import add_name_scope
# TODO: a even more "base" class for policy
class PolicyBase(object):
"""
base class for policy. only provides `act` method with exploration
"""
def act(self, observation):
raise NotImplementedError()
class StochasticPolicy(PolicyBase):
"""
The :class:`StochasticPolicy` class is the base class for various probabilistic
distributions which support batch inputs, generating batches of samples and
evaluate probabilities at batches of given values.
The typical input shape for a :class:`Distribution` is like
``batch_shape + input_shape``. where ``input_shape`` represents the shape
of non-batch input parameter, :attr:`batch_shape` represents how many
independent inputs are fed into the distribution.
Samples generated are of shape
``([n_samples]+ )batch_shape + value_shape``. The first additional axis
is omitted only when passed `n_samples` is None (by default), in which
case one sample is generated. :attr:`value_shape` is the non-batch value
shape of the distribution. For a univariate distribution, its
:attr:`value_shape` is [].
There are cases where a batch of random variables are grouped into a
single event so that their probabilities should be computed together. This
is achieved by setting `group_ndims` argument, which defaults to 0.
The last `group_ndims` number of axes in :attr:`batch_shape` are
grouped into a single event. For example,
``Normal(..., group_ndims=1)`` will set the last axis of its
:attr:`batch_shape` to a single event, i.e., a multivariate Normal with
identity covariance matrix.
When evaluating probabilities at given values, the given Tensor should be
broadcastable to shape ``(... + )batch_shape + value_shape``. The returned
Tensor has shape ``(... + )batch_shape[:-group_ndims]``.
.. seealso::
:doc:`/concepts`
For both, the parameter `dtype` represents type of samples. For discrete,
can be set by user. For continuous, automatically determined from parameter
types.
The value type of `prob` and `log_prob` will be `param_dtype` which is
deduced from the parameter(s) when initializating. And `dtype` must be
among `int16`, `int32`, `int64`, `float16`, `float32` and `float64`.
When two or more parameters are tensors and they have different type,
`TypeError` will be raised.
:param dtype: The value type of samples from the distribution.
:param param_dtype: The parameter(s) type of the distribution.
:param is_continuous: Whether the distribution is continuous.
:param is_reparameterized: A bool. Whether the gradients of samples can
and are allowed to propagate back into inputs, using the
reparametrization trick from (Kingma, 2013).
:param use_path_derivative: A bool. Whether when taking the gradients
of the log-probability to propagate them through the parameters
of the distribution (False meaning you do propagate them). This
is based on the paper "Sticking the Landing: Simple,
Lower-Variance Gradient Estimators for Variational Inference"
:param group_ndims: A 0-D `int32` Tensor representing the number of
dimensions in :attr:`batch_shape` (counted from the end) that are
grouped into a single event, so that their probabilities are calculated
together. Default is 0, which means a single value is an event.
See above for more detailed explanation.
"""
def __init__(self,
act_dtype,
param_dtype,
is_continuous,
observation_placeholder,
2017-12-23 17:25:16 +08:00
group_ndims=0, # maybe useful for repeat_action
**kwargs):
self._act_dtype = act_dtype
self._param_dtype = param_dtype
self._is_continuous = is_continuous
self._observation_placeholder = observation_placeholder
if isinstance(group_ndims, int):
if group_ndims < 0:
raise ValueError("group_ndims must be non-negative.")
self._group_ndims = group_ndims
else:
group_ndims = tf.convert_to_tensor(group_ndims, tf.int32)
_assert_rank_op = tf.assert_rank(
group_ndims, 0,
message="group_ndims should be a scalar (0-D Tensor).")
_assert_nonnegative_op = tf.assert_greater_equal(
group_ndims, 0,
message="group_ndims must be non-negative.")
with tf.control_dependencies([_assert_rank_op,
_assert_nonnegative_op]):
self._group_ndims = tf.identity(group_ndims)
@property
def act_dtype(self):
"""The sample data type of the policy."""
return self._act_dtype
@property
def param_dtype(self):
"""The parameter(s) type of the distribution."""
return self._param_dtype
@property
def is_continuous(self):
"""Whether the distribution is continuous."""
return self._is_continuous
@property
def group_ndims(self):
"""
The number of dimensions in :attr:`batch_shape` (counted from the end)
that are grouped into a single event, so that their probabilities are
calculated together. See `Distribution` for more detailed explanation.
"""
return self._group_ndims
# @add_name_scope
def act(self, observation, my_feed_dict={}):
"""
sample(n_samples=None)
Return samples from the distribution. When `n_samples` is None (by
default), one sample of shape ``batch_shape + value_shape`` is
generated. For a scalar `n_samples`, the returned Tensor has a new
sample dimension with size `n_samples` inserted at ``axis=0``, i.e.,
the shape of samples is ``[n_samples] + batch_shape + value_shape``.
:param n_samples: A 0-D `int32` Tensor or None. How many independent
samples to draw from the distribution.
:return: A Tensor of samples.
"""
return self._act(observation, my_feed_dict)
def _act(self, observation, my_feed_dict):
"""
Private method for subclasses to rewrite the :meth:`sample` method.
"""
raise NotImplementedError()
# @add_name_scope
def log_prob(self, sampled_action):
"""
log_prob(sampled_action)
Compute log probability density (mass) function at `given` value.
:param given: A Tensor. The value at which to evaluate log probability
density (mass) function. Must be able to broadcast to have a shape
of ``(... + )batch_shape + value_shape``.
:return: A Tensor of shape ``(... + )batch_shape[:-group_ndims]``.
"""
log_p = self._log_prob(sampled_action)
return tf.reduce_sum(log_p, tf.range(-self._group_ndims, 0))
# @add_name_scope
def prob(self, sampled_action):
"""
prob(given)
Compute probability density (mass) function at `given` value.
:param given: A Tensor. The value at which to evaluate probability
density (mass) function. Must be able to broadcast to have a shape
of ``(... + )batch_shape + value_shape``.
:return: A Tensor of shape ``(... + )batch_shape[:-group_ndims]``.
"""
p = self._prob(sampled_action)
return tf.reduce_prod(p, tf.range(-self._group_ndims, 0))
def _log_prob(self, sampled_action):
"""
Private method for subclasses to rewrite the :meth:`log_prob` method.
"""
raise NotImplementedError()
def _prob(self, sampled_action):
"""
Private method for subclasses to rewrite the :meth:`prob` method.
"""
raise NotImplementedError()