interfaces for advantage_estimation. full_return finished and tested.
This commit is contained in:
parent
25b25ce7d8
commit
675057c6b9
@ -59,7 +59,7 @@ if __name__ == '__main__':
|
||||
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
|
||||
|
||||
### 3. define data collection
|
||||
training_data = Batch(env, pi, [advantage_estimation.full_return], [pi], render = args.render)
|
||||
training_data = Batch(env, pi, [advantage_estimation.full_return], [pi], render=args.render)
|
||||
|
||||
### 4. start training
|
||||
config = tf.ConfigProto()
|
||||
|
@ -10,6 +10,6 @@ data_collector.data.keys()
|
||||
|
||||
['reward']
|
||||
|
||||
['start_flag']
|
||||
['done_flag']
|
||||
|
||||
['advantage'] > ['return'] # they may appear simultaneously
|
@ -1,39 +1,48 @@
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
|
||||
def full_return(raw_data):
|
||||
STATE = 0
|
||||
ACTION = 1
|
||||
REWARD = 2
|
||||
DONE = 3
|
||||
|
||||
# modified for new interfaces
|
||||
def full_return(buffer, index=None):
|
||||
"""
|
||||
naively compute full return
|
||||
:param raw_data: dict of specified keys and values.
|
||||
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
||||
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
|
||||
each episode.
|
||||
:return: dict with key 'return' and value the computed returns corresponding to `index`.
|
||||
"""
|
||||
observations = raw_data['observation']
|
||||
actions = raw_data['action']
|
||||
rewards = raw_data['reward']
|
||||
episode_start_flags = raw_data['end_flag']
|
||||
num_timesteps = rewards.shape[0]
|
||||
index = index or buffer.index
|
||||
raw_data = buffer.data
|
||||
|
||||
data = {}
|
||||
returns = []
|
||||
for i_episode in range(len(index)):
|
||||
index_this = index[i_episode]
|
||||
if index_this:
|
||||
episode = raw_data[i_episode]
|
||||
if not episode[-1][DONE]:
|
||||
logging.warning('Computing full return on episode {} with no terminal state.'.format(i_episode))
|
||||
|
||||
returns = rewards.copy()
|
||||
episode_start_idx = 0
|
||||
for i in range(1, num_timesteps):
|
||||
if episode_start_flags[i] or (
|
||||
i == num_timesteps - 1): # found the start of next episode or the end of all episodes
|
||||
if i < rewards.shape[0] - 1:
|
||||
t = i - 1
|
||||
else:
|
||||
t = i
|
||||
Gt = 0
|
||||
while t >= episode_start_idx:
|
||||
Gt += rewards[t]
|
||||
returns[t] = Gt
|
||||
t -= 1
|
||||
episode_length = len(episode)
|
||||
returns_episode = [0.] * episode_length
|
||||
returns_this = [0.] * len(index_this)
|
||||
return_ = 0.
|
||||
index_min = min(index_this)
|
||||
for i, frame in zip(range(episode_length - 1, index_min - 1, -1), reversed(episode[index_min:])):
|
||||
return_ += frame[REWARD]
|
||||
returns_episode[i] = return_
|
||||
|
||||
episode_start_idx = i
|
||||
for i in range(len(index_this)):
|
||||
returns_this[i] = returns_episode[index_this[i]]
|
||||
|
||||
data['return'] = returns
|
||||
returns.append(returns_this)
|
||||
else:
|
||||
returns.append([])
|
||||
|
||||
return data
|
||||
return {'return': returns}
|
||||
|
||||
|
||||
class gae_lambda:
|
||||
@ -44,16 +53,14 @@ class gae_lambda:
|
||||
self.T = T
|
||||
self.value_function = value_function
|
||||
|
||||
def __call__(self, raw_data):
|
||||
reward = raw_data['reward']
|
||||
observation = raw_data['observation']
|
||||
|
||||
state_value = self.value_function.eval_value(observation)
|
||||
|
||||
# wrong version of advantage just to run
|
||||
advantage = reward + state_value
|
||||
|
||||
return {'advantage': advantage}
|
||||
def __call__(self, buffer, index=None):
|
||||
"""
|
||||
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
||||
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
|
||||
each episode.
|
||||
:return: dict with key 'advantage' and value the computed advantages corresponding to `index`.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class nstep_return:
|
||||
@ -64,16 +71,15 @@ class nstep_return:
|
||||
self.n = n
|
||||
self.value_function = value_function
|
||||
|
||||
def __call__(self, raw_data):
|
||||
reward = raw_data['reward']
|
||||
observation = raw_data['observation']
|
||||
|
||||
state_value = self.value_function.eval_value(observation)
|
||||
|
||||
# wrong version of return just to run
|
||||
return_ = reward + state_value
|
||||
|
||||
return {'return': return_}
|
||||
def __call__(self, buffer, index=None):
|
||||
"""
|
||||
naively compute full return
|
||||
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
||||
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
|
||||
each episode.
|
||||
:return: dict with key 'return' and value the computed returns corresponding to `index`.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ddpg_return:
|
||||
@ -85,20 +91,15 @@ class ddpg_return:
|
||||
self.critic = critic
|
||||
self.use_target_network = use_target_network
|
||||
|
||||
def __call__(self, raw_data):
|
||||
observation = raw_data['observation']
|
||||
reward = raw_data['reward']
|
||||
|
||||
if self.use_target_network:
|
||||
action_target = self.actor.eval_action_old(observation)
|
||||
value_target = self.critic.eval_value_old(observation, action_target)
|
||||
else:
|
||||
action_target = self.actor.eval_action(observation)
|
||||
value_target = self.critic.eval_value(observation, action_target)
|
||||
|
||||
return_ = reward + value_target
|
||||
|
||||
return {'return': return_}
|
||||
def __call__(self, buffer, index=None):
|
||||
"""
|
||||
naively compute full return
|
||||
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
||||
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
|
||||
each episode.
|
||||
:return: dict with key 'return' and value the computed returns corresponding to `index`.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class nstep_q_return:
|
||||
@ -110,80 +111,12 @@ class nstep_q_return:
|
||||
self.action_value = action_value
|
||||
self.use_target_network = use_target_network
|
||||
|
||||
def __call__(self, raw_data):
|
||||
# raw_data should contain 'next_observation' from replay memory...?
|
||||
# maybe the main difference between Batch and Replay is the stored data format?
|
||||
reward = raw_data['reward']
|
||||
observation = raw_data['observation']
|
||||
|
||||
if self.use_target_network:
|
||||
action_value_all_actions = self.action_value.eval_value_all_actions_old(observation)
|
||||
else:
|
||||
action_value_all_actions = self.action_value.eval_value_all_actions(observation)
|
||||
|
||||
action_value_max = np.max(action_value_all_actions, axis=1)
|
||||
|
||||
return_ = reward + action_value_max
|
||||
|
||||
return {'return': return_}
|
||||
|
||||
|
||||
class QLearningTarget:
|
||||
def __init__(self, policy, gamma):
|
||||
self._policy = policy
|
||||
self._gamma = gamma
|
||||
|
||||
def __call__(self, raw_data):
|
||||
data = dict()
|
||||
observations = list()
|
||||
actions = list()
|
||||
rewards = list()
|
||||
wi = list()
|
||||
all_data, data_wi, data_index = raw_data
|
||||
|
||||
for i in range(0, all_data.shape[0]):
|
||||
current_data = all_data[i]
|
||||
current_wi = data_wi[i]
|
||||
current_index = data_index[i]
|
||||
observations.append(current_data['observation'])
|
||||
actions.append(current_data['action'])
|
||||
next_max_qvalue = np.max(self._policy.values(current_data['observation']))
|
||||
current_qvalue = self._policy.values(current_data['previous_observation'])[current_data['previous_action']]
|
||||
reward = current_data['reward'] + next_max_qvalue - current_qvalue
|
||||
rewards.append(reward)
|
||||
wi.append(current_wi)
|
||||
|
||||
data['observations'] = np.array(observations)
|
||||
data['actions'] = np.array(actions)
|
||||
data['rewards'] = np.array(rewards)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class ReplayMemoryQReturn:
|
||||
"""
|
||||
compute the n-step return for Q-learning targets
|
||||
"""
|
||||
def __init__(self, n, action_value, use_target_network=True):
|
||||
self.n = n
|
||||
self._action_value = action_value
|
||||
self._use_target_network = use_target_network
|
||||
|
||||
def __call__(self, raw_data):
|
||||
reward = raw_data['reward']
|
||||
observation = raw_data['observation']
|
||||
|
||||
if self._use_target_network:
|
||||
# print(observation.shape)
|
||||
# print((observation.reshape((1,) + observation.shape)))
|
||||
action_value_all_actions = self._action_value.eval_value_all_actions_old(observation.reshape((1,) + observation.shape))
|
||||
else:
|
||||
# print(observation.shape)
|
||||
# print((observation.reshape((1,) + observation.shape)))
|
||||
action_value_all_actions = self._action_value.eval_value_all_actions(observation.reshape((1,) + observation.shape))
|
||||
|
||||
action_value_max = np.max(action_value_all_actions, axis=1)
|
||||
|
||||
return_ = reward + action_value_max
|
||||
|
||||
return {'return': return_}
|
||||
def __call__(self, buffer, index=None):
|
||||
"""
|
||||
naively compute full return
|
||||
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
||||
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
|
||||
each episode.
|
||||
:return: dict with key 'return' and value the computed returns corresponding to `index`.
|
||||
"""
|
||||
pass
|
||||
|
@ -156,11 +156,19 @@ class Batch(object):
|
||||
def next_batch(self, batch_size, standardize_advantage=True):
|
||||
rand_idx = np.random.choice(self.raw_data['observation'].shape[0], batch_size)
|
||||
|
||||
# maybe re-compute advantage here, but only on rand_idx
|
||||
# but how to construct the feed_dict?
|
||||
if self.online:
|
||||
self.data_batch.update(self.apply_advantage_estimation_function(rand_idx))
|
||||
|
||||
|
||||
feed_dict = {}
|
||||
for key, placeholder in self.required_placeholders.items():
|
||||
feed_dict[placeholder] = utils.get_batch(self, key, rand_idx)
|
||||
|
||||
found, data_key = utils.internal_key_match(key, self.raw_data.keys())
|
||||
if found:
|
||||
feed_dict[placeholder] = self.raw_data[data_key][rand_idx]
|
||||
feed_dict[placeholder] = utils.get_batch(self.raw_data[data_key], rand_idx) # self.raw_data[data_key][rand_idx]
|
||||
else:
|
||||
found, data_key = utils.internal_key_match(key, self.data.keys())
|
||||
if found:
|
||||
|
29
tianshou/data/test_advantage_estimation.py
Normal file
29
tianshou/data/test_advantage_estimation.py
Normal file
@ -0,0 +1,29 @@
|
||||
|
||||
|
||||
from advantage_estimation import *
|
||||
|
||||
class ReplayBuffer(object):
|
||||
def __init__(self):
|
||||
self.index = [
|
||||
[0, 1, 2],
|
||||
[0, 1, 2, 3],
|
||||
[0, 1],
|
||||
]
|
||||
self.data = [
|
||||
[(0, 0, 10, False), (0, 0, 1, False), (0, 0, -100, True)],
|
||||
[(0, 0, 1, False), (0, 0, 1, False), (0, 0, 1, False), (0, 0, 5, False)],
|
||||
[(0, 0, 9, False), (0, 0, -2, True)],
|
||||
]
|
||||
|
||||
|
||||
buffer = ReplayBuffer()
|
||||
sample_index = [
|
||||
[0, 2, 0],
|
||||
[1, 2, 1, 3],
|
||||
[],
|
||||
]
|
||||
|
||||
data = full_return(buffer)
|
||||
print(data['return'])
|
||||
data = full_return(buffer, sample_index)
|
||||
print(data['return'])
|
Loading…
x
Reference in New Issue
Block a user