diff --git a/test/base/test_utils.py b/test/base/test_utils.py index aa72272..d0e1cef 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -1,8 +1,8 @@ import torch import numpy as np -from tianshou.utils import MovAvg from tianshou.utils.net.common import MLP, Net +from tianshou.utils import MovAvg, RunningMeanStd from tianshou.exploration import GaussianNoise, OUNoise from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic @@ -30,6 +30,16 @@ def test_moving_average(): assert np.allclose(stat.std() ** 2, 2) +def test_rms(): + rms = RunningMeanStd() + assert np.allclose(rms.mean, 0) + assert np.allclose(rms.var, 1) + rms.update(np.array([[[1, 2], [3, 5]]])) + rms.update(np.array([[[1, 2], [3, 4]], [[1, 2], [0, 0]]])) + assert np.allclose(rms.mean, np.array([[1, 2], [2, 3]]), atol=1e-3) + assert np.allclose(rms.var, np.array([[0, 0], [2, 14 / 3.]]), atol=1e-3) + + def test_net(): # here test the networks that does not appear in the other script bsz = 64 @@ -79,4 +89,5 @@ def test_net(): if __name__ == '__main__': test_noise() test_moving_average() + test_rms() test_net() diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 36c613e..adefcf0 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -2,6 +2,7 @@ import gym import numpy as np from typing import Any, List, Union, Optional, Callable +from tianshou.utils import RunningMeanStd from tianshou.env.worker import EnvWorker, DummyEnvWorker, SubprocEnvWorker, \ RayEnvWorker @@ -55,6 +56,13 @@ class BaseVectorEnv(gym.Env): :param float timeout: use in asynchronous simulation same as above, in each vectorized step it only deal with those environments spending time within ``timeout`` seconds. + :param bool norm_obs: Whether to track mean/std of data and normalise observation + on return. For now, observation normalization only support observation of + type np.ndarray. + :param obs_rms: class to track mean&std of observation. If not given, it will + initialize a new one. Usually in envs that is used to evaluate algorithm, + obs_rms should be passed in. Default to None. + :param bool update_obs_rms: Whether to update obs_rms. Default to True. """ def __init__( @@ -63,6 +71,9 @@ class BaseVectorEnv(gym.Env): worker_fn: Callable[[Callable[[], gym.Env]], EnvWorker], wait_num: Optional[int] = None, timeout: Optional[float] = None, + norm_obs: bool = False, + obs_rms: Optional[RunningMeanStd] = None, + update_obs_rms: bool = True, ) -> None: self._env_fns = env_fns # A VectorEnv contains a pool of EnvWorkers, which corresponds to @@ -90,6 +101,12 @@ class BaseVectorEnv(gym.Env): self.ready_id = list(range(self.env_num)) self.is_closed = False + # initialize observation running mean/std + self.norm_obs = norm_obs + self.update_obs_rms = update_obs_rms + self.obs_rms = RunningMeanStd() if obs_rms is None and norm_obs else obs_rms + self.__eps = np.finfo(np.float32).eps.item() + def _assert_is_not_closed(self) -> None: assert not self.is_closed, \ f"Methods of {self.__class__.__name__} cannot be called after close." @@ -149,7 +166,9 @@ class BaseVectorEnv(gym.Env): if self.is_async: self._assert_id(id) obs = np.stack([self.workers[i].reset() for i in id]) - return obs + if self.obs_rms and self.update_obs_rms: + self.obs_rms.update(obs) + return self.normalize_obs(obs) def step( self, @@ -219,7 +238,10 @@ class BaseVectorEnv(gym.Env): info["env_id"] = env_id result.append((obs, rew, done, info)) self.ready_id.append(env_id) - return list(map(np.stack, zip(*result))) + obs_stack, rew_stack, done_stack, info_stack = map(np.stack, zip(*result)) + if self.obs_rms and self.update_obs_rms: + self.obs_rms.update(obs_stack) + return [self.normalize_obs(obs_stack), rew_stack, done_stack, info_stack] def seed( self, seed: Optional[Union[int, List[int]]] = None @@ -255,15 +277,23 @@ class BaseVectorEnv(gym.Env): def close(self) -> None: """Close all of the environments. - This function will be called only once (if not, it will be called - during garbage collected). This way, ``close`` of all workers can be - assured. + This function will be called only once (if not, it will be called during + garbage collected). This way, ``close`` of all workers can be assured. """ self._assert_is_not_closed() for w in self.workers: w.close() self.is_closed = True + def normalize_obs(self, obs: np.ndarray) -> np.ndarray: + """Normalize observations by statistics in obs_rms.""" + clip_max = 10.0 # this magic number is from openai baselines + # see baselines/common/vec_env/vec_normalize.py#L10 + if self.obs_rms and self.norm_obs: + obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.__eps) + obs = np.clip(obs, -clip_max, clip_max) + return obs + def __del__(self) -> None: """Redirect to self.close().""" if not self.is_closed: @@ -275,17 +305,11 @@ class DummyVectorEnv(BaseVectorEnv): .. seealso:: - Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed - explanation. + Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. """ - def __init__( - self, - env_fns: List[Callable[[], gym.Env]], - wait_num: Optional[int] = None, - timeout: Optional[float] = None, - ) -> None: - super().__init__(env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout) + def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None: + super().__init__(env_fns, DummyEnvWorker, **kwargs) class SubprocVectorEnv(BaseVectorEnv): @@ -293,20 +317,14 @@ class SubprocVectorEnv(BaseVectorEnv): .. seealso:: - Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed - explanation. + Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. """ - def __init__( - self, - env_fns: List[Callable[[], gym.Env]], - wait_num: Optional[int] = None, - timeout: Optional[float] = None, - ) -> None: + def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None: def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: return SubprocEnvWorker(fn, share_memory=False) - super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout) + super().__init__(env_fns, worker_fn, **kwargs) class ShmemVectorEnv(BaseVectorEnv): @@ -316,20 +334,14 @@ class ShmemVectorEnv(BaseVectorEnv): .. seealso:: - Please refer to :class:`~tianshou.env.SubprocVectorEnv` for more - detailed explanation. + Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. """ - def __init__( - self, - env_fns: List[Callable[[], gym.Env]], - wait_num: Optional[int] = None, - timeout: Optional[float] = None, - ) -> None: + def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None: def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: return SubprocEnvWorker(fn, share_memory=True) - super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout) + super().__init__(env_fns, worker_fn, **kwargs) class RayVectorEnv(BaseVectorEnv): @@ -339,16 +351,10 @@ class RayVectorEnv(BaseVectorEnv): .. seealso:: - Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed - explanation. + Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage. """ - def __init__( - self, - env_fns: List[Callable[[], gym.Env]], - wait_num: Optional[int] = None, - timeout: Optional[float] = None, - ) -> None: + def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None: try: import ray except ImportError as e: @@ -357,4 +363,4 @@ class RayVectorEnv(BaseVectorEnv): ) from e if not ray.is_initialized(): ray.init() - super().__init__(env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout) + super().__init__(env_fns, RayEnvWorker, **kwargs) diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index b8cfa23..ccd8732 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -1,9 +1,10 @@ from tianshou.utils.config import tqdm_config -from tianshou.utils.moving_average import MovAvg +from tianshou.utils.statistics import MovAvg, RunningMeanStd from tianshou.utils.log_tools import BasicLogger, LazyLogger, BaseLogger __all__ = [ "MovAvg", + "RunningMeanStd", "tqdm_config", "BaseLogger", "BasicLogger", diff --git a/tianshou/utils/moving_average.py b/tianshou/utils/statistics.py similarity index 67% rename from tianshou/utils/moving_average.py rename to tianshou/utils/statistics.py index 58e8860..009ad4d 100644 --- a/tianshou/utils/moving_average.py +++ b/tianshou/utils/statistics.py @@ -66,3 +66,31 @@ class MovAvg(object): if len(self.cache) == 0: return 0 return np.std(self.cache) + + +class RunningMeanStd(object): + """Calulates the running mean and std of a data stream. + + https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + """ + + def __init__(self) -> None: + self.mean, self.var = 0.0, 1.0 + self.count = 0 + + def update(self, x: np.ndarray) -> None: + """Add a batch of item into RMS with the same shape, modify mean/var/count.""" + batch_mean, batch_var = np.mean(x, axis=0), np.var(x, axis=0) + batch_count = len(x) + + delta = batch_mean - self.mean + total_count = self.count + batch_count + + new_mean = self.mean + delta * batch_count / total_count + m_a = self.var * self.count + m_b = batch_var * batch_count + m_2 = m_a + m_b + delta ** 2 * self.count * batch_count / total_count + new_var = m_2 / total_count + + self.mean, self.var = new_mean, new_var + self.count = total_count