2020-05-12 11:31:47 +08:00
|
|
|
import gym
|
2020-03-16 11:11:29 +08:00
|
|
|
import time
|
2020-03-14 21:48:31 +08:00
|
|
|
import torch
|
2020-03-28 07:27:18 +08:00
|
|
|
import warnings
|
2020-03-28 15:14:41 +08:00
|
|
|
import numpy as np
|
2020-05-12 11:31:47 +08:00
|
|
|
from typing import Any, Dict, List, Union, Optional, Callable
|
2020-04-09 19:53:45 +08:00
|
|
|
|
2020-03-12 22:20:33 +08:00
|
|
|
from tianshou.utils import MovAvg
|
2020-04-09 19:53:45 +08:00
|
|
|
from tianshou.env import BaseVectorEnv
|
2020-05-12 11:31:47 +08:00
|
|
|
from tianshou.policy import BasePolicy
|
2020-06-23 07:20:51 +08:00
|
|
|
from tianshou.exploration import BaseNoise
|
2020-07-13 00:24:31 +08:00
|
|
|
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
|
2020-03-12 22:20:33 +08:00
|
|
|
|
2020-03-13 17:49:22 +08:00
|
|
|
|
2020-03-12 22:20:33 +08:00
|
|
|
class Collector(object):
|
2020-04-05 18:34:45 +08:00
|
|
|
"""The :class:`~tianshou.data.Collector` enables the policy to interact
|
2020-04-06 19:36:59 +08:00
|
|
|
with different types of environments conveniently.
|
|
|
|
|
|
|
|
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
|
|
|
class.
|
2020-05-05 13:39:51 +08:00
|
|
|
:param env: a ``gym.Env`` environment or an instance of the
|
2020-04-06 19:36:59 +08:00
|
|
|
:class:`~tianshou.env.BaseVectorEnv` class.
|
|
|
|
:param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer`
|
|
|
|
class, or a list of :class:`~tianshou.data.ReplayBuffer`. If set to
|
|
|
|
``None``, it will automatically assign a small-size
|
|
|
|
:class:`~tianshou.data.ReplayBuffer`.
|
2020-05-05 13:39:51 +08:00
|
|
|
:param function preprocess_fn: a function called before the data has been
|
2020-07-13 00:24:31 +08:00
|
|
|
added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults
|
|
|
|
to ``None``.
|
2020-04-06 19:36:59 +08:00
|
|
|
:param int stat_size: for the moving average of recording speed, defaults
|
|
|
|
to 100.
|
2020-06-23 07:20:51 +08:00
|
|
|
:param BaseNoise action_noise: add a noise to continuous action. Normally
|
|
|
|
a policy already has a noise param for exploration in training phase,
|
|
|
|
so this is recommended to use in test collector for some purpose.
|
2020-07-13 00:24:31 +08:00
|
|
|
:param function reward_metric: to be used in multi-agent RL. The reward to
|
|
|
|
report is of shape [agent_num], but we need to return a single scalar
|
|
|
|
to monitor training. This function specifies what is the desired
|
|
|
|
metric, e.g., the reward of agent 1 or the average reward over all
|
|
|
|
agents. By default, the behavior is to select the reward of agent 1.
|
2020-04-06 19:36:59 +08:00
|
|
|
|
2020-05-05 13:39:51 +08:00
|
|
|
The ``preprocess_fn`` is a function called before the data has been added
|
|
|
|
to the buffer with batch format, which receives up to 7 keys as listed in
|
|
|
|
:class:`~tianshou.data.Batch`. It will receive with only ``obs`` when the
|
|
|
|
collector resets the environment. It returns either a dict or a
|
|
|
|
:class:`~tianshou.data.Batch` with the modified keys and values. Examples
|
|
|
|
are in "test/base/test_collector.py".
|
|
|
|
|
2020-04-06 19:36:59 +08:00
|
|
|
Example:
|
2020-04-05 18:34:45 +08:00
|
|
|
::
|
|
|
|
|
|
|
|
policy = PGPolicy(...) # or other policies if you wish
|
|
|
|
env = gym.make('CartPole-v0')
|
|
|
|
replay_buffer = ReplayBuffer(size=10000)
|
|
|
|
# here we set up a collector with a single environment
|
|
|
|
collector = Collector(policy, env, buffer=replay_buffer)
|
|
|
|
|
|
|
|
# the collector supports vectorized environments as well
|
|
|
|
envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)])
|
|
|
|
buffers = [ReplayBuffer(size=5000) for _ in range(3)]
|
|
|
|
# you can also pass a list of replay buffer to collector, for multi-env
|
|
|
|
# collector = Collector(policy, envs, buffer=buffers)
|
|
|
|
collector = Collector(policy, envs, buffer=replay_buffer)
|
|
|
|
|
|
|
|
# collect at least 3 episodes
|
|
|
|
collector.collect(n_episode=3)
|
|
|
|
# collect 1 episode for the first env, 3 for the third env
|
|
|
|
collector.collect(n_episode=[1, 0, 3])
|
|
|
|
# collect at least 2 steps
|
|
|
|
collector.collect(n_step=2)
|
|
|
|
# collect episodes with visual rendering (the render argument is the
|
|
|
|
# sleep time between rendering consecutive frames)
|
|
|
|
collector.collect(n_episode=1, render=0.03)
|
|
|
|
|
|
|
|
# sample data with a given number of batch-size:
|
|
|
|
batch_data = collector.sample(batch_size=64)
|
|
|
|
# policy.learn(batch_data) # btw, vanilla policy gradient only
|
|
|
|
# supports on-policy training, so here we pick all data in the buffer
|
|
|
|
batch_data = collector.sample(batch_size=0)
|
|
|
|
policy.learn(batch_data)
|
|
|
|
# on-policy algorithms use the collected data only once, so here we
|
|
|
|
# clear the buffer
|
|
|
|
collector.reset_buffer()
|
|
|
|
|
|
|
|
For the scenario of collecting data from multiple environments to a single
|
|
|
|
buffer, the cache buffers will turn on automatically. It may return the
|
|
|
|
data more than the given limitation.
|
|
|
|
|
|
|
|
.. note::
|
|
|
|
|
|
|
|
Please make sure the given environment has a time limitation.
|
|
|
|
"""
|
2020-03-13 17:49:22 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def __init__(self,
|
|
|
|
policy: BasePolicy,
|
|
|
|
env: Union[gym.Env, BaseVectorEnv],
|
2020-07-13 00:24:31 +08:00
|
|
|
buffer: Optional[ReplayBuffer] = None,
|
2020-05-12 11:31:47 +08:00
|
|
|
preprocess_fn: Callable[[Any], Union[dict, Batch]] = None,
|
|
|
|
stat_size: Optional[int] = 100,
|
2020-06-23 07:20:51 +08:00
|
|
|
action_noise: Optional[BaseNoise] = None,
|
2020-07-13 00:24:31 +08:00
|
|
|
reward_metric: Optional[Callable[[np.ndarray], float]] = None,
|
2020-07-16 19:36:32 +08:00
|
|
|
) -> None:
|
2020-03-12 22:20:33 +08:00
|
|
|
super().__init__()
|
|
|
|
self.env = env
|
|
|
|
self.env_num = 1
|
2020-07-13 00:24:31 +08:00
|
|
|
self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
|
2020-04-20 11:50:18 +08:00
|
|
|
self.buffer = buffer
|
2020-03-12 22:20:33 +08:00
|
|
|
self.policy = policy
|
2020-05-05 13:39:51 +08:00
|
|
|
self.preprocess_fn = preprocess_fn
|
2020-03-12 22:20:33 +08:00
|
|
|
self.process_fn = policy.process_fn
|
2020-03-14 21:48:31 +08:00
|
|
|
self._multi_env = isinstance(env, BaseVectorEnv)
|
2020-03-15 17:41:00 +08:00
|
|
|
# need multiple cache buffers only if storing in one buffer
|
2020-03-14 21:48:31 +08:00
|
|
|
self._cached_buf = []
|
|
|
|
if self._multi_env:
|
2020-03-12 22:20:33 +08:00
|
|
|
self.env_num = len(env)
|
2020-07-13 00:24:31 +08:00
|
|
|
self._cached_buf = [ListReplayBuffer()
|
|
|
|
for _ in range(self.env_num)]
|
2020-04-13 19:37:27 +08:00
|
|
|
self.stat_size = stat_size
|
2020-06-23 07:20:51 +08:00
|
|
|
self._action_noise = action_noise
|
2020-07-13 00:24:31 +08:00
|
|
|
|
|
|
|
self._rew_metric = reward_metric or Collector._default_rew_metric
|
2020-04-13 19:37:27 +08:00
|
|
|
self.reset()
|
|
|
|
|
2020-07-13 00:24:31 +08:00
|
|
|
@staticmethod
|
|
|
|
def _default_rew_metric(x):
|
|
|
|
# this internal function is designed for single-agent RL
|
|
|
|
# for multi-agent RL, a reward_metric must be provided
|
|
|
|
assert np.asanyarray(x).size == 1, \
|
|
|
|
'Please specify the reward_metric ' \
|
|
|
|
'since the reward is not a scalar.'
|
|
|
|
return x
|
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def reset(self) -> None:
|
2020-04-13 19:37:27 +08:00
|
|
|
"""Reset all related variables in the collector."""
|
2020-07-13 00:24:31 +08:00
|
|
|
self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={},
|
|
|
|
obs_next={}, policy={})
|
2020-03-12 22:20:33 +08:00
|
|
|
self.reset_env()
|
2020-03-15 17:41:00 +08:00
|
|
|
self.reset_buffer()
|
2020-04-13 19:37:27 +08:00
|
|
|
self.step_speed = MovAvg(self.stat_size)
|
|
|
|
self.episode_speed = MovAvg(self.stat_size)
|
2020-07-13 00:24:31 +08:00
|
|
|
self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
|
2020-06-23 07:20:51 +08:00
|
|
|
if self._action_noise is not None:
|
|
|
|
self._action_noise.reset()
|
2020-03-12 22:20:33 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def reset_buffer(self) -> None:
|
2020-04-05 18:34:45 +08:00
|
|
|
"""Reset the main data buffer."""
|
2020-07-13 00:24:31 +08:00
|
|
|
if self.buffer is not None:
|
|
|
|
self.buffer.reset()
|
2020-03-12 22:20:33 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def get_env_num(self) -> int:
|
2020-05-05 13:39:51 +08:00
|
|
|
"""Return the number of environments the collector have."""
|
2020-03-27 09:04:29 +08:00
|
|
|
return self.env_num
|
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def reset_env(self) -> None:
|
2020-04-05 18:34:45 +08:00
|
|
|
"""Reset all of the environment(s)' states and reset all of the cache
|
|
|
|
buffers (if need).
|
|
|
|
"""
|
2020-07-13 00:24:31 +08:00
|
|
|
obs = self.env.reset()
|
2020-05-05 13:39:51 +08:00
|
|
|
if not self._multi_env:
|
2020-07-13 00:24:31 +08:00
|
|
|
obs = self._make_batch(obs)
|
2020-05-05 13:39:51 +08:00
|
|
|
if self.preprocess_fn:
|
2020-07-13 00:24:31 +08:00
|
|
|
obs = self.preprocess_fn(obs=obs).get('obs', obs)
|
|
|
|
self.data.obs = obs
|
|
|
|
self.reward = 0. # will be specified when the first data is ready
|
|
|
|
self.length = np.zeros(self.env_num)
|
2020-03-14 21:48:31 +08:00
|
|
|
for b in self._cached_buf:
|
|
|
|
b.reset()
|
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
|
2020-04-05 18:34:45 +08:00
|
|
|
"""Reset all the seed(s) of the given environment(s)."""
|
2020-07-13 00:24:31 +08:00
|
|
|
return self.env.seed(seed)
|
2020-03-15 17:41:00 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def render(self, **kwargs) -> None:
|
2020-04-05 18:34:45 +08:00
|
|
|
"""Render all the environment(s)."""
|
2020-07-13 00:24:31 +08:00
|
|
|
return self.env.render(**kwargs)
|
2020-03-15 17:41:00 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def close(self) -> None:
|
2020-04-05 18:34:45 +08:00
|
|
|
"""Close the environment(s)."""
|
2020-07-13 00:24:31 +08:00
|
|
|
self.env.close()
|
2020-03-15 17:41:00 +08:00
|
|
|
|
2020-06-09 18:46:14 +08:00
|
|
|
def _make_batch(self, data: Any) -> np.ndarray:
|
2020-04-08 21:13:15 +08:00
|
|
|
"""Return [data]."""
|
2020-03-14 21:48:31 +08:00
|
|
|
if isinstance(data, np.ndarray):
|
|
|
|
return data[None]
|
|
|
|
else:
|
2020-03-25 14:08:28 +08:00
|
|
|
return np.array([data])
|
2020-03-12 22:20:33 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def _reset_state(self, id: Union[int, List[int]]) -> None:
|
2020-07-13 00:24:31 +08:00
|
|
|
"""Reset self.data.state[id]."""
|
|
|
|
state = self.data.state # it is a reference
|
|
|
|
if isinstance(state, torch.Tensor):
|
|
|
|
state[id].zero_()
|
|
|
|
elif isinstance(state, np.ndarray):
|
|
|
|
state[id] = None if state.dtype == np.object else 0
|
|
|
|
elif isinstance(state, Batch):
|
|
|
|
state.empty_(id)
|
2020-04-08 21:13:15 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def collect(self,
|
2020-05-16 20:08:32 +08:00
|
|
|
n_step: int = 0,
|
|
|
|
n_episode: Union[int, List[int]] = 0,
|
2020-06-11 08:57:37 +08:00
|
|
|
random: bool = False,
|
2020-05-12 11:31:47 +08:00
|
|
|
render: Optional[float] = None,
|
|
|
|
log_fn: Optional[Callable[[dict], None]] = None
|
|
|
|
) -> Dict[str, float]:
|
2020-04-05 18:34:45 +08:00
|
|
|
"""Collect a specified number of step or episode.
|
|
|
|
|
2020-04-06 19:36:59 +08:00
|
|
|
:param int n_step: how many steps you want to collect.
|
|
|
|
:param n_episode: how many episodes you want to collect (in each
|
|
|
|
environment).
|
|
|
|
:type n_episode: int or list
|
2020-06-11 08:57:37 +08:00
|
|
|
:param bool random: whether to use random policy for collecting data,
|
|
|
|
defaults to ``False``.
|
2020-04-06 19:36:59 +08:00
|
|
|
:param float render: the sleep time between rendering consecutive
|
2020-04-08 21:13:15 +08:00
|
|
|
frames, defaults to ``None`` (no rendering).
|
2020-04-10 18:02:05 +08:00
|
|
|
:param function log_fn: a function which receives env info, typically
|
|
|
|
for tensorboard logging.
|
2020-04-05 18:34:45 +08:00
|
|
|
|
|
|
|
.. note::
|
|
|
|
|
|
|
|
One and only one collection number specification is permitted,
|
|
|
|
either ``n_step`` or ``n_episode``.
|
|
|
|
|
|
|
|
:return: A dict including the following keys
|
|
|
|
|
|
|
|
* ``n/ep`` the collected number of episodes.
|
|
|
|
* ``n/st`` the collected number of steps.
|
|
|
|
* ``v/st`` the speed of steps per second.
|
|
|
|
* ``v/ep`` the speed of episode per second.
|
|
|
|
* ``rew`` the mean reward over collected episodes.
|
|
|
|
* ``len`` the mean length over collected episodes.
|
|
|
|
"""
|
2020-03-25 14:08:28 +08:00
|
|
|
if not self._multi_env:
|
|
|
|
n_episode = np.sum(n_episode)
|
2020-03-16 15:04:58 +08:00
|
|
|
start_time = time.time()
|
2020-03-26 09:01:20 +08:00
|
|
|
assert sum([(n_step != 0), (n_episode != 0)]) == 1, \
|
2020-04-05 18:34:45 +08:00
|
|
|
"One and only one collection number specification is permitted!"
|
2020-07-13 00:24:31 +08:00
|
|
|
cur_step, cur_episode = 0, np.zeros(self.env_num)
|
|
|
|
reward_sum, length_sum = 0., 0
|
2020-03-12 22:20:33 +08:00
|
|
|
while True:
|
2020-07-13 00:24:31 +08:00
|
|
|
if cur_step >= 100000 and cur_episode.sum() == 0:
|
2020-03-28 07:27:18 +08:00
|
|
|
warnings.warn(
|
2020-03-28 09:43:35 +08:00
|
|
|
'There are already many steps in an episode. '
|
|
|
|
'You should add a time limitation to your environment!',
|
2020-03-28 07:27:18 +08:00
|
|
|
Warning)
|
2020-07-13 00:24:31 +08:00
|
|
|
|
|
|
|
# restore the state and the input data
|
|
|
|
last_state = self.data.state
|
|
|
|
if last_state.is_empty():
|
|
|
|
last_state = None
|
|
|
|
self.data.update(state=Batch(), obs_next=Batch(), policy=Batch())
|
|
|
|
|
|
|
|
# calculate the next action
|
2020-06-11 08:57:37 +08:00
|
|
|
if random:
|
|
|
|
action_space = self.env.action_space
|
|
|
|
if isinstance(action_space, list):
|
|
|
|
result = Batch(act=[a.sample() for a in action_space])
|
|
|
|
else:
|
|
|
|
result = Batch(act=self._make_batch(action_space.sample()))
|
|
|
|
else:
|
|
|
|
with torch.no_grad():
|
2020-07-13 00:24:31 +08:00
|
|
|
result = self.policy(self.data, last_state)
|
2020-06-29 12:18:52 +08:00
|
|
|
|
2020-07-13 00:24:31 +08:00
|
|
|
# convert None to Batch(), since None is reserved for 0-init
|
|
|
|
state = result.get('state', Batch())
|
|
|
|
if state is None:
|
|
|
|
state = Batch()
|
|
|
|
self.data.state = state
|
2020-06-29 12:18:52 +08:00
|
|
|
if hasattr(result, 'policy'):
|
2020-07-13 00:24:31 +08:00
|
|
|
self.data.policy = to_numpy(result.policy)
|
|
|
|
# save hidden state to policy._state, in order to save into buffer
|
|
|
|
self.data.policy._state = self.data.state
|
2020-06-29 12:18:52 +08:00
|
|
|
|
2020-07-13 00:24:31 +08:00
|
|
|
self.data.act = to_numpy(result.act)
|
2020-06-23 07:20:51 +08:00
|
|
|
if self._action_noise is not None:
|
2020-07-13 00:24:31 +08:00
|
|
|
self.data.act += self._action_noise(self.data.act.shape)
|
|
|
|
|
|
|
|
# step in env
|
|
|
|
obs_next, rew, done, info = self.env.step(
|
|
|
|
self.data.act if self._multi_env else self.data.act[0])
|
|
|
|
|
|
|
|
# move data to self.data
|
2020-05-05 13:39:51 +08:00
|
|
|
if not self._multi_env:
|
|
|
|
obs_next = self._make_batch(obs_next)
|
2020-07-13 00:24:31 +08:00
|
|
|
rew = self._make_batch(rew)
|
|
|
|
done = self._make_batch(done)
|
|
|
|
info = self._make_batch(info)
|
|
|
|
self.data.obs_next = obs_next
|
|
|
|
self.data.rew = rew
|
|
|
|
self.data.done = done
|
|
|
|
self.data.info = info
|
|
|
|
|
2020-05-05 13:39:51 +08:00
|
|
|
if log_fn:
|
2020-07-13 00:24:31 +08:00
|
|
|
log_fn(info if self._multi_env else info[0])
|
2020-05-05 13:39:51 +08:00
|
|
|
if render:
|
2020-07-13 00:24:31 +08:00
|
|
|
self.render()
|
2020-04-08 21:13:15 +08:00
|
|
|
if render > 0:
|
|
|
|
time.sleep(render)
|
2020-07-13 00:24:31 +08:00
|
|
|
|
|
|
|
# add data into the buffer
|
2020-03-13 17:49:22 +08:00
|
|
|
self.length += 1
|
2020-07-13 00:24:31 +08:00
|
|
|
self.reward += self.data.rew
|
2020-05-05 13:39:51 +08:00
|
|
|
if self.preprocess_fn:
|
2020-07-13 00:24:31 +08:00
|
|
|
result = self.preprocess_fn(**self.data)
|
|
|
|
self.data.update(result)
|
|
|
|
if self._multi_env: # cache_buffer branch
|
2020-03-12 22:20:33 +08:00
|
|
|
for i in range(self.env_num):
|
2020-07-13 00:24:31 +08:00
|
|
|
self._cached_buf[i].add(**self.data[i])
|
|
|
|
if self.data.done[i]:
|
2020-03-25 14:08:28 +08:00
|
|
|
if n_step != 0 or np.isscalar(n_episode) or \
|
|
|
|
cur_episode[i] < n_episode[i]:
|
|
|
|
cur_episode[i] += 1
|
|
|
|
reward_sum += self.reward[i]
|
|
|
|
length_sum += self.length[i]
|
|
|
|
if self._cached_buf:
|
|
|
|
cur_step += len(self._cached_buf[i])
|
2020-04-20 11:50:18 +08:00
|
|
|
if self.buffer is not None:
|
|
|
|
self.buffer.update(self._cached_buf[i])
|
2020-07-13 00:24:31 +08:00
|
|
|
self.reward[i], self.length[i] = 0., 0
|
2020-03-14 21:48:31 +08:00
|
|
|
if self._cached_buf:
|
|
|
|
self._cached_buf[i].reset()
|
2020-04-08 21:13:15 +08:00
|
|
|
self._reset_state(i)
|
2020-07-13 00:24:31 +08:00
|
|
|
obs_next = self.data.obs_next
|
|
|
|
if sum(self.data.done):
|
2020-07-13 16:38:42 +02:00
|
|
|
env_ind = np.where(self.data.done)[0]
|
|
|
|
obs_reset = self.env.reset(env_ind)
|
2020-05-05 13:39:51 +08:00
|
|
|
if self.preprocess_fn:
|
2020-07-13 16:38:42 +02:00
|
|
|
obs_next[env_ind] = self.preprocess_fn(
|
|
|
|
obs=obs_reset).get('obs', obs_reset)
|
|
|
|
else:
|
|
|
|
obs_next[env_ind] = obs_reset
|
2020-07-13 00:24:31 +08:00
|
|
|
self.data.obs_next = obs_next
|
2020-03-25 14:08:28 +08:00
|
|
|
if n_episode != 0:
|
|
|
|
if isinstance(n_episode, list) and \
|
|
|
|
(cur_episode >= np.array(n_episode)).all() or \
|
|
|
|
np.isscalar(n_episode) and \
|
|
|
|
cur_episode.sum() >= n_episode:
|
|
|
|
break
|
2020-07-13 00:24:31 +08:00
|
|
|
else: # single buffer, without cache_buffer
|
2020-04-20 11:50:18 +08:00
|
|
|
if self.buffer is not None:
|
2020-07-13 00:24:31 +08:00
|
|
|
self.buffer.add(**self.data[0])
|
2020-03-14 21:48:31 +08:00
|
|
|
cur_step += 1
|
2020-07-13 00:24:31 +08:00
|
|
|
if self.data.done[0]:
|
2020-03-12 22:20:33 +08:00
|
|
|
cur_episode += 1
|
2020-05-05 13:39:51 +08:00
|
|
|
reward_sum += self.reward[0]
|
2020-07-13 00:24:31 +08:00
|
|
|
length_sum += self.length[0]
|
|
|
|
self.reward, self.length = 0., np.zeros(self.env_num)
|
|
|
|
self.data.state = Batch()
|
2020-05-05 13:39:51 +08:00
|
|
|
obs_next = self._make_batch(self.env.reset())
|
|
|
|
if self.preprocess_fn:
|
|
|
|
obs_next = self.preprocess_fn(obs=obs_next).get(
|
|
|
|
'obs', obs_next)
|
2020-07-13 00:24:31 +08:00
|
|
|
self.data.obs_next = obs_next
|
2020-03-25 14:08:28 +08:00
|
|
|
if n_episode != 0 and cur_episode >= n_episode:
|
2020-03-12 22:20:33 +08:00
|
|
|
break
|
2020-03-25 14:08:28 +08:00
|
|
|
if n_step != 0 and cur_step >= n_step:
|
2020-03-12 22:20:33 +08:00
|
|
|
break
|
2020-07-13 00:24:31 +08:00
|
|
|
self.data.obs = self.data.obs_next
|
|
|
|
self.data.obs = self.data.obs_next
|
|
|
|
|
|
|
|
# generate the statistics
|
|
|
|
cur_episode = sum(cur_episode)
|
2020-04-09 19:53:45 +08:00
|
|
|
duration = max(time.time() - start_time, 1e-9)
|
2020-03-19 17:23:46 +08:00
|
|
|
self.step_speed.add(cur_step / duration)
|
|
|
|
self.episode_speed.add(cur_episode / duration)
|
|
|
|
self.collect_step += cur_step
|
|
|
|
self.collect_episode += cur_episode
|
2020-03-20 19:52:29 +08:00
|
|
|
self.collect_time += duration
|
2020-03-25 14:08:28 +08:00
|
|
|
if isinstance(n_episode, list):
|
|
|
|
n_episode = np.sum(n_episode)
|
|
|
|
else:
|
|
|
|
n_episode = max(cur_episode, 1)
|
2020-07-13 00:24:31 +08:00
|
|
|
reward_sum /= n_episode
|
|
|
|
if np.asanyarray(reward_sum).size > 1: # non-scalar reward_sum
|
|
|
|
reward_sum = self._rew_metric(reward_sum)
|
2020-03-16 15:04:58 +08:00
|
|
|
return {
|
2020-03-19 17:23:46 +08:00
|
|
|
'n/ep': cur_episode,
|
|
|
|
'n/st': cur_step,
|
2020-03-20 19:52:29 +08:00
|
|
|
'v/st': self.step_speed.get(),
|
|
|
|
'v/ep': self.episode_speed.get(),
|
2020-07-13 00:24:31 +08:00
|
|
|
'rew': reward_sum,
|
2020-03-25 14:08:28 +08:00
|
|
|
'len': length_sum / n_episode,
|
2020-03-16 15:04:58 +08:00
|
|
|
}
|
2020-03-12 22:20:33 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def sample(self, batch_size: int) -> Batch:
|
2020-04-05 18:34:45 +08:00
|
|
|
"""Sample a data batch from the internal replay buffer. It will call
|
|
|
|
:meth:`~tianshou.policy.BasePolicy.process_fn` before returning
|
|
|
|
the final batch data.
|
|
|
|
|
2020-04-06 19:36:59 +08:00
|
|
|
:param int batch_size: ``0`` means it will extract all the data from
|
|
|
|
the buffer, otherwise it will extract the data with the given
|
|
|
|
batch_size.
|
2020-04-05 18:34:45 +08:00
|
|
|
"""
|
2020-07-13 00:24:31 +08:00
|
|
|
batch_data, indice = self.buffer.sample(batch_size)
|
|
|
|
batch_data = self.process_fn(batch_data, self.buffer, indice)
|
2020-03-13 17:49:22 +08:00
|
|
|
return batch_data
|