diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 496712e..92a5d99 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -75,10 +75,7 @@ class Collector(object): self.collect_step = 0 self.collect_episode = 0 self.collect_time = 0 - if buffer is None: - self.buffer = ReplayBuffer(100) - else: - self.buffer = buffer + self.buffer = buffer self.policy = policy self.process_fn = policy.process_fn self._multi_env = isinstance(env, BaseVectorEnv) @@ -92,7 +89,7 @@ class Collector(object): 'The number of data buffer does not match the number of ' \ 'input env.' self._multi_buf = True - elif isinstance(self.buffer, ReplayBuffer): + elif isinstance(self.buffer, ReplayBuffer) or self.buffer is None: self._cached_buf = [ ListReplayBuffer() for _ in range(self.env_num)] else: @@ -118,7 +115,8 @@ class Collector(object): for b in self.buffer: b.reset() else: - self.buffer.reset() + if self.buffer is not None: + self.buffer.reset() def get_env_num(self): """Return the number of environments the collector has.""" @@ -265,7 +263,8 @@ class Collector(object): cur_step += 1 else: warning_count += 1 - self.buffer.add(**data) + if self.buffer is not None: + self.buffer.add(**data) cur_step += 1 if self._done[i]: if n_step != 0 or np.isscalar(n_episode) or \ @@ -275,7 +274,8 @@ class Collector(object): length_sum += self.length[i] if self._cached_buf: cur_step += len(self._cached_buf[i]) - self.buffer.update(self._cached_buf[i]) + if self.buffer is not None: + self.buffer.update(self._cached_buf[i]) self.reward[i], self.length[i] = 0, 0 if self._cached_buf: self._cached_buf[i].reset() @@ -289,9 +289,10 @@ class Collector(object): cur_episode.sum() >= n_episode: break else: - self.buffer.add( - self._obs, self._act[0], self._rew, - self._done, obs_next, self._info) + if self.buffer is not None: + self.buffer.add( + self._obs, self._act[0], self._rew, + self._done, obs_next, self._info) cur_step += 1 if self._done: cur_episode += 1 diff --git a/tianshou/policy/imitation/__init__.py b/tianshou/policy/imitation/__init__.py new file mode 100644 index 0000000..e69de29