Improve collector (#125)
* remove multibuf * reward_metric * make fileds with empty Batch rather than None after reset * many fixes and refactor Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
parent
5599a6d1a6
commit
26fb87433d
@ -1,19 +1,34 @@
|
|||||||
import time
|
|
||||||
import gym
|
import gym
|
||||||
|
import time
|
||||||
from gym.spaces.discrete import Discrete
|
from gym.spaces.discrete import Discrete
|
||||||
|
|
||||||
|
|
||||||
class MyTestEnv(gym.Env):
|
class MyTestEnv(gym.Env):
|
||||||
def __init__(self, size, sleep=0, dict_state=False):
|
"""This is a "going right" task. The task is to go right ``size`` steps.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, size, sleep=0, dict_state=False, ma_rew=0):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.sleep = sleep
|
self.sleep = sleep
|
||||||
self.dict_state = dict_state
|
self.dict_state = dict_state
|
||||||
|
self.ma_rew = ma_rew
|
||||||
self.action_space = Discrete(2)
|
self.action_space = Discrete(2)
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self, state=0):
|
def reset(self, state=0):
|
||||||
self.done = False
|
self.done = False
|
||||||
self.index = state
|
self.index = state
|
||||||
|
return self._get_dict_state()
|
||||||
|
|
||||||
|
def _get_reward(self):
|
||||||
|
"""Generate a non-scalar reward if ma_rew is True."""
|
||||||
|
x = int(self.done)
|
||||||
|
if self.ma_rew > 0:
|
||||||
|
return [x] * self.ma_rew
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _get_dict_state(self):
|
||||||
|
"""Generate a dict_state if dict_state is True."""
|
||||||
return {'index': self.index} if self.dict_state else self.index
|
return {'index': self.index} if self.dict_state else self.index
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
@ -23,22 +38,13 @@ class MyTestEnv(gym.Env):
|
|||||||
time.sleep(self.sleep)
|
time.sleep(self.sleep)
|
||||||
if self.index == self.size:
|
if self.index == self.size:
|
||||||
self.done = True
|
self.done = True
|
||||||
if self.dict_state:
|
return self._get_dict_state(), self._get_reward(), self.done, {}
|
||||||
return {'index': self.index}, 0, True, {}
|
|
||||||
else:
|
|
||||||
return self.index, 0, True, {}
|
|
||||||
if action == 0:
|
if action == 0:
|
||||||
self.index = max(self.index - 1, 0)
|
self.index = max(self.index - 1, 0)
|
||||||
if self.dict_state:
|
return self._get_dict_state(), self._get_reward(), self.done, \
|
||||||
return {'index': self.index}, 0, False, {'key': 1, 'env': self}
|
{'key': 1, 'env': self} if self.dict_state else {}
|
||||||
else:
|
|
||||||
return self.index, 0, False, {}
|
|
||||||
elif action == 1:
|
elif action == 1:
|
||||||
self.index += 1
|
self.index += 1
|
||||||
self.done = self.index == self.size
|
self.done = self.index == self.size
|
||||||
if self.dict_state:
|
return self._get_dict_state(), self._get_reward(), \
|
||||||
return {'index': self.index}, int(self.done), self.done, \
|
self.done, {'key': 1, 'env': self}
|
||||||
{'key': 1, 'env': self}
|
|
||||||
else:
|
|
||||||
return self.index, int(self.done), self.done, \
|
|
||||||
{'key': 1, 'env': self}
|
|
||||||
|
@ -27,16 +27,16 @@ class MyPolicy(BasePolicy):
|
|||||||
|
|
||||||
def preprocess_fn(**kwargs):
|
def preprocess_fn(**kwargs):
|
||||||
# modify info before adding into the buffer
|
# modify info before adding into the buffer
|
||||||
if kwargs.get('info', None) is not None:
|
# if info is not provided from env, it will be a ``Batch()``.
|
||||||
|
if not kwargs.get('info', Batch()).is_empty():
|
||||||
n = len(kwargs['obs'])
|
n = len(kwargs['obs'])
|
||||||
info = kwargs['info']
|
info = kwargs['info']
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
info[i].update(rew=kwargs['rew'][i])
|
info[i].update(rew=kwargs['rew'][i])
|
||||||
return {'info': info}
|
return {'info': info}
|
||||||
# or
|
# or: return Batch(info=info)
|
||||||
# return Batch(info=info)
|
|
||||||
else:
|
else:
|
||||||
return {}
|
return Batch()
|
||||||
|
|
||||||
|
|
||||||
class Logger(object):
|
class Logger(object):
|
||||||
@ -119,6 +119,48 @@ def test_collector_with_dict_state():
|
|||||||
print(batch['obs_next']['index'])
|
print(batch['obs_next']['index'])
|
||||||
|
|
||||||
|
|
||||||
|
def test_collector_with_ma():
|
||||||
|
def reward_metric(x):
|
||||||
|
return x.sum()
|
||||||
|
env = MyTestEnv(size=5, sleep=0, ma_rew=4)
|
||||||
|
policy = MyPolicy()
|
||||||
|
c0 = Collector(policy, env, ReplayBuffer(size=100),
|
||||||
|
preprocess_fn, reward_metric=reward_metric)
|
||||||
|
r = c0.collect(n_step=3)['rew']
|
||||||
|
assert np.asanyarray(r).size == 1 and r == 0.
|
||||||
|
r = c0.collect(n_episode=3)['rew']
|
||||||
|
assert np.asanyarray(r).size == 1 and r == 4.
|
||||||
|
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4)
|
||||||
|
for i in [2, 3, 4, 5]]
|
||||||
|
envs = VectorEnv(env_fns)
|
||||||
|
c1 = Collector(policy, envs, ReplayBuffer(size=100),
|
||||||
|
preprocess_fn, reward_metric=reward_metric)
|
||||||
|
r = c1.collect(n_step=10)['rew']
|
||||||
|
assert np.asanyarray(r).size == 1 and r == 4.
|
||||||
|
r = c1.collect(n_episode=[2, 1, 1, 2])['rew']
|
||||||
|
assert np.asanyarray(r).size == 1 and r == 4.
|
||||||
|
batch = c1.sample(10)
|
||||||
|
print(batch)
|
||||||
|
c0.buffer.update(c1.buffer)
|
||||||
|
obs = [
|
||||||
|
0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1.,
|
||||||
|
0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0.,
|
||||||
|
1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.]
|
||||||
|
assert np.allclose(c0.buffer[:len(c0.buffer)].obs, obs)
|
||||||
|
rew = [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1,
|
||||||
|
0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0,
|
||||||
|
0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1]
|
||||||
|
assert np.allclose(c0.buffer[:len(c0.buffer)].rew,
|
||||||
|
[[x] * 4 for x in rew])
|
||||||
|
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4),
|
||||||
|
preprocess_fn, reward_metric=reward_metric)
|
||||||
|
r = c2.collect(n_episode=[0, 0, 0, 10])['rew']
|
||||||
|
assert np.asanyarray(r).size == 1 and r == 4.
|
||||||
|
batch = c2.sample(10)
|
||||||
|
print(batch['obs_next'])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_collector()
|
test_collector()
|
||||||
test_collector_with_dict_state()
|
test_collector_with_dict_state()
|
||||||
|
test_collector_with_ma()
|
||||||
|
@ -8,8 +8,8 @@ from typing import Any, Dict, List, Union, Optional, Callable
|
|||||||
from tianshou.utils import MovAvg
|
from tianshou.utils import MovAvg
|
||||||
from tianshou.env import BaseVectorEnv
|
from tianshou.env import BaseVectorEnv
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
|
|
||||||
from tianshou.exploration import BaseNoise
|
from tianshou.exploration import BaseNoise
|
||||||
|
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
|
||||||
|
|
||||||
|
|
||||||
class Collector(object):
|
class Collector(object):
|
||||||
@ -25,12 +25,18 @@ class Collector(object):
|
|||||||
``None``, it will automatically assign a small-size
|
``None``, it will automatically assign a small-size
|
||||||
:class:`~tianshou.data.ReplayBuffer`.
|
:class:`~tianshou.data.ReplayBuffer`.
|
||||||
:param function preprocess_fn: a function called before the data has been
|
:param function preprocess_fn: a function called before the data has been
|
||||||
added to the buffer, see issue #42, defaults to ``None``.
|
added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults
|
||||||
|
to ``None``.
|
||||||
:param int stat_size: for the moving average of recording speed, defaults
|
:param int stat_size: for the moving average of recording speed, defaults
|
||||||
to 100.
|
to 100.
|
||||||
:param BaseNoise action_noise: add a noise to continuous action. Normally
|
:param BaseNoise action_noise: add a noise to continuous action. Normally
|
||||||
a policy already has a noise param for exploration in training phase,
|
a policy already has a noise param for exploration in training phase,
|
||||||
so this is recommended to use in test collector for some purpose.
|
so this is recommended to use in test collector for some purpose.
|
||||||
|
:param function reward_metric: to be used in multi-agent RL. The reward to
|
||||||
|
report is of shape [agent_num], but we need to return a single scalar
|
||||||
|
to monitor training. This function specifies what is the desired
|
||||||
|
metric, e.g., the reward of agent 1 or the average reward over all
|
||||||
|
agents. By default, the behavior is to select the reward of agent 1.
|
||||||
|
|
||||||
The ``preprocess_fn`` is a function called before the data has been added
|
The ``preprocess_fn`` is a function called before the data has been added
|
||||||
to the buffer with batch format, which receives up to 7 keys as listed in
|
to the buffer with batch format, which receives up to 7 keys as listed in
|
||||||
@ -87,66 +93,56 @@ class Collector(object):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
policy: BasePolicy,
|
policy: BasePolicy,
|
||||||
env: Union[gym.Env, BaseVectorEnv],
|
env: Union[gym.Env, BaseVectorEnv],
|
||||||
buffer: Optional[Union[ReplayBuffer, List[ReplayBuffer]]]
|
buffer: Optional[ReplayBuffer] = None,
|
||||||
= None,
|
|
||||||
preprocess_fn: Callable[[Any], Union[dict, Batch]] = None,
|
preprocess_fn: Callable[[Any], Union[dict, Batch]] = None,
|
||||||
stat_size: Optional[int] = 100,
|
stat_size: Optional[int] = 100,
|
||||||
action_noise: Optional[BaseNoise] = None,
|
action_noise: Optional[BaseNoise] = None,
|
||||||
|
reward_metric: Optional[Callable[[np.ndarray], float]] = None,
|
||||||
**kwargs) -> None:
|
**kwargs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.env = env
|
self.env = env
|
||||||
self.env_num = 1
|
self.env_num = 1
|
||||||
self.collect_time = 0
|
self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
|
||||||
self.collect_step = 0
|
|
||||||
self.collect_episode = 0
|
|
||||||
self.buffer = buffer
|
self.buffer = buffer
|
||||||
self.policy = policy
|
self.policy = policy
|
||||||
self.preprocess_fn = preprocess_fn
|
self.preprocess_fn = preprocess_fn
|
||||||
# if preprocess_fn is None:
|
|
||||||
# def _prep(**kwargs):
|
|
||||||
# return kwargs
|
|
||||||
# self.preprocess_fn = _prep
|
|
||||||
self.process_fn = policy.process_fn
|
self.process_fn = policy.process_fn
|
||||||
self._multi_env = isinstance(env, BaseVectorEnv)
|
self._multi_env = isinstance(env, BaseVectorEnv)
|
||||||
self._multi_buf = False # True if buf is a list
|
|
||||||
# need multiple cache buffers only if storing in one buffer
|
# need multiple cache buffers only if storing in one buffer
|
||||||
self._cached_buf = []
|
self._cached_buf = []
|
||||||
if self._multi_env:
|
if self._multi_env:
|
||||||
self.env_num = len(env)
|
self.env_num = len(env)
|
||||||
if isinstance(self.buffer, list):
|
self._cached_buf = [ListReplayBuffer()
|
||||||
assert len(self.buffer) == self.env_num, \
|
for _ in range(self.env_num)]
|
||||||
'The number of data buffer does not match the number of ' \
|
|
||||||
'input env.'
|
|
||||||
self._multi_buf = True
|
|
||||||
elif isinstance(self.buffer, ReplayBuffer) or self.buffer is None:
|
|
||||||
self._cached_buf = [
|
|
||||||
ListReplayBuffer() for _ in range(self.env_num)]
|
|
||||||
else:
|
|
||||||
raise TypeError('The buffer in data collector is invalid!')
|
|
||||||
self.stat_size = stat_size
|
self.stat_size = stat_size
|
||||||
self._action_noise = action_noise
|
self._action_noise = action_noise
|
||||||
|
|
||||||
|
self._rew_metric = reward_metric or Collector._default_rew_metric
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _default_rew_metric(x):
|
||||||
|
# this internal function is designed for single-agent RL
|
||||||
|
# for multi-agent RL, a reward_metric must be provided
|
||||||
|
assert np.asanyarray(x).size == 1, \
|
||||||
|
'Please specify the reward_metric ' \
|
||||||
|
'since the reward is not a scalar.'
|
||||||
|
return x
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""Reset all related variables in the collector."""
|
"""Reset all related variables in the collector."""
|
||||||
|
self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={},
|
||||||
|
obs_next={}, policy={})
|
||||||
self.reset_env()
|
self.reset_env()
|
||||||
self.reset_buffer()
|
self.reset_buffer()
|
||||||
# state over batch is either a list, an np.ndarray, or a torch.Tensor
|
|
||||||
self.state = None
|
|
||||||
self.step_speed = MovAvg(self.stat_size)
|
self.step_speed = MovAvg(self.stat_size)
|
||||||
self.episode_speed = MovAvg(self.stat_size)
|
self.episode_speed = MovAvg(self.stat_size)
|
||||||
self.collect_step = 0
|
self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
|
||||||
self.collect_episode = 0
|
|
||||||
self.collect_time = 0
|
|
||||||
if self._action_noise is not None:
|
if self._action_noise is not None:
|
||||||
self._action_noise.reset()
|
self._action_noise.reset()
|
||||||
|
|
||||||
def reset_buffer(self) -> None:
|
def reset_buffer(self) -> None:
|
||||||
"""Reset the main data buffer."""
|
"""Reset the main data buffer."""
|
||||||
if self._multi_buf:
|
|
||||||
for b in self.buffer:
|
|
||||||
b.reset()
|
|
||||||
else:
|
|
||||||
if self.buffer is not None:
|
if self.buffer is not None:
|
||||||
self.buffer.reset()
|
self.buffer.reset()
|
||||||
|
|
||||||
@ -158,33 +154,27 @@ class Collector(object):
|
|||||||
"""Reset all of the environment(s)' states and reset all of the cache
|
"""Reset all of the environment(s)' states and reset all of the cache
|
||||||
buffers (if need).
|
buffers (if need).
|
||||||
"""
|
"""
|
||||||
self._obs = self.env.reset()
|
obs = self.env.reset()
|
||||||
if not self._multi_env:
|
if not self._multi_env:
|
||||||
self._obs = self._make_batch(self._obs)
|
obs = self._make_batch(obs)
|
||||||
if self.preprocess_fn:
|
if self.preprocess_fn:
|
||||||
self._obs = self.preprocess_fn(obs=self._obs).get('obs', self._obs)
|
obs = self.preprocess_fn(obs=obs).get('obs', obs)
|
||||||
self._act = self._rew = self._done = self._info = None
|
self.data.obs = obs
|
||||||
if self._multi_env:
|
self.reward = 0. # will be specified when the first data is ready
|
||||||
self.reward = np.zeros(self.env_num)
|
|
||||||
self.length = np.zeros(self.env_num)
|
self.length = np.zeros(self.env_num)
|
||||||
else:
|
|
||||||
self.reward, self.length = 0, 0
|
|
||||||
for b in self._cached_buf:
|
for b in self._cached_buf:
|
||||||
b.reset()
|
b.reset()
|
||||||
|
|
||||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
|
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
|
||||||
"""Reset all the seed(s) of the given environment(s)."""
|
"""Reset all the seed(s) of the given environment(s)."""
|
||||||
if hasattr(self.env, 'seed'):
|
|
||||||
return self.env.seed(seed)
|
return self.env.seed(seed)
|
||||||
|
|
||||||
def render(self, **kwargs) -> None:
|
def render(self, **kwargs) -> None:
|
||||||
"""Render all the environment(s)."""
|
"""Render all the environment(s)."""
|
||||||
if hasattr(self.env, 'render'):
|
|
||||||
return self.env.render(**kwargs)
|
return self.env.render(**kwargs)
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""Close the environment(s)."""
|
"""Close the environment(s)."""
|
||||||
if hasattr(self.env, 'close'):
|
|
||||||
self.env.close()
|
self.env.close()
|
||||||
|
|
||||||
def _make_batch(self, data: Any) -> np.ndarray:
|
def _make_batch(self, data: Any) -> np.ndarray:
|
||||||
@ -195,20 +185,14 @@ class Collector(object):
|
|||||||
return np.array([data])
|
return np.array([data])
|
||||||
|
|
||||||
def _reset_state(self, id: Union[int, List[int]]) -> None:
|
def _reset_state(self, id: Union[int, List[int]]) -> None:
|
||||||
"""Reset self.state[id]."""
|
"""Reset self.data.state[id]."""
|
||||||
if self.state is None:
|
state = self.data.state # it is a reference
|
||||||
return
|
if isinstance(state, torch.Tensor):
|
||||||
if isinstance(self.state, list):
|
state[id].zero_()
|
||||||
self.state[id] = None
|
elif isinstance(state, np.ndarray):
|
||||||
elif isinstance(self.state, torch.Tensor):
|
state[id] = None if state.dtype == np.object else 0
|
||||||
self.state[id].zero_()
|
elif isinstance(state, Batch):
|
||||||
elif isinstance(self.state, np.ndarray):
|
state.empty_(id)
|
||||||
if isinstance(self.state.dtype == np.object):
|
|
||||||
self.state[id] = None
|
|
||||||
else:
|
|
||||||
self.state[id] = 0
|
|
||||||
elif isinstance(self.state, Batch):
|
|
||||||
self.state.empty_(id)
|
|
||||||
|
|
||||||
def collect(self,
|
def collect(self,
|
||||||
n_step: int = 0,
|
n_step: int = 0,
|
||||||
@ -244,26 +228,27 @@ class Collector(object):
|
|||||||
* ``rew`` the mean reward over collected episodes.
|
* ``rew`` the mean reward over collected episodes.
|
||||||
* ``len`` the mean length over collected episodes.
|
* ``len`` the mean length over collected episodes.
|
||||||
"""
|
"""
|
||||||
warning_count = 0
|
|
||||||
if not self._multi_env:
|
if not self._multi_env:
|
||||||
n_episode = np.sum(n_episode)
|
n_episode = np.sum(n_episode)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
assert sum([(n_step != 0), (n_episode != 0)]) == 1, \
|
assert sum([(n_step != 0), (n_episode != 0)]) == 1, \
|
||||||
"One and only one collection number specification is permitted!"
|
"One and only one collection number specification is permitted!"
|
||||||
cur_step = 0
|
cur_step, cur_episode = 0, np.zeros(self.env_num)
|
||||||
cur_episode = np.zeros(self.env_num) if self._multi_env else 0
|
reward_sum, length_sum = 0., 0
|
||||||
reward_sum = 0
|
|
||||||
length_sum = 0
|
|
||||||
while True:
|
while True:
|
||||||
if warning_count >= 100000:
|
if cur_step >= 100000 and cur_episode.sum() == 0:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
'There are already many steps in an episode. '
|
'There are already many steps in an episode. '
|
||||||
'You should add a time limitation to your environment!',
|
'You should add a time limitation to your environment!',
|
||||||
Warning)
|
Warning)
|
||||||
batch = Batch(
|
|
||||||
obs=self._obs, act=self._act, rew=self._rew,
|
# restore the state and the input data
|
||||||
done=self._done, obs_next=None, info=self._info,
|
last_state = self.data.state
|
||||||
policy=None)
|
if last_state.is_empty():
|
||||||
|
last_state = None
|
||||||
|
self.data.update(state=Batch(), obs_next=Batch(), policy=Batch())
|
||||||
|
|
||||||
|
# calculate the next action
|
||||||
if random:
|
if random:
|
||||||
action_space = self.env.action_space
|
action_space = self.env.action_space
|
||||||
if isinstance(action_space, list):
|
if isinstance(action_space, list):
|
||||||
@ -272,69 +257,54 @@ class Collector(object):
|
|||||||
result = Batch(act=self._make_batch(action_space.sample()))
|
result = Batch(act=self._make_batch(action_space.sample()))
|
||||||
else:
|
else:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
result = self.policy(batch, self.state)
|
result = self.policy(self.data, last_state)
|
||||||
|
|
||||||
# save hidden state to policy._state, in order to save into buffer
|
# convert None to Batch(), since None is reserved for 0-init
|
||||||
self.state = result.get('state', None)
|
state = result.get('state', Batch())
|
||||||
|
if state is None:
|
||||||
|
state = Batch()
|
||||||
|
self.data.state = state
|
||||||
if hasattr(result, 'policy'):
|
if hasattr(result, 'policy'):
|
||||||
self._policy = to_numpy(result.policy)
|
self.data.policy = to_numpy(result.policy)
|
||||||
if self.state is not None:
|
# save hidden state to policy._state, in order to save into buffer
|
||||||
self._policy._state = self.state
|
self.data.policy._state = self.data.state
|
||||||
elif self.state is not None:
|
|
||||||
self._policy = Batch(_state=self.state)
|
|
||||||
else:
|
|
||||||
self._policy = [{}] * self.env_num
|
|
||||||
|
|
||||||
self._act = to_numpy(result.act)
|
self.data.act = to_numpy(result.act)
|
||||||
if self._action_noise is not None:
|
if self._action_noise is not None:
|
||||||
self._act += self._action_noise(self._act.shape)
|
self.data.act += self._action_noise(self.data.act.shape)
|
||||||
obs_next, self._rew, self._done, self._info = self.env.step(
|
|
||||||
self._act if self._multi_env else self._act[0])
|
# step in env
|
||||||
|
obs_next, rew, done, info = self.env.step(
|
||||||
|
self.data.act if self._multi_env else self.data.act[0])
|
||||||
|
|
||||||
|
# move data to self.data
|
||||||
if not self._multi_env:
|
if not self._multi_env:
|
||||||
obs_next = self._make_batch(obs_next)
|
obs_next = self._make_batch(obs_next)
|
||||||
self._rew = self._make_batch(self._rew)
|
rew = self._make_batch(rew)
|
||||||
self._done = self._make_batch(self._done)
|
done = self._make_batch(done)
|
||||||
self._info = self._make_batch(self._info)
|
info = self._make_batch(info)
|
||||||
|
self.data.obs_next = obs_next
|
||||||
|
self.data.rew = rew
|
||||||
|
self.data.done = done
|
||||||
|
self.data.info = info
|
||||||
|
|
||||||
if log_fn:
|
if log_fn:
|
||||||
log_fn(self._info if self._multi_env else self._info[0])
|
log_fn(info if self._multi_env else info[0])
|
||||||
if render:
|
if render:
|
||||||
self.env.render()
|
self.render()
|
||||||
if render > 0:
|
if render > 0:
|
||||||
time.sleep(render)
|
time.sleep(render)
|
||||||
|
|
||||||
|
# add data into the buffer
|
||||||
self.length += 1
|
self.length += 1
|
||||||
self.reward += self._rew
|
self.reward += self.data.rew
|
||||||
if self.preprocess_fn:
|
if self.preprocess_fn:
|
||||||
result = self.preprocess_fn(
|
result = self.preprocess_fn(**self.data)
|
||||||
obs=self._obs, act=self._act, rew=self._rew,
|
self.data.update(result)
|
||||||
done=self._done, obs_next=obs_next, info=self._info,
|
if self._multi_env: # cache_buffer branch
|
||||||
policy=self._policy)
|
|
||||||
self._obs = result.get('obs', self._obs)
|
|
||||||
self._act = result.get('act', self._act)
|
|
||||||
self._rew = result.get('rew', self._rew)
|
|
||||||
self._done = result.get('done', self._done)
|
|
||||||
obs_next = result.get('obs_next', obs_next)
|
|
||||||
self._info = result.get('info', self._info)
|
|
||||||
self._policy = result.get('policy', self._policy)
|
|
||||||
if self._multi_env:
|
|
||||||
for i in range(self.env_num):
|
for i in range(self.env_num):
|
||||||
data = {
|
self._cached_buf[i].add(**self.data[i])
|
||||||
'obs': self._obs[i], 'act': self._act[i],
|
if self.data.done[i]:
|
||||||
'rew': self._rew[i], 'done': self._done[i],
|
|
||||||
'obs_next': obs_next[i], 'info': self._info[i],
|
|
||||||
'policy': self._policy[i]}
|
|
||||||
if self._cached_buf:
|
|
||||||
warning_count += 1
|
|
||||||
self._cached_buf[i].add(**data)
|
|
||||||
elif self._multi_buf:
|
|
||||||
warning_count += 1
|
|
||||||
self.buffer[i].add(**data)
|
|
||||||
cur_step += 1
|
|
||||||
else:
|
|
||||||
warning_count += 1
|
|
||||||
if self.buffer is not None:
|
|
||||||
self.buffer.add(**data)
|
|
||||||
cur_step += 1
|
|
||||||
if self._done[i]:
|
|
||||||
if n_step != 0 or np.isscalar(n_episode) or \
|
if n_step != 0 or np.isscalar(n_episode) or \
|
||||||
cur_episode[i] < n_episode[i]:
|
cur_episode[i] < n_episode[i]:
|
||||||
cur_episode[i] += 1
|
cur_episode[i] += 1
|
||||||
@ -344,45 +314,46 @@ class Collector(object):
|
|||||||
cur_step += len(self._cached_buf[i])
|
cur_step += len(self._cached_buf[i])
|
||||||
if self.buffer is not None:
|
if self.buffer is not None:
|
||||||
self.buffer.update(self._cached_buf[i])
|
self.buffer.update(self._cached_buf[i])
|
||||||
self.reward[i], self.length[i] = 0, 0
|
self.reward[i], self.length[i] = 0., 0
|
||||||
if self._cached_buf:
|
if self._cached_buf:
|
||||||
self._cached_buf[i].reset()
|
self._cached_buf[i].reset()
|
||||||
self._reset_state(i)
|
self._reset_state(i)
|
||||||
if sum(self._done):
|
obs_next = self.data.obs_next
|
||||||
obs_next = self.env.reset(np.where(self._done)[0])
|
if sum(self.data.done):
|
||||||
|
obs_next = self.env.reset(np.where(self.data.done)[0])
|
||||||
if self.preprocess_fn:
|
if self.preprocess_fn:
|
||||||
obs_next = self.preprocess_fn(obs=obs_next).get(
|
obs_next = self.preprocess_fn(obs=obs_next).get(
|
||||||
'obs', obs_next)
|
'obs', obs_next)
|
||||||
|
self.data.obs_next = obs_next
|
||||||
if n_episode != 0:
|
if n_episode != 0:
|
||||||
if isinstance(n_episode, list) and \
|
if isinstance(n_episode, list) and \
|
||||||
(cur_episode >= np.array(n_episode)).all() or \
|
(cur_episode >= np.array(n_episode)).all() or \
|
||||||
np.isscalar(n_episode) and \
|
np.isscalar(n_episode) and \
|
||||||
cur_episode.sum() >= n_episode:
|
cur_episode.sum() >= n_episode:
|
||||||
break
|
break
|
||||||
else:
|
else: # single buffer, without cache_buffer
|
||||||
if self.buffer is not None:
|
if self.buffer is not None:
|
||||||
self.buffer.add(
|
self.buffer.add(**self.data[0])
|
||||||
self._obs[0], self._act[0], self._rew[0],
|
|
||||||
self._done[0], obs_next[0], self._info[0],
|
|
||||||
self._policy[0])
|
|
||||||
cur_step += 1
|
cur_step += 1
|
||||||
if self._done:
|
if self.data.done[0]:
|
||||||
cur_episode += 1
|
cur_episode += 1
|
||||||
reward_sum += self.reward[0]
|
reward_sum += self.reward[0]
|
||||||
length_sum += self.length
|
length_sum += self.length[0]
|
||||||
self.reward, self.length = 0, 0
|
self.reward, self.length = 0., np.zeros(self.env_num)
|
||||||
self.state = None
|
self.data.state = Batch()
|
||||||
obs_next = self._make_batch(self.env.reset())
|
obs_next = self._make_batch(self.env.reset())
|
||||||
if self.preprocess_fn:
|
if self.preprocess_fn:
|
||||||
obs_next = self.preprocess_fn(obs=obs_next).get(
|
obs_next = self.preprocess_fn(obs=obs_next).get(
|
||||||
'obs', obs_next)
|
'obs', obs_next)
|
||||||
|
self.data.obs_next = obs_next
|
||||||
if n_episode != 0 and cur_episode >= n_episode:
|
if n_episode != 0 and cur_episode >= n_episode:
|
||||||
break
|
break
|
||||||
if n_step != 0 and cur_step >= n_step:
|
if n_step != 0 and cur_step >= n_step:
|
||||||
break
|
break
|
||||||
self._obs = obs_next
|
self.data.obs = self.data.obs_next
|
||||||
self._obs = obs_next
|
self.data.obs = self.data.obs_next
|
||||||
if self._multi_env:
|
|
||||||
|
# generate the statistics
|
||||||
cur_episode = sum(cur_episode)
|
cur_episode = sum(cur_episode)
|
||||||
duration = max(time.time() - start_time, 1e-9)
|
duration = max(time.time() - start_time, 1e-9)
|
||||||
self.step_speed.add(cur_step / duration)
|
self.step_speed.add(cur_step / duration)
|
||||||
@ -394,12 +365,15 @@ class Collector(object):
|
|||||||
n_episode = np.sum(n_episode)
|
n_episode = np.sum(n_episode)
|
||||||
else:
|
else:
|
||||||
n_episode = max(cur_episode, 1)
|
n_episode = max(cur_episode, 1)
|
||||||
|
reward_sum /= n_episode
|
||||||
|
if np.asanyarray(reward_sum).size > 1: # non-scalar reward_sum
|
||||||
|
reward_sum = self._rew_metric(reward_sum)
|
||||||
return {
|
return {
|
||||||
'n/ep': cur_episode,
|
'n/ep': cur_episode,
|
||||||
'n/st': cur_step,
|
'n/st': cur_step,
|
||||||
'v/st': self.step_speed.get(),
|
'v/st': self.step_speed.get(),
|
||||||
'v/ep': self.episode_speed.get(),
|
'v/ep': self.episode_speed.get(),
|
||||||
'rew': reward_sum / n_episode,
|
'rew': reward_sum,
|
||||||
'len': length_sum / n_episode,
|
'len': length_sum / n_episode,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -412,22 +386,6 @@ class Collector(object):
|
|||||||
the buffer, otherwise it will extract the data with the given
|
the buffer, otherwise it will extract the data with the given
|
||||||
batch_size.
|
batch_size.
|
||||||
"""
|
"""
|
||||||
if self._multi_buf:
|
|
||||||
if batch_size > 0:
|
|
||||||
lens = [len(b) for b in self.buffer]
|
|
||||||
total = sum(lens)
|
|
||||||
batch_index = np.random.choice(
|
|
||||||
len(self.buffer), batch_size, p=np.array(lens) / total)
|
|
||||||
else:
|
|
||||||
batch_index = np.array([])
|
|
||||||
batch_data = Batch()
|
|
||||||
for i, b in enumerate(self.buffer):
|
|
||||||
cur_batch = (batch_index == i).sum()
|
|
||||||
if batch_size and cur_batch or batch_size <= 0:
|
|
||||||
batch, indice = b.sample(cur_batch)
|
|
||||||
batch = self.process_fn(batch, b, indice)
|
|
||||||
batch_data.cat_(batch)
|
|
||||||
else:
|
|
||||||
batch_data, indice = self.buffer.sample(batch_size)
|
batch_data, indice = self.buffer.sample(batch_size)
|
||||||
batch_data = self.process_fn(batch_data, self.buffer, indice)
|
batch_data = self.process_fn(batch_data, self.buffer, indice)
|
||||||
return batch_data
|
return batch_data
|
||||||
|
Loading…
x
Reference in New Issue
Block a user