Fix 2 bugs and refactor RunningMeanStd to support dict obs norm (#695)
* fix #689 * fix #672 * refactor RMS class * fix #688
This commit is contained in:
parent
65054847ef
commit
99c99bb09a
@ -1,9 +1,13 @@
|
|||||||
Cheat Sheet
|
Cheat Sheet
|
||||||
===========
|
===========
|
||||||
|
|
||||||
This page shows some code snippets of how to use Tianshou to develop new algorithms / apply algorithms to new scenarios.
|
This page shows some code snippets of how to use Tianshou to develop new
|
||||||
|
algorithms / apply algorithms to new scenarios.
|
||||||
|
|
||||||
By the way, some of these issues can be resolved by using a ``gym.Wrapper``. It could be a universal solution in the policy-environment interaction. But you can also use the batch processor :ref:`preprocess_fn`.
|
By the way, some of these issues can be resolved by using a ``gym.Wrapper``.
|
||||||
|
It could be a universal solution in the policy-environment interaction. But
|
||||||
|
you can also use the batch processor :ref:`preprocess_fn` or vectorized
|
||||||
|
environment wrapper :class:`~tianshou.env.VectorEnvWrapper`.
|
||||||
|
|
||||||
|
|
||||||
.. _network_api:
|
.. _network_api:
|
||||||
@ -22,6 +26,18 @@ Build New Policy
|
|||||||
See :class:`~tianshou.policy.BasePolicy`.
|
See :class:`~tianshou.policy.BasePolicy`.
|
||||||
|
|
||||||
|
|
||||||
|
.. _eval_policy:
|
||||||
|
|
||||||
|
Manually Evaluate Policy
|
||||||
|
------------------------
|
||||||
|
|
||||||
|
If you'd like to manually see the action generated by a well-trained agent:
|
||||||
|
::
|
||||||
|
|
||||||
|
# assume obs is a single environment observation
|
||||||
|
action = policy(Batch(obs=np.array([obs]))).act[0]
|
||||||
|
|
||||||
|
|
||||||
.. _customize_training:
|
.. _customize_training:
|
||||||
|
|
||||||
Customize Training Process
|
Customize Training Process
|
||||||
|
@ -256,6 +256,12 @@ Watch the Agent's Performance
|
|||||||
collector = ts.data.Collector(policy, env, exploration_noise=True)
|
collector = ts.data.Collector(policy, env, exploration_noise=True)
|
||||||
collector.collect(n_episode=1, render=1 / 35)
|
collector.collect(n_episode=1, render=1 / 35)
|
||||||
|
|
||||||
|
If you'd like to manually see the action generated by a well-trained agent:
|
||||||
|
::
|
||||||
|
|
||||||
|
# assume obs is a single environment observation
|
||||||
|
action = policy(Batch(obs=np.array([obs]))).act[0]
|
||||||
|
|
||||||
|
|
||||||
.. _customized_trainer:
|
.. _customized_trainer:
|
||||||
|
|
||||||
|
@ -211,6 +211,14 @@ def test_vecenv(size=10, num=8, sleep=0.001):
|
|||||||
v.close()
|
v.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_attr_unwrapped():
|
||||||
|
train_envs = DummyVectorEnv([lambda: gym.make("CartPole-v1")])
|
||||||
|
train_envs.set_env_attr("test_attribute", 1337)
|
||||||
|
assert train_envs.get_env_attr("test_attribute") == [1337]
|
||||||
|
assert hasattr(train_envs.workers[0].env, "test_attribute")
|
||||||
|
assert hasattr(train_envs.workers[0].env.unwrapped, "test_attribute")
|
||||||
|
|
||||||
|
|
||||||
def test_env_obs_dtype():
|
def test_env_obs_dtype():
|
||||||
for obs_type in ["array", "object"]:
|
for obs_type in ["array", "object"]:
|
||||||
envs = SubprocVectorEnv(
|
envs = SubprocVectorEnv(
|
||||||
@ -349,6 +357,7 @@ if __name__ == '__main__':
|
|||||||
test_venv_wrapper_envpool()
|
test_venv_wrapper_envpool()
|
||||||
test_env_obs_dtype()
|
test_env_obs_dtype()
|
||||||
test_vecenv()
|
test_vecenv()
|
||||||
|
test_attr_unwrapped()
|
||||||
test_async_env()
|
test_async_env()
|
||||||
test_async_check_id()
|
test_async_check_id()
|
||||||
test_env_reset_optional_kwargs()
|
test_env_reset_optional_kwargs()
|
||||||
|
@ -145,7 +145,7 @@ class Collector(object):
|
|||||||
)
|
)
|
||||||
obs = processed_data.get("obs", obs)
|
obs = processed_data.get("obs", obs)
|
||||||
info = processed_data.get("info", info)
|
info = processed_data.get("info", info)
|
||||||
self.data.info = info
|
self.data.info = info
|
||||||
else:
|
else:
|
||||||
obs = rval
|
obs = rval
|
||||||
if self.preprocess_fn:
|
if self.preprocess_fn:
|
||||||
|
10
tianshou/env/venv_wrappers.py
vendored
10
tianshou/env/venv_wrappers.py
vendored
@ -68,24 +68,17 @@ class VectorEnvNormObs(VectorEnvWrapper):
|
|||||||
"""An observation normalization wrapper for vectorized environments.
|
"""An observation normalization wrapper for vectorized environments.
|
||||||
|
|
||||||
:param bool update_obs_rms: whether to update obs_rms. Default to True.
|
:param bool update_obs_rms: whether to update obs_rms. Default to True.
|
||||||
:param float clip_obs: the maximum absolute value for observation. Default to
|
|
||||||
10.0.
|
|
||||||
:param float epsilon: To avoid division by zero.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
venv: BaseVectorEnv,
|
venv: BaseVectorEnv,
|
||||||
update_obs_rms: bool = True,
|
update_obs_rms: bool = True,
|
||||||
clip_obs: float = 10.0,
|
|
||||||
epsilon: float = np.finfo(np.float32).eps.item(),
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(venv)
|
super().__init__(venv)
|
||||||
# initialize observation running mean/std
|
# initialize observation running mean/std
|
||||||
self.update_obs_rms = update_obs_rms
|
self.update_obs_rms = update_obs_rms
|
||||||
self.obs_rms = RunningMeanStd()
|
self.obs_rms = RunningMeanStd()
|
||||||
self.clip_max = clip_obs
|
|
||||||
self.eps = epsilon
|
|
||||||
|
|
||||||
def reset(
|
def reset(
|
||||||
self,
|
self,
|
||||||
@ -127,8 +120,7 @@ class VectorEnvNormObs(VectorEnvWrapper):
|
|||||||
|
|
||||||
def _norm_obs(self, obs: np.ndarray) -> np.ndarray:
|
def _norm_obs(self, obs: np.ndarray) -> np.ndarray:
|
||||||
if self.obs_rms:
|
if self.obs_rms:
|
||||||
obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.eps)
|
return self.obs_rms.norm(obs) # type: ignore
|
||||||
obs = np.clip(obs, -self.clip_max, self.clip_max)
|
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def set_obs_rms(self, obs_rms: RunningMeanStd) -> None:
|
def set_obs_rms(self, obs_rms: RunningMeanStd) -> None:
|
||||||
|
2
tianshou/env/worker/dummy.py
vendored
2
tianshou/env/worker/dummy.py
vendored
@ -17,7 +17,7 @@ class DummyEnvWorker(EnvWorker):
|
|||||||
return getattr(self.env, key)
|
return getattr(self.env, key)
|
||||||
|
|
||||||
def set_env_attr(self, key: str, value: Any) -> None:
|
def set_env_attr(self, key: str, value: Any) -> None:
|
||||||
setattr(self.env, key, value)
|
setattr(self.env.unwrapped, key, value)
|
||||||
|
|
||||||
def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
|
def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
|
||||||
if "seed" in kwargs:
|
if "seed" in kwargs:
|
||||||
|
2
tianshou/env/worker/ray.py
vendored
2
tianshou/env/worker/ray.py
vendored
@ -14,7 +14,7 @@ except ImportError:
|
|||||||
class _SetAttrWrapper(gym.Wrapper):
|
class _SetAttrWrapper(gym.Wrapper):
|
||||||
|
|
||||||
def set_env_attr(self, key: str, value: Any) -> None:
|
def set_env_attr(self, key: str, value: Any) -> None:
|
||||||
setattr(self.env, key, value)
|
setattr(self.env.unwrapped, key, value)
|
||||||
|
|
||||||
def get_env_attr(self, key: str) -> Any:
|
def get_env_attr(self, key: str) -> Any:
|
||||||
return getattr(self.env, key)
|
return getattr(self.env, key)
|
||||||
|
8
tianshou/env/worker/subproc.py
vendored
8
tianshou/env/worker/subproc.py
vendored
@ -49,9 +49,9 @@ def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:
|
|||||||
if isinstance(space, gym.spaces.Dict):
|
if isinstance(space, gym.spaces.Dict):
|
||||||
assert isinstance(space.spaces, OrderedDict)
|
assert isinstance(space.spaces, OrderedDict)
|
||||||
return {k: _setup_buf(v) for k, v in space.spaces.items()}
|
return {k: _setup_buf(v) for k, v in space.spaces.items()}
|
||||||
elif isinstance(space, gym.spaces.Tuple):
|
elif isinstance(space, gym.spaces.Tuple): # type: ignore
|
||||||
assert isinstance(space.spaces, tuple)
|
assert isinstance(space.spaces, tuple) # type: ignore
|
||||||
return tuple([_setup_buf(t) for t in space.spaces])
|
return tuple([_setup_buf(t) for t in space.spaces]) # type: ignore
|
||||||
else:
|
else:
|
||||||
return ShArray(space.dtype, space.shape) # type: ignore
|
return ShArray(space.dtype, space.shape) # type: ignore
|
||||||
|
|
||||||
@ -122,7 +122,7 @@ def _worker(
|
|||||||
elif cmd == "getattr":
|
elif cmd == "getattr":
|
||||||
p.send(getattr(env, data) if hasattr(env, data) else None)
|
p.send(getattr(env, data) if hasattr(env, data) else None)
|
||||||
elif cmd == "setattr":
|
elif cmd == "setattr":
|
||||||
setattr(env, data["key"], data["value"])
|
setattr(env.unwrapped, data["key"], data["value"])
|
||||||
else:
|
else:
|
||||||
p.close()
|
p.close()
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from numbers import Number
|
from numbers import Number
|
||||||
from typing import List, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -70,15 +70,31 @@ class RunningMeanStd(object):
|
|||||||
"""Calculates the running mean and std of a data stream.
|
"""Calculates the running mean and std of a data stream.
|
||||||
|
|
||||||
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
||||||
|
|
||||||
|
:param mean: the initial mean estimation for data array. Default to 0.
|
||||||
|
:param std: the initial standard error estimation for data array. Default to 1.
|
||||||
|
:param float clip_max: the maximum absolute value for data array. Default to
|
||||||
|
10.0.
|
||||||
|
:param float epsilon: To avoid division by zero.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
mean: Union[float, np.ndarray] = 0.0,
|
mean: Union[float, np.ndarray] = 0.0,
|
||||||
std: Union[float, np.ndarray] = 1.0
|
std: Union[float, np.ndarray] = 1.0,
|
||||||
|
clip_max: Optional[float] = 10.0,
|
||||||
|
epsilon: float = np.finfo(np.float32).eps.item(),
|
||||||
) -> None:
|
) -> None:
|
||||||
self.mean, self.var = mean, std
|
self.mean, self.var = mean, std
|
||||||
|
self.clip_max = clip_max
|
||||||
self.count = 0
|
self.count = 0
|
||||||
|
self.eps = epsilon
|
||||||
|
|
||||||
|
def norm(self, data_array: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
|
||||||
|
data_array = (data_array - self.mean) / np.sqrt(self.var + self.eps)
|
||||||
|
if self.clip_max:
|
||||||
|
data_array = np.clip(data_array, -self.clip_max, self.clip_max)
|
||||||
|
return data_array
|
||||||
|
|
||||||
def update(self, data_array: np.ndarray) -> None:
|
def update(self, data_array: np.ndarray) -> None:
|
||||||
"""Add a batch of item into RMS with the same shape, modify mean/var/count."""
|
"""Add a batch of item into RMS with the same shape, modify mean/var/count."""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user