support observation normalization in BaseVectorEnv (#308)

add RunningMeanStd
This commit is contained in:
ChenDRAG 2021-03-11 20:50:20 +08:00 committed by GitHub
parent 5c53f8c1f8
commit 243ab43b3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 89 additions and 43 deletions

View File

@ -1,8 +1,8 @@
import torch
import numpy as np
from tianshou.utils import MovAvg
from tianshou.utils.net.common import MLP, Net
from tianshou.utils import MovAvg, RunningMeanStd
from tianshou.exploration import GaussianNoise, OUNoise
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
@ -30,6 +30,16 @@ def test_moving_average():
assert np.allclose(stat.std() ** 2, 2)
def test_rms():
rms = RunningMeanStd()
assert np.allclose(rms.mean, 0)
assert np.allclose(rms.var, 1)
rms.update(np.array([[[1, 2], [3, 5]]]))
rms.update(np.array([[[1, 2], [3, 4]], [[1, 2], [0, 0]]]))
assert np.allclose(rms.mean, np.array([[1, 2], [2, 3]]), atol=1e-3)
assert np.allclose(rms.var, np.array([[0, 0], [2, 14 / 3.]]), atol=1e-3)
def test_net():
# here test the networks that does not appear in the other script
bsz = 64
@ -79,4 +89,5 @@ def test_net():
if __name__ == '__main__':
test_noise()
test_moving_average()
test_rms()
test_net()

88
tianshou/env/venvs.py vendored
View File

@ -2,6 +2,7 @@ import gym
import numpy as np
from typing import Any, List, Union, Optional, Callable
from tianshou.utils import RunningMeanStd
from tianshou.env.worker import EnvWorker, DummyEnvWorker, SubprocEnvWorker, \
RayEnvWorker
@ -55,6 +56,13 @@ class BaseVectorEnv(gym.Env):
:param float timeout: use in asynchronous simulation same as above, in each
vectorized step it only deal with those environments spending time
within ``timeout`` seconds.
:param bool norm_obs: Whether to track mean/std of data and normalise observation
on return. For now, observation normalization only support observation of
type np.ndarray.
:param obs_rms: class to track mean&std of observation. If not given, it will
initialize a new one. Usually in envs that is used to evaluate algorithm,
obs_rms should be passed in. Default to None.
:param bool update_obs_rms: Whether to update obs_rms. Default to True.
"""
def __init__(
@ -63,6 +71,9 @@ class BaseVectorEnv(gym.Env):
worker_fn: Callable[[Callable[[], gym.Env]], EnvWorker],
wait_num: Optional[int] = None,
timeout: Optional[float] = None,
norm_obs: bool = False,
obs_rms: Optional[RunningMeanStd] = None,
update_obs_rms: bool = True,
) -> None:
self._env_fns = env_fns
# A VectorEnv contains a pool of EnvWorkers, which corresponds to
@ -90,6 +101,12 @@ class BaseVectorEnv(gym.Env):
self.ready_id = list(range(self.env_num))
self.is_closed = False
# initialize observation running mean/std
self.norm_obs = norm_obs
self.update_obs_rms = update_obs_rms
self.obs_rms = RunningMeanStd() if obs_rms is None and norm_obs else obs_rms
self.__eps = np.finfo(np.float32).eps.item()
def _assert_is_not_closed(self) -> None:
assert not self.is_closed, \
f"Methods of {self.__class__.__name__} cannot be called after close."
@ -149,7 +166,9 @@ class BaseVectorEnv(gym.Env):
if self.is_async:
self._assert_id(id)
obs = np.stack([self.workers[i].reset() for i in id])
return obs
if self.obs_rms and self.update_obs_rms:
self.obs_rms.update(obs)
return self.normalize_obs(obs)
def step(
self,
@ -219,7 +238,10 @@ class BaseVectorEnv(gym.Env):
info["env_id"] = env_id
result.append((obs, rew, done, info))
self.ready_id.append(env_id)
return list(map(np.stack, zip(*result)))
obs_stack, rew_stack, done_stack, info_stack = map(np.stack, zip(*result))
if self.obs_rms and self.update_obs_rms:
self.obs_rms.update(obs_stack)
return [self.normalize_obs(obs_stack), rew_stack, done_stack, info_stack]
def seed(
self, seed: Optional[Union[int, List[int]]] = None
@ -255,15 +277,23 @@ class BaseVectorEnv(gym.Env):
def close(self) -> None:
"""Close all of the environments.
This function will be called only once (if not, it will be called
during garbage collected). This way, ``close`` of all workers can be
assured.
This function will be called only once (if not, it will be called during
garbage collected). This way, ``close`` of all workers can be assured.
"""
self._assert_is_not_closed()
for w in self.workers:
w.close()
self.is_closed = True
def normalize_obs(self, obs: np.ndarray) -> np.ndarray:
"""Normalize observations by statistics in obs_rms."""
clip_max = 10.0 # this magic number is from openai baselines
# see baselines/common/vec_env/vec_normalize.py#L10
if self.obs_rms and self.norm_obs:
obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.__eps)
obs = np.clip(obs, -clip_max, clip_max)
return obs
def __del__(self) -> None:
"""Redirect to self.close()."""
if not self.is_closed:
@ -275,17 +305,11 @@ class DummyVectorEnv(BaseVectorEnv):
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
explanation.
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
"""
def __init__(
self,
env_fns: List[Callable[[], gym.Env]],
wait_num: Optional[int] = None,
timeout: Optional[float] = None,
) -> None:
super().__init__(env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
super().__init__(env_fns, DummyEnvWorker, **kwargs)
class SubprocVectorEnv(BaseVectorEnv):
@ -293,20 +317,14 @@ class SubprocVectorEnv(BaseVectorEnv):
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
explanation.
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
"""
def __init__(
self,
env_fns: List[Callable[[], gym.Env]],
wait_num: Optional[int] = None,
timeout: Optional[float] = None,
) -> None:
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
return SubprocEnvWorker(fn, share_memory=False)
super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
super().__init__(env_fns, worker_fn, **kwargs)
class ShmemVectorEnv(BaseVectorEnv):
@ -316,20 +334,14 @@ class ShmemVectorEnv(BaseVectorEnv):
.. seealso::
Please refer to :class:`~tianshou.env.SubprocVectorEnv` for more
detailed explanation.
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
"""
def __init__(
self,
env_fns: List[Callable[[], gym.Env]],
wait_num: Optional[int] = None,
timeout: Optional[float] = None,
) -> None:
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
return SubprocEnvWorker(fn, share_memory=True)
super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
super().__init__(env_fns, worker_fn, **kwargs)
class RayVectorEnv(BaseVectorEnv):
@ -339,16 +351,10 @@ class RayVectorEnv(BaseVectorEnv):
.. seealso::
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
explanation.
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
"""
def __init__(
self,
env_fns: List[Callable[[], gym.Env]],
wait_num: Optional[int] = None,
timeout: Optional[float] = None,
) -> None:
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
try:
import ray
except ImportError as e:
@ -357,4 +363,4 @@ class RayVectorEnv(BaseVectorEnv):
) from e
if not ray.is_initialized():
ray.init()
super().__init__(env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout)
super().__init__(env_fns, RayEnvWorker, **kwargs)

View File

@ -1,9 +1,10 @@
from tianshou.utils.config import tqdm_config
from tianshou.utils.moving_average import MovAvg
from tianshou.utils.statistics import MovAvg, RunningMeanStd
from tianshou.utils.log_tools import BasicLogger, LazyLogger, BaseLogger
__all__ = [
"MovAvg",
"RunningMeanStd",
"tqdm_config",
"BaseLogger",
"BasicLogger",

View File

@ -66,3 +66,31 @@ class MovAvg(object):
if len(self.cache) == 0:
return 0
return np.std(self.cache)
class RunningMeanStd(object):
"""Calulates the running mean and std of a data stream.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
"""
def __init__(self) -> None:
self.mean, self.var = 0.0, 1.0
self.count = 0
def update(self, x: np.ndarray) -> None:
"""Add a batch of item into RMS with the same shape, modify mean/var/count."""
batch_mean, batch_var = np.mean(x, axis=0), np.var(x, axis=0)
batch_count = len(x)
delta = batch_mean - self.mean
total_count = self.count + batch_count
new_mean = self.mean + delta * batch_count / total_count
m_a = self.var * self.count
m_b = batch_var * batch_count
m_2 = m_a + m_b + delta ** 2 * self.count * batch_count / total_count
new_var = m_2 / total_count
self.mean, self.var = new_mean, new_var
self.count = total_count