Vector env enable select worker (#132)

* Enable selecting worker for vector env step method.

* Update collector to match new vecenv selective worker behavior.

* Bug fix.

* Fix rebase

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
Alexis DUBURCQ 2020-07-13 16:38:42 +02:00 committed by Trinkle23897
parent 26fb87433d
commit 09e10e384f
3 changed files with 77 additions and 75 deletions

View File

@ -353,6 +353,8 @@ class Batch:
str, slice, int, np.integer, np.ndarray, List[int]], str, slice, int, np.integer, np.ndarray, List[int]],
value: Any) -> None: value: Any) -> None:
"""Assign value to self[index].""" """Assign value to self[index]."""
if isinstance(value, (list, tuple)):
value = np.asanyarray(value)
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
if not issubclass(value.dtype.type, (np.bool_, np.number)): if not issubclass(value.dtype.type, (np.bool_, np.number)):
value = value.astype(np.object) value = value.astype(np.object)
@ -360,8 +362,11 @@ class Batch:
self.__dict__[index] = value self.__dict__[index] = value
return return
if not isinstance(value, (dict, Batch)): if not isinstance(value, (dict, Batch)):
raise TypeError("Batch does not supported value type " if _is_batch_set(value):
f"{type(value)} for item assignment.") value = Batch(value)
else:
raise TypeError("Batch does not supported value type "
f"{type(value)} for item assignment.")
if not set(value.keys()).issubset(self.__dict__.keys()): if not set(value.keys()).issubset(self.__dict__.keys()):
raise KeyError( raise KeyError(
"Creating keys is not supported by item assignment.") "Creating keys is not supported by item assignment.")

View File

@ -320,10 +320,13 @@ class Collector(object):
self._reset_state(i) self._reset_state(i)
obs_next = self.data.obs_next obs_next = self.data.obs_next
if sum(self.data.done): if sum(self.data.done):
obs_next = self.env.reset(np.where(self.data.done)[0]) env_ind = np.where(self.data.done)[0]
obs_reset = self.env.reset(env_ind)
if self.preprocess_fn: if self.preprocess_fn:
obs_next = self.preprocess_fn(obs=obs_next).get( obs_next[env_ind] = self.preprocess_fn(
'obs', obs_next) obs=obs_reset).get('obs', obs_reset)
else:
obs_next[env_ind] = obs_reset
self.data.obs_next = obs_next self.data.obs_next = obs_next
if n_episode != 0: if n_episode != 0:
if isinstance(n_episode, list) and \ if isinstance(n_episode, list) and \

134
tianshou/env/vecenv.py vendored
View File

