Fix 2 bugs and refactor RunningMeanStd to support dict obs norm (#695)

* fix #689

* fix #672

* refactor RMS class

* fix #688
This commit is contained in:
Jiayi Weng 2022-07-14 22:52:56 -07:00 committed by GitHub
parent 65054847ef
commit 99c99bb09a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 59 additions and 20 deletions

View File

@ -1,9 +1,13 @@
Cheat Sheet
===========
This page shows some code snippets of how to use Tianshou to develop new algorithms / apply algorithms to new scenarios.
This page shows some code snippets of how to use Tianshou to develop new
algorithms / apply algorithms to new scenarios.
By the way, some of these issues can be resolved by using a ``gym.Wrapper``. It could be a universal solution in the policy-environment interaction. But you can also use the batch processor :ref:`preprocess_fn`.
By the way, some of these issues can be resolved by using a ``gym.Wrapper``.
It could be a universal solution in the policy-environment interaction. But
you can also use the batch processor :ref:`preprocess_fn` or vectorized
environment wrapper :class:`~tianshou.env.VectorEnvWrapper`.
.. _network_api:
@ -22,6 +26,18 @@ Build New Policy
See :class:`~tianshou.policy.BasePolicy`.
.. _eval_policy:
Manually Evaluate Policy
------------------------
If you'd like to manually see the action generated by a well-trained agent:
::
# assume obs is a single environment observation
action = policy(Batch(obs=np.array([obs]))).act[0]
.. _customize_training:
Customize Training Process

View File

@ -256,6 +256,12 @@ Watch the Agent's Performance
collector = ts.data.Collector(policy, env, exploration_noise=True)
collector.collect(n_episode=1, render=1 / 35)
If you'd like to manually see the action generated by a well-trained agent:
::
# assume obs is a single environment observation
action = policy(Batch(obs=np.array([obs]))).act[0]
.. _customized_trainer:

View File

@ -211,6 +211,14 @@ def test_vecenv(size=10, num=8, sleep=0.001):
v.close()
def test_attr_unwrapped():
train_envs = DummyVectorEnv([lambda: gym.make("CartPole-v1")])
train_envs.set_env_attr("test_attribute", 1337)
assert train_envs.get_env_attr("test_attribute") == [1337]
assert hasattr(train_envs.workers[0].env, "test_attribute")
assert hasattr(train_envs.workers[0].env.unwrapped, "test_attribute")
def test_env_obs_dtype():
for obs_type in ["array", "object"]:
envs = SubprocVectorEnv(
@ -349,6 +357,7 @@ if __name__ == '__main__':
test_venv_wrapper_envpool()
test_env_obs_dtype()
test_vecenv()
test_attr_unwrapped()
test_async_env()
test_async_check_id()
test_env_reset_optional_kwargs()

View File

@ -68,24 +68,17 @@ class VectorEnvNormObs(VectorEnvWrapper):
"""An observation normalization wrapper for vectorized environments.
:param bool update_obs_rms: whether to update obs_rms. Default to True.
:param float clip_obs: the maximum absolute value for observation. Default to
10.0.
:param float epsilon: To avoid division by zero.
"""
def __init__(
self,
venv: BaseVectorEnv,
update_obs_rms: bool = True,
clip_obs: float = 10.0,
epsilon: float = np.finfo(np.float32).eps.item(),
) -> None:
super().__init__(venv)
# initialize observation running mean/std
self.update_obs_rms = update_obs_rms
self.obs_rms = RunningMeanStd()
self.clip_max = clip_obs
self.eps = epsilon
def reset(
self,
@ -127,8 +120,7 @@ class VectorEnvNormObs(VectorEnvWrapper):
def _norm_obs(self, obs: np.ndarray) -> np.ndarray:
if self.obs_rms:
obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.eps)
obs = np.clip(obs, -self.clip_max, self.clip_max)
return self.obs_rms.norm(obs) # type: ignore
return obs
def set_obs_rms(self, obs_rms: RunningMeanStd) -> None:

View File

@ -17,7 +17,7 @@ class DummyEnvWorker(EnvWorker):
return getattr(self.env, key)
def set_env_attr(self, key: str, value: Any) -> None:
setattr(self.env, key, value)
setattr(self.env.unwrapped, key, value)
def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
if "seed" in kwargs:

View File

@ -14,7 +14,7 @@ except ImportError:
class _SetAttrWrapper(gym.Wrapper):
def set_env_attr(self, key: str, value: Any) -> None:
setattr(self.env, key, value)
setattr(self.env.unwrapped, key, value)
def get_env_attr(self, key: str) -> Any:
return getattr(self.env, key)

View File

@ -49,9 +49,9 @@ def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:
if isinstance(space, gym.spaces.Dict):
assert isinstance(space.spaces, OrderedDict)
return {k: _setup_buf(v) for k, v in space.spaces.items()}
elif isinstance(space, gym.spaces.Tuple):
assert isinstance(space.spaces, tuple)
return tuple([_setup_buf(t) for t in space.spaces])
elif isinstance(space, gym.spaces.Tuple): # type: ignore
assert isinstance(space.spaces, tuple) # type: ignore
return tuple([_setup_buf(t) for t in space.spaces]) # type: ignore
else:
return ShArray(space.dtype, space.shape) # type: ignore
@ -122,7 +122,7 @@ def _worker(
elif cmd == "getattr":
p.send(getattr(env, data) if hasattr(env, data) else None)
elif cmd == "setattr":
setattr(env, data["key"], data["value"])
setattr(env.unwrapped, data["key"], data["value"])
else:
p.close()
raise NotImplementedError

View File

@ -1,5 +1,5 @@
from numbers import Number
from typing import List, Union
from typing import List, Optional, Union
import numpy as np
import torch
@ -70,15 +70,31 @@ class RunningMeanStd(object):
"""Calculates the running mean and std of a data stream.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
:param mean: the initial mean estimation for data array. Default to 0.
:param std: the initial standard error estimation for data array. Default to 1.
:param float clip_max: the maximum absolute value for data array. Default to
10.0.
:param float epsilon: To avoid division by zero.
"""
def __init__(
self,
mean: Union[float, np.ndarray] = 0.0,
std: Union[float, np.ndarray] = 1.0
std: Union[float, np.ndarray] = 1.0,
clip_max: Optional[float] = 10.0,
epsilon: float = np.finfo(np.float32).eps.item(),
) -> None:
self.mean, self.var = mean, std
self.clip_max = clip_max
self.count = 0
self.eps = epsilon
def norm(self, data_array: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
data_array = (data_array - self.mean) / np.sqrt(self.var + self.eps)
if self.clip_max:
data_array = np.clip(data_array, -self.clip_max, self.clip_max)
return data_array
def update(self, data_array: np.ndarray) -> None:
"""Add a batch of item into RMS with the same shape, modify mean/var/count."""