diff --git a/examples/ppo_cartpole.py b/examples/ppo_cartpole.py new file mode 100755 index 0000000..6fc986f --- /dev/null +++ b/examples/ppo_cartpole.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python +from __future__ import absolute_import + +import tensorflow as tf + +# our lib imports here! It's ok to append path in examples +import sys +sys.path.append('..') +from tianshou.core import losses +from tianshou.data.batch import Batch +import tianshou.data.advantage_estimation as advantage_estimation +import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy + +from rllab.envs.box2d.cartpole_env import CartpoleEnv +from rllab.envs.normalized_env import normalize + + +def policy_net(observation, action_dim, scope=None): + """ + Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf + + :param observation: Placeholder for the observation. A tensor of shape (bs, x, y, channels) + :param action_dim: int. The number of actions. + :param scope: str. Specifying the scope of the variables. + """ + # with tf.variable_scope(scope): + net = tf.layers.dense(observation, 32, activation=tf.nn.tanh) + net = tf.layers.dense(net, 32, activation=tf.nn.tanh) + + act_mean = tf.layers.dense(net, action_dim, activation=None) + + return act_mean + + +if __name__ == '__main__': # a clean version with only policy net, no value net + env = normalize(CartpoleEnv()) + observation_dim = env.observation_space.shape + action_dim = env.action_space.flat_dim + + clip_param = 0.2 + num_batches = 10 + batch_size = 512 + + # 1. build network with pure tf + observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input + + with tf.variable_scope('pi'): + action_mean = policy_net(observation, action_dim, 'pi') + action_logstd = tf.get_variable('action_logstd', shape=(action_dim,)) + train_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES + with tf.variable_scope('pi_old'): + action_mean_old = policy_net(observation, action_dim, 'pi_old') + action_logstd_old = tf.get_variable('action_logstd_old', shape=(action_dim,)) + pi_old_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'pi_old') + + # 2. build losses, optimizers + pi = policy.Normal(action_mean, action_logstd, observation_placeholder=observation) # YongRen: policy.Gaussian (could reference the policy in TRPO paper, my code is adapted from zhusuan.distributions) policy.DQN etc. + # for continuous action space, you may need to change an environment to run + pi_old = policy.Normal(action_mean_old, action_logstd_old, observation_placeholder=observation) + + action = tf.placeholder(dtype=tf.float32, shape=(None, action_dim)) # batch of integer actions + advantage = tf.placeholder(dtype=tf.float32, shape=(None,)) # advantage values used in the Gradients + + ppo_loss_clip = losses.ppo_clip(action, advantage, clip_param, pi, pi_old) # TongzhengRen: losses.vpg ... management of placeholders and feed_dict + + total_loss = ppo_loss_clip + optimizer = tf.train.AdamOptimizer(1e-4) + train_op = optimizer.minimize(total_loss, var_list=train_var_list) + + # 3. define data collection + training_data = Batch(env, pi, advantage_estimation.full_return) # YouQiaoben: finish and polish Batch, advantage_estimation.gae_lambda as in PPO paper + # ShihongSong: Replay(), see dqn_example.py + # maybe a dict to manage the elements to be collected + + # 4. start training + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + with tf.Session(config=config) as sess: + sess.run(tf.global_variables_initializer()) + # sync pi and pi_old + sess.run([tf.assign(theta_old, theta) for (theta_old, theta) in zip(pi_old_var_list, train_var_list)]) + + for i in range(100): # until some stopping criterion met... + # collect data + training_data.collect(num_episodes=120) # YouQiaoben, ShihongSong + + # print current return + print('Epoch {}:'.format(i)) + training_data.statistics() + + # update network + for _ in range(num_batches): + data = training_data.next_batch(batch_size) # YouQiaoben, ShihongSong + # TODO: auto managing of the placeholders? or add this to params of data.Batch + sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], advantage: data['returns']}) + + # assigning pi to pi_old + sess.run([tf.assign(theta_old, theta) for (theta_old, theta) in zip(pi_old_var_list, train_var_list)]) \ No newline at end of file diff --git a/tianshou/core/policy/base.py b/tianshou/core/policy/base.py index 1c1e1c5..8d4b2a1 100644 --- a/tianshou/core/policy/base.py +++ b/tianshou/core/policy/base.py @@ -55,7 +55,7 @@ class QValuePolicy(object): class StochasticPolicy(object): """ - The :class:`Distribution` class is the base class for various probabilistic + The :class:`StochasticPolicy` class is the base class for various probabilistic distributions which support batch inputs, generating batches of samples and evaluate probabilities at batches of given values. diff --git a/tianshou/core/policy/stochastic.py b/tianshou/core/policy/stochastic.py index d7a75d7..e2c2dea 100644 --- a/tianshou/core/policy/stochastic.py +++ b/tianshou/core/policy/stochastic.py @@ -62,9 +62,11 @@ class OnehotCategorical(StochasticPolicy): return self._n_categories def _act(self, observation): - sess = tf.get_default_session() # TODO: this may be ugly. also maybe huge problem when parallel - sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1), feed_dict={self._observation_placeholder: observation[None]}) + # TODO: this may be ugly. also maybe huge problem when parallel + sess = tf.get_default_session() # observation[None] adds one dimension at the beginning + sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1), + feed_dict={self._observation_placeholder: observation[None]}) sampled_action = sampled_action[0, 0] @@ -73,28 +75,75 @@ class OnehotCategorical(StochasticPolicy): def _log_prob(self, sampled_action): return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=sampled_action, logits=self.logits) - # given = tf.cast(given, self.param_dtype) - # given, logits = maybe_explicit_broadcast( - # given, self.logits, 'given', 'logits') - # if (given.get_shape().ndims == 2) or (logits.get_shape().ndims == 2): - # given_flat = given - # logits_flat = logits - # else: - # given_flat = tf.reshape(given, [-1, self.n_categories]) - # logits_flat = tf.reshape(logits, [-1, self.n_categories]) - # log_p_flat = -tf.nn.softmax_cross_entropy_with_logits( - # labels=given_flat, logits=logits_flat) - # if (given.get_shape().ndims == 2) or (logits.get_shape().ndims == 2): - # log_p = log_p_flat - # else: - # log_p = tf.reshape(log_p_flat, tf.shape(logits)[:-1]) - # if given.get_shape() and logits.get_shape(): - # log_p.set_shape(tf.broadcast_static_shape( - # given.get_shape(), logits.get_shape())[:-1]) - # return log_p def _prob(self, sampled_action): return tf.exp(self._log_prob(sampled_action)) -OnehotDiscrete = OnehotCategorical \ No newline at end of file +OnehotDiscrete = OnehotCategorical + + +class Normal(StochasticPolicy): + """ + The :class:`Normal' class is the Normal policy + + :param mean: + :param std: + :param group_ndims + :param observation_placeholder + """ + def __init__(self, + mean = 0., + logstd = 1., + group_ndims = 1, + observation_placeholder = None, + **kwargs): + + self._mean = tf.convert_to_tensor(mean, dtype = tf.float32) + self._logstd = tf.convert_to_tensor(logstd, dtype = tf.float32) + self._std = tf.exp(self._logstd) + + super(Normal, self).__init__( + act_dtype = tf.float32, + param_dtype = tf.float32, + is_continuous = True, + observation_placeholder = observation_placeholder, + group_ndims = group_ndims, + **kwargs) + + @property + def mean(self): + return self._mean + + @property + def std(self): + return self._std + + @property + def logstd(self): + return self._logstd + + def _act(self, observation): + # TODO: getting session like this maybe ugly. also maybe huge problem when parallel + sess = tf.get_default_session() + mean, std = self._mean, self._std + shape = tf.broadcast_dynamic_shape(tf.shape(self._mean),\ + tf.shape(self._std)) + + + # observation[None] adds one dimension at the beginning + sampled_action = sess.run(tf.random_normal(tf.concat([[1], shape], 0), + dtype = tf.float32) * std + mean, + feed_dict={self._observation_placeholder: observation[None]}) + sampled_action = sampled_action[0, 0] + return sampled_action + + + def _log_prob(self, sampled_action): + mean, logstd = self._mean, self._logstd + c = -0.5 * np.log(2 * np.pi) + precision = tf.exp(-2 * logstd) + return c - logstd - 0.5 * precision * tf.square(sampled_action - mean) + + def _prob(self, sampled_action): + return tf.exp(self._log_prob(sampled_action)) \ No newline at end of file diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 7b28966..4d7b1f2 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -8,7 +8,7 @@ class Batch(object): class for batch datasets. Collect multiple observations (actions, rewards, etc.) on-policy. """ - def __init__(self, env, pi, advantage_estimation_function): # how to name the function? + def __init__(self, env, pi, advantage_estimation_function): # how to name the function? self._env = env self._pi = pi self._advantage_estimation_function = advantage_estimation_function @@ -63,7 +63,7 @@ class Batch(object): ob = env.reset() t += 1 - if num_episodes > 0: # YouQiaoben: fix memory growth, both del and gc.collect() fail + if num_episodes > 0: # YouQiaoben: fix memory growth, both del and gc.collect() fail # initialize rawdata lists if not self._is_first_collect: del self.observations @@ -91,10 +91,10 @@ class Batch(object): rewards.append(reward) t_count += 1 - if t_count >= 200: # force episode stop, just to test if memory still grows - break + if t_count >= 100: # force episode stop, just to test if memory still grows + done = True - if done: # end of episode, discard s_T + if done: # end of episode, discard s_T break else: observations.append(ob) @@ -122,7 +122,46 @@ class Batch(object): def apply_advantage_estimation_function(self): self.data = self._advantage_estimation_function(self.raw_data) - def next_batch(self, batch_size): # YouQiaoben: referencing other iterate over batches + def next_batch(self, batch_size, standardize_advantage=True): # YouQiaoben: referencing other iterate over batches rand_idx = np.random.choice(self.data['observations'].shape[0], batch_size) - return {key: value[rand_idx] for key, value in self.data.items()} + current_batch = {key: value[rand_idx] for key, value in self.data.items()} + if standardize_advantage: + advantage_mean = np.mean(current_batch['returns']) + advantage_std = np.std(current_batch['returns']) + current_batch['returns'] = (current_batch['returns'] - advantage_mean) / advantage_std + + return current_batch + + def statistics(self): + """ + compute the statistics of the current sampled paths + :return: + """ + rewards = self.raw_data['rewards'] + episode_start_flags = self.raw_data['episode_start_flags'] + num_timesteps = rewards.shape[0] + + returns = [] + max_return = 0 + episode_start_idx = 0 + for i in range(1, num_timesteps): + if episode_start_flags[i] or ( + i == num_timesteps - 1): # found the start of next episode or the end of all episodes + if i < rewards.shape[0] - 1: + t = i - 1 + else: + t = i + Gt = 0 + while t >= episode_start_idx: + Gt += rewards[t] + t -= 1 + + returns.append(Gt) + if Gt > max_return: + max_return = Gt + episode_start_idx = i + + print('AverageReturn: {}'.format(np.mean(returns))) + print('StdReturn: : {}'.format(np.std(returns))) + print('MaxReturn : {}'.format(max_return)) \ No newline at end of file