2020-05-12 11:31:47 +08:00
|
|
|
import gym
|
2020-03-16 11:11:29 +08:00
|
|
|
import time
|
2020-03-14 21:48:31 +08:00
|
|
|
import torch
|
2020-03-28 07:27:18 +08:00
|
|
|
import warnings
|
2020-03-28 15:14:41 +08:00
|
|
|
import numpy as np
|
2020-08-27 12:15:18 +08:00
|
|
|
from copy import deepcopy
|
2020-09-12 15:39:01 +08:00
|
|
|
from numbers import Number
|
|
|
|
from typing import Dict, List, Union, Optional, Callable
|
2020-04-09 19:53:45 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
from tianshou.policy import BasePolicy
|
2020-06-23 07:20:51 +08:00
|
|
|
from tianshou.exploration import BaseNoise
|
2020-07-26 12:01:21 +02:00
|
|
|
from tianshou.data.batch import _create_value
|
2020-09-12 15:39:01 +08:00
|
|
|
from tianshou.env import BaseVectorEnv, DummyVectorEnv
|
|
|
|
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
|
2020-03-12 22:20:33 +08:00
|
|
|
|
2020-03-13 17:49:22 +08:00
|
|
|
|
2020-03-12 22:20:33 +08:00
|
|
|
class Collector(object):
|
2020-09-11 07:55:37 +08:00
|
|
|
"""Collector enables the policy to interact with different types of envs.
|
2020-04-06 19:36:59 +08:00
|
|
|
|
|
|
|
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
|
|
|
class.
|
2020-05-05 13:39:51 +08:00
|
|
|
:param env: a ``gym.Env`` environment or an instance of the
|
2020-04-06 19:36:59 +08:00
|
|
|
:class:`~tianshou.env.BaseVectorEnv` class.
|
|
|
|
:param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer`
|
2020-07-23 16:40:53 +08:00
|
|
|
class. If set to ``None`` (testing phase), it will not store the data.
|
2020-05-05 13:39:51 +08:00
|
|
|
:param function preprocess_fn: a function called before the data has been
|
2020-07-13 00:24:31 +08:00
|
|
|
added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults
|
2020-09-11 07:55:37 +08:00
|
|
|
to None.
|
2020-06-23 07:20:51 +08:00
|
|
|
:param BaseNoise action_noise: add a noise to continuous action. Normally
|
|
|
|
a policy already has a noise param for exploration in training phase,
|
|
|
|
so this is recommended to use in test collector for some purpose.
|
2020-07-13 00:24:31 +08:00
|
|
|
:param function reward_metric: to be used in multi-agent RL. The reward to
|
|
|
|
report is of shape [agent_num], but we need to return a single scalar
|
|
|
|
to monitor training. This function specifies what is the desired
|
|
|
|
metric, e.g., the reward of agent 1 or the average reward over all
|
|
|
|
agents. By default, the behavior is to select the reward of agent 1.
|
2020-04-06 19:36:59 +08:00
|
|
|
|
2020-05-05 13:39:51 +08:00
|
|
|
The ``preprocess_fn`` is a function called before the data has been added
|
|
|
|
to the buffer with batch format, which receives up to 7 keys as listed in
|
|
|
|
:class:`~tianshou.data.Batch`. It will receive with only ``obs`` when the
|
|
|
|
collector resets the environment. It returns either a dict or a
|
|
|
|
:class:`~tianshou.data.Batch` with the modified keys and values. Examples
|
|
|
|
are in "test/base/test_collector.py".
|
|
|
|
|
2020-09-11 07:55:37 +08:00
|
|
|
Here is the example:
|
2020-04-05 18:34:45 +08:00
|
|
|
::
|
|
|
|
|
|
|
|
policy = PGPolicy(...) # or other policies if you wish
|
|
|
|
env = gym.make('CartPole-v0')
|
|
|
|
replay_buffer = ReplayBuffer(size=10000)
|
|
|
|
# here we set up a collector with a single environment
|
|
|
|
collector = Collector(policy, env, buffer=replay_buffer)
|
|
|
|
|
|
|
|
# the collector supports vectorized environments as well
|
2020-08-19 15:00:24 +08:00
|
|
|
envs = DummyVectorEnv([lambda: gym.make('CartPole-v0')
|
|
|
|
for _ in range(3)])
|
2020-04-05 18:34:45 +08:00
|
|
|
collector = Collector(policy, envs, buffer=replay_buffer)
|
|
|
|
|
2020-07-23 16:40:53 +08:00
|
|
|
# collect 3 episodes
|
2020-04-05 18:34:45 +08:00
|
|
|
collector.collect(n_episode=3)
|
|
|
|
# collect 1 episode for the first env, 3 for the third env
|
|
|
|
collector.collect(n_episode=[1, 0, 3])
|
|
|
|
# collect at least 2 steps
|
|
|
|
collector.collect(n_step=2)
|
|
|
|
# collect episodes with visual rendering (the render argument is the
|
|
|
|
# sleep time between rendering consecutive frames)
|
|
|
|
collector.collect(n_episode=1, render=0.03)
|
|
|
|
|
2020-07-23 16:40:53 +08:00
|
|
|
Collected data always consist of full episodes. So if only ``n_step``
|
|
|
|
argument is give, the collector may return the data more than the
|
|
|
|
``n_step`` limitation. Same as ``n_episode`` for the multiple environment
|
|
|
|
case.
|
2020-04-05 18:34:45 +08:00
|
|
|
|
|
|
|
.. note::
|
|
|
|
|
|
|
|
Please make sure the given environment has a time limitation.
|
|
|
|
"""
|
2020-03-13 17:49:22 +08:00
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
policy: BasePolicy,
|
|
|
|
env: Union[gym.Env, BaseVectorEnv],
|
|
|
|
buffer: Optional[ReplayBuffer] = None,
|
|
|
|
preprocess_fn: Optional[Callable[..., Batch]] = None,
|
|
|
|
action_noise: Optional[BaseNoise] = None,
|
|
|
|
reward_metric: Optional[Callable[[np.ndarray], float]] = None,
|
|
|
|
) -> None:
|
2020-03-12 22:20:33 +08:00
|
|
|
super().__init__()
|
2020-07-23 16:40:53 +08:00
|
|
|
if not isinstance(env, BaseVectorEnv):
|
2020-08-19 15:00:24 +08:00
|
|
|
env = DummyVectorEnv([lambda: env])
|
2020-03-12 22:20:33 +08:00
|
|
|
self.env = env
|
2020-07-23 16:40:53 +08:00
|
|
|
self.env_num = len(env)
|
2020-07-26 12:01:21 +02:00
|
|
|
# environments that are available in step()
|
|
|
|
# this means all environments in synchronous simulation
|
|
|
|
# but only a subset of environments in asynchronous simulation
|
|
|
|
self._ready_env_ids = np.arange(self.env_num)
|
|
|
|
# self.async is a flag to indicate whether this collector works
|
|
|
|
# with asynchronous simulation
|
2020-08-19 15:00:24 +08:00
|
|
|
self.is_async = env.is_async
|
2020-07-23 16:40:53 +08:00
|
|
|
# need cache buffers before storing in the main buffer
|
|
|
|
self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)]
|
2020-04-20 11:50:18 +08:00
|
|
|
self.buffer = buffer
|
2020-03-12 22:20:33 +08:00
|
|
|
self.policy = policy
|
2020-05-05 13:39:51 +08:00
|
|
|
self.preprocess_fn = preprocess_fn
|
2020-03-12 22:20:33 +08:00
|
|
|
self.process_fn = policy.process_fn
|
2020-08-19 15:00:24 +08:00
|
|
|
self._action_space = env.action_space
|
2020-06-23 07:20:51 +08:00
|
|
|
self._action_noise = action_noise
|
2020-07-13 00:24:31 +08:00
|
|
|
self._rew_metric = reward_metric or Collector._default_rew_metric
|
2020-07-26 12:01:21 +02:00
|
|
|
# avoid creating attribute outside __init__
|
2020-04-13 19:37:27 +08:00
|
|
|
self.reset()
|
|
|
|
|
2020-07-13 00:24:31 +08:00
|
|
|
@staticmethod
|
2020-09-12 15:39:01 +08:00
|
|
|
def _default_rew_metric(
|
|
|
|
x: Union[Number, np.number]
|
|
|
|
) -> Union[Number, np.number]:
|
2020-07-13 00:24:31 +08:00
|
|
|
# this internal function is designed for single-agent RL
|
|
|
|
# for multi-agent RL, a reward_metric must be provided
|
2020-09-12 15:39:01 +08:00
|
|
|
assert np.asanyarray(x).size == 1, (
|
|
|
|
"Please specify the reward_metric "
|
|
|
|
"since the reward is not a scalar."
|
|
|
|
)
|
2020-07-13 00:24:31 +08:00
|
|
|
return x
|
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def reset(self) -> None:
|
2020-04-13 19:37:27 +08:00
|
|
|
"""Reset all related variables in the collector."""
|
2020-08-19 15:00:24 +08:00
|
|
|
# use empty Batch for ``state`` so that ``self.data`` supports slicing
|
|
|
|
# convert empty Batch to None when passing data to policy
|
2020-07-13 00:24:31 +08:00
|
|
|
self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={},
|
|
|
|
obs_next={}, policy={})
|
2020-03-12 22:20:33 +08:00
|
|
|
self.reset_env()
|
2020-03-15 17:41:00 +08:00
|
|
|
self.reset_buffer()
|
2020-09-12 15:39:01 +08:00
|
|
|
self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0
|
2020-06-23 07:20:51 +08:00
|
|
|
if self._action_noise is not None:
|
|
|
|
self._action_noise.reset()
|
2020-03-12 22:20:33 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def reset_buffer(self) -> None:
|
2020-04-05 18:34:45 +08:00
|
|
|
"""Reset the main data buffer."""
|
2020-07-13 00:24:31 +08:00
|
|
|
if self.buffer is not None:
|
|
|
|
self.buffer.reset()
|
2020-03-12 22:20:33 +08:00
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def get_env_num(self) -> int:
|
2020-05-05 13:39:51 +08:00
|
|
|
"""Return the number of environments the collector have."""
|
2020-03-27 09:04:29 +08:00
|
|
|
return self.env_num
|
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def reset_env(self) -> None:
|
2020-09-11 07:55:37 +08:00
|
|
|
"""Reset all of the environment(s)' states and the cache buffers."""
|
2020-07-26 12:01:21 +02:00
|
|
|
self._ready_env_ids = np.arange(self.env_num)
|
2020-07-13 00:24:31 +08:00
|
|
|
obs = self.env.reset()
|
2020-05-05 13:39:51 +08:00
|
|
|
if self.preprocess_fn:
|
2020-09-12 15:39:01 +08:00
|
|
|
obs = self.preprocess_fn(obs=obs).get("obs", obs)
|
2020-07-13 00:24:31 +08:00
|
|
|
self.data.obs = obs
|
2020-03-14 21:48:31 +08:00
|
|
|
for b in self._cached_buf:
|
|
|
|
b.reset()
|
|
|
|
|
2020-05-12 11:31:47 +08:00
|
|
|
def _reset_state(self, id: Union[int, List[int]]) -> None:
|
2020-07-26 12:01:21 +02:00
|
|
|
"""Reset the hidden state: self.data.state[id]."""
|
2020-07-13 00:24:31 +08:00
|
|
|
state = self.data.state # it is a reference
|
|
|
|
if isinstance(state, torch.Tensor):
|
|
|
|
state[id].zero_()
|
|
|
|
elif isinstance(state, np.ndarray):
|
|
|
|
state[id] = None if state.dtype == np.object else 0
|
|
|
|
elif isinstance(state, Batch):
|
|
|
|
state.empty_(id)
|
2020-04-08 21:13:15 +08:00
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def collect(
|
|
|
|
self,
|
|
|
|
n_step: Optional[int] = None,
|
|
|
|
n_episode: Optional[Union[int, List[int]]] = None,
|
|
|
|
random: bool = False,
|
|
|
|
render: Optional[float] = None,
|
|
|
|
no_grad: bool = True,
|
|
|
|
) -> Dict[str, float]:
|
2020-04-05 18:34:45 +08:00
|
|
|
"""Collect a specified number of step or episode.
|
|
|
|
|
2020-04-06 19:36:59 +08:00
|
|
|
:param int n_step: how many steps you want to collect.
|
2020-07-23 16:40:53 +08:00
|
|
|
:param n_episode: how many episodes you want to collect. If it is an
|
|
|
|
int, it means to collect at lease ``n_episode`` episodes; if it is
|
|
|
|
a list, it means to collect exactly ``n_episode[i]`` episodes in
|
|
|
|
the i-th environment
|
2020-06-11 08:57:37 +08:00
|
|
|
:param bool random: whether to use random policy for collecting data,
|
2020-09-11 07:55:37 +08:00
|
|
|
defaults to False.
|
2020-04-06 19:36:59 +08:00
|
|
|
:param float render: the sleep time between rendering consecutive
|
2020-09-11 07:55:37 +08:00
|
|
|
frames, defaults to None (no rendering).
|
2020-09-06 16:20:16 +08:00
|
|
|
:param bool no_grad: whether to retain gradient in policy.forward,
|
2020-09-11 07:55:37 +08:00
|
|
|
defaults to True (no gradient retaining).
|
2020-04-05 18:34:45 +08:00
|
|
|
|
|
|
|
.. note::
|
|
|
|
|
|
|
|
One and only one collection number specification is permitted,
|
|
|
|
either ``n_step`` or ``n_episode``.
|
|
|
|
|
|
|
|
:return: A dict including the following keys
|
|
|
|
|
|
|
|
* ``n/ep`` the collected number of episodes.
|
|
|
|
* ``n/st`` the collected number of steps.
|
|
|
|
* ``v/st`` the speed of steps per second.
|
|
|
|
* ``v/ep`` the speed of episode per second.
|
|
|
|
* ``rew`` the mean reward over collected episodes.
|
|
|
|
* ``len`` the mean length over collected episodes.
|
|
|
|
"""
|
2020-08-27 12:15:18 +08:00
|
|
|
assert (n_step is not None and n_episode is None and n_step > 0) or (
|
|
|
|
n_step is None and n_episode is not None and np.sum(n_episode) > 0
|
|
|
|
), "Only one of n_step or n_episode is allowed in Collector.collect, "
|
|
|
|
f"got n_step = {n_step}, n_episode = {n_episode}."
|
2020-07-23 16:40:53 +08:00
|
|
|
start_time = time.time()
|
|
|
|
step_count = 0
|
|
|
|
# episode of each environment
|
|
|
|
episode_count = np.zeros(self.env_num)
|
2020-08-27 12:15:18 +08:00
|
|
|
# If n_episode is a list, and some envs have collected the required
|
|
|
|
# number of episodes, these envs will be recorded in this list, and
|
|
|
|
# they will not be stepped.
|
|
|
|
finished_env_ids = []
|
2020-07-23 16:40:53 +08:00
|
|
|
reward_total = 0.0
|
2020-07-26 12:01:21 +02:00
|
|
|
whole_data = Batch()
|
2020-09-13 19:31:50 +08:00
|
|
|
if isinstance(n_episode, list):
|
2020-08-27 12:15:18 +08:00
|
|
|
assert len(n_episode) == self.get_env_num()
|
|
|
|
finished_env_ids = [
|
|
|
|
i for i in self._ready_env_ids if n_episode[i] <= 0]
|
|
|
|
self._ready_env_ids = np.array(
|
|
|
|
[x for x in self._ready_env_ids if x not in finished_env_ids])
|
2020-03-12 22:20:33 +08:00
|
|
|
while True:
|
2020-07-23 16:40:53 +08:00
|
|
|
if step_count >= 100000 and episode_count.sum() == 0:
|
2020-03-28 07:27:18 +08:00
|
|
|
warnings.warn(
|
2020-09-12 15:39:01 +08:00
|
|
|
"There are already many steps in an episode. "
|
|
|
|
"You should add a time limitation to your environment!",
|
2020-03-28 07:27:18 +08:00
|
|
|
Warning)
|
2020-07-13 00:24:31 +08:00
|
|
|
|
2020-08-27 12:15:18 +08:00
|
|
|
is_async = self.is_async or len(finished_env_ids) > 0
|
|
|
|
if is_async:
|
|
|
|
# self.data are the data for all environments in async
|
|
|
|
# simulation or some envs have finished,
|
|
|
|
# **only a subset of data are disposed**,
|
2020-07-26 12:01:21 +02:00
|
|
|
# so we store the whole data in ``whole_data``, let self.data
|
2020-08-27 12:15:18 +08:00
|
|
|
# to be the data available in ready environments, and finally
|
|
|
|
# set these back into all the data
|
2020-07-26 12:01:21 +02:00
|
|
|
whole_data = self.data
|
|
|
|
self.data = self.data[self._ready_env_ids]
|
|
|
|
|
2020-07-13 00:24:31 +08:00
|
|
|
# restore the state and the input data
|
|
|
|
last_state = self.data.state
|
2020-08-19 15:00:24 +08:00
|
|
|
if isinstance(last_state, Batch) and last_state.is_empty():
|
2020-07-13 00:24:31 +08:00
|
|
|
last_state = None
|
|
|
|
self.data.update(state=Batch(), obs_next=Batch(), policy=Batch())
|
|
|
|
|
|
|
|
# calculate the next action
|
2020-06-11 08:57:37 +08:00
|
|
|
if random:
|
2020-08-19 15:00:24 +08:00
|
|
|
spaces = self._action_space
|
2020-07-23 16:40:53 +08:00
|
|
|
result = Batch(
|
2020-07-26 12:01:21 +02:00
|
|
|
act=[spaces[i].sample() for i in self._ready_env_ids])
|
2020-06-11 08:57:37 +08:00
|
|
|
else:
|
2020-09-06 16:20:16 +08:00
|
|
|
if no_grad:
|
|
|
|
with torch.no_grad(): # faster than retain_grad version
|
|
|
|
result = self.policy(self.data, last_state)
|
|
|
|
else:
|
2020-07-13 00:24:31 +08:00
|
|
|
result = self.policy(self.data, last_state)
|
2020-06-29 12:18:52 +08:00
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
state = result.get("state", Batch())
|
2020-07-27 16:54:14 +08:00
|
|
|
# convert None to Batch(), since None is reserved for 0-init
|
2020-07-13 00:24:31 +08:00
|
|
|
if state is None:
|
|
|
|
state = Batch()
|
2020-09-12 15:39:01 +08:00
|
|
|
self.data.update(state=state, policy=result.get("policy", Batch()))
|
2020-07-13 00:24:31 +08:00
|
|
|
# save hidden state to policy._state, in order to save into buffer
|
2020-08-27 12:15:18 +08:00
|
|
|
if not (isinstance(state, Batch) and state.is_empty()):
|
2020-08-19 15:00:24 +08:00
|
|
|
self.data.policy._state = self.data.state
|
2020-06-29 12:18:52 +08:00
|
|
|
|
2020-07-13 00:24:31 +08:00
|
|
|
self.data.act = to_numpy(result.act)
|
2020-09-13 19:31:50 +08:00
|
|
|
if self._action_noise is not None:
|
|
|
|
assert isinstance(self.data.act, np.ndarray)
|
2020-07-13 00:24:31 +08:00
|
|
|
self.data.act += self._action_noise(self.data.act.shape)
|
|
|
|
|
|
|
|
# step in env
|
2020-08-27 12:15:18 +08:00
|
|
|
if not is_async:
|
2020-07-26 12:01:21 +02:00
|
|
|
obs_next, rew, done, info = self.env.step(self.data.act)
|
|
|
|
else:
|
|
|
|
# store computed actions, states, etc
|
2020-09-12 15:39:01 +08:00
|
|
|
_batch_set_item(
|
|
|
|
whole_data, self._ready_env_ids, self.data, self.env_num)
|
2020-07-26 12:01:21 +02:00
|
|
|
# fetch finished data
|
|
|
|
obs_next, rew, done, info = self.env.step(
|
2020-08-27 12:15:18 +08:00
|
|
|
self.data.act, id=self._ready_env_ids)
|
2020-09-12 15:39:01 +08:00
|
|
|
self._ready_env_ids = np.array([i["env_id"] for i in info])
|
2020-07-26 12:01:21 +02:00
|
|
|
# get the stepped data
|
|
|
|
self.data = whole_data[self._ready_env_ids]
|
2020-07-13 00:24:31 +08:00
|
|
|
# move data to self.data
|
2020-07-23 16:40:53 +08:00
|
|
|
self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
|
|
|
|
|
2020-05-05 13:39:51 +08:00
|
|
|
if render:
|
2020-09-11 07:55:37 +08:00
|
|
|
self.env.render()
|
2020-07-23 16:40:53 +08:00
|
|
|
time.sleep(render)
|
2020-07-13 00:24:31 +08:00
|
|
|
|
|
|
|
# add data into the buffer
|
2020-05-05 13:39:51 +08:00
|
|
|
if self.preprocess_fn:
|
2020-09-13 19:31:50 +08:00
|
|
|
result = self.preprocess_fn(**self.data) # type: ignore
|
2020-07-13 00:24:31 +08:00
|
|
|
self.data.update(result)
|
2020-08-27 12:15:18 +08:00
|
|
|
|
2020-07-26 12:01:21 +02:00
|
|
|
for j, i in enumerate(self._ready_env_ids):
|
|
|
|
# j is the index in current ready_env_ids
|
|
|
|
# i is the index in all environments
|
2020-08-27 12:15:18 +08:00
|
|
|
if self.buffer is None:
|
|
|
|
# users do not want to store data, so we store
|
|
|
|
# small fake data here to make the code clean
|
|
|
|
self._cached_buf[i].add(obs=0, act=0, rew=rew[j], done=0)
|
|
|
|
else:
|
|
|
|
self._cached_buf[i].add(**self.data[j])
|
|
|
|
|
|
|
|
if done[j]:
|
2020-09-13 19:31:50 +08:00
|
|
|
if not (isinstance(n_episode, list)
|
|
|
|
and episode_count[i] >= n_episode[i]):
|
2020-07-23 16:40:53 +08:00
|
|
|
episode_count[i] += 1
|
|
|
|
reward_total += np.sum(self._cached_buf[i].rew, axis=0)
|
|
|
|
step_count += len(self._cached_buf[i])
|
|
|
|
if self.buffer is not None:
|
|
|
|
self.buffer.update(self._cached_buf[i])
|
2020-09-13 19:31:50 +08:00
|
|
|
if isinstance(n_episode, list) and \
|
2020-08-27 12:15:18 +08:00
|
|
|
episode_count[i] >= n_episode[i]:
|
|
|
|
# env i has collected enough data, it has finished
|
|
|
|
finished_env_ids.append(i)
|
2020-07-23 16:40:53 +08:00
|
|
|
self._cached_buf[i].reset()
|
2020-07-26 12:01:21 +02:00
|
|
|
self._reset_state(j)
|
2020-07-23 16:40:53 +08:00
|
|
|
obs_next = self.data.obs_next
|
2020-08-27 12:15:18 +08:00
|
|
|
if sum(done):
|
|
|
|
env_ind_local = np.where(done)[0]
|
2020-07-26 12:01:21 +02:00
|
|
|
env_ind_global = self._ready_env_ids[env_ind_local]
|
|
|
|
obs_reset = self.env.reset(env_ind_global)
|
2020-07-23 16:40:53 +08:00
|
|
|
if self.preprocess_fn:
|
2020-09-13 19:31:50 +08:00
|
|
|
obs_reset = self.preprocess_fn(
|
2020-09-12 15:39:01 +08:00
|
|
|
obs=obs_reset).get("obs", obs_reset)
|
2020-09-13 19:31:50 +08:00
|
|
|
obs_next[env_ind_local] = obs_reset
|
2020-07-23 16:40:53 +08:00
|
|
|
self.data.obs = obs_next
|
2020-08-27 12:15:18 +08:00
|
|
|
if is_async:
|
2020-07-26 12:01:21 +02:00
|
|
|
# set data back
|
2020-08-27 12:15:18 +08:00
|
|
|
whole_data = deepcopy(whole_data) # avoid reference in ListBuf
|
2020-09-12 15:39:01 +08:00
|
|
|
_batch_set_item(
|
|
|
|
whole_data, self._ready_env_ids, self.data, self.env_num)
|
2020-07-26 12:01:21 +02:00
|
|
|
# let self.data be the data in all environments again
|
|
|
|
self.data = whole_data
|
2020-08-27 12:15:18 +08:00
|
|
|
self._ready_env_ids = np.array(
|
|
|
|
[x for x in self._ready_env_ids if x not in finished_env_ids])
|
2020-07-23 16:40:53 +08:00
|
|
|
if n_step:
|
|
|
|
if step_count >= n_step:
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
if isinstance(n_episode, int) and \
|
|
|
|
episode_count.sum() >= n_episode:
|
|
|
|
break
|
|
|
|
if isinstance(n_episode, list) and \
|
|
|
|
(episode_count >= n_episode).all():
|
2020-03-12 22:20:33 +08:00
|
|
|
break
|
2020-07-13 00:24:31 +08:00
|
|
|
|
2020-08-27 12:15:18 +08:00
|
|
|
# finished envs are ready, and can be used for the next collection
|
|
|
|
self._ready_env_ids = np.array(
|
|
|
|
self._ready_env_ids.tolist() + finished_env_ids)
|
|
|
|
|
2020-07-13 00:24:31 +08:00
|
|
|
# generate the statistics
|
2020-07-23 16:40:53 +08:00
|
|
|
episode_count = sum(episode_count)
|
2020-04-09 19:53:45 +08:00
|
|
|
duration = max(time.time() - start_time, 1e-9)
|
2020-07-23 16:40:53 +08:00
|
|
|
self.collect_step += step_count
|
|
|
|
self.collect_episode += episode_count
|
2020-03-20 19:52:29 +08:00
|
|
|
self.collect_time += duration
|
2020-07-23 16:40:53 +08:00
|
|
|
# average reward across the number of episodes
|
|
|
|
reward_avg = reward_total / episode_count
|
|
|
|
if np.asanyarray(reward_avg).size > 1: # non-scalar reward_avg
|
2020-09-13 19:31:50 +08:00
|
|
|
reward_avg = self._rew_metric(reward_avg) # type: ignore
|
2020-03-16 15:04:58 +08:00
|
|
|
return {
|
2020-09-12 15:39:01 +08:00
|
|
|
"n/ep": episode_count,
|
|
|
|
"n/st": step_count,
|
|
|
|
"v/st": step_count / duration,
|
|
|
|
"v/ep": episode_count / duration,
|
|
|
|
"rew": reward_avg,
|
|
|
|
"len": step_count / episode_count,
|
2020-03-16 15:04:58 +08:00
|
|
|
}
|
2020-03-12 22:20:33 +08:00
|
|
|
|
2020-07-26 12:01:21 +02:00
|
|
|
|
2020-09-12 15:39:01 +08:00
|
|
|
def _batch_set_item(
|
|
|
|
source: Batch, indices: np.ndarray, target: Batch, size: int
|
|
|
|
) -> None:
|
2020-08-19 15:00:24 +08:00
|
|
|
# for any key chain k, there are four cases
|
2020-07-26 12:01:21 +02:00
|
|
|
# 1. source[k] is non-reserved, but target[k] does not exist or is reserved
|
|
|
|
# 2. source[k] does not exist or is reserved, but target[k] is non-reserved
|
2020-08-19 15:00:24 +08:00
|
|
|
# 3. both source[k] and target[k] are non-reserved
|
|
|
|
# 4. both source[k] and target[k] do not exist or are reserved, do nothing.
|
|
|
|
# A special case in case 4, if target[k] is reserved but source[k] does
|
|
|
|
# not exist, make source[k] reserved, too.
|
|
|
|
for k, vt in target.items():
|
|
|
|
if not isinstance(vt, Batch) or not vt.is_empty():
|
2020-07-26 12:01:21 +02:00
|
|
|
# target[k] is non-reserved
|
|
|
|
vs = source.get(k, Batch())
|
2020-08-19 15:00:24 +08:00
|
|
|
if isinstance(vs, Batch):
|
|
|
|
if vs.is_empty():
|
|
|
|
# case 2, use __dict__ to avoid many type checks
|
|
|
|
source.__dict__[k] = _create_value(vt[0], size)
|
|
|
|
else:
|
|
|
|
assert isinstance(vt, Batch)
|
|
|
|
_batch_set_item(source.__dict__[k], indices, vt, size)
|
2020-07-26 12:01:21 +02:00
|
|
|
else:
|
|
|
|
# target[k] is reserved
|
2020-08-19 15:00:24 +08:00
|
|
|
# case 1 or special case of case 4
|
|
|
|
if k not in source.__dict__:
|
|
|
|
source.__dict__[k] = Batch()
|
2020-07-26 12:01:21 +02:00
|
|
|
continue
|
2020-08-19 15:00:24 +08:00
|
|
|
source.__dict__[k][indices] = vt
|