TrulyPPO/baselines/common/distributions.py
2020-01-17 12:30:26 +08:00

492 lines
19 KiB
Python

import tensorflow as tf
import numpy as np
import baselines.common.tf_util as U
from baselines.a2c.utils import fc
from tensorflow.python.ops import math_ops
class Pd(object):
"""
A particular probability distribution
"""
def flatparam(self):
raise NotImplementedError
def mode(self):
raise NotImplementedError
def neglogp(self, x):
# Usually it's easier to define the negative logprob
raise NotImplementedError
def kl(self, other):
raise NotImplementedError
def entropy(self):
raise NotImplementedError
def sample(self):
raise NotImplementedError
def logp(self, x):
return - self.neglogp(x)
def p(self,x):
return tf.exp( self.logp(x) )
class PdType(object):
"""
Parametrized family of probability distributions
"""
def pdclass(self):
raise NotImplementedError
def pdfromflat(self, flat):
return self.pdclass()(flat)
def pdfromlatent(self, latent_vector,**kwargs):
raise NotImplementedError
def param_shape(self):
raise NotImplementedError
def sample_shape(self):
raise NotImplementedError
def sample_dtype(self):
raise NotImplementedError
def param_placeholder(self, prepend_shape, name=None):
return tf.placeholder(dtype=tf.float32, shape=prepend_shape+self.param_shape(), name=name)
def sample_placeholder(self, prepend_shape, name=None, default_value=None):
if default_value is None:
return tf.placeholder(dtype=self.sample_dtype(), shape=prepend_shape+self.sample_shape(), name=name)
else:
return tf.placeholder_with_default(default_value, shape=prepend_shape+self.sample_shape(), name=name)
class CategoricalPdType(PdType):
def __init__(self, ncat):
self.ncat = ncat
def pdclass(self):
return CategoricalPd
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0, **kwargs):
pdparam = fc(latent_vector, 'pi', self.ncat, init_scale=init_scale, init_bias=init_bias)
return self.pdfromflat(pdparam), pdparam
def param_shape(self):
return [self.ncat]
def sample_shape(self):
return []
def sample_dtype(self):
return tf.int32
class MultiCategoricalPdType(PdType):
def __init__(self, nvec):
self.ncats = nvec
def pdclass(self):
return MultiCategoricalPd
def pdfromflat(self, flat):
return MultiCategoricalPd(self.ncats, flat)
def param_shape(self):
return [sum(self.ncats)]
def sample_shape(self):
return [len(self.ncats)]
def sample_dtype(self):
return tf.int32
class DiagGaussianPdType(PdType):
def __init__(self, size):
self.size = size
def pdclass(self):
return DiagGaussianPd
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0, latent_logstd=None, logstd_initial=0.):
# TODO: Add **kwargs to other PdType
mean = fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
if latent_logstd is not None:
logstd = fc(latent_logstd, 'logstd', self.size, init_scale=0.0, init_bias=logstd_initial)
else:
# logstd = tf.get_variable(name='logstd', shape=[1, self.size], initializer=tf.zeros_initializer() )
logstd = tf.get_variable(name='logstd', shape=[1, self.size], initializer=tf.constant_initializer(logstd_initial))
self.logstd = logstd
pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1)
return self.pdfromflat(pdparam), mean
def param_shape(self):
return [2*self.size]
def sample_shape(self):
return [self.size]
def sample_dtype(self):
return tf.float32
class BetaPdType(PdType):
def __init__(self, size):
self.size = size
def pdclass(self):
return BetaPd
# def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
# mean = fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
# logstd = tf.get_variable(name='logstd', shape=[1, self.size], initializer=tf.zeros_initializer())
# pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1)
# return self.pdfromflat(pdparam), mean
def param_shape(self):
return [2*self.size]
def sample_shape(self):
return [self.size]
def sample_dtype(self):
return tf.float32
class MultimodalPdType(PdType):
def __init__(self, size):
self.size = size
def pdclass(self):
return MultimodalPd
# def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
# mean = fc(latent_vector, 'pi', 2 * self.size, init_scale=init_scale, init_bias=init_bias)
# logstd = tf.get_variable(name='logstd', shape=[1, 2 * self.size], initializer=tf.zeros_initializer())
# alpha = tf.get_variable(name='alpha', shape=[1, self.size], initializer=tf.random_normal_initializer())
# pdparam = tf.concat([mean, mean * 0.0 + logstd, alpha], axis=1)
# return self.pdfromflat(pdparam), mean
def param_shape(self):
return [5*self.size]
def sample_shape(self):
return [self.size]
def sample_dtype(self):
return tf.float32
class BernoulliPdType(PdType):
def __init__(self, size):
self.size = size
def pdclass(self):
return BernoulliPd
def param_shape(self):
return [self.size]
def sample_shape(self):
return [self.size]
def sample_dtype(self):
return tf.int32
# WRONG SECOND DERIVATIVES
# class CategoricalPd(Pd):
# def __init__(self, logits):
# self.logits = logits
# self.ps = tf.nn.softmax(logits)
# @classmethod
# def fromflat(cls, flat):
# return cls(flat)
# def flatparam(self):
# return self.logits
# def mode(self):
# return U.argmax(self.logits, axis=-1)
# def logp(self, x):
# return -tf.nn.sparse_softmax_cross_entropy_with_logits(self.logits, x)
# def kl(self, other):
# return tf.nn.softmax_cross_entropy_with_logits(other.logits, self.ps) \
# - tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)
# def entropy(self):
# return tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)
# def sample(self):
# u = tf.random_uniform(tf.shape(self.logits))
# return U.argmax(self.logits - tf.log(-tf.log(u)), axis=-1)
class CategoricalPd(Pd):
def __init__(self, logits):
self.logits = logits
def flatparam(self):
return self.logits
def mode(self):
return tf.argmax(self.logits, axis=-1)
def neglogp(self, x):
# return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x)
# Note: we can't use sparse_softmax_cross_entropy_with_logits because
# the implementation does not allow second-order derivatives...
one_hot_actions = tf.one_hot(x, self.logits.get_shape().as_list()[-1])
return tf.nn.softmax_cross_entropy_with_logits(
logits=self.logits,
labels=one_hot_actions)
def kl(self, other):
a0 = self.logits - tf.reduce_max(self.logits, axis=-1, keep_dims=True)
a1 = other.logits - tf.reduce_max(other.logits, axis=-1, keep_dims=True)
ea0 = tf.exp(a0)
ea1 = tf.exp(a1)
z0 = tf.reduce_sum(ea0, axis=-1, keep_dims=True)
z1 = tf.reduce_sum(ea1, axis=-1, keep_dims=True)
p0 = ea0 / z0
return tf.reduce_sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=-1)
def entropy(self):
a0 = self.logits - tf.reduce_max(self.logits, axis=-1, keep_dims=True)
ea0 = tf.exp(a0)
z0 = tf.reduce_sum(ea0, axis=-1, keep_dims=True)
p0 = ea0 / z0
return tf.reduce_sum(p0 * (tf.log(z0) - a0), axis=-1)
def sample(self, sample_shape=None):
shape_new = tf.shape(self.logits)
if sample_shape is not None:
if not isinstance( sample_shape, list ) or not isinstance( sample_shape, tuple):
sample_shape = [sample_shape]
shape_new = tf.concat( ( sample_shape, shape_new ), axis=0 )
u = tf.random_uniform(shape_new)
return tf.argmax(self.logits - tf.log(-tf.log(u)), axis=-1)
@classmethod
def fromflat(cls, flat):
return cls(flat)
class MultiCategoricalPd(Pd):
def __init__(self, nvec, flat):
self.flat = flat
self.categoricals = list(map(CategoricalPd, tf.split(flat, nvec, axis=-1)))
def flatparam(self):
return self.flat
def mode(self):
return tf.cast(tf.stack([p.mode() for p in self.categoricals], axis=-1), tf.int32)
def neglogp(self, x):
return tf.add_n([p.neglogp(px) for p, px in zip(self.categoricals, tf.unstack(x, axis=-1))])
def kl(self, other):
return tf.add_n([p.kl(q) for p, q in zip(self.categoricals, other.categoricals)])
def entropy(self):
return tf.add_n([p.entropy() for p in self.categoricals])
def sample(self):
return tf.cast(tf.stack([p.sample() for p in self.categoricals], axis=-1), tf.int32)
@classmethod
def fromflat(cls, flat):
raise NotImplementedError
class DiagGaussianPd(Pd):
def __init__(self, flat, var_type='logstd'):
self.flat = flat
mean, var_term = tf.split(axis=len(flat.shape)-1, num_or_size_splits=2, value=flat)
self.mean = mean
'''
std: standard deviation
var: var=std^2. variance
logstd: log(std)
logvar: log(var)
'''
if var_type == 'logstd':
logstd = var_term
std = tf.exp(var_term)
elif var_type == 'std':
std = var_term
logstd = tf.log(std)
elif var_type == 'logvar':
logstd = var_term /2
std = tf.exp(var_term)
elif var_type == 'var':
std = tf.sqrt(var_term)
logstd = tf.log(std)
else:
raise NotImplementedError
self.logstd = logstd
self.std = std
def flatparam(self):
return self.flat
def mode(self):
return self.mean
def neglogp(self, x):
return 0.5 * tf.reduce_sum(tf.square((x - self.mean) / self.std), axis=-1) \
+ 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[-1]) \
+ tf.reduce_sum(self.logstd, axis=-1)
def kl(self, other):
assert isinstance(other, DiagGaussianPd)
return tf.reduce_sum(other.logstd - self.logstd + (tf.square(self.std) + tf.square(self.mean - other.mean)) / (2.0 * tf.square(other.std)) - 0.5, axis=-1)
def wasserstein(self, other):
assert isinstance(other, DiagGaussianPd)
return tf.reduce_sum( tf.square(self.mean-other.mean) + tf.square(self.std-other.std ) , axis=-1)
def entropy(self):
return tf.reduce_sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e), axis=-1)
def sample(self, sample_shape=None):
shape_new = tf.shape(self.mean)
if sample_shape is not None:
if not isinstance( sample_shape, list ) or not isinstance( sample_shape, tuple):
sample_shape = [sample_shape]
shape_new = tf.concat( ( sample_shape, shape_new ), axis=0 )
return self.mean + self.std * tf.random_normal(shape_new)
@classmethod
def fromflat(cls, flat):
return cls(flat)
class BetaPd(Pd):
'''
Note that the paramter is logstd, not logsigma
'''
def __init__(self, flat, std_type='log'):
self.flat = flat
alpha, beta = tf.split(axis=len(flat.shape)-1, num_or_size_splits=2, value=flat)
self.alpha = alpha# = tf.squeeze(alpha)
#beta = tf.squeeze(beta)
tfd = tf.contrib.distributions
self._pd = tfd.Beta(alpha, beta)
self.low, self.high = 0., 5. # map the beta dist support (0, 1) to (low, high)
def flatparam(self):
return self.flat
def neglogp(self, x):
# x = tf.squeeze(x)
return -tf.log(self._pd.prob(self._inverse_map_to_support(x)) / (self.high - self.low))
def sample(self, sample_shape=None):
# shape_new = tf.shape(self.alpha)
# if sample_shape is not None:
# if not isinstance( sample_shape, list ) or not isinstance( sample_shape, tuple):
# sample_shape = [sample_shape]
# shape_new = tf.concat( ( sample_shape, shape_new ), axis=0 )
#
# return self.mean + self.std * tf.random_normal(shape_new)
return tf.expand_dims(self._map_to_support(self._pd.sample()), 0)
def _map_to_support(self, x):
# map the beta dist support (0, 1) to (low, high)
return self.low + x * (self.high - self.low)
def _inverse_map_to_support(self, x):
return (x - self.low) / (self.high - self.low)
@classmethod
def fromflat(cls, flat):
return cls(flat)
class MultimodalPd(Pd):
def __init__(self, flat, std_type='log'):
self.flat = flat
self.mean1, self.mean2, self.logstd1, self.logstd2, self.alpha = tf.split(axis=len(flat.shape)-1,
num_or_size_splits=5, value=flat)
self.std1 = tf.exp(self.logstd1)
self.std2 = tf.exp(self.logstd2)
self.alpha = tf.sigmoid(self.alpha)
self.alpha = tf.squeeze(self.alpha, 0)
self._pdtype = DiagGaussianPdType(self.mean1.shape[0])
self._pd1 = self._pdtype.pdfromflat(tf.concat([self.mean1, self.logstd1], axis=-1))
self._pd2 = self._pdtype.pdfromflat(tf.concat([self.mean2, self.logstd2], axis=-1))
def flatparam(self):
return self.flat
def mode(self):
return self.mean1, self.mean2
def neglogp(self, x):
# return tf.squeeze(-tf.log(self.alpha * tf.exp(-self._pd1.neglogp(x)) + (1 - self.alpha) * tf.exp(-self._pd2.neglogp(x))), 0)
# return -tf.log(self.alpha * tf.exp(-self._pd1.neglogp(x)) + (1 - self.alpha) * tf.exp(-self._pd2.neglogp(x)))
return -tf.log(self.alpha * self._pd1.p(x) + (1 - self.alpha) * self._pd2.p(x))
def sample(self, sample_shape=None):
return self.alpha * self._pd1.sample() + (1 - self.alpha) * self._pd2.sample()
@classmethod
def fromflat(cls, flat):
return cls(flat)
class BernoulliPd(Pd):
def __init__(self, logits):
self.logits = logits
self.ps = tf.sigmoid(logits)
def flatparam(self):
return self.logits
def mode(self):
return tf.round(self.ps)
def neglogp(self, x):
return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=tf.to_float(x)), axis=-1)
def kl(self, other):
return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=other.logits, labels=self.ps), axis=-1) - tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1)
def entropy(self):
return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1)
def sample(self):
u = tf.random_uniform(tf.shape(self.ps))
return tf.to_float(math_ops.less(u, self.ps))
@classmethod
def fromflat(cls, flat):
return cls(flat)
def make_pdtype(ac_space):
from gym import spaces
if isinstance(ac_space, spaces.Box):
assert len(ac_space.shape) == 1
return DiagGaussianPdType(ac_space.shape[0])
elif isinstance(ac_space, spaces.Discrete):
return CategoricalPdType(ac_space.n)
elif isinstance(ac_space, spaces.MultiDiscrete):
return MultiCategoricalPdType(ac_space.nvec)
elif isinstance(ac_space, spaces.MultiBinary):
return BernoulliPdType(ac_space.n)
else:
raise NotImplementedError
def make_pdtype_bandit(ac_space, policy_type):
from gym import spaces
if isinstance(ac_space, spaces.Box):
assert len(ac_space.shape) == 1
if policy_type == 'Gaussian':
return DiagGaussianPdType(ac_space.shape[0])
elif policy_type == 'Beta':
return BetaPdType(ac_space.shape[0])
elif policy_type == 'Multimodal':
return MultimodalPdType(ac_space.shape[0])
elif isinstance(ac_space, spaces.Discrete):
return CategoricalPdType(ac_space.n)
elif isinstance(ac_space, spaces.MultiDiscrete):
return MultiCategoricalPdType(ac_space.nvec)
elif isinstance(ac_space, spaces.MultiBinary):
return BernoulliPdType(ac_space.n)
else:
raise NotImplementedError
def shape_el(v, i):
maybe = v.get_shape()[i]
if maybe is not None:
return maybe
else:
return tf.shape(v)[i]
@U.in_session
def test_probtypes():
np.random.seed(0)
pdparam_diag_gauss = np.array([-.2, .3, .4, -.5, .1, -.5, .1, 0.8])
diag_gauss = DiagGaussianPdType(pdparam_diag_gauss.size // 2) #pylint: disable=E1101
validate_probtype(diag_gauss, pdparam_diag_gauss)
pdparam_categorical = np.array([-.2, .3, .5])
categorical = CategoricalPdType(pdparam_categorical.size) #pylint: disable=E1101
validate_probtype(categorical, pdparam_categorical)
nvec = [1,2,3]
pdparam_multicategorical = np.array([-.2, .3, .5, .1, 1, -.1])
multicategorical = MultiCategoricalPdType(nvec) #pylint: disable=E1101
validate_probtype(multicategorical, pdparam_multicategorical)
pdparam_bernoulli = np.array([-.2, .3, .5])
bernoulli = BernoulliPdType(pdparam_bernoulli.size) #pylint: disable=E1101
validate_probtype(bernoulli, pdparam_bernoulli)
def validate_probtype(probtype, pdparam):
N = 100000
# Check to see if mean negative log likelihood == differential entropy
Mval = np.repeat(pdparam[None, :], N, axis=0)
M = probtype.param_placeholder([N])
X = probtype.sample_placeholder([N])
pd = probtype.pdfromflat(M)
calcloglik = U.function([X, M], pd.logp(X))
calcent = U.function([M], pd.entropy())
Xval = tf.get_default_session().run(pd.sample(), feed_dict={M:Mval})
logliks = calcloglik(Xval, Mval)
entval_ll = - logliks.mean() #pylint: disable=E1101
entval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101
entval = calcent(Mval).mean() #pylint: disable=E1101
assert np.abs(entval - entval_ll) < 3 * entval_ll_stderr # within 3 sigmas
# Check to see if kldiv[p,q] = - ent[p] - E_p[log q]
M2 = probtype.param_placeholder([N])
pd2 = probtype.pdfromflat(M2)
q = pdparam + np.random.randn(pdparam.size) * 0.1
Mval2 = np.repeat(q[None, :], N, axis=0)
calckl = U.function([M, M2], pd.kl(pd2))
klval = calckl(Mval, Mval2).mean() #pylint: disable=E1101
logliks = calcloglik(Xval, Mval2)
klval_ll = - entval - logliks.mean() #pylint: disable=E1101
klval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101
assert np.abs(klval - klval_ll) < 3 * klval_ll_stderr # within 3 sigmas
print('ok on', probtype, pdparam)