2018-03-04 21:29:58 +08:00
|
|
|
import numpy as np
|
|
|
|
import logging
|
|
|
|
import itertools
|
|
|
|
|
2018-03-09 15:07:14 +08:00
|
|
|
from .data_buffer.replay_buffer_base import ReplayBufferBase
|
|
|
|
from .data_buffer.batch_set import BatchSet
|
2018-03-10 17:30:11 +08:00
|
|
|
from .utils import internal_key_match
|
2018-04-15 17:41:43 +08:00
|
|
|
from ..core.policy.deterministic import Deterministic
|
2018-03-03 20:42:34 +08:00
|
|
|
|
2018-05-20 22:36:04 +08:00
|
|
|
__all__ = [
|
|
|
|
'DataCollector',
|
|
|
|
]
|
|
|
|
|
2018-04-15 11:46:46 +08:00
|
|
|
|
2018-03-03 20:42:34 +08:00
|
|
|
class DataCollector(object):
|
|
|
|
"""
|
2018-04-15 11:46:46 +08:00
|
|
|
A utility class to manage the data flow during the interaction between the policy and the environment.
|
2018-04-15 17:41:43 +08:00
|
|
|
It stores data into ``data_buffer``, processes the reward signals and returns the feed_dict for
|
|
|
|
tf graph running.
|
|
|
|
|
|
|
|
:param env: An environment.
|
|
|
|
:param policy: A :class:`tianshou.core.policy`.
|
|
|
|
:param data_buffer: A :class:`tianshou.data.data_buffer`.
|
|
|
|
:param process_functions: A list of callables in :mod:`tianshou.data.advantage_estimation`
|
|
|
|
to process rewards.
|
|
|
|
:param managed_networks: A list of networks of :class:`tianshou.core.policy` and/or
|
|
|
|
:class:`tianshou.core.value_function`. The networks you want this class to manage. This class
|
|
|
|
will automatically generate the feed_dict for all the placeholders in the ``managed_placeholders``
|
|
|
|
of all networks in this list.
|
2018-03-03 20:42:34 +08:00
|
|
|
"""
|
|
|
|
def __init__(self, env, policy, data_buffer, process_functions, managed_networks):
|
|
|
|
self.env = env
|
|
|
|
self.policy = policy
|
|
|
|
self.data_buffer = data_buffer
|
|
|
|
self.process_functions = process_functions
|
|
|
|
self.managed_networks = managed_networks
|
|
|
|
|
2018-03-04 21:29:58 +08:00
|
|
|
self.data = {}
|
|
|
|
self.data_batch = {}
|
|
|
|
|
2018-03-03 20:42:34 +08:00
|
|
|
self.required_placeholders = {}
|
|
|
|
for net in self.managed_networks:
|
|
|
|
self.required_placeholders.update(net.managed_placeholders)
|
|
|
|
self.require_advantage = 'advantage' in self.required_placeholders.keys()
|
|
|
|
|
|
|
|
if isinstance(self.data_buffer, ReplayBufferBase): # process when sampling minibatch
|
2018-03-04 21:29:58 +08:00
|
|
|
self.process_mode = 'sample'
|
2018-03-03 20:42:34 +08:00
|
|
|
else:
|
2018-03-04 21:29:58 +08:00
|
|
|
self.process_mode = 'full'
|
2018-03-03 20:42:34 +08:00
|
|
|
|
|
|
|
self.current_observation = self.env.reset()
|
2018-03-28 18:47:41 +08:00
|
|
|
self.step_count_this_episode = 0
|
2018-03-03 20:42:34 +08:00
|
|
|
|
2018-03-28 18:47:41 +08:00
|
|
|
def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={}, auto_clear=True, episode_cutoff=None):
|
2018-04-15 17:41:43 +08:00
|
|
|
"""
|
|
|
|
Collect data in the environment using ``self.policy``.
|
|
|
|
|
|
|
|
:param num_timesteps: An int specifying the number of timesteps to act. It defaults to 0 and either
|
|
|
|
``num_timesteps`` or ``num_episodes`` could be set but not both.
|
|
|
|
:param num_episodes: An int specifying the number of episodes to act. It defaults to 0 and either
|
|
|
|
``num_timesteps`` or ``num_episodes`` could be set but not both.
|
|
|
|
:param my_feed_dict: Optional. A dict defaulting to empty.
|
|
|
|
Specifies placeholders such as dropout and batch_norm except observation and action.
|
|
|
|
:param auto_clear: Optional. A bool defaulting to ``True``. If ``True`` then this method clears the
|
|
|
|
``self.data_buffer`` if ``self.data_buffer`` is an instance of
|
|
|
|
:class:`tianshou.data.data_buffer.BatchSet.` and does nothing if it's not that instance.
|
|
|
|
If set to ``False`` then the aforementioned auto clearing behavior is disabled.
|
|
|
|
:param episode_cutoff: Optional. An int. The maximum number of timesteps in one episode. This is
|
|
|
|
useful when the environment has no terminal states or a single episode could be prohibitively long.
|
|
|
|
If set than all episodes are forced to stop beyond this number to timesteps.
|
|
|
|
"""
|
2018-03-03 20:42:34 +08:00
|
|
|
assert sum([num_timesteps > 0, num_episodes > 0]) == 1,\
|
|
|
|
"One and only one collection number specification permitted!"
|
|
|
|
|
2018-03-09 15:07:14 +08:00
|
|
|
if isinstance(self.data_buffer, BatchSet) and auto_clear:
|
|
|
|
self.data_buffer.clear()
|
|
|
|
|
2018-03-03 20:42:34 +08:00
|
|
|
if num_timesteps > 0:
|
2018-03-08 16:51:12 +08:00
|
|
|
num_timesteps_ = int(num_timesteps)
|
|
|
|
for _ in range(num_timesteps_):
|
2018-03-04 21:29:58 +08:00
|
|
|
action = self.policy.act(self.current_observation, my_feed_dict=my_feed_dict)
|
2018-03-03 20:42:34 +08:00
|
|
|
next_observation, reward, done, _ = self.env.step(action)
|
2018-03-28 18:47:41 +08:00
|
|
|
self.step_count_this_episode += 1
|
|
|
|
if episode_cutoff and self.step_count_this_episode >= episode_cutoff:
|
|
|
|
done = True
|
2018-03-03 20:42:34 +08:00
|
|
|
self.data_buffer.add((self.current_observation, action, reward, done))
|
2018-03-08 16:51:12 +08:00
|
|
|
|
|
|
|
if done:
|
|
|
|
self.current_observation = self.env.reset()
|
2018-03-11 17:47:42 +08:00
|
|
|
self.policy.reset()
|
2018-03-28 18:47:41 +08:00
|
|
|
self.step_count_this_episode = 0
|
2018-03-08 16:51:12 +08:00
|
|
|
else:
|
|
|
|
self.current_observation = next_observation
|
2018-03-03 20:42:34 +08:00
|
|
|
|
|
|
|
if num_episodes > 0:
|
2018-03-08 16:51:12 +08:00
|
|
|
num_episodes_ = int(num_episodes)
|
|
|
|
for _ in range(num_episodes_):
|
2018-03-03 20:42:34 +08:00
|
|
|
observation = self.env.reset()
|
2018-03-31 19:26:48 +08:00
|
|
|
self.policy.reset()
|
2018-03-03 20:42:34 +08:00
|
|
|
done = False
|
2018-03-28 18:47:41 +08:00
|
|
|
step_count = 0
|
2018-03-03 20:42:34 +08:00
|
|
|
while not done:
|
2018-03-04 21:29:58 +08:00
|
|
|
action = self.policy.act(observation, my_feed_dict=my_feed_dict)
|
2018-03-03 20:42:34 +08:00
|
|
|
next_observation, reward, done, _ = self.env.step(action)
|
2018-03-28 18:47:41 +08:00
|
|
|
step_count += 1
|
|
|
|
|
|
|
|
if episode_cutoff and step_count >= episode_cutoff:
|
|
|
|
done = True
|
|
|
|
|
2018-03-03 20:42:34 +08:00
|
|
|
self.data_buffer.add((observation, action, reward, done))
|
|
|
|
observation = next_observation
|
2018-03-28 18:47:41 +08:00
|
|
|
|
2018-03-11 17:47:42 +08:00
|
|
|
self.current_observation = self.env.reset()
|
2018-03-03 20:42:34 +08:00
|
|
|
|
2018-03-04 21:29:58 +08:00
|
|
|
if self.process_mode == 'full':
|
|
|
|
for processor in self.process_functions:
|
|
|
|
self.data.update(processor(self.data_buffer))
|
|
|
|
|
2018-04-15 17:41:43 +08:00
|
|
|
return
|
|
|
|
|
2018-08-16 16:20:14 +08:00
|
|
|
def next_batch(self, batch_size, standardize_advantage=True, my_feed_dict={}):
|
2018-04-15 17:41:43 +08:00
|
|
|
"""
|
|
|
|
Constructs and returns the feed_dict of data to be used with ``sess.run``.
|
|
|
|
|
|
|
|
:param batch_size: An int. The size of one minibatch.
|
|
|
|
:param standardize_advantage: Optional. A bool but defaulting to ``True``.
|
|
|
|
If ``True``, then this method standardize advantages if advantage is required by the networks.
|
|
|
|
If ``False`` then this method will never standardize advantage.
|
|
|
|
|
|
|
|
:return: A dict in the format of conventional feed_dict in tf, with keys the placeholders and
|
|
|
|
values the numpy arrays.
|
|
|
|
"""
|
2018-03-03 20:42:34 +08:00
|
|
|
sampled_index = self.data_buffer.sample(batch_size)
|
2018-03-04 21:29:58 +08:00
|
|
|
if self.process_mode == 'sample':
|
|
|
|
for processor in self.process_functions:
|
2018-08-16 16:20:14 +08:00
|
|
|
self.data_batch.update(processor(self.data_buffer, indexes=sampled_index, my_feed_dict=my_feed_dict))
|
2018-03-03 20:42:34 +08:00
|
|
|
|
2018-03-04 13:53:29 +08:00
|
|
|
# flatten rank-2 list to numpy array, construct feed_dict
|
2018-03-04 21:29:58 +08:00
|
|
|
feed_dict = {}
|
|
|
|
frame_key_map = {'observation': 0, 'action': 1, 'reward': 2, 'done_flag': 3}
|
|
|
|
for key, placeholder in self.required_placeholders.items():
|
2018-03-10 17:30:11 +08:00
|
|
|
# check raw_data first
|
|
|
|
found, matched_key = internal_key_match(key, frame_key_map.keys())
|
|
|
|
if found:
|
|
|
|
frame_index = frame_key_map[matched_key]
|
2018-03-04 21:29:58 +08:00
|
|
|
flattened = []
|
|
|
|
for index_episode, data_episode in zip(sampled_index, self.data_buffer.data):
|
|
|
|
for i in index_episode:
|
|
|
|
flattened.append(data_episode[i][frame_index])
|
|
|
|
feed_dict[placeholder] = np.array(flattened)
|
|
|
|
else:
|
2018-03-10 17:30:11 +08:00
|
|
|
# then check processed minibatch data
|
|
|
|
found, matched_key = internal_key_match(key, self.data_batch.keys())
|
|
|
|
if found:
|
|
|
|
flattened = list(itertools.chain.from_iterable(self.data_batch[matched_key]))
|
|
|
|
feed_dict[placeholder] = np.array(flattened)
|
|
|
|
else:
|
|
|
|
# finally check processed full data
|
|
|
|
found, matched_key = internal_key_match(key, self.data.keys())
|
|
|
|
if found:
|
|
|
|
flattened = [0.] * batch_size # float
|
|
|
|
i_in_batch = 0
|
|
|
|
for index_episode, data_episode in zip(sampled_index, self.data[matched_key]):
|
|
|
|
for i in index_episode:
|
|
|
|
flattened[i_in_batch] = data_episode[i]
|
|
|
|
i_in_batch += 1
|
|
|
|
feed_dict[placeholder] = np.array(flattened)
|
|
|
|
else:
|
|
|
|
raise TypeError('Placeholder {} has no value to feed!'.format(str(placeholder.name)))
|
2018-03-04 21:29:58 +08:00
|
|
|
|
2018-04-15 17:41:43 +08:00
|
|
|
if standardize_advantage:
|
2018-03-04 21:29:58 +08:00
|
|
|
if self.require_advantage:
|
|
|
|
advantage_value = feed_dict[self.required_placeholders['advantage']]
|
|
|
|
advantage_mean = np.mean(advantage_value)
|
|
|
|
advantage_std = np.std(advantage_value)
|
|
|
|
if advantage_std < 1e-3:
|
|
|
|
logging.warning('advantage_std too small (< 1e-3) for advantage standardization. may cause numerical issues')
|
|
|
|
feed_dict[self.required_placeholders['advantage']] = (advantage_value - advantage_mean) / advantage_std
|
2018-03-03 20:42:34 +08:00
|
|
|
|
2018-03-04 21:29:58 +08:00
|
|
|
return feed_dict
|
2018-04-11 14:23:40 +08:00
|
|
|
|
2018-04-15 17:41:43 +08:00
|
|
|
def denoise_action(self, feed_dict, my_feed_dict={}):
|
|
|
|
"""
|
|
|
|
Recompute the actions of deterministic policies without exploration noise, hence denoising.
|
|
|
|
It modifies ``feed_dict`` **in place** and has no return value.
|
|
|
|
This is useful in, e.g., DDPG since the stored action in ``self.data_buffer`` is the sampled
|
|
|
|
action with additional exploration noise.
|
|
|
|
|
|
|
|
:param feed_dict: A dict. It has to be the dict returned by :func:`next_batch` by this class.
|
|
|
|
:param my_feed_dict: Optional. A dict defaulting to empty.
|
|
|
|
Specifies placeholders such as dropout and batch_norm except observation and action.
|
|
|
|
"""
|
|
|
|
assert isinstance(self.policy, Deterministic), 'denoise_action() could only be called' \
|
|
|
|
'with deterministic policies'
|
2018-04-11 14:23:40 +08:00
|
|
|
observation = feed_dict[self.required_placeholders['observation']]
|
2018-04-15 17:41:43 +08:00
|
|
|
action_mean = self.policy.eval_action(observation, my_feed_dict)
|
2018-04-11 14:23:40 +08:00
|
|
|
feed_dict[self.required_placeholders['action']] = action_mean
|
|
|
|
|
|
|
|
return
|