diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 415c8f3..23d8599 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -353,6 +353,8 @@ class Batch: str, slice, int, np.integer, np.ndarray, List[int]], value: Any) -> None: """Assign value to self[index].""" + if isinstance(value, (list, tuple)): + value = np.asanyarray(value) if isinstance(value, np.ndarray): if not issubclass(value.dtype.type, (np.bool_, np.number)): value = value.astype(np.object) @@ -360,8 +362,11 @@ class Batch: self.__dict__[index] = value return if not isinstance(value, (dict, Batch)): - raise TypeError("Batch does not supported value type " - f"{type(value)} for item assignment.") + if _is_batch_set(value): + value = Batch(value) + else: + raise TypeError("Batch does not supported value type " + f"{type(value)} for item assignment.") if not set(value.keys()).issubset(self.__dict__.keys()): raise KeyError( "Creating keys is not supported by item assignment.") diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 40cd739..2bacf14 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -320,10 +320,13 @@ class Collector(object): self._reset_state(i) obs_next = self.data.obs_next if sum(self.data.done): - obs_next = self.env.reset(np.where(self.data.done)[0]) + env_ind = np.where(self.data.done)[0] + obs_reset = self.env.reset(env_ind) if self.preprocess_fn: - obs_next = self.preprocess_fn(obs=obs_next).get( - 'obs', obs_next) + obs_next[env_ind] = self.preprocess_fn( + obs=obs_reset).get('obs', obs_reset) + else: + obs_next[env_ind] = obs_reset self.data.obs_next = obs_next if n_episode != 0: if isinstance(n_episode, list) and \ diff --git a/tianshou/env/vecenv.py b/tianshou/env/vecenv.py index 26e9835..93c3882 100644 --- a/tianshou/env/vecenv.py +++ b/tianshou/env/vecenv.py @@ -40,16 +40,12 @@ class BaseVectorEnv(ABC, gym.Env): def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None: self._env_fns = env_fns self.env_num = len(env_fns) - self._obs = None - self._rew = None - self._done = None - self._info = None def __len__(self) -> int: """Return len(self), which is the number of environments.""" return self.env_num - def __getattribute__(self, key): + def __getattribute__(self, key: str): """Switch between the default attribute getter or one looking at wrapped environment level depending on the key.""" if key not in ('observation_space', 'action_space'): @@ -58,7 +54,7 @@ class BaseVectorEnv(ABC, gym.Env): return self.__getattr__(key) @abstractmethod - def __getattr__(self, key): + def __getattr__(self, key: str): """Try to retrieve an attribute from each individual wrapped environment, if it does not belong to the wrapping vector environment class.""" @@ -73,11 +69,15 @@ class BaseVectorEnv(ABC, gym.Env): pass @abstractmethod - def step(self, action: np.ndarray + def step(self, + action: np.ndarray, + id: Optional[Union[int, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """Run one timestep of all the environments’ dynamics. When the end of - episode is reached, you are responsible for calling reset(id) to reset - this environment’s state. + """Run one timestep of all the environments’ dynamics if id is + ``None``, otherwise run one timestep for some environments + with given id, either an int or a list. When the end of + episode is reached, you are responsible for calling reset(id) + to reset this environment’s state. Accept a batch of action and return a tuple (obs, rew, done, info). @@ -143,24 +143,24 @@ class VectorEnv(BaseVectorEnv): def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray: if id is None: - self._obs = np.stack([e.reset() for e in self.envs]) - else: - if np.isscalar(id): - id = [id] - for i in id: - self._obs[i] = self.envs[i].reset() - return self._obs + id = range(self.env_num) + elif np.isscalar(id): + id = [id] + obs = np.stack([self.envs[i].reset() for i in id]) + return obs - def step(self, action: np.ndarray + def step(self, + action: np.ndarray, + id: Optional[Union[int, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - assert len(action) == self.env_num - result = [e.step(a) for e, a in zip(self.envs, action)] - self._obs, self._rew, self._done, self._info = zip(*result) - self._obs = np.stack(self._obs) - self._rew = np.stack(self._rew) - self._done = np.stack(self._done) - self._info = np.stack(self._info) - return self._obs, self._rew, self._done, self._info + if id is None: + id = range(self.env_num) + elif np.isscalar(id): + id = [id] + assert len(action) == len(id) + result = [self.envs[i].step(action[i]) for i in id] + obs, rew, done, info = map(np.stack, zip(*result)) + return obs, rew, done, info def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]: if np.isscalar(seed): @@ -241,33 +241,30 @@ class SubprocVectorEnv(BaseVectorEnv): p.send(['getattr', key]) return [p.recv() for p in self.parent_remote] - def step(self, action: np.ndarray + def step(self, + action: np.ndarray, + id: Optional[Union[int, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - assert len(action) == self.env_num - for p, a in zip(self.parent_remote, action): - p.send(['step', a]) - result = [p.recv() for p in self.parent_remote] - self._obs, self._rew, self._done, self._info = zip(*result) - self._obs = np.stack(self._obs) - self._rew = np.stack(self._rew) - self._done = np.stack(self._done) - self._info = np.stack(self._info) - return self._obs, self._rew, self._done, self._info + if id is None: + id = range(self.env_num) + elif np.isscalar(id): + id = [id] + assert len(action) == len(id) + for i, j in enumerate(id): + self.parent_remote[j].send(['step', action[i]]) + result = [self.parent_remote[i].recv() for i in id] + obs, rew, done, info = map(np.stack, zip(*result)) + return obs, rew, done, info def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray: if id is None: - for p in self.parent_remote: - p.send(['reset', None]) - self._obs = np.stack([p.recv() for p in self.parent_remote]) - return self._obs - else: - if np.isscalar(id): - id = [id] - for i in id: - self.parent_remote[i].send(['reset', None]) - for i in id: - self._obs[i] = self.parent_remote[i].recv() - return self._obs + id = range(self.env_num) + elif np.isscalar(id): + id = [id] + for i in id: + self.parent_remote[i].send(['reset', None]) + obs = np.stack([self.parent_remote[i].recv() for i in id]) + return obs def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]: if np.isscalar(seed): @@ -320,32 +317,29 @@ class RayVectorEnv(BaseVectorEnv): for e in env_fns] def __getattr__(self, key): - return ray.get([e.getattr.remote(key) for e in self.envs]) + return ray.get([e.__getattr__.remote(key) for e in self.envs]) - def step(self, action: np.ndarray + def step(self, + action: np.ndarray, + id: Optional[Union[int, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - assert len(action) == self.env_num - result = ray.get([e.step.remote(a) for e, a in zip(self.envs, action)]) - self._obs, self._rew, self._done, self._info = zip(*result) - self._obs = np.stack(self._obs) - self._rew = np.stack(self._rew) - self._done = np.stack(self._done) - self._info = np.stack(self._info) - return self._obs, self._rew, self._done, self._info + if id is None: + id = range(self.env_num) + elif np.isscalar(id): + id = [id] + assert len(action) == len(id) + result = ray.get([self.envs[j].step.remote(action[i]) + for i, j in enumerate(id)]) + obs, rew, done, info = map(np.stack, zip(*result)) + return obs, rew, done, info def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray: if id is None: - result_obj = [e.reset.remote() for e in self.envs] - self._obs = np.stack(ray.get(result_obj)) - else: - result_obj = [] - if np.isscalar(id): - id = [id] - for i in id: - result_obj.append(self.envs[i].reset.remote()) - for _, i in enumerate(id): - self._obs[i] = ray.get(result_obj[_]) - return self._obs + id = range(self.env_num) + elif np.isscalar(id): + id = [id] + obs = np.stack(ray.get([self.envs[i].reset.remote() for i in id])) + return obs def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]: if not hasattr(self.envs[0], 'seed'):