Fix 2 bugs and refactor RunningMeanStd to support dict obs norm (#695)

* fix #689

* fix #672

* refactor RMS class

* fix #688
This commit is contained in:
Jiayi Weng 2022-07-14 22:52:56 -07:00 committed by GitHub
parent 65054847ef
commit 99c99bb09a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 59 additions and 20 deletions

View File

@ -1,9 +1,13 @@
Cheat Sheet Cheat Sheet
=========== ===========
This page shows some code snippets of how to use Tianshou to develop new algorithms / apply algorithms to new scenarios. This page shows some code snippets of how to use Tianshou to develop new
algorithms / apply algorithms to new scenarios.
By the way, some of these issues can be resolved by using a ``gym.Wrapper``. It could be a universal solution in the policy-environment interaction. But you can also use the batch processor :ref:`preprocess_fn`. By the way, some of these issues can be resolved by using a ``gym.Wrapper``.
It could be a universal solution in the policy-environment interaction. But
you can also use the batch processor :ref:`preprocess_fn` or vectorized
environment wrapper :class:`~tianshou.env.VectorEnvWrapper`.
.. _network_api: .. _network_api:
@ -22,6 +26,18 @@ Build New Policy
See :class:`~tianshou.policy.BasePolicy`. See :class:`~tianshou.policy.BasePolicy`.
.. _eval_policy:
Manually Evaluate Policy
------------------------
If you'd like to manually see the action generated by a well-trained agent:
::
# assume obs is a single environment observation
action = policy(Batch(obs=np.array([obs]))).act[0]
.. _customize_training: .. _customize_training:
Customize Training Process Customize Training Process

View File

@ -256,6 +256,12 @@ Watch the Agent's Performance
collector = ts.data.Collector(policy, env, exploration_noise=True) collector = ts.data.Collector(policy, env, exploration_noise=True)
collector.collect(n_episode=1, render=1 / 35) collector.collect(n_episode=1, render=1 / 35)
If you'd like to manually see the action generated by a well-trained agent:
::
# assume obs is a single environment observation
action = policy(Batch(obs=np.array([obs]))).act[0]
.. _customized_trainer: .. _customized_trainer:

View File

@ -211,6 +211,14 @@ def test_vecenv(size=10, num=8, sleep=0.001):
v.close() v.close()
def test_attr_unwrapped():
train_envs = DummyVectorEnv([lambda: gym.make("CartPole-v1")])
train_envs.set_env_attr("test_attribute", 1337)
assert train_envs.get_env_attr("test_attribute") == [1337]
assert hasattr(train_envs.workers[0].env, "test_attribute")
assert hasattr(train_envs.workers[0].env.unwrapped, "test_attribute")
def test_env_obs_dtype(): def test_env_obs_dtype():
for obs_type in ["array", "object"]: for obs_type in ["array", "object"]:
envs = SubprocVectorEnv( envs = SubprocVectorEnv(
@ -349,6 +357,7 @@ if __name__ == '__main__':
test_venv_wrapper_envpool() test_venv_wrapper_envpool()
test_env_obs_dtype() test_env_obs_dtype()
test_vecenv() test_vecenv()
test_attr_unwrapped()
test_async_env() test_async_env()
test_async_check_id() test_async_check_id()
test_env_reset_optional_kwargs() test_env_reset_optional_kwargs()

View File

@ -145,7 +145,7 @@ class Collector(object):
) )
obs = processed_data.get("obs", obs) obs = processed_data.get("obs", obs)
info = processed_data.get("info", info) info = processed_data.get("info", info)
self.data.info = info self.data.info = info
else: else:
obs = rval obs = rval
if self.preprocess_fn: if self.preprocess_fn:

View File

@ -68,24 +68,17 @@ class VectorEnvNormObs(VectorEnvWrapper):
"""An observation normalization wrapper for vectorized environments. """An observation normalization wrapper for vectorized environments.
:param bool update_obs_rms: whether to update obs_rms. Default to True. :param bool update_obs_rms: whether to update obs_rms. Default to True.
:param float clip_obs: the maximum absolute value for observation. Default to
10.0.
:param float epsilon: To avoid division by zero.
""" """
def __init__( def __init__(
self, self,
venv: BaseVectorEnv, venv: BaseVectorEnv,
update_obs_rms: bool = True, update_obs_rms: bool = True,
clip_obs: float = 10.0,
epsilon: float = np.finfo(np.float32).eps.item(),
) -> None: ) -> None:
super().__init__(venv) super().__init__(venv)
# initialize observation running mean/std # initialize observation running mean/std
self.update_obs_rms = update_obs_rms self.update_obs_rms = update_obs_rms
self.obs_rms = RunningMeanStd() self.obs_rms = RunningMeanStd()
self.clip_max = clip_obs
self.eps = epsilon
def reset( def reset(
self, self,
@ -127,8 +120,7 @@ class VectorEnvNormObs(VectorEnvWrapper):
def _norm_obs(self, obs: np.ndarray) -> np.ndarray: def _norm_obs(self, obs: np.ndarray) -> np.ndarray:
if self.obs_rms: if self.obs_rms:
obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.eps) return self.obs_rms.norm(obs) # type: ignore
obs = np.clip(obs, -self.clip_max, self.clip_max)
return obs return obs
def set_obs_rms(self, obs_rms: RunningMeanStd) -> None: def set_obs_rms(self, obs_rms: RunningMeanStd) -> None:

View File

@ -17,7 +17,7 @@ class DummyEnvWorker(EnvWorker):
return getattr(self.env, key) return getattr(self.env, key)
def set_env_attr(self, key: str, value: Any) -> None: def set_env_attr(self, key: str, value: Any) -> None:
setattr(self.env, key, value) setattr(self.env.unwrapped, key, value)
def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
if "seed" in kwargs: if "seed" in kwargs:

View File

@ -14,7 +14,7 @@ except ImportError:
class _SetAttrWrapper(gym.Wrapper): class _SetAttrWrapper(gym.Wrapper):
def set_env_attr(self, key: str, value: Any) -> None: def set_env_attr(self, key: str, value: Any) -> None:
setattr(self.env, key, value) setattr(self.env.unwrapped, key, value)
def get_env_attr(self, key: str) -> Any: def get_env_attr(self, key: str) -> Any:
return getattr(self.env, key) return getattr(self.env, key)

View File

@ -49,9 +49,9 @@ def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:
if isinstance(space, gym.spaces.Dict): if isinstance(space, gym.spaces.Dict):
assert isinstance(space.spaces, OrderedDict) assert isinstance(space.spaces, OrderedDict)
return {k: _setup_buf(v) for k, v in space.spaces.items()} return {k: _setup_buf(v) for k, v in space.spaces.items()}
elif isinstance(space, gym.spaces.Tuple): elif isinstance(space, gym.spaces.Tuple): # type: ignore
assert isinstance(space.spaces, tuple) assert isinstance(space.spaces, tuple) # type: ignore
return tuple([_setup_buf(t) for t in space.spaces]) return tuple([_setup_buf(t) for t in space.spaces]) # type: ignore
else: else:
return ShArray(space.dtype, space.shape) # type: ignore return ShArray(space.dtype, space.shape) # type: ignore
@ -122,7 +122,7 @@ def _worker(
elif cmd == "getattr": elif cmd == "getattr":
p.send(getattr(env, data) if hasattr(env, data) else None) p.send(getattr(env, data) if hasattr(env, data) else None)
elif cmd == "setattr": elif cmd == "setattr":
setattr(env, data["key"], data["value"]) setattr(env.unwrapped, data["key"], data["value"])
else: else:
p.close() p.close()
raise NotImplementedError raise NotImplementedError

View File

@ -1,5 +1,5 @@
from numbers import Number from numbers import Number
from typing import List, Union from typing import List, Optional, Union
import numpy as np import numpy as np
import torch import torch
@ -70,15 +70,31 @@ class RunningMeanStd(object):
"""Calculates the running mean and std of a data stream. """Calculates the running mean and std of a data stream.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
:param mean: the initial mean estimation for data array. Default to 0.
:param std: the initial standard error estimation for data array. Default to 1.
:param float clip_max: the maximum absolute value for data array. Default to
10.0.
:param float epsilon: To avoid division by zero.
""" """
def __init__( def __init__(
self, self,
mean: Union[float, np.ndarray] = 0.0, mean: Union[float, np.ndarray] = 0.0,
std: Union[float, np.ndarray] = 1.0 std: Union[float, np.ndarray] = 1.0,
clip_max: Optional[float] = 10.0,
epsilon: float = np.finfo(np.float32).eps.item(),
) -> None: ) -> None:
self.mean, self.var = mean, std self.mean, self.var = mean, std
self.clip_max = clip_max
self.count = 0 self.count = 0
self.eps = epsilon
def norm(self, data_array: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
data_array = (data_array - self.mean) / np.sqrt(self.var + self.eps)
if self.clip_max:
data_array = np.clip(data_array, -self.clip_max, self.clip_max)
return data_array
def update(self, data_array: np.ndarray) -> None: def update(self, data_array: np.ndarray) -> None:
"""Add a batch of item into RMS with the same shape, modify mean/var/count.""" """Add a batch of item into RMS with the same shape, modify mean/var/count."""