enable null buffer in test collector
This commit is contained in:
parent
815f3522bb
commit
4fd826761c
@ -75,9 +75,6 @@ class Collector(object):
|
||||
self.collect_step = 0
|
||||
self.collect_episode = 0
|
||||
self.collect_time = 0
|
||||
if buffer is None:
|
||||
self.buffer = ReplayBuffer(100)
|
||||
else:
|
||||
self.buffer = buffer
|
||||
self.policy = policy
|
||||
self.process_fn = policy.process_fn
|
||||
@ -92,7 +89,7 @@ class Collector(object):
|
||||
'The number of data buffer does not match the number of ' \
|
||||
'input env.'
|
||||
self._multi_buf = True
|
||||
elif isinstance(self.buffer, ReplayBuffer):
|
||||
elif isinstance(self.buffer, ReplayBuffer) or self.buffer is None:
|
||||
self._cached_buf = [
|
||||
ListReplayBuffer() for _ in range(self.env_num)]
|
||||
else:
|
||||
@ -118,6 +115,7 @@ class Collector(object):
|
||||
for b in self.buffer:
|
||||
b.reset()
|
||||
else:
|
||||
if self.buffer is not None:
|
||||
self.buffer.reset()
|
||||
|
||||
def get_env_num(self):
|
||||
@ -265,6 +263,7 @@ class Collector(object):
|
||||
cur_step += 1
|
||||
else:
|
||||
warning_count += 1
|
||||
if self.buffer is not None:
|
||||
self.buffer.add(**data)
|
||||
cur_step += 1
|
||||
if self._done[i]:
|
||||
@ -275,6 +274,7 @@ class Collector(object):
|
||||
length_sum += self.length[i]
|
||||
if self._cached_buf:
|
||||
cur_step += len(self._cached_buf[i])
|
||||
if self.buffer is not None:
|
||||
self.buffer.update(self._cached_buf[i])
|
||||
self.reward[i], self.length[i] = 0, 0
|
||||
if self._cached_buf:
|
||||
@ -289,6 +289,7 @@ class Collector(object):
|
||||
cur_episode.sum() >= n_episode:
|
||||
break
|
||||
else:
|
||||
if self.buffer is not None:
|
||||
self.buffer.add(
|
||||
self._obs, self._act[0], self._rew,
|
||||
self._done, obs_next, self._info)
|
||||
|
0
tianshou/policy/imitation/__init__.py
Normal file
0
tianshou/policy/imitation/__init__.py
Normal file
Loading…
x
Reference in New Issue
Block a user