finished very naive dqn: changed the interface of replay buffer by adding collect and next_batch, but still need refactoring; added implementation of dqn.py, but still need to consider the interface to make it more extensive; slightly refactored the code style of the codebase; more comments and todos will be in the next commit
This commit is contained in:
parent
e10acf5130
commit
62e2c6582d
@ -9,8 +9,7 @@ import gym
|
||||
import sys
|
||||
sys.path.append('..')
|
||||
import tianshou.core.losses as losses
|
||||
from tianshou.data.replay import Replay
|
||||
import tianshou.data.advantage_estimation as advantage_estimation
|
||||
from tianshou.data.replay_buffer.utils import get_replay_buffer
|
||||
import tianshou.core.policy as policy
|
||||
|
||||
|
||||
@ -38,11 +37,10 @@ if __name__ == '__main__':
|
||||
action_dim = env.action_space.n
|
||||
|
||||
# 1. build network with pure tf
|
||||
observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input
|
||||
observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim, name="dqn_observation") # network input
|
||||
|
||||
with tf.variable_scope('q_net'):
|
||||
q_values = policy_net(observation, action_dim)
|
||||
train_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES
|
||||
with tf.variable_scope('target_net'):
|
||||
q_values_target = policy_net(observation, action_dim)
|
||||
|
||||
@ -54,13 +52,15 @@ if __name__ == '__main__':
|
||||
target = tf.placeholder(dtype=tf.float32, shape=[None]) # target value for DQN
|
||||
|
||||
dqn_loss = losses.dqn_loss(action, target, q_net) # TongzhengRen
|
||||
|
||||
global_step = tf.Variable(0, name='global_step', trainable=False)
|
||||
train_var_list = tf.get_collection(
|
||||
tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES
|
||||
total_loss = dqn_loss
|
||||
optimizer = tf.train.AdamOptimizer(1e-3)
|
||||
train_op = optimizer.minimize(total_loss, var_list=train_var_list)
|
||||
|
||||
train_op = optimizer.minimize(total_loss, var_list=train_var_list, global_step=tf.train.get_global_step())
|
||||
# 3. define data collection
|
||||
training_data = Replay(env, q_net, advantage_estimation.qlearning_target(target_net)) #
|
||||
replay_memory = get_replay_buffer('rank_based', env, q_values, q_net, target_net,
|
||||
{'size': 1000, 'batch_size': 64, 'learn_start': 20})
|
||||
# ShihongSong: Replay(env, q_net, advantage_estimation.qlearning_target(target_network)), use your ReplayMemory, interact as follows. Simplify your advantage_estimation.dqn to run before YongRen's DQN
|
||||
# maybe a dict to manage the elements to be collected
|
||||
|
||||
@ -70,14 +70,16 @@ if __name__ == '__main__':
|
||||
|
||||
minibatch_count = 0
|
||||
collection_count = 0
|
||||
collect_freq = 100
|
||||
while True: # until some stopping criterion met...
|
||||
# collect data
|
||||
training_data.collect() # ShihongSong
|
||||
collection_count += 1
|
||||
print('Collected {} times.'.format(collection_count))
|
||||
for i in range(0, collect_freq):
|
||||
replay_memory.collect() # ShihongSong
|
||||
collection_count += 1
|
||||
print('Collected {} times.'.format(collection_count))
|
||||
|
||||
# update network
|
||||
data = training_data.next_batch(64) # YouQiaoben, ShihongSong
|
||||
data = replay_memory.next_batch(10) # YouQiaoben, ShihongSong
|
||||
# TODO: auto managing of the placeholders? or add this to params of data.Batch
|
||||
sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], target: data['target']})
|
||||
minibatch_count += 1
|
||||
|
@ -32,7 +32,7 @@ def vanilla_policy_gradient(sampled_action, reward, pi, baseline="None"):
|
||||
"""
|
||||
log_pi_act = pi.log_prob(sampled_action)
|
||||
vanilla_policy_gradient_loss = tf.reduce_mean(reward * log_pi_act)
|
||||
# TODO: Different baseline methods like REINFORCE, etc.
|
||||
# TODO: Different baseline methods like REINFORCE, etc.
|
||||
return vanilla_policy_gradient_loss
|
||||
|
||||
def dqn_loss(sampled_action, sampled_target, q_net):
|
||||
@ -44,8 +44,8 @@ def dqn_loss(sampled_action, sampled_target, q_net):
|
||||
:param q_net: current `policy` to be optimized
|
||||
:return:
|
||||
"""
|
||||
action_num = q_net.get_values().shape()[1]
|
||||
sampled_q = tf.reduce_sum(q_net.get_values() * tf.one_hot(sampled_action, action_num), axis=1)
|
||||
action_num = q_net.values_tensor().get_shape()[1]
|
||||
sampled_q = tf.reduce_sum(q_net.values_tensor() * tf.one_hot(sampled_action, action_num), axis=1)
|
||||
return tf.reduce_mean(tf.square(sampled_target - sampled_q))
|
||||
|
||||
def deterministic_policy_gradient(sampled_state, critic):
|
||||
|
@ -2,4 +2,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from .base import *
|
||||
from .stochastic import *
|
||||
from .stochastic import *
|
||||
from .dqn import *
|
@ -12,23 +12,28 @@ import tensorflow as tf
|
||||
|
||||
__all__ = [
|
||||
'StochasticPolicy',
|
||||
'QValuePolicy',
|
||||
]
|
||||
|
||||
#TODO: separate actor and critic, we should focus on it once we finish the basic module.
|
||||
# TODO: separate actor and critic, we should focus on it once we finish the basic module.
|
||||
|
||||
|
||||
class QValuePolicy(object):
|
||||
"""
|
||||
The policy as in DQN
|
||||
"""
|
||||
def __init__(self, observation_placeholder):
|
||||
self.observation_placeholder = observation_placeholder
|
||||
self._observation_placeholder = observation_placeholder
|
||||
|
||||
def act(self, observation, exploration=None): # first implement no exploration
|
||||
"""
|
||||
return the action (int) to be executed.
|
||||
no exploration when exploration=None.
|
||||
"""
|
||||
pass
|
||||
self._act(observation, exploration)
|
||||
|
||||
def _act(self, observation, exploration = None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def values(self, observation):
|
||||
"""
|
||||
@ -36,7 +41,7 @@ class QValuePolicy(object):
|
||||
"""
|
||||
pass
|
||||
|
||||
def values_tensor(self, observation):
|
||||
def values_tensor(self):
|
||||
"""
|
||||
returns the tensor of the values for all actions a at observation s
|
||||
"""
|
||||
|
@ -1,7 +1,54 @@
|
||||
|
||||
|
||||
from .base import QValuePolicy
|
||||
from tianshou.core.policy.base import QValuePolicy
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class DQN(QValuePolicy):
|
||||
pass
|
||||
"""
|
||||
The policy as in DQN
|
||||
"""
|
||||
|
||||
def __init__(self, logits, observation_placeholder, dtype=None, **kwargs):
|
||||
self._logits = tf.convert_to_tensor(logits)
|
||||
if dtype is None:
|
||||
dtype = tf.int32
|
||||
self._n_categories = self._logits.get_shape()[-1].value
|
||||
|
||||
super(DQN, self).__init__(observation_placeholder)
|
||||
|
||||
net = tf.layers.conv2d(self._observation_placeholder, 16, 8, 4, 'valid', activation=tf.nn.relu)
|
||||
net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu)
|
||||
net = tf.layers.flatten(net)
|
||||
net = tf.layers.dense(net, 256, activation=tf.nn.relu, use_bias=True)
|
||||
self._value = tf.layers.dense(net, self._n_categories)
|
||||
|
||||
def _act(self, observation, exploration=None): # first implement no exploration
|
||||
"""
|
||||
return the action (int) to be executed.
|
||||
no exploration when exploration=None.
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1),
|
||||
feed_dict={self._observation_placeholder: observation[None]})
|
||||
return sampled_action
|
||||
|
||||
@property
|
||||
def logits(self):
|
||||
return self._logits
|
||||
|
||||
@property
|
||||
def n_categories(self):
|
||||
return self._n_categories
|
||||
|
||||
def values(self, observation):
|
||||
"""
|
||||
returns the Q(s, a) values (float) for all actions a at observation s
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
value = sess.run(self._value, feed_dict={self._observation_placeholder: observation[None]})
|
||||
return value
|
||||
|
||||
def values_tensor(self):
|
||||
"""
|
||||
returns the tensor of the values for all actions a at observation s
|
||||
"""
|
||||
return self._value
|
||||
|
@ -19,7 +19,8 @@ def full_return(raw_data):
|
||||
returns = rewards.copy()
|
||||
episode_start_idx = 0
|
||||
for i in range(1, num_timesteps):
|
||||
if episode_start_flags[i] or (i == num_timesteps - 1): # found the start of next episode or the end of all episodes
|
||||
if episode_start_flags[i] or (
|
||||
i == num_timesteps - 1): # found the start of next episode or the end of all episodes
|
||||
if i < rewards.shape[0] - 1:
|
||||
t = i - 1
|
||||
else:
|
||||
@ -34,4 +35,36 @@ def full_return(raw_data):
|
||||
|
||||
data['returns'] = returns
|
||||
|
||||
return data
|
||||
return data
|
||||
|
||||
|
||||
class QLearningTarget:
|
||||
def __init__(self, policy, gamma):
|
||||
self._policy = policy
|
||||
self._gamma = gamma
|
||||
|
||||
def __call__(self, raw_data):
|
||||
data = dict()
|
||||
observations = list()
|
||||
actions = list()
|
||||
rewards = list()
|
||||
wi = list()
|
||||
all_data, data_wi, data_index = raw_data
|
||||
|
||||
for i in range(0, all_data.shape[0]):
|
||||
current_data = all_data[i]
|
||||
current_wi = data_wi[i]
|
||||
current_index = data_index[i]
|
||||
observations.append(current_data['observation'])
|
||||
actions.append(current_data['action'])
|
||||
next_max_qvalue = np.max(self._policy.values(current_data['observation']))
|
||||
current_qvalue = self._policy.values(current_data['previous_observation'])[current_data['previous_action']]
|
||||
reward = current_data['reward'] + next_max_qvalue - current_qvalue
|
||||
rewards.append(reward)
|
||||
wi.append(current_wi)
|
||||
|
||||
data['observations'] = np.array(observations)
|
||||
data['actions'] = np.array(actions)
|
||||
data['rewards'] = np.array(rewards)
|
||||
|
||||
return data
|
||||
|
@ -1,39 +1,51 @@
|
||||
class ReplayBuffer(object):
|
||||
def __init__(self, conf):
|
||||
'''
|
||||
def __init__(self, env, policy, qnet, target_qnet, conf):
|
||||
"""
|
||||
Initialize a replay buffer with parameters in conf.
|
||||
'''
|
||||
pass
|
||||
"""
|
||||
pass
|
||||
|
||||
def add(self, data, priority):
|
||||
'''
|
||||
def add(self, data, priority):
|
||||
"""
|
||||
Add a data with priority = priority to replay buffer.
|
||||
'''
|
||||
pass
|
||||
"""
|
||||
pass
|
||||
|
||||
def update_priority(self, indices, priorities):
|
||||
'''
|
||||
def collect(self):
|
||||
"""
|
||||
Collect data from current environment and policy.
|
||||
"""
|
||||
pass
|
||||
|
||||
def next_batch(self, batch_size):
|
||||
"""
|
||||
get batch of data from the replay buffer.
|
||||
"""
|
||||
pass
|
||||
|
||||
def update_priority(self, indices, priorities):
|
||||
"""
|
||||
Update the data's priority whose indices = indices.
|
||||
For proportional replay buffer, the priority is the priority.
|
||||
For rank based replay buffer, the priorities parameter will be the delta used to update the priority.
|
||||
'''
|
||||
pass
|
||||
"""
|
||||
pass
|
||||
|
||||
def reset_alpha(self, alpha):
|
||||
'''
|
||||
def reset_alpha(self, alpha):
|
||||
"""
|
||||
This function only works for proportional replay buffer.
|
||||
This function resets alpha.
|
||||
'''
|
||||
pass
|
||||
"""
|
||||
pass
|
||||
|
||||
def sample(self, conf):
|
||||
'''
|
||||
def sample(self, conf):
|
||||
"""
|
||||
Sample from replay buffer with parameters in conf.
|
||||
'''
|
||||
pass
|
||||
"""
|
||||
pass
|
||||
|
||||
def rebalance(self):
|
||||
'''
|
||||
def rebalance(self):
|
||||
"""
|
||||
This is for rank based priority replay buffer, which is used to rebalance the sum tree of the priority queue.
|
||||
'''
|
||||
pass
|
||||
"""
|
||||
pass
|
||||
|
@ -1,29 +1,93 @@
|
||||
from buffer import ReplayBuffer
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from collections import deque
|
||||
from math import fabs
|
||||
|
||||
from tianshou.data.replay_buffer.buffer import ReplayBuffer
|
||||
|
||||
|
||||
class NaiveExperience(ReplayBuffer):
|
||||
def __init__(self, conf):
|
||||
self.max_size = conf['size']
|
||||
self.n_entries = 0
|
||||
self.memory = deque(maxlen = self.max_size)
|
||||
def __init__(self, env, policy, qnet, target_qnet, conf):
|
||||
self.max_size = conf['size']
|
||||
self._env = env
|
||||
self._policy = policy
|
||||
self._qnet = qnet
|
||||
self._target_qnet = target_qnet
|
||||
self._begin_act()
|
||||
self.n_entries = 0
|
||||
self.memory = deque(maxlen=self.max_size)
|
||||
|
||||
def add(self, data, priority = 0):
|
||||
self.memory.append(data)
|
||||
if self.n_entries < self.max_size:
|
||||
self.n_entries += 1
|
||||
def add(self, data, priority=0):
|
||||
self.memory.append(data)
|
||||
if self.n_entries < self.max_size:
|
||||
self.n_entries += 1
|
||||
|
||||
def update_priority(self, indices, priorities = 0):
|
||||
pass
|
||||
def _begin_act(self):
|
||||
self.observation = self._env.reset()
|
||||
self.action = self._env.action_space.sample()
|
||||
done = False
|
||||
while not done:
|
||||
if done:
|
||||
self.observation = self._env.reset()
|
||||
self.action = self._env.action_space.sample()
|
||||
self.observation, _, done, _ = self._env.step(self.action)
|
||||
|
||||
def reset_alpha(self, alpha):
|
||||
pass
|
||||
def collect(self):
|
||||
sess = tf.get_default_session()
|
||||
current_data = dict()
|
||||
current_data['previous_action'] = self.action
|
||||
current_data['previous_observation'] = self.observation
|
||||
self.action = np.argmax(sess.run(self._policy, feed_dict={"dqn_observation:0": self.observation.reshape((1,) + self.observation.shape)}))
|
||||
self.observation, reward, done, _ = self._env.step(self.action)
|
||||
current_data['action'] = self.action
|
||||
current_data['observation'] = self.observation
|
||||
current_data['reward'] = reward
|
||||
self.add(current_data)
|
||||
if done:
|
||||
self._begin_act()
|
||||
|
||||
def sample(self, conf):
|
||||
batch_size = conf['batch_size']
|
||||
batch_size = min(len(self.memory), batch_size)
|
||||
idxs = np.random.choice(len(self.memory), batch_size)
|
||||
return [self.memory[idx] for idx in idxs], [1] * len(idxs), idxs
|
||||
def update_priority(self, indices, priorities=0):
|
||||
pass
|
||||
|
||||
def rebalance(self):
|
||||
pass
|
||||
def reset_alpha(self, alpha):
|
||||
pass
|
||||
|
||||
def sample(self, conf):
|
||||
batch_size = conf['batch_size']
|
||||
batch_size = min(len(self.memory), batch_size)
|
||||
idxs = np.random.choice(len(self.memory), batch_size)
|
||||
return [self.memory[idx] for idx in idxs], [1] * len(idxs), idxs
|
||||
|
||||
def next_batch(self, batch_size):
|
||||
data = dict()
|
||||
observations = list()
|
||||
actions = list()
|
||||
rewards = list()
|
||||
wi = list()
|
||||
target = list()
|
||||
|
||||
for i in range(0, batch_size):
|
||||
current_datas, current_wis, current_indexs = self.sample({'batch_size': 1})
|
||||
current_data = current_datas[0]
|
||||
current_wi = current_wis[0]
|
||||
current_index = current_indexs[0]
|
||||
observations.append(current_data['observation'])
|
||||
actions.append(current_data['action'])
|
||||
next_max_qvalue = np.max(self._target_qnet.values(current_data['observation']))
|
||||
current_qvalue = self._qnet.values(current_data['previous_observation'])[0, current_data['previous_action']]
|
||||
reward = current_data['reward'] + next_max_qvalue - current_qvalue
|
||||
rewards.append(reward)
|
||||
target.append(current_data['reward'] + next_max_qvalue)
|
||||
self.update_priority(current_index, [fabs(reward)])
|
||||
wi.append(current_wi)
|
||||
|
||||
data['observations'] = np.array(observations)
|
||||
data['actions'] = np.array(actions)
|
||||
data['rewards'] = np.array(rewards)
|
||||
data['wi'] = np.array(wi)
|
||||
data['target'] = np.array(target)
|
||||
|
||||
return data
|
||||
|
||||
def rebalance(self):
|
||||
pass
|
||||
|
@ -1,7 +1,10 @@
|
||||
import numpy
|
||||
import numpy as np
|
||||
import random
|
||||
import sum_tree
|
||||
from buffer import ReplayBuffer
|
||||
import tensorflow as tf
|
||||
import math
|
||||
|
||||
from tianshou.data.replay_buffer import sum_tree
|
||||
from tianshou.data.replay_buffer.buffer import ReplayBuffer
|
||||
|
||||
|
||||
class PropotionalExperience(ReplayBuffer):
|
||||
@ -15,7 +18,7 @@ class PropotionalExperience(ReplayBuffer):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, conf):
|
||||
def __init__(self, env, policy, qnet, target_qnet, conf):
|
||||
""" Prioritized experience replay buffer initialization.
|
||||
|
||||
Parameters
|
||||
@ -30,11 +33,26 @@ class PropotionalExperience(ReplayBuffer):
|
||||
"""
|
||||
memory_size = conf['size']
|
||||
batch_size = conf['batch_size']
|
||||
alpha = conf['alpha']
|
||||
alpha = conf['alpha'] if 'alpha' in conf else 0.6
|
||||
self.tree = sum_tree.SumTree(memory_size)
|
||||
self.memory_size = memory_size
|
||||
self.batch_size = batch_size
|
||||
self.alpha = alpha
|
||||
self._env = env
|
||||
self._policy = policy
|
||||
self._qnet = qnet
|
||||
self._target_qnet = target_qnet
|
||||
self._begin_act()
|
||||
|
||||
def _begin_act(self):
|
||||
self.observation = self._env.reset()
|
||||
self.action = self._env.action_space.sample()
|
||||
done = False
|
||||
while not done:
|
||||
if done:
|
||||
self.observation = self._env.reset()
|
||||
self.action = self._env.action_space.sample()
|
||||
self.observation, _, done, _ = self._env.step(self.action)
|
||||
|
||||
def add(self, data, priority):
|
||||
""" Add new sample.
|
||||
@ -48,6 +66,12 @@ class PropotionalExperience(ReplayBuffer):
|
||||
"""
|
||||
self.tree.add(data, priority**self.alpha)
|
||||
|
||||
def collect(self):
|
||||
pass
|
||||
|
||||
def next_batch(self, batch_size):
|
||||
pass
|
||||
|
||||
def sample(self, conf):
|
||||
""" The method return samples randomly.
|
||||
|
||||
@ -64,8 +88,9 @@ class PropotionalExperience(ReplayBuffer):
|
||||
indices:
|
||||
list of sample indices
|
||||
The indices indicate sample positions in a sum tree.
|
||||
:param conf: giving beta
|
||||
"""
|
||||
beta = conf['beta']
|
||||
beta = conf['beta'] if 'beta' in conf else 0.4
|
||||
if self.tree.filled_size() < self.batch_size:
|
||||
return None, None, None
|
||||
|
||||
@ -91,6 +116,54 @@ class PropotionalExperience(ReplayBuffer):
|
||||
|
||||
return out, weights, indices
|
||||
|
||||
def collect(self):
|
||||
sess = tf.get_default_session()
|
||||
current_data = dict()
|
||||
current_data['previous_action'] = self.action
|
||||
current_data['previous_observation'] = self.observation
|
||||
# TODO: change the name of the feed_dict
|
||||
self.action = np.argmax(sess.run(self._policy, feed_dict={"dqn_observation:0": self.observation.reshape((1,) + self.observation.shape)}))
|
||||
self.observation, reward, done, _ = self._env.step(self.action)
|
||||
current_data['action'] = self.action
|
||||
current_data['observation'] = self.observation
|
||||
current_data['reward'] = reward
|
||||
priorities = np.array([self.tree.get_val(i) ** -self.alpha for i in range(self.tree.filled_size())])
|
||||
priority = np.max(priorities) if len(priorities) > 0 else 1
|
||||
self.add(current_data, priority)
|
||||
if done:
|
||||
self._begin_act()
|
||||
|
||||
def next_batch(self, batch_size):
|
||||
data = dict()
|
||||
observations = list()
|
||||
actions = list()
|
||||
rewards = list()
|
||||
wi = list()
|
||||
target = list()
|
||||
|
||||
for i in range(0, batch_size):
|
||||
current_datas, current_wis, current_indexs = self.sample({'batch_size': 1})
|
||||
current_data = current_datas[0]
|
||||
current_wi = current_wis[0]
|
||||
current_index = current_indexs[0]
|
||||
observations.append(current_data['observation'])
|
||||
actions.append(current_data['action'])
|
||||
next_max_qvalue = np.max(self._target_qnet.values(current_data['observation']))
|
||||
current_qvalue = self._qnet.values(current_data['previous_observation'])[0, current_data['previous_action']]
|
||||
reward = current_data['reward'] + next_max_qvalue - current_qvalue
|
||||
rewards.append(reward)
|
||||
target.append(current_data['reward'] + next_max_qvalue)
|
||||
self.update_priority([current_index], [math.fabs(reward)])
|
||||
wi.append(current_wi)
|
||||
|
||||
data['observations'] = np.array(observations)
|
||||
data['actions'] = np.array(actions)
|
||||
data['rewards'] = np.array(rewards)
|
||||
data['wi'] = np.array(wi)
|
||||
data['target'] = np.array(target)
|
||||
|
||||
return data
|
||||
|
||||
def update_priority(self, indices, priorities):
|
||||
""" The methods update samples's priority.
|
||||
|
||||
|
@ -8,13 +8,15 @@ import sys
|
||||
import math
|
||||
import random
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tianshou.data.replay_buffer.binary_heap import BinaryHeap
|
||||
from tianshou.data.replay_buffer.buffer import ReplayBuffer
|
||||
|
||||
from binary_heap import BinaryHeap
|
||||
from buffer import ReplayBuffer
|
||||
|
||||
class RankBasedExperience(ReplayBuffer):
|
||||
|
||||
def __init__(self, conf):
|
||||
def __init__(self, env, policy, qnet, target_qnet, conf):
|
||||
self.size = conf['size']
|
||||
self.replace_flag = conf['replace_old'] if 'replace_old' in conf else True
|
||||
self.priority_size = conf['priority_size'] if 'priority_size' in conf else self.size
|
||||
@ -25,12 +27,18 @@ class RankBasedExperience(ReplayBuffer):
|
||||
self.learn_start = conf['learn_start'] if 'learn_start' in conf else 1000
|
||||
self.total_steps = conf['steps'] if 'steps' in conf else 100000
|
||||
# partition number N, split total size to N part
|
||||
self.partition_num = conf['partition_num'] if 'partition_num' in conf else 100
|
||||
self.partition_num = conf['partition_num'] if 'partition_num' in conf else 10
|
||||
|
||||
self.index = 0
|
||||
self.record_size = 0
|
||||
self.isFull = False
|
||||
|
||||
self._env = env
|
||||
self._policy = policy
|
||||
self._qnet = qnet
|
||||
self._target_qnet = target_qnet
|
||||
self._begin_act()
|
||||
|
||||
self._experience = {}
|
||||
self.priority_queue = BinaryHeap(self.priority_size)
|
||||
self.distributions = self.build_distributions()
|
||||
@ -98,7 +106,64 @@ class RankBasedExperience(ReplayBuffer):
|
||||
self.index += 1
|
||||
return self.index
|
||||
|
||||
def add(self, data, priority = 0):
|
||||
def _begin_act(self):
|
||||
self.observation = self._env.reset()
|
||||
self.action = self._env.action_space.sample()
|
||||
done = False
|
||||
while not done:
|
||||
if done:
|
||||
self.observation = self._env.reset()
|
||||
self.action = self._env.action_space.sample()
|
||||
self.observation, _, done, _ = self._env.step(self.action)
|
||||
|
||||
def collect(self):
|
||||
sess = tf.get_default_session()
|
||||
current_data = dict()
|
||||
current_data['previous_action'] = self.action
|
||||
current_data['previous_observation'] = self.observation
|
||||
self.action = np.argmax(sess.run(self._policy, feed_dict={"dqn_observation:0": self.observation.reshape((1,) + self.observation.shape)}))
|
||||
self.observation, reward, done, _ = self._env.step(self.action)
|
||||
current_data['action'] = self.action
|
||||
current_data['observation'] = self.observation
|
||||
current_data['reward'] = reward
|
||||
self.add(current_data)
|
||||
if done:
|
||||
self._begin_act()
|
||||
|
||||
def next_batch(self, batch_size):
|
||||
data = dict()
|
||||
observations = list()
|
||||
actions = list()
|
||||
rewards = list()
|
||||
wi = list()
|
||||
target = list()
|
||||
|
||||
sess = tf.get_default_session()
|
||||
current_datas, current_wis, current_indexs = self.sample({'global_step': sess.run(tf.train.get_global_step())})
|
||||
|
||||
for i in range(0, batch_size):
|
||||
current_data = current_datas[i]
|
||||
current_wi = current_wis[i]
|
||||
current_index = current_indexs[i]
|
||||
observations.append(current_data['observation'])
|
||||
actions.append(current_data['action'])
|
||||
next_max_qvalue = np.max(self._target_qnet.values(current_data['observation']))
|
||||
current_qvalue = self._qnet.values(current_data['previous_observation'])[0, current_data['previous_action']]
|
||||
reward = current_data['reward'] + next_max_qvalue - current_qvalue
|
||||
rewards.append(reward)
|
||||
target.append(current_data['reward'] + next_max_qvalue)
|
||||
self.update_priority([current_index], [math.fabs(reward)])
|
||||
wi.append(current_wi)
|
||||
|
||||
data['observations'] = np.array(observations)
|
||||
data['actions'] = np.array(actions)
|
||||
data['rewards'] = np.array(rewards)
|
||||
data['wi'] = np.array(wi)
|
||||
data['target'] = np.array(target)
|
||||
|
||||
return data
|
||||
|
||||
def add(self, data, priority = 1):
|
||||
"""
|
||||
store experience, suggest that experience is a tuple of (s1, a, r, s2, t)
|
||||
so each experience is valid
|
||||
@ -156,16 +221,16 @@ class RankBasedExperience(ReplayBuffer):
|
||||
sys.stderr.write('Record size less than learn start! Sample failed\n')
|
||||
return False, False, False
|
||||
|
||||
dist_index = math.floor(self.record_size / self.size * self.partition_num)
|
||||
dist_index = math.floor(self.record_size * 1. / self.size * self.partition_num)
|
||||
# issue 1 by @camigord
|
||||
partition_size = math.floor(self.size / self.partition_num)
|
||||
partition_size = math.floor(self.size * 1. / self.partition_num)
|
||||
partition_max = dist_index * partition_size
|
||||
distribution = self.distributions[dist_index]
|
||||
rank_list = []
|
||||
# sample from k segments
|
||||
for n in range(1, self.batch_size + 1):
|
||||
index = random.randint(distribution['strata_ends'][n] + 1,
|
||||
distribution['strata_ends'][n + 1])
|
||||
index = random.randint(distribution['strata_ends'][n],
|
||||
distribution['strata_ends'][n + 1])
|
||||
rank_list.append(index)
|
||||
|
||||
# beta, increase by global_step, max 1
|
||||
|
@ -1,13 +1,15 @@
|
||||
from utils import *
|
||||
from functions import *
|
||||
|
||||
from tianshou.data.replay_buffer.utils import get_replay_buffer
|
||||
|
||||
|
||||
def test_rank_based():
|
||||
conf = {'size': 50,
|
||||
'learn_start': 10,
|
||||
'partition_num': 5,
|
||||
'total_step': 100,
|
||||
'batch_size': 4}
|
||||
experience = getReplayBuffer('rank_based', conf)
|
||||
experience = get_replay_buffer('rank_based', conf)
|
||||
|
||||
# insert to experience
|
||||
print 'test insert experience'
|
||||
@ -52,7 +54,7 @@ def test_proportional():
|
||||
conf = {'size': 50,
|
||||
'alpha': 0.7,
|
||||
'batch_size': 4}
|
||||
experience = getReplayBuffer('proportional', conf)
|
||||
experience = get_replay_buffer('proportional', conf)
|
||||
|
||||
# insert to experience
|
||||
print 'test insert experience'
|
||||
@ -90,7 +92,7 @@ def test_proportional():
|
||||
|
||||
def test_naive():
|
||||
conf = {'size': 50}
|
||||
experience = getReplayBuffer('naive', conf)
|
||||
experience = get_replay_buffer('naive', conf)
|
||||
|
||||
# insert to experience
|
||||
print 'test insert experience'
|
||||
|
@ -1,17 +1,20 @@
|
||||
from rank_based import *
|
||||
from proportional import *
|
||||
from naive import *
|
||||
import sys
|
||||
|
||||
def getReplayBuffer(name, conf):
|
||||
'''
|
||||
Get replay buffer according to the given name.
|
||||
'''
|
||||
if (name == 'rank_based'):
|
||||
return RankBasedExperience(conf)
|
||||
elif (name == 'proportional'):
|
||||
return PropotionalExperience(conf)
|
||||
elif (name == 'naive'):
|
||||
return NaiveExperience(conf)
|
||||
else:
|
||||
sys.stderr.write('no such replay buffer')
|
||||
from tianshou.data.replay_buffer.naive import NaiveExperience
|
||||
from tianshou.data.replay_buffer.proportional import PropotionalExperience
|
||||
from tianshou.data.replay_buffer.rank_based import RankBasedExperience
|
||||
|
||||
|
||||
def get_replay_buffer(name, env, policy, qnet, target_qnet, conf):
|
||||
"""
|
||||
Get replay buffer according to the given name.
|
||||
"""
|
||||
|
||||
if name == 'rank_based':
|
||||
return RankBasedExperience(env, policy, qnet, target_qnet, conf)
|
||||
elif name == 'proportional':
|
||||
return PropotionalExperience(env, policy, qnet, target_qnet, conf)
|
||||
elif name == 'naive':
|
||||
return NaiveExperience(env, policy, qnet, target_qnet, conf)
|
||||
else:
|
||||
sys.stderr.write('no such replay buffer')
|
||||
|
Loading…
x
Reference in New Issue
Block a user