492 lines
19 KiB
Python
492 lines
19 KiB
Python
import tensorflow as tf
|
|
import numpy as np
|
|
import baselines.common.tf_util as U
|
|
from baselines.a2c.utils import fc
|
|
from tensorflow.python.ops import math_ops
|
|
|
|
class Pd(object):
|
|
"""
|
|
A particular probability distribution
|
|
"""
|
|
def flatparam(self):
|
|
raise NotImplementedError
|
|
def mode(self):
|
|
raise NotImplementedError
|
|
def neglogp(self, x):
|
|
# Usually it's easier to define the negative logprob
|
|
raise NotImplementedError
|
|
def kl(self, other):
|
|
raise NotImplementedError
|
|
def entropy(self):
|
|
raise NotImplementedError
|
|
def sample(self):
|
|
raise NotImplementedError
|
|
def logp(self, x):
|
|
return - self.neglogp(x)
|
|
def p(self,x):
|
|
return tf.exp( self.logp(x) )
|
|
|
|
class PdType(object):
|
|
"""
|
|
Parametrized family of probability distributions
|
|
"""
|
|
def pdclass(self):
|
|
raise NotImplementedError
|
|
def pdfromflat(self, flat):
|
|
return self.pdclass()(flat)
|
|
def pdfromlatent(self, latent_vector,**kwargs):
|
|
raise NotImplementedError
|
|
def param_shape(self):
|
|
raise NotImplementedError
|
|
def sample_shape(self):
|
|
raise NotImplementedError
|
|
def sample_dtype(self):
|
|
raise NotImplementedError
|
|
|
|
def param_placeholder(self, prepend_shape, name=None):
|
|
return tf.placeholder(dtype=tf.float32, shape=prepend_shape+self.param_shape(), name=name)
|
|
def sample_placeholder(self, prepend_shape, name=None, default_value=None):
|
|
if default_value is None:
|
|
return tf.placeholder(dtype=self.sample_dtype(), shape=prepend_shape+self.sample_shape(), name=name)
|
|
else:
|
|
return tf.placeholder_with_default(default_value, shape=prepend_shape+self.sample_shape(), name=name)
|
|
|
|
class CategoricalPdType(PdType):
|
|
def __init__(self, ncat):
|
|
self.ncat = ncat
|
|
|
|
def pdclass(self):
|
|
return CategoricalPd
|
|
|
|
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0, **kwargs):
|
|
pdparam = fc(latent_vector, 'pi', self.ncat, init_scale=init_scale, init_bias=init_bias)
|
|
return self.pdfromflat(pdparam), pdparam
|
|
|
|
def param_shape(self):
|
|
return [self.ncat]
|
|
|
|
def sample_shape(self):
|
|
return []
|
|
def sample_dtype(self):
|
|
return tf.int32
|
|
|
|
|
|
class MultiCategoricalPdType(PdType):
|
|
def __init__(self, nvec):
|
|
self.ncats = nvec
|
|
def pdclass(self):
|
|
return MultiCategoricalPd
|
|
def pdfromflat(self, flat):
|
|
return MultiCategoricalPd(self.ncats, flat)
|
|
def param_shape(self):
|
|
return [sum(self.ncats)]
|
|
def sample_shape(self):
|
|
return [len(self.ncats)]
|
|
def sample_dtype(self):
|
|
return tf.int32
|
|
|
|
class DiagGaussianPdType(PdType):
|
|
def __init__(self, size):
|
|
self.size = size
|
|
def pdclass(self):
|
|
return DiagGaussianPd
|
|
|
|
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0, latent_logstd=None, logstd_initial=0.):
|
|
# TODO: Add **kwargs to other PdType
|
|
mean = fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
|
|
if latent_logstd is not None:
|
|
logstd = fc(latent_logstd, 'logstd', self.size, init_scale=0.0, init_bias=logstd_initial)
|
|
else:
|
|
# logstd = tf.get_variable(name='logstd', shape=[1, self.size], initializer=tf.zeros_initializer() )
|
|
logstd = tf.get_variable(name='logstd', shape=[1, self.size], initializer=tf.constant_initializer(logstd_initial))
|
|
self.logstd = logstd
|
|
pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1)
|
|
return self.pdfromflat(pdparam), mean
|
|
|
|
def param_shape(self):
|
|
return [2*self.size]
|
|
def sample_shape(self):
|
|
return [self.size]
|
|
def sample_dtype(self):
|
|
return tf.float32
|
|
|
|
class BetaPdType(PdType):
|
|
def __init__(self, size):
|
|
self.size = size
|
|
def pdclass(self):
|
|
return BetaPd
|
|
# def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
|
|
# mean = fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
|
|
# logstd = tf.get_variable(name='logstd', shape=[1, self.size], initializer=tf.zeros_initializer())
|
|
# pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1)
|
|
# return self.pdfromflat(pdparam), mean
|
|
|
|
def param_shape(self):
|
|
return [2*self.size]
|
|
def sample_shape(self):
|
|
return [self.size]
|
|
def sample_dtype(self):
|
|
return tf.float32
|
|
|
|
class MultimodalPdType(PdType):
|
|
def __init__(self, size):
|
|
self.size = size
|
|
def pdclass(self):
|
|
return MultimodalPd
|
|
# def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
|
|
# mean = fc(latent_vector, 'pi', 2 * self.size, init_scale=init_scale, init_bias=init_bias)
|
|
# logstd = tf.get_variable(name='logstd', shape=[1, 2 * self.size], initializer=tf.zeros_initializer())
|
|
# alpha = tf.get_variable(name='alpha', shape=[1, self.size], initializer=tf.random_normal_initializer())
|
|
# pdparam = tf.concat([mean, mean * 0.0 + logstd, alpha], axis=1)
|
|
# return self.pdfromflat(pdparam), mean
|
|
|
|
def param_shape(self):
|
|
return [5*self.size]
|
|
def sample_shape(self):
|
|
return [self.size]
|
|
def sample_dtype(self):
|
|
return tf.float32
|
|
|
|
|
|
class BernoulliPdType(PdType):
|
|
def __init__(self, size):
|
|
self.size = size
|
|
def pdclass(self):
|
|
return BernoulliPd
|
|
def param_shape(self):
|
|
return [self.size]
|
|
def sample_shape(self):
|
|
return [self.size]
|
|
def sample_dtype(self):
|
|
return tf.int32
|
|
|
|
# WRONG SECOND DERIVATIVES
|
|
# class CategoricalPd(Pd):
|
|
# def __init__(self, logits):
|
|
# self.logits = logits
|
|
# self.ps = tf.nn.softmax(logits)
|
|
# @classmethod
|
|
# def fromflat(cls, flat):
|
|
# return cls(flat)
|
|
# def flatparam(self):
|
|
# return self.logits
|
|
# def mode(self):
|
|
# return U.argmax(self.logits, axis=-1)
|
|
# def logp(self, x):
|
|
# return -tf.nn.sparse_softmax_cross_entropy_with_logits(self.logits, x)
|
|
# def kl(self, other):
|
|
# return tf.nn.softmax_cross_entropy_with_logits(other.logits, self.ps) \
|
|
# - tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)
|
|
# def entropy(self):
|
|
# return tf.nn.softmax_cross_entropy_with_logits(self.logits, self.ps)
|
|
# def sample(self):
|
|
# u = tf.random_uniform(tf.shape(self.logits))
|
|
# return U.argmax(self.logits - tf.log(-tf.log(u)), axis=-1)
|
|
|
|
class CategoricalPd(Pd):
|
|
def __init__(self, logits):
|
|
self.logits = logits
|
|
def flatparam(self):
|
|
return self.logits
|
|
def mode(self):
|
|
return tf.argmax(self.logits, axis=-1)
|
|
def neglogp(self, x):
|
|
# return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x)
|
|
# Note: we can't use sparse_softmax_cross_entropy_with_logits because
|
|
# the implementation does not allow second-order derivatives...
|
|
one_hot_actions = tf.one_hot(x, self.logits.get_shape().as_list()[-1])
|
|
return tf.nn.softmax_cross_entropy_with_logits(
|
|
logits=self.logits,
|
|
labels=one_hot_actions)
|
|
def kl(self, other):
|
|
a0 = self.logits - tf.reduce_max(self.logits, axis=-1, keep_dims=True)
|
|
a1 = other.logits - tf.reduce_max(other.logits, axis=-1, keep_dims=True)
|
|
ea0 = tf.exp(a0)
|
|
ea1 = tf.exp(a1)
|
|
z0 = tf.reduce_sum(ea0, axis=-1, keep_dims=True)
|
|
z1 = tf.reduce_sum(ea1, axis=-1, keep_dims=True)
|
|
p0 = ea0 / z0
|
|
return tf.reduce_sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)), axis=-1)
|
|
def entropy(self):
|
|
a0 = self.logits - tf.reduce_max(self.logits, axis=-1, keep_dims=True)
|
|
ea0 = tf.exp(a0)
|
|
z0 = tf.reduce_sum(ea0, axis=-1, keep_dims=True)
|
|
p0 = ea0 / z0
|
|
return tf.reduce_sum(p0 * (tf.log(z0) - a0), axis=-1)
|
|
def sample(self, sample_shape=None):
|
|
shape_new = tf.shape(self.logits)
|
|
if sample_shape is not None:
|
|
if not isinstance( sample_shape, list ) or not isinstance( sample_shape, tuple):
|
|
sample_shape = [sample_shape]
|
|
shape_new = tf.concat( ( sample_shape, shape_new ), axis=0 )
|
|
u = tf.random_uniform(shape_new)
|
|
return tf.argmax(self.logits - tf.log(-tf.log(u)), axis=-1)
|
|
@classmethod
|
|
def fromflat(cls, flat):
|
|
return cls(flat)
|
|
|
|
class MultiCategoricalPd(Pd):
|
|
def __init__(self, nvec, flat):
|
|
self.flat = flat
|
|
self.categoricals = list(map(CategoricalPd, tf.split(flat, nvec, axis=-1)))
|
|
def flatparam(self):
|
|
return self.flat
|
|
def mode(self):
|
|
return tf.cast(tf.stack([p.mode() for p in self.categoricals], axis=-1), tf.int32)
|
|
def neglogp(self, x):
|
|
return tf.add_n([p.neglogp(px) for p, px in zip(self.categoricals, tf.unstack(x, axis=-1))])
|
|
def kl(self, other):
|
|
return tf.add_n([p.kl(q) for p, q in zip(self.categoricals, other.categoricals)])
|
|
def entropy(self):
|
|
return tf.add_n([p.entropy() for p in self.categoricals])
|
|
def sample(self):
|
|
return tf.cast(tf.stack([p.sample() for p in self.categoricals], axis=-1), tf.int32)
|
|
@classmethod
|
|
def fromflat(cls, flat):
|
|
raise NotImplementedError
|
|
|
|
class DiagGaussianPd(Pd):
|
|
def __init__(self, flat, var_type='logstd'):
|
|
self.flat = flat
|
|
mean, var_term = tf.split(axis=len(flat.shape)-1, num_or_size_splits=2, value=flat)
|
|
self.mean = mean
|
|
'''
|
|
std: standard deviation
|
|
var: var=std^2. variance
|
|
logstd: log(std)
|
|
logvar: log(var)
|
|
'''
|
|
if var_type == 'logstd':
|
|
logstd = var_term
|
|
std = tf.exp(var_term)
|
|
elif var_type == 'std':
|
|
std = var_term
|
|
logstd = tf.log(std)
|
|
elif var_type == 'logvar':
|
|
logstd = var_term /2
|
|
std = tf.exp(var_term)
|
|
elif var_type == 'var':
|
|
std = tf.sqrt(var_term)
|
|
logstd = tf.log(std)
|
|
else:
|
|
raise NotImplementedError
|
|
self.logstd = logstd
|
|
self.std = std
|
|
|
|
def flatparam(self):
|
|
return self.flat
|
|
def mode(self):
|
|
return self.mean
|
|
def neglogp(self, x):
|
|
return 0.5 * tf.reduce_sum(tf.square((x - self.mean) / self.std), axis=-1) \
|
|
+ 0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[-1]) \
|
|
+ tf.reduce_sum(self.logstd, axis=-1)
|
|
def kl(self, other):
|
|
assert isinstance(other, DiagGaussianPd)
|
|
return tf.reduce_sum(other.logstd - self.logstd + (tf.square(self.std) + tf.square(self.mean - other.mean)) / (2.0 * tf.square(other.std)) - 0.5, axis=-1)
|
|
|
|
def wasserstein(self, other):
|
|
assert isinstance(other, DiagGaussianPd)
|
|
return tf.reduce_sum( tf.square(self.mean-other.mean) + tf.square(self.std-other.std ) , axis=-1)
|
|
|
|
def entropy(self):
|
|
return tf.reduce_sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e), axis=-1)
|
|
def sample(self, sample_shape=None):
|
|
shape_new = tf.shape(self.mean)
|
|
if sample_shape is not None:
|
|
if not isinstance( sample_shape, list ) or not isinstance( sample_shape, tuple):
|
|
sample_shape = [sample_shape]
|
|
shape_new = tf.concat( ( sample_shape, shape_new ), axis=0 )
|
|
|
|
return self.mean + self.std * tf.random_normal(shape_new)
|
|
@classmethod
|
|
def fromflat(cls, flat):
|
|
return cls(flat)
|
|
|
|
class BetaPd(Pd):
|
|
'''
|
|
Note that the paramter is logstd, not logsigma
|
|
'''
|
|
def __init__(self, flat, std_type='log'):
|
|
self.flat = flat
|
|
alpha, beta = tf.split(axis=len(flat.shape)-1, num_or_size_splits=2, value=flat)
|
|
self.alpha = alpha# = tf.squeeze(alpha)
|
|
#beta = tf.squeeze(beta)
|
|
tfd = tf.contrib.distributions
|
|
self._pd = tfd.Beta(alpha, beta)
|
|
self.low, self.high = 0., 5. # map the beta dist support (0, 1) to (low, high)
|
|
|
|
|
|
def flatparam(self):
|
|
return self.flat
|
|
|
|
def neglogp(self, x):
|
|
# x = tf.squeeze(x)
|
|
return -tf.log(self._pd.prob(self._inverse_map_to_support(x)) / (self.high - self.low))
|
|
|
|
def sample(self, sample_shape=None):
|
|
# shape_new = tf.shape(self.alpha)
|
|
# if sample_shape is not None:
|
|
# if not isinstance( sample_shape, list ) or not isinstance( sample_shape, tuple):
|
|
# sample_shape = [sample_shape]
|
|
# shape_new = tf.concat( ( sample_shape, shape_new ), axis=0 )
|
|
#
|
|
# return self.mean + self.std * tf.random_normal(shape_new)
|
|
|
|
return tf.expand_dims(self._map_to_support(self._pd.sample()), 0)
|
|
|
|
def _map_to_support(self, x):
|
|
# map the beta dist support (0, 1) to (low, high)
|
|
return self.low + x * (self.high - self.low)
|
|
|
|
def _inverse_map_to_support(self, x):
|
|
return (x - self.low) / (self.high - self.low)
|
|
|
|
@classmethod
|
|
def fromflat(cls, flat):
|
|
return cls(flat)
|
|
|
|
class MultimodalPd(Pd):
|
|
def __init__(self, flat, std_type='log'):
|
|
self.flat = flat
|
|
self.mean1, self.mean2, self.logstd1, self.logstd2, self.alpha = tf.split(axis=len(flat.shape)-1,
|
|
num_or_size_splits=5, value=flat)
|
|
self.std1 = tf.exp(self.logstd1)
|
|
self.std2 = tf.exp(self.logstd2)
|
|
self.alpha = tf.sigmoid(self.alpha)
|
|
self.alpha = tf.squeeze(self.alpha, 0)
|
|
|
|
self._pdtype = DiagGaussianPdType(self.mean1.shape[0])
|
|
self._pd1 = self._pdtype.pdfromflat(tf.concat([self.mean1, self.logstd1], axis=-1))
|
|
self._pd2 = self._pdtype.pdfromflat(tf.concat([self.mean2, self.logstd2], axis=-1))
|
|
|
|
def flatparam(self):
|
|
return self.flat
|
|
def mode(self):
|
|
return self.mean1, self.mean2
|
|
def neglogp(self, x):
|
|
# return tf.squeeze(-tf.log(self.alpha * tf.exp(-self._pd1.neglogp(x)) + (1 - self.alpha) * tf.exp(-self._pd2.neglogp(x))), 0)
|
|
# return -tf.log(self.alpha * tf.exp(-self._pd1.neglogp(x)) + (1 - self.alpha) * tf.exp(-self._pd2.neglogp(x)))
|
|
|
|
return -tf.log(self.alpha * self._pd1.p(x) + (1 - self.alpha) * self._pd2.p(x))
|
|
|
|
def sample(self, sample_shape=None):
|
|
return self.alpha * self._pd1.sample() + (1 - self.alpha) * self._pd2.sample()
|
|
|
|
@classmethod
|
|
def fromflat(cls, flat):
|
|
return cls(flat)
|
|
|
|
class BernoulliPd(Pd):
|
|
def __init__(self, logits):
|
|
self.logits = logits
|
|
self.ps = tf.sigmoid(logits)
|
|
def flatparam(self):
|
|
return self.logits
|
|
def mode(self):
|
|
return tf.round(self.ps)
|
|
def neglogp(self, x):
|
|
return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=tf.to_float(x)), axis=-1)
|
|
def kl(self, other):
|
|
return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=other.logits, labels=self.ps), axis=-1) - tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1)
|
|
def entropy(self):
|
|
return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.logits, labels=self.ps), axis=-1)
|
|
def sample(self):
|
|
u = tf.random_uniform(tf.shape(self.ps))
|
|
return tf.to_float(math_ops.less(u, self.ps))
|
|
@classmethod
|
|
def fromflat(cls, flat):
|
|
return cls(flat)
|
|
|
|
def make_pdtype(ac_space):
|
|
from gym import spaces
|
|
if isinstance(ac_space, spaces.Box):
|
|
assert len(ac_space.shape) == 1
|
|
return DiagGaussianPdType(ac_space.shape[0])
|
|
elif isinstance(ac_space, spaces.Discrete):
|
|
return CategoricalPdType(ac_space.n)
|
|
elif isinstance(ac_space, spaces.MultiDiscrete):
|
|
return MultiCategoricalPdType(ac_space.nvec)
|
|
elif isinstance(ac_space, spaces.MultiBinary):
|
|
return BernoulliPdType(ac_space.n)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
def make_pdtype_bandit(ac_space, policy_type):
|
|
from gym import spaces
|
|
if isinstance(ac_space, spaces.Box):
|
|
assert len(ac_space.shape) == 1
|
|
if policy_type == 'Gaussian':
|
|
return DiagGaussianPdType(ac_space.shape[0])
|
|
elif policy_type == 'Beta':
|
|
return BetaPdType(ac_space.shape[0])
|
|
elif policy_type == 'Multimodal':
|
|
return MultimodalPdType(ac_space.shape[0])
|
|
elif isinstance(ac_space, spaces.Discrete):
|
|
return CategoricalPdType(ac_space.n)
|
|
elif isinstance(ac_space, spaces.MultiDiscrete):
|
|
return MultiCategoricalPdType(ac_space.nvec)
|
|
elif isinstance(ac_space, spaces.MultiBinary):
|
|
return BernoulliPdType(ac_space.n)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
def shape_el(v, i):
|
|
maybe = v.get_shape()[i]
|
|
if maybe is not None:
|
|
return maybe
|
|
else:
|
|
return tf.shape(v)[i]
|
|
|
|
@U.in_session
|
|
def test_probtypes():
|
|
np.random.seed(0)
|
|
|
|
pdparam_diag_gauss = np.array([-.2, .3, .4, -.5, .1, -.5, .1, 0.8])
|
|
diag_gauss = DiagGaussianPdType(pdparam_diag_gauss.size // 2) #pylint: disable=E1101
|
|
validate_probtype(diag_gauss, pdparam_diag_gauss)
|
|
|
|
pdparam_categorical = np.array([-.2, .3, .5])
|
|
categorical = CategoricalPdType(pdparam_categorical.size) #pylint: disable=E1101
|
|
validate_probtype(categorical, pdparam_categorical)
|
|
|
|
nvec = [1,2,3]
|
|
pdparam_multicategorical = np.array([-.2, .3, .5, .1, 1, -.1])
|
|
multicategorical = MultiCategoricalPdType(nvec) #pylint: disable=E1101
|
|
validate_probtype(multicategorical, pdparam_multicategorical)
|
|
|
|
pdparam_bernoulli = np.array([-.2, .3, .5])
|
|
bernoulli = BernoulliPdType(pdparam_bernoulli.size) #pylint: disable=E1101
|
|
validate_probtype(bernoulli, pdparam_bernoulli)
|
|
|
|
|
|
def validate_probtype(probtype, pdparam):
|
|
N = 100000
|
|
# Check to see if mean negative log likelihood == differential entropy
|
|
Mval = np.repeat(pdparam[None, :], N, axis=0)
|
|
M = probtype.param_placeholder([N])
|
|
X = probtype.sample_placeholder([N])
|
|
pd = probtype.pdfromflat(M)
|
|
calcloglik = U.function([X, M], pd.logp(X))
|
|
calcent = U.function([M], pd.entropy())
|
|
Xval = tf.get_default_session().run(pd.sample(), feed_dict={M:Mval})
|
|
logliks = calcloglik(Xval, Mval)
|
|
entval_ll = - logliks.mean() #pylint: disable=E1101
|
|
entval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101
|
|
entval = calcent(Mval).mean() #pylint: disable=E1101
|
|
assert np.abs(entval - entval_ll) < 3 * entval_ll_stderr # within 3 sigmas
|
|
|
|
# Check to see if kldiv[p,q] = - ent[p] - E_p[log q]
|
|
M2 = probtype.param_placeholder([N])
|
|
pd2 = probtype.pdfromflat(M2)
|
|
q = pdparam + np.random.randn(pdparam.size) * 0.1
|
|
Mval2 = np.repeat(q[None, :], N, axis=0)
|
|
calckl = U.function([M, M2], pd.kl(pd2))
|
|
klval = calckl(Mval, Mval2).mean() #pylint: disable=E1101
|
|
logliks = calcloglik(Xval, Mval2)
|
|
klval_ll = - entval - logliks.mean() #pylint: disable=E1101
|
|
klval_ll_stderr = logliks.std() / np.sqrt(N) #pylint: disable=E1101
|
|
assert np.abs(klval - klval_ll) < 3 * klval_ll_stderr # within 3 sigmas
|
|
print('ok on', probtype, pdparam)
|
|
|