From 1a914336f780d68a20afd8c056b07f21c261fc72 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 11 Jun 2020 08:57:37 +0800 Subject: [PATCH] add random action in collector (fix #78) --- test/base/env.py | 2 +- test/base/test_collector.py | 6 +++++- tianshou/data/collector.py | 14 ++++++++++++-- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/test/base/env.py b/test/base/env.py index 73ed9a5..1aa409f 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -8,7 +8,7 @@ class MyTestEnv(gym.Env): self.size = size self.sleep = sleep self.dict_state = dict_state - self.action_space = Discrete(1) + self.action_space = Discrete(2) self.reset() def reset(self, state=0): diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 183397a..05973f4 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -56,6 +56,7 @@ def test_collector(): env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]] venv = SubprocVectorEnv(env_fns) + dum = VectorEnv(env_fns) policy = MyPolicy() env = env_fns[0]() c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False), @@ -66,6 +67,7 @@ def test_collector(): c0.collect(n_episode=3, log_fn=logger.log) assert np.allclose(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1]) assert np.allclose(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2]) + c0.collect(n_step=3, random=True) c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False), preprocess_fn) c1.collect(n_step=6) @@ -76,7 +78,8 @@ def test_collector(): assert np.allclose(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2]) assert np.allclose(c1.buffer[11:21].obs_next, [1, 2, 3, 4, 5, 1, 2, 1, 2, 3]) - c2 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False), + c1.collect(n_episode=3, random=True) + c2 = Collector(policy, dum, ReplayBuffer(size=100, ignore_obs_next=False), preprocess_fn) c2.collect(n_episode=[1, 2, 2, 2]) assert np.allclose(c2.buffer.obs_next[:26], [ @@ -87,6 +90,7 @@ def test_collector(): assert np.allclose(c2.buffer.obs_next[26:54], [ 1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 5, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]) + c2.collect(n_episode=[1, 1, 1, 1], random=True) def test_collector_with_dict_state(): diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 92d39e9..ed75b4e 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -204,6 +204,7 @@ class Collector(object): def collect(self, n_step: int = 0, n_episode: Union[int, List[int]] = 0, + random: bool = False, render: Optional[float] = None, log_fn: Optional[Callable[[dict], None]] = None ) -> Dict[str, float]: @@ -213,6 +214,8 @@ class Collector(object): :param n_episode: how many episodes you want to collect (in each environment). :type n_episode: int or list + :param bool random: whether to use random policy for collecting data, + defaults to ``False``. :param float render: the sleep time between rendering consecutive frames, defaults to ``None`` (no rendering). :param function log_fn: a function which receives env info, typically @@ -252,8 +255,15 @@ class Collector(object): obs=self._obs, act=self._act, rew=self._rew, done=self._done, obs_next=None, info=self._info, policy=None) - with torch.no_grad(): - result = self.policy(batch, self.state) + if random: + action_space = self.env.action_space + if isinstance(action_space, list): + result = Batch(act=[a.sample() for a in action_space]) + else: + result = Batch(act=self._make_batch(action_space.sample())) + else: + with torch.no_grad(): + result = self.policy(batch, self.state) self.state = result.get('state', None) self._policy = to_numpy(result.policy) \ if hasattr(result, 'policy') else [{}] * self.env_num