enable null buffer in test collector

This commit is contained in:
Trinkle23897 2020-04-20 11:50:18 +08:00
parent 815f3522bb
commit 4fd826761c
2 changed files with 12 additions and 11 deletions

View File

@ -75,9 +75,6 @@ class Collector(object):
self.collect_step = 0 self.collect_step = 0
self.collect_episode = 0 self.collect_episode = 0
self.collect_time = 0 self.collect_time = 0
if buffer is None:
self.buffer = ReplayBuffer(100)
else:
self.buffer = buffer self.buffer = buffer
self.policy = policy self.policy = policy
self.process_fn = policy.process_fn self.process_fn = policy.process_fn
@ -92,7 +89,7 @@ class Collector(object):
'The number of data buffer does not match the number of ' \ 'The number of data buffer does not match the number of ' \
'input env.' 'input env.'
self._multi_buf = True self._multi_buf = True
elif isinstance(self.buffer, ReplayBuffer): elif isinstance(self.buffer, ReplayBuffer) or self.buffer is None:
self._cached_buf = [ self._cached_buf = [
ListReplayBuffer() for _ in range(self.env_num)] ListReplayBuffer() for _ in range(self.env_num)]
else: else:
@ -118,6 +115,7 @@ class Collector(object):
for b in self.buffer: for b in self.buffer:
b.reset() b.reset()
else: else:
if self.buffer is not None:
self.buffer.reset() self.buffer.reset()
def get_env_num(self): def get_env_num(self):
@ -265,6 +263,7 @@ class Collector(object):
cur_step += 1 cur_step += 1
else: else:
warning_count += 1 warning_count += 1
if self.buffer is not None:
self.buffer.add(**data) self.buffer.add(**data)
cur_step += 1 cur_step += 1
if self._done[i]: if self._done[i]:
@ -275,6 +274,7 @@ class Collector(object):
length_sum += self.length[i] length_sum += self.length[i]
if self._cached_buf: if self._cached_buf:
cur_step += len(self._cached_buf[i]) cur_step += len(self._cached_buf[i])
if self.buffer is not None:
self.buffer.update(self._cached_buf[i]) self.buffer.update(self._cached_buf[i])
self.reward[i], self.length[i] = 0, 0 self.reward[i], self.length[i] = 0, 0
if self._cached_buf: if self._cached_buf:
@ -289,6 +289,7 @@ class Collector(object):
cur_episode.sum() >= n_episode: cur_episode.sum() >= n_episode:
break break
else: else:
if self.buffer is not None:
self.buffer.add( self.buffer.add(
self._obs, self._act[0], self._rew, self._obs, self._act[0], self._rew,
self._done, obs_next, self._info) self._done, obs_next, self._info)

View File