Improve collector (#125)

* remove multibuf

* reward_metric

* make fileds with empty Batch rather than None after reset

* many fixes and refactor
Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
youkaichao 2020-07-13 00:24:31 +08:00 committed by n+e
parent 5599a6d1a6
commit 26fb87433d
3 changed files with 183 additions and 177 deletions

View File

@ -1,19 +1,34 @@
import time
import gym import gym
import time
from gym.spaces.discrete import Discrete from gym.spaces.discrete import Discrete
class MyTestEnv(gym.Env): class MyTestEnv(gym.Env):
def __init__(self, size, sleep=0, dict_state=False): """This is a "going right" task. The task is to go right ``size`` steps.
"""
def __init__(self, size, sleep=0, dict_state=False, ma_rew=0):
self.size = size self.size = size
self.sleep = sleep self.sleep = sleep
self.dict_state = dict_state self.dict_state = dict_state
self.ma_rew = ma_rew
self.action_space = Discrete(2) self.action_space = Discrete(2)
self.reset() self.reset()
def reset(self, state=0): def reset(self, state=0):
self.done = False self.done = False
self.index = state self.index = state
return self._get_dict_state()
def _get_reward(self):
"""Generate a non-scalar reward if ma_rew is True."""
x = int(self.done)
if self.ma_rew > 0:
return [x] * self.ma_rew
return x
def _get_dict_state(self):
"""Generate a dict_state if dict_state is True."""
return {'index': self.index} if self.dict_state else self.index return {'index': self.index} if self.dict_state else self.index
def step(self, action): def step(self, action):
@ -23,22 +38,13 @@ class MyTestEnv(gym.Env):
time.sleep(self.sleep) time.sleep(self.sleep)
if self.index == self.size: if self.index == self.size:
self.done = True self.done = True
if self.dict_state: return self._get_dict_state(), self._get_reward(), self.done, {}
return {'index': self.index}, 0, True, {}
else:
return self.index, 0, True, {}
if action == 0: if action == 0:
self.index = max(self.index - 1, 0) self.index = max(self.index - 1, 0)
if self.dict_state: return self._get_dict_state(), self._get_reward(), self.done, \
return {'index': self.index}, 0, False, {'key': 1, 'env': self} {'key': 1, 'env': self} if self.dict_state else {}
else:
return self.index, 0, False, {}
elif action == 1: elif action == 1:
self.index += 1 self.index += 1
self.done = self.index == self.size self.done = self.index == self.size
if self.dict_state: return self._get_dict_state(), self._get_reward(), \
return {'index': self.index}, int(self.done), self.done, \ self.done, {'key': 1, 'env': self}
{'key': 1, 'env': self}
else:
return self.index, int(self.done), self.done, \
{'key': 1, 'env': self}

View File

@ -27,16 +27,16 @@ class MyPolicy(BasePolicy):
def preprocess_fn(**kwargs): def preprocess_fn(**kwargs):
# modify info before adding into the buffer # modify info before adding into the buffer
if kwargs.get('info', None) is not None: # if info is not provided from env, it will be a ``Batch()``.
if not kwargs.get('info', Batch()).is_empty():
n = len(kwargs['obs']) n = len(kwargs['obs'])
info = kwargs['info'] info = kwargs['info']
for i in range(n): for i in range(n):
info[i].update(rew=kwargs['rew'][i]) info[i].update(rew=kwargs['rew'][i])
return {'info': info} return {'info': info}
# or # or: return Batch(info=info)
# return Batch(info=info)
else: else:
return {} return Batch()
class Logger(object): class Logger(object):
@ -119,6 +119,48 @@ def test_collector_with_dict_state():
print(batch['obs_next']['index']) print(batch['obs_next']['index'])
def test_collector_with_ma():
def reward_metric(x):
return x.sum()
env = MyTestEnv(size=5, sleep=0, ma_rew=4)
policy = MyPolicy()
c0 = Collector(policy, env, ReplayBuffer(size=100),
preprocess_fn, reward_metric=reward_metric)
r = c0.collect(n_step=3)['rew']
assert np.asanyarray(r).size == 1 and r == 0.
r = c0.collect(n_episode=3)['rew']
assert np.asanyarray(r).size == 1 and r == 4.
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4)
for i in [2, 3, 4, 5]]
envs = VectorEnv(env_fns)
c1 = Collector(policy, envs, ReplayBuffer(size=100),
preprocess_fn, reward_metric=reward_metric)
r = c1.collect(n_step=10)['rew']
assert np.asanyarray(r).size == 1 and r == 4.
r = c1.collect(n_episode=[2, 1, 1, 2])['rew']
assert np.asanyarray(r).size == 1 and r == 4.
batch = c1.sample(10)
print(batch)
c0.buffer.update(c1.buffer)
obs = [
0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1.,
0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0.,
1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.]
assert np.allclose(c0.buffer[:len(c0.buffer)].obs, obs)
rew = [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1,
0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0,
0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1]
assert np.allclose(c0.buffer[:len(c0.buffer)].rew,
[[x] * 4 for x in rew])
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4),
preprocess_fn, reward_metric=reward_metric)
r = c2.collect(n_episode=[0, 0, 0, 10])['rew']
assert np.asanyarray(r).size == 1 and r == 4.
batch = c2.sample(10)
print(batch['obs_next'])
if __name__ == '__main__': if __name__ == '__main__':
test_collector() test_collector()
test_collector_with_dict_state() test_collector_with_dict_state()
test_collector_with_ma()

