diff --git a/examples/dqn_example.py b/examples/dqn_example.py deleted file mode 100644 index f453311..0000000 --- a/examples/dqn_example.py +++ /dev/null @@ -1,95 +0,0 @@ -#!/usr/bin/env python -from __future__ import absolute_import - -import tensorflow as tf -import gym -import numpy as np -import time - -# our lib imports here! It's ok to append path in examples -import sys -sys.path.append('..') -from tianshou.core import losses -from tianshou.data.batch import Batch -import tianshou.data.advantage_estimation as advantage_estimation -import tianshou.core.policy.dqn as policy # TODO: fix imports as zhusuan so that only need to import to policy -import tianshou.core.value_function.action_value as value_function -import tianshou.data.replay_buffer.proportional as proportional -import tianshou.data.replay_buffer.rank_based as rank_based -import tianshou.data.replay_buffer.naive as naive -import tianshou.data.replay_buffer.Replay as Replay - - -# TODO: why this solves cartpole even without training? - - -if __name__ == '__main__': - env = gym.make('CartPole-v0') - observation_dim = env.observation_space.shape - action_dim = env.action_space.n - - clip_param = 0.2 - num_batches = 10 - batch_size = 512 - - seed = 0 - np.random.seed(seed) - tf.set_random_seed(seed) - - ### 1. build network with pure tf - observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim) - - def my_network(): - net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh) - net = tf.layers.dense(net, 32, activation=tf.nn.tanh) - - action_values = tf.layers.dense(net, action_dim, activation=None) - - return None, action_values # no policy head - - ### 2. build policy, loss, optimizer - dqn = value_function.DQN(my_network, observation_placeholder=observation_ph, weight_update=100) - pi = policy.DQN(dqn) - - dqn_loss = losses.qlearning(dqn) - - total_loss = dqn_loss - global_step = tf.Variable(0, name='global_step', trainable=False) - optimizer = tf.train.AdamOptimizer(1e-4) - train_op = optimizer.minimize(total_loss, var_list=dqn.trainable_variables, global_step=tf.train.get_global_step()) - - # replay_memory = naive.NaiveExperience({'size': 1000}) - replay_memory = rank_based.RankBasedExperience({'size': 30}) - # replay_memory = proportional.PropotionalExperience({'size': 100, 'batch_size': 10}) - data_collector = Replay.Replay(replay_memory, env, pi, [advantage_estimation.ReplayMemoryQReturn(1, dqn)], [dqn]) - - ### 3. define data collection - # data_collector = Batch(env, pi, [advantage_estimation.nstep_q_return(1, dqn)], [dqn]) - - ### 4. start training - config = tf.ConfigProto() - config.gpu_options.allow_growth = True - with tf.Session(config=config) as sess: - sess.run(tf.global_variables_initializer()) - - # assign actor to pi_old - pi.sync_weights() # TODO: automate this for policies with target network - - start_time = time.time() - #TODO : repeat_num shoulde be defined in some configuration files - repeat_num = 100 - for i in range(repeat_num): - # collect data - # data_collector.collect(nums=50) - data_collector.collect(num_episodes=50, epsilon_greedy= (repeat_num - i + 0.0) / repeat_num) - - # print current return - print('Epoch {}:'.format(i)) - data_collector.statistics() - - # update network - for _ in range(num_batches): - feed_dict = data_collector.next_batch(batch_size, tf.train.global_step(sess, global_step)) - sess.run(train_op, feed_dict=feed_dict) - - print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) diff --git a/examples/dqn_replay.py b/examples/dqn_replay.py index 70657b0..127ea00 100644 --- a/examples/dqn_replay.py +++ b/examples/dqn_replay.py @@ -14,6 +14,7 @@ from tianshou.core import losses import tianshou.data.advantage_estimation as advantage_estimation import tianshou.core.policy.dqn as policy # TODO: fix imports as zhusuan so that only need to import to policy import tianshou.core.value_function.action_value as value_function +import sys from tianshou.data.replay_buffer.vanilla import VanillaReplayBuffer from tianshou.data.data_collector import DataCollector @@ -79,7 +80,7 @@ if __name__ == '__main__': start_time = time.time() epsilon = 0.5 pi.set_epsilon_train(epsilon) - data_collector.collect(num_timesteps=1e3) # warm-up + data_collector.collect(num_timesteps=int(1e3)) # warm-up for i in range(int(1e8)): # number of training steps # anneal epsilon step-wise if (i + 1) % 1e4 == 0 and epsilon > 0.1: @@ -101,4 +102,4 @@ if __name__ == '__main__': if i % 1000 == 0: # epsilon 0.05 as in nature paper pi.set_epsilon_test(0.05) - test(env, pi) # go for act_test of pi, not act + #test(env, pi) # go for act_test of pi, not act diff --git a/tianshou/data/advantage_estimation.py b/tianshou/data/advantage_estimation.py index 5ffa544..f86fe8c 100644 --- a/tianshou/data/advantage_estimation.py +++ b/tianshou/data/advantage_estimation.py @@ -8,20 +8,20 @@ REWARD = 2 DONE = 3 # modified for new interfaces -def full_return(buffer, index=None): +def full_return(buffer, indexes=None): """ naively compute full return :param buffer: buffer with property index and data. index determines the current content in `buffer`. - :param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within + :param indexes: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within each episode. :return: dict with key 'return' and value the computed returns corresponding to `index`. """ - index = index or buffer.index + indexes = indexes or buffer.index raw_data = buffer.data returns = [] - for i_episode in range(len(index)): - index_this = index[i_episode] + for i_episode in range(len(indexes)): + index_this = indexes[i_episode] if index_this: episode = raw_data[i_episode] if not episode[-1][DONE]: @@ -111,7 +111,7 @@ class nstep_q_return: self.use_target_network = use_target_network # TODO : we should transfer the tf -> numpy/python -> tf into a monolithic compute graph in tf - def __call__(self, buffer, index=None): + def __call__(self, buffer, indexes=None): """ :param buffer: buffer with property index and data. index determines the current content in `buffer`. :param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within @@ -119,7 +119,7 @@ class nstep_q_return: :return: dict with key 'return' and value the computed returns corresponding to `index`. """ qvalue = self.action_value._value_tensor_all_actions - index = index or buffer.index + indexes = indexes or buffer.index episodes = buffer.data discount_factor = 0.99 returns = [] @@ -128,8 +128,8 @@ class nstep_q_return: config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) - for episode_index in range(len(index)): - index = index[episode_index] + for episode_index in range(len(indexes)): + index = indexes[episode_index] if index: episode = episodes[episode_index] episode_q = [] @@ -145,9 +145,11 @@ class nstep_q_return: current_discount_factor *= discount_factor last_frame_index = lfi if last_frame_index > i: - target_q += current_discount_factor * \ - max(sess.run(qvalue, feed_dict={self.action_value.managed_placeholders['observation']: - episode[last_frame_index][STATE]})) + state = episode[last_frame_index][STATE] + # the shape of qpredict is [batch_size, action_dimension] + qpredict = sess.run(qvalue, feed_dict={self.action_value.managed_placeholders['observation']: + state.reshape(1, state.shape[0])}) + target_q += current_discount_factor * max(qpredict[0]) episode_q.append(target_q) returns.append(episode_q) diff --git a/tianshou/data/data_collector.py b/tianshou/data/data_collector.py index aa0eda1..610d0f3 100644 --- a/tianshou/data/data_collector.py +++ b/tianshou/data/data_collector.py @@ -1,6 +1,7 @@ import numpy as np import logging import itertools +import sys from .replay_buffer.base import ReplayBufferBase @@ -59,7 +60,7 @@ class DataCollector(object): sampled_index = self.data_buffer.sample(batch_size) if self.process_mode == 'sample': for processor in self.process_functions: - self.data_batch.update(processor(self.data_buffer, index=sampled_index)) + self.data_batch.update(processor(self.data_buffer, indexes=sampled_index)) # flatten rank-2 list to numpy array, construct feed_dict feed_dict = {}