Fix 2 bugs and refactor RunningMeanStd to support dict obs norm (#695)
* fix #689 * fix #672 * refactor RMS class * fix #688
This commit is contained in:
parent
65054847ef
commit
99c99bb09a
@ -1,9 +1,13 @@
|
||||
Cheat Sheet
|
||||
===========
|
||||
|
||||
This page shows some code snippets of how to use Tianshou to develop new algorithms / apply algorithms to new scenarios.
|
||||
This page shows some code snippets of how to use Tianshou to develop new
|
||||
algorithms / apply algorithms to new scenarios.
|
||||
|
||||
By the way, some of these issues can be resolved by using a ``gym.Wrapper``. It could be a universal solution in the policy-environment interaction. But you can also use the batch processor :ref:`preprocess_fn`.
|
||||
By the way, some of these issues can be resolved by using a ``gym.Wrapper``.
|
||||
It could be a universal solution in the policy-environment interaction. But
|
||||
you can also use the batch processor :ref:`preprocess_fn` or vectorized
|
||||
environment wrapper :class:`~tianshou.env.VectorEnvWrapper`.
|
||||
|
||||
|
||||
.. _network_api:
|
||||
@ -22,6 +26,18 @@ Build New Policy
|
||||
See :class:`~tianshou.policy.BasePolicy`.
|
||||
|
||||
|
||||
.. _eval_policy:
|
||||
|
||||
Manually Evaluate Policy
|
||||
------------------------
|
||||
|
||||
If you'd like to manually see the action generated by a well-trained agent:
|
||||
::
|
||||
|
||||
# assume obs is a single environment observation
|
||||
action = policy(Batch(obs=np.array([obs]))).act[0]
|
||||
|
||||
|
||||
.. _customize_training:
|
||||
|
||||
Customize Training Process
|
||||
|
@ -256,6 +256,12 @@ Watch the Agent's Performance
|
||||
collector = ts.data.Collector(policy, env, exploration_noise=True)
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
|
||||
If you'd like to manually see the action generated by a well-trained agent:
|
||||
::
|
||||
|
||||
# assume obs is a single environment observation
|
||||
action = policy(Batch(obs=np.array([obs]))).act[0]
|
||||
|
||||
|
||||
.. _customized_trainer:
|
||||
|
||||
|
@ -211,6 +211,14 @@ def test_vecenv(size=10, num=8, sleep=0.001):
|
||||
v.close()
|
||||
|
||||
|
||||
def test_attr_unwrapped():
|
||||
train_envs = DummyVectorEnv([lambda: gym.make("CartPole-v1")])
|
||||
train_envs.set_env_attr("test_attribute", 1337)
|
||||
assert train_envs.get_env_attr("test_attribute") == [1337]
|
||||
assert hasattr(train_envs.workers[0].env, "test_attribute")
|
||||
assert hasattr(train_envs.workers[0].env.unwrapped, "test_attribute")
|
||||
|
||||
|
||||
def test_env_obs_dtype():
|
||||
for obs_type in ["array", "object"]:
|
||||
envs = SubprocVectorEnv(
|
||||
@ -349,6 +357,7 @@ if __name__ == '__main__':
|
||||
test_venv_wrapper_envpool()
|
||||
test_env_obs_dtype()
|
||||
test_vecenv()
|
||||
test_attr_unwrapped()
|
||||
test_async_env()
|
||||
test_async_check_id()
|
||||
test_env_reset_optional_kwargs()
|
||||
|
@ -145,7 +145,7 @@ class Collector(object):
|
||||
)
|
||||
obs = processed_data.get("obs", obs)
|
||||
info = processed_data.get("info", info)
|
||||
self.data.info = info
|
||||
self.data.info = info
|
||||
else:
|
||||
obs = rval
|
||||
if self.preprocess_fn:
|
||||
|
10
tianshou/env/venv_wrappers.py
vendored
10
tianshou/env/venv_wrappers.py
vendored
@ -68,24 +68,17 @@ class VectorEnvNormObs(VectorEnvWrapper):
|
||||
"""An observation normalization wrapper for vectorized environments.
|
||||
|
||||
:param bool update_obs_rms: whether to update obs_rms. Default to True.
|
||||
:param float clip_obs: the maximum absolute value for observation. Default to
|
||||
10.0.
|
||||
:param float epsilon: To avoid division by zero.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
venv: BaseVectorEnv,
|
||||
update_obs_rms: bool = True,
|
||||
clip_obs: float = 10.0,
|
||||
epsilon: float = np.finfo(np.float32).eps.item(),
|
||||
) -> None:
|
||||
super().__init__(venv)
|
||||
# initialize observation running mean/std
|
||||
self.update_obs_rms = update_obs_rms
|
||||
self.obs_rms = RunningMeanStd()
|
||||
self.clip_max = clip_obs
|
||||
self.eps = epsilon
|
||||
|
||||
def reset(
|
||||
self,
|
||||
@ -127,8 +120,7 @@ class VectorEnvNormObs(VectorEnvWrapper):
|
||||
|
||||
def _norm_obs(self, obs: np.ndarray) -> np.ndarray:
|
||||
if self.obs_rms:
|
||||
obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.eps)
|
||||
obs = np.clip(obs, -self.clip_max, self.clip_max)
|
||||
return self.obs_rms.norm(obs) # type: ignore
|
||||
return obs
|
||||
|
||||
def set_obs_rms(self, obs_rms: RunningMeanStd) -> None:
|
||||
|
2
tianshou/env/worker/dummy.py
vendored
2
tianshou/env/worker/dummy.py
vendored
@ -17,7 +17,7 @@ class DummyEnvWorker(EnvWorker):
|
||||
return getattr(self.env, key)
|
||||
|
||||
def set_env_attr(self, key: str, value: Any) -> None:
|
||||
setattr(self.env, key, value)
|
||||
setattr(self.env.unwrapped, key, value)
|
||||
|
||||
def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
|
||||
if "seed" in kwargs:
|
||||
|
2
tianshou/env/worker/ray.py
vendored
2
tianshou/env/worker/ray.py
vendored
@ -14,7 +14,7 @@ except ImportError:
|
||||
class _SetAttrWrapper(gym.Wrapper):
|
||||
|
||||
def set_env_attr(self, key: str, value: Any) -> None:
|
||||
setattr(self.env, key, value)
|
||||
setattr(self.env.unwrapped, key, value)
|
||||
|
||||
def get_env_attr(self, key: str) -> Any:
|
||||
return getattr(self.env, key)
|
||||
|
8
tianshou/env/worker/subproc.py
vendored
8
tianshou/env/worker/subproc.py
vendored
@ -49,9 +49,9 @@ def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:
|
||||
if isinstance(space, gym.spaces.Dict):
|
||||
assert isinstance(space.spaces, OrderedDict)
|
||||
return {k: _setup_buf(v) for k, v in space.spaces.items()}
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
assert isinstance(space.spaces, tuple)
|
||||
return tuple([_setup_buf(t) for t in space.spaces])
|
||||
elif isinstance(space, gym.spaces.Tuple): # type: ignore
|
||||
assert isinstance(space.spaces, tuple) # type: ignore
|
||||
return tuple([_setup_buf(t) for t in space.spaces]) # type: ignore
|
||||
else:
|
||||
return ShArray(space.dtype, space.shape) # type: ignore
|
||||
|
||||
@ -122,7 +122,7 @@ def _worker(
|
||||
elif cmd == "getattr":
|
||||
p.send(getattr(env, data) if hasattr(env, data) else None)
|
||||
elif cmd == "setattr":
|
||||
setattr(env, data["key"], data["value"])
|
||||
setattr(env.unwrapped, data["key"], data["value"])
|
||||
else:
|
||||
p.close()
|
||||
raise NotImplementedError
|
||||
|
@ -1,5 +1,5 @@
|
||||
from numbers import Number
|
||||
from typing import List, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -70,15 +70,31 @@ class RunningMeanStd(object):
|
||||
"""Calculates the running mean and std of a data stream.
|
||||
|
||||
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
||||
|
||||
:param mean: the initial mean estimation for data array. Default to 0.
|
||||
:param std: the initial standard error estimation for data array. Default to 1.
|
||||
:param float clip_max: the maximum absolute value for data array. Default to
|
||||
10.0.
|
||||
:param float epsilon: To avoid division by zero.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mean: Union[float, np.ndarray] = 0.0,
|
||||
std: Union[float, np.ndarray] = 1.0
|
||||
std: Union[float, np.ndarray] = 1.0,
|
||||
clip_max: Optional[float] = 10.0,
|
||||
epsilon: float = np.finfo(np.float32).eps.item(),
|
||||
) -> None:
|
||||
self.mean, self.var = mean, std
|
||||
self.clip_max = clip_max
|
||||
self.count = 0
|
||||
self.eps = epsilon
|
||||
|
||||
def norm(self, data_array: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
|
||||
data_array = (data_array - self.mean) / np.sqrt(self.var + self.eps)
|
||||
if self.clip_max:
|
||||
data_array = np.clip(data_array, -self.clip_max, self.clip_max)
|
||||
return data_array
|
||||
|
||||
def update(self, data_array: np.ndarray) -> None:
|
||||
"""Add a batch of item into RMS with the same shape, modify mean/var/count."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user