ppo_cartpole.py seems to be working with param: bs128, num_ep20, max_time500; manually merged Normal from branch policy_wrapper
This commit is contained in:
parent
88648f0c4b
commit
4333ee5d39
98
examples/ppo_cartpole.py
Executable file
98
examples/ppo_cartpole.py
Executable file
@ -0,0 +1,98 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
# our lib imports here! It's ok to append path in examples
|
||||||
|
import sys
|
||||||
|
sys.path.append('..')
|
||||||
|
from tianshou.core import losses
|
||||||
|
from tianshou.data.batch import Batch
|
||||||
|
import tianshou.data.advantage_estimation as advantage_estimation
|
||||||
|
import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy
|
||||||
|
|
||||||
|
from rllab.envs.box2d.cartpole_env import CartpoleEnv
|
||||||
|
from rllab.envs.normalized_env import normalize
|
||||||
|
|
||||||
|
|
||||||
|
def policy_net(observation, action_dim, scope=None):
|
||||||
|
"""
|
||||||
|
Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf
|
||||||
|
|
||||||
|
:param observation: Placeholder for the observation. A tensor of shape (bs, x, y, channels)
|
||||||
|
:param action_dim: int. The number of actions.
|
||||||
|
:param scope: str. Specifying the scope of the variables.
|
||||||
|
"""
|
||||||
|
# with tf.variable_scope(scope):
|
||||||
|
net = tf.layers.dense(observation, 32, activation=tf.nn.tanh)
|
||||||
|
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
|
||||||
|
|
||||||
|
act_mean = tf.layers.dense(net, action_dim, activation=None)
|
||||||
|
|
||||||
|
return act_mean
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__': # a clean version with only policy net, no value net
|
||||||
|
env = normalize(CartpoleEnv())
|
||||||
|
observation_dim = env.observation_space.shape
|
||||||
|
action_dim = env.action_space.flat_dim
|
||||||
|
|
||||||
|
clip_param = 0.2
|
||||||
|
num_batches = 10
|
||||||
|
batch_size = 512
|
||||||
|
|
||||||
|
# 1. build network with pure tf
|
||||||
|
observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input
|
||||||
|
|
||||||
|
with tf.variable_scope('pi'):
|
||||||
|
action_mean = policy_net(observation, action_dim, 'pi')
|
||||||
|
action_logstd = tf.get_variable('action_logstd', shape=(action_dim,))
|
||||||
|
train_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES
|
||||||
|
with tf.variable_scope('pi_old'):
|
||||||
|
action_mean_old = policy_net(observation, action_dim, 'pi_old')
|
||||||
|
action_logstd_old = tf.get_variable('action_logstd_old', shape=(action_dim,))
|
||||||
|
pi_old_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'pi_old')
|
||||||
|
|
||||||
|
# 2. build losses, optimizers
|
||||||
|
pi = policy.Normal(action_mean, action_logstd, observation_placeholder=observation) # YongRen: policy.Gaussian (could reference the policy in TRPO paper, my code is adapted from zhusuan.distributions) policy.DQN etc.
|
||||||
|
# for continuous action space, you may need to change an environment to run
|
||||||
|
pi_old = policy.Normal(action_mean_old, action_logstd_old, observation_placeholder=observation)
|
||||||
|
|
||||||
|
action = tf.placeholder(dtype=tf.float32, shape=(None, action_dim)) # batch of integer actions
|
||||||
|
advantage = tf.placeholder(dtype=tf.float32, shape=(None,)) # advantage values used in the Gradients
|
||||||
|
|
||||||
|
ppo_loss_clip = losses.ppo_clip(action, advantage, clip_param, pi, pi_old) # TongzhengRen: losses.vpg ... management of placeholders and feed_dict
|
||||||
|
|
||||||
|
total_loss = ppo_loss_clip
|
||||||
|
optimizer = tf.train.AdamOptimizer(1e-4)
|
||||||
|
train_op = optimizer.minimize(total_loss, var_list=train_var_list)
|
||||||
|
|
||||||
|
# 3. define data collection
|
||||||
|
training_data = Batch(env, pi, advantage_estimation.full_return) # YouQiaoben: finish and polish Batch, advantage_estimation.gae_lambda as in PPO paper
|
||||||
|
# ShihongSong: Replay(), see dqn_example.py
|
||||||
|
# maybe a dict to manage the elements to be collected
|
||||||
|
|
||||||
|
# 4. start training
|
||||||
|
config = tf.ConfigProto()
|
||||||
|
config.gpu_options.allow_growth = True
|
||||||
|
with tf.Session(config=config) as sess:
|
||||||
|
sess.run(tf.global_variables_initializer())
|
||||||
|
# sync pi and pi_old
|
||||||
|
sess.run([tf.assign(theta_old, theta) for (theta_old, theta) in zip(pi_old_var_list, train_var_list)])
|
||||||
|
|
||||||
|
for i in range(100): # until some stopping criterion met...
|
||||||
|
# collect data
|
||||||
|
training_data.collect(num_episodes=120) # YouQiaoben, ShihongSong
|
||||||
|
|
||||||
|
# print current return
|
||||||
|
print('Epoch {}:'.format(i))
|
||||||
|
training_data.statistics()
|
||||||
|
|
||||||
|
# update network
|
||||||
|
for _ in range(num_batches):
|
||||||
|
data = training_data.next_batch(batch_size) # YouQiaoben, ShihongSong
|
||||||
|
# TODO: auto managing of the placeholders? or add this to params of data.Batch
|
||||||
|
sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], advantage: data['returns']})
|
||||||
|
|
||||||
|
# assigning pi to pi_old
|
||||||
|
sess.run([tf.assign(theta_old, theta) for (theta_old, theta) in zip(pi_old_var_list, train_var_list)])
|
@ -55,7 +55,7 @@ class QValuePolicy(object):
|
|||||||
|
|
||||||
class StochasticPolicy(object):
|
class StochasticPolicy(object):
|
||||||
"""
|
"""
|
||||||
The :class:`Distribution` class is the base class for various probabilistic
|
The :class:`StochasticPolicy` class is the base class for various probabilistic
|
||||||
distributions which support batch inputs, generating batches of samples and
|
distributions which support batch inputs, generating batches of samples and
|
||||||
evaluate probabilities at batches of given values.
|
evaluate probabilities at batches of given values.
|
||||||
|
|
||||||
|
@ -62,9 +62,11 @@ class OnehotCategorical(StochasticPolicy):
|
|||||||
return self._n_categories
|
return self._n_categories
|
||||||
|
|
||||||
def _act(self, observation):
|
def _act(self, observation):
|
||||||
sess = tf.get_default_session() # TODO: this may be ugly. also maybe huge problem when parallel
|
# TODO: this may be ugly. also maybe huge problem when parallel
|
||||||
sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1), feed_dict={self._observation_placeholder: observation[None]})
|
sess = tf.get_default_session()
|
||||||
# observation[None] adds one dimension at the beginning
|
# observation[None] adds one dimension at the beginning
|
||||||
|
sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1),
|
||||||
|
feed_dict={self._observation_placeholder: observation[None]})
|
||||||
|
|
||||||
sampled_action = sampled_action[0, 0]
|
sampled_action = sampled_action[0, 0]
|
||||||
|
|
||||||
@ -73,28 +75,75 @@ class OnehotCategorical(StochasticPolicy):
|
|||||||
def _log_prob(self, sampled_action):
|
def _log_prob(self, sampled_action):
|
||||||
return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=sampled_action, logits=self.logits)
|
return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=sampled_action, logits=self.logits)
|
||||||
|
|
||||||
# given = tf.cast(given, self.param_dtype)
|
|
||||||
# given, logits = maybe_explicit_broadcast(
|
|
||||||
# given, self.logits, 'given', 'logits')
|
|
||||||
# if (given.get_shape().ndims == 2) or (logits.get_shape().ndims == 2):
|
|
||||||
# given_flat = given
|
|
||||||
# logits_flat = logits
|
|
||||||
# else:
|
|
||||||
# given_flat = tf.reshape(given, [-1, self.n_categories])
|
|
||||||
# logits_flat = tf.reshape(logits, [-1, self.n_categories])
|
|
||||||
# log_p_flat = -tf.nn.softmax_cross_entropy_with_logits(
|
|
||||||
# labels=given_flat, logits=logits_flat)
|
|
||||||
# if (given.get_shape().ndims == 2) or (logits.get_shape().ndims == 2):
|
|
||||||
# log_p = log_p_flat
|
|
||||||
# else:
|
|
||||||
# log_p = tf.reshape(log_p_flat, tf.shape(logits)[:-1])
|
|
||||||
# if given.get_shape() and logits.get_shape():
|
|
||||||
# log_p.set_shape(tf.broadcast_static_shape(
|
|
||||||
# given.get_shape(), logits.get_shape())[:-1])
|
|
||||||
# return log_p
|
|
||||||
|
|
||||||
def _prob(self, sampled_action):
|
def _prob(self, sampled_action):
|
||||||
return tf.exp(self._log_prob(sampled_action))
|
return tf.exp(self._log_prob(sampled_action))
|
||||||
|
|
||||||
|
|
||||||
OnehotDiscrete = OnehotCategorical
|
OnehotDiscrete = OnehotCategorical
|
||||||
|
|
||||||
|
|
||||||
|
class Normal(StochasticPolicy):
|
||||||
|
"""
|
||||||
|
The :class:`Normal' class is the Normal policy
|
||||||
|
|
||||||
|
:param mean:
|
||||||
|
:param std:
|
||||||
|
:param group_ndims
|
||||||
|
:param observation_placeholder
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
mean = 0.,
|
||||||
|
logstd = 1.,
|
||||||
|
group_ndims = 1,
|
||||||
|
observation_placeholder = None,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
self._mean = tf.convert_to_tensor(mean, dtype = tf.float32)
|
||||||
|
self._logstd = tf.convert_to_tensor(logstd, dtype = tf.float32)
|
||||||
|
self._std = tf.exp(self._logstd)
|
||||||
|
|
||||||
|
super(Normal, self).__init__(
|
||||||
|
act_dtype = tf.float32,
|
||||||
|
param_dtype = tf.float32,
|
||||||
|
is_continuous = True,
|
||||||
|
observation_placeholder = observation_placeholder,
|
||||||
|
group_ndims = group_ndims,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mean(self):
|
||||||
|
return self._mean
|
||||||
|
|
||||||
|
@property
|
||||||
|
def std(self):
|
||||||
|
return self._std
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logstd(self):
|
||||||
|
return self._logstd
|
||||||
|
|
||||||
|
def _act(self, observation):
|
||||||
|
# TODO: getting session like this maybe ugly. also maybe huge problem when parallel
|
||||||
|
sess = tf.get_default_session()
|
||||||
|
mean, std = self._mean, self._std
|
||||||
|
shape = tf.broadcast_dynamic_shape(tf.shape(self._mean),\
|
||||||
|
tf.shape(self._std))
|
||||||
|
|
||||||
|
|
||||||
|
# observation[None] adds one dimension at the beginning
|
||||||
|
sampled_action = sess.run(tf.random_normal(tf.concat([[1], shape], 0),
|
||||||
|
dtype = tf.float32) * std + mean,
|
||||||
|
feed_dict={self._observation_placeholder: observation[None]})
|
||||||
|
sampled_action = sampled_action[0, 0]
|
||||||
|
return sampled_action
|
||||||
|
|
||||||
|
|
||||||
|
def _log_prob(self, sampled_action):
|
||||||
|
mean, logstd = self._mean, self._logstd
|
||||||
|
c = -0.5 * np.log(2 * np.pi)
|
||||||
|
precision = tf.exp(-2 * logstd)
|
||||||
|
return c - logstd - 0.5 * precision * tf.square(sampled_action - mean)
|
||||||
|
|
||||||
|
def _prob(self, sampled_action):
|
||||||
|
return tf.exp(self._log_prob(sampled_action))
|
@ -8,7 +8,7 @@ class Batch(object):
|
|||||||
class for batch datasets. Collect multiple observations (actions, rewards, etc.) on-policy.
|
class for batch datasets. Collect multiple observations (actions, rewards, etc.) on-policy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env, pi, advantage_estimation_function): # how to name the function?
|
def __init__(self, env, pi, advantage_estimation_function): # how to name the function?
|
||||||
self._env = env
|
self._env = env
|
||||||
self._pi = pi
|
self._pi = pi
|
||||||
self._advantage_estimation_function = advantage_estimation_function
|
self._advantage_estimation_function = advantage_estimation_function
|
||||||
@ -63,7 +63,7 @@ class Batch(object):
|
|||||||
ob = env.reset()
|
ob = env.reset()
|
||||||
t += 1
|
t += 1
|
||||||
|
|
||||||
if num_episodes > 0: # YouQiaoben: fix memory growth, both del and gc.collect() fail
|
if num_episodes > 0: # YouQiaoben: fix memory growth, both del and gc.collect() fail
|
||||||
# initialize rawdata lists
|
# initialize rawdata lists
|
||||||
if not self._is_first_collect:
|
if not self._is_first_collect:
|
||||||
del self.observations
|
del self.observations
|
||||||
@ -91,10 +91,10 @@ class Batch(object):
|
|||||||
rewards.append(reward)
|
rewards.append(reward)
|
||||||
|
|
||||||
t_count += 1
|
t_count += 1
|
||||||
if t_count >= 200: # force episode stop, just to test if memory still grows
|
if t_count >= 100: # force episode stop, just to test if memory still grows
|
||||||
break
|
done = True
|
||||||
|
|
||||||
if done: # end of episode, discard s_T
|
if done: # end of episode, discard s_T
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
observations.append(ob)
|
observations.append(ob)
|
||||||
@ -122,7 +122,46 @@ class Batch(object):
|
|||||||
def apply_advantage_estimation_function(self):
|
def apply_advantage_estimation_function(self):
|
||||||
self.data = self._advantage_estimation_function(self.raw_data)
|
self.data = self._advantage_estimation_function(self.raw_data)
|
||||||
|
|
||||||
def next_batch(self, batch_size): # YouQiaoben: referencing other iterate over batches
|
def next_batch(self, batch_size, standardize_advantage=True): # YouQiaoben: referencing other iterate over batches
|
||||||
rand_idx = np.random.choice(self.data['observations'].shape[0], batch_size)
|
rand_idx = np.random.choice(self.data['observations'].shape[0], batch_size)
|
||||||
return {key: value[rand_idx] for key, value in self.data.items()}
|
current_batch = {key: value[rand_idx] for key, value in self.data.items()}
|
||||||
|
|
||||||
|
if standardize_advantage:
|
||||||
|
advantage_mean = np.mean(current_batch['returns'])
|
||||||
|
advantage_std = np.std(current_batch['returns'])
|
||||||
|
current_batch['returns'] = (current_batch['returns'] - advantage_mean) / advantage_std
|
||||||
|
|
||||||
|
return current_batch
|
||||||
|
|
||||||
|
def statistics(self):
|
||||||
|
"""
|
||||||
|
compute the statistics of the current sampled paths
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
rewards = self.raw_data['rewards']
|
||||||
|
episode_start_flags = self.raw_data['episode_start_flags']
|
||||||
|
num_timesteps = rewards.shape[0]
|
||||||
|
|
||||||
|
returns = []
|
||||||
|
max_return = 0
|
||||||
|
episode_start_idx = 0
|
||||||
|
for i in range(1, num_timesteps):
|
||||||
|
if episode_start_flags[i] or (
|
||||||
|
i == num_timesteps - 1): # found the start of next episode or the end of all episodes
|
||||||
|
if i < rewards.shape[0] - 1:
|
||||||
|
t = i - 1
|
||||||
|
else:
|
||||||
|
t = i
|
||||||
|
Gt = 0
|
||||||
|
while t >= episode_start_idx:
|
||||||
|
Gt += rewards[t]
|
||||||
|
t -= 1
|
||||||
|
|
||||||
|
returns.append(Gt)
|
||||||
|
if Gt > max_return:
|
||||||
|
max_return = Gt
|
||||||
|
episode_start_idx = i
|
||||||
|
|
||||||
|
print('AverageReturn: {}'.format(np.mean(returns)))
|
||||||
|
print('StdReturn: : {}'.format(np.std(returns)))
|
||||||
|
print('MaxReturn : {}'.format(max_return))
|
Loading…
x
Reference in New Issue
Block a user