From a86354834c08e4f038d7c85c01d06a696db7ac82 Mon Sep 17 00:00:00 2001 From: haoshengzou Date: Sun, 11 Mar 2018 15:07:41 +0800 Subject: [PATCH] actor critic also works. fix some bugs in nstep_q_return. dqn still trains slow. --- ...tor_critic_cartpole.py => actor_critic.py} | 41 ++++++----- examples/{dqn_replay.py => dqn.py} | 0 examples/{ppo_cartpole.py => ppo.py} | 0 tianshou/data/advantage_estimation.py | 70 +++++++++++++++---- 4 files changed, 79 insertions(+), 32 deletions(-) rename examples/{actor_critic_cartpole.py => actor_critic.py} (72%) rename examples/{dqn_replay.py => dqn.py} (100%) rename examples/{ppo_cartpole.py => ppo.py} (100%) diff --git a/examples/actor_critic_cartpole.py b/examples/actor_critic.py similarity index 72% rename from examples/actor_critic_cartpole.py rename to examples/actor_critic.py index 4aa1020..16b60bb 100755 --- a/examples/actor_critic_cartpole.py +++ b/examples/actor_critic.py @@ -5,18 +5,20 @@ import tensorflow as tf import time import numpy as np import gym +import logging +logging.basicConfig(level=logging.INFO) # our lib imports here! It's ok to append path in examples import sys sys.path.append('..') from tianshou.core import losses -from tianshou.data.batch import Batch import tianshou.data.advantage_estimation as advantage_estimation -import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan so that only need to import to policy +import tianshou.core.policy.stochastic as policy import tianshou.core.value_function.state_value as value_function +from tianshou.data.data_buffer.batch_set import BatchSet +from tianshou.data.data_collector import DataCollector -# for tutorial purpose, placeholders are explicitly appended with '_ph' suffix if __name__ == '__main__': env = gym.make('CartPole-v0') @@ -25,9 +27,9 @@ if __name__ == '__main__': clip_param = 0.2 num_batches = 10 - batch_size = 128 + batch_size = 512 - seed = 10 + seed = 0 np.random.seed(seed) tf.set_random_seed(seed) @@ -36,13 +38,13 @@ if __name__ == '__main__': def my_network(): # placeholders defined in this function would be very difficult to manage - net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh) - net = tf.layers.dense(net, 32, activation=tf.nn.tanh) + net = tf.layers.dense(observation_ph, 64, activation=tf.nn.tanh) + net = tf.layers.dense(net, 64, activation=tf.nn.tanh) - action_logtis = tf.layers.dense(net, action_dim, activation=None) + action_logits = tf.layers.dense(net, action_dim, activation=None) value = tf.layers.dense(net, 1, activation=None) - return action_logtis, value + return action_logits, value # TODO: overriding seems not able to handle shared layers, unless a new class `SharedPolicyValue` # maybe the most desired thing is to freely build policy and value function from any tensor? # but for now, only the outputs of the network matters @@ -53,7 +55,7 @@ if __name__ == '__main__': actor_loss = losses.REINFORCE(actor) critic_loss = losses.value_mse(critic) - total_loss = actor_loss + critic_loss + total_loss = actor_loss + 1e-2 * critic_loss optimizer = tf.train.AdamOptimizer(1e-4) @@ -63,10 +65,15 @@ if __name__ == '__main__': train_op = optimizer.minimize(total_loss, var_list=var_list) ### 3. define data collection - data_collector = Batch(env, actor, - [advantage_estimation.gae_lambda(1, critic), advantage_estimation.nstep_return(1, critic)], - [actor, critic]) - # TODO: refactor this, data_collector should be just the top-level abstraction + data_buffer = BatchSet() + + data_collector = DataCollector( + env=env, + policy=actor, + data_buffer=data_buffer, + process_functions=[advantage_estimation.nstep_return(n=3, value_function=critic, return_advantage=True)], + managed_networks=[actor, critic], + ) ### 4. start training config = tf.ConfigProto() @@ -75,13 +82,13 @@ if __name__ == '__main__': sess.run(tf.global_variables_initializer()) start_time = time.time() - for i in range(100): + for i in range(int(1e6)): # collect data - data_collector.collect(num_episodes=20) + data_collector.collect(num_episodes=50) # print current return print('Epoch {}:'.format(i)) - data_collector.statistics() + data_buffer.statistics() # update network for _ in range(num_batches): diff --git a/examples/dqn_replay.py b/examples/dqn.py similarity index 100% rename from examples/dqn_replay.py rename to examples/dqn.py diff --git a/examples/ppo_cartpole.py b/examples/ppo.py similarity index 100% rename from examples/ppo_cartpole.py rename to examples/ppo.py diff --git a/tianshou/data/advantage_estimation.py b/tianshou/data/advantage_estimation.py index 151b260..621684e 100644 --- a/tianshou/data/advantage_estimation.py +++ b/tianshou/data/advantage_estimation.py @@ -6,7 +6,8 @@ ACTION = 1 REWARD = 2 DONE = 3 -# modified for new interfaces + +# TODO: add discount_factor... maybe make it to be a global config? def full_return(buffer, indexes=None): """ naively compute full return @@ -67,18 +68,59 @@ class nstep_return: """ compute the n-step return from n-step rewards and bootstrapped value function """ - def __init__(self, n, value_function): + def __init__(self, n, value_function, return_advantage=False, discount_factor=0.99): self.n = n self.value_function = value_function + self.return_advantage = return_advantage + self.discount_factor = discount_factor - def __call__(self, buffer, index=None): + def __call__(self, buffer, indexes=None): """ :param buffer: buffer with property index and data. index determines the current content in `buffer`. - :param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within + :param indexes: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within each episode. :return: dict with key 'return' and value the computed returns corresponding to `index`. """ - pass + indexes = indexes or buffer.index + episodes = buffer.data + returns = [] + advantages = [] + + for i_episode in range(len(indexes)): + index_this = indexes[i_episode] + if index_this: + episode = episodes[i_episode] + returns_this = [] + advantages_this = [] + + for i in index_this: + current_discount_factor = 1. + last_frame_index = i + return_ = 0. + for last_frame_index in range(i, min(len(episode), i + self.n)): + return_ += current_discount_factor * episode[last_frame_index][REWARD] + current_discount_factor *= self.discount_factor + if episode[last_frame_index][DONE]: + break + if not episode[last_frame_index][DONE]: + state = episode[last_frame_index + 1][STATE] + v_sT = self.value_function.eval_value(state[None]) + return_ += current_discount_factor * v_sT + returns_this.append(return_) + if self.return_advantage: + v_s0 = self.value_function.eval_value(episode[i][STATE][None]) + advantages_this.append(return_ - v_s0) + + returns.append(returns_this) + advantages.append(advantages_this) + else: + returns.append([]) + advantages.append([]) + + if self.return_advantage: + return {'return': returns, 'advantage':advantages} + else: + return {'return': returns} class ddpg_return: @@ -128,18 +170,16 @@ class nstep_q_return: episode_q = [] for i in index: - current_discount_factor = 1 + current_discount_factor = 1. last_frame_index = i - target_q = episode[i][REWARD] - for lfi in range(i, min(len(episode), i + self.n + 1)): - if episode[lfi][DONE]: - break - target_q += current_discount_factor * episode[lfi][REWARD] + target_q = 0. + for last_frame_index in range(i, min(len(episode), i + self.n)): + target_q += current_discount_factor * episode[last_frame_index][REWARD] current_discount_factor *= self.discount_factor - last_frame_index = lfi - if last_frame_index > i: - state = episode[last_frame_index][STATE] - + if episode[last_frame_index][DONE]: + break + if not episode[last_frame_index][DONE]: # not done will definitely have one frame later + state = episode[last_frame_index + 1][STATE] if self.use_target_network: # [None] adds one dimension to the beginning qpredict = self.action_value.eval_value_all_actions_old(state[None])