From 2a2274aeea2873d1ededd37016f2061fb517c195 Mon Sep 17 00:00:00 2001 From: haoshengzou Date: Sun, 4 Mar 2018 21:29:58 +0800 Subject: [PATCH] initial data_collector. working on examples/dqn_replay.py to run --- tianshou/core/policy/base.py | 2 +- tianshou/core/policy/dqn.py | 21 ++++++-- tianshou/data/advantage_estimation.py | 14 +++--- tianshou/data/data_collector.py | 72 ++++++++++++++++++++------- 4 files changed, 78 insertions(+), 31 deletions(-) diff --git a/tianshou/core/policy/base.py b/tianshou/core/policy/base.py index 23cd45d..6a060ce 100644 --- a/tianshou/core/policy/base.py +++ b/tianshou/core/policy/base.py @@ -16,7 +16,7 @@ class PolicyBase(object): """ base class for policy. only provides `act` method with exploration """ - def act(self, observation): + def act(self, observation, my_feed_dict): raise NotImplementedError() diff --git a/tianshou/core/policy/dqn.py b/tianshou/core/policy/dqn.py index 5cef57a..b93f1af 100644 --- a/tianshou/core/policy/dqn.py +++ b/tianshou/core/policy/dqn.py @@ -9,7 +9,7 @@ class DQN(PolicyBase): """ use DQN from value_function as a member """ - def __init__(self, dqn): + def __init__(self, dqn, epsilon_train=0.1, epsilon_test=0.05): self.action_value = dqn self._argmax_action = tf.argmax(dqn.value_tensor_all_actions, axis=1) self.weight_update = dqn.weight_update @@ -18,20 +18,29 @@ class DQN(PolicyBase): else: self.interaction_count = -1 - def act(self, observation, my_feed_dict): + self.epsilon_train = epsilon_train + self.epsilon_test = epsilon_test + + def act(self, observation, my_feed_dict={}): sess = tf.get_default_session() if self.weight_update > 1: if self.interaction_count % self.weight_update == 0: self.update_weights() feed_dict = {self.action_value._observation_placeholder: observation[None]} + feed_dict.update(my_feed_dict) action = sess.run(self._argmax_action, feed_dict=feed_dict) + if np.random.rand() < self.epsilon_train: + pass if self.weight_update > 0: self.interaction_count += 1 return np.squeeze(action) + def act_test(self, observation, my_feed_dict={}): + pass + @property def q_net(self): return self.action_value @@ -50,4 +59,10 @@ class DQN(PolicyBase): :return: """ if self.action_value.weight_update_ops is not None: - self.action_value.update_weights() \ No newline at end of file + self.action_value.update_weights() + + def set_epsilon_train(self, epsilon): + self.epsilon_train = epsilon + + def set_epsilon_test(self, epsilon): + self.epsilon_test = epsilon diff --git a/tianshou/data/advantage_estimation.py b/tianshou/data/advantage_estimation.py index cd956a2..5ffa544 100644 --- a/tianshou/data/advantage_estimation.py +++ b/tianshou/data/advantage_estimation.py @@ -101,7 +101,7 @@ class ddpg_return: pass -class ReplayMemoryQReturn: +class nstep_q_return: """ compute the n-step return for Q-learning targets """ @@ -111,7 +111,7 @@ class ReplayMemoryQReturn: self.use_target_network = use_target_network # TODO : we should transfer the tf -> numpy/python -> tf into a monolithic compute graph in tf - def __call__(self, buffer, indexes =None): + def __call__(self, buffer, index=None): """ :param buffer: buffer with property index and data. index determines the current content in `buffer`. :param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within @@ -119,7 +119,7 @@ class ReplayMemoryQReturn: :return: dict with key 'return' and value the computed returns corresponding to `index`. """ qvalue = self.action_value._value_tensor_all_actions - indexes = indexes or buffer.index + index = index or buffer.index episodes = buffer.data discount_factor = 0.99 returns = [] @@ -128,13 +128,11 @@ class ReplayMemoryQReturn: config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) - for episode_index in range(len(indexes)): - index = indexes[episode_index] + for episode_index in range(len(index)): + index = index[episode_index] if index: episode = episodes[episode_index] episode_q = [] - if not episode[-1][DONE]: - logging.warning('Computing Q return on episode {} with no terminal state.'.format(episode_index)) for i in index: current_discount_factor = 1 @@ -155,4 +153,4 @@ class ReplayMemoryQReturn: returns.append(episode_q) else: returns.append([]) - return {'TD-lambda': returns} + return {'return': returns} diff --git a/tianshou/data/data_collector.py b/tianshou/data/data_collector.py index 5ad5484..aa0eda1 100644 --- a/tianshou/data/data_collector.py +++ b/tianshou/data/data_collector.py @@ -1,3 +1,7 @@ +import numpy as np +import logging +import itertools + from .replay_buffer.base import ReplayBufferBase class DataCollector(object): @@ -11,30 +15,28 @@ class DataCollector(object): self.process_functions = process_functions self.managed_networks = managed_networks + self.data = {} + self.data_batch = {} + self.required_placeholders = {} for net in self.managed_networks: self.required_placeholders.update(net.managed_placeholders) self.require_advantage = 'advantage' in self.required_placeholders.keys() if isinstance(self.data_buffer, ReplayBufferBase): # process when sampling minibatch - self.process_mode = 'minibatch' + self.process_mode = 'sample' else: - self.process_mode = 'batch' + self.process_mode = 'full' self.current_observation = self.env.reset() - def collect(self, num_timesteps=1, num_episodes=0, exploration=None, my_feed_dict={}): + def collect(self, num_timesteps=1, num_episodes=0, my_feed_dict={}): assert sum([num_timesteps > 0, num_episodes > 0]) == 1,\ "One and only one collection number specification permitted!" if num_timesteps > 0: for _ in range(num_timesteps): - action_vanilla = self.policy.act(self.current_observation, my_feed_dict=my_feed_dict) - if exploration: - action = exploration(action_vanilla) - else: - action = action_vanilla - + action = self.policy.act(self.current_observation, my_feed_dict=my_feed_dict) next_observation, reward, done, _ = self.env.step(action) self.data_buffer.add((self.current_observation, action, reward, done)) self.current_observation = next_observation @@ -44,24 +46,56 @@ class DataCollector(object): observation = self.env.reset() done = False while not done: - action_vanilla = self.policy.act(observation, my_feed_dict=my_feed_dict) - if exploration: - action = exploration(action_vanilla) - else: - action = action_vanilla - + action = self.policy.act(observation, my_feed_dict=my_feed_dict) next_observation, reward, done, _ = self.env.step(action) self.data_buffer.add((observation, action, reward, done)) observation = next_observation - def next_batch(self, batch_size): + if self.process_mode == 'full': + for processor in self.process_functions: + self.data.update(processor(self.data_buffer)) + + def next_batch(self, batch_size, standardize_advantage=True): sampled_index = self.data_buffer.sample(batch_size) - if self.process_mode == 'minibatch': - pass + if self.process_mode == 'sample': + for processor in self.process_functions: + self.data_batch.update(processor(self.data_buffer, index=sampled_index)) # flatten rank-2 list to numpy array, construct feed_dict + feed_dict = {} + frame_key_map = {'observation': 0, 'action': 1, 'reward': 2, 'done_flag': 3} + for key, placeholder in self.required_placeholders.items(): + if key in frame_key_map.keys(): # access raw_data + frame_index = frame_key_map[key] + flattened = [] + for index_episode, data_episode in zip(sampled_index, self.data_buffer.data): + for i in index_episode: + flattened.append(data_episode[i][frame_index]) + feed_dict[placeholder] = np.array(flattened) + elif key in self.data_batch.keys(): # access processed minibatch data + flattened = list(itertools.chain.from_iterable(self.data_batch[key])) + feed_dict[placeholder] = np.array(flattened) + elif key in self.data.keys(): # access processed full data + flattened = [0.] * batch_size # float + i_in_batch = 0 + for index_episode, data_episode in zip(sampled_index, self.data[key]): + for i in index_episode: + flattened[i_in_batch] = data_episode[i] + i_in_batch += 1 + feed_dict[placeholder] = np.array(flattened) + else: + raise TypeError('Placeholder {} has no value to feed!'.format(str(placeholder.name))) - return + if standardize_advantage: + if self.require_advantage: + advantage_value = feed_dict[self.required_placeholders['advantage']] + advantage_mean = np.mean(advantage_value) + advantage_std = np.std(advantage_value) + if advantage_std < 1e-3: + logging.warning('advantage_std too small (< 1e-3) for advantage standardization. may cause numerical issues') + feed_dict[self.required_placeholders['advantage']] = (advantage_value - advantage_mean) / advantage_std + + return feed_dict def statistics(self): pass \ No newline at end of file