support observation normalization in BaseVectorEnv (#308)

add RunningMeanStd
This commit is contained in:
ChenDRAG 2021-03-11 20:50:20 +08:00 committed by GitHub
parent 5c53f8c1f8
commit 243ab43b3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 89 additions and 43 deletions

View File

@ -1,8 +1,8 @@
import torch import torch
import numpy as np import numpy as np
from tianshou.utils import MovAvg
from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.common import MLP, Net
from tianshou.utils import MovAvg, RunningMeanStd
from tianshou.exploration import GaussianNoise, OUNoise from tianshou.exploration import GaussianNoise, OUNoise
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
@ -30,6 +30,16 @@ def test_moving_average():
assert np.allclose(stat.std() ** 2, 2) assert np.allclose(stat.std() ** 2, 2)
def test_rms():
rms = RunningMeanStd()
assert np.allclose(rms.mean, 0)
assert np.allclose(rms.var, 1)
rms.update(np.array([[[1, 2], [3, 5]]]))
rms.update(np.array([[[1, 2], [3, 4]], [[1, 2], [0, 0]]]))
assert np.allclose(rms.mean, np.array([[1, 2], [2, 3]]), atol=1e-3)
assert np.allclose(rms.var, np.array([[0, 0], [2, 14 / 3.]]), atol=1e-3)
def test_net(): def test_net():
# here test the networks that does not appear in the other script # here test the networks that does not appear in the other script
bsz = 64 bsz = 64
@ -79,4 +89,5 @@ def test_net():
if __name__ == '__main__': if __name__ == '__main__':
test_noise() test_noise()
test_moving_average() test_moving_average()
test_rms()
test_net() test_net()

88
tianshou/env/venvs.py vendored
View File

@ -2,6 +2,7 @@ import gym
import numpy as np import numpy as np
from typing import Any, List, Union, Optional, Callable from typing import Any, List, Union, Optional, Callable
from tianshou.utils import RunningMeanStd
from tianshou.env.worker import EnvWorker, DummyEnvWorker, SubprocEnvWorker, \ from tianshou.env.worker import EnvWorker, DummyEnvWorker, SubprocEnvWorker, \
RayEnvWorker RayEnvWorker
@ -55,6 +56,13 @@ class BaseVectorEnv(gym.Env):
:param float timeout: use in asynchronous simulation same as above, in each :param float timeout: use in asynchronous simulation same as above, in each
vectorized step it only deal with those environments spending time vectorized step it only deal with those environments spending time
within ``timeout`` seconds. within ``timeout`` seconds.
:param bool norm_obs: Whether to track mean/std of data and normalise observation
on return. For now, observation normalization only support observation of
type np.ndarray.
:param obs_rms: class to track mean&std of observation. If not given, it will
initialize a new one. Usually in envs that is used to evaluate algorithm,
obs_rms should be passed in. Default to None.
:param bool update_obs_rms: Whether to update obs_rms. Default to True.
""" """
def __init__( def __init__(
@ -63,6 +71,9 @@ class BaseVectorEnv(gym.Env):
worker_fn: Callable[[Callable[[], gym.Env]], EnvWorker], worker_fn: Callable[[Callable[[], gym.Env]], EnvWorker],
wait_num: Optional[int] = None, wait_num: Optional[int] = None,
timeout: Optional[float] = None, timeout: Optional[float] = None,
norm_obs: bool = False,
obs_rms: Optional[RunningMeanStd] = None,
update_obs_rms: bool = True,
) -> None: ) -> None:
self._env_fns = env_fns self._env_fns = env_fns
# A VectorEnv contains a pool of EnvWorkers, which corresponds to # A VectorEnv contains a pool of EnvWorkers, which corresponds to
@ -90,6 +101,12 @@ class BaseVectorEnv(gym.Env):
self.ready_id = list(range(self.env_num)) self.ready_id = list(range(self.env_num))
self.is_closed = False self.is_closed = False
# initialize observation running mean/std
self.norm_obs = norm_obs
self.update_obs_rms = update_obs_rms
self.obs_rms = RunningMeanStd() if obs_rms is None and norm_obs else obs_rms
self.__eps = np.finfo(np.float32).eps.item()
def _assert_is_not_closed(self) -> None: def _assert_is_not_closed(self) -> None:
assert not self.is_closed, \ assert not self.is_closed, \
f"Methods of {self.__class__.__name__} cannot be called after close." f"Methods of {self.__class__.__name__} cannot be called after close."
@ -149,7 +166,9 @@ class BaseVectorEnv(gym.Env):
if self.is_async: if self.is_async:
self._assert_id(id) self._assert_id(id)
obs = np.stack([self.workers[i].reset() for i in id]) obs = np.stack([self.workers[i].reset() for i in id])
return obs if self.obs_rms and self.update_obs_rms:
self.obs_rms.update(obs)
return self.normalize_obs(obs)
def step( def step(
self, self,
@ -219,7 +238,10 @@ class BaseVectorEnv(gym.Env):
info["env_id"] = env_id info["env_id"] = env_id
result.append((obs, rew, done, info)) result.append((obs, rew, done, info))
self.ready_id.append(env_id) self.ready_id.append(env_id)
return list(map(np.stack, zip(*result))) obs_stack, rew_stack, done_stack, info_stack = map(np.stack, zip(*result))
if self.obs_rms and self.update_obs_rms:
self.obs_rms.update(obs_stack)
return [self.normalize_obs(obs_stack), rew_stack, done_stack, info_stack]
def seed( def seed(
self, seed: Optional[Union[int, List[int]]] = None self, seed: Optional[Union[int, List[int]]] = None
@ -255,15 +277,23 @@ class BaseVectorEnv(gym.Env):
def close(self) -> None: def close(self) -> None:
"""Close all of the environments. """Close all of the environments.
This function will be called only once (if not, it will be called This function will be called only once (if not, it will be called during
during garbage collected). This way, ``close`` of all workers can be garbage collected). This way, ``close`` of all workers can be assured.
assured.
""" """
self._assert_is_not_closed() self._assert_is_not_closed()
for w in self.workers: for w in self.workers:
w.close() w.close()
self.is_closed = True self.is_closed = True
def normalize_obs(self, obs: np.ndarray) -> np.ndarray:
"""Normalize observations by statistics in obs_rms."""
clip_max = 10.0 # this magic number is from openai baselines
# see baselines/common/vec_env/vec_normalize.py#L10
if self.obs_rms and self.norm_obs:
obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.__eps)
obs = np.clip(obs, -clip_max, clip_max)
return obs
def __del__(self) -> None: def __del__(self) -> None:
"""Redirect to self.close().""" """Redirect to self.close()."""
if not self.is_closed: if not self.is_closed:
@ -275,17 +305,11 @@ class DummyVectorEnv(BaseVectorEnv):
.. seealso:: .. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
explanation.
""" """
def __init__( def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
self, super().__init__(env_fns, DummyEnvWorker, **kwargs)
env_fns: List[Callable[[], gym.Env]],
wait_num: Optional[int] = None,
timeout: Optional[float] = None,
) -> None:
super().__init__(env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)
class SubprocVectorEnv(BaseVectorEnv): class SubprocVectorEnv(BaseVectorEnv):
@ -293,20 +317,14 @@ class SubprocVectorEnv(BaseVectorEnv):
.. seealso:: .. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
explanation.
""" """
def __init__( def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
self,
env_fns: List[Callable[[], gym.Env]],
wait_num: Optional[int] = None,
timeout: Optional[float] = None,
) -> None:
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
return SubprocEnvWorker(fn, share_memory=False) return SubprocEnvWorker(fn, share_memory=False)
super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout) super().__init__(env_fns, worker_fn, **kwargs)
class ShmemVectorEnv(BaseVectorEnv): class ShmemVectorEnv(BaseVectorEnv):
@ -316,20 +334,14 @@ class ShmemVectorEnv(BaseVectorEnv):
.. seealso:: .. seealso::
Please refer to :class:`~tianshou.env.SubprocVectorEnv` for more Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
detailed explanation.
""" """
def __init__( def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
self,
env_fns: List[Callable[[], gym.Env]],
wait_num: Optional[int] = None,
timeout: Optional[float] = None,
) -> None:
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
return SubprocEnvWorker(fn, share_memory=True) return SubprocEnvWorker(fn, share_memory=True)
super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout) super().__init__(env_fns, worker_fn, **kwargs)
class RayVectorEnv(BaseVectorEnv): class RayVectorEnv(BaseVectorEnv):
@ -339,16 +351,10 @@ class RayVectorEnv(BaseVectorEnv):
.. seealso:: .. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
explanation.
""" """
def __init__( def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
self,
env_fns: List[Callable[[], gym.Env]],
wait_num: Optional[int] = None,
timeout: Optional[float] = None,
) -> None:
try: try:
import ray import ray
except ImportError as e: except ImportError as e:
@ -357,4 +363,4 @@ class RayVectorEnv(BaseVectorEnv):
) from e ) from e
if not ray.is_initialized(): if not ray.is_initialized():
ray.init() ray.init()
super().__init__(env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout) super().__init__(env_fns, RayEnvWorker, **kwargs)

