support observation normalization in BaseVectorEnv (#308)
add RunningMeanStd
This commit is contained in:
parent
5c53f8c1f8
commit
243ab43b3c
@ -1,8 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tianshou.utils import MovAvg
|
|
||||||
from tianshou.utils.net.common import MLP, Net
|
from tianshou.utils.net.common import MLP, Net
|
||||||
|
from tianshou.utils import MovAvg, RunningMeanStd
|
||||||
from tianshou.exploration import GaussianNoise, OUNoise
|
from tianshou.exploration import GaussianNoise, OUNoise
|
||||||
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
|
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
|
||||||
|
|
||||||
@ -30,6 +30,16 @@ def test_moving_average():
|
|||||||
assert np.allclose(stat.std() ** 2, 2)
|
assert np.allclose(stat.std() ** 2, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rms():
|
||||||
|
rms = RunningMeanStd()
|
||||||
|
assert np.allclose(rms.mean, 0)
|
||||||
|
assert np.allclose(rms.var, 1)
|
||||||
|
rms.update(np.array([[[1, 2], [3, 5]]]))
|
||||||
|
rms.update(np.array([[[1, 2], [3, 4]], [[1, 2], [0, 0]]]))
|
||||||
|
assert np.allclose(rms.mean, np.array([[1, 2], [2, 3]]), atol=1e-3)
|
||||||
|
assert np.allclose(rms.var, np.array([[0, 0], [2, 14 / 3.]]), atol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
def test_net():
|
def test_net():
|
||||||
# here test the networks that does not appear in the other script
|
# here test the networks that does not appear in the other script
|
||||||
bsz = 64
|
bsz = 64
|
||||||
@ -79,4 +89,5 @@ def test_net():
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_noise()
|
test_noise()
|
||||||
test_moving_average()
|
test_moving_average()
|
||||||
|
test_rms()
|
||||||
test_net()
|
test_net()
|
||||||
|
88
tianshou/env/venvs.py
vendored
88
tianshou/env/venvs.py
vendored
@ -2,6 +2,7 @@ import gym
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Any, List, Union, Optional, Callable
|
from typing import Any, List, Union, Optional, Callable
|
||||||
|
|
||||||
|
from tianshou.utils import RunningMeanStd
|
||||||
from tianshou.env.worker import EnvWorker, DummyEnvWorker, SubprocEnvWorker, \
|
from tianshou.env.worker import EnvWorker, DummyEnvWorker, SubprocEnvWorker, \
|
||||||
RayEnvWorker
|
RayEnvWorker
|
||||||
|
|
||||||
@ -55,6 +56,13 @@ class BaseVectorEnv(gym.Env):
|
|||||||
:param float timeout: use in asynchronous simulation same as above, in each
|
:param float timeout: use in asynchronous simulation same as above, in each
|
||||||
vectorized step it only deal with those environments spending time
|
vectorized step it only deal with those environments spending time
|
||||||
within ``timeout`` seconds.
|
within ``timeout`` seconds.
|
||||||
|
:param bool norm_obs: Whether to track mean/std of data and normalise observation
|
||||||
|
on return. For now, observation normalization only support observation of
|
||||||
|
type np.ndarray.
|
||||||
|
:param obs_rms: class to track mean&std of observation. If not given, it will
|
||||||
|
initialize a new one. Usually in envs that is used to evaluate algorithm,
|
||||||
|
obs_rms should be passed in. Default to None.
|
||||||
|
:param bool update_obs_rms: Whether to update obs_rms. Default to True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -63,6 +71,9 @@ class BaseVectorEnv(gym.Env):
|
|||||||
worker_fn: Callable[[Callable[[], gym.Env]], EnvWorker],
|
worker_fn: Callable[[Callable[[], gym.Env]], EnvWorker],
|
||||||
wait_num: Optional[int] = None,
|
wait_num: Optional[int] = None,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
|
norm_obs: bool = False,
|
||||||
|
obs_rms: Optional[RunningMeanStd] = None,
|
||||||
|
update_obs_rms: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._env_fns = env_fns
|
self._env_fns = env_fns
|
||||||
# A VectorEnv contains a pool of EnvWorkers, which corresponds to
|
# A VectorEnv contains a pool of EnvWorkers, which corresponds to
|
||||||
@ -90,6 +101,12 @@ class BaseVectorEnv(gym.Env):
|
|||||||
self.ready_id = list(range(self.env_num))
|
self.ready_id = list(range(self.env_num))
|
||||||
self.is_closed = False
|
self.is_closed = False
|
||||||
|
|
||||||
|
# initialize observation running mean/std
|
||||||
|
self.norm_obs = norm_obs
|
||||||
|
self.update_obs_rms = update_obs_rms
|
||||||
|
self.obs_rms = RunningMeanStd() if obs_rms is None and norm_obs else obs_rms
|
||||||
|
self.__eps = np.finfo(np.float32).eps.item()
|
||||||
|
|
||||||
def _assert_is_not_closed(self) -> None:
|
def _assert_is_not_closed(self) -> None:
|
||||||
assert not self.is_closed, \
|
assert not self.is_closed, \
|
||||||
f"Methods of {self.__class__.__name__} cannot be called after close."
|
f"Methods of {self.__class__.__name__} cannot be called after close."
|
||||||
@ -149,7 +166,9 @@ class BaseVectorEnv(gym.Env):
|
|||||||
if self.is_async:
|
if self.is_async:
|
||||||
self._assert_id(id)
|
self._assert_id(id)
|
||||||
obs = np.stack([self.workers[i].reset() for i in id])
|
obs = np.stack([self.workers[i].reset() for i in id])
|
||||||
return obs
|
if self.obs_rms and self.update_obs_rms:
|
||||||
|
self.obs_rms.update(obs)
|
||||||
|
return self.normalize_obs(obs)
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
self,
|
self,
|
||||||
@ -219,7 +238,10 @@ class BaseVectorEnv(gym.Env):
|
|||||||
info["env_id"] = env_id
|
info["env_id"] = env_id
|
||||||
result.append((obs, rew, done, info))
|
result.append((obs, rew, done, info))
|
||||||
self.ready_id.append(env_id)
|
self.ready_id.append(env_id)
|
||||||
return list(map(np.stack, zip(*result)))
|
obs_stack, rew_stack, done_stack, info_stack = map(np.stack, zip(*result))
|
||||||
|
if self.obs_rms and self.update_obs_rms:
|
||||||
|
self.obs_rms.update(obs_stack)
|
||||||
|
return [self.normalize_obs(obs_stack), rew_stack, done_stack, info_stack]
|
||||||
|
|
||||||
def seed(
|
def seed(
|
||||||
self, seed: Optional[Union[int, List[int]]] = None
|
self, seed: Optional[Union[int, List[int]]] = None
|
||||||
@ -255,15 +277,23 @@ class BaseVectorEnv(gym.Env):
|
|||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""Close all of the environments.
|
"""Close all of the environments.
|
||||||
|
|
||||||
This function will be called only once (if not, it will be called
|
This function will be called only once (if not, it will be called during
|
||||||
during garbage collected). This way, ``close`` of all workers can be
|
garbage collected). This way, ``close`` of all workers can be assured.
|
||||||
assured.
|
|
||||||
"""
|
"""
|
||||||
self._assert_is_not_closed()
|
self._assert_is_not_closed()
|
||||||
for w in self.workers:
|
for w in self.workers:
|
||||||
w.close()
|
w.close()
|
||||||
self.is_closed = True
|
self.is_closed = True
|
||||||
|
|
||||||
|
def normalize_obs(self, obs: np.ndarray) -> np.ndarray:
|
||||||
|
"""Normalize observations by statistics in obs_rms."""
|
||||||
|
clip_max = 10.0 # this magic number is from openai baselines
|
||||||
|
# see baselines/common/vec_env/vec_normalize.py#L10
|
||||||
|
if self.obs_rms and self.norm_obs:
|
||||||
|
obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.__eps)
|
||||||
|
obs = np.clip(obs, -clip_max, clip_max)
|
||||||
|
return obs
|
||||||
|
|
||||||
def __del__(self) -> None:
|
def __del__(self) -> None:
|
||||||
"""Redirect to self.close()."""
|
"""Redirect to self.close()."""
|
||||||
if not self.is_closed:
|
if not self.is_closed:
|
||||||
@ -275,17 +305,11 @@ class DummyVectorEnv(BaseVectorEnv):
|
|||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
|
||||||
explanation.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
|
||||||
self,
|
super().__init__(env_fns, DummyEnvWorker, **kwargs)
|
||||||
env_fns: List[Callable[[], gym.Env]],
|
|
||||||
wait_num: Optional[int] = None,
|
|
||||||
timeout: Optional[float] = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)
|
|
||||||
|
|
||||||
|
|
||||||
class SubprocVectorEnv(BaseVectorEnv):
|
class SubprocVectorEnv(BaseVectorEnv):
|
||||||
@ -293,20 +317,14 @@ class SubprocVectorEnv(BaseVectorEnv):
|
|||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
|
||||||
explanation.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
|
||||||
self,
|
|
||||||
env_fns: List[Callable[[], gym.Env]],
|
|
||||||
wait_num: Optional[int] = None,
|
|
||||||
timeout: Optional[float] = None,
|
|
||||||
) -> None:
|
|
||||||
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
|
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
|
||||||
return SubprocEnvWorker(fn, share_memory=False)
|
return SubprocEnvWorker(fn, share_memory=False)
|
||||||
|
|
||||||
super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
|
super().__init__(env_fns, worker_fn, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class ShmemVectorEnv(BaseVectorEnv):
|
class ShmemVectorEnv(BaseVectorEnv):
|
||||||
@ -316,20 +334,14 @@ class ShmemVectorEnv(BaseVectorEnv):
|
|||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
Please refer to :class:`~tianshou.env.SubprocVectorEnv` for more
|
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
|
||||||
detailed explanation.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
|
||||||
self,
|
|
||||||
env_fns: List[Callable[[], gym.Env]],
|
|
||||||
wait_num: Optional[int] = None,
|
|
||||||
timeout: Optional[float] = None,
|
|
||||||
) -> None:
|
|
||||||
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
|
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
|
||||||
return SubprocEnvWorker(fn, share_memory=True)
|
return SubprocEnvWorker(fn, share_memory=True)
|
||||||
|
|
||||||
super().__init__(env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
|
super().__init__(env_fns, worker_fn, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class RayVectorEnv(BaseVectorEnv):
|
class RayVectorEnv(BaseVectorEnv):
|
||||||
@ -339,16 +351,10 @@ class RayVectorEnv(BaseVectorEnv):
|
|||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
Please refer to :class:`~tianshou.env.BaseVectorEnv` for other APIs' usage.
|
||||||
explanation.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, env_fns: List[Callable[[], gym.Env]], **kwargs: Any) -> None:
|
||||||
self,
|
|
||||||
env_fns: List[Callable[[], gym.Env]],
|
|
||||||
wait_num: Optional[int] = None,
|
|
||||||
timeout: Optional[float] = None,
|
|
||||||
) -> None:
|
|
||||||
try:
|
try:
|
||||||
import ray
|
import ray
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@ -357,4 +363,4 @@ class RayVectorEnv(BaseVectorEnv):
|
|||||||
) from e
|
) from e
|
||||||
if not ray.is_initialized():
|
if not ray.is_initialized():
|
||||||
ray.init()
|
ray.init()
|
||||||
super().__init__(env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout)
|
super().__init__(env_fns, RayEnvWorker, **kwargs)
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
from tianshou.utils.config import tqdm_config
|
from tianshou.utils.config import tqdm_config
|
||||||
from tianshou.utils.moving_average import MovAvg
|
from tianshou.utils.statistics import MovAvg, RunningMeanStd
|
||||||
from tianshou.utils.log_tools import BasicLogger, LazyLogger, BaseLogger
|
from tianshou.utils.log_tools import BasicLogger, LazyLogger, BaseLogger
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MovAvg",
|
"MovAvg",
|
||||||
|
"RunningMeanStd",
|
||||||
"tqdm_config",
|
"tqdm_config",
|
||||||
"BaseLogger",
|
"BaseLogger",
|
||||||
"BasicLogger",
|
"BasicLogger",
|
||||||
|
@ -66,3 +66,31 @@ class MovAvg(object):
|
|||||||
if len(self.cache) == 0:
|
if len(self.cache) == 0:
|
||||||
return 0
|
return 0
|
||||||
return np.std(self.cache)
|
return np.std(self.cache)
|
||||||
|
|
||||||
|
|
||||||
|
class RunningMeanStd(object):
|
||||||
|
"""Calulates the running mean and std of a data stream.
|
||||||
|
|
||||||
|
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.mean, self.var = 0.0, 1.0
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def update(self, x: np.ndarray) -> None:
|
||||||
|
"""Add a batch of item into RMS with the same shape, modify mean/var/count."""
|
||||||
|
batch_mean, batch_var = np.mean(x, axis=0), np.var(x, axis=0)
|
||||||
|
batch_count = len(x)
|
||||||
|
|
||||||
|
delta = batch_mean - self.mean
|
||||||
|
total_count = self.count + batch_count
|
||||||
|
|
||||||
|
new_mean = self.mean + delta * batch_count / total_count
|
||||||
|
m_a = self.var * self.count
|
||||||
|
m_b = batch_var * batch_count
|
||||||
|
m_2 = m_a + m_b + delta ** 2 * self.count * batch_count / total_count
|
||||||
|
new_var = m_2 / total_count
|
||||||
|
|
||||||
|
self.mean, self.var = new_mean, new_var
|
||||||
|
self.count = total_count
|
Loading…
x
Reference in New Issue
Block a user