Vector env enable select worker (#132)
* Enable selecting worker for vector env step method. * Update collector to match new vecenv selective worker behavior. * Bug fix. * Fix rebase Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
parent
26fb87433d
commit
09e10e384f
@ -353,6 +353,8 @@ class Batch:
|
||||
str, slice, int, np.integer, np.ndarray, List[int]],
|
||||
value: Any) -> None:
|
||||
"""Assign value to self[index]."""
|
||||
if isinstance(value, (list, tuple)):
|
||||
value = np.asanyarray(value)
|
||||
if isinstance(value, np.ndarray):
|
||||
if not issubclass(value.dtype.type, (np.bool_, np.number)):
|
||||
value = value.astype(np.object)
|
||||
@ -360,8 +362,11 @@ class Batch:
|
||||
self.__dict__[index] = value
|
||||
return
|
||||
if not isinstance(value, (dict, Batch)):
|
||||
raise TypeError("Batch does not supported value type "
|
||||
f"{type(value)} for item assignment.")
|
||||
if _is_batch_set(value):
|
||||
value = Batch(value)
|
||||
else:
|
||||
raise TypeError("Batch does not supported value type "
|
||||
f"{type(value)} for item assignment.")
|
||||
if not set(value.keys()).issubset(self.__dict__.keys()):
|
||||
raise KeyError(
|
||||
"Creating keys is not supported by item assignment.")
|
||||
|
@ -320,10 +320,13 @@ class Collector(object):
|
||||
self._reset_state(i)
|
||||
obs_next = self.data.obs_next
|
||||
if sum(self.data.done):
|
||||
obs_next = self.env.reset(np.where(self.data.done)[0])
|
||||
env_ind = np.where(self.data.done)[0]
|
||||
obs_reset = self.env.reset(env_ind)
|
||||
if self.preprocess_fn:
|
||||
obs_next = self.preprocess_fn(obs=obs_next).get(
|
||||
'obs', obs_next)
|
||||
obs_next[env_ind] = self.preprocess_fn(
|
||||
obs=obs_reset).get('obs', obs_reset)
|
||||
else:
|
||||
obs_next[env_ind] = obs_reset
|
||||
self.data.obs_next = obs_next
|
||||
if n_episode != 0:
|
||||
if isinstance(n_episode, list) and \
|
||||
|
134
tianshou/env/vecenv.py
vendored
134
tianshou/env/vecenv.py
vendored
@ -40,16 +40,12 @@ class BaseVectorEnv(ABC, gym.Env):
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
||||
self._env_fns = env_fns
|
||||
self.env_num = len(env_fns)
|
||||
self._obs = None
|
||||
self._rew = None
|
||||
self._done = None
|
||||
self._info = None
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return len(self), which is the number of environments."""
|
||||
return self.env_num
|
||||
|
||||
def __getattribute__(self, key):
|
||||
def __getattribute__(self, key: str):
|
||||
"""Switch between the default attribute getter or one
|
||||
looking at wrapped environment level depending on the key."""
|
||||
if key not in ('observation_space', 'action_space'):
|
||||
@ -58,7 +54,7 @@ class BaseVectorEnv(ABC, gym.Env):
|
||||
return self.__getattr__(key)
|
||||
|
||||
@abstractmethod
|
||||
def __getattr__(self, key):
|
||||
def __getattr__(self, key: str):
|
||||
"""Try to retrieve an attribute from each individual wrapped
|
||||
environment, if it does not belong to the wrapping vector
|
||||
environment class."""
|
||||
@ -73,11 +69,15 @@ class BaseVectorEnv(ABC, gym.Env):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def step(self, action: np.ndarray
|
||||
def step(self,
|
||||
action: np.ndarray,
|
||||
id: Optional[Union[int, List[int]]] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Run one timestep of all the environments’ dynamics. When the end of
|
||||
episode is reached, you are responsible for calling reset(id) to reset
|
||||
this environment’s state.
|
||||
"""Run one timestep of all the environments’ dynamics if id is
|
||||
``None``, otherwise run one timestep for some environments
|
||||
with given id, either an int or a list. When the end of
|
||||
episode is reached, you are responsible for calling reset(id)
|
||||
to reset this environment’s state.
|
||||
|
||||
Accept a batch of action and return a tuple (obs, rew, done, info).
|
||||
|
||||
@ -143,24 +143,24 @@ class VectorEnv(BaseVectorEnv):
|
||||
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||
if id is None:
|
||||
self._obs = np.stack([e.reset() for e in self.envs])
|
||||
else:
|
||||
if np.isscalar(id):
|
||||
id = [id]
|
||||
for i in id:
|
||||
self._obs[i] = self.envs[i].reset()
|
||||
return self._obs
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
obs = np.stack([self.envs[i].reset() for i in id])
|
||||
return obs
|
||||
|
||||
def step(self, action: np.ndarray
|
||||
def step(self,
|
||||
action: np.ndarray,
|
||||
id: Optional[Union[int, List[int]]] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
assert len(action) == self.env_num
|
||||
result = [e.step(a) for e, a in zip(self.envs, action)]
|
||||
self._obs, self._rew, self._done, self._info = zip(*result)
|
||||
self._obs = np.stack(self._obs)
|
||||
self._rew = np.stack(self._rew)
|
||||
self._done = np.stack(self._done)
|
||||
self._info = np.stack(self._info)
|
||||
return self._obs, self._rew, self._done, self._info
|
||||
if id is None:
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
assert len(action) == len(id)
|
||||
result = [self.envs[i].step(action[i]) for i in id]
|
||||
obs, rew, done, info = map(np.stack, zip(*result))
|
||||
return obs, rew, done, info
|
||||
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
|
||||
if np.isscalar(seed):
|
||||
@ -241,33 +241,30 @@ class SubprocVectorEnv(BaseVectorEnv):
|
||||
p.send(['getattr', key])
|
||||
return [p.recv() for p in self.parent_remote]
|
||||
|
||||
def step(self, action: np.ndarray
|
||||
def step(self,
|
||||
action: np.ndarray,
|
||||
id: Optional[Union[int, List[int]]] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
assert len(action) == self.env_num
|
||||
for p, a in zip(self.parent_remote, action):
|
||||
p.send(['step', a])
|
||||
result = [p.recv() for p in self.parent_remote]
|
||||
self._obs, self._rew, self._done, self._info = zip(*result)
|
||||
self._obs = np.stack(self._obs)
|
||||
self._rew = np.stack(self._rew)
|
||||
self._done = np.stack(self._done)
|
||||
self._info = np.stack(self._info)
|
||||
return self._obs, self._rew, self._done, self._info
|
||||
if id is None:
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
assert len(action) == len(id)
|
||||
for i, j in enumerate(id):
|
||||
self.parent_remote[j].send(['step', action[i]])
|
||||
result = [self.parent_remote[i].recv() for i in id]
|
||||
obs, rew, done, info = map(np.stack, zip(*result))
|
||||
return obs, rew, done, info
|
||||
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||
if id is None:
|
||||
for p in self.parent_remote:
|
||||
p.send(['reset', None])
|
||||
self._obs = np.stack([p.recv() for p in self.parent_remote])
|
||||
return self._obs
|
||||
else:
|
||||
if np.isscalar(id):
|
||||
id = [id]
|
||||
for i in id:
|
||||
self.parent_remote[i].send(['reset', None])
|
||||
for i in id:
|
||||
self._obs[i] = self.parent_remote[i].recv()
|
||||
return self._obs
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
for i in id:
|
||||
self.parent_remote[i].send(['reset', None])
|
||||
obs = np.stack([self.parent_remote[i].recv() for i in id])
|
||||
return obs
|
||||
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
|
||||
if np.isscalar(seed):
|
||||
@ -320,32 +317,29 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
for e in env_fns]
|
||||
|
||||
def __getattr__(self, key):
|
||||
return ray.get([e.getattr.remote(key) for e in self.envs])
|
||||
return ray.get([e.__getattr__.remote(key) for e in self.envs])
|
||||
|
||||
def step(self, action: np.ndarray
|
||||
def step(self,
|
||||
action: np.ndarray,
|
||||
id: Optional[Union[int, List[int]]] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
assert len(action) == self.env_num
|
||||
result = ray.get([e.step.remote(a) for e, a in zip(self.envs, action)])
|
||||
self._obs, self._rew, self._done, self._info = zip(*result)
|
||||
self._obs = np.stack(self._obs)
|
||||
self._rew = np.stack(self._rew)
|
||||
self._done = np.stack(self._done)
|
||||
self._info = np.stack(self._info)
|
||||
return self._obs, self._rew, self._done, self._info
|
||||
if id is None:
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
assert len(action) == len(id)
|
||||
result = ray.get([self.envs[j].step.remote(action[i])
|
||||
for i, j in enumerate(id)])
|
||||
obs, rew, done, info = map(np.stack, zip(*result))
|
||||
return obs, rew, done, info
|
||||
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||
if id is None:
|
||||
result_obj = [e.reset.remote() for e in self.envs]
|
||||
self._obs = np.stack(ray.get(result_obj))
|
||||
else:
|
||||
result_obj = []
|
||||
if np.isscalar(id):
|
||||
id = [id]
|
||||
for i in id:
|
||||
result_obj.append(self.envs[i].reset.remote())
|
||||
for _, i in enumerate(id):
|
||||
self._obs[i] = ray.get(result_obj[_])
|
||||
return self._obs
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
obs = np.stack(ray.get([self.envs[i].reset.remote() for i in id]))
|
||||
return obs
|
||||
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
|
||||
if not hasattr(self.envs[0], 'seed'):
|
||||
|
Loading…
x
Reference in New Issue
Block a user