From 67d0e78ab981b720301a1486f042851475d4691f Mon Sep 17 00:00:00 2001 From: songshshshsh <644240545@qq.com> Date: Tue, 27 Feb 2018 13:10:47 +0800 Subject: [PATCH] first modify of replay buffer, make all three replay buffers work, wait for refactoring and testing --- examples/dqn_example.py | 18 ++- tianshou/data/advantage_estimation.py | 28 ++++ tianshou/data/replay_buffer/Replay.py | 164 ++++++++++++++++++++ tianshou/data/replay_buffer/binary_heap.py | 1 + tianshou/data/replay_buffer/naive.py | 14 +- tianshou/data/replay_buffer/proportional.py | 19 +-- tianshou/data/replay_buffer/rank_based.py | 21 +-- 7 files changed, 235 insertions(+), 30 deletions(-) create mode 100644 tianshou/data/replay_buffer/Replay.py diff --git a/examples/dqn_example.py b/examples/dqn_example.py index 70c9e4b..31fdfee 100644 --- a/examples/dqn_example.py +++ b/examples/dqn_example.py @@ -14,6 +14,10 @@ from tianshou.data.batch import Batch import tianshou.data.advantage_estimation as advantage_estimation import tianshou.core.policy.dqn as policy # TODO: fix imports as zhusuan so that only need to import to policy import tianshou.core.value_function.action_value as value_function +import tianshou.data.replay_buffer.proportional as proportional +import tianshou.data.replay_buffer.rank_based as rank_based +import tianshou.data.replay_buffer.naive as naive +import tianshou.data.replay_buffer.Replay as Replay # TODO: why this solves cartpole even without training? @@ -50,11 +54,17 @@ if __name__ == '__main__': dqn_loss = losses.qlearning(dqn) total_loss = dqn_loss + global_step = tf.Variable(0, name='global_step', trainable=False) optimizer = tf.train.AdamOptimizer(1e-4) - train_op = optimizer.minimize(total_loss, var_list=dqn.trainable_variables) + train_op = optimizer.minimize(total_loss, var_list=dqn.trainable_variables, global_step=tf.train.get_global_step()) + + # replay_memory = naive.NaiveExperience({'size': 1000}) + replay_memory = rank_based.RankBasedExperience({'size': 30}) + # replay_memory = proportional.PropotionalExperience({'size': 100, 'batch_size': 10}) + data_collector = Replay.Replay(replay_memory, env, pi, [advantage_estimation.ReplayMemoryQReturn(1, dqn)], [dqn]) ### 3. define data collection - data_collector = Batch(env, pi, [advantage_estimation.nstep_q_return(1, dqn)], [dqn]) + # data_collector = Batch(env, pi, [advantage_estimation.nstep_q_return(1, dqn)], [dqn]) ### 4. start training config = tf.ConfigProto() @@ -68,7 +78,7 @@ if __name__ == '__main__': start_time = time.time() for i in range(100): # collect data - data_collector.collect(num_episodes=50) + data_collector.collect(nums=50) # print current return print('Epoch {}:'.format(i)) @@ -76,7 +86,7 @@ if __name__ == '__main__': # update network for _ in range(num_batches): - feed_dict = data_collector.next_batch(batch_size) + feed_dict = data_collector.next_batch(batch_size, tf.train.global_step(sess, global_step)) sess.run(train_op, feed_dict=feed_dict) print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60)) diff --git a/tianshou/data/advantage_estimation.py b/tianshou/data/advantage_estimation.py index b9bf0e3..8d9c25f 100644 --- a/tianshou/data/advantage_estimation.py +++ b/tianshou/data/advantage_estimation.py @@ -159,3 +159,31 @@ class QLearningTarget: return data + +class ReplayMemoryQReturn: + """ + compute the n-step return for Q-learning targets + """ + def __init__(self, n, action_value, use_target_network=True): + self.n = n + self._action_value = action_value + self._use_target_network = use_target_network + + def __call__(self, raw_data): + reward = raw_data['reward'] + observation = raw_data['observation'] + + if self._use_target_network: + # print(observation.shape) + # print((observation.reshape((1,) + observation.shape))) + action_value_all_actions = self._action_value.eval_value_all_actions_old(observation.reshape((1,) + observation.shape)) + else: + # print(observation.shape) + # print((observation.reshape((1,) + observation.shape))) + action_value_all_actions = self._action_value.eval_value_all_actions(observation.reshape((1,) + observation.shape)) + + action_value_max = np.max(action_value_all_actions, axis=1) + + return_ = reward + action_value_max + + return {'return': return_} diff --git a/tianshou/data/replay_buffer/Replay.py b/tianshou/data/replay_buffer/Replay.py new file mode 100644 index 0000000..e1c9189 --- /dev/null +++ b/tianshou/data/replay_buffer/Replay.py @@ -0,0 +1,164 @@ +import tianshou.data.replay_buffer.naive as naive +import tianshou.data.replay_buffer.rank_based as rank_based +import tianshou.data.replay_buffer.proportional as proportional +import numpy as np +import tensorflow as tf +from tianshou.data import utils +import logging + + +class Replay(object): + def __init__(self, replay_memory, env, pi, reward_processors, networks): + self._replay_memory = replay_memory + self._env = env + self._pi = pi + self._reward_processors = reward_processors + self._networks = networks + + self._required_placeholders = {} + for net in self._networks: + self._required_placeholders.update(net.managed_placeholders) + self._require_advantage = 'advantage' in self._required_placeholders.keys() + self._collected_data = list() + + self._is_first_collect = True + + def _begin_act(self, exploration): + while self._is_first_collect: + self._observation = self._env.reset() + self._action = self._pi.act(self._observation, exploration) + self._observation, reward, done, _ = self._env.step(self._action) + if not done: + self._is_first_collect = False + + def collect(self, nums, exploration=None): + """ + collect data for replay memory and update the priority according to the given data. + store the previous action, previous observation, reward, action, observation in the replay memory. + """ + sess = tf.get_default_session() + self._collected_data = list() + + for _ in range(0, nums): + if self._is_first_collect: + self._begin_act(exploration) + + current_data = dict() + current_data['previous_action'] = self._action + current_data['previous_observation'] = self._observation + self._action = self._pi.act(self._observation, exploration) + self._observation, reward, done, _ = self._env.step(self._action) + current_data['action'] = self._action + current_data['observation'] = self._observation + current_data['reward'] = reward + current_data['end_flag'] = done + self._replay_memory.add(current_data) + self._collected_data.append(current_data) + if done: + self._begin_act(exploration) + + # I don't know what statistics should replay memory provide, for replay memory only saves discrete data + def statistics(self): + """ + compute the statistics of the current sampled paths + :return: + """ + raw_data = dict(zip(self._collected_data[0], zip(*[d.values() for d in self._collected_data]))) + rewards = np.array(raw_data['reward']) + episode_start_flags = np.array(raw_data['end_flag']) + num_timesteps = rewards.shape[0] + + returns = [] + episode_lengths = [] + max_return = 0 + num_episodes = 1 + episode_start_idx = 0 + for i in range(1, num_timesteps): + if episode_start_flags[i] or ( + i == num_timesteps - 1): # found the start of next episode or the end of all episodes + if episode_start_flags[i]: + num_episodes += 1 + if i < rewards.shape[0] - 1: + t = i - 1 + else: + t = i + Gt = 0 + episode_lengths.append(t - episode_start_idx) + while t >= episode_start_idx: + Gt += rewards[t] + t -= 1 + + returns.append(Gt) + if Gt > max_return: + max_return = Gt + episode_start_idx = i + + print('AverageReturn: {}'.format(np.mean(returns))) + print('StdReturn : {}'.format(np.std(returns))) + print('NumEpisodes : {}'.format(num_episodes)) + print('MinMaxReturns: {}..., {}'.format(np.sort(returns)[:3], np.sort(returns)[-3:])) + print('AverageLength: {}'.format(np.mean(episode_lengths))) + print('MinMaxLengths: {}..., {}'.format(np.sort(episode_lengths)[:3], np.sort(episode_lengths)[-3:])) + + def next_batch(self, batch_size, global_step=0, standardize_advantage=True): + """ + collect a batch of data from replay buffer, update the priority and calculate the necessary statistics for + updating q value network. + :param batch_size: int batch size. + :param global_step: int training global step. + :return: a batch of data, with target storing the target q value and wi, rewards storing the coefficient + for gradient of q value network. + """ + + feed_dict = {} + is_first = True + + for _ in range(0, batch_size): + current_datas, current_wis, current_indexs = \ + self._replay_memory.sample( + {'batch_size': 1, 'global_step': global_step}) + current_data = current_datas[0] + current_wi = current_wis[0] + current_index = current_indexs[0] + current_processed_data = {} + for processors in self._reward_processors: + current_processed_data.update(processors(current_data)) + + for key, placeholder in self._required_placeholders.items(): + found, data_key = utils.internal_key_match(key, current_data.keys()) + if found: + if is_first: + feed_dict[placeholder] = np.array([current_data[data_key]]) + else: + feed_dict[placeholder] = np.append(feed_dict[placeholder], np.array([current_data[data_key]]), 0) + else: + found, data_key = utils.internal_key_match(key, current_processed_data.keys()) + if found: + if is_first: + feed_dict[placeholder] = np.array(current_processed_data[data_key]) + else: + feed_dict[placeholder] = np.append(feed_dict[placeholder], + np.array(current_processed_data[data_key]), 0) + else: + raise TypeError('Placeholder {} has no value to feed!'.format(str(placeholder.name))) + next_max_qvalue = np.max(self._networks[-1].eval_value_all_actions( + current_data['observation'].reshape((1,) + current_data['observation'].shape))) + current_qvalue = self._networks[-1].eval_value_all_actions( + current_data['previous_observation'] + .reshape((1,) + current_data['previous_observation'].shape))[0, current_data['previous_action']] + reward = current_data['reward'] + next_max_qvalue - current_qvalue + import math + self._replay_memory.update_priority([current_index], [math.fabs(reward)]) + if is_first: + is_first = False + + if standardize_advantage: + if self._require_advantage: + advantage_value = feed_dict[self._required_placeholders['advantage']] + advantage_mean = np.mean(advantage_value) + advantage_std = np.std(advantage_value) + if advantage_std < 1e-3: + logging.warning( + 'advantage_std too small (< 1e-3) for advantage standardization. may cause numerical issues') + feed_dict[self._required_placeholders['advantage']] = (advantage_value - advantage_mean) / advantage_std + return feed_dict \ No newline at end of file diff --git a/tianshou/data/replay_buffer/binary_heap.py b/tianshou/data/replay_buffer/binary_heap.py index 2deac14..800604d 100644 --- a/tianshou/data/replay_buffer/binary_heap.py +++ b/tianshou/data/replay_buffer/binary_heap.py @@ -218,4 +218,5 @@ class BinaryHeap(object): :param priority_ids: list of priority id :return: list of experience id """ + # print(priority_ids) return [self.p2e[i] for i in priority_ids] diff --git a/tianshou/data/replay_buffer/naive.py b/tianshou/data/replay_buffer/naive.py index 5eb4dd7..dc08f77 100644 --- a/tianshou/data/replay_buffer/naive.py +++ b/tianshou/data/replay_buffer/naive.py @@ -7,13 +7,15 @@ from tianshou.data.replay_buffer.buffer import ReplayBuffer class NaiveExperience(ReplayBuffer): - def __init__(self, env, policy, qnet, target_qnet, conf): + # def __init__(self, env, policy, qnet, target_qnet, conf): + def __init__(self, conf): self.max_size = conf['size'] - self._env = env - self._policy = policy - self._qnet = qnet - self._target_qnet = target_qnet - self._begin_act() + self._name = 'naive' + # self._env = env + # self._policy = policy + # self._qnet = qnet + # self._target_qnet = target_qnet + # self._begin_act() self.n_entries = 0 self.memory = deque(maxlen=self.max_size) diff --git a/tianshou/data/replay_buffer/proportional.py b/tianshou/data/replay_buffer/proportional.py index 52a231d..41154c3 100644 --- a/tianshou/data/replay_buffer/proportional.py +++ b/tianshou/data/replay_buffer/proportional.py @@ -18,7 +18,7 @@ class PropotionalExperience(ReplayBuffer): """ - def __init__(self, env, policy, qnet, target_qnet, conf): + def __init__(self, conf): """ Prioritized experience replay buffer initialization. Parameters @@ -38,11 +38,12 @@ class PropotionalExperience(ReplayBuffer): self.memory_size = memory_size self.batch_size = batch_size self.alpha = alpha - self._env = env - self._policy = policy - self._qnet = qnet - self._target_qnet = target_qnet - self._begin_act() + # self._env = env + # self._policy = policy + # self._qnet = qnet + # self._target_qnet = target_qnet + # self._begin_act() + self._name = 'proportional' def _begin_act(self): """ @@ -58,7 +59,7 @@ class PropotionalExperience(ReplayBuffer): self.action = self._env.action_space.sample() self.observation, _, done, _ = self._env.step(self.action) - def add(self, data, priority): + def add(self, data, priority=1): """ Add new sample. Parameters @@ -195,7 +196,3 @@ class PropotionalExperience(ReplayBuffer): priorities = [self.tree.get_val(i)**-old_alpha for i in range(self.tree.filled_size())] self.update_priority(range(self.tree.filled_size()), priorities) - - - - diff --git a/tianshou/data/replay_buffer/rank_based.py b/tianshou/data/replay_buffer/rank_based.py index 0a6641f..0abb0d8 100644 --- a/tianshou/data/replay_buffer/rank_based.py +++ b/tianshou/data/replay_buffer/rank_based.py @@ -16,15 +16,16 @@ from tianshou.data.replay_buffer.buffer import ReplayBuffer class RankBasedExperience(ReplayBuffer): - def __init__(self, env, policy, qnet, target_qnet, conf): + def __init__(self, conf): self.size = conf['size'] self.replace_flag = conf['replace_old'] if 'replace_old' in conf else True self.priority_size = conf['priority_size'] if 'priority_size' in conf else self.size + self._name = 'rank_based' self.alpha = conf['alpha'] if 'alpha' in conf else 0.7 self.beta_zero = conf['beta_zero'] if 'beta_zero' in conf else 0.5 self.batch_size = conf['batch_size'] if 'batch_size' in conf else 32 - self.learn_start = conf['learn_start'] if 'learn_start' in conf else 1000 + self.learn_start = conf['learn_start'] if 'learn_start' in conf else 10 self.total_steps = conf['steps'] if 'steps' in conf else 100000 # partition number N, split total size to N part self.partition_num = conf['partition_num'] if 'partition_num' in conf else 10 @@ -33,11 +34,11 @@ class RankBasedExperience(ReplayBuffer): self.record_size = 0 self.isFull = False - self._env = env - self._policy = policy - self._qnet = qnet - self._target_qnet = target_qnet - self._begin_act() + # self._env = env + # self._policy = policy + # self._qnet = qnet + # self._target_qnet = target_qnet + # self._begin_act() self._experience = {} self.priority_queue = BinaryHeap(self.priority_size) @@ -241,12 +242,14 @@ class RankBasedExperience(ReplayBuffer): # issue 1 by @camigord partition_size = math.floor(self.size * 1. / self.partition_num) partition_max = dist_index * partition_size + # print(self.record_size, self.partition_num, partition_max, partition_size, dist_index) + # print(self.distributions.keys()) distribution = self.distributions[dist_index] rank_list = [] # sample from k segments for n in range(1, self.batch_size + 1): - index = random.randint(distribution['strata_ends'][n], - distribution['strata_ends'][n + 1]) + index = max(random.randint(distribution['strata_ends'][n], + distribution['strata_ends'][n + 1]), 1) rank_list.append(index) # beta, increase by global_step, max 1