enable null buffer in test collector

This commit is contained in:
Trinkle23897 2020-04-20 11:50:18 +08:00
parent 815f3522bb
commit 4fd826761c
2 changed files with 12 additions and 11 deletions

View File

@ -75,10 +75,7 @@ class Collector(object):
self.collect_step = 0 self.collect_step = 0
self.collect_episode = 0 self.collect_episode = 0
self.collect_time = 0 self.collect_time = 0
if buffer is None: self.buffer = buffer
self.buffer = ReplayBuffer(100)
else:
self.buffer = buffer
self.policy = policy self.policy = policy
self.process_fn = policy.process_fn self.process_fn = policy.process_fn
self._multi_env = isinstance(env, BaseVectorEnv) self._multi_env = isinstance(env, BaseVectorEnv)
@ -92,7 +89,7 @@ class Collector(object):
'The number of data buffer does not match the number of ' \ 'The number of data buffer does not match the number of ' \
'input env.' 'input env.'
self._multi_buf = True self._multi_buf = True
elif isinstance(self.buffer, ReplayBuffer): elif isinstance(self.buffer, ReplayBuffer) or self.buffer is None:
self._cached_buf = [ self._cached_buf = [
ListReplayBuffer() for _ in range(self.env_num)] ListReplayBuffer() for _ in range(self.env_num)]
else: else:
@ -118,7 +115,8 @@ class Collector(object):
for b in self.buffer: for b in self.buffer:
b.reset() b.reset()
else: else:
self.buffer.reset() if self.buffer is not None:
self.buffer.reset()
def get_env_num(self): def get_env_num(self):
"""Return the number of environments the collector has.""" """Return the number of environments the collector has."""
@ -265,7 +263,8 @@ class Collector(object):
cur_step += 1 cur_step += 1
else: else:
warning_count += 1 warning_count += 1
self.buffer.add(**data) if self.buffer is not None:
self.buffer.add(**data)
cur_step += 1 cur_step += 1
if self._done[i]: if self._done[i]:
if n_step != 0 or np.isscalar(n_episode) or \ if n_step != 0 or np.isscalar(n_episode) or \
@ -275,7 +274,8 @@ class Collector(object):
length_sum += self.length[i] length_sum += self.length[i]
if self._cached_buf: if self._cached_buf:
cur_step += len(self._cached_buf[i]) cur_step += len(self._cached_buf[i])
self.buffer.update(self._cached_buf[i]) if self.buffer is not None:
self.buffer.update(self._cached_buf[i])
self.reward[i], self.length[i] = 0, 0 self.reward[i], self.length[i] = 0, 0
if self._cached_buf: if self._cached_buf:
self._cached_buf[i].reset() self._cached_buf[i].reset()
@ -289,9 +289,10 @@ class Collector(object):
cur_episode.sum() >= n_episode: cur_episode.sum() >= n_episode:
break break
else: else:
self.buffer.add( if self.buffer is not None:
self._obs, self._act[0], self._rew, self.buffer.add(
self._done, obs_next, self._info) self._obs, self._act[0], self._rew,
self._done, obs_next, self._info)
cur_step += 1 cur_step += 1
if self._done: if self._done:
cur_episode += 1 cur_episode += 1

View File