From 0bc1b63e389f3a3213bf8469a0a2ff96289634cf Mon Sep 17 00:00:00 2001 From: Dong Yan Date: Sun, 25 Feb 2018 16:31:35 +0800 Subject: [PATCH] add epsilon-greedy for dqn --- examples/dqn_example.py | 6 ++++-- tianshou/core/policy/dqn.py | 5 ++--- tianshou/data/batch.py | 14 +++++++++----- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/examples/dqn_example.py b/examples/dqn_example.py index 70c9e4b..ee18863 100644 --- a/examples/dqn_example.py +++ b/examples/dqn_example.py @@ -66,9 +66,11 @@ if __name__ == '__main__': pi.sync_weights() # TODO: automate this for policies with target network start_time = time.time() - for i in range(100): + #TODO : repeat_num shoulde be defined in some configuration files + repeat_num = 100 + for i in range(repeat_num): # collect data - data_collector.collect(num_episodes=50) + data_collector.collect(num_episodes=50, epsilon_greedy= (repeat_num - i + 0.0) / repeat_num) # print current return print('Epoch {}:'.format(i)) diff --git a/tianshou/core/policy/dqn.py b/tianshou/core/policy/dqn.py index bc5db67..5cef57a 100644 --- a/tianshou/core/policy/dqn.py +++ b/tianshou/core/policy/dqn.py @@ -18,7 +18,7 @@ class DQN(PolicyBase): else: self.interaction_count = -1 - def act(self, observation, exploration=None): + def act(self, observation, my_feed_dict): sess = tf.get_default_session() if self.weight_update > 1: if self.interaction_count % self.weight_update == 0: @@ -30,8 +30,7 @@ class DQN(PolicyBase): if self.weight_update > 0: self.interaction_count += 1 - if not exploration: - return np.squeeze(action) + return np.squeeze(action) @property def q_net(self): diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 9c7405d..d559ded 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -34,7 +34,7 @@ class Batch(object): self._is_first_collect = True def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={}, - process_reward=True): # specify how many data to collect here, or fix it in __init__() + process_reward=True, epsilon_greedy=0): # specify how many data to collect here, or fix it in __init__() assert sum( [num_timesteps > 0, num_episodes > 0]) == 1, "One and only one collection number specification permitted!" @@ -106,7 +106,11 @@ class Batch(object): episode_start_flags.append(True) while True: - ac = self._pi.act(ob, my_feed_dict) + # a simple implementation of epsilon greedy + if epsilon_greedy > 0 and np.random.random() < epsilon_greedy: + ac = np.random.randint(low = 0, high = self._env.action_space.n) + else: + ac = self._pi.act(ob, my_feed_dict) actions.append(ac) if self.render: @@ -114,9 +118,9 @@ class Batch(object): ob, reward, done, _ = self._env.step(ac) rewards.append(reward) - t_count += 1 - if t_count >= 100: # force episode stop, just to test if memory still grows - break + #t_count += 1 + #if t_count >= 100: # force episode stop, just to test if memory still grows + # break if done: # end of episode, discard s_T # TODO: for num_timesteps collection, has to store terminal flag instead of start flag!