first modify of replay buffer, make all three replay buffers work, wait for refactoring and testing
This commit is contained in:
parent
a40e5aec54
commit
67d0e78ab9
@ -14,6 +14,10 @@ from tianshou.data.batch import Batch
|
|||||||
import tianshou.data.advantage_estimation as advantage_estimation
|
import tianshou.data.advantage_estimation as advantage_estimation
|
||||||
import tianshou.core.policy.dqn as policy # TODO: fix imports as zhusuan so that only need to import to policy
|
import tianshou.core.policy.dqn as policy # TODO: fix imports as zhusuan so that only need to import to policy
|
||||||
import tianshou.core.value_function.action_value as value_function
|
import tianshou.core.value_function.action_value as value_function
|
||||||
|
import tianshou.data.replay_buffer.proportional as proportional
|
||||||
|
import tianshou.data.replay_buffer.rank_based as rank_based
|
||||||
|
import tianshou.data.replay_buffer.naive as naive
|
||||||
|
import tianshou.data.replay_buffer.Replay as Replay
|
||||||
|
|
||||||
|
|
||||||
# TODO: why this solves cartpole even without training?
|
# TODO: why this solves cartpole even without training?
|
||||||
@ -50,11 +54,17 @@ if __name__ == '__main__':
|
|||||||
dqn_loss = losses.qlearning(dqn)
|
dqn_loss = losses.qlearning(dqn)
|
||||||
|
|
||||||
total_loss = dqn_loss
|
total_loss = dqn_loss
|
||||||
|
global_step = tf.Variable(0, name='global_step', trainable=False)
|
||||||
optimizer = tf.train.AdamOptimizer(1e-4)
|
optimizer = tf.train.AdamOptimizer(1e-4)
|
||||||
train_op = optimizer.minimize(total_loss, var_list=dqn.trainable_variables)
|
train_op = optimizer.minimize(total_loss, var_list=dqn.trainable_variables, global_step=tf.train.get_global_step())
|
||||||
|
|
||||||
|
# replay_memory = naive.NaiveExperience({'size': 1000})
|
||||||
|
replay_memory = rank_based.RankBasedExperience({'size': 30})
|
||||||
|
# replay_memory = proportional.PropotionalExperience({'size': 100, 'batch_size': 10})
|
||||||
|
data_collector = Replay.Replay(replay_memory, env, pi, [advantage_estimation.ReplayMemoryQReturn(1, dqn)], [dqn])
|
||||||
|
|
||||||
### 3. define data collection
|
### 3. define data collection
|
||||||
data_collector = Batch(env, pi, [advantage_estimation.nstep_q_return(1, dqn)], [dqn])
|
# data_collector = Batch(env, pi, [advantage_estimation.nstep_q_return(1, dqn)], [dqn])
|
||||||
|
|
||||||
### 4. start training
|
### 4. start training
|
||||||
config = tf.ConfigProto()
|
config = tf.ConfigProto()
|
||||||
@ -68,7 +78,7 @@ if __name__ == '__main__':
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
# collect data
|
# collect data
|
||||||
data_collector.collect(num_episodes=50)
|
data_collector.collect(nums=50)
|
||||||
|
|
||||||
# print current return
|
# print current return
|
||||||
print('Epoch {}:'.format(i))
|
print('Epoch {}:'.format(i))
|
||||||
@ -76,7 +86,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# update network
|
# update network
|
||||||
for _ in range(num_batches):
|
for _ in range(num_batches):
|
||||||
feed_dict = data_collector.next_batch(batch_size)
|
feed_dict = data_collector.next_batch(batch_size, tf.train.global_step(sess, global_step))
|
||||||
sess.run(train_op, feed_dict=feed_dict)
|
sess.run(train_op, feed_dict=feed_dict)
|
||||||
|
|
||||||
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))
|
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))
|
||||||
|
@ -159,3 +159,31 @@ class QLearningTarget:
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class ReplayMemoryQReturn:
|
||||||
|
"""
|
||||||
|
compute the n-step return for Q-learning targets
|
||||||
|
"""
|
||||||
|
def __init__(self, n, action_value, use_target_network=True):
|
||||||
|
self.n = n
|
||||||
|
self._action_value = action_value
|
||||||
|
self._use_target_network = use_target_network
|
||||||
|
|
||||||
|
def __call__(self, raw_data):
|
||||||
|
reward = raw_data['reward']
|
||||||
|
observation = raw_data['observation']
|
||||||
|
|
||||||
|
if self._use_target_network:
|
||||||
|
# print(observation.shape)
|
||||||
|
# print((observation.reshape((1,) + observation.shape)))
|
||||||
|
action_value_all_actions = self._action_value.eval_value_all_actions_old(observation.reshape((1,) + observation.shape))
|
||||||
|
else:
|
||||||
|
# print(observation.shape)
|
||||||
|
# print((observation.reshape((1,) + observation.shape)))
|
||||||
|
action_value_all_actions = self._action_value.eval_value_all_actions(observation.reshape((1,) + observation.shape))
|
||||||
|
|
||||||
|
action_value_max = np.max(action_value_all_actions, axis=1)
|
||||||
|
|
||||||
|
return_ = reward + action_value_max
|
||||||
|
|
||||||
|
return {'return': return_}
|
||||||
|
164
tianshou/data/replay_buffer/Replay.py
Normal file
164
tianshou/data/replay_buffer/Replay.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
import tianshou.data.replay_buffer.naive as naive
|
||||||
|
import tianshou.data.replay_buffer.rank_based as rank_based
|
||||||
|
import tianshou.data.replay_buffer.proportional as proportional
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
from tianshou.data import utils
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
class Replay(object):
|
||||||
|
def __init__(self, replay_memory, env, pi, reward_processors, networks):
|
||||||
|
self._replay_memory = replay_memory
|
||||||
|
self._env = env
|
||||||
|
self._pi = pi
|
||||||
|
self._reward_processors = reward_processors
|
||||||
|
self._networks = networks
|
||||||
|
|
||||||
|
self._required_placeholders = {}
|
||||||
|
for net in self._networks:
|
||||||
|
self._required_placeholders.update(net.managed_placeholders)
|
||||||
|
self._require_advantage = 'advantage' in self._required_placeholders.keys()
|
||||||
|
self._collected_data = list()
|
||||||
|
|
||||||
|
self._is_first_collect = True
|
||||||
|
|
||||||
|
def _begin_act(self, exploration):
|
||||||
|
while self._is_first_collect:
|
||||||
|
self._observation = self._env.reset()
|
||||||
|
self._action = self._pi.act(self._observation, exploration)
|
||||||
|
self._observation, reward, done, _ = self._env.step(self._action)
|
||||||
|
if not done:
|
||||||
|
self._is_first_collect = False
|
||||||
|
|
||||||
|
def collect(self, nums, exploration=None):
|
||||||
|
"""
|
||||||
|
collect data for replay memory and update the priority according to the given data.
|
||||||
|
store the previous action, previous observation, reward, action, observation in the replay memory.
|
||||||
|
"""
|
||||||
|
sess = tf.get_default_session()
|
||||||
|
self._collected_data = list()
|
||||||
|
|
||||||
|
for _ in range(0, nums):
|
||||||
|
if self._is_first_collect:
|
||||||
|
self._begin_act(exploration)
|
||||||
|
|
||||||
|
current_data = dict()
|
||||||
|
current_data['previous_action'] = self._action
|
||||||
|
current_data['previous_observation'] = self._observation
|
||||||
|
self._action = self._pi.act(self._observation, exploration)
|
||||||
|
self._observation, reward, done, _ = self._env.step(self._action)
|
||||||
|
current_data['action'] = self._action
|
||||||
|
current_data['observation'] = self._observation
|
||||||
|
current_data['reward'] = reward
|
||||||
|
current_data['end_flag'] = done
|
||||||
|
self._replay_memory.add(current_data)
|
||||||
|
self._collected_data.append(current_data)
|
||||||
|
if done:
|
||||||
|
self._begin_act(exploration)
|
||||||
|
|
||||||
|
# I don't know what statistics should replay memory provide, for replay memory only saves discrete data
|
||||||
|
def statistics(self):
|
||||||
|
"""
|
||||||
|
compute the statistics of the current sampled paths
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
raw_data = dict(zip(self._collected_data[0], zip(*[d.values() for d in self._collected_data])))
|
||||||
|
rewards = np.array(raw_data['reward'])
|
||||||
|
episode_start_flags = np.array(raw_data['end_flag'])
|
||||||
|
num_timesteps = rewards.shape[0]
|
||||||
|
|
||||||
|
returns = []
|
||||||
|
episode_lengths = []
|
||||||
|
max_return = 0
|
||||||
|
num_episodes = 1
|
||||||
|
episode_start_idx = 0
|
||||||
|
for i in range(1, num_timesteps):
|
||||||
|
if episode_start_flags[i] or (
|
||||||
|
i == num_timesteps - 1): # found the start of next episode or the end of all episodes
|
||||||
|
if episode_start_flags[i]:
|
||||||
|
num_episodes += 1
|
||||||
|
if i < rewards.shape[0] - 1:
|
||||||
|
t = i - 1
|
||||||
|
else:
|
||||||
|
t = i
|
||||||
|
Gt = 0
|
||||||
|
episode_lengths.append(t - episode_start_idx)
|
||||||
|
while t >= episode_start_idx:
|
||||||
|
Gt += rewards[t]
|
||||||
|
t -= 1
|
||||||
|
|
||||||
|
returns.append(Gt)
|
||||||
|
if Gt > max_return:
|
||||||
|
max_return = Gt
|
||||||
|
episode_start_idx = i
|
||||||
|
|
||||||
|
print('AverageReturn: {}'.format(np.mean(returns)))
|
||||||
|
print('StdReturn : {}'.format(np.std(returns)))
|
||||||
|
print('NumEpisodes : {}'.format(num_episodes))
|
||||||
|
print('MinMaxReturns: {}..., {}'.format(np.sort(returns)[:3], np.sort(returns)[-3:]))
|
||||||
|
print('AverageLength: {}'.format(np.mean(episode_lengths)))
|
||||||
|
print('MinMaxLengths: {}..., {}'.format(np.sort(episode_lengths)[:3], np.sort(episode_lengths)[-3:]))
|
||||||
|
|
||||||
|
def next_batch(self, batch_size, global_step=0, standardize_advantage=True):
|
||||||
|
"""
|
||||||
|
collect a batch of data from replay buffer, update the priority and calculate the necessary statistics for
|
||||||
|
updating q value network.
|
||||||
|
:param batch_size: int batch size.
|
||||||
|
:param global_step: int training global step.
|
||||||
|
:return: a batch of data, with target storing the target q value and wi, rewards storing the coefficient
|
||||||
|
for gradient of q value network.
|
||||||
|
"""
|
||||||
|
|
||||||
|
feed_dict = {}
|
||||||
|
is_first = True
|
||||||
|
|
||||||
|
for _ in range(0, batch_size):
|
||||||
|
current_datas, current_wis, current_indexs = \
|
||||||
|
self._replay_memory.sample(
|
||||||
|
{'batch_size': 1, 'global_step': global_step})
|
||||||
|
current_data = current_datas[0]
|
||||||
|
current_wi = current_wis[0]
|
||||||
|
current_index = current_indexs[0]
|
||||||
|
current_processed_data = {}
|
||||||
|
for processors in self._reward_processors:
|
||||||
|
current_processed_data.update(processors(current_data))
|
||||||
|
|
||||||
|
for key, placeholder in self._required_placeholders.items():
|
||||||
|
found, data_key = utils.internal_key_match(key, current_data.keys())
|
||||||
|
if found:
|
||||||
|
if is_first:
|
||||||
|
feed_dict[placeholder] = np.array([current_data[data_key]])
|
||||||
|
else:
|
||||||
|
feed_dict[placeholder] = np.append(feed_dict[placeholder], np.array([current_data[data_key]]), 0)
|
||||||
|
else:
|
||||||
|
found, data_key = utils.internal_key_match(key, current_processed_data.keys())
|
||||||
|
if found:
|
||||||
|
if is_first:
|
||||||
|
feed_dict[placeholder] = np.array(current_processed_data[data_key])
|
||||||
|
else:
|
||||||
|
feed_dict[placeholder] = np.append(feed_dict[placeholder],
|
||||||
|
np.array(current_processed_data[data_key]), 0)
|
||||||
|
else:
|
||||||
|
raise TypeError('Placeholder {} has no value to feed!'.format(str(placeholder.name)))
|
||||||
|
next_max_qvalue = np.max(self._networks[-1].eval_value_all_actions(
|
||||||
|
current_data['observation'].reshape((1,) + current_data['observation'].shape)))
|
||||||
|
current_qvalue = self._networks[-1].eval_value_all_actions(
|
||||||
|
current_data['previous_observation']
|
||||||
|
.reshape((1,) + current_data['previous_observation'].shape))[0, current_data['previous_action']]
|
||||||
|
reward = current_data['reward'] + next_max_qvalue - current_qvalue
|
||||||
|
import math
|
||||||
|
self._replay_memory.update_priority([current_index], [math.fabs(reward)])
|
||||||
|
if is_first:
|
||||||
|
is_first = False
|
||||||
|
|
||||||
|
if standardize_advantage:
|
||||||
|
if self._require_advantage:
|
||||||
|
advantage_value = feed_dict[self._required_placeholders['advantage']]
|
||||||
|
advantage_mean = np.mean(advantage_value)
|
||||||
|
advantage_std = np.std(advantage_value)
|
||||||
|
if advantage_std < 1e-3:
|
||||||
|
logging.warning(
|
||||||
|
'advantage_std too small (< 1e-3) for advantage standardization. may cause numerical issues')
|
||||||
|
feed_dict[self._required_placeholders['advantage']] = (advantage_value - advantage_mean) / advantage_std
|
||||||
|
return feed_dict
|
@ -218,4 +218,5 @@ class BinaryHeap(object):
|
|||||||
:param priority_ids: list of priority id
|
:param priority_ids: list of priority id
|
||||||
:return: list of experience id
|
:return: list of experience id
|
||||||
"""
|
"""
|
||||||
|
# print(priority_ids)
|
||||||
return [self.p2e[i] for i in priority_ids]
|
return [self.p2e[i] for i in priority_ids]
|
||||||
|
@ -7,13 +7,15 @@ from tianshou.data.replay_buffer.buffer import ReplayBuffer
|
|||||||
|
|
||||||
|
|
||||||
class NaiveExperience(ReplayBuffer):
|
class NaiveExperience(ReplayBuffer):
|
||||||
def __init__(self, env, policy, qnet, target_qnet, conf):
|
# def __init__(self, env, policy, qnet, target_qnet, conf):
|
||||||
|
def __init__(self, conf):
|
||||||
self.max_size = conf['size']
|
self.max_size = conf['size']
|
||||||
self._env = env
|
self._name = 'naive'
|
||||||
self._policy = policy
|
# self._env = env
|
||||||
self._qnet = qnet
|
# self._policy = policy
|
||||||
self._target_qnet = target_qnet
|
# self._qnet = qnet
|
||||||
self._begin_act()
|
# self._target_qnet = target_qnet
|
||||||
|
# self._begin_act()
|
||||||
self.n_entries = 0
|
self.n_entries = 0
|
||||||
self.memory = deque(maxlen=self.max_size)
|
self.memory = deque(maxlen=self.max_size)
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ class PropotionalExperience(ReplayBuffer):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env, policy, qnet, target_qnet, conf):
|
def __init__(self, conf):
|
||||||
""" Prioritized experience replay buffer initialization.
|
""" Prioritized experience replay buffer initialization.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -38,11 +38,12 @@ class PropotionalExperience(ReplayBuffer):
|
|||||||
self.memory_size = memory_size
|
self.memory_size = memory_size
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self._env = env
|
# self._env = env
|
||||||
self._policy = policy
|
# self._policy = policy
|
||||||
self._qnet = qnet
|
# self._qnet = qnet
|
||||||
self._target_qnet = target_qnet
|
# self._target_qnet = target_qnet
|
||||||
self._begin_act()
|
# self._begin_act()
|
||||||
|
self._name = 'proportional'
|
||||||
|
|
||||||
def _begin_act(self):
|
def _begin_act(self):
|
||||||
"""
|
"""
|
||||||
@ -58,7 +59,7 @@ class PropotionalExperience(ReplayBuffer):
|
|||||||
self.action = self._env.action_space.sample()
|
self.action = self._env.action_space.sample()
|
||||||
self.observation, _, done, _ = self._env.step(self.action)
|
self.observation, _, done, _ = self._env.step(self.action)
|
||||||
|
|
||||||
def add(self, data, priority):
|
def add(self, data, priority=1):
|
||||||
""" Add new sample.
|
""" Add new sample.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -195,7 +196,3 @@ class PropotionalExperience(ReplayBuffer):
|
|||||||
priorities = [self.tree.get_val(i)**-old_alpha for i in range(self.tree.filled_size())]
|
priorities = [self.tree.get_val(i)**-old_alpha for i in range(self.tree.filled_size())]
|
||||||
self.update_priority(range(self.tree.filled_size()), priorities)
|
self.update_priority(range(self.tree.filled_size()), priorities)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,15 +16,16 @@ from tianshou.data.replay_buffer.buffer import ReplayBuffer
|
|||||||
|
|
||||||
class RankBasedExperience(ReplayBuffer):
|
class RankBasedExperience(ReplayBuffer):
|
||||||
|
|
||||||
def __init__(self, env, policy, qnet, target_qnet, conf):
|
def __init__(self, conf):
|
||||||
self.size = conf['size']
|
self.size = conf['size']
|
||||||
self.replace_flag = conf['replace_old'] if 'replace_old' in conf else True
|
self.replace_flag = conf['replace_old'] if 'replace_old' in conf else True
|
||||||
self.priority_size = conf['priority_size'] if 'priority_size' in conf else self.size
|
self.priority_size = conf['priority_size'] if 'priority_size' in conf else self.size
|
||||||
|
self._name = 'rank_based'
|
||||||
|
|
||||||
self.alpha = conf['alpha'] if 'alpha' in conf else 0.7
|
self.alpha = conf['alpha'] if 'alpha' in conf else 0.7
|
||||||
self.beta_zero = conf['beta_zero'] if 'beta_zero' in conf else 0.5
|
self.beta_zero = conf['beta_zero'] if 'beta_zero' in conf else 0.5
|
||||||
self.batch_size = conf['batch_size'] if 'batch_size' in conf else 32
|
self.batch_size = conf['batch_size'] if 'batch_size' in conf else 32
|
||||||
self.learn_start = conf['learn_start'] if 'learn_start' in conf else 1000
|
self.learn_start = conf['learn_start'] if 'learn_start' in conf else 10
|
||||||
self.total_steps = conf['steps'] if 'steps' in conf else 100000
|
self.total_steps = conf['steps'] if 'steps' in conf else 100000
|
||||||
# partition number N, split total size to N part
|
# partition number N, split total size to N part
|
||||||
self.partition_num = conf['partition_num'] if 'partition_num' in conf else 10
|
self.partition_num = conf['partition_num'] if 'partition_num' in conf else 10
|
||||||
@ -33,11 +34,11 @@ class RankBasedExperience(ReplayBuffer):
|
|||||||
self.record_size = 0
|
self.record_size = 0
|
||||||
self.isFull = False
|
self.isFull = False
|
||||||
|
|
||||||
self._env = env
|
# self._env = env
|
||||||
self._policy = policy
|
# self._policy = policy
|
||||||
self._qnet = qnet
|
# self._qnet = qnet
|
||||||
self._target_qnet = target_qnet
|
# self._target_qnet = target_qnet
|
||||||
self._begin_act()
|
# self._begin_act()
|
||||||
|
|
||||||
self._experience = {}
|
self._experience = {}
|
||||||
self.priority_queue = BinaryHeap(self.priority_size)
|
self.priority_queue = BinaryHeap(self.priority_size)
|
||||||
@ -241,12 +242,14 @@ class RankBasedExperience(ReplayBuffer):
|
|||||||
# issue 1 by @camigord
|
# issue 1 by @camigord
|
||||||
partition_size = math.floor(self.size * 1. / self.partition_num)
|
partition_size = math.floor(self.size * 1. / self.partition_num)
|
||||||
partition_max = dist_index * partition_size
|
partition_max = dist_index * partition_size
|
||||||
|
# print(self.record_size, self.partition_num, partition_max, partition_size, dist_index)
|
||||||
|
# print(self.distributions.keys())
|
||||||
distribution = self.distributions[dist_index]
|
distribution = self.distributions[dist_index]
|
||||||
rank_list = []
|
rank_list = []
|
||||||
# sample from k segments
|
# sample from k segments
|
||||||
for n in range(1, self.batch_size + 1):
|
for n in range(1, self.batch_size + 1):
|
||||||
index = random.randint(distribution['strata_ends'][n],
|
index = max(random.randint(distribution['strata_ends'][n],
|
||||||
distribution['strata_ends'][n + 1])
|
distribution['strata_ends'][n + 1]), 1)
|
||||||
rank_list.append(index)
|
rank_list.append(index)
|
||||||
|
|
||||||
# beta, increase by global_step, max 1
|
# beta, increase by global_step, max 1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user