add random action in collector (fix #78)

This commit is contained in:
Trinkle23897 2020-06-11 08:57:37 +08:00
parent 397e92b0fc
commit 1a914336f7
3 changed files with 18 additions and 4 deletions

View File

@ -8,7 +8,7 @@ class MyTestEnv(gym.Env):
self.size = size
self.sleep = sleep
self.dict_state = dict_state
self.action_space = Discrete(1)
self.action_space = Discrete(2)
self.reset()
def reset(self, state=0):

View File

@ -56,6 +56,7 @@ def test_collector():
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]
venv = SubprocVectorEnv(env_fns)
dum = VectorEnv(env_fns)
policy = MyPolicy()
env = env_fns[0]()
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False),
@ -66,6 +67,7 @@ def test_collector():
c0.collect(n_episode=3, log_fn=logger.log)
assert np.allclose(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
assert np.allclose(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2])
c0.collect(n_step=3, random=True)
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False),
preprocess_fn)
c1.collect(n_step=6)
@ -76,7 +78,8 @@ def test_collector():
assert np.allclose(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2])
assert np.allclose(c1.buffer[11:21].obs_next,
[1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
c2 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False),
c1.collect(n_episode=3, random=True)
c2 = Collector(policy, dum, ReplayBuffer(size=100, ignore_obs_next=False),
preprocess_fn)
c2.collect(n_episode=[1, 2, 2, 2])
assert np.allclose(c2.buffer.obs_next[:26], [
@ -87,6 +90,7 @@ def test_collector():
assert np.allclose(c2.buffer.obs_next[26:54], [
1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 5,
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
c2.collect(n_episode=[1, 1, 1, 1], random=True)
def test_collector_with_dict_state():

View File

@ -204,6 +204,7 @@ class Collector(object):
def collect(self,
n_step: int = 0,
n_episode: Union[int, List[int]] = 0,
random: bool = False,
render: Optional[float] = None,
log_fn: Optional[Callable[[dict], None]] = None
) -> Dict[str, float]:
@ -213,6 +214,8 @@ class Collector(object):
:param n_episode: how many episodes you want to collect (in each
environment).
:type n_episode: int or list
:param bool random: whether to use random policy for collecting data,
defaults to ``False``.
:param float render: the sleep time between rendering consecutive
frames, defaults to ``None`` (no rendering).
:param function log_fn: a function which receives env info, typically
@ -252,8 +255,15 @@ class Collector(object):
obs=self._obs, act=self._act, rew=self._rew,
done=self._done, obs_next=None, info=self._info,
policy=None)
with torch.no_grad():
result = self.policy(batch, self.state)
if random:
action_space = self.env.action_space
if isinstance(action_space, list):
result = Batch(act=[a.sample() for a in action_space])
else:
result = Batch(act=self._make_batch(action_space.sample()))
else:
with torch.no_grad():
result = self.policy(batch, self.state)
self.state = result.get('state', None)
self._policy = to_numpy(result.policy) \
if hasattr(result, 'policy') else [{}] * self.env_num