interfaces for advantage_estimation. full_return finished and tested.
This commit is contained in:
parent
25b25ce7d8
commit
675057c6b9
@ -59,7 +59,7 @@ if __name__ == '__main__':
|
|||||||
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
|
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
|
||||||
|
|
||||||
### 3. define data collection
|
### 3. define data collection
|
||||||
training_data = Batch(env, pi, [advantage_estimation.full_return], [pi], render = args.render)
|
training_data = Batch(env, pi, [advantage_estimation.full_return], [pi], render=args.render)
|
||||||
|
|
||||||
### 4. start training
|
### 4. start training
|
||||||
config = tf.ConfigProto()
|
config = tf.ConfigProto()
|
||||||
|
@ -10,6 +10,6 @@ data_collector.data.keys()
|
|||||||
|
|
||||||
['reward']
|
['reward']
|
||||||
|
|
||||||
['start_flag']
|
['done_flag']
|
||||||
|
|
||||||
['advantage'] > ['return'] # they may appear simultaneously
|
['advantage'] > ['return'] # they may appear simultaneously
|
@ -1,39 +1,48 @@
|
|||||||
import numpy as np
|
import logging
|
||||||
|
|
||||||
|
|
||||||
def full_return(raw_data):
|
STATE = 0
|
||||||
|
ACTION = 1
|
||||||
|
REWARD = 2
|
||||||
|
DONE = 3
|
||||||
|
|
||||||
|
# modified for new interfaces
|
||||||
|
def full_return(buffer, index=None):
|
||||||
"""
|
"""
|
||||||
naively compute full return
|
naively compute full return
|
||||||
:param raw_data: dict of specified keys and values.
|
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
||||||
|
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
|
||||||
|
each episode.
|
||||||
|
:return: dict with key 'return' and value the computed returns corresponding to `index`.
|
||||||
"""
|
"""
|
||||||
observations = raw_data['observation']
|
index = index or buffer.index
|
||||||
actions = raw_data['action']
|
raw_data = buffer.data
|
||||||
rewards = raw_data['reward']
|
|
||||||
episode_start_flags = raw_data['end_flag']
|
|
||||||
num_timesteps = rewards.shape[0]
|
|
||||||
|
|
||||||
data = {}
|
returns = []
|
||||||
|
for i_episode in range(len(index)):
|
||||||
|
index_this = index[i_episode]
|
||||||
|
if index_this:
|
||||||
|
episode = raw_data[i_episode]
|
||||||
|
if not episode[-1][DONE]:
|
||||||
|
logging.warning('Computing full return on episode {} with no terminal state.'.format(i_episode))
|
||||||
|
|
||||||
returns = rewards.copy()
|
episode_length = len(episode)
|
||||||
episode_start_idx = 0
|
returns_episode = [0.] * episode_length
|
||||||
for i in range(1, num_timesteps):
|
returns_this = [0.] * len(index_this)
|
||||||
if episode_start_flags[i] or (
|
return_ = 0.
|
||||||
i == num_timesteps - 1): # found the start of next episode or the end of all episodes
|
index_min = min(index_this)
|
||||||
if i < rewards.shape[0] - 1:
|
for i, frame in zip(range(episode_length - 1, index_min - 1, -1), reversed(episode[index_min:])):
|
||||||
t = i - 1
|
return_ += frame[REWARD]
|
||||||
else:
|
returns_episode[i] = return_
|
||||||
t = i
|
|
||||||
Gt = 0
|
|
||||||
while t >= episode_start_idx:
|
|
||||||
Gt += rewards[t]
|
|
||||||
returns[t] = Gt
|
|
||||||
t -= 1
|
|
||||||
|
|
||||||
episode_start_idx = i
|
for i in range(len(index_this)):
|
||||||
|
returns_this[i] = returns_episode[index_this[i]]
|
||||||
|
|
||||||
data['return'] = returns
|
returns.append(returns_this)
|
||||||
|
else:
|
||||||
|
returns.append([])
|
||||||
|
|
||||||
return data
|
return {'return': returns}
|
||||||
|
|
||||||
|
|
||||||
class gae_lambda:
|
class gae_lambda:
|
||||||
@ -44,16 +53,14 @@ class gae_lambda:
|
|||||||
self.T = T
|
self.T = T
|
||||||
self.value_function = value_function
|
self.value_function = value_function
|
||||||
|
|
||||||
def __call__(self, raw_data):
|
def __call__(self, buffer, index=None):
|
||||||
reward = raw_data['reward']
|
"""
|
||||||
observation = raw_data['observation']
|
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
||||||
|
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
|
||||||
state_value = self.value_function.eval_value(observation)
|
each episode.
|
||||||
|
:return: dict with key 'advantage' and value the computed advantages corresponding to `index`.
|
||||||
# wrong version of advantage just to run
|
"""
|
||||||
advantage = reward + state_value
|
pass
|
||||||
|
|
||||||
return {'advantage': advantage}
|
|
||||||
|
|
||||||
|
|
||||||
class nstep_return:
|
class nstep_return:
|
||||||
@ -64,16 +71,15 @@ class nstep_return:
|
|||||||
self.n = n
|
self.n = n
|
||||||
self.value_function = value_function
|
self.value_function = value_function
|
||||||
|
|
||||||
def __call__(self, raw_data):
|
def __call__(self, buffer, index=None):
|
||||||
reward = raw_data['reward']
|
"""
|
||||||
observation = raw_data['observation']
|
naively compute full return
|
||||||
|
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
||||||
state_value = self.value_function.eval_value(observation)
|
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
|
||||||
|
each episode.
|
||||||
# wrong version of return just to run
|
:return: dict with key 'return' and value the computed returns corresponding to `index`.
|
||||||
return_ = reward + state_value
|
"""
|
||||||
|
pass
|
||||||
return {'return': return_}
|
|
||||||
|
|
||||||
|
|
||||||
class ddpg_return:
|
class ddpg_return:
|
||||||
@ -85,20 +91,15 @@ class ddpg_return:
|
|||||||
self.critic = critic
|
self.critic = critic
|
||||||
self.use_target_network = use_target_network
|
self.use_target_network = use_target_network
|
||||||
|
|
||||||
def __call__(self, raw_data):
|
def __call__(self, buffer, index=None):
|
||||||
observation = raw_data['observation']
|
"""
|
||||||
reward = raw_data['reward']
|
naively compute full return
|
||||||
|
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
||||||
if self.use_target_network:
|
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
|
||||||
action_target = self.actor.eval_action_old(observation)
|
each episode.
|
||||||
value_target = self.critic.eval_value_old(observation, action_target)
|
:return: dict with key 'return' and value the computed returns corresponding to `index`.
|
||||||
else:
|
"""
|
||||||
action_target = self.actor.eval_action(observation)
|
pass
|
||||||
value_target = self.critic.eval_value(observation, action_target)
|
|
||||||
|
|
||||||
return_ = reward + value_target
|
|
||||||
|
|
||||||
return {'return': return_}
|
|
||||||
|
|
||||||
|
|
||||||
class nstep_q_return:
|
class nstep_q_return:
|
||||||
@ -110,80 +111,12 @@ class nstep_q_return:
|
|||||||
self.action_value = action_value
|
self.action_value = action_value
|
||||||
self.use_target_network = use_target_network
|
self.use_target_network = use_target_network
|
||||||
|
|
||||||
def __call__(self, raw_data):
|
def __call__(self, buffer, index=None):
|
||||||
# raw_data should contain 'next_observation' from replay memory...?
|
"""
|
||||||
# maybe the main difference between Batch and Replay is the stored data format?
|
naively compute full return
|
||||||
reward = raw_data['reward']
|
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
||||||
observation = raw_data['observation']
|
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
|
||||||
|
each episode.
|
||||||
if self.use_target_network:
|
:return: dict with key 'return' and value the computed returns corresponding to `index`.
|
||||||
action_value_all_actions = self.action_value.eval_value_all_actions_old(observation)
|
"""
|
||||||
else:
|
pass
|
||||||
action_value_all_actions = self.action_value.eval_value_all_actions(observation)
|
|
||||||
|
|
||||||
action_value_max = np.max(action_value_all_actions, axis=1)
|
|
||||||
|
|
||||||
return_ = reward + action_value_max
|
|
||||||
|
|
||||||
return {'return': return_}
|
|
||||||
|
|
||||||
|
|
||||||
class QLearningTarget:
|
|
||||||
def __init__(self, policy, gamma):
|
|
||||||
self._policy = policy
|
|
||||||
self._gamma = gamma
|
|
||||||
|
|
||||||
def __call__(self, raw_data):
|
|
||||||
data = dict()
|
|
||||||
observations = list()
|
|
||||||
actions = list()
|
|
||||||
rewards = list()
|
|
||||||
wi = list()
|
|
||||||
all_data, data_wi, data_index = raw_data
|
|
||||||
|
|
||||||
for i in range(0, all_data.shape[0]):
|
|
||||||
current_data = all_data[i]
|
|
||||||
current_wi = data_wi[i]
|
|
||||||
current_index = data_index[i]
|
|
||||||
observations.append(current_data['observation'])
|
|
||||||
actions.append(current_data['action'])
|
|
||||||
next_max_qvalue = np.max(self._policy.values(current_data['observation']))
|
|
||||||
current_qvalue = self._policy.values(current_data['previous_observation'])[current_data['previous_action']]
|
|
||||||
reward = current_data['reward'] + next_max_qvalue - current_qvalue
|
|
||||||
rewards.append(reward)
|
|
||||||
wi.append(current_wi)
|
|
||||||
|
|
||||||
data['observations'] = np.array(observations)
|
|
||||||
data['actions'] = np.array(actions)
|
|
||||||
data['rewards'] = np.array(rewards)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
class ReplayMemoryQReturn:
|
|
||||||
"""
|
|
||||||
compute the n-step return for Q-learning targets
|
|
||||||
"""
|
|
||||||
def __init__(self, n, action_value, use_target_network=True):
|
|
||||||
self.n = n
|
|
||||||
self._action_value = action_value
|
|
||||||
self._use_target_network = use_target_network
|
|
||||||
|
|
||||||
def __call__(self, raw_data):
|
|
||||||
reward = raw_data['reward']
|
|
||||||
observation = raw_data['observation']
|
|
||||||
|
|
||||||
if self._use_target_network:
|
|
||||||
# print(observation.shape)
|
|
||||||
# print((observation.reshape((1,) + observation.shape)))
|
|
||||||
action_value_all_actions = self._action_value.eval_value_all_actions_old(observation.reshape((1,) + observation.shape))
|
|
||||||
else:
|
|
||||||
# print(observation.shape)
|
|
||||||
# print((observation.reshape((1,) + observation.shape)))
|
|
||||||
action_value_all_actions = self._action_value.eval_value_all_actions(observation.reshape((1,) + observation.shape))
|
|
||||||
|
|
||||||
action_value_max = np.max(action_value_all_actions, axis=1)
|
|
||||||
|
|
||||||
return_ = reward + action_value_max
|
|
||||||
|
|
||||||
return {'return': return_}
|
|
||||||
|
@ -156,11 +156,19 @@ class Batch(object):
|
|||||||
def next_batch(self, batch_size, standardize_advantage=True):
|
def next_batch(self, batch_size, standardize_advantage=True):
|
||||||
rand_idx = np.random.choice(self.raw_data['observation'].shape[0], batch_size)
|
rand_idx = np.random.choice(self.raw_data['observation'].shape[0], batch_size)
|
||||||
|
|
||||||
|
# maybe re-compute advantage here, but only on rand_idx
|
||||||
|
# but how to construct the feed_dict?
|
||||||
|
if self.online:
|
||||||
|
self.data_batch.update(self.apply_advantage_estimation_function(rand_idx))
|
||||||
|
|
||||||
|
|
||||||
feed_dict = {}
|
feed_dict = {}
|
||||||
for key, placeholder in self.required_placeholders.items():
|
for key, placeholder in self.required_placeholders.items():
|
||||||
|
feed_dict[placeholder] = utils.get_batch(self, key, rand_idx)
|
||||||
|
|
||||||
found, data_key = utils.internal_key_match(key, self.raw_data.keys())
|
found, data_key = utils.internal_key_match(key, self.raw_data.keys())
|
||||||
if found:
|
if found:
|
||||||
feed_dict[placeholder] = self.raw_data[data_key][rand_idx]
|
feed_dict[placeholder] = utils.get_batch(self.raw_data[data_key], rand_idx) # self.raw_data[data_key][rand_idx]
|
||||||
else:
|
else:
|
||||||
found, data_key = utils.internal_key_match(key, self.data.keys())
|
found, data_key = utils.internal_key_match(key, self.data.keys())
|
||||||
if found:
|
if found:
|
||||||
|
29
tianshou/data/test_advantage_estimation.py
Normal file
29
tianshou/data/test_advantage_estimation.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
|
||||||
|
|
||||||
|
from advantage_estimation import *
|
||||||
|
|
||||||
|
class ReplayBuffer(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.index = [
|
||||||
|
[0, 1, 2],
|
||||||
|
[0, 1, 2, 3],
|
||||||
|
[0, 1],
|
||||||
|
]
|
||||||
|
self.data = [
|
||||||
|
[(0, 0, 10, False), (0, 0, 1, False), (0, 0, -100, True)],
|
||||||
|
[(0, 0, 1, False), (0, 0, 1, False), (0, 0, 1, False), (0, 0, 5, False)],
|
||||||
|
[(0, 0, 9, False), (0, 0, -2, True)],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
buffer = ReplayBuffer()
|
||||||
|
sample_index = [
|
||||||
|
[0, 2, 0],
|
||||||
|
[1, 2, 1, 3],
|
||||||
|
[],
|
||||||
|
]
|
||||||
|
|
||||||
|
data = full_return(buffer)
|
||||||
|
print(data['return'])
|
||||||
|
data = full_return(buffer, sample_index)
|
||||||
|
print(data['return'])
|
Loading…
x
Reference in New Issue
Block a user