diff --git a/examples/.gitignore b/examples/.gitignore index 15584ca..2a48b2c 100644 --- a/examples/.gitignore +++ b/examples/.gitignore @@ -1 +1,2 @@ .pyc +logs/ diff --git a/examples/ppo_cartpole.py b/examples/ppo_cartpole.py index 6fc986f..418cc52 100755 --- a/examples/ppo_cartpole.py +++ b/examples/ppo_cartpole.py @@ -2,6 +2,8 @@ from __future__ import absolute_import import tensorflow as tf +import time +import numpy as np # our lib imports here! It's ok to append path in examples import sys @@ -39,7 +41,11 @@ if __name__ == '__main__': # a clean version with only policy net, no value net clip_param = 0.2 num_batches = 10 - batch_size = 512 + batch_size = 128 + + seed = 10 + np.random.seed(seed) + tf.set_random_seed(seed) # 1. build network with pure tf observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input @@ -80,9 +86,10 @@ if __name__ == '__main__': # a clean version with only policy net, no value net # sync pi and pi_old sess.run([tf.assign(theta_old, theta) for (theta_old, theta) in zip(pi_old_var_list, train_var_list)]) + start_time = time.time() for i in range(100): # until some stopping criterion met... # collect data - training_data.collect(num_episodes=120) # YouQiaoben, ShihongSong + training_data.collect(num_episodes=20) # YouQiaoben, ShihongSong # print current return print('Epoch {}:'.format(i)) @@ -95,4 +102,6 @@ if __name__ == '__main__': # a clean version with only policy net, no value net sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], advantage: data['returns']}) # assigning pi to pi_old - sess.run([tf.assign(theta_old, theta) for (theta_old, theta) in zip(pi_old_var_list, train_var_list)]) \ No newline at end of file + sess.run([tf.assign(theta_old, theta) for (theta_old, theta) in zip(pi_old_var_list, train_var_list)]) + + print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) \ No newline at end of file diff --git a/examples/ppo_cartpole_gym.py b/examples/ppo_cartpole_gym.py new file mode 100755 index 0000000..35ac275 --- /dev/null +++ b/examples/ppo_cartpole_gym.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python +from __future__ import absolute_import + +import tensorflow as tf +import gym +import numpy as np +import time + +# our lib imports here! It's ok to append path in examples +import sys +sys.path.append('..') +from tianshou.core import losses +from tianshou.data.batch import Batch +import tianshou.data.advantage_estimation as advantage_estimation +import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy + +from rllab.envs.box2d.cartpole_env import CartpoleEnv +from rllab.envs.normalized_env import normalize + + +def policy_net(observation, action_dim, scope=None): + """ + Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf + + :param observation: Placeholder for the observation. A tensor of shape (bs, x, y, channels) + :param action_dim: int. The number of actions. + :param scope: str. Specifying the scope of the variables. + """ + # with tf.variable_scope(scope): + net = tf.layers.dense(observation, 32, activation=tf.nn.tanh) + net = tf.layers.dense(net, 32, activation=tf.nn.tanh) + + act_logits = tf.layers.dense(net, action_dim, activation=None) + + return act_logits + + +if __name__ == '__main__': # a clean version with only policy net, no value net + env = gym.make('CartPole-v0') + observation_dim = env.observation_space.shape + action_dim = env.action_space.n + + clip_param = 0.2 + num_batches = 10 + batch_size = 512 + + seed = 10 + np.random.seed(seed) + tf.set_random_seed(seed) + + # 1. build network with pure tf + observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input + + with tf.variable_scope('pi'): + action_logits = policy_net(observation, action_dim, 'pi') + train_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES + with tf.variable_scope('pi_old'): + action_logits_old = policy_net(observation, action_dim, 'pi_old') + pi_old_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'pi_old') + + # 2. build losses, optimizers + pi = policy.OnehotCategorical(action_logits, observation_placeholder=observation) # YongRen: policy.Gaussian (could reference the policy in TRPO paper, my code is adapted from zhusuan.distributions) policy.DQN etc. + # for continuous action space, you may need to change an environment to run + pi_old = policy.OnehotCategorical(action_logits_old, observation_placeholder=observation) + + action = tf.placeholder(dtype=tf.int32, shape=(None,)) # batch of integer actions + advantage = tf.placeholder(dtype=tf.float32, shape=(None,)) # advantage values used in the Gradients + + ppo_loss_clip = losses.ppo_clip(action, advantage, clip_param, pi, pi_old) # TongzhengRen: losses.vpg ... management of placeholders and feed_dict + + total_loss = ppo_loss_clip + optimizer = tf.train.AdamOptimizer(1e-4) + train_op = optimizer.minimize(total_loss, var_list=train_var_list) + + # 3. define data collection + training_data = Batch(env, pi, advantage_estimation.full_return) # YouQiaoben: finish and polish Batch, advantage_estimation.gae_lambda as in PPO paper + # ShihongSong: Replay(), see dqn_example.py + # maybe a dict to manage the elements to be collected + + # 4. start training + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + with tf.Session(config=config) as sess: + sess.run(tf.global_variables_initializer()) + # sync pi and pi_old + sess.run([tf.assign(theta_old, theta) for (theta_old, theta) in zip(pi_old_var_list, train_var_list)]) + + start_time = time.time() + for i in range(100): # until some stopping criterion met... + # collect data + training_data.collect(num_episodes=50) # YouQiaoben, ShihongSong + + # print current return + print('Epoch {}:'.format(i)) + training_data.statistics() + + # update network + for _ in range(num_batches): + data = training_data.next_batch(batch_size) # YouQiaoben, ShihongSong + # TODO: auto managing of the placeholders? or add this to params of data.Batch + sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], + advantage: data['returns']}) + + # assigning pi to pi_old + sess.run([tf.assign(theta_old, theta) for (theta_old, theta) in zip(pi_old_var_list, train_var_list)]) + + print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) \ No newline at end of file diff --git a/tianshou/core/policy/dqn.py b/tianshou/core/policy/dqn.py index 8533549..2f6db5a 100644 --- a/tianshou/core/policy/dqn.py +++ b/tianshou/core/policy/dqn.py @@ -54,7 +54,7 @@ class DQNOld(QValuePolicy): return the action (int) to be executed. no exploration when exploration=None. """ - # TODO: ensure thread safety + # TODO: ensure thread safety, tf.multinomial to init sess = tf.get_default_session() sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1), feed_dict={self._observation_placeholder: observation[None]}) diff --git a/tianshou/core/policy/stochastic.py b/tianshou/core/policy/stochastic.py index e2c2dea..cda204d 100644 --- a/tianshou/core/policy/stochastic.py +++ b/tianshou/core/policy/stochastic.py @@ -35,6 +35,7 @@ class OnehotCategorical(StochasticPolicy): def __init__(self, logits, observation_placeholder, dtype=None, group_ndims=0, **kwargs): self._logits = tf.convert_to_tensor(logits) + self._action = tf.multinomial(self.logits, num_samples=1) if dtype is None: dtype = tf.int32 @@ -65,7 +66,7 @@ class OnehotCategorical(StochasticPolicy): # TODO: this may be ugly. also maybe huge problem when parallel sess = tf.get_default_session() # observation[None] adds one dimension at the beginning - sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1), + sampled_action = sess.run(self._action, feed_dict={self._observation_placeholder: observation[None]}) sampled_action = sampled_action[0, 0] @@ -103,6 +104,9 @@ class Normal(StochasticPolicy): self._logstd = tf.convert_to_tensor(logstd, dtype = tf.float32) self._std = tf.exp(self._logstd) + shape = tf.broadcast_dynamic_shape(tf.shape(self._mean), tf.shape(self._std)) + self._action = tf.random_normal(tf.concat([[1], shape], 0), dtype = tf.float32) * self._std + self._mean + super(Normal, self).__init__( act_dtype = tf.float32, param_dtype = tf.float32, @@ -126,14 +130,9 @@ class Normal(StochasticPolicy): def _act(self, observation): # TODO: getting session like this maybe ugly. also maybe huge problem when parallel sess = tf.get_default_session() - mean, std = self._mean, self._std - shape = tf.broadcast_dynamic_shape(tf.shape(self._mean),\ - tf.shape(self._std)) - # observation[None] adds one dimension at the beginning - sampled_action = sess.run(tf.random_normal(tf.concat([[1], shape], 0), - dtype = tf.float32) * std + mean, + sampled_action = sess.run(self._action, feed_dict={self._observation_placeholder: observation[None]}) sampled_action = sampled_action[0, 0] return sampled_action diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 4d7b1f2..a2a8dde 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -14,19 +14,20 @@ class Batch(object): self._advantage_estimation_function = advantage_estimation_function self._is_first_collect = True + def collect(self, num_timesteps=0, num_episodes=0, + apply_function=True): # specify how many data to collect here, or fix it in __init__() + assert sum( + [num_timesteps > 0, num_episodes > 0]) == 1, "One and only one collection number specification permitted!" - def collect(self, num_timesteps=0, num_episodes=0, apply_function=True): # specify how many data to collect here, or fix it in __init__() - assert sum([num_timesteps > 0, num_episodes > 0]) == 1, "One and only one collection number specification permitted!" - - if num_timesteps > 0: # YouQiaoben: finish this implementation, the following code are just from openai/baselines + if num_timesteps > 0: # YouQiaoben: finish this implementation, the following code are just from openai/baselines t = 0 - ac = self.env.action_space.sample() # not used, just so we have the datatype - new = True # marks if we're on first timestep of an episode + ac = self.env.action_space.sample() # not used, just so we have the datatype + new = True # marks if we're on first timestep of an episode if self.is_first_collect: ob = self.env.reset() self.is_first_collect = False else: - ob = self.raw_data['observations'][0] # last observation! + ob = self.raw_data['observations'][0] # last observation! # Initialize history arrays observations = np.array([ob for _ in range(num_timesteps)]) @@ -76,9 +77,11 @@ class Batch(object): rewards = [] episode_start_flags = [] - t_count = 0 + # t_count = 0 for _ in range(num_episodes): + t_count = 0 + ob = self._env.reset() observations.append(ob) episode_start_flags.append(True) @@ -92,7 +95,7 @@ class Batch(object): t_count += 1 if t_count >= 100: # force episode stop, just to test if memory still grows - done = True + break if done: # end of episode, discard s_T break @@ -110,8 +113,9 @@ class Batch(object): del rewards del episode_start_flags - self.raw_data = {'observations': self.observations, 'actions': self.actions, 'rewards': self.rewards, 'episode_start_flags': self.episode_start_flags} - + self.raw_data = {'observations': self.observations, 'actions': self.actions, 'rewards': self.rewards, + 'episode_start_flags': self.episode_start_flags} + self._is_first_collect = False if apply_function: @@ -133,6 +137,7 @@ class Batch(object): return current_batch + # TODO: this will definitely be refactored with a proper logger def statistics(self): """ compute the statistics of the current sampled paths @@ -143,16 +148,21 @@ class Batch(object): num_timesteps = rewards.shape[0] returns = [] + episode_lengths = [] max_return = 0 + num_episodes = 1 episode_start_idx = 0 for i in range(1, num_timesteps): if episode_start_flags[i] or ( i == num_timesteps - 1): # found the start of next episode or the end of all episodes + if episode_start_flags[i]: + num_episodes += 1 if i < rewards.shape[0] - 1: t = i - 1 else: t = i Gt = 0 + episode_lengths.append(t - episode_start_idx) while t >= episode_start_idx: Gt += rewards[t] t -= 1 @@ -163,5 +173,8 @@ class Batch(object): episode_start_idx = i print('AverageReturn: {}'.format(np.mean(returns))) - print('StdReturn: : {}'.format(np.std(returns))) - print('MaxReturn : {}'.format(max_return)) \ No newline at end of file + print('StdReturn : {}'.format(np.std(returns))) + print('NumEpisodes : {}'.format(num_episodes)) + print('MinMaxReturns: {}..., {}'.format(np.sort(returns)[:3], np.sort(returns)[-3:])) + print('AverageLength: {}'.format(np.mean(episode_lengths))) + print('MinMaxLengths: {}..., {}'.format(np.sort(episode_lengths)[:3], np.sort(episode_lengths)[-3:])) diff --git a/tianshou/data/replay_buffer/binary_heap.py b/tianshou/data/replay_buffer/binary_heap.py index e2b1474..2deac14 100644 --- a/tianshou/data/replay_buffer/binary_heap.py +++ b/tianshou/data/replay_buffer/binary_heap.py @@ -7,7 +7,7 @@ import sys import math -import utility +from . import utility class BinaryHeap(object): diff --git a/tianshou/data/replay_buffer/rank_based.py b/tianshou/data/replay_buffer/rank_based.py index b71ca68..0a6641f 100644 --- a/tianshou/data/replay_buffer/rank_based.py +++ b/tianshou/data/replay_buffer/rank_based.py @@ -154,6 +154,7 @@ class RankBasedExperience(ReplayBuffer): target = list() sess = tf.get_default_session() + # TODO: pre-build the thing in sess.run current_datas, current_wis, current_indexs = self.sample({'global_step': sess.run(tf.train.get_global_step())}) for i in range(0, batch_size):