auto target network. ppo_cartpole.py run ok. but results is different from previous version even with the same random seed, still needs debugging.
This commit is contained in:
parent
3b222f5edb
commit
fed3bf2a12
@ -17,24 +17,9 @@ from rllab.envs.box2d.cartpole_env import CartpoleEnv
|
||||
from rllab.envs.normalized_env import normalize
|
||||
|
||||
|
||||
def policy_net(observation, action_dim, scope=None):
|
||||
"""
|
||||
Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf
|
||||
# for tutorial purpose, placeholders are explicitly appended with '_ph' suffix
|
||||
|
||||
:param observation: Placeholder for the observation. A tensor of shape (bs, x, y, channels)
|
||||
:param action_dim: int. The number of actions.
|
||||
:param scope: str. Specifying the scope of the variables.
|
||||
"""
|
||||
# with tf.variable_scope(scope):
|
||||
net = tf.layers.dense(observation, 32, activation=tf.nn.tanh)
|
||||
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
|
||||
|
||||
act_mean = tf.layers.dense(net, action_dim, activation=None)
|
||||
|
||||
return act_mean
|
||||
|
||||
|
||||
if __name__ == '__main__': # a clean version with only policy net, no value net
|
||||
if __name__ == '__main__':
|
||||
env = normalize(CartpoleEnv())
|
||||
observation_dim = env.observation_space.shape
|
||||
action_dim = env.action_space.flat_dim
|
||||
@ -47,44 +32,51 @@ if __name__ == '__main__': # a clean version with only policy net, no value net
|
||||
np.random.seed(seed)
|
||||
tf.set_random_seed(seed)
|
||||
|
||||
# 1. build network with pure tf
|
||||
observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input
|
||||
### 1. build network with pure tf
|
||||
observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input
|
||||
|
||||
with tf.variable_scope('pi'):
|
||||
action_mean = policy_net(observation, action_dim, 'pi')
|
||||
action_logstd = tf.get_variable('action_logstd', shape=(action_dim,))
|
||||
train_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES
|
||||
with tf.variable_scope('pi_old'):
|
||||
action_mean_old = policy_net(observation, action_dim, 'pi_old')
|
||||
action_logstd_old = tf.get_variable('action_logstd_old', shape=(action_dim,))
|
||||
pi_old_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'pi_old')
|
||||
def my_policy():
|
||||
net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
|
||||
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
|
||||
action_mean = tf.layers.dense(net, action_dim, activation=None)
|
||||
action_logstd = tf.get_variable('action_logstd', shape=(action_dim, ))
|
||||
|
||||
# 2. build losses, optimizers
|
||||
pi = policy.Normal(action_mean, action_logstd, observation_placeholder=observation) # YongRen: policy.Gaussian (could reference the policy in TRPO paper, my code is adapted from zhusuan.distributions) policy.DQN etc.
|
||||
# for continuous action space, you may need to change an environment to run
|
||||
pi_old = policy.Normal(action_mean_old, action_logstd_old, observation_placeholder=observation)
|
||||
# value = tf.layers.dense(net, 1, activation=None)
|
||||
|
||||
action = tf.placeholder(dtype=tf.float32, shape=(None, action_dim)) # batch of integer actions
|
||||
advantage = tf.placeholder(dtype=tf.float32, shape=(None,)) # advantage values used in the Gradients
|
||||
return action_mean, action_logstd, None # None value head
|
||||
|
||||
ppo_loss_clip = losses.ppo_clip(action, advantage, clip_param, pi, pi_old) # TongzhengRen: losses.vpg ... management of placeholders and feed_dict
|
||||
# TODO: current implementation of passing function or overriding function has to return a value head
|
||||
# to allow network sharing between policy and value networks. This makes 'policy' and 'value_function'
|
||||
# imbalanced semantically (though they are naturally imbalanced since 'policy' is required to interact
|
||||
# with the environment and 'value_function' is not). I have an idea to solve this imbalance, which is
|
||||
# not based on passing function or overriding function.
|
||||
|
||||
### 2. build policy, losses, optimizers
|
||||
pi = policy.Normal(my_policy, observation_placeholder=observation_ph, weight_update=0)
|
||||
|
||||
# action = tf.placeholder(dtype=tf.float32, shape=(None, action_dim)) # batch of integer actions
|
||||
# advantage = tf.placeholder(dtype=tf.float32, shape=(None,)) # advantage values used in the Gradients
|
||||
|
||||
ppo_loss_clip = losses.ppo_clip(pi, clip_param) # TongzhengRen: losses.vpg ... management of placeholders and feed_dict
|
||||
|
||||
total_loss = ppo_loss_clip
|
||||
optimizer = tf.train.AdamOptimizer(1e-4)
|
||||
train_op = optimizer.minimize(total_loss, var_list=train_var_list)
|
||||
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
|
||||
|
||||
# 3. define data collection
|
||||
### 3. define data collection
|
||||
training_data = Batch(env, pi, advantage_estimation.full_return) # YouQiaoben: finish and polish Batch, advantage_estimation.gae_lambda as in PPO paper
|
||||
# ShihongSong: Replay(), see dqn_example.py
|
||||
# maybe a dict to manage the elements to be collected
|
||||
|
||||
# 4. start training
|
||||
### 4. start training
|
||||
# init = tf.global_variables_initializer()
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
with tf.Session(config=config) as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# sync pi and pi_old
|
||||
sess.run([tf.assign(theta_old, theta) for (theta_old, theta) in zip(pi_old_var_list, train_var_list)])
|
||||
pi.sync_weights() # TODO: automate this for policies with target network
|
||||
|
||||
start_time = time.time()
|
||||
for i in range(100): # until some stopping criterion met...
|
||||
@ -97,11 +89,10 @@ if __name__ == '__main__': # a clean version with only policy net, no value net
|
||||
|
||||
# update network
|
||||
for _ in range(num_batches):
|
||||
data = training_data.next_batch(batch_size) # YouQiaoben, ShihongSong
|
||||
# TODO: auto managing of the placeholders? or add this to params of data.Batch
|
||||
sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], advantage: data['returns']})
|
||||
feed_dict = training_data.next_batch(batch_size) # YouQiaoben, ShihongSong
|
||||
sess.run(train_op, feed_dict=feed_dict)
|
||||
|
||||
# assigning pi to pi_old
|
||||
sess.run([tf.assign(theta_old, theta) for (theta_old, theta) in zip(pi_old_var_list, train_var_list)])
|
||||
pi.update_weights()
|
||||
|
||||
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))
|
@ -1,22 +1,26 @@
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def ppo_clip(sampled_action, advantage, clip_param, pi, pi_old):
|
||||
def ppo_clip(policy, clip_param):
|
||||
"""
|
||||
the clip loss in ppo paper
|
||||
|
||||
:param sampled_action: placeholder of sampled actions during interaction with the environment
|
||||
:param advantage: placeholder of estimated advantage values.
|
||||
:param clip param: float or Tensor of type float.
|
||||
:param pi: current `policy` to be optimized
|
||||
:param policy: current `policy` to be optimized
|
||||
:param pi_old: old `policy` for computing the ppo loss as in Eqn. (7) in the paper
|
||||
"""
|
||||
action_ph = tf.placeholder(policy.act_dtype, shape=(None, policy.action_dim), name='ppo_clip_loss/action_placeholder')
|
||||
advantage_ph = tf.placeholder(tf.float32, shape=(None,), name='ppo_clip_loss/advantage_placeholder')
|
||||
policy.managed_placeholders['action'] = action_ph
|
||||
policy.managed_placeholders['processed_reward'] = advantage_ph
|
||||
|
||||
log_pi_act = pi.log_prob(sampled_action)
|
||||
log_pi_old_act = pi_old.log_prob(sampled_action)
|
||||
log_pi_act = policy.log_prob(action_ph)
|
||||
log_pi_old_act = policy.log_prob_old(action_ph)
|
||||
ratio = tf.exp(log_pi_act - log_pi_old_act)
|
||||
clipped_ratio = tf.clip_by_value(ratio, 1. - clip_param, 1. + clip_param)
|
||||
ppo_clip_loss = -tf.reduce_mean(tf.minimum(ratio * advantage, clipped_ratio * advantage))
|
||||
ppo_clip_loss = -tf.reduce_mean(tf.minimum(ratio * advantage_ph, clipped_ratio * advantage_ph))
|
||||
return ppo_clip_loss
|
||||
|
||||
|
||||
|
@ -16,44 +16,11 @@ class PolicyBase(object):
|
||||
"""
|
||||
base class for policy. only provides `act` method with exploration
|
||||
"""
|
||||
def __init__(self, observation_placeholder):
|
||||
self._observation_placeholder = observation_placeholder
|
||||
|
||||
def act(self, observation, exploration):
|
||||
def act(self, observation):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class QValuePolicy(object):
|
||||
"""
|
||||
The policy as in DQN
|
||||
"""
|
||||
def __init__(self, observation_placeholder):
|
||||
self._observation_placeholder = observation_placeholder
|
||||
|
||||
def act(self, observation, exploration=None): # first implement no exploration
|
||||
"""
|
||||
return the action (int) to be executed.
|
||||
no exploration when exploration=None.
|
||||
"""
|
||||
self._act(observation, exploration)
|
||||
|
||||
def _act(self, observation, exploration=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def values(self, observation):
|
||||
"""
|
||||
returns the Q(s, a) values (float) for all actions a at observation s
|
||||
"""
|
||||
pass
|
||||
|
||||
def values_tensor(self):
|
||||
"""
|
||||
returns the tensor of the values for all actions a at observation s
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class StochasticPolicy(object):
|
||||
class StochasticPolicy(PolicyBase):
|
||||
"""
|
||||
The :class:`StochasticPolicy` class is the base class for various probabilistic
|
||||
distributions which support batch inputs, generating batches of samples and
|
||||
@ -170,7 +137,7 @@ class StochasticPolicy(object):
|
||||
return self._group_ndims
|
||||
|
||||
# @add_name_scope
|
||||
def act(self, observation):
|
||||
def act(self, observation, my_feed_dict={}):
|
||||
"""
|
||||
sample(n_samples=None)
|
||||
|
||||
@ -184,9 +151,9 @@ class StochasticPolicy(object):
|
||||
samples to draw from the distribution.
|
||||
:return: A Tensor of samples.
|
||||
"""
|
||||
return self._act(observation)
|
||||
return self._act(observation, my_feed_dict)
|
||||
|
||||
def _act(self, observation):
|
||||
def _act(self, observation, my_feed_dict):
|
||||
"""
|
||||
Private method for subclasses to rewrite the :meth:`sample` method.
|
||||
"""
|
||||
|
@ -94,24 +94,64 @@ class Normal(StochasticPolicy):
|
||||
:param observation_placeholder
|
||||
"""
|
||||
def __init__(self,
|
||||
mean = 0.,
|
||||
logstd = 1.,
|
||||
group_ndims = 1,
|
||||
observation_placeholder = None,
|
||||
policy_callable,
|
||||
observation_placeholder,
|
||||
weight_update=1,
|
||||
group_ndims=1,
|
||||
**kwargs):
|
||||
self.managed_placeholders = {'observation': observation_placeholder}
|
||||
self.weight_update = weight_update
|
||||
self.interaction_count = -1 # defaults to -1. only useful if weight_update > 1.
|
||||
|
||||
with tf.variable_scope('network'):
|
||||
mean, logstd, value_head = policy_callable()
|
||||
self._mean = tf.convert_to_tensor(mean, dtype = tf.float32)
|
||||
self._logstd = tf.convert_to_tensor(logstd, dtype = tf.float32)
|
||||
self._std = tf.exp(self._logstd)
|
||||
|
||||
shape = tf.broadcast_dynamic_shape(tf.shape(self._mean), tf.shape(self._std))
|
||||
self._action = tf.random_normal(tf.concat([[1], shape], 0), dtype = tf.float32) * self._std + self._mean
|
||||
# TODO: self._action should be exactly the action tensor to run, without [0, 0] in self._act
|
||||
|
||||
if value_head is not None:
|
||||
pass
|
||||
|
||||
self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network')
|
||||
|
||||
if self.weight_update == 1:
|
||||
self.weight_update_ops = None
|
||||
self.sync_weights_ops = None
|
||||
else: # then we need to build another tf graph as target network
|
||||
with tf.variable_scope('net_old'):
|
||||
mean, logstd, value_head = policy_callable()
|
||||
self._mean_old = tf.convert_to_tensor(mean, dtype=tf.float32)
|
||||
self._logstd_old = tf.convert_to_tensor(logstd, dtype=tf.float32)
|
||||
|
||||
if value_head is not None: # useful in DDPG
|
||||
pass
|
||||
|
||||
network_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='network')
|
||||
self.network_old_weights = network_old_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net_old')
|
||||
assert len(network_weights) == len(network_old_weights)
|
||||
self.sync_weights_ops = [tf.assign(variable_old, variable)
|
||||
for (variable_old, variable) in zip(network_old_weights, network_weights)]
|
||||
|
||||
if weight_update == 0:
|
||||
self.weight_update_ops = self.sync_weights_ops
|
||||
elif 0 < weight_update < 1:
|
||||
pass
|
||||
else:
|
||||
self.interaction_count = 0
|
||||
import math
|
||||
self.weight_update = math.ceil(weight_update)
|
||||
|
||||
self._mean = tf.convert_to_tensor(mean, dtype = tf.float32)
|
||||
self._logstd = tf.convert_to_tensor(logstd, dtype = tf.float32)
|
||||
self._std = tf.exp(self._logstd)
|
||||
|
||||
shape = tf.broadcast_dynamic_shape(tf.shape(self._mean), tf.shape(self._std))
|
||||
self._action = tf.random_normal(tf.concat([[1], shape], 0), dtype = tf.float32) * self._std + self._mean
|
||||
|
||||
super(Normal, self).__init__(
|
||||
act_dtype = tf.float32,
|
||||
param_dtype = tf.float32,
|
||||
is_continuous = True,
|
||||
observation_placeholder = observation_placeholder,
|
||||
act_dtype=tf.float32,
|
||||
param_dtype=tf.float32,
|
||||
is_continuous=True,
|
||||
observation_placeholder=observation_placeholder,
|
||||
group_ndims = group_ndims,
|
||||
**kwargs)
|
||||
|
||||
@ -127,13 +167,22 @@ class Normal(StochasticPolicy):
|
||||
def logstd(self):
|
||||
return self._logstd
|
||||
|
||||
def _act(self, observation):
|
||||
@property
|
||||
def action(self):
|
||||
return self._action
|
||||
|
||||
@property
|
||||
def action_dim(self):
|
||||
return self.mean.shape.as_list()[-1]
|
||||
|
||||
def _act(self, observation, my_feed_dict):
|
||||
# TODO: getting session like this maybe ugly. also maybe huge problem when parallel
|
||||
sess = tf.get_default_session()
|
||||
|
||||
# observation[None] adds one dimension at the beginning
|
||||
sampled_action = sess.run(self._action,
|
||||
feed_dict={self._observation_placeholder: observation[None]})
|
||||
feed_dict = {self._observation_placeholder: observation[None]}
|
||||
feed_dict.update(my_feed_dict)
|
||||
sampled_action = sess.run(self._action, feed_dict=feed_dict)
|
||||
sampled_action = sampled_action[0, 0]
|
||||
return sampled_action
|
||||
|
||||
@ -145,4 +194,36 @@ class Normal(StochasticPolicy):
|
||||
return c - logstd - 0.5 * precision * tf.square(sampled_action - mean)
|
||||
|
||||
def _prob(self, sampled_action):
|
||||
return tf.exp(self._log_prob(sampled_action))
|
||||
return tf.exp(self._log_prob(sampled_action))
|
||||
|
||||
def log_prob_old(self, sampled_action):
|
||||
"""
|
||||
return the log_prob of the old policy when constructing tf graphs. Raises error when there's no old policy.
|
||||
:param sampled_action: the placeholder for sampled actions during interaction with the environment.
|
||||
:return: tensor of the log_prob of the old policy
|
||||
"""
|
||||
if self.weight_update == 1:
|
||||
raise AttributeError('Policy has no policy_old since it\'s initialized with weight_update=1!')
|
||||
|
||||
mean, logstd = self._mean_old, self._logstd_old
|
||||
c = -0.5 * np.log(2 * np.pi)
|
||||
precision = tf.exp(-2 * logstd)
|
||||
return c - logstd - 0.5 * precision * tf.square(sampled_action - mean)
|
||||
|
||||
def update_weights(self):
|
||||
"""
|
||||
updates the weights of policy_old.
|
||||
:return:
|
||||
"""
|
||||
if self.weight_update_ops is not None:
|
||||
sess = tf.get_default_session()
|
||||
sess.run(self.weight_update_ops)
|
||||
|
||||
def sync_weights(self):
|
||||
"""
|
||||
sync the weights of network_old. Direct copy the weights of network.
|
||||
:return:
|
||||
"""
|
||||
if self.sync_weights_ops is not None:
|
||||
sess = tf.get_default_session()
|
||||
sess.run(self.sync_weights_ops)
|
@ -135,7 +135,12 @@ class Batch(object):
|
||||
advantage_std = np.std(current_batch['returns'])
|
||||
current_batch['returns'] = (current_batch['returns'] - advantage_mean) / advantage_std
|
||||
|
||||
return current_batch
|
||||
feed_dict = {}
|
||||
feed_dict[self._pi.managed_placeholders['observation']] = current_batch['observations']
|
||||
feed_dict[self._pi.managed_placeholders['action']] = current_batch['actions']
|
||||
feed_dict[self._pi.managed_placeholders['processed_reward']] = current_batch['returns']
|
||||
|
||||
return feed_dict
|
||||
|
||||
# TODO: this will definitely be refactored with a proper logger
|
||||
def statistics(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user