add random action in collector (fix #78)
This commit is contained in:
parent
397e92b0fc
commit
1a914336f7
@ -8,7 +8,7 @@ class MyTestEnv(gym.Env):
|
|||||||
self.size = size
|
self.size = size
|
||||||
self.sleep = sleep
|
self.sleep = sleep
|
||||||
self.dict_state = dict_state
|
self.dict_state = dict_state
|
||||||
self.action_space = Discrete(1)
|
self.action_space = Discrete(2)
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self, state=0):
|
def reset(self, state=0):
|
||||||
|
@ -56,6 +56,7 @@ def test_collector():
|
|||||||
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]
|
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]
|
||||||
|
|
||||||
venv = SubprocVectorEnv(env_fns)
|
venv = SubprocVectorEnv(env_fns)
|
||||||
|
dum = VectorEnv(env_fns)
|
||||||
policy = MyPolicy()
|
policy = MyPolicy()
|
||||||
env = env_fns[0]()
|
env = env_fns[0]()
|
||||||
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False),
|
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False),
|
||||||
@ -66,6 +67,7 @@ def test_collector():
|
|||||||
c0.collect(n_episode=3, log_fn=logger.log)
|
c0.collect(n_episode=3, log_fn=logger.log)
|
||||||
assert np.allclose(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
|
assert np.allclose(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
|
||||||
assert np.allclose(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2])
|
assert np.allclose(c0.buffer[:8].obs_next, [1, 2, 1, 2, 1, 2, 1, 2])
|
||||||
|
c0.collect(n_step=3, random=True)
|
||||||
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False),
|
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False),
|
||||||
preprocess_fn)
|
preprocess_fn)
|
||||||
c1.collect(n_step=6)
|
c1.collect(n_step=6)
|
||||||
@ -76,7 +78,8 @@ def test_collector():
|
|||||||
assert np.allclose(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2])
|
assert np.allclose(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2])
|
||||||
assert np.allclose(c1.buffer[11:21].obs_next,
|
assert np.allclose(c1.buffer[11:21].obs_next,
|
||||||
[1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
|
[1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
|
||||||
c2 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False),
|
c1.collect(n_episode=3, random=True)
|
||||||
|
c2 = Collector(policy, dum, ReplayBuffer(size=100, ignore_obs_next=False),
|
||||||
preprocess_fn)
|
preprocess_fn)
|
||||||
c2.collect(n_episode=[1, 2, 2, 2])
|
c2.collect(n_episode=[1, 2, 2, 2])
|
||||||
assert np.allclose(c2.buffer.obs_next[:26], [
|
assert np.allclose(c2.buffer.obs_next[:26], [
|
||||||
@ -87,6 +90,7 @@ def test_collector():
|
|||||||
assert np.allclose(c2.buffer.obs_next[26:54], [
|
assert np.allclose(c2.buffer.obs_next[26:54], [
|
||||||
1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 5,
|
1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 5,
|
||||||
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
|
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
|
||||||
|
c2.collect(n_episode=[1, 1, 1, 1], random=True)
|
||||||
|
|
||||||
|
|
||||||
def test_collector_with_dict_state():
|
def test_collector_with_dict_state():
|
||||||
|
@ -204,6 +204,7 @@ class Collector(object):
|
|||||||
def collect(self,
|
def collect(self,
|
||||||
n_step: int = 0,
|
n_step: int = 0,
|
||||||
n_episode: Union[int, List[int]] = 0,
|
n_episode: Union[int, List[int]] = 0,
|
||||||
|
random: bool = False,
|
||||||
render: Optional[float] = None,
|
render: Optional[float] = None,
|
||||||
log_fn: Optional[Callable[[dict], None]] = None
|
log_fn: Optional[Callable[[dict], None]] = None
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
@ -213,6 +214,8 @@ class Collector(object):
|
|||||||
:param n_episode: how many episodes you want to collect (in each
|
:param n_episode: how many episodes you want to collect (in each
|
||||||
environment).
|
environment).
|
||||||
:type n_episode: int or list
|
:type n_episode: int or list
|
||||||
|
:param bool random: whether to use random policy for collecting data,
|
||||||
|
defaults to ``False``.
|
||||||
:param float render: the sleep time between rendering consecutive
|
:param float render: the sleep time between rendering consecutive
|
||||||
frames, defaults to ``None`` (no rendering).
|
frames, defaults to ``None`` (no rendering).
|
||||||
:param function log_fn: a function which receives env info, typically
|
:param function log_fn: a function which receives env info, typically
|
||||||
@ -252,8 +255,15 @@ class Collector(object):
|
|||||||
obs=self._obs, act=self._act, rew=self._rew,
|
obs=self._obs, act=self._act, rew=self._rew,
|
||||||
done=self._done, obs_next=None, info=self._info,
|
done=self._done, obs_next=None, info=self._info,
|
||||||
policy=None)
|
policy=None)
|
||||||
with torch.no_grad():
|
if random:
|
||||||
result = self.policy(batch, self.state)
|
action_space = self.env.action_space
|
||||||
|
if isinstance(action_space, list):
|
||||||
|
result = Batch(act=[a.sample() for a in action_space])
|
||||||
|
else:
|
||||||
|
result = Batch(act=self._make_batch(action_space.sample()))
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
result = self.policy(batch, self.state)
|
||||||
self.state = result.get('state', None)
|
self.state = result.get('state', None)
|
||||||
self._policy = to_numpy(result.policy) \
|
self._policy = to_numpy(result.policy) \
|
||||||
if hasattr(result, 'policy') else [{}] * self.env_num
|
if hasattr(result, 'policy') else [{}] * self.env_num
|
||||||
|
Loading…
x
Reference in New Issue
Block a user