enable null buffer in test collector

This commit is contained in:
Trinkle23897 2020-04-20 11:50:18 +08:00
parent 815f3522bb
commit 4fd826761c
2 changed files with 12 additions and 11 deletions

View File

@ -75,9 +75,6 @@ class Collector(object):
self.collect_step = 0
self.collect_episode = 0
self.collect_time = 0
if buffer is None:
self.buffer = ReplayBuffer(100)
else:
self.buffer = buffer
self.policy = policy
self.process_fn = policy.process_fn
@ -92,7 +89,7 @@ class Collector(object):
'The number of data buffer does not match the number of ' \
'input env.'
self._multi_buf = True
elif isinstance(self.buffer, ReplayBuffer):
elif isinstance(self.buffer, ReplayBuffer) or self.buffer is None:
self._cached_buf = [
ListReplayBuffer() for _ in range(self.env_num)]
else:
@ -118,6 +115,7 @@ class Collector(object):
for b in self.buffer:
b.reset()
else:
if self.buffer is not None:
self.buffer.reset()
def get_env_num(self):
@ -265,6 +263,7 @@ class Collector(object):
cur_step += 1
else:
warning_count += 1
if self.buffer is not None:
self.buffer.add(**data)
cur_step += 1
if self._done[i]:
@ -275,6 +274,7 @@ class Collector(object):
length_sum += self.length[i]
if self._cached_buf:
cur_step += len(self._cached_buf[i])
if self.buffer is not None:
self.buffer.update(self._cached_buf[i])
self.reward[i], self.length[i] = 0, 0
if self._cached_buf:
@ -289,6 +289,7 @@ class Collector(object):
cur_episode.sum() >= n_episode:
break
else:
if self.buffer is not None:
self.buffer.add(
self._obs, self._act[0], self._rew,
self._done, obs_next, self._info)

View File