Merge branch 'master' of github.com:sproblvem/tianshou
This commit is contained in:
commit
2eb056a721
@ -15,11 +15,8 @@ import tianshou.data.advantage_estimation as advantage_estimation
|
||||
import tianshou.core.policy.dqn as policy # TODO: fix imports as zhusuan so that only need to import to policy
|
||||
import tianshou.core.value_function.action_value as value_function
|
||||
|
||||
import tianshou.data.replay as replay
|
||||
import tianshou.data.data_collector as data_collector
|
||||
|
||||
|
||||
# TODO: why this solves cartpole even without training?
|
||||
from tianshou.data.replay_buffer.vanilla import VanillaReplayBuffer
|
||||
from tianshou.data.data_collector import DataCollector
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -57,9 +54,18 @@ if __name__ == '__main__':
|
||||
train_op = optimizer.minimize(total_loss, var_list=dqn.trainable_variables)
|
||||
|
||||
### 3. define data collection
|
||||
replay_buffer = replay()
|
||||
replay_buffer = VanillaReplayBuffer(capacity=1e5, nstep=1)
|
||||
|
||||
data_collector = data_collector(env, pi, [advantage_estimation.nstep_q_return(1, dqn)], [dqn], replay_buffer)
|
||||
process_functions = [advantage_estimation.nstep_q_return(1, dqn)]
|
||||
managed_networks = [dqn]
|
||||
|
||||
data_collector = DataCollector(
|
||||
env=env,
|
||||
policy=pi,
|
||||
data_buffer=replay_buffer,
|
||||
process_functions=process_functions,
|
||||
managed_networks=managed_networks
|
||||
)
|
||||
|
||||
### 4. start training
|
||||
config = tf.ConfigProto()
|
||||
@ -73,7 +79,7 @@ if __name__ == '__main__':
|
||||
start_time = time.time()
|
||||
for i in range(100):
|
||||
# collect data
|
||||
data_collector.collect(num_episodes=50)
|
||||
data_collector.collect()
|
||||
|
||||
# print current return
|
||||
print('Epoch {}:'.format(i))
|
||||
|
@ -25,7 +25,7 @@ def full_return(buffer, index=None):
|
||||
if index_this:
|
||||
episode = raw_data[i_episode]
|
||||
if not episode[-1][DONE]:
|
||||
logging.warning('Computing full return on episode {} with no terminal state.'.format(i_episode))
|
||||
logging.warning('Computing full return on episode {} which is not terminated.'.format(i_episode))
|
||||
|
||||
episode_length = len(episode)
|
||||
returns_episode = [0.] * episode_length
|
||||
|
67
tianshou/data/data_collector.py
Normal file
67
tianshou/data/data_collector.py
Normal file
@ -0,0 +1,67 @@
|
||||
from .replay_buffer.base import ReplayBufferBase
|
||||
|
||||
class DataCollector(object):
|
||||
"""
|
||||
a utility class to manage the interaction between buffer and advantage_estimation
|
||||
"""
|
||||
def __init__(self, env, policy, data_buffer, process_functions, managed_networks):
|
||||
self.env = env
|
||||
self.policy = policy
|
||||
self.data_buffer = data_buffer
|
||||
self.process_functions = process_functions
|
||||
self.managed_networks = managed_networks
|
||||
|
||||
self.required_placeholders = {}
|
||||
for net in self.managed_networks:
|
||||
self.required_placeholders.update(net.managed_placeholders)
|
||||
self.require_advantage = 'advantage' in self.required_placeholders.keys()
|
||||
|
||||
if isinstance(self.data_buffer, ReplayBufferBase): # process when sampling minibatch
|
||||
self.process_mode = 'minibatch'
|
||||
else:
|
||||
self.process_mode = 'batch'
|
||||
|
||||
self.current_observation = self.env.reset()
|
||||
|
||||
def collect(self, num_timesteps=1, num_episodes=0, exploration=None, my_feed_dict={}):
|
||||
assert sum([num_timesteps > 0, num_episodes > 0]) == 1,\
|
||||
"One and only one collection number specification permitted!"
|
||||
|
||||
if num_timesteps > 0:
|
||||
for _ in range(num_timesteps):
|
||||
action_vanilla = self.policy.act(self.current_observation, my_feed_dict=my_feed_dict)
|
||||
if exploration:
|
||||
action = exploration(action_vanilla)
|
||||
else:
|
||||
action = action_vanilla
|
||||
|
||||
next_observation, reward, done, _ = self.env.step(action)
|
||||
self.data_buffer.add((self.current_observation, action, reward, done))
|
||||
self.current_observation = next_observation
|
||||
|
||||
if num_episodes > 0:
|
||||
for _ in range(num_episodes):
|
||||
observation = self.env.reset()
|
||||
done = False
|
||||
while not done:
|
||||
action_vanilla = self.policy.act(observation, my_feed_dict=my_feed_dict)
|
||||
if exploration:
|
||||
action = exploration(action_vanilla)
|
||||
else:
|
||||
action = action_vanilla
|
||||
|
||||
next_observation, reward, done, _ = self.env.step(action)
|
||||
self.data_buffer.add((observation, action, reward, done))
|
||||
observation = next_observation
|
||||
|
||||
def next_batch(self, batch_size):
|
||||
sampled_index = self.data_buffer.sample(batch_size)
|
||||
if self.process_mode == 'minibatch':
|
||||
pass
|
||||
|
||||
# flatten rank-2 list to numpy array
|
||||
|
||||
return
|
||||
|
||||
def statistics(self):
|
||||
pass
|
@ -1,164 +0,0 @@
|
||||
import tianshou.data.replay_buffer.naive as naive
|
||||
import tianshou.data.replay_buffer.rank_based as rank_based
|
||||
import tianshou.data.replay_buffer.proportional as proportional
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tianshou.data import utils
|
||||
import logging
|
||||
|
||||
|
||||
class Replay(object):
|
||||
def __init__(self, replay_memory, env, pi, reward_processors, networks):
|
||||
self._replay_memory = replay_memory
|
||||
self._env = env
|
||||
self._pi = pi
|
||||
self._reward_processors = reward_processors
|
||||
self._networks = networks
|
||||
|
||||
self._required_placeholders = {}
|
||||
for net in self._networks:
|
||||
self._required_placeholders.update(net.managed_placeholders)
|
||||
self._require_advantage = 'advantage' in self._required_placeholders.keys()
|
||||
self._collected_data = list()
|
||||
|
||||
self._is_first_collect = True
|
||||
|
||||
def _begin_act(self, exploration):
|
||||
while self._is_first_collect:
|
||||
self._observation = self._env.reset()
|
||||
self._action = self._pi.act(self._observation, exploration)
|
||||
self._observation, reward, done, _ = self._env.step(self._action)
|
||||
if not done:
|
||||
self._is_first_collect = False
|
||||
|
||||
def collect(self, nums, exploration=None):
|
||||
"""
|
||||
collect data for replay memory and update the priority according to the given data.
|
||||
store the previous action, previous observation, reward, action, observation in the replay memory.
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
self._collected_data = list()
|
||||
|
||||
for _ in range(0, nums):
|
||||
if self._is_first_collect:
|
||||
self._begin_act(exploration)
|
||||
|
||||
current_data = dict()
|
||||
current_data['previous_action'] = self._action
|
||||
current_data['previous_observation'] = self._observation
|
||||
self._action = self._pi.act(self._observation, exploration)
|
||||
self._observation, reward, done, _ = self._env.step(self._action)
|
||||
current_data['action'] = self._action
|
||||
current_data['observation'] = self._observation
|
||||
current_data['reward'] = reward
|
||||
current_data['end_flag'] = done
|
||||
self._replay_memory.add(current_data)
|
||||
self._collected_data.append(current_data)
|
||||
if done:
|
||||
self._begin_act(exploration)
|
||||
|
||||
# I don't know what statistics should replay memory provide, for replay memory only saves discrete data
|
||||
def statistics(self):
|
||||
"""
|
||||
compute the statistics of the current sampled paths
|
||||
:return:
|
||||
"""
|
||||
raw_data = dict(zip(self._collected_data[0], zip(*[d.values() for d in self._collected_data])))
|
||||
rewards = np.array(raw_data['reward'])
|
||||
episode_start_flags = np.array(raw_data['end_flag'])
|
||||
num_timesteps = rewards.shape[0]
|
||||
|
||||
returns = []
|
||||
episode_lengths = []
|
||||
max_return = 0
|
||||
num_episodes = 1
|
||||
episode_start_idx = 0
|
||||
for i in range(1, num_timesteps):
|
||||
if episode_start_flags[i] or (
|
||||
i == num_timesteps - 1): # found the start of next episode or the end of all episodes
|
||||
if episode_start_flags[i]:
|
||||
num_episodes += 1
|
||||
if i < rewards.shape[0] - 1:
|
||||
t = i - 1
|
||||
else:
|
||||
t = i
|
||||
Gt = 0
|
||||
episode_lengths.append(t - episode_start_idx)
|
||||
while t >= episode_start_idx:
|
||||
Gt += rewards[t]
|
||||
t -= 1
|
||||
|
||||
returns.append(Gt)
|
||||
if Gt > max_return:
|
||||
max_return = Gt
|
||||
episode_start_idx = i
|
||||
|
||||
print('AverageReturn: {}'.format(np.mean(returns)))
|
||||
print('StdReturn : {}'.format(np.std(returns)))
|
||||
print('NumEpisodes : {}'.format(num_episodes))
|
||||
print('MinMaxReturns: {}..., {}'.format(np.sort(returns)[:3], np.sort(returns)[-3:]))
|
||||
print('AverageLength: {}'.format(np.mean(episode_lengths)))
|
||||
print('MinMaxLengths: {}..., {}'.format(np.sort(episode_lengths)[:3], np.sort(episode_lengths)[-3:]))
|
||||
|
||||
def next_batch(self, batch_size, global_step=0, standardize_advantage=True):
|
||||
"""
|
||||
collect a batch of data from replay buffer, update the priority and calculate the necessary statistics for
|
||||
updating q value network.
|
||||
:param batch_size: int batch size.
|
||||
:param global_step: int training global step.
|
||||
:return: a batch of data, with target storing the target q value and wi, rewards storing the coefficient
|
||||
for gradient of q value network.
|
||||
"""
|
||||
|
||||
feed_dict = {}
|
||||
is_first = True
|
||||
|
||||
for _ in range(0, batch_size):
|
||||
current_datas, current_wis, current_indexs = \
|
||||
self._replay_memory.sample(
|
||||
{'batch_size': 1, 'global_step': global_step})
|
||||
current_data = current_datas[0]
|
||||
current_wi = current_wis[0]
|
||||
current_index = current_indexs[0]
|
||||
current_processed_data = {}
|
||||
for processors in self._reward_processors:
|
||||
current_processed_data.update(processors(current_data))
|
||||
|
||||
for key, placeholder in self._required_placeholders.items():
|
||||
found, data_key = utils.internal_key_match(key, current_data.keys())
|
||||
if found:
|
||||
if is_first:
|
||||
feed_dict[placeholder] = np.array([current_data[data_key]])
|
||||
else:
|
||||
feed_dict[placeholder] = np.append(feed_dict[placeholder], np.array([current_data[data_key]]), 0)
|
||||
else:
|
||||
found, data_key = utils.internal_key_match(key, current_processed_data.keys())
|
||||
if found:
|
||||
if is_first:
|
||||
feed_dict[placeholder] = np.array(current_processed_data[data_key])
|
||||
else:
|
||||
feed_dict[placeholder] = np.append(feed_dict[placeholder],
|
||||
np.array(current_processed_data[data_key]), 0)
|
||||
else:
|
||||
raise TypeError('Placeholder {} has no value to feed!'.format(str(placeholder.name)))
|
||||
next_max_qvalue = np.max(self._networks[-1].eval_value_all_actions(
|
||||
current_data['observation'].reshape((1,) + current_data['observation'].shape)))
|
||||
current_qvalue = self._networks[-1].eval_value_all_actions(
|
||||
current_data['previous_observation']
|
||||
.reshape((1,) + current_data['previous_observation'].shape))[0, current_data['previous_action']]
|
||||
reward = current_data['reward'] + next_max_qvalue - current_qvalue
|
||||
import math
|
||||
self._replay_memory.update_priority([current_index], [math.fabs(reward)])
|
||||
if is_first:
|
||||
is_first = False
|
||||
|
||||
if standardize_advantage:
|
||||
if self._require_advantage:
|
||||
advantage_value = feed_dict[self._required_placeholders['advantage']]
|
||||
advantage_mean = np.mean(advantage_value)
|
||||
advantage_std = np.std(advantage_value)
|
||||
if advantage_std < 1e-3:
|
||||
logging.warning(
|
||||
'advantage_std too small (< 1e-3) for advantage standardization. may cause numerical issues')
|
||||
feed_dict[self._required_placeholders['advantage']] = (advantage_value - advantage_mean) / advantage_std
|
||||
return feed_dict
|
14
tianshou/data/replay_buffer/base.py
Normal file
14
tianshou/data/replay_buffer/base.py
Normal file
@ -0,0 +1,14 @@
|
||||
|
||||
|
||||
class ReplayBufferBase(object):
|
||||
"""
|
||||
base class for replay buffer.
|
||||
"""
|
||||
def add(self, frame):
|
||||
raise NotImplementedError()
|
||||
|
||||
def remove(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def sample(self, batch_size):
|
||||
raise NotImplementedError()
|
@ -1,222 +0,0 @@
|
||||
#!/usr/bin/python
|
||||
# -*- encoding=utf-8 -*-
|
||||
# author: Ian
|
||||
# e-mail: stmayue@gmail.com
|
||||
# description:
|
||||
|
||||
import sys
|
||||
import math
|
||||
|
||||
from . import utility
|
||||
|
||||
|
||||
class BinaryHeap(object):
|
||||
|
||||
def __init__(self, priority_size=100, priority_init=None, replace=True):
|
||||
self.e2p = {}
|
||||
self.p2e = {}
|
||||
self.replace = replace
|
||||
|
||||
if priority_init is None:
|
||||
self.priority_queue = {}
|
||||
self.size = 0
|
||||
self.max_size = priority_size
|
||||
else:
|
||||
# not yet test
|
||||
self.priority_queue = priority_init
|
||||
self.size = len(self.priority_queue)
|
||||
self.max_size = None or self.size
|
||||
|
||||
experience_list = list(map(lambda x: self.priority_queue[x], self.priority_queue))
|
||||
self.p2e = utility.list_to_dict(experience_list)
|
||||
self.e2p = utility.exchange_key_value(self.p2e)
|
||||
for i in range(int(self.size / 2), -1, -1):
|
||||
self.down_heap(i)
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
:return: string of the priority queue, with level info
|
||||
"""
|
||||
if self.size == 0:
|
||||
return 'No element in heap!'
|
||||
to_string = ''
|
||||
level = -1
|
||||
max_level = int(math.floor(math.log(self.size, 2)))
|
||||
|
||||
for i in range(1, self.size + 1):
|
||||
now_level = int(math.floor(math.log(i, 2)))
|
||||
if level != now_level:
|
||||
to_string = to_string + ('\n' if level != -1 else '') \
|
||||
+ ' ' * (max_level - now_level)
|
||||
level = now_level
|
||||
|
||||
to_string = to_string + '%.2f ' % self.priority_queue[i][1] + ' ' * (max_level - now_level)
|
||||
|
||||
return to_string
|
||||
|
||||
def check_full(self):
|
||||
return self.size > self.max_size
|
||||
|
||||
def _insert(self, priority, e_id):
|
||||
"""
|
||||
insert new experience id with priority
|
||||
(maybe don't need get_max_priority and implement it in this function)
|
||||
:param priority: priority value
|
||||
:param e_id: experience id
|
||||
:return: bool
|
||||
"""
|
||||
self.size += 1
|
||||
|
||||
if self.check_full() and not self.replace:
|
||||
sys.stderr.write('Error: no space left to add experience id %d with priority value %f\n' % (e_id, priority))
|
||||
return False
|
||||
else:
|
||||
self.size = min(self.size, self.max_size)
|
||||
|
||||
self.priority_queue[self.size] = (priority, e_id)
|
||||
self.p2e[self.size] = e_id
|
||||
self.e2p[e_id] = self.size
|
||||
|
||||
self.up_heap(self.size)
|
||||
return True
|
||||
|
||||
def update(self, priority, e_id):
|
||||
"""
|
||||
update priority value according its experience id
|
||||
:param priority: new priority value
|
||||
:param e_id: experience id
|
||||
:return: bool
|
||||
"""
|
||||
if e_id in self.e2p:
|
||||
p_id = self.e2p[e_id]
|
||||
self.priority_queue[p_id] = (priority, e_id)
|
||||
self.p2e[p_id] = e_id
|
||||
|
||||
self.down_heap(p_id)
|
||||
self.up_heap(p_id)
|
||||
return True
|
||||
else:
|
||||
# this e id is new, do insert
|
||||
return self._insert(priority, e_id)
|
||||
|
||||
def get_max_priority(self):
|
||||
"""
|
||||
get max priority, if no experience, return 1
|
||||
:return: max priority if size > 0 else 1
|
||||
"""
|
||||
if self.size > 0:
|
||||
return self.priority_queue[1][0]
|
||||
else:
|
||||
return 1
|
||||
|
||||
def pop(self):
|
||||
"""
|
||||
pop out the max priority value with its experience id
|
||||
:return: priority value & experience id
|
||||
"""
|
||||
if self.size == 0:
|
||||
sys.stderr.write('Error: no value in heap, pop failed\n')
|
||||
return False, False
|
||||
|
||||
pop_priority, pop_e_id = self.priority_queue[1]
|
||||
self.e2p[pop_e_id] = -1
|
||||
# replace first
|
||||
last_priority, last_e_id = self.priority_queue[self.size]
|
||||
self.priority_queue[1] = (last_priority, last_e_id)
|
||||
self.size -= 1
|
||||
self.e2p[last_e_id] = 1
|
||||
self.p2e[1] = last_e_id
|
||||
|
||||
self.down_heap(1)
|
||||
|
||||
return pop_priority, pop_e_id
|
||||
|
||||
def up_heap(self, i):
|
||||
"""
|
||||
upward balance
|
||||
:param i: tree node i
|
||||
:return: None
|
||||
"""
|
||||
if i > 1:
|
||||
parent = math.floor(i / 2)
|
||||
if self.priority_queue[parent][0] < self.priority_queue[i][0]:
|
||||
tmp = self.priority_queue[i]
|
||||
self.priority_queue[i] = self.priority_queue[parent]
|
||||
self.priority_queue[parent] = tmp
|
||||
# change e2p & p2e
|
||||
self.e2p[self.priority_queue[i][1]] = i
|
||||
self.e2p[self.priority_queue[parent][1]] = parent
|
||||
self.p2e[i] = self.priority_queue[i][1]
|
||||
self.p2e[parent] = self.priority_queue[parent][1]
|
||||
# up heap parent
|
||||
self.up_heap(parent)
|
||||
|
||||
def down_heap(self, i):
|
||||
"""
|
||||
downward balance
|
||||
:param i: tree node i
|
||||
:return: None
|
||||
"""
|
||||
if i < self.size:
|
||||
greatest = i
|
||||
left, right = i * 2, i * 2 + 1
|
||||
if left < self.size and self.priority_queue[left][0] > self.priority_queue[greatest][0]:
|
||||
greatest = left
|
||||
if right < self.size and self.priority_queue[right][0] > self.priority_queue[greatest][0]:
|
||||
greatest = right
|
||||
|
||||
if greatest != i:
|
||||
tmp = self.priority_queue[i]
|
||||
self.priority_queue[i] = self.priority_queue[greatest]
|
||||
self.priority_queue[greatest] = tmp
|
||||
# change e2p & p2e
|
||||
self.e2p[self.priority_queue[i][1]] = i
|
||||
self.e2p[self.priority_queue[greatest][1]] = greatest
|
||||
self.p2e[i] = self.priority_queue[i][1]
|
||||
self.p2e[greatest] = self.priority_queue[greatest][1]
|
||||
# down heap greatest
|
||||
self.down_heap(greatest)
|
||||
|
||||
def get_priority(self):
|
||||
"""
|
||||
get all priority value
|
||||
:return: list of priority
|
||||
"""
|
||||
return list(map(lambda x: x[0], self.priority_queue.values()))[0:self.size]
|
||||
|
||||
def get_e_id(self):
|
||||
"""
|
||||
get all experience id in priority queue
|
||||
:return: list of experience ids order by their priority
|
||||
"""
|
||||
return list(map(lambda x: x[1], self.priority_queue.values()))[0:self.size]
|
||||
|
||||
def balance_tree(self):
|
||||
"""
|
||||
rebalance priority queue
|
||||
:return: None
|
||||
"""
|
||||
sort_array = sorted(self.priority_queue.values(), key=lambda x: x[0], reverse=True)
|
||||
# reconstruct priority_queue
|
||||
self.priority_queue.clear()
|
||||
self.p2e.clear()
|
||||
self.e2p.clear()
|
||||
cnt = 1
|
||||
while cnt <= self.size:
|
||||
priority, e_id = sort_array[cnt - 1]
|
||||
self.priority_queue[cnt] = (priority, e_id)
|
||||
self.p2e[cnt] = e_id
|
||||
self.e2p[e_id] = cnt
|
||||
cnt += 1
|
||||
# sort the heap
|
||||
for i in range(int(math.floor(self.size / 2)), 1, -1):
|
||||
self.down_heap(i)
|
||||
|
||||
def priority_to_experience(self, priority_ids):
|
||||
"""
|
||||
retrieve experience ids by priority ids
|
||||
:param priority_ids: list of priority id
|
||||
:return: list of experience id
|
||||
"""
|
||||
# print(priority_ids)
|
||||
return [self.p2e[i] for i in priority_ids]
|
@ -1,51 +0,0 @@
|
||||
class ReplayBuffer(object):
|
||||
def __init__(self, env, policy, qnet, target_qnet, conf):
|
||||
"""
|
||||
Initialize a replay buffer with parameters in conf.
|
||||
"""
|
||||
pass
|
||||
|
||||
def add(self, data, priority):
|
||||
"""
|
||||
Add a data with priority = priority to replay buffer.
|
||||
"""
|
||||
pass
|
||||
|
||||
def collect(self):
|
||||
"""
|
||||
Collect data from current environment and policy.
|
||||
"""
|
||||
pass
|
||||
|
||||
def next_batch(self, batch_size):
|
||||
"""
|
||||
get batch of data from the replay buffer.
|
||||
"""
|
||||
pass
|
||||
|
||||
def update_priority(self, indices, priorities):
|
||||
"""
|
||||
Update the data's priority whose indices = indices.
|
||||
For proportional replay buffer, the priority is the priority.
|
||||
For rank based replay buffer, the priorities parameter will be the delta used to update the priority.
|
||||
"""
|
||||
pass
|
||||
|
||||
def reset_alpha(self, alpha):
|
||||
"""
|
||||
This function only works for proportional replay buffer.
|
||||
This function resets alpha.
|
||||
"""
|
||||
pass
|
||||
|
||||
def sample(self, conf):
|
||||
"""
|
||||
Sample from replay buffer with parameters in conf.
|
||||
"""
|
||||
pass
|
||||
|
||||
def rebalance(self):
|
||||
"""
|
||||
This is for rank based priority replay buffer, which is used to rebalance the sum tree of the priority queue.
|
||||
"""
|
||||
pass
|
@ -1,110 +0,0 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from collections import deque
|
||||
from math import fabs
|
||||
|
||||
from .buffer import ReplayBuffer
|
||||
|
||||
|
||||
class NaiveExperience(ReplayBuffer):
|
||||
# def __init__(self, env, policy, qnet, target_qnet, conf):
|
||||
def __init__(self, conf):
|
||||
self.max_size = conf['size']
|
||||
self._name = 'naive'
|
||||
# self._env = env
|
||||
# self._policy = policy
|
||||
# self._qnet = qnet
|
||||
# self._target_qnet = target_qnet
|
||||
# self._begin_act()
|
||||
self.n_entries = 0
|
||||
self.memory = deque(maxlen=self.max_size)
|
||||
|
||||
def add(self, data, priority=0):
|
||||
self.memory.append(data)
|
||||
if self.n_entries < self.max_size:
|
||||
self.n_entries += 1
|
||||
|
||||
def _begin_act(self):
|
||||
"""
|
||||
if the previous interaction is ended or the interaction hasn't started
|
||||
then begin act from the state of env.reset()
|
||||
"""
|
||||
self.observation = self._env.reset()
|
||||
self.action = self._env.action_space.sample()
|
||||
done = False
|
||||
while not done:
|
||||
if done:
|
||||
self.observation = self._env.reset()
|
||||
self.action = self._env.action_space.sample()
|
||||
self.observation, _, done, _ = self._env.step(self.action)
|
||||
|
||||
def collect(self):
|
||||
"""
|
||||
collect data for replay memory and update the priority according to the given data.
|
||||
store the previous action, previous observation, reward, action, observation in the replay memory.
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
current_data = dict()
|
||||
current_data['previous_action'] = self.action
|
||||
current_data['previous_observation'] = self.observation
|
||||
self.action = np.argmax(sess.run(self._policy, feed_dict={"dqn_observation:0": self.observation.reshape((1,) + self.observation.shape)}))
|
||||
self.observation, reward, done, _ = self._env.step(self.action)
|
||||
current_data['action'] = self.action
|
||||
current_data['observation'] = self.observation
|
||||
current_data['reward'] = reward
|
||||
self.add(current_data)
|
||||
if done:
|
||||
self._begin_act()
|
||||
|
||||
def update_priority(self, indices, priorities=0):
|
||||
pass
|
||||
|
||||
def reset_alpha(self, alpha):
|
||||
pass
|
||||
|
||||
def sample(self, conf):
|
||||
batch_size = conf['batch_size']
|
||||
batch_size = min(len(self.memory), batch_size)
|
||||
idxs = np.random.choice(len(self.memory), batch_size)
|
||||
return [self.memory[idx] for idx in idxs], [1] * len(idxs), idxs
|
||||
|
||||
def next_batch(self, batch_size):
|
||||
"""
|
||||
collect a batch of data from replay buffer, update the priority and calculate the necessary statistics for
|
||||
updating q value network.
|
||||
:param batch_size: int batch size.
|
||||
:return: a batch of data, with target storing the target q value and wi, rewards storing the coefficient
|
||||
for gradient of q value network.
|
||||
"""
|
||||
data = dict()
|
||||
observations = list()
|
||||
actions = list()
|
||||
rewards = list()
|
||||
wi = list()
|
||||
target = list()
|
||||
|
||||
for i in range(0, batch_size):
|
||||
current_datas, current_wis, current_indexs = self.sample({'batch_size': 1})
|
||||
current_data = current_datas[0]
|
||||
current_wi = current_wis[0]
|
||||
current_index = current_indexs[0]
|
||||
observations.append(current_data['observation'])
|
||||
actions.append(current_data['action'])
|
||||
next_max_qvalue = np.max(self._target_qnet.values(current_data['observation']))
|
||||
current_qvalue = self._qnet.values(current_data['previous_observation'])[0, current_data['previous_action']]
|
||||
reward = current_data['reward'] + next_max_qvalue - current_qvalue
|
||||
rewards.append(reward)
|
||||
target.append(current_data['reward'] + next_max_qvalue)
|
||||
self.update_priority(current_index, [fabs(reward)])
|
||||
wi.append(current_wi)
|
||||
|
||||
data['observations'] = np.array(observations)
|
||||
data['actions'] = np.array(actions)
|
||||
data['rewards'] = np.array(rewards)
|
||||
data['wi'] = np.array(wi)
|
||||
data['target'] = np.array(target)
|
||||
|
||||
return data
|
||||
|
||||
def rebalance(self):
|
||||
pass
|
@ -1,198 +0,0 @@
|
||||
import numpy as np
|
||||
import random
|
||||
import tensorflow as tf
|
||||
import math
|
||||
|
||||
from tianshou.data.replay_buffer import sum_tree
|
||||
from tianshou.data.replay_buffer.buffer import ReplayBuffer
|
||||
|
||||
|
||||
class PropotionalExperience(ReplayBuffer):
|
||||
""" The class represents prioritized experience replay buffer.
|
||||
|
||||
The class has functions: store samples, pick samples with
|
||||
probability in proportion to sample's priority, update
|
||||
each sample's priority, reset alpha.
|
||||
|
||||
see https://arxiv.org/pdf/1511.05952.pdf .
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, conf):
|
||||
""" Prioritized experience replay buffer initialization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
memory_size : int
|
||||
sample size to be stored
|
||||
batch_size : int
|
||||
batch size to be selected by `select` method
|
||||
alpha: float
|
||||
exponent determine how much prioritization.
|
||||
Prob_i \sim priority_i**alpha/sum(priority**alpha)
|
||||
"""
|
||||
memory_size = conf['size']
|
||||
batch_size = conf['batch_size']
|
||||
alpha = conf['alpha'] if 'alpha' in conf else 0.6
|
||||
self.tree = sum_tree.SumTree(memory_size)
|
||||
self.memory_size = memory_size
|
||||
self.batch_size = batch_size
|
||||
self.alpha = alpha
|
||||
# self._env = env
|
||||
# self._policy = policy
|
||||
# self._qnet = qnet
|
||||
# self._target_qnet = target_qnet
|
||||
# self._begin_act()
|
||||
self._name = 'proportional'
|
||||
|
||||
def _begin_act(self):
|
||||
"""
|
||||
if the previous interaction is ended or the interaction hasn't started
|
||||
then begin act from the state of env.reset()
|
||||
"""
|
||||
self.observation = self._env.reset()
|
||||
self.action = self._env.action_space.sample()
|
||||
done = False
|
||||
while not done:
|
||||
if done:
|
||||
self.observation = self._env.reset()
|
||||
self.action = self._env.action_space.sample()
|
||||
self.observation, _, done, _ = self._env.step(self.action)
|
||||
|
||||
def add(self, data, priority=1):
|
||||
""" Add new sample.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : object
|
||||
new sample
|
||||
priority : float
|
||||
sample's priority
|
||||
"""
|
||||
self.tree.add(data, priority**self.alpha)
|
||||
|
||||
def sample(self, conf):
|
||||
""" The method return samples randomly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
beta : float
|
||||
|
||||
Returns
|
||||
-------
|
||||
out :
|
||||
list of samples
|
||||
weights:
|
||||
list of weight
|
||||
indices:
|
||||
list of sample indices
|
||||
The indices indicate sample positions in a sum tree.
|
||||
:param conf: giving beta
|
||||
"""
|
||||
beta = conf['beta'] if 'beta' in conf else 0.4
|
||||
if self.tree.filled_size() < self.batch_size:
|
||||
return None, None, None
|
||||
|
||||
out = []
|
||||
indices = []
|
||||
weights = []
|
||||
priorities = []
|
||||
for _ in range(self.batch_size):
|
||||
r = random.random()
|
||||
data, priority, index = self.tree.find(r)
|
||||
priorities.append(priority)
|
||||
weights.append((1./self.memory_size/priority)**beta if priority > 1e-16 else 0)
|
||||
indices.append(index)
|
||||
out.append(data)
|
||||
self.update_priority([index], [0]) # To avoid duplicating
|
||||
|
||||
|
||||
self.update_priority(indices, priorities) # Revert priorities
|
||||
|
||||
max_weights = max(weights)
|
||||
|
||||
weights[:] = [x / max_weights for x in weights] # Normalize for stability
|
||||
|
||||
return out, weights, indices
|
||||
|
||||
def collect(self):
|
||||
"""
|
||||
collect data for replay memory and update the priority according to the given data.
|
||||
store the previous action, previous observation, reward, action, observation in the replay memory.
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
current_data = dict()
|
||||
current_data['previous_action'] = self.action
|
||||
current_data['previous_observation'] = self.observation
|
||||
# TODO: change the name of the feed_dict
|
||||
self.action = np.argmax(sess.run(self._policy, feed_dict={"dqn_observation:0": self.observation.reshape((1,) + self.observation.shape)}))
|
||||
self.observation, reward, done, _ = self._env.step(self.action)
|
||||
current_data['action'] = self.action
|
||||
current_data['observation'] = self.observation
|
||||
current_data['reward'] = reward
|
||||
priorities = np.array([self.tree.get_val(i) ** -self.alpha for i in range(self.tree.filled_size())])
|
||||
priority = np.max(priorities) if len(priorities) > 0 else 1
|
||||
self.add(current_data, priority)
|
||||
if done:
|
||||
self._begin_act()
|
||||
|
||||
def next_batch(self, batch_size):
|
||||
"""
|
||||
collect a batch of data from replay buffer, update the priority and calculate the necessary statistics for
|
||||
updating q value network.
|
||||
:param batch_size: int batch size.
|
||||
:return: a batch of data, with target storing the target q value and wi, rewards storing the coefficient
|
||||
for gradient of q value network.
|
||||
"""
|
||||
data = dict()
|
||||
observations = list()
|
||||
actions = list()
|
||||
rewards = list()
|
||||
wi = list()
|
||||
target = list()
|
||||
|
||||
for i in range(0, batch_size):
|
||||
current_datas, current_wis, current_indexs = self.sample({'batch_size': 1})
|
||||
current_data = current_datas[0]
|
||||
current_wi = current_wis[0]
|
||||
current_index = current_indexs[0]
|
||||
observations.append(current_data['observation'])
|
||||
actions.append(current_data['action'])
|
||||
next_max_qvalue = np.max(self._target_qnet.values(current_data['observation']))
|
||||
current_qvalue = self._qnet.values(current_data['previous_observation'])[0, current_data['previous_action']]
|
||||
reward = current_data['reward'] + next_max_qvalue - current_qvalue
|
||||
rewards.append(reward)
|
||||
target.append(current_data['reward'] + next_max_qvalue)
|
||||
self.update_priority([current_index], [math.fabs(reward)])
|
||||
wi.append(current_wi)
|
||||
|
||||
data['observations'] = np.array(observations)
|
||||
data['actions'] = np.array(actions)
|
||||
data['rewards'] = np.array(rewards)
|
||||
data['wi'] = np.array(wi)
|
||||
data['target'] = np.array(target)
|
||||
|
||||
return data
|
||||
|
||||
def update_priority(self, indices, priorities):
|
||||
""" The methods update samples's priority.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
indices :
|
||||
list of sample indices
|
||||
"""
|
||||
for i, p in zip(indices, priorities):
|
||||
self.tree.val_update(i, p**self.alpha)
|
||||
|
||||
def reset_alpha(self, alpha):
|
||||
""" Reset a exponent alpha.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
alpha : float
|
||||
"""
|
||||
self.alpha, old_alpha = alpha, self.alpha
|
||||
priorities = [self.tree.get_val(i)**-old_alpha for i in range(self.tree.filled_size())]
|
||||
self.update_priority(range(self.tree.filled_size()), priorities)
|
||||
|
@ -1,262 +0,0 @@
|
||||
import sys
|
||||
import math
|
||||
import random
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tianshou.data.replay_buffer.binary_heap import BinaryHeap
|
||||
from tianshou.data.replay_buffer.buffer import ReplayBuffer
|
||||
|
||||
|
||||
class RankBasedExperience(ReplayBuffer):
|
||||
|
||||
def __init__(self, conf):
|
||||
self.size = conf['size']
|
||||
self.replace_flag = conf['replace_old'] if 'replace_old' in conf else True
|
||||
self.priority_size = conf['priority_size'] if 'priority_size' in conf else self.size
|
||||
self._name = 'rank_based'
|
||||
|
||||
self.alpha = conf['alpha'] if 'alpha' in conf else 0.7
|
||||
self.beta_zero = conf['beta_zero'] if 'beta_zero' in conf else 0.5
|
||||
self.batch_size = conf['batch_size'] if 'batch_size' in conf else 32
|
||||
self.learn_start = conf['learn_start'] if 'learn_start' in conf else 10
|
||||
self.total_steps = conf['steps'] if 'steps' in conf else 100000
|
||||
# partition number N, split total size to N part
|
||||
self.partition_num = conf['partition_num'] if 'partition_num' in conf else 10
|
||||
|
||||
self.index = 0
|
||||
self.record_size = 0
|
||||
self.isFull = False
|
||||
|
||||
# self._env = env
|
||||
# self._policy = policy
|
||||
# self._qnet = qnet
|
||||
# self._target_qnet = target_qnet
|
||||
# self._begin_act()
|
||||
|
||||
self._experience = {}
|
||||
self.priority_queue = BinaryHeap(self.priority_size)
|
||||
self.distributions = self.build_distributions()
|
||||
|
||||
self.beta_grad = (1 - self.beta_zero) / (self.total_steps - self.learn_start)
|
||||
|
||||
def build_distributions(self):
|
||||
"""
|
||||
preprocess pow of rank
|
||||
(rank i) ^ (-alpha) / sum ((rank i) ^ (-alpha))
|
||||
:return: distributions, dict
|
||||
"""
|
||||
res = {}
|
||||
n_partitions = self.partition_num
|
||||
partition_num = 1
|
||||
# each part size
|
||||
partition_size = int(math.floor(self.size / n_partitions))
|
||||
|
||||
for n in range(partition_size, self.size + 1, partition_size):
|
||||
if self.learn_start <= n <= self.priority_size:
|
||||
distribution = {}
|
||||
# P(i) = (rank i) ^ (-alpha) / sum ((rank i) ^ (-alpha))
|
||||
pdf = list(
|
||||
map(lambda x: math.pow(x, -self.alpha), range(1, n + 1))
|
||||
)
|
||||
pdf_sum = math.fsum(pdf)
|
||||
distribution['pdf'] = list(map(lambda x: x / pdf_sum, pdf))
|
||||
# split to k segment, and than uniform sample in each k
|
||||
# set k = batch_size, each segment has total probability is 1 / batch_size
|
||||
# strata_ends keep each segment start pos and end pos
|
||||
cdf = np.cumsum(distribution['pdf'])
|
||||
strata_ends = {1: 0, self.batch_size + 1: n}
|
||||
step = 1. / self.batch_size
|
||||
index = 1
|
||||
for s in range(2, self.batch_size + 1):
|
||||
while cdf[index] < step:
|
||||
index += 1
|
||||
strata_ends[s] = index
|
||||
step += 1. / self.batch_size
|
||||
|
||||
distribution['strata_ends'] = strata_ends
|
||||
|
||||
res[partition_num] = distribution
|
||||
|
||||
partition_num += 1
|
||||
|
||||
return res
|
||||
|
||||
def fix_index(self):
|
||||
"""
|
||||
get next insert index
|
||||
:return: index, int
|
||||
"""
|
||||
if self.record_size <= self.size:
|
||||
self.record_size += 1
|
||||
if self.index % self.size == 0:
|
||||
self.isFull = True if len(self._experience) == self.size else False
|
||||
if self.replace_flag:
|
||||
self.index = 1
|
||||
return self.index
|
||||
else:
|
||||
sys.stderr.write('Experience replay buff is full and replace is set to FALSE!\n')
|
||||
return -1
|
||||
else:
|
||||
self.index += 1
|
||||
return self.index
|
||||
|
||||
def _begin_act(self):
|
||||
"""
|
||||
if the previous interaction is ended or the interaction hasn't started
|
||||
then begin act from the state of env.reset()
|
||||
"""
|
||||
self.observation = self._env.reset()
|
||||
self.action = self._env.action_space.sample()
|
||||
done = False
|
||||
while not done:
|
||||
if done:
|
||||
self.observation = self._env.reset()
|
||||
self.action = self._env.action_space.sample()
|
||||
self.observation, _, done, _ = self._env.step(self.action)
|
||||
|
||||
def collect(self):
|
||||
"""
|
||||
collect data for replay memory and update the priority according to the given data.
|
||||
store the previous action, previous observation, reward, action, observation in the replay memory.
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
current_data = dict()
|
||||
current_data['previous_action'] = self.action
|
||||
current_data['previous_observation'] = self.observation
|
||||
self.action = np.argmax(sess.run(self._policy, feed_dict={"dqn_observation:0": self.observation.reshape((1,) + self.observation.shape)}))
|
||||
self.observation, reward, done, _ = self._env.step(self.action)
|
||||
current_data['action'] = self.action
|
||||
current_data['observation'] = self.observation
|
||||
current_data['reward'] = reward
|
||||
self.add(current_data)
|
||||
if done:
|
||||
self._begin_act()
|
||||
|
||||
def next_batch(self, batch_size):
|
||||
"""
|
||||
collect a batch of data from replay buffer, update the priority and calculate the necessary statistics for
|
||||
updating q value network.
|
||||
:param batch_size: int batch size.
|
||||
:return: a batch of data, with target storing the target q value and wi, rewards storing the coefficient
|
||||
for gradient of q value network.
|
||||
"""
|
||||
data = dict()
|
||||
observations = list()
|
||||
actions = list()
|
||||
rewards = list()
|
||||
wi = list()
|
||||
target = list()
|
||||
|
||||
sess = tf.get_default_session()
|
||||
# TODO: pre-build the thing in sess.run
|
||||
current_datas, current_wis, current_indexs = self.sample({'global_step': sess.run(tf.train.get_global_step())})
|
||||
|
||||
for i in range(0, batch_size):
|
||||
current_data = current_datas[i]
|
||||
current_wi = current_wis[i]
|
||||
current_index = current_indexs[i]
|
||||
observations.append(current_data['observation'])
|
||||
actions.append(current_data['action'])
|
||||
next_max_qvalue = np.max(self._target_qnet.values(current_data['observation']))
|
||||
current_qvalue = self._qnet.values(current_data['previous_observation'])[0, current_data['previous_action']]
|
||||
reward = current_data['reward'] + next_max_qvalue - current_qvalue
|
||||
rewards.append(reward)
|
||||
target.append(current_data['reward'] + next_max_qvalue)
|
||||
self.update_priority([current_index], [math.fabs(reward)])
|
||||
wi.append(current_wi)
|
||||
|
||||
data['observations'] = np.array(observations)
|
||||
data['actions'] = np.array(actions)
|
||||
data['rewards'] = np.array(rewards)
|
||||
data['wi'] = np.array(wi)
|
||||
data['target'] = np.array(target)
|
||||
|
||||
return data
|
||||
|
||||
def add(self, data, priority = 1):
|
||||
"""
|
||||
store experience, suggest that experience is a tuple of (s1, a, r, s2, t)
|
||||
so each experience is valid
|
||||
:param experience: maybe a tuple, or list
|
||||
:return: bool, indicate insert status
|
||||
"""
|
||||
insert_index = self.fix_index()
|
||||
if insert_index > 0:
|
||||
if insert_index in self._experience:
|
||||
del self._experience[insert_index]
|
||||
self._experience[insert_index] = data
|
||||
# add to priority queue
|
||||
priority = self.priority_queue.get_max_priority()
|
||||
self.priority_queue.update(priority, insert_index)
|
||||
return True
|
||||
else:
|
||||
sys.stderr.write('Insert failed\n')
|
||||
return False
|
||||
|
||||
def retrieve(self, indices):
|
||||
"""
|
||||
get experience from indices
|
||||
:param indices: list of experience id
|
||||
:return: experience replay sample
|
||||
"""
|
||||
return [self._experience[v] for v in indices]
|
||||
|
||||
def rebalance(self):
|
||||
"""
|
||||
rebalance priority queue
|
||||
:return: None
|
||||
"""
|
||||
self.priority_queue.balance_tree()
|
||||
|
||||
def update_priority(self, indices, delta):
|
||||
"""
|
||||
update priority according indices and deltas
|
||||
:param indices: list of experience id
|
||||
:param delta: list of delta, order correspond to indices
|
||||
:return: None
|
||||
"""
|
||||
for i in range(0, len(indices)):
|
||||
self.priority_queue.update(math.fabs(delta[i]), indices[i])
|
||||
|
||||
def sample(self, conf):
|
||||
"""
|
||||
sample a mini batch from experience replay
|
||||
:param global_step: now training step
|
||||
:return: experience, list, samples
|
||||
:return: w, list, weights
|
||||
:return: rank_e_id, list, samples id, used for update priority
|
||||
"""
|
||||
global_step = conf['global_step']
|
||||
if self.record_size < self.learn_start:
|
||||
sys.stderr.write('Record size less than learn start! Sample failed\n')
|
||||
return False, False, False
|
||||
|
||||
dist_index = math.floor(self.record_size * 1. / self.size * self.partition_num)
|
||||
# issue 1 by @camigord
|
||||
partition_size = math.floor(self.size * 1. / self.partition_num)
|
||||
partition_max = dist_index * partition_size
|
||||
# print(self.record_size, self.partition_num, partition_max, partition_size, dist_index)
|
||||
# print(self.distributions.keys())
|
||||
distribution = self.distributions[dist_index]
|
||||
rank_list = []
|
||||
# sample from k segments
|
||||
for n in range(1, self.batch_size + 1):
|
||||
index = max(random.randint(distribution['strata_ends'][n],
|
||||
distribution['strata_ends'][n + 1]), 1)
|
||||
rank_list.append(index)
|
||||
|
||||
# beta, increase by global_step, max 1
|
||||
beta = min(self.beta_zero + (global_step - self.learn_start - 1) * self.beta_grad, 1)
|
||||
# find all alpha pow, notice that pdf is a list, start from 0
|
||||
alpha_pow = [distribution['pdf'][v - 1] for v in rank_list]
|
||||
# w = (N * P(i)) ^ (-beta) / max w
|
||||
w = np.power(np.array(alpha_pow) * partition_max, -beta)
|
||||
w_max = max(w)
|
||||
w = np.divide(w, w_max)
|
||||
# rank list is priority id
|
||||
# convert to experience id
|
||||
rank_e_id = self.priority_queue.priority_to_experience(rank_list)
|
||||
# get experience id according rank_e_id
|
||||
experience = self.retrieve(rank_e_id)
|
||||
return experience, w, rank_e_id
|
@ -1,131 +0,0 @@
|
||||
from functions import *
|
||||
|
||||
from tianshou.data.replay_buffer.utils import get_replay_buffer
|
||||
|
||||
|
||||
def test_rank_based():
|
||||
conf = {'size': 50,
|
||||
'learn_start': 10,
|
||||
'partition_num': 5,
|
||||
'total_step': 100,
|
||||
'batch_size': 4}
|
||||
experience = get_replay_buffer('rank_based', conf)
|
||||
|
||||
# insert to experience
|
||||
print 'test insert experience'
|
||||
for i in range(1, 51):
|
||||
# tuple, like(state_t, a, r, state_t_1, t)
|
||||
to_insert = (i, 1, 1, i, 1)
|
||||
experience.add(to_insert)
|
||||
print experience.priority_queue
|
||||
print experience._experience[1]
|
||||
print experience._experience[2]
|
||||
print 'test replace'
|
||||
to_insert = (51, 1, 1, 51, 1)
|
||||
experience.add(to_insert)
|
||||
print experience.priority_queue
|
||||
print experience._experience[1]
|
||||
print experience._experience[2]
|
||||
|
||||
# sample
|
||||
print 'test sample'
|
||||
global_step = {'global_step': 51}
|
||||
sample, w, e_id = experience.sample(global_step)
|
||||
print sample
|
||||
print w
|
||||
print e_id
|
||||
|
||||
# update delta to priority
|
||||
print 'test update delta'
|
||||
delta = [v for v in range(1, 5)]
|
||||
experience.update_priority(e_id, delta)
|
||||
print experience.priority_queue
|
||||
sample, w, e_id = experience.sample(global_step)
|
||||
print sample
|
||||
print w
|
||||
print e_id
|
||||
|
||||
# rebalance
|
||||
print 'test rebalance'
|
||||
experience.rebalance()
|
||||
print experience.priority_queue
|
||||
|
||||
def test_proportional():
|
||||
conf = {'size': 50,
|
||||
'alpha': 0.7,
|
||||
'batch_size': 4}
|
||||
experience = get_replay_buffer('proportional', conf)
|
||||
|
||||
# insert to experience
|
||||
print 'test insert experience'
|
||||
for i in range(1, 51):
|
||||
# tuple, like(state_t, a, r, state_t_1, t)
|
||||
to_insert = (i, 1, 1, i, 1)
|
||||
experience.add(to_insert, i)
|
||||
print experience.tree
|
||||
print experience.tree.get_val(1)
|
||||
print experience.tree.get_val(2)
|
||||
print 'test replace'
|
||||
to_insert = (51, 1, 1, 51, 1)
|
||||
experience.add(to_insert, 51)
|
||||
print experience.tree
|
||||
print experience.tree.get_val(1)
|
||||
print experience.tree.get_val(2)
|
||||
|
||||
# sample
|
||||
print 'test sample'
|
||||
beta = {'beta': 0.005}
|
||||
sample, w, e_id = experience.sample(beta)
|
||||
print sample
|
||||
print w
|
||||
print e_id
|
||||
|
||||
# update delta to priority
|
||||
print 'test update delta'
|
||||
delta = [v for v in range(1, 5)]
|
||||
experience.update_priority(e_id, delta)
|
||||
print experience.tree
|
||||
sample, w, e_id = experience.sample(beta)
|
||||
print sample
|
||||
print w
|
||||
print e_id
|
||||
|
||||
def test_naive():
|
||||
conf = {'size': 50}
|
||||
experience = get_replay_buffer('naive', conf)
|
||||
|
||||
# insert to experience
|
||||
print 'test insert experience'
|
||||
for i in range(1, 51):
|
||||
# tuple, like(state_t, a, r, state_t_1, t)
|
||||
to_insert = (i, 1, 1, i, 1)
|
||||
experience.add(to_insert)
|
||||
print experience.memory
|
||||
print 'test replace'
|
||||
to_insert = (51, 1, 1, 51, 1)
|
||||
experience.add(to_insert)
|
||||
print experience.memory
|
||||
|
||||
# sample
|
||||
print 'test sample'
|
||||
batch_size = {'batch_size': 5}
|
||||
sample, w, e_id = experience.sample(batch_size)
|
||||
print sample
|
||||
print w
|
||||
print e_id
|
||||
|
||||
# update delta to priority
|
||||
print 'test update delta'
|
||||
delta = [v for v in range(1, 5)]
|
||||
experience.update_priority(e_id, delta)
|
||||
print experience.memory
|
||||
sample, w, e_id = experience.sample(batch_size)
|
||||
print sample
|
||||
print w
|
||||
print e_id
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_rank_based()
|
||||
test_proportional()
|
||||
test_naive()
|
@ -1,64 +0,0 @@
|
||||
#! -*- coding:utf-8 -*-
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
import os
|
||||
import math
|
||||
|
||||
class SumTree(object):
|
||||
def __init__(self, max_size):
|
||||
self.max_size = max_size
|
||||
self.tree_level = int(math.ceil(math.log(max_size+1, 2))+1)
|
||||
self.tree_size = 2**self.tree_level-1
|
||||
self.tree = [0 for i in range(self.tree_size)]
|
||||
self.data = [None for i in range(self.max_size)]
|
||||
self.size = 0
|
||||
self.cursor = 0
|
||||
|
||||
def add(self, contents, value):
|
||||
index = self.cursor
|
||||
self.cursor = (self.cursor+1)%self.max_size
|
||||
self.size = min(self.size+1, self.max_size)
|
||||
|
||||
self.data[index] = contents
|
||||
self.val_update(index, value)
|
||||
|
||||
def get_val(self, index):
|
||||
tree_index = 2**(self.tree_level-1)-1+index
|
||||
return self.tree[tree_index]
|
||||
|
||||
def val_update(self, index, value):
|
||||
tree_index = 2**(self.tree_level-1)-1+index
|
||||
diff = value-self.tree[tree_index]
|
||||
self.reconstruct(tree_index, diff)
|
||||
|
||||
def reconstruct(self, tindex, diff):
|
||||
self.tree[tindex] += diff
|
||||
if not tindex == 0:
|
||||
tindex = int((tindex-1)/2)
|
||||
self.reconstruct(tindex, diff)
|
||||
|
||||
def find(self, value, norm=True):
|
||||
if norm:
|
||||
value *= self.tree[0]
|
||||
return self._find(value, 0)
|
||||
|
||||
def _find(self, value, index):
|
||||
if 2**(self.tree_level-1)-1 <= index:
|
||||
return self.data[index-(2**(self.tree_level-1)-1)], self.tree[index], index-(2**(self.tree_level-1)-1)
|
||||
|
||||
left = self.tree[2*index+1]
|
||||
|
||||
if value <= left:
|
||||
return self._find(value,2*index+1)
|
||||
else:
|
||||
return self._find(value-left,2*(index+1))
|
||||
|
||||
def print_tree(self):
|
||||
for k in range(1, self.tree_level+1):
|
||||
for j in range(2**(k-1)-1, 2**k-1):
|
||||
print(self.tree[j], end=' ')
|
||||
print()
|
||||
|
||||
def filled_size(self):
|
||||
return self.size
|
@ -1,13 +0,0 @@
|
||||
#!/usr/bin/python
|
||||
# -*- encoding=utf-8 -*-
|
||||
# author: Ian
|
||||
# e-mail: stmayue@gmail.com
|
||||
# description:
|
||||
|
||||
|
||||
def list_to_dict(in_list):
|
||||
return dict((i, in_list[i]) for i in range(0, len(in_list)))
|
||||
|
||||
|
||||
def exchange_key_value(in_dict):
|
||||
return dict((in_dict[i], i) for i in in_dict)
|
@ -1,20 +0,0 @@
|
||||
import sys
|
||||
|
||||
from .naive import NaiveExperience
|
||||
from .proportional import PropotionalExperience
|
||||
from .rank_based import RankBasedExperience
|
||||
|
||||
|
||||
def get_replay_buffer(name, env, policy, qnet, target_qnet, conf):
|
||||
"""
|
||||
Get replay buffer according to the given name.
|
||||
"""
|
||||
|
||||
if name == 'rank_based':
|
||||
return RankBasedExperience(env, policy, qnet, target_qnet, conf)
|
||||
elif name == 'proportional':
|
||||
return PropotionalExperience(env, policy, qnet, target_qnet, conf)
|
||||
elif name == 'naive':
|
||||
return NaiveExperience(env, policy, qnet, target_qnet, conf)
|
||||
else:
|
||||
sys.stderr.write('no such replay buffer')
|
119
tianshou/data/replay_buffer/vanilla.py
Normal file
119
tianshou/data/replay_buffer/vanilla.py
Normal file
@ -0,0 +1,119 @@
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from .base import ReplayBufferBase
|
||||
|
||||
STATE = 0
|
||||
ACTION = 1
|
||||
REWARD = 2
|
||||
DONE = 3
|
||||
|
||||
class VanillaReplayBuffer(ReplayBufferBase):
|
||||
"""
|
||||
vanilla replay buffer as used in (Mnih, et al., 2015).
|
||||
Frames are always continuous in temporal order. They are only removed from the beginning. This continuity
|
||||
in `self.data` could be exploited, but only in vanilla replay buffer.
|
||||
"""
|
||||
def __init__(self, capacity, nstep=1):
|
||||
"""
|
||||
:param capacity: int. capacity of the buffer.
|
||||
:param nstep: int. number of timesteps to lookahead for temporal difference
|
||||
"""
|
||||
assert capacity > 0
|
||||
self.capacity = int(capacity)
|
||||
self.nstep = nstep
|
||||
|
||||
self.data = [[]]
|
||||
self.index = [[]]
|
||||
self.candidate_index = 0
|
||||
|
||||
self.size = 0 # number of valid data points (not frames)
|
||||
|
||||
self.index_lengths = [0] # for sampling
|
||||
|
||||
def add(self, frame):
|
||||
"""
|
||||
add one frame to the buffer.
|
||||
:param frame: tuple, (observation, action, reward, done_flag).
|
||||
"""
|
||||
self.data[-1].append(frame)
|
||||
|
||||
has_enough_frames = len(self.data[-1]) > self.nstep
|
||||
if frame[DONE]: # episode terminates, all trailing frames become valid data points
|
||||
trailing_index = list(range(self.candidate_index, len(self.data[-1])))
|
||||
self.index[-1] += trailing_index
|
||||
self.size += len(trailing_index)
|
||||
self.index_lengths[-1] += len(trailing_index)
|
||||
|
||||
# prepare for the next episode
|
||||
self.data.append([])
|
||||
self.index.append([])
|
||||
self.candidate_index = 0
|
||||
|
||||
self.index_lengths.append(0)
|
||||
|
||||
elif has_enough_frames: # add one valid data point
|
||||
self.index[-1].append(self.candidate_index)
|
||||
self.candidate_index += 1
|
||||
self.size += 1
|
||||
self.index_lengths[-1] += 1
|
||||
|
||||
# automated removal to capacity
|
||||
if self.size > self.capacity:
|
||||
self.remove()
|
||||
|
||||
def remove(self):
|
||||
"""
|
||||
remove data until `self.size` <= `self.capacity`
|
||||
"""
|
||||
if self.size:
|
||||
while self.size > self.capacity:
|
||||
self.remove_oldest()
|
||||
else:
|
||||
logging.warning('Attempting to remove from empty buffer!')
|
||||
|
||||
def remove_oldest(self):
|
||||
"""
|
||||
remove the oldest data point, in this case, just the oldest frame. Empty episodes are also removed
|
||||
if resulted from removal.
|
||||
"""
|
||||
self.index[0].pop() # note that all index of frames in the first episode are shifted forward by 1
|
||||
|
||||
if self.index[0]: # first episode still has data points
|
||||
self.data[0].pop(0)
|
||||
if len(self.data) == 1: # otherwise self.candidate index is for another episode
|
||||
self.candidate_index -= 1
|
||||
self.index_lengths[0] -= 1
|
||||
|
||||
else: # first episode becomes empty
|
||||
self.data.pop(0)
|
||||
self.index.pop(0)
|
||||
if len(self.data) == 0: # otherwise self.candidate index is for another episode
|
||||
self.candidate_index = 0
|
||||
|
||||
self.index_lengths.pop(0)
|
||||
|
||||
self.size -= 1
|
||||
|
||||
def sample(self, batch_size):
|
||||
"""
|
||||
uniform random sampling on `self.index`. For simplicity, we do random sampling with replacement
|
||||
for now with time O(`batch_size`). Fastest sampling without replacement seems to have to be of time
|
||||
O(`batch_size` * log(num_episodes)).
|
||||
:param batch_size: int.
|
||||
:return: sampled index, same structure as `self.index`. Episodes without sampled data points
|
||||
correspond to empty sub-lists.
|
||||
"""
|
||||
prob_episode = np.array(self.index_lengths) * 1. / self.size
|
||||
num_episodes = len(self.index)
|
||||
sampled_index = [[] for _ in range(num_episodes)]
|
||||
|
||||
for _ in range(batch_size):
|
||||
# sample which episode
|
||||
sampled_episode_i = int(np.random.choice(num_episodes, p=prob_episode))
|
||||
|
||||
# sample which data point within the sampled episode
|
||||
sampled_frame_i = int(np.random.randint(self.index_lengths[sampled_episode_i]))
|
||||
sampled_index[sampled_episode_i].append(sampled_frame_i)
|
||||
|
||||
return sampled_index
|
40
tianshou/data/test_replay_buffer.py
Normal file
40
tianshou/data/test_replay_buffer.py
Normal file
@ -0,0 +1,40 @@
|
||||
import numpy as np
|
||||
|
||||
from replay_buffer.vanilla import VanillaReplayBuffer
|
||||
|
||||
capacity = 12
|
||||
nstep = 3
|
||||
buffer = VanillaReplayBuffer(capacity=capacity, nstep=nstep)
|
||||
|
||||
for i in range(capacity):
|
||||
s = np.random.randint(10)
|
||||
a = np.random.randint(3)
|
||||
r = np.random.randint(5)
|
||||
done = np.random.rand() > 0.6
|
||||
|
||||
buffer.add((s, a, r, done))
|
||||
|
||||
if i % 5 == 0:
|
||||
print('i = {}:'.format(i))
|
||||
print(buffer.index)
|
||||
print(buffer.data)
|
||||
|
||||
print('Now buffer with size {}:'.format(buffer.size))
|
||||
print(buffer.index)
|
||||
print(buffer.data)
|
||||
|
||||
for i in range(5):
|
||||
s = np.random.randint(10)
|
||||
a = np.random.randint(3)
|
||||
r = np.random.randint(5)
|
||||
done = np.random.rand() > 0.6
|
||||
|
||||
buffer.add((s, a, r, done))
|
||||
print('added frame {}, {}:'.format(i, (s, a, r, done)))
|
||||
print(buffer.index)
|
||||
print(buffer.data)
|
||||
|
||||
print('sampling from buffer:')
|
||||
print(buffer.index)
|
||||
print(buffer.data)
|
||||
print(buffer.sample(8))
|
Loading…
x
Reference in New Issue
Block a user