View File

@ -1,9 +1,10 @@
from tianshou.utils.config import tqdm_config from tianshou.utils.config import tqdm_config
from tianshou.utils.moving_average import MovAvg from tianshou.utils.statistics import MovAvg, RunningMeanStd
from tianshou.utils.log_tools import BasicLogger, LazyLogger, BaseLogger from tianshou.utils.log_tools import BasicLogger, LazyLogger, BaseLogger
__all__ = [ __all__ = [
"MovAvg", "MovAvg",
"RunningMeanStd",
"tqdm_config", "tqdm_config",
"BaseLogger", "BaseLogger",
"BasicLogger", "BasicLogger",

View File

@ -66,3 +66,31 @@ class MovAvg(object):
if len(self.cache) == 0: if len(self.cache) == 0:
return 0 return 0
return np.std(self.cache) return np.std(self.cache)
class RunningMeanStd(object):
"""Calulates the running mean and std of a data stream.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
"""
def __init__(self) -> None:
self.mean, self.var = 0.0, 1.0
self.count = 0
def update(self, x: np.ndarray) -> None:
"""Add a batch of item into RMS with the same shape, modify mean/var/count."""
batch_mean, batch_var = np.mean(x, axis=0), np.var(x, axis=0)
batch_count = len(x)
delta = batch_mean - self.mean
total_count = self.count + batch_count
new_mean = self.mean + delta * batch_count / total_count
m_a = self.var * self.count
m_b = batch_var * batch_count
m_2 = m_a + m_b + delta ** 2 * self.count * batch_count / total_count
new_var = m_2 / total_count
self.mean, self.var = new_mean, new_var
self.count = total_count