initial data_collector. working on examples/dqn_replay.py to run
This commit is contained in:
parent
54a7b1343d
commit
2a2274aeea
@ -16,7 +16,7 @@ class PolicyBase(object):
|
|||||||
"""
|
"""
|
||||||
base class for policy. only provides `act` method with exploration
|
base class for policy. only provides `act` method with exploration
|
||||||
"""
|
"""
|
||||||
def act(self, observation):
|
def act(self, observation, my_feed_dict):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ class DQN(PolicyBase):
|
|||||||
"""
|
"""
|
||||||
use DQN from value_function as a member
|
use DQN from value_function as a member
|
||||||
"""
|
"""
|
||||||
def __init__(self, dqn):
|
def __init__(self, dqn, epsilon_train=0.1, epsilon_test=0.05):
|
||||||
self.action_value = dqn
|
self.action_value = dqn
|
||||||
self._argmax_action = tf.argmax(dqn.value_tensor_all_actions, axis=1)
|
self._argmax_action = tf.argmax(dqn.value_tensor_all_actions, axis=1)
|
||||||
self.weight_update = dqn.weight_update
|
self.weight_update = dqn.weight_update
|
||||||
@ -18,20 +18,29 @@ class DQN(PolicyBase):
|
|||||||
else:
|
else:
|
||||||
self.interaction_count = -1
|
self.interaction_count = -1
|
||||||
|
|
||||||
def act(self, observation, my_feed_dict):
|
self.epsilon_train = epsilon_train
|
||||||
|
self.epsilon_test = epsilon_test
|
||||||
|
|
||||||
|
def act(self, observation, my_feed_dict={}):
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
if self.weight_update > 1:
|
if self.weight_update > 1:
|
||||||
if self.interaction_count % self.weight_update == 0:
|
if self.interaction_count % self.weight_update == 0:
|
||||||
self.update_weights()
|
self.update_weights()
|
||||||
|
|
||||||
feed_dict = {self.action_value._observation_placeholder: observation[None]}
|
feed_dict = {self.action_value._observation_placeholder: observation[None]}
|
||||||
|
feed_dict.update(my_feed_dict)
|
||||||
action = sess.run(self._argmax_action, feed_dict=feed_dict)
|
action = sess.run(self._argmax_action, feed_dict=feed_dict)
|
||||||
|
if np.random.rand() < self.epsilon_train:
|
||||||
|
pass
|
||||||
|
|
||||||
if self.weight_update > 0:
|
if self.weight_update > 0:
|
||||||
self.interaction_count += 1
|
self.interaction_count += 1
|
||||||
|
|
||||||
return np.squeeze(action)
|
return np.squeeze(action)
|
||||||
|
|
||||||
|
def act_test(self, observation, my_feed_dict={}):
|
||||||
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def q_net(self):
|
def q_net(self):
|
||||||
return self.action_value
|
return self.action_value
|
||||||
@ -51,3 +60,9 @@ class DQN(PolicyBase):
|
|||||||
"""
|
"""
|
||||||
if self.action_value.weight_update_ops is not None:
|
if self.action_value.weight_update_ops is not None:
|
||||||
self.action_value.update_weights()
|
self.action_value.update_weights()
|
||||||
|
|
||||||
|
def set_epsilon_train(self, epsilon):
|
||||||
|
self.epsilon_train = epsilon
|
||||||
|
|
||||||
|
def set_epsilon_test(self, epsilon):
|
||||||
|
self.epsilon_test = epsilon
|
||||||
|
@ -101,7 +101,7 @@ class ddpg_return:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ReplayMemoryQReturn:
|
class nstep_q_return:
|
||||||
"""
|
"""
|
||||||
compute the n-step return for Q-learning targets
|
compute the n-step return for Q-learning targets
|
||||||
"""
|
"""
|
||||||
@ -111,7 +111,7 @@ class ReplayMemoryQReturn:
|
|||||||
self.use_target_network = use_target_network
|
self.use_target_network = use_target_network
|
||||||
|
|
||||||
# TODO : we should transfer the tf -> numpy/python -> tf into a monolithic compute graph in tf
|
# TODO : we should transfer the tf -> numpy/python -> tf into a monolithic compute graph in tf
|
||||||
def __call__(self, buffer, indexes =None):
|
def __call__(self, buffer, index=None):
|
||||||
"""
|
"""
|
||||||
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
||||||
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
|
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
|
||||||
@ -119,7 +119,7 @@ class ReplayMemoryQReturn:
|
|||||||
:return: dict with key 'return' and value the computed returns corresponding to `index`.
|
:return: dict with key 'return' and value the computed returns corresponding to `index`.
|
||||||
"""
|
"""
|
||||||
qvalue = self.action_value._value_tensor_all_actions
|
qvalue = self.action_value._value_tensor_all_actions
|
||||||
indexes = indexes or buffer.index
|
index = index or buffer.index
|
||||||
episodes = buffer.data
|
episodes = buffer.data
|
||||||
discount_factor = 0.99
|
discount_factor = 0.99
|
||||||
returns = []
|
returns = []
|
||||||
@ -128,13 +128,11 @@ class ReplayMemoryQReturn:
|
|||||||
config.gpu_options.allow_growth = True
|
config.gpu_options.allow_growth = True
|
||||||
with tf.Session(config=config) as sess:
|
with tf.Session(config=config) as sess:
|
||||||
sess.run(tf.global_variables_initializer())
|
sess.run(tf.global_variables_initializer())
|
||||||
for episode_index in range(len(indexes)):
|
for episode_index in range(len(index)):
|
||||||
index = indexes[episode_index]
|
index = index[episode_index]
|
||||||
if index:
|
if index:
|
||||||
episode = episodes[episode_index]
|
episode = episodes[episode_index]
|
||||||
episode_q = []
|
episode_q = []
|
||||||
if not episode[-1][DONE]:
|
|
||||||
logging.warning('Computing Q return on episode {} with no terminal state.'.format(episode_index))
|
|
||||||
|
|
||||||
for i in index:
|
for i in index:
|
||||||
current_discount_factor = 1
|
current_discount_factor = 1
|
||||||
@ -155,4 +153,4 @@ class ReplayMemoryQReturn:
|
|||||||
returns.append(episode_q)
|
returns.append(episode_q)
|
||||||
else:
|
else:
|
||||||
returns.append([])
|
returns.append([])
|
||||||
return {'TD-lambda': returns}
|
return {'return': returns}
|
||||||
|
@ -1,3 +1,7 @@
|
|||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
import itertools
|
||||||
|
|
||||||
from .replay_buffer.base import ReplayBufferBase
|
from .replay_buffer.base import ReplayBufferBase
|
||||||
|
|
||||||
class DataCollector(object):
|
class DataCollector(object):
|
||||||
@ -11,30 +15,28 @@ class DataCollector(object):
|
|||||||
self.process_functions = process_functions
|
self.process_functions = process_functions
|
||||||
self.managed_networks = managed_networks
|
self.managed_networks = managed_networks
|
||||||
|
|
||||||
|
self.data = {}
|
||||||
|
self.data_batch = {}
|
||||||
|
|
||||||
self.required_placeholders = {}
|
self.required_placeholders = {}
|
||||||
for net in self.managed_networks:
|
for net in self.managed_networks:
|
||||||
self.required_placeholders.update(net.managed_placeholders)
|
self.required_placeholders.update(net.managed_placeholders)
|
||||||
self.require_advantage = 'advantage' in self.required_placeholders.keys()
|
self.require_advantage = 'advantage' in self.required_placeholders.keys()
|
||||||
|
|
||||||
if isinstance(self.data_buffer, ReplayBufferBase): # process when sampling minibatch
|
if isinstance(self.data_buffer, ReplayBufferBase): # process when sampling minibatch
|
||||||
self.process_mode = 'minibatch'
|
self.process_mode = 'sample'
|
||||||
else:
|
else:
|
||||||
self.process_mode = 'batch'
|
self.process_mode = 'full'
|
||||||
|
|
||||||
self.current_observation = self.env.reset()
|
self.current_observation = self.env.reset()
|
||||||
|
|
||||||
def collect(self, num_timesteps=1, num_episodes=0, exploration=None, my_feed_dict={}):
|
def collect(self, num_timesteps=1, num_episodes=0, my_feed_dict={}):
|
||||||
assert sum([num_timesteps > 0, num_episodes > 0]) == 1,\
|
assert sum([num_timesteps > 0, num_episodes > 0]) == 1,\
|
||||||
"One and only one collection number specification permitted!"
|
"One and only one collection number specification permitted!"
|
||||||
|
|
||||||
if num_timesteps > 0:
|
if num_timesteps > 0:
|
||||||
for _ in range(num_timesteps):
|
for _ in range(num_timesteps):
|
||||||
action_vanilla = self.policy.act(self.current_observation, my_feed_dict=my_feed_dict)
|
action = self.policy.act(self.current_observation, my_feed_dict=my_feed_dict)
|
||||||
if exploration:
|
|
||||||
action = exploration(action_vanilla)
|
|
||||||
else:
|
|
||||||
action = action_vanilla
|
|
||||||
|
|
||||||
next_observation, reward, done, _ = self.env.step(action)
|
next_observation, reward, done, _ = self.env.step(action)
|
||||||
self.data_buffer.add((self.current_observation, action, reward, done))
|
self.data_buffer.add((self.current_observation, action, reward, done))
|
||||||
self.current_observation = next_observation
|
self.current_observation = next_observation
|
||||||
@ -44,24 +46,56 @@ class DataCollector(object):
|
|||||||
observation = self.env.reset()
|
observation = self.env.reset()
|
||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
action_vanilla = self.policy.act(observation, my_feed_dict=my_feed_dict)
|
action = self.policy.act(observation, my_feed_dict=my_feed_dict)
|
||||||
if exploration:
|
|
||||||
action = exploration(action_vanilla)
|
|
||||||
else:
|
|
||||||
action = action_vanilla
|
|
||||||
|
|
||||||
next_observation, reward, done, _ = self.env.step(action)
|
next_observation, reward, done, _ = self.env.step(action)
|
||||||
self.data_buffer.add((observation, action, reward, done))
|
self.data_buffer.add((observation, action, reward, done))
|
||||||
observation = next_observation
|
observation = next_observation
|
||||||
|
|
||||||
def next_batch(self, batch_size):
|
if self.process_mode == 'full':
|
||||||
|
for processor in self.process_functions:
|
||||||
|
self.data.update(processor(self.data_buffer))
|
||||||
|
|
||||||
|
def next_batch(self, batch_size, standardize_advantage=True):
|
||||||
sampled_index = self.data_buffer.sample(batch_size)
|
sampled_index = self.data_buffer.sample(batch_size)
|
||||||
if self.process_mode == 'minibatch':
|
if self.process_mode == 'sample':
|
||||||
pass
|
for processor in self.process_functions:
|
||||||
|
self.data_batch.update(processor(self.data_buffer, index=sampled_index))
|
||||||
|
|
||||||
# flatten rank-2 list to numpy array, construct feed_dict
|
# flatten rank-2 list to numpy array, construct feed_dict
|
||||||
|
feed_dict = {}
|
||||||
|
frame_key_map = {'observation': 0, 'action': 1, 'reward': 2, 'done_flag': 3}
|
||||||
|
for key, placeholder in self.required_placeholders.items():
|
||||||
|
if key in frame_key_map.keys(): # access raw_data
|
||||||
|
frame_index = frame_key_map[key]
|
||||||
|
flattened = []
|
||||||
|
for index_episode, data_episode in zip(sampled_index, self.data_buffer.data):
|
||||||
|
for i in index_episode:
|
||||||
|
flattened.append(data_episode[i][frame_index])
|
||||||
|
feed_dict[placeholder] = np.array(flattened)
|
||||||
|
elif key in self.data_batch.keys(): # access processed minibatch data
|
||||||
|
flattened = list(itertools.chain.from_iterable(self.data_batch[key]))
|
||||||
|
feed_dict[placeholder] = np.array(flattened)
|
||||||
|
elif key in self.data.keys(): # access processed full data
|
||||||
|
flattened = [0.] * batch_size # float
|
||||||
|
i_in_batch = 0
|
||||||
|
for index_episode, data_episode in zip(sampled_index, self.data[key]):
|
||||||
|
for i in index_episode:
|
||||||
|
flattened[i_in_batch] = data_episode[i]
|
||||||
|
i_in_batch += 1
|
||||||
|
feed_dict[placeholder] = np.array(flattened)
|
||||||
|
else:
|
||||||
|
raise TypeError('Placeholder {} has no value to feed!'.format(str(placeholder.name)))
|
||||||
|
|
||||||
return
|
if standardize_advantage:
|
||||||
|
if self.require_advantage:
|
||||||
|
advantage_value = feed_dict[self.required_placeholders['advantage']]
|
||||||
|
advantage_mean = np.mean(advantage_value)
|
||||||
|
advantage_std = np.std(advantage_value)
|
||||||
|
if advantage_std < 1e-3:
|
||||||
|
logging.warning('advantage_std too small (< 1e-3) for advantage standardization. may cause numerical issues')
|
||||||
|
feed_dict[self.required_placeholders['advantage']] = (advantage_value - advantage_mean) / advantage_std
|
||||||
|
|
||||||
|
return feed_dict
|
||||||
|
|
||||||
def statistics(self):
|
def statistics(self):
|
||||||
pass
|
pass
|
Loading…
x
Reference in New Issue
Block a user