From ed25bf75861ccf4644f7a23903de061dda0b7d7e Mon Sep 17 00:00:00 2001 From: haoshengzou Date: Wed, 17 Jan 2018 11:55:51 +0800 Subject: [PATCH] fixed the bugs on Jan 14, which gives inferior or even no improvement. mistook group_ndims. policy will soon need refactoring. --- README.md | 24 ++++ examples/actor_critic_cartpole.py | 94 +++++++++++++++ examples/actor_critic_fail_cartpole.py | 98 +++++++++++++++ examples/actor_critic_separate_cartpole.py | 90 ++++++++++++++ examples/contrib_dqn_example.py | 95 +++++++++++++++ examples/dqn_example.py | 126 +++++++++----------- examples/ppo_cartpole.py | 6 +- examples/ppo_cartpole_alternative.py | 15 ++- examples/ppo_cartpole_gym.py | 6 +- internal_keys.md | 15 +++ tianshou/core/losses.py | 31 ++++- tianshou/core/policy/stochastic.py | 12 +- tianshou/core/value_function/state_value.py | 7 +- tianshou/data/advantage_estimation.py | 41 +++++-- tianshou/data/batch.py | 78 ++++++++---- 15 files changed, 619 insertions(+), 119 deletions(-) create mode 100755 examples/actor_critic_cartpole.py create mode 100755 examples/actor_critic_fail_cartpole.py create mode 100755 examples/actor_critic_separate_cartpole.py create mode 100644 examples/contrib_dqn_example.py create mode 100644 internal_keys.md diff --git a/README.md b/README.md index fc7d494..817a5d6 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,30 @@ Tianshou(天授) is a reinforcement learning platform. The following image illus     Specific network architectures in original paper of DQN, TRPO, A3C, etc. Policy-Value Network of AlphaGo Zero +#### brief intro of current implementation: + +how to write your own network: +- define the observation placeholder yourself, pass it to `observation_placeholder` when initializing a policy instance +- pass a callable when initializing a policy instance. The callable should satisfy only three conditions: + - it accepts no parameters + - it does not create any new placeholders + - it returns `action-related tensors, value_head` + +Our lib will take care of your observation placeholder from now on, as well as +all the placeholders that will be created by our lib. + +The other placeholders, such as `keep_prob` in dropout and `clip_param` in ppo loss +should be managed by your own (see examples/ppo_cartpole_alternative.py) + +The `weight_update` parameter: +- 0 means manually update target network +- 1 means no target network (the target network is updated every 1 minibatch) +- (0, 1) is the target network as used in DDPG +- greater than 1 is the target network as used in DQN + +Other comments are in the python files in example/ and in the lib codes. +Refactor is definitely needed so don't dwell too much on annoying details... + ### Algorithm #### losses diff --git a/examples/actor_critic_cartpole.py b/examples/actor_critic_cartpole.py new file mode 100755 index 0000000..c43c99b --- /dev/null +++ b/examples/actor_critic_cartpole.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python +from __future__ import absolute_import + +import tensorflow as tf +import time +import numpy as np + +# our lib imports here! It's ok to append path in examples +import sys +sys.path.append('..') +from tianshou.core import losses +from tianshou.data.batch import Batch +import tianshou.data.advantage_estimation as advantage_estimation +import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy +import tianshou.core.value_function.state_value as value_function + +from rllab.envs.box2d.cartpole_env import CartpoleEnv +from rllab.envs.normalized_env import normalize + + +# for tutorial purpose, placeholders are explicitly appended with '_ph' suffix + +if __name__ == '__main__': + env = normalize(CartpoleEnv()) + observation_dim = env.observation_space.shape + action_dim = env.action_space.flat_dim + + clip_param = 0.2 + num_batches = 10 + batch_size = 128 + + seed = 10 + np.random.seed(seed) + tf.set_random_seed(seed) + + ### 1. build network with pure tf + observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim) + + def my_network(): + # placeholders defined in this function would be very difficult to manage + net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh) + net = tf.layers.dense(net, 32, activation=tf.nn.tanh) + + action_mean = tf.layers.dense(net, action_dim, activation=None) + action_logstd = tf.get_variable('action_logstd', shape=(action_dim, )) + value = tf.layers.dense(net, 1, activation=None) + + return action_mean, action_logstd, value + # TODO: overriding seems not able to handle shared layers, unless a new class `SharedPolicyValue` + # maybe the most desired thing is to freely build policy and value function from any tensor? + # but for now, only the outputs of the network matters + + ### 2. build policy, critic, loss, optimizer + actor = policy.Normal(my_network, observation_placeholder=observation_ph, weight_update=1) + critic = value_function.StateValue(my_network, observation_placeholder=observation_ph) + + actor_loss = losses.REINFORCE(actor) + critic_loss = losses.state_value_mse(critic) + total_loss = actor_loss + critic_loss + + optimizer = tf.train.AdamOptimizer(1e-4) + + # this hack would be unnecessary if we have a `SharedPolicyValue` class, or hack the trainable_variables management + var_list = list(set(actor.trainable_variables + critic.trainable_variables)) + + train_op = optimizer.minimize(total_loss, var_list=var_list) + + ### 3. define data collection + data_collector = Batch(env, actor, + [advantage_estimation.gae_lambda(1, critic), advantage_estimation.nstep_return(1, critic)], + [actor, critic]) + # TODO: refactor this, data_collector should be just the top-level abstraction + + ### 4. start training + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + with tf.Session(config=config) as sess: + sess.run(tf.global_variables_initializer()) + + start_time = time.time() + for i in range(100): + # collect data + data_collector.collect(num_episodes=20) + + # print current return + print('Epoch {}:'.format(i)) + data_collector.statistics() + + # update network + for _ in range(num_batches): + feed_dict = data_collector.next_batch(batch_size) + sess.run(train_op, feed_dict=feed_dict) + + print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) \ No newline at end of file diff --git a/examples/actor_critic_fail_cartpole.py b/examples/actor_critic_fail_cartpole.py new file mode 100755 index 0000000..c942052 --- /dev/null +++ b/examples/actor_critic_fail_cartpole.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python +from __future__ import absolute_import + +import tensorflow as tf +import time +import numpy as np + +# our lib imports here! It's ok to append path in examples +import sys +sys.path.append('..') +from tianshou.core import losses +from tianshou.data.batch import Batch +import tianshou.data.advantage_estimation as advantage_estimation +import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy +import tianshou.core.value_function.state_value as value_function + +from rllab.envs.box2d.cartpole_env import CartpoleEnv +from rllab.envs.normalized_env import normalize + + +# for tutorial purpose, placeholders are explicitly appended with '_ph' suffix + +if __name__ == '__main__': + env = normalize(CartpoleEnv()) + observation_dim = env.observation_space.shape + action_dim = env.action_space.flat_dim + + clip_param = 0.2 + num_batches = 10 + batch_size = 128 + + seed = 10 + np.random.seed(seed) + tf.set_random_seed(seed) + + ### 1. build network with pure tf + observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim) + + def my_actor(): + net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh) + net = tf.layers.dense(net, 32, activation=tf.nn.tanh) + + action_mean = tf.layers.dense(net, action_dim, activation=None) + action_logstd = tf.get_variable('action_logstd', shape=(action_dim, )) + + return action_mean, action_logstd, None + + def my_critic(): + net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh) + net = tf.layers.dense(net, 32, activation=tf.nn.tanh) + value = tf.layers.dense(net, 1, activation=None) + + return None, value + + ### 2. build policy, critic, loss, optimizer + actor = policy.Normal(my_actor, observation_placeholder=observation_ph, weight_update=1) + critic = value_function.StateValue(my_critic, observation_placeholder=observation_ph) + + print('actor and critic will share variables in this case') + sys.exit() + + actor_loss = losses.vanilla_policy_gradient(actor) + critic_loss = losses.state_value_mse(critic) + total_loss = actor_loss + critic_loss + + optimizer = tf.train.AdamOptimizer(1e-4) + train_op = optimizer.minimize(total_loss, var_list=actor.trainable_variables) + + ### 3. define data collection + training_data = Batch(env, actor, advantage_estimation.full_return) + + ### 4. start training + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + with tf.Session(config=config) as sess: + sess.run(tf.global_variables_initializer()) + + # assign actor to pi_old + actor.sync_weights() # TODO: automate this for policies with target network + + start_time = time.time() + for i in range(100): + # collect data + training_data.collect(num_episodes=20) + + # print current return + print('Epoch {}:'.format(i)) + training_data.statistics() + + # update network + for _ in range(num_batches): + feed_dict = training_data.next_batch(batch_size) + sess.run(train_op, feed_dict=feed_dict) + + # assigning actor to pi_old + actor.update_weights() + + print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) \ No newline at end of file diff --git a/examples/actor_critic_separate_cartpole.py b/examples/actor_critic_separate_cartpole.py new file mode 100755 index 0000000..87d6c43 --- /dev/null +++ b/examples/actor_critic_separate_cartpole.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python +from __future__ import absolute_import + +import tensorflow as tf +import time +import numpy as np + +# our lib imports here! It's ok to append path in examples +import sys +sys.path.append('..') +from tianshou.core import losses +from tianshou.data.batch import Batch +import tianshou.data.advantage_estimation as advantage_estimation +import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy +import tianshou.core.value_function.state_value as value_function + +from rllab.envs.box2d.cartpole_env import CartpoleEnv +from rllab.envs.normalized_env import normalize + + +# for tutorial purpose, placeholders are explicitly appended with '_ph' suffix + +if __name__ == '__main__': + env = normalize(CartpoleEnv()) + observation_dim = env.observation_space.shape + action_dim = env.action_space.flat_dim + + clip_param = 0.2 + num_batches = 10 + batch_size = 128 + + seed = 10 + np.random.seed(seed) + tf.set_random_seed(seed) + + ### 1. build network with pure tf + observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim) + + def my_network(): + net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh) + net = tf.layers.dense(net, 32, activation=tf.nn.tanh) + + action_mean = tf.layers.dense(net, action_dim, activation=None) + action_logstd = tf.get_variable('action_logstd', shape=(action_dim, )) + + net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh) + net = tf.layers.dense(net, 32, activation=tf.nn.tanh) + value = tf.layers.dense(net, 1, activation=None) + + return action_mean, action_logstd, value + + ### 2. build policy, critic, loss, optimizer + actor = policy.Normal(my_network, observation_placeholder=observation_ph, weight_update=1) + critic = value_function.StateValue(my_network, observation_placeholder=observation_ph) + + actor_loss = losses.REINFORCE(actor) + critic_loss = losses.state_value_mse(critic) + + actor_optimizer = tf.train.AdamOptimizer(1e-4) + actor_train_op = actor_optimizer.minimize(actor_loss, var_list=actor.trainable_variables) + + critic_optimizer = tf.train.RMSPropOptimizer(1e-4) + critic_train_op = critic_optimizer.minimize(critic_loss, var_list=critic.trainable_variables) + + ### 3. define data collection + data_collector = Batch(env, actor, + [advantage_estimation.gae_lambda(1, critic), advantage_estimation.nstep_return(1, critic)], + [actor, critic]) + + ### 4. start training + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + with tf.Session(config=config) as sess: + sess.run(tf.global_variables_initializer()) + + start_time = time.time() + for i in range(100): + # collect data + data_collector.collect(num_episodes=20) + + # print current return + print('Epoch {}:'.format(i)) + data_collector.statistics() + + # update network + for _ in range(num_batches): + feed_dict = data_collector.next_batch(batch_size) + sess.run([actor_train_op, critic_train_op], feed_dict=feed_dict) + + print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) \ No newline at end of file diff --git a/examples/contrib_dqn_example.py b/examples/contrib_dqn_example.py new file mode 100644 index 0000000..4b97ea8 --- /dev/null +++ b/examples/contrib_dqn_example.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python + +import tensorflow as tf +import gym + +# our lib imports here! +import sys +sys.path.append('..') +import tianshou.core.losses as losses +from tianshou.data.replay_buffer.utils import get_replay_buffer +import tianshou.core.policy.dqn as policy + + +# THIS EXAMPLE IS NOT FINISHED YET!!! + + +def policy_net(observation, action_dim): + """ + Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf + + :param observation: Placeholder for the observation. A tensor of shape (bs, x, y, channels) + :param action_dim: int. The number of actions. + :param scope: str. Specifying the scope of the variables. + """ + net = tf.layers.conv2d(observation, 16, 8, 4, 'valid', activation=tf.nn.relu) + net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu) + net = tf.layers.flatten(net) + net = tf.layers.dense(net, 256, activation=tf.nn.relu) + + q_values = tf.layers.dense(net, action_dim) + + return q_values + + +if __name__ == '__main__': + env = gym.make('PongNoFrameskip-v4') + observation_dim = env.observation_space.shape + action_dim = env.action_space.n + + # 1. build network with pure tf + # TODO: + # pass the observation variable to the replay buffer or find a more reasonable way to help replay buffer + # access this observation variable. + observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim, name="dqn_observation") # network input + action = tf.placeholder(dtype=tf.int32, shape=(None,)) # batch of integer actions + + + with tf.variable_scope('q_net'): + q_values = policy_net(observation, action_dim) + with tf.variable_scope('target_net'): + q_values_target = policy_net(observation, action_dim) + + # 2. build losses, optimizers + q_net = policy.DQNRefactor(q_values, observation_placeholder=observation, action_placeholder=action) # YongRen: policy.DQN + target_net = policy.DQNRefactor(q_values_target, observation_placeholder=observation, action_placeholder=action) + + target = tf.placeholder(dtype=tf.float32, shape=[None]) # target value for DQN + + dqn_loss = losses.dqn_loss(action, target, q_net) # TongzhengRen + global_step = tf.Variable(0, name='global_step', trainable=False) + train_var_list = tf.get_collection( + tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES + total_loss = dqn_loss + optimizer = tf.train.AdamOptimizer(1e-3) + train_op = optimizer.minimize(total_loss, var_list=train_var_list, global_step=tf.train.get_global_step()) + # 3. define data collection + # configuration should be given as parameters, different replay buffer has different parameters. + replay_memory = get_replay_buffer('rank_based', env, q_values, q_net, target_net, + {'size': 1000, 'batch_size': 64, 'learn_start': 20}) + # ShihongSong: Replay(env, q_net, advantage_estimation.qlearning_target(target_network)), use your ReplayMemory, interact as follows. Simplify your advantage_estimation.dqn to run before YongRen's DQN + # maybe a dict to manage the elements to be collected + + # 4. start training + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + + minibatch_count = 0 + collection_count = 0 + # need to first collect some then sample, collect_freq must be larger than batch_size + collect_freq = 100 + while True: # until some stopping criterion met... + # collect data + for i in range(0, collect_freq): + replay_memory.collect() # ShihongSong + collection_count += 1 + print('Collected {} times.'.format(collection_count)) + + # update network + data = replay_memory.next_batch(10) # YouQiaoben, ShihongSong + # TODO: auto managing of the placeholders? or add this to params of data.Batch + sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], target: data['target']}) + minibatch_count += 1 + print('Trained {} minibatches.'.format(minibatch_count)) + + # TODO: assigning pi to pi_old is not implemented yet \ No newline at end of file diff --git a/examples/dqn_example.py b/examples/dqn_example.py index 4b97ea8..f822e59 100644 --- a/examples/dqn_example.py +++ b/examples/dqn_example.py @@ -1,95 +1,83 @@ #!/usr/bin/env python +from __future__ import absolute_import import tensorflow as tf import gym +import numpy as np +import time -# our lib imports here! +# our lib imports here! It's ok to append path in examples import sys sys.path.append('..') -import tianshou.core.losses as losses -from tianshou.data.replay_buffer.utils import get_replay_buffer -import tianshou.core.policy.dqn as policy - - -# THIS EXAMPLE IS NOT FINISHED YET!!! - - -def policy_net(observation, action_dim): - """ - Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf - - :param observation: Placeholder for the observation. A tensor of shape (bs, x, y, channels) - :param action_dim: int. The number of actions. - :param scope: str. Specifying the scope of the variables. - """ - net = tf.layers.conv2d(observation, 16, 8, 4, 'valid', activation=tf.nn.relu) - net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu) - net = tf.layers.flatten(net) - net = tf.layers.dense(net, 256, activation=tf.nn.relu) - - q_values = tf.layers.dense(net, action_dim) - - return q_values +from tianshou.core import losses +from tianshou.data.batch import Batch +import tianshou.data.advantage_estimation as advantage_estimation +import tianshou.core.policy.dqn as policy # TODO: fix imports as zhusuan so that only need to import to policy if __name__ == '__main__': - env = gym.make('PongNoFrameskip-v4') + env = gym.make('CartPole-v0') observation_dim = env.observation_space.shape action_dim = env.action_space.n - # 1. build network with pure tf - # TODO: - # pass the observation variable to the replay buffer or find a more reasonable way to help replay buffer - # access this observation variable. - observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim, name="dqn_observation") # network input - action = tf.placeholder(dtype=tf.int32, shape=(None,)) # batch of integer actions + clip_param = 0.2 + num_batches = 10 + batch_size = 512 + seed = 0 + np.random.seed(seed) + tf.set_random_seed(seed) - with tf.variable_scope('q_net'): - q_values = policy_net(observation, action_dim) - with tf.variable_scope('target_net'): - q_values_target = policy_net(observation, action_dim) + ### 1. build network with pure tf + observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim) - # 2. build losses, optimizers - q_net = policy.DQNRefactor(q_values, observation_placeholder=observation, action_placeholder=action) # YongRen: policy.DQN - target_net = policy.DQNRefactor(q_values_target, observation_placeholder=observation, action_placeholder=action) + def my_policy(): + net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh) + net = tf.layers.dense(net, 32, activation=tf.nn.tanh) - target = tf.placeholder(dtype=tf.float32, shape=[None]) # target value for DQN + action_values = tf.layers.dense(net, action_dim, activation=None) + + return action_values, None # None value head + + # TODO: current implementation of passing function or overriding function has to return a value head + # to allow network sharing between policy and value networks. This makes 'policy' and 'value_function' + # imbalanced semantically (though they are naturally imbalanced since 'policy' is required to interact + # with the environment and 'value_function' is not). I have an idea to solve this imbalance, which is + # not based on passing function or overriding function. + + ### 2. build policy, loss, optimizer + pi = policy.DQN(my_policy, observation_placeholder=observation_ph, weight_update=10) + + dqn_loss = losses.qlearning(pi) - dqn_loss = losses.dqn_loss(action, target, q_net) # TongzhengRen - global_step = tf.Variable(0, name='global_step', trainable=False) - train_var_list = tf.get_collection( - tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES total_loss = dqn_loss - optimizer = tf.train.AdamOptimizer(1e-3) - train_op = optimizer.minimize(total_loss, var_list=train_var_list, global_step=tf.train.get_global_step()) - # 3. define data collection - # configuration should be given as parameters, different replay buffer has different parameters. - replay_memory = get_replay_buffer('rank_based', env, q_values, q_net, target_net, - {'size': 1000, 'batch_size': 64, 'learn_start': 20}) - # ShihongSong: Replay(env, q_net, advantage_estimation.qlearning_target(target_network)), use your ReplayMemory, interact as follows. Simplify your advantage_estimation.dqn to run before YongRen's DQN - # maybe a dict to manage the elements to be collected + optimizer = tf.train.AdamOptimizer(1e-4) + train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables) - # 4. start training - with tf.Session() as sess: + ### 3. define data collection + data_collector = Batch(env, pi, [advantage_estimation.nstep_q_return(1, pi.target_network)], [pi]) + + ### 4. start training + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) - minibatch_count = 0 - collection_count = 0 - # need to first collect some then sample, collect_freq must be larger than batch_size - collect_freq = 100 - while True: # until some stopping criterion met... + # assign actor to pi_old + pi.sync_weights() # TODO: automate this for policies with target network + + start_time = time.time() + for i in range(100): # collect data - for i in range(0, collect_freq): - replay_memory.collect() # ShihongSong - collection_count += 1 - print('Collected {} times.'.format(collection_count)) + data_collector.collect(num_episodes=50) + + # print current return + print('Epoch {}:'.format(i)) + data_collector.statistics() # update network - data = replay_memory.next_batch(10) # YouQiaoben, ShihongSong - # TODO: auto managing of the placeholders? or add this to params of data.Batch - sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], target: data['target']}) - minibatch_count += 1 - print('Trained {} minibatches.'.format(minibatch_count)) + for _ in range(num_batches): + feed_dict = data_collector.next_batch(batch_size) + sess.run(train_op, feed_dict=feed_dict) - # TODO: assigning pi to pi_old is not implemented yet \ No newline at end of file + print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) \ No newline at end of file diff --git a/examples/ppo_cartpole.py b/examples/ppo_cartpole.py index 05d317a..faee623 100755 --- a/examples/ppo_cartpole.py +++ b/examples/ppo_cartpole.py @@ -61,7 +61,7 @@ if __name__ == '__main__': train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables) ### 3. define data collection - training_data = Batch(env, pi, advantage_estimation.full_return) + training_data = Batch(env, pi, [advantage_estimation.full_return], [pi]) ### 4. start training config = tf.ConfigProto() @@ -69,7 +69,7 @@ if __name__ == '__main__': with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) - # assign pi to pi_old + # assign actor to pi_old pi.sync_weights() # TODO: automate this for policies with target network start_time = time.time() @@ -86,7 +86,7 @@ if __name__ == '__main__': feed_dict = training_data.next_batch(batch_size) sess.run(train_op, feed_dict=feed_dict) - # assigning pi to pi_old + # assigning actor to pi_old pi.update_weights() print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) \ No newline at end of file diff --git a/examples/ppo_cartpole_alternative.py b/examples/ppo_cartpole_alternative.py index b76e634..e989cd8 100755 --- a/examples/ppo_cartpole_alternative.py +++ b/examples/ppo_cartpole_alternative.py @@ -47,7 +47,7 @@ if __name__ == '__main__': observation_dim = env.observation_space.shape action_dim = env.action_space.flat_dim - clip_param = 0.2 + # clip_param = 0.2 num_batches = 10 batch_size = 128 @@ -65,6 +65,7 @@ if __name__ == '__main__': ### 2. build policy, loss, optimizer pi = policy.Normal(my_policy, observation_placeholder=observation_ph, weight_update=0) + clip_param = tf.placeholder(tf.float32, shape=(), name='ppo_loss_clip_param') ppo_loss_clip = losses.ppo_clip(pi, clip_param) total_loss = ppo_loss_clip @@ -72,7 +73,7 @@ if __name__ == '__main__': train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables) ### 3. define data collection - training_data = Batch(env, pi, advantage_estimation.full_return) + training_data = Batch(env, pi, [advantage_estimation.full_return], [pi]) ### 4. start training feed_dict_train = {is_training_ph: True, keep_prob_ph: 0.8} @@ -83,7 +84,7 @@ if __name__ == '__main__': with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) - # assign pi to pi_old + # assign actor to pi_old pi.sync_weights() # TODO: automate this for policies with target network start_time = time.time() @@ -95,13 +96,19 @@ if __name__ == '__main__': print('Epoch {}:'.format(i)) training_data.statistics() + # manipulate decay_param + if i < 30: + feed_dict_train[clip_param] = 0.2 + else: + feed_dict_train[clip_param] = 0.1 + # update network for _ in range(num_batches): feed_dict = training_data.next_batch(batch_size) feed_dict.update(feed_dict_train) sess.run(train_op, feed_dict=feed_dict) - # assigning pi to pi_old + # assigning actor to pi_old pi.update_weights() # approximate test mode diff --git a/examples/ppo_cartpole_gym.py b/examples/ppo_cartpole_gym.py index 2710c98..46f7fad 100755 --- a/examples/ppo_cartpole_gym.py +++ b/examples/ppo_cartpole_gym.py @@ -55,7 +55,7 @@ if __name__ == '__main__': train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables) ### 3. define data collection - training_data = Batch(env, pi, advantage_estimation.full_return) + training_data = Batch(env, pi, [advantage_estimation.full_return], [pi]) ### 4. start training config = tf.ConfigProto() @@ -63,7 +63,7 @@ if __name__ == '__main__': with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) - # assign pi to pi_old + # assign actor to pi_old pi.sync_weights() # TODO: automate this for policies with target network start_time = time.time() @@ -80,7 +80,7 @@ if __name__ == '__main__': feed_dict = training_data.next_batch(batch_size) sess.run(train_op, feed_dict=feed_dict) - # assigning pi to pi_old + # assigning actor to pi_old pi.update_weights() print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) \ No newline at end of file diff --git a/internal_keys.md b/internal_keys.md new file mode 100644 index 0000000..9f7e4bd --- /dev/null +++ b/internal_keys.md @@ -0,0 +1,15 @@ +network.managed_placeholders.keys() + +data_collector.raw_data.keys() + +data_collector.data.keys() + +['observation'] + +['action'] + +['reward'] + +['start_flag'] + +['advantage'] > ['return'] # they may appear simultaneously \ No newline at end of file diff --git a/tianshou/core/losses.py b/tianshou/core/losses.py index e1c48d4..69ac79e 100644 --- a/tianshou/core/losses.py +++ b/tianshou/core/losses.py @@ -14,7 +14,7 @@ def ppo_clip(policy, clip_param): action_ph = tf.placeholder(policy.act_dtype, shape=(None,) + policy.action_shape, name='ppo_clip_loss/action_placeholder') advantage_ph = tf.placeholder(tf.float32, shape=(None,), name='ppo_clip_loss/advantage_placeholder') policy.managed_placeholders['action'] = action_ph - policy.managed_placeholders['processed_reward'] = advantage_ph + policy.managed_placeholders['advantage'] = advantage_ph log_pi_act = policy.log_prob(action_ph) log_pi_old_act = policy.log_prob_old(action_ph) @@ -24,7 +24,7 @@ def ppo_clip(policy, clip_param): return ppo_clip_loss -def vanilla_policy_gradient(sampled_action, reward, pi, baseline="None"): +def REINFORCE(policy): """ vanilla policy gradient @@ -34,10 +34,29 @@ def vanilla_policy_gradient(sampled_action, reward, pi, baseline="None"): :param baseline: the baseline method used to reduce the variance, default is 'None' :return: """ - log_pi_act = pi.log_prob(sampled_action) - vanilla_policy_gradient_loss = tf.reduce_mean(reward * log_pi_act) - # TODO: Different baseline methods like REINFORCE, etc. - return vanilla_policy_gradient_loss + action_ph = tf.placeholder(policy.act_dtype, shape=(None,) + policy.action_shape, + name='REINFORCE/action_placeholder') + advantage_ph = tf.placeholder(tf.float32, shape=(None,), name='REINFORCE/advantage_placeholder') + policy.managed_placeholders['action'] = action_ph + policy.managed_placeholders['advantage'] = advantage_ph + + log_pi_act = policy.log_prob(action_ph) + REINFORCE_loss = -tf.reduce_mean(advantage_ph * log_pi_act) + return REINFORCE_loss + + +def state_value_mse(state_value_function): + """ + L2 loss of state value + :param state_value_function: instance of StateValue + :return: tensor of the mse loss + """ + state_value_ph = tf.placeholder(tf.float32, shape=(None,), name='state_value_mse/state_value_placeholder') + state_value_function.managed_placeholders['return'] = state_value_ph + + state_value = state_value_function.value_tensor + return tf.losses.mean_squared_error(state_value_ph, state_value) + def dqn_loss(sampled_action, sampled_target, policy): """ diff --git a/tianshou/core/policy/stochastic.py b/tianshou/core/policy/stochastic.py index 294c21f..3ac82f0 100644 --- a/tianshou/core/policy/stochastic.py +++ b/tianshou/core/policy/stochastic.py @@ -44,7 +44,8 @@ class OnehotCategorical(StochasticPolicy): self.weight_update = weight_update self.interaction_count = -1 # defaults to -1. only useful if weight_update > 1. - with tf.variable_scope('network'): + # build network, action and value + with tf.variable_scope('network', reuse=tf.AUTO_REUSE): logits, value_head = policy_callable() self._logits = tf.convert_to_tensor(logits, dtype=tf.float32) self._action = tf.multinomial(self._logits, num_samples=1) @@ -55,11 +56,12 @@ class OnehotCategorical(StochasticPolicy): self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network') + # deal with target network if self.weight_update == 1: self.weight_update_ops = None self.sync_weights_ops = None else: # then we need to build another tf graph as target network - with tf.variable_scope('net_old'): + with tf.variable_scope('net_old', reuse=tf.AUTO_REUSE): logits, value_head = policy_callable() self._logits_old = tf.convert_to_tensor(logits, dtype=tf.float32) @@ -173,7 +175,8 @@ class Normal(StochasticPolicy): self.weight_update = weight_update self.interaction_count = -1 # defaults to -1. only useful if weight_update > 1. - with tf.variable_scope('network'): + # build network, action and value + with tf.variable_scope('network', reuse=tf.AUTO_REUSE): mean, logstd, value_head = policy_callable() self._mean = tf.convert_to_tensor(mean, dtype = tf.float32) self._logstd = tf.convert_to_tensor(logstd, dtype = tf.float32) @@ -188,11 +191,12 @@ class Normal(StochasticPolicy): self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network') + # deal with target network if self.weight_update == 1: self.weight_update_ops = None self.sync_weights_ops = None else: # then we need to build another tf graph as target network - with tf.variable_scope('net_old'): + with tf.variable_scope('net_old', reuse=tf.AUTO_REUSE): mean, logstd, value_head = policy_callable() self._mean_old = tf.convert_to_tensor(mean, dtype=tf.float32) self._logstd_old = tf.convert_to_tensor(logstd, dtype=tf.float32) diff --git a/tianshou/core/value_function/state_value.py b/tianshou/core/value_function/state_value.py index 02c12fe..79464d8 100644 --- a/tianshou/core/value_function/state_value.py +++ b/tianshou/core/value_function/state_value.py @@ -8,7 +8,12 @@ class StateValue(ValueFunctionBase): """ class of state values V(s). """ - def __init__(self, value_tensor, observation_placeholder): + def __init__(self, policy_callable, observation_placeholder): + self.managed_placeholders = {'observation': observation_placeholder} + with tf.variable_scope('network', reuse=tf.AUTO_REUSE): + value_tensor = policy_callable()[-1] + self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) + super(StateValue, self).__init__( value_tensor=value_tensor, observation_placeholder=observation_placeholder diff --git a/tianshou/data/advantage_estimation.py b/tianshou/data/advantage_estimation.py index 3c2d644..a1f1978 100644 --- a/tianshou/data/advantage_estimation.py +++ b/tianshou/data/advantage_estimation.py @@ -6,15 +6,13 @@ def full_return(raw_data): naively compute full return :param raw_data: dict of specified keys and values. """ - observations = raw_data['observations'] - actions = raw_data['actions'] - rewards = raw_data['rewards'] - episode_start_flags = raw_data['episode_start_flags'] + observations = raw_data['observation'] + actions = raw_data['action'] + rewards = raw_data['reward'] + episode_start_flags = raw_data['end_flag'] num_timesteps = rewards.shape[0] data = {} - data['observations'] = observations - data['actions'] = actions returns = rewards.copy() episode_start_idx = 0 @@ -33,11 +31,39 @@ def full_return(raw_data): episode_start_idx = i - data['returns'] = returns + data['return'] = returns return data +class gae_lambda: + """ + Generalized Advantage Estimation (Schulman, 15) to compute advantage + """ + def __init__(self, T, value_function): + self.T = T + self.value_function = value_function + + def __call__(self, raw_data): + reward = raw_data['reward'] + + return {'advantage': reward} + + +class nstep_return: + """ + compute the n-step return from n-step rewards and bootstrapped value function + """ + def __init__(self, n, value_function): + self.n = n + self.value_function = value_function + + def __call__(self, raw_data): + reward = raw_data['reward'] + + return {'return': reward} + + class QLearningTarget: def __init__(self, policy, gamma): self._policy = policy @@ -68,3 +94,4 @@ class QLearningTarget: data['rewards'] = np.array(rewards) return data + diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 24d0af7..1dce932 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -1,6 +1,7 @@ import numpy as np import gc - +import logging +from . import utils # TODO: Refactor with tf.train.slice_input_producer, tf.train.Coordinator, tf.train.QueueRunner class Batch(object): @@ -8,14 +9,31 @@ class Batch(object): class for batch datasets. Collect multiple observations (actions, rewards, etc.) on-policy. """ - def __init__(self, env, pi, advantage_estimation_function): # how to name the function? + def __init__(self, env, pi, reward_processors, networks): # how to name the function? + """ + constructor + :param env: + :param pi: + :param reward_processors: list of functions to process reward + :param networks: list of networks to be optimized, so as to match data in feed_dict + """ self._env = env self._pi = pi - self._advantage_estimation_function = advantage_estimation_function + self.raw_data = {} + self.data = {} + + self.reward_processors = reward_processors + self.networks = networks + + self.required_placeholders = {} + for net in self.networks: + self.required_placeholders.update(net.managed_placeholders) + self.require_advantage = 'advantage' in self.required_placeholders.keys() + self._is_first_collect = True def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={}, - apply_function=True): # specify how many data to collect here, or fix it in __init__() + process_reward=True): # specify how many data to collect here, or fix it in __init__() assert sum( [num_timesteps > 0, num_episodes > 0]) == 1, "One and only one collection number specification permitted!" @@ -98,6 +116,7 @@ class Batch(object): break if done: # end of episode, discard s_T + # TODO: for num_timesteps collection, has to store terminal flag instead of start flag! break else: observations.append(ob) @@ -113,33 +132,48 @@ class Batch(object): del rewards del episode_start_flags - self.raw_data = {'observations': self.observations, 'actions': self.actions, 'rewards': self.rewards, - 'episode_start_flags': self.episode_start_flags} + self.raw_data = {'observation': self.observations, 'action': self.actions, 'reward': self.rewards, + 'end_flag': self.episode_start_flags} self._is_first_collect = False - if apply_function: + if process_reward: self.apply_advantage_estimation_function() gc.collect() def apply_advantage_estimation_function(self): - self.data = self._advantage_estimation_function(self.raw_data) + for processor in self.reward_processors: + self.data.update(processor(self.raw_data)) - def next_batch(self, batch_size, standardize_advantage=True): # YouQiaoben: referencing other iterate over batches - rand_idx = np.random.choice(self.data['observations'].shape[0], batch_size) - current_batch = {key: value[rand_idx] for key, value in self.data.items()} - - if standardize_advantage: - advantage_mean = np.mean(current_batch['returns']) - advantage_std = np.std(current_batch['returns']) - current_batch['returns'] = (current_batch['returns'] - advantage_mean) / advantage_std + def next_batch(self, batch_size, standardize_advantage=True): + rand_idx = np.random.choice(self.raw_data['observation'].shape[0], batch_size) feed_dict = {} - feed_dict[self._pi.managed_placeholders['observation']] = current_batch['observations'] - feed_dict[self._pi.managed_placeholders['action']] = current_batch['actions'] - feed_dict[self._pi.managed_placeholders['processed_reward']] = current_batch['returns'] - # TODO: should use the keys in pi.managed_placeholders to find values in self.data and self.raw_data + for key, placeholder in self.required_placeholders.items(): + found, data_key = utils.internal_key_match(key, self.raw_data.keys()) + if found: + feed_dict[placeholder] = self.raw_data[data_key][rand_idx] + else: + found, data_key = utils.internal_key_match(key, self.data.keys()) + if found: + feed_dict[placeholder] = self.data[data_key][rand_idx] + + if not found: + raise TypeError('Placeholder {} has no value to feed!'.format(str(placeholder.name))) + + if standardize_advantage: + if self.require_advantage: + advantage_value = feed_dict[self.required_placeholders['advantage']] + advantage_mean = np.mean(advantage_value) + advantage_std = np.std(advantage_value) + if advantage_std < 1e-3: + logging.warning('advantage_std too small (< 1e-3) for advantage standardization. may cause numerical issues') + feed_dict[self.required_placeholders['advantage']] = (advantage_value - advantage_mean) / advantage_std + + # TODO: maybe move all advantage estimation functions to tf, as in tensorforce (though haven't + # understood tensorforce after reading) maybe tf.stop_gradient for targets/advantages + # this will simplify data collector as it only needs to collect raw data, (s, a, r, done) only return feed_dict @@ -149,8 +183,8 @@ class Batch(object): compute the statistics of the current sampled paths :return: """ - rewards = self.raw_data['rewards'] - episode_start_flags = self.raw_data['episode_start_flags'] + rewards = self.raw_data['reward'] + episode_start_flags = self.raw_data['end_flag'] num_timesteps = rewards.shape[0] returns = []