View File

@ -8,8 +8,8 @@ from typing import Any, Dict, List, Union, Optional, Callable
from tianshou.utils import MovAvg from tianshou.utils import MovAvg
from tianshou.env import BaseVectorEnv from tianshou.env import BaseVectorEnv
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
from tianshou.exploration import BaseNoise from tianshou.exploration import BaseNoise
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
class Collector(object): class Collector(object):
@ -25,12 +25,18 @@ class Collector(object):
``None``, it will automatically assign a small-size ``None``, it will automatically assign a small-size
:class:`~tianshou.data.ReplayBuffer`. :class:`~tianshou.data.ReplayBuffer`.
:param function preprocess_fn: a function called before the data has been :param function preprocess_fn: a function called before the data has been
added to the buffer, see issue #42, defaults to ``None``. added to the buffer, see issue #42 and :ref:`preprocess_fn`, defaults
to ``None``.
:param int stat_size: for the moving average of recording speed, defaults :param int stat_size: for the moving average of recording speed, defaults
to 100. to 100.
:param BaseNoise action_noise: add a noise to continuous action. Normally :param BaseNoise action_noise: add a noise to continuous action. Normally
a policy already has a noise param for exploration in training phase, a policy already has a noise param for exploration in training phase,
so this is recommended to use in test collector for some purpose. so this is recommended to use in test collector for some purpose.
:param function reward_metric: to be used in multi-agent RL. The reward to
report is of shape [agent_num], but we need to return a single scalar
to monitor training. This function specifies what is the desired
metric, e.g., the reward of agent 1 or the average reward over all
agents. By default, the behavior is to select the reward of agent 1.
The ``preprocess_fn`` is a function called before the data has been added The ``preprocess_fn`` is a function called before the data has been added
to the buffer with batch format, which receives up to 7 keys as listed in to the buffer with batch format, which receives up to 7 keys as listed in
@ -87,66 +93,56 @@ class Collector(object):
def __init__(self, def __init__(self,
policy: BasePolicy, policy: BasePolicy,
env: Union[gym.Env, BaseVectorEnv], env: Union[gym.Env, BaseVectorEnv],
buffer: Optional[Union[ReplayBuffer, List[ReplayBuffer]]] buffer: Optional[ReplayBuffer] = None,
= None,
preprocess_fn: Callable[[Any], Union[dict, Batch]] = None, preprocess_fn: Callable[[Any], Union[dict, Batch]] = None,
stat_size: Optional[int] = 100, stat_size: Optional[int] = 100,
action_noise: Optional[BaseNoise] = None, action_noise: Optional[BaseNoise] = None,
reward_metric: Optional[Callable[[np.ndarray], float]] = None,
**kwargs) -> None: **kwargs) -> None:
super().__init__() super().__init__()
self.env = env self.env = env
self.env_num = 1 self.env_num = 1
self.collect_time = 0 self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
self.collect_step = 0
self.collect_episode = 0
self.buffer = buffer self.buffer = buffer
self.policy = policy self.policy = policy
self.preprocess_fn = preprocess_fn self.preprocess_fn = preprocess_fn
# if preprocess_fn is None:
# def _prep(**kwargs):
# return kwargs
# self.preprocess_fn = _prep
self.process_fn = policy.process_fn self.process_fn = policy.process_fn
self._multi_env = isinstance(env, BaseVectorEnv) self._multi_env = isinstance(env, BaseVectorEnv)
self._multi_buf = False # True if buf is a list
# need multiple cache buffers only if storing in one buffer # need multiple cache buffers only if storing in one buffer
self._cached_buf = [] self._cached_buf = []
if self._multi_env: if self._multi_env:
self.env_num = len(env) self.env_num = len(env)
if isinstance(self.buffer, list): self._cached_buf = [ListReplayBuffer()
assert len(self.buffer) == self.env_num, \ for _ in range(self.env_num)]
'The number of data buffer does not match the number of ' \
'input env.'
self._multi_buf = True
elif isinstance(self.buffer, ReplayBuffer) or self.buffer is None:
self._cached_buf = [
ListReplayBuffer() for _ in range(self.env_num)]
else:
raise TypeError('The buffer in data collector is invalid!')
self.stat_size = stat_size self.stat_size = stat_size
self._action_noise = action_noise self._action_noise = action_noise
self._rew_metric = reward_metric or Collector._default_rew_metric
self.reset() self.reset()
@staticmethod
def _default_rew_metric(x):
# this internal function is designed for single-agent RL
# for multi-agent RL, a reward_metric must be provided
assert np.asanyarray(x).size == 1, \
'Please specify the reward_metric ' \
'since the reward is not a scalar.'
return x
def reset(self) -> None: def reset(self) -> None:
"""Reset all related variables in the collector.""" """Reset all related variables in the collector."""
self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={},
obs_next={}, policy={})
self.reset_env() self.reset_env()
self.reset_buffer() self.reset_buffer()
# state over batch is either a list, an np.ndarray, or a torch.Tensor
self.state = None
self.step_speed = MovAvg(self.stat_size) self.step_speed = MovAvg(self.stat_size)
self.episode_speed = MovAvg(self.stat_size) self.episode_speed = MovAvg(self.stat_size)
self.collect_step = 0 self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
self.collect_episode = 0
self.collect_time = 0
if self._action_noise is not None: if self._action_noise is not None:
self._action_noise.reset() self._action_noise.reset()
def reset_buffer(self) -> None: def reset_buffer(self) -> None:
"""Reset the main data buffer.""" """Reset the main data buffer."""
if self._multi_buf:
for b in self.buffer:
b.reset()
else:
if self.buffer is not None: if self.buffer is not None:
self.buffer.reset() self.buffer.reset()
@ -158,33 +154,27 @@ class Collector(object):
"""Reset all of the environment(s)' states and reset all of the cache """Reset all of the environment(s)' states and reset all of the cache
buffers (if need). buffers (if need).
""" """
self._obs = self.env.reset() obs = self.env.reset()
if not self._multi_env: if not self._multi_env:
self._obs = self._make_batch(self._obs) obs = self._make_batch(obs)
if self.preprocess_fn: if self.preprocess_fn:
self._obs = self.preprocess_fn(obs=self._obs).get('obs', self._obs) obs = self.preprocess_fn(obs=obs).get('obs', obs)
self._act = self._rew = self._done = self._info = None self.data.obs = obs
if self._multi_env: self.reward = 0. # will be specified when the first data is ready
self.reward = np.zeros(self.env_num)
self.length = np.zeros(self.env_num) self.length = np.zeros(self.env_num)
else:
self.reward, self.length = 0, 0
for b in self._cached_buf: for b in self._cached_buf:
b.reset() b.reset()
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None: def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
"""Reset all the seed(s) of the given environment(s).""" """Reset all the seed(s) of the given environment(s)."""
if hasattr(self.env, 'seed'):
return self.env.seed(seed) return self.env.seed(seed)
def render(self, **kwargs) -> None: def render(self, **kwargs) -> None:
"""Render all the environment(s).""" """Render all the environment(s)."""
if hasattr(self.env, 'render'):
return self.env.render(**kwargs) return self.env.render(**kwargs)
def close(self) -> None: def close(self) -> None:
"""Close the environment(s).""" """Close the environment(s)."""
if hasattr(self.env, 'close'):
self.env.close() self.env.close()
def _make_batch(self, data: Any) -> np.ndarray: def _make_batch(self, data: Any) -> np.ndarray:
@ -195,20 +185,14 @@ class Collector(object):
return np.array([data]) return np.array([data])
def _reset_state(self, id: Union[int, List[int]]) -> None: def _reset_state(self, id: Union[int, List[int]]) -> None:
"""Reset self.state[id].""" """Reset self.data.state[id]."""
if self.state is None: state = self.data.state # it is a reference
return if isinstance(state, torch.Tensor):
if isinstance(self.state, list): state[id].zero_()
self.state[id] = None elif isinstance(state, np.ndarray):
elif isinstance(self.state, torch.Tensor): state[id] = None if state.dtype == np.object else 0
self.state[id].zero_() elif isinstance(state, Batch):
elif isinstance(self.state, np.ndarray): state.empty_(id)
if isinstance(self.state.dtype == np.object):
self.state[id] = None
else:
self.state[id] = 0
elif isinstance(self.state, Batch):
self.state.empty_(id)
def collect(self, def collect(self,
n_step: int = 0, n_step: int = 0,
@ -244,26 +228,27 @@ class Collector(object):
* ``rew`` the mean reward over collected episodes. * ``rew`` the mean reward over collected episodes.
* ``len`` the mean length over collected episodes. * ``len`` the mean length over collected episodes.
""" """
warning_count = 0
if not self._multi_env: if not self._multi_env:
n_episode = np.sum(n_episode) n_episode = np.sum(n_episode)
start_time = time.time() start_time = time.time()
assert sum([(n_step != 0), (n_episode != 0)]) == 1, \ assert sum([(n_step != 0), (n_episode != 0)]) == 1, \
"One and only one collection number specification is permitted!" "One and only one collection number specification is permitted!"
cur_step = 0 cur_step, cur_episode = 0, np.zeros(self.env_num)
cur_episode = np.zeros(self.env_num) if self._multi_env else 0 reward_sum, length_sum = 0., 0
reward_sum = 0
length_sum = 0
while True: while True:
if warning_count >= 100000: if cur_step >= 100000 and cur_episode.sum() == 0:
warnings.warn( warnings.warn(
'There are already many steps in an episode. ' 'There are already many steps in an episode. '
'You should add a time limitation to your environment!', 'You should add a time limitation to your environment!',
Warning) Warning)
batch = Batch(
obs=self._obs, act=self._act, rew=self._rew, # restore the state and the input data
done=self._done, obs_next=None, info=self._info, last_state = self.data.state
policy=None) if last_state.is_empty():
last_state = None
self.data.update(state=Batch(), obs_next=Batch(), policy=Batch())
# calculate the next action
if random: if random:
action_space = self.env.action_space action_space = self.env.action_space
if isinstance(action_space, list): if isinstance(action_space, list):
@ -272,69 +257,54 @@ class Collector(object):
result = Batch(act=self._make_batch(action_space.sample())) result = Batch(act=self._make_batch(action_space.sample()))
else: else:
with torch.no_grad(): with torch.no_grad():
result = self.policy(batch, self.state) result = self.policy(self.data, last_state)
# save hidden state to policy._state, in order to save into buffer # convert None to Batch(), since None is reserved for 0-init
self.state = result.get('state', None) state = result.get('state', Batch())
if state is None:
state = Batch()
self.data.state = state
if hasattr(result, 'policy'): if hasattr(result, 'policy'):
self._policy = to_numpy(result.policy) self.data.policy = to_numpy(result.policy)
if self.state is not None: # save hidden state to policy._state, in order to save into buffer
self._policy._state = self.state self.data.policy._state = self.data.state
elif self.state is not None:
self._policy = Batch(_state=self.state)
else:
self._policy = [{}] * self.env_num
self._act = to_numpy(result.act) self.data.act = to_numpy(result.act)
if self._action_noise is not None: if self._action_noise is not None:
self._act += self._action_noise(self._act.shape) self.data.act += self._action_noise(self.data.act.shape)
obs_next, self._rew, self._done, self._info = self.env.step(
self._act if self._multi_env else self._act[0]) # step in env
obs_next, rew, done, info = self.env.step(
self.data.act if self._multi_env else self.data.act[0])
# move data to self.data
if not self._multi_env: if not self._multi_env:
obs_next = self._make_batch(obs_next) obs_next = self._make_batch(obs_next)
self._rew = self._make_batch(self._rew) rew = self._make_batch(rew)
self._done = self._make_batch(self._done) done = self._make_batch(done)
self._info = self._make_batch(self._info) info = self._make_batch(info)
self.data.obs_next = obs_next
self.data.rew = rew
self.data.done = done
self.data.info = info
if log_fn: if log_fn:
log_fn(self._info if self._multi_env else self._info[0]) log_fn(info if self._multi_env else info[0])
if render: if render:
self.env.render() self.render()
if render > 0: if render > 0:
time.sleep(render) time.sleep(render)
# add data into the buffer
self.length += 1 self.length += 1
self.reward += self._rew self.reward += self.data.rew
if self.preprocess_fn: if self.preprocess_fn:
result = self.preprocess_fn( result = self.preprocess_fn(**self.data)
obs=self._obs, act=self._act, rew=self._rew, self.data.update(result)
done=self._done, obs_next=obs_next, info=self._info, if self._multi_env: # cache_buffer branch
policy=self._policy)
self._obs = result.get('obs', self._obs)
self._act = result.get('act', self._act)
self._rew = result.get('rew', self._rew)
self._done = result.get('done', self._done)
obs_next = result.get('obs_next', obs_next)
self._info = result.get('info', self._info)
self._policy = result.get('policy', self._policy)
if self._multi_env:
for i in range(self.env_num): for i in range(self.env_num):
data = { self._cached_buf[i].add(**self.data[i])
'obs': self._obs[i], 'act': self._act[i], if self.data.done[i]:
'rew': self._rew[i], 'done': self._done[i],
'obs_next': obs_next[i], 'info': self._info[i],
'policy': self._policy[i]}
if self._cached_buf:
warning_count += 1
self._cached_buf[i].add(**data)
elif self._multi_buf:
warning_count += 1
self.buffer[i].add(**data)
cur_step += 1
else:
warning_count += 1
if self.buffer is not None:
self.buffer.add(**data)
cur_step += 1
if self._done[i]:
if n_step != 0 or np.isscalar(n_episode) or \ if n_step != 0 or np.isscalar(n_episode) or \
cur_episode[i] < n_episode[i]: cur_episode[i] < n_episode[i]:
cur_episode[i] += 1 cur_episode[i] += 1
@ -344,45 +314,46 @@ class Collector(object):
cur_step += len(self._cached_buf[i]) cur_step += len(self._cached_buf[i])
if self.buffer is not None: if self.buffer is not None:
self.buffer.update(self._cached_buf[i]) self.buffer.update(self._cached_buf[i])
self.reward[i], self.length[i] = 0, 0 self.reward[i], self.length[i] = 0., 0
if self._cached_buf: if self._cached_buf:
self._cached_buf[i].reset() self._cached_buf[i].reset()
self._reset_state(i) self._reset_state(i)
if sum(self._done): obs_next = self.data.obs_next
obs_next = self.env.reset(np.where(self._done)[0]) if sum(self.data.done):
obs_next = self.env.reset(np.where(self.data.done)[0])
if self.preprocess_fn: if self.preprocess_fn:
obs_next = self.preprocess_fn(obs=obs_next).get( obs_next = self.preprocess_fn(obs=obs_next).get(
'obs', obs_next) 'obs', obs_next)
self.data.obs_next = obs_next
if n_episode != 0: if n_episode != 0:
if isinstance(n_episode, list) and \ if isinstance(n_episode, list) and \
(cur_episode >= np.array(n_episode)).all() or \ (cur_episode >= np.array(n_episode)).all() or \
np.isscalar(n_episode) and \ np.isscalar(n_episode) and \
cur_episode.sum() >= n_episode: cur_episode.sum() >= n_episode:
break break
else: else: # single buffer, without cache_buffer
if self.buffer is not None: if self.buffer is not None:
self.buffer.add( self.buffer.add(**self.data[0])
self._obs[0], self._act[0], self._rew[0],
self._done[0], obs_next[0], self._info[0],
self._policy[0])
cur_step += 1 cur_step += 1
if self._done: if self.data.done[0]:
cur_episode += 1 cur_episode += 1
reward_sum += self.reward[0] reward_sum += self.reward[0]
length_sum += self.length length_sum += self.length[0]
self.reward, self.length = 0, 0 self.reward, self.length = 0., np.zeros(self.env_num)
self.state = None self.data.state = Batch()
obs_next = self._make_batch(self.env.reset()) obs_next = self._make_batch(self.env.reset())
if self.preprocess_fn: if self.preprocess_fn:
obs_next = self.preprocess_fn(obs=obs_next).get( obs_next = self.preprocess_fn(obs=obs_next).get(
'obs', obs_next) 'obs', obs_next)
self.data.obs_next = obs_next
if n_episode != 0 and cur_episode >= n_episode: if n_episode != 0 and cur_episode >= n_episode:
break break
if n_step != 0 and cur_step >= n_step: if n_step != 0 and cur_step >= n_step:
break break
self._obs = obs_next self.data.obs = self.data.obs_next
self._obs = obs_next self.data.obs = self.data.obs_next
if self._multi_env:
# generate the statistics
cur_episode = sum(cur_episode) cur_episode = sum(cur_episode)
duration = max(time.time() - start_time, 1e-9) duration = max(time.time() - start_time, 1e-9)
self.step_speed.add(cur_step / duration) self.step_speed.add(cur_step / duration)
@ -394,12 +365,15 @@ class Collector(object):
n_episode = np.sum(n_episode) n_episode = np.sum(n_episode)
else: else:
n_episode = max(cur_episode, 1) n_episode = max(cur_episode, 1)
reward_sum /= n_episode
if np.asanyarray(reward_sum).size > 1: # non-scalar reward_sum
reward_sum = self._rew_metric(reward_sum)
return { return {
'n/ep': cur_episode, 'n/ep': cur_episode,
'n/st': cur_step, 'n/st': cur_step,
'v/st': self.step_speed.get(), 'v/st': self.step_speed.get(),
'v/ep': self.episode_speed.get(), 'v/ep': self.episode_speed.get(),
'rew': reward_sum / n_episode, 'rew': reward_sum,
'len': length_sum / n_episode, 'len': length_sum / n_episode,
} }
@ -412,22 +386,6 @@ class Collector(object):
the buffer, otherwise it will extract the data with the given the buffer, otherwise it will extract the data with the given
batch_size. batch_size.
""" """
if self._multi_buf:
if batch_size > 0:
lens = [len(b) for b in self.buffer]
total = sum(lens)
batch_index = np.random.choice(
len(self.buffer), batch_size, p=np.array(lens) / total)
else:
batch_index = np.array([])
batch_data = Batch()
for i, b in enumerate(self.buffer):
cur_batch = (batch_index == i).sum()
if batch_size and cur_batch or batch_size <= 0:
batch, indice = b.sample(cur_batch)
batch = self.process_fn(batch, b, indice)
batch_data.cat_(batch)
else:
batch_data, indice = self.buffer.sample(batch_size) batch_data, indice = self.buffer.sample(batch_size)
batch_data = self.process_fn(batch_data, self.buffer, indice) batch_data = self.process_fn(batch_data, self.buffer, indice)
return batch_data return batch_data