Vector env enable select worker (#132)

* Enable selecting worker for vector env step method.

* Update collector to match new vecenv selective worker behavior.

* Bug fix.

* Fix rebase

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
Alexis DUBURCQ 2020-07-13 16:38:42 +02:00 committed by Trinkle23897
parent 26fb87433d
commit 09e10e384f
3 changed files with 77 additions and 75 deletions

View File

@ -353,6 +353,8 @@ class Batch:
str, slice, int, np.integer, np.ndarray, List[int]],
value: Any) -> None:
"""Assign value to self[index]."""
if isinstance(value, (list, tuple)):
value = np.asanyarray(value)
if isinstance(value, np.ndarray):
if not issubclass(value.dtype.type, (np.bool_, np.number)):
value = value.astype(np.object)
@ -360,8 +362,11 @@ class Batch:
self.__dict__[index] = value
return
if not isinstance(value, (dict, Batch)):
raise TypeError("Batch does not supported value type "
f"{type(value)} for item assignment.")
if _is_batch_set(value):
value = Batch(value)
else:
raise TypeError("Batch does not supported value type "
f"{type(value)} for item assignment.")
if not set(value.keys()).issubset(self.__dict__.keys()):
raise KeyError(
"Creating keys is not supported by item assignment.")

View File

@ -320,10 +320,13 @@ class Collector(object):
self._reset_state(i)
obs_next = self.data.obs_next
if sum(self.data.done):
obs_next = self.env.reset(np.where(self.data.done)[0])
env_ind = np.where(self.data.done)[0]
obs_reset = self.env.reset(env_ind)
if self.preprocess_fn:
obs_next = self.preprocess_fn(obs=obs_next).get(
'obs', obs_next)
obs_next[env_ind] = self.preprocess_fn(
obs=obs_reset).get('obs', obs_reset)
else:
obs_next[env_ind] = obs_reset
self.data.obs_next = obs_next
if n_episode != 0:
if isinstance(n_episode, list) and \

134
tianshou/env/vecenv.py vendored
View File

@ -40,16 +40,12 @@ class BaseVectorEnv(ABC, gym.Env):
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
self._env_fns = env_fns
self.env_num = len(env_fns)
self._obs = None
self._rew = None
self._done = None
self._info = None
def __len__(self) -> int:
"""Return len(self), which is the number of environments."""
return self.env_num
def __getattribute__(self, key):
def __getattribute__(self, key: str):
"""Switch between the default attribute getter or one
looking at wrapped environment level depending on the key."""
if key not in ('observation_space', 'action_space'):
@ -58,7 +54,7 @@ class BaseVectorEnv(ABC, gym.Env):
return self.__getattr__(key)
@abstractmethod
def __getattr__(self, key):
def __getattr__(self, key: str):
"""Try to retrieve an attribute from each individual wrapped
environment, if it does not belong to the wrapping vector
environment class."""
@ -73,11 +69,15 @@ class BaseVectorEnv(ABC, gym.Env):
pass
@abstractmethod
def step(self, action: np.ndarray
def step(self,
action: np.ndarray,
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Run one timestep of all the environments dynamics. When the end of
episode is reached, you are responsible for calling reset(id) to reset
this environments state.
"""Run one timestep of all the environments dynamics if id is
``None``, otherwise run one timestep for some environments
with given id, either an int or a list. When the end of
episode is reached, you are responsible for calling reset(id)
to reset this environments state.
Accept a batch of action and return a tuple (obs, rew, done, info).
@ -143,24 +143,24 @@ class VectorEnv(BaseVectorEnv):
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
if id is None:
self._obs = np.stack([e.reset() for e in self.envs])
else:
if np.isscalar(id):
id = [id]
for i in id:
self._obs[i] = self.envs[i].reset()
return self._obs
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
obs = np.stack([self.envs[i].reset() for i in id])
return obs
def step(self, action: np.ndarray
def step(self,
action: np.ndarray,
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
assert len(action) == self.env_num
result = [e.step(a) for e, a in zip(self.envs, action)]
self._obs, self._rew, self._done, self._info = zip(*result)
self._obs = np.stack(self._obs)
self._rew = np.stack(self._rew)
self._done = np.stack(self._done)
self._info = np.stack(self._info)
return self._obs, self._rew, self._done, self._info
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
assert len(action) == len(id)
result = [self.envs[i].step(action[i]) for i in id]
obs, rew, done, info = map(np.stack, zip(*result))
return obs, rew, done, info
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
if np.isscalar(seed):
@ -241,33 +241,30 @@ class SubprocVectorEnv(BaseVectorEnv):
p.send(['getattr', key])
return [p.recv() for p in self.parent_remote]
def step(self, action: np.ndarray
def step(self,
action: np.ndarray,
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
assert len(action) == self.env_num
for p, a in zip(self.parent_remote, action):
p.send(['step', a])
result = [p.recv() for p in self.parent_remote]
self._obs, self._rew, self._done, self._info = zip(*result)
self._obs = np.stack(self._obs)
self._rew = np.stack(self._rew)
self._done = np.stack(self._done)
self._info = np.stack(self._info)
return self._obs, self._rew, self._done, self._info
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
assert len(action) == len(id)
for i, j in enumerate(id):
self.parent_remote[j].send(['step', action[i]])
result = [self.parent_remote[i].recv() for i in id]
obs, rew, done, info = map(np.stack, zip(*result))
return obs, rew, done, info
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
if id is None:
for p in self.parent_remote:
p.send(['reset', None])
self._obs = np.stack([p.recv() for p in self.parent_remote])
return self._obs
else:
if np.isscalar(id):
id = [id]
for i in id:
self.parent_remote[i].send(['reset', None])
for i in id:
self._obs[i] = self.parent_remote[i].recv()
return self._obs
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
for i in id:
self.parent_remote[i].send(['reset', None])
obs = np.stack([self.parent_remote[i].recv() for i in id])
return obs
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
if np.isscalar(seed):
@ -320,32 +317,29 @@ class RayVectorEnv(BaseVectorEnv):
for e in env_fns]
def __getattr__(self, key):
return ray.get([e.getattr.remote(key) for e in self.envs])
return ray.get([e.__getattr__.remote(key) for e in self.envs])
def step(self, action: np.ndarray
def step(self,
action: np.ndarray,
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
assert len(action) == self.env_num
result = ray.get([e.step.remote(a) for e, a in zip(self.envs, action)])
self._obs, self._rew, self._done, self._info = zip(*result)
self._obs = np.stack(self._obs)
self._rew = np.stack(self._rew)
self._done = np.stack(self._done)
self._info = np.stack(self._info)
return self._obs, self._rew, self._done, self._info
if id is None:
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
assert len(action) == len(id)
result = ray.get([self.envs[j].step.remote(action[i])
for i, j in enumerate(id)])
obs, rew, done, info = map(np.stack, zip(*result))
return obs, rew, done, info
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
if id is None:
result_obj = [e.reset.remote() for e in self.envs]
self._obs = np.stack(ray.get(result_obj))
else:
result_obj = []
if np.isscalar(id):
id = [id]
for i in id:
result_obj.append(self.envs[i].reset.remote())
for _, i in enumerate(id):
self._obs[i] = ray.get(result_obj[_])
return self._obs
id = range(self.env_num)
elif np.isscalar(id):
id = [id]
obs = np.stack(ray.get([self.envs[i].reset.remote() for i in id]))
return obs
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
if not hasattr(self.envs[0], 'seed'):