support observation normalization in BaseVectorEnv (#308)
add RunningMeanStd
This commit is contained in:
parent
5c53f8c1f8
commit
243ab43b3c
@ -1,8 +1,8 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from tianshou.utils import MovAvg
|
||||
from tianshou.utils.net.common import MLP, Net
|
||||
from tianshou.utils import MovAvg, RunningMeanStd
|
||||
from tianshou.exploration import GaussianNoise, OUNoise
|
||||
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
|
||||
|
||||
@ -30,6 +30,16 @@ def test_moving_average():
|
||||
assert np.allclose(stat.std() ** 2, 2)
|
||||
|
||||
|
||||
def test_rms():
|
||||
rms = RunningMeanStd()
|
||||
assert np.allclose(rms.mean, 0)
|
||||
assert np.allclose(rms.var, 1)
|
||||
rms.update(np.array([[[1, 2], [3, 5]]]))
|
||||
rms.update(np.array([[[1, 2], [3, 4]], [[1, 2], [0, 0]]]))
|
||||
assert np.allclose(rms.mean, np.array([[1, 2], [2, 3]]), atol=1e-3)
|
||||
assert np.allclose(rms.var, np.array([[0, 0], [2, 14 / 3.]]), atol=1e-3)
|
||||
|
||||
|
||||
def test_net():
|
||||
# here test the networks that does not appear in the other script
|
||||
bsz = 64
|
||||
@ -79,4 +89,5 @@ def test_net():
|
||||
if __name__ == '__main__':
|
||||
test_noise()
|
||||
test_moving_average()
|
||||
test_rms()
|
||||
test_net()
|
||||
|
88
tianshou/env/venvs.py
vendored
88
tianshou/env/venvs.py
vendored
@ -2,6 +2,7 @@ import gym
|
||||
import numpy as np
|
||||
from typing import Any, List, Union, Optional, Callable
|
||||
|
||||
from tianshou.utils import RunningMeanStd
|
||||
from tianshou.env.worker import EnvWorker, DummyEnvWorker, SubprocEnvWorker, \
|
||||
RayEnvWorker
|
||||
|
||||
@ -55,6 +56,13 @@ class BaseVectorEnv(gym.Env):
|
||||
:param float timeout: use in asynchronous simulation same as above, in each
|
||||
vectorized step it only deal with those environments spending time
|
||||
within ``timeout`` seconds.
|
||||
:param bool norm_obs: Whether to track mean/std of data and normalise observation
|
||||
on return. For now, observation normalization only support observation of
|
||||
type np.ndarray.
|
||||
:param obs_rms: class to track mean&std of observation. If not given, it will
|
||||
initialize a new one. Usually in envs that is used to evaluate algorithm,
|
||||
obs_rms should be passed in. Default to None.
|
||||
:param bool update_obs_rms: Whether to update obs_rms. Default to True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -63,6 +71,9 @@ class BaseVectorEnv(gym.Env):
|
||||
worker_fn: Callable[[Callable[[], gym.Env]], EnvWorker],
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
norm_obs: bool = False,
|
||||
obs_rms: Optional[RunningMeanStd] = None,
|
||||
update_obs_rms: bool = True,
|
||||
) -> None:
|
||||
self._env_fns = env_fns
|
||||
# A VectorEnv contains a pool of EnvWorkers, which corresponds to
|
||||
@ -90,6 +101,12 @@ class BaseVectorEnv(gym.Env):
|
||||
self.ready_id = list(range(self.env_num))
|
||||
self.is_closed = False
|
||||
|
||||
# initialize observation running mean/std
|
||||
self.norm_obs = norm_obs
|
||||
self.update_obs_rms = update_obs_rms
|
||||
self.obs_rms = RunningMeanStd() if obs_rms is None and norm_obs else obs_rms
|
||||
self.__eps = np.finfo(np.float32).eps.item()
|
||||
|
||||
def _assert_is_not_closed(self) -> None:
|
||||
assert not self.is_closed, \
|
||||
f"Methods of {self.__class__.__name__} cannot be called after close."
|
||||
@ -149,7 +166,9 @@ class BaseVectorEnv(gym.Env):
|
||||
if self.is_async:
|
||||
self._assert_id(id)
|
||||
obs = np.stack([self.workers[i].reset() for i in id])
|
||||
return obs
|
||||
if self.obs_rms and self.update_obs_rms:
|
||||
self.obs_rms.update(obs)
|
||||
return self.normalize_obs(obs)
|
||||
|
||||
def step(
|
||||
self,
|
||||
@ -219,7 +238,10 @@ class BaseVectorEnv(gym.Env):
|
||||
info["env_id"] = env_id
|
||||
result.append((obs, rew, done, info))
|
||||
self.ready_id.append(env_id)
|
||||
return list(map(np.stack, zip(*result)))
|
||||
obs_stack, rew_stack, done_stack, info_stack = map(np.stack, zip(*result))
|
||||
if self.obs_rms and self.update_obs_rms:
|
||||
self.obs_rms.update(obs_stack)
|
||||
return [self.normalize_obs(obs_stack), rew_stack, done_stack, info_stack]
|
||||
|
||||
def seed(
|
||||
self, seed: Optional[Union[int, List[int]]] = None
|
||||
@ -255,15 +277,23 @@ class BaseVectorEnv(gym.Env):
|
||||
def close(self) -> None:
|
||||
"""Close all of the environments.
|
||||
|
||||
This function will be called only once (if not, it will be called
|
||||
during garbage collected). This way, ``close`` of all workers can be
|
||||
assured.
|
||||
This function will be called only once (if not, it will be called during
|
||||
garbage collected). This way, ``close`` of all workers can be assured.
|
||||
"""
|
||||
self._assert_is_not_closed()
|
||||
for w in self.workers:
|
||||
w.close()
|
||||
self.is_closed = True
|
||||
|
||||
def normalize_obs(self, obs: np.ndarray) -> np.ndarray:
|
||||
"""Normalize observations by statistics in obs_rms."""
|
||||
clip_max = 10.0 # this magic number is from openai baselines
|
||||
# see baselines/common/vec_env/vec_normalize.py#L10
|
||||
if self.obs_rms and self.norm_obs:
|
||||
obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.__eps)
|
||||
obs = np.clip(obs, -clip_max, clip_max)
|
||||
return obs
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Redirect to self.close()."""
|
||||
if not self.is_closed:
|
||||
@ -275,17 +305,11 @@ class DummyVectorEnv(BaseVectorEnv):
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||
explanation.
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: List[Callable[[], gym.Env]],
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
super().__init__(env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
|
||||
super().__init__(env_fns, DummyEnvWorker, **kwargs)
|
||||
|
||||
|
||||
class SubprocVectorEnv(BaseVectorEnv):
|
||||
@ -293,20 +317,14 @@ class SubprocVectorEnv(BaseVectorEnv):
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||
explanation.
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: List[Callable[[], gym.Env]],
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
|
||||
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
|
||||
return SubprocEnvWorker(fn, share_memory=False)
|
||||
|
||||
super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
|
||||
super().__init__(env_fns, worker_fn, **kwargs)
|
||||
|
||||
|
||||
class ShmemVectorEnv(BaseVectorEnv):
|
||||
@ -316,20 +334,14 @@ class ShmemVectorEnv(BaseVectorEnv):
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.env.SubprocVectorEnv` for more
|
||||
detailed explanation.
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: List[Callable[[], gym.Env]],
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
|
||||
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
|
||||
return SubprocEnvWorker(fn, share_memory=True)
|
||||
|
||||
super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
|
||||
super().__init__(env_fns, worker_fn, **kwargs)
|
||||
|
||||
|
||||
class RayVectorEnv(BaseVectorEnv):
|
||||
@ -339,16 +351,10 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||
explanation.
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: List[Callable[[], gym.Env]],
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
|
||||
try:
|
||||
import ray
|
||||
except ImportError as e:
|
||||
@ -357,4 +363,4 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
) from e
|
||||
if not ray.is_initialized():
|
||||
ray.init()
|
||||
super().__init__(env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout)
|
||||
super().__init__(env_fns, RayEnvWorker, **kwargs)
|
||||
|
@ -1,9 +1,10 @@
|
||||
from tianshou.utils.config import tqdm_config
|
||||
from tianshou.utils.moving_average import MovAvg
|
||||
from tianshou.utils.statistics import MovAvg, RunningMeanStd
|
||||
from tianshou.utils.log_tools import BasicLogger, LazyLogger, BaseLogger
|
||||
|
||||
__all__ = [
|
||||
"MovAvg",
|
||||
"RunningMeanStd",
|
||||
"tqdm_config",
|
||||
"BaseLogger",
|
||||
"BasicLogger",
|
||||
|
@ -66,3 +66,31 @@ class MovAvg(object):
|
||||
if len(self.cache) == 0:
|
||||
return 0
|
||||
return np.std(self.cache)
|
||||
|
||||
|
||||
class RunningMeanStd(object):
|
||||
"""Calulates the running mean and std of a data stream.
|
||||
|
||||
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.mean, self.var = 0.0, 1.0
|
||||
self.count = 0
|
||||
|
||||
def update(self, x: np.ndarray) -> None:
|
||||
"""Add a batch of item into RMS with the same shape, modify mean/var/count."""
|
||||
batch_mean, batch_var = np.mean(x, axis=0), np.var(x, axis=0)
|
||||
batch_count = len(x)
|
||||
|
||||
delta = batch_mean - self.mean
|
||||
total_count = self.count + batch_count
|
||||
|
||||
new_mean = self.mean + delta * batch_count / total_count
|
||||
m_a = self.var * self.count
|
||||
m_b = batch_var * batch_count
|
||||
m_2 = m_a + m_b + delta ** 2 * self.count * batch_count / total_count
|
||||
new_var = m_2 / total_count
|
||||
|
||||
self.mean, self.var = new_mean, new_var
|
||||
self.count = total_count
|
Loading…
x
Reference in New Issue
Block a user