add random action in collector (fix #78)
This commit is contained in:
parent
397e92b0fc
commit
1a914336f7
@ -8,7 +8,7 @@ class MyTestEnv(gym.Env):
|
||||
self.size = size
|
||||
self.sleep = sleep
|
||||
self.dict_state = dict_state
|
||||
self.action_space = Discrete(1)
|
||||
self.action_space = Discrete(2)
|
||||
self.reset()
|
||||
|
||||
def reset(self, state=0):
|
||||
|
@ -56,6 +56,7 @@ def test_collector():
|
||||
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]
|
||||
|
||||
venv = SubprocVectorEnv(env_fns)
|
||||
dum = VectorEnv(env_fns)
|
||||
policy = MyPolicy()
|
||||
env = env_fns[0]()
|
||||
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False),
|
||||
@ -66,6 +67,7 @@ def test_collector():
|
||||
c0.collect(n_episode=3, log_fn=logger.log)
|
||||
assert np.allclose(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
|
||||
assert np.allclose(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2])
|
||||
c0.collect(n_step=3, random=True)
|
||||
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False),
|
||||
preprocess_fn)
|
||||
c1.collect(n_step=6)
|
||||
@ -76,7 +78,8 @@ def test_collector():
|
||||
assert np.allclose(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2])
|
||||
assert np.allclose(c1.buffer[11:21].obs_next,
|
||||
[1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
|
||||
c2 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False),
|
||||
c1.collect(n_episode=3, random=True)
|
||||
c2 = Collector(policy, dum, ReplayBuffer(size=100, ignore_obs_next=False),
|
||||
preprocess_fn)
|
||||
c2.collect(n_episode=[1, 2, 2, 2])
|
||||
assert np.allclose(c2.buffer.obs_next[:26], [
|
||||
@ -87,6 +90,7 @@ def test_collector():
|
||||
assert np.allclose(c2.buffer.obs_next[26:54], [
|
||||
1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 5,
|
||||
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
|
||||
c2.collect(n_episode=[1, 1, 1, 1], random=True)
|
||||
|
||||
|
||||
def test_collector_with_dict_state():
|
||||
|
@ -204,6 +204,7 @@ class Collector(object):
|
||||
def collect(self,
|
||||
n_step: int = 0,
|
||||
n_episode: Union[int, List[int]] = 0,
|
||||
random: bool = False,
|
||||
render: Optional[float] = None,
|
||||
log_fn: Optional[Callable[[dict], None]] = None
|
||||
) -> Dict[str, float]:
|
||||
@ -213,6 +214,8 @@ class Collector(object):
|
||||
:param n_episode: how many episodes you want to collect (in each
|
||||
environment).
|
||||
:type n_episode: int or list
|
||||
:param bool random: whether to use random policy for collecting data,
|
||||
defaults to ``False``.
|
||||
:param float render: the sleep time between rendering consecutive
|
||||
frames, defaults to ``None`` (no rendering).
|
||||
:param function log_fn: a function which receives env info, typically
|
||||
@ -252,8 +255,15 @@ class Collector(object):
|
||||
obs=self._obs, act=self._act, rew=self._rew,
|
||||
done=self._done, obs_next=None, info=self._info,
|
||||
policy=None)
|
||||
with torch.no_grad():
|
||||
result = self.policy(batch, self.state)
|
||||
if random:
|
||||
action_space = self.env.action_space
|
||||
if isinstance(action_space, list):
|
||||
result = Batch(act=[a.sample() for a in action_space])
|
||||
else:
|
||||
result = Batch(act=self._make_batch(action_space.sample()))
|
||||
else:
|
||||
with torch.no_grad():
|
||||
result = self.policy(batch, self.state)
|
||||
self.state = result.get('state', None)
|
||||
self._policy = to_numpy(result.policy) \
|
||||
if hasattr(result, 'policy') else [{}] * self.env_num
|
||||
|
Loading…
x
Reference in New Issue
Block a user