enable null buffer in test collector
This commit is contained in:
parent
815f3522bb
commit
4fd826761c
@ -75,10 +75,7 @@ class Collector(object):
|
|||||||
self.collect_step = 0
|
self.collect_step = 0
|
||||||
self.collect_episode = 0
|
self.collect_episode = 0
|
||||||
self.collect_time = 0
|
self.collect_time = 0
|
||||||
if buffer is None:
|
self.buffer = buffer
|
||||||
self.buffer = ReplayBuffer(100)
|
|
||||||
else:
|
|
||||||
self.buffer = buffer
|
|
||||||
self.policy = policy
|
self.policy = policy
|
||||||
self.process_fn = policy.process_fn
|
self.process_fn = policy.process_fn
|
||||||
self._multi_env = isinstance(env, BaseVectorEnv)
|
self._multi_env = isinstance(env, BaseVectorEnv)
|
||||||
@ -92,7 +89,7 @@ class Collector(object):
|
|||||||
'The number of data buffer does not match the number of ' \
|
'The number of data buffer does not match the number of ' \
|
||||||
'input env.'
|
'input env.'
|
||||||
self._multi_buf = True
|
self._multi_buf = True
|
||||||
elif isinstance(self.buffer, ReplayBuffer):
|
elif isinstance(self.buffer, ReplayBuffer) or self.buffer is None:
|
||||||
self._cached_buf = [
|
self._cached_buf = [
|
||||||
ListReplayBuffer() for _ in range(self.env_num)]
|
ListReplayBuffer() for _ in range(self.env_num)]
|
||||||
else:
|
else:
|
||||||
@ -118,7 +115,8 @@ class Collector(object):
|
|||||||
for b in self.buffer:
|
for b in self.buffer:
|
||||||
b.reset()
|
b.reset()
|
||||||
else:
|
else:
|
||||||
self.buffer.reset()
|
if self.buffer is not None:
|
||||||
|
self.buffer.reset()
|
||||||
|
|
||||||
def get_env_num(self):
|
def get_env_num(self):
|
||||||
"""Return the number of environments the collector has."""
|
"""Return the number of environments the collector has."""
|
||||||
@ -265,7 +263,8 @@ class Collector(object):
|
|||||||
cur_step += 1
|
cur_step += 1
|
||||||
else:
|
else:
|
||||||
warning_count += 1
|
warning_count += 1
|
||||||
self.buffer.add(**data)
|
if self.buffer is not None:
|
||||||
|
self.buffer.add(**data)
|
||||||
cur_step += 1
|
cur_step += 1
|
||||||
if self._done[i]:
|
if self._done[i]:
|
||||||
if n_step != 0 or np.isscalar(n_episode) or \
|
if n_step != 0 or np.isscalar(n_episode) or \
|
||||||
@ -275,7 +274,8 @@ class Collector(object):
|
|||||||
length_sum += self.length[i]
|
length_sum += self.length[i]
|
||||||
if self._cached_buf:
|
if self._cached_buf:
|
||||||
cur_step += len(self._cached_buf[i])
|
cur_step += len(self._cached_buf[i])
|
||||||
self.buffer.update(self._cached_buf[i])
|
if self.buffer is not None:
|
||||||
|
self.buffer.update(self._cached_buf[i])
|
||||||
self.reward[i], self.length[i] = 0, 0
|
self.reward[i], self.length[i] = 0, 0
|
||||||
if self._cached_buf:
|
if self._cached_buf:
|
||||||
self._cached_buf[i].reset()
|
self._cached_buf[i].reset()
|
||||||
@ -289,9 +289,10 @@ class Collector(object):
|
|||||||
cur_episode.sum() >= n_episode:
|
cur_episode.sum() >= n_episode:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
self.buffer.add(
|
if self.buffer is not None:
|
||||||
self._obs, self._act[0], self._rew,
|
self.buffer.add(
|
||||||
self._done, obs_next, self._info)
|
self._obs, self._act[0], self._rew,
|
||||||
|
self._done, obs_next, self._info)
|
||||||
cur_step += 1
|
cur_step += 1
|
||||||
if self._done:
|
if self._done:
|
||||||
cur_episode += 1
|
cur_episode += 1
|
||||||
|
0
tianshou/policy/imitation/__init__.py
Normal file
0
tianshou/policy/imitation/__init__.py
Normal file
Loading…
x
Reference in New Issue
Block a user