@ -40,16 +40,12 @@ class BaseVectorEnv(ABC, gym.Env):
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None: def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
self._env_fns = env_fns self._env_fns = env_fns
self.env_num = len(env_fns) self.env_num = len(env_fns)
self._obs = None
self._rew = None
self._done = None
self._info = None
def __len__(self) -> int: def __len__(self) -> int:
"""Return len(self), which is the number of environments.""" """Return len(self), which is the number of environments."""
return self.env_num return self.env_num
def __getattribute__(self, key): def __getattribute__(self, key: str):
"""Switch between the default attribute getter or one """Switch between the default attribute getter or one
looking at wrapped environment level depending on the key.""" looking at wrapped environment level depending on the key."""
if key not in ('observation_space', 'action_space'): if key not in ('observation_space', 'action_space'):
@ -58,7 +54,7 @@ class BaseVectorEnv(ABC, gym.Env):
return self.__getattr__(key) return self.__getattr__(key)
@abstractmethod @abstractmethod
def __getattr__(self, key): def __getattr__(self, key: str):
"""Try to retrieve an attribute from each individual wrapped """Try to retrieve an attribute from each individual wrapped
environment, if it does not belong to the wrapping vector environment, if it does not belong to the wrapping vector
environment class.""" environment class."""
@ -73,11 +69,15 @@ class BaseVectorEnv(ABC, gym.Env):
pass pass
@abstractmethod @abstractmethod
def step(self, action: np.ndarray def step(self,
action: np.ndarray,
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Run one timestep of all the environments dynamics. When the end of """Run one timestep of all the environments dynamics if id is
episode is reached, you are responsible for calling reset(id) to reset ``None``, otherwise run one timestep for some environments
this environments state. with given id, either an int or a list. When the end of
episode is reached, you are responsible for calling reset(id)
to reset this environments state.
Accept a batch of action and return a tuple (obs, rew, done, info). Accept a batch of action and return a tuple (obs, rew, done, info).
@ -143,24 +143,24 @@ class VectorEnv(BaseVectorEnv):
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray: def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
if id is None: if id is None:
self._obs = np.stack([e.reset() for e in self.envs]) id = range(self.env_num)
else: elif np.isscalar(id):
if np.isscalar(id): id = [id]
id = [id] obs = np.stack([self.envs[i].reset() for i in id])
for i in id: return obs
self._obs[i] = self.envs[i].reset()
return self._obs
def step(self, action: np.ndarray def step(self,
action: np.ndarray,
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
assert len(action) == self.env_num if id is None:
result = [e.step(a) for e, a in zip(self.envs, action)] id = range(self.env_num)
self._obs, self._rew, self._done, self._info = zip(*result) elif np.isscalar(id):
self._obs = np.stack(self._obs) id = [id]
self._rew = np.stack(self._rew) assert len(action) == len(id)
self._done = np.stack(self._done) result = [self.envs[i].step(action[i]) for i in id]
self._info = np.stack(self._info) obs, rew, done, info = map(np.stack, zip(*result))
return self._obs, self._rew, self._done, self._info return obs, rew, done, info
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]: def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
if np.isscalar(seed): if np.isscalar(seed):
@ -241,33 +241,30 @@ class SubprocVectorEnv(BaseVectorEnv):
p.send(['getattr', key]) p.send(['getattr', key])
return [p.recv() for p in self.parent_remote] return [p.recv() for p in self.parent_remote]
def step(self, action: np.ndarray def step(self,
action: np.ndarray,
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
assert len(action) == self.env_num if id is None:
for p, a in zip(self.parent_remote, action): id = range(self.env_num)
p.send(['step', a]) elif np.isscalar(id):
result = [p.recv() for p in self.parent_remote] id = [id]
self._obs, self._rew, self._done, self._info = zip(*result) assert len(action) == len(id)
self._obs = np.stack(self._obs) for i, j in enumerate(id):
self._rew = np.stack(self._rew) self.parent_remote[j].send(['step', action[i]])
self._done = np.stack(self._done) result = [self.parent_remote[i].recv() for i in id]
self._info = np.stack(self._info) obs, rew, done, info = map(np.stack, zip(*result))
return self._obs, self._rew, self._done, self._info return obs, rew, done, info
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray: def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
if id is None: if id is None:
for p in self.parent_remote: id = range(self.env_num)
p.send(['reset', None]) elif np.isscalar(id):
self._obs = np.stack([p.recv() for p in self.parent_remote]) id = [id]
return self._obs for i in id:
else: self.parent_remote[i].send(['reset', None])
if np.isscalar(id): obs = np.stack([self.parent_remote[i].recv() for i in id])
id = [id] return obs
for i in id:
self.parent_remote[i].send(['reset', None])
for i in id:
self._obs[i] = self.parent_remote[i].recv()
return self._obs
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]: def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
if np.isscalar(seed): if np.isscalar(seed):
@ -320,32 +317,29 @@ class RayVectorEnv(BaseVectorEnv):
for e in env_fns] for e in env_fns]
def __getattr__(self, key): def __getattr__(self, key):
return ray.get([e.getattr.remote(key) for e in self.envs]) return ray.get([e.__getattr__.remote(key) for e in self.envs])
def step(self, action: np.ndarray def step(self,
action: np.ndarray,
id: Optional[Union[int, List[int]]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
assert len(action) == self.env_num if id is None:
result = ray.get([e.step.remote(a) for e, a in zip(self.envs, action)]) id = range(self.env_num)
self._obs, self._rew, self._done, self._info = zip(*result) elif np.isscalar(id):
self._obs = np.stack(self._obs) id = [id]
self._rew = np.stack(self._rew) assert len(action) == len(id)
self._done = np.stack(self._done) result = ray.get([self.envs[j].step.remote(action[i])
self._info = np.stack(self._info) for i, j in enumerate(id)])
return self._obs, self._rew, self._done, self._info obs, rew, done, info = map(np.stack, zip(*result))
return obs, rew, done, info
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray: def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
if id is None: if id is None:
result_obj = [e.reset.remote() for e in self.envs] id = range(self.env_num)
self._obs = np.stack(ray.get(result_obj)) elif np.isscalar(id):
else: id = [id]
result_obj = [] obs = np.stack(ray.get([self.envs[i].reset.remote() for i in id]))
if np.isscalar(id): return obs
id = [id]
for i in id:
result_obj.append(self.envs[i].reset.remote())
for _, i in enumerate(id):
self._obs[i] = ray.get(result_obj[_])
return self._obs
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]: def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
if not hasattr(self.envs[0], 'seed'): if not hasattr(self.envs[0], 'seed'):