From 54a7b1343d8798b783342377820904337c9afdd1 Mon Sep 17 00:00:00 2001 From: haoshengzou Date: Sun, 4 Mar 2018 13:53:29 +0800 Subject: [PATCH] design exploration and evaluators for off-policy algos --- examples/ddpg_example.py | 15 ++++++++------- examples/dqn_replay.py | 21 ++++++++++++++++----- tianshou/data/data_collector.py | 2 +- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/examples/ddpg_example.py b/examples/ddpg_example.py index 3ae4ae8..faa8dcd 100644 --- a/examples/ddpg_example.py +++ b/examples/ddpg_example.py @@ -78,17 +78,18 @@ if __name__ == '__main__': critic.sync_weights() start_time = time.time() - for i in range(100): + data_collector.collect(num_timesteps=1e3) # warm-up + for i in range(int(1e8)): # collect data - data_collector.collect(num_episodes=50) - - # print current return - print('Epoch {}:'.format(i)) - data_collector.statistics() + data_collector.collect() # update network for _ in range(num_batches): feed_dict = data_collector.next_batch(batch_size) sess.run([actor_train_op, critic_train_op], feed_dict=feed_dict) - print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) \ No newline at end of file + print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) + + # test every 1000 training steps + if i % 1000 == 0: + test(env, actor) diff --git a/examples/dqn_replay.py b/examples/dqn_replay.py index 9f14ddb..70657b0 100644 --- a/examples/dqn_replay.py +++ b/examples/dqn_replay.py @@ -77,17 +77,28 @@ if __name__ == '__main__': pi.sync_weights() # TODO: automate this for policies with target network start_time = time.time() - for i in range(100): + epsilon = 0.5 + pi.set_epsilon_train(epsilon) + data_collector.collect(num_timesteps=1e3) # warm-up + for i in range(int(1e8)): # number of training steps + # anneal epsilon step-wise + if (i + 1) % 1e4 == 0 and epsilon > 0.1: + epsilon -= 0.1 + pi.set_epsilon_train(epsilon) + # collect data data_collector.collect() - # print current return - print('Epoch {}:'.format(i)) - data_collector.statistics() - # update network for _ in range(num_batches): feed_dict = data_collector.next_batch(batch_size) sess.run(train_op, feed_dict=feed_dict) print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) + + # test every 1000 training steps + # tester could share some code with batch! + if i % 1000 == 0: + # epsilon 0.05 as in nature paper + pi.set_epsilon_test(0.05) + test(env, pi) # go for act_test of pi, not act diff --git a/tianshou/data/data_collector.py b/tianshou/data/data_collector.py index 91920eb..5ad5484 100644 --- a/tianshou/data/data_collector.py +++ b/tianshou/data/data_collector.py @@ -59,7 +59,7 @@ class DataCollector(object): if self.process_mode == 'minibatch': pass - # flatten rank-2 list to numpy array + # flatten rank-2 list to numpy array, construct feed_dict return