From c91def6cbcf572d3d2e645d7fec9f580bdb42e60 Mon Sep 17 00:00:00 2001 From: n+e <463003665@qq.com> Date: Sat, 12 Sep 2020 15:39:01 +0800 Subject: [PATCH] code format and update function signatures (#213) Cherry-pick from #200 - update the function signature - format code-style - move _compile into separate functions - fix a bug in to_torch and to_numpy (Batch) - remove None in action_range In short, the code-format only contains function-signature style and `'` -> `"`. (pick up from [black](https://github.com/psf/black)) --- examples/box2d/bipedal_hardcore_sac.py | 4 +- examples/box2d/mcc_sac.py | 8 +- examples/mujoco/ant_v2_ddpg.py | 5 +- examples/mujoco/ant_v2_sac.py | 4 +- examples/mujoco/ant_v2_td3.py | 10 +- examples/mujoco/halfcheetahBullet_v0_sac.py | 4 +- examples/mujoco/point_maze_td3.py | 10 +- test/continuous/test_ddpg.py | 5 +- test/continuous/test_sac_with_il.py | 4 +- test/continuous/test_td3.py | 9 +- tianshou/__init__.py | 18 +- tianshou/data/__init__.py | 21 +-- tianshou/data/batch.py | 195 +++++++++++--------- tianshou/data/buffer.py | 173 +++++++++-------- tianshou/data/collector.py | 101 +++++----- tianshou/data/utils/converter.py | 65 ++++--- tianshou/data/utils/segtree.py | 35 +++- tianshou/env/__init__.py | 14 +- tianshou/env/maenv.py | 11 +- tianshou/env/utils.py | 7 +- tianshou/env/venvs.py | 138 ++++++++------ tianshou/env/worker/__init__.py | 8 +- tianshou/env/worker/base.py | 26 +-- tianshou/env/worker/dummy.py | 21 ++- tianshou/env/worker/ray.py | 29 +-- tianshou/env/worker/subproc.py | 111 ++++++----- tianshou/exploration/__init__.py | 6 +- tianshou/exploration/random.py | 39 ++-- tianshou/policy/__init__.py | 22 +-- tianshou/policy/base.py | 90 +++++---- tianshou/policy/imitation/base.py | 38 ++-- tianshou/policy/modelfree/a2c.py | 72 ++++---- tianshou/policy/modelfree/ddpg.py | 92 ++++----- tianshou/policy/modelfree/dqn.py | 74 ++++---- tianshou/policy/modelfree/pg.py | 50 ++--- tianshou/policy/modelfree/ppo.py | 98 +++++----- tianshou/policy/modelfree/sac.py | 65 ++++--- tianshou/policy/modelfree/td3.py | 86 +++++---- tianshou/policy/multiagent/mapolicy.py | 45 +++-- tianshou/policy/random.py | 14 +- tianshou/trainer/__init__.py | 8 +- tianshou/trainer/offpolicy.py | 65 +++---- tianshou/trainer/onpolicy.py | 66 +++---- tianshou/trainer/utils.py | 50 ++--- tianshou/utils/__init__.py | 2 - tianshou/utils/compile.py | 27 --- tianshou/utils/config.py | 4 +- tianshou/utils/moving_average.py | 17 +- tianshou/utils/net/common.py | 74 +++++--- tianshou/utils/net/continuous.py | 114 +++++++++--- tianshou/utils/net/discrete.py | 62 +++++-- 51 files changed, 1325 insertions(+), 991 deletions(-) delete mode 100644 tianshou/utils/compile.py diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index e6e7d73..ffd3e8f 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -119,8 +119,8 @@ def test_sac_bipedal(args=get_args()): policy = SACPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - args.tau, args.gamma, args.alpha, - [env.action_space.low[0], env.action_space.high[0]], + action_range=[env.action_space.low[0], env.action_space.high[0]], + tau=args.tau, gamma=args.gamma, alpha=args.alpha, reward_normalization=args.rew_norm, ignore_done=args.ignore_done, estimation_step=args.n_step) diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 6e09e6c..d73e1f5 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -78,14 +78,12 @@ def test_sac(args=get_args()): target_entropy = -np.prod(env.action_space.shape) log_alpha = torch.zeros(1, requires_grad=True, device=args.device) alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) - alpha = (target_entropy, log_alpha, alpha_optim) - else: - alpha = args.alpha + args.alpha = (target_entropy, log_alpha, alpha_optim) policy = SACPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - args.tau, args.gamma, alpha, - [env.action_space.low[0], env.action_space.high[0]], + action_range=[env.action_space.low[0], env.action_space.high[0]], + tau=args.tau, gamma=args.gamma, alpha=args.alpha, reward_normalization=args.rew_norm, ignore_done=True, exploration_noise=OUNoise(0.0, args.noise_std)) # collector diff --git a/examples/mujoco/ant_v2_ddpg.py b/examples/mujoco/ant_v2_ddpg.py index db65f58..948ceee 100644 --- a/examples/mujoco/ant_v2_ddpg.py +++ b/examples/mujoco/ant_v2_ddpg.py @@ -66,8 +66,9 @@ def test_ddpg(args=get_args()): critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) policy = DDPGPolicy( actor, actor_optim, critic, critic_optim, - args.tau, args.gamma, GaussianNoise(sigma=args.exploration_noise), - [env.action_space.low[0], env.action_space.high[0]], + action_range=[env.action_space.low[0], env.action_space.high[0]], + tau=args.tau, gamma=args.gamma, + exploration_noise=GaussianNoise(sigma=args.exploration_noise), reward_normalization=True, ignore_done=True) # collector train_collector = Collector( diff --git a/examples/mujoco/ant_v2_sac.py b/examples/mujoco/ant_v2_sac.py index 108be79..a86bcff 100644 --- a/examples/mujoco/ant_v2_sac.py +++ b/examples/mujoco/ant_v2_sac.py @@ -71,8 +71,8 @@ def test_sac(args=get_args()): critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = SACPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - args.tau, args.gamma, args.alpha, - [env.action_space.low[0], env.action_space.high[0]], + action_range=[env.action_space.low[0], env.action_space.high[0]], + tau=args.tau, gamma=args.gamma, alpha=args.alpha, reward_normalization=args.rew_norm, ignore_done=True) # collector train_collector = Collector( diff --git a/examples/mujoco/ant_v2_td3.py b/examples/mujoco/ant_v2_td3.py index db59e18..7165315 100644 --- a/examples/mujoco/ant_v2_td3.py +++ b/examples/mujoco/ant_v2_td3.py @@ -73,10 +73,12 @@ def test_td3(args=get_args()): critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = TD3Policy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - args.tau, args.gamma, - GaussianNoise(sigma=args.exploration_noise), args.policy_noise, - args.update_actor_freq, args.noise_clip, - [env.action_space.low[0], env.action_space.high[0]], + action_range=[env.action_space.low[0], env.action_space.high[0]], + tau=args.tau, gamma=args.gamma, + exploration_noise=GaussianNoise(sigma=args.exploration_noise), + policy_noise=args.policy_noise, + update_actor_freq=args.update_actor_freq, + noise_clip=args.noise_clip, reward_normalization=True, ignore_done=True) # collector train_collector = Collector( diff --git a/examples/mujoco/halfcheetahBullet_v0_sac.py b/examples/mujoco/halfcheetahBullet_v0_sac.py index 3aec4f8..97b3bc7 100644 --- a/examples/mujoco/halfcheetahBullet_v0_sac.py +++ b/examples/mujoco/halfcheetahBullet_v0_sac.py @@ -79,8 +79,8 @@ def test_sac(args=get_args()): critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = SACPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - args.tau, args.gamma, args.alpha, - [env.action_space.low[0], env.action_space.high[0]], + action_range=[env.action_space.low[0], env.action_space.high[0]], + tau=args.tau, gamma=args.gamma, alpha=args.alpha, reward_normalization=True, ignore_done=True) # collector train_collector = Collector( diff --git a/examples/mujoco/point_maze_td3.py b/examples/mujoco/point_maze_td3.py index 1f4a217..6de2b20 100644 --- a/examples/mujoco/point_maze_td3.py +++ b/examples/mujoco/point_maze_td3.py @@ -76,10 +76,12 @@ def test_td3(args=get_args()): critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = TD3Policy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - args.tau, args.gamma, - GaussianNoise(sigma=args.exploration_noise), args.policy_noise, - args.update_actor_freq, args.noise_clip, - [env.action_space.low[0], env.action_space.high[0]], + action_range=[env.action_space.low[0], env.action_space.high[0]], + tau=args.tau, gamma=args.gamma, + exploration_noise=GaussianNoise(sigma=args.exploration_noise), + policy_noise=args.policy_noise, + update_actor_freq=args.update_actor_freq, + noise_clip=args.noise_clip, reward_normalization=True, ignore_done=True) # collector train_collector = Collector( diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 224fb0d..979444f 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -78,8 +78,9 @@ def test_ddpg(args=get_args()): critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) policy = DDPGPolicy( actor, actor_optim, critic, critic_optim, - args.tau, args.gamma, GaussianNoise(sigma=args.exploration_noise), - [env.action_space.low[0], env.action_space.high[0]], + action_range=[env.action_space.low[0], env.action_space.high[0]], + tau=args.tau, gamma=args.gamma, + exploration_noise=GaussianNoise(sigma=args.exploration_noise), reward_normalization=args.rew_norm, ignore_done=args.ignore_done, estimation_step=args.n_step) diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 9384f43..2067973 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -79,8 +79,8 @@ def test_sac_with_il(args=get_args()): critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = SACPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - args.tau, args.gamma, args.alpha, - [env.action_space.low[0], env.action_space.high[0]], + action_range=[env.action_space.low[0], env.action_space.high[0]], + tau=args.tau, gamma=args.gamma, alpha=args.alpha, reward_normalization=args.rew_norm, ignore_done=args.ignore_done, estimation_step=args.n_step) diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index d8b31aa..a6215e0 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -82,9 +82,12 @@ def test_td3(args=get_args()): critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = TD3Policy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, - args.tau, args.gamma, GaussianNoise(sigma=args.exploration_noise), - args.policy_noise, args.update_actor_freq, args.noise_clip, - [env.action_space.low[0], env.action_space.high[0]], + action_range=[env.action_space.low[0], env.action_space.high[0]], + tau=args.tau, gamma=args.gamma, + exploration_noise=GaussianNoise(sigma=args.exploration_noise), + policy_noise=args.policy_noise, + update_actor_freq=args.update_actor_freq, + noise_clip=args.noise_clip, reward_normalization=args.rew_norm, ignore_done=args.ignore_done, estimation_step=args.n_step) diff --git a/tianshou/__init__.py b/tianshou/__init__.py index e03d864..5ad1c9f 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,17 +1,13 @@ from tianshou import data, env, utils, policy, trainer, exploration -# pre-compile some common-type function-call to produce the correct benchmark -# result: https://github.com/thu-ml/tianshou/pull/193#discussion_r480536371 -utils.pre_compile() - -__version__ = '0.2.7' +__version__ = "0.2.7" __all__ = [ - 'env', - 'data', - 'utils', - 'policy', - 'trainer', - 'exploration', + "env", + "data", + "utils", + "policy", + "trainer", + "exploration", ] diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index f5f68e9..e51b8d1 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -1,19 +1,18 @@ from tianshou.data.batch import Batch -from tianshou.data.utils.converter import to_numpy, to_torch, \ - to_torch_as +from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as from tianshou.data.utils.segtree import SegmentTree from tianshou.data.buffer import ReplayBuffer, \ ListReplayBuffer, PrioritizedReplayBuffer from tianshou.data.collector import Collector __all__ = [ - 'Batch', - 'to_numpy', - 'to_torch', - 'to_torch_as', - 'SegmentTree', - 'ReplayBuffer', - 'ListReplayBuffer', - 'PrioritizedReplayBuffer', - 'Collector', + "Batch", + "to_numpy", + "to_torch", + "to_torch_as", + "SegmentTree", + "ReplayBuffer", + "ListReplayBuffer", + "PrioritizedReplayBuffer", + "Collector", ] diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 7f7f10b..a06c0df 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -5,8 +5,8 @@ import numpy as np from copy import deepcopy from numbers import Number from collections.abc import Collection -from typing import Any, List, Tuple, Union, Iterator, KeysView, ValuesView, \ - ItemsView, Optional +from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, \ + Sequence, KeysView, ValuesView, ItemsView # Disable pickle warning related to torch, since it has been removed # on torch master branch. See Pull Request #39003 for details: @@ -23,8 +23,8 @@ def _is_batch_set(data: Any) -> bool: # "for e in data" will just unpack the first dimension, # but data.tolist() will flatten ndarray of objects # so do not use data.tolist() - return data.dtype == np.object and \ - all(isinstance(e, (dict, Batch)) for e in data) + return data.dtype == np.object and all( + isinstance(e, (dict, Batch)) for e in data) elif isinstance(data, (list, tuple)): if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data): return True @@ -54,8 +54,9 @@ def _is_number(value: Any) -> bool: def _to_array_with_correct_type(v: Any) -> np.ndarray: - if isinstance(v, np.ndarray) and \ - issubclass(v.dtype.type, (np.bool_, np.number)): # most often case + if isinstance(v, np.ndarray) and issubclass( + v.dtype.type, (np.bool_, np.number) + ): # most often case return v # convert the value to np.ndarray # convert to np.object data type if neither bool nor number @@ -71,14 +72,16 @@ def _to_array_with_correct_type(v: Any) -> np.ndarray: # array([{}, array({}, dtype=object)], dtype=object) if not v.shape: v = v.item(0) - elif any(isinstance(e, (np.ndarray, torch.Tensor)) - for e in v.reshape(-1)): + elif any( + isinstance(e, (np.ndarray, torch.Tensor)) for e in v.reshape(-1) + ): raise ValueError("Numpy arrays of tensors are not supported yet.") return v -def _create_value(inst: Any, size: int, stack=True) -> Union[ - 'Batch', np.ndarray, torch.Tensor]: +def _create_value( + inst: Any, size: int, stack: bool = True +) -> Union["Batch", np.ndarray, torch.Tensor]: """Create empty place-holders accroding to inst's shape. :param bool stack: whether to stack or to concatenate. E.g. if inst has @@ -100,12 +103,15 @@ def _create_value(inst: Any, size: int, stack=True) -> Union[ target_type = inst.dtype.type else: target_type = np.object - return np.full(shape, - fill_value=None if target_type == np.object else 0, - dtype=target_type) + return np.full( + shape, + fill_value=None if target_type == np.object else 0, + dtype=target_type + ) elif isinstance(inst, torch.Tensor): - return torch.full(shape, - fill_value=0, device=inst.device, dtype=inst.dtype) + return torch.full( + shape, fill_value=0, device=inst.device, dtype=inst.dtype + ) elif isinstance(inst, (dict, Batch)): zero_batch = Batch() for key, val in inst.items(): @@ -117,12 +123,13 @@ def _create_value(inst: Any, size: int, stack=True) -> Union[ return np.array([None for _ in range(size)]) -def _assert_type_keys(keys) -> None: - assert all(isinstance(e, str) for e in keys), \ - f"keys should all be string, but got {keys}" +def _assert_type_keys(keys: Iterable[str]) -> None: + assert all( + isinstance(e, str) for e in keys + ), f"keys should all be string, but got {keys}" -def _parse_value(v: Any): +def _parse_value(v: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]: if isinstance(v, Batch): # most often case return v elif (isinstance(v, np.ndarray) and @@ -166,12 +173,14 @@ class Batch: For a detailed description, please refer to :ref:`batch_concept`. """ - def __init__(self, - batch_dict: Optional[Union[ - dict, 'Batch', Tuple[Union[dict, 'Batch']], - List[Union[dict, 'Batch']], np.ndarray]] = None, - copy: bool = False, - **kwargs) -> None: + def __init__( + self, + batch_dict: Optional[ + Union[dict, "Batch", Sequence[Union[dict, "Batch"]], np.ndarray] + ] = None, + copy: bool = False, + **kwargs: Any, + ) -> None: if copy: batch_dict = deepcopy(batch_dict) if batch_dict is not None: @@ -188,7 +197,7 @@ class Batch: """Set self.key = value.""" self.__dict__[key] = _parse_value(value) - def __getstate__(self) -> dict: + def __getstate__(self) -> Dict[str, Any]: """Pickling interface. Only the actual data are serialized for both efficiency and simplicity. @@ -200,7 +209,7 @@ class Batch: state[k] = v return state - def __setstate__(self, state) -> None: + def __setstate__(self, state: Dict[str, Any]) -> None: """Unpickling interface. At this point, self is an empty Batch instance that has not been @@ -208,8 +217,9 @@ class Batch: """ self.__init__(**state) - def __getitem__(self, index: Union[ - str, slice, int, np.integer, np.ndarray, List[int]]) -> 'Batch': + def __getitem__( + self, index: Union[str, slice, int, np.integer, np.ndarray, List[int]] + ) -> Union["Batch", np.ndarray, torch.Tensor]: """Return self[index].""" if isinstance(index, str): return self.__dict__[index] @@ -225,9 +235,11 @@ class Batch: else: raise IndexError("Cannot access item from empty Batch object.") - def __setitem__(self, index: Union[ - str, slice, int, np.integer, np.ndarray, List[int]], - value: Any) -> None: + def __setitem__( + self, + index: Union[str, slice, int, np.integer, np.ndarray, List[int]], + value: Any, + ) -> None: """Assign value to self[index].""" value = _parse_value(value) if isinstance(index, str): @@ -252,12 +264,12 @@ class Batch: else: self.__dict__[key][index] = None - def __iadd__(self, other: Union['Batch', Number, np.number]): + def __iadd__(self, other: Union["Batch", Number, np.number]) -> "Batch": """Algebraic addition with another Batch instance in-place.""" if isinstance(other, Batch): - for (k, r), v in zip(self.__dict__.items(), - other.__dict__.values()): - # TODO are keys consistent? + for (k, r), v in zip( + self.__dict__.items(), other.__dict__.values() + ): # TODO are keys consistent? if isinstance(r, Batch) and r.is_empty(): continue else: @@ -273,11 +285,11 @@ class Batch: else: raise TypeError("Only addition of Batch or number is supported.") - def __add__(self, other: Union['Batch', Number, np.number]): + def __add__(self, other: Union["Batch", Number, np.number]) -> "Batch": """Algebraic addition with another Batch instance out-of-place.""" return deepcopy(self).__iadd__(other) - def __imul__(self, val: Union[Number, np.number]): + def __imul__(self, val: Union[Number, np.number]) -> "Batch": """Algebraic multiplication with a scalar value in-place.""" assert _is_number(val), "Only multiplication by a number is supported." for k, r in self.__dict__.items(): @@ -286,11 +298,11 @@ class Batch: self.__dict__[k] *= val return self - def __mul__(self, val: Union[Number, np.number]): + def __mul__(self, val: Union[Number, np.number]) -> "Batch": """Algebraic multiplication with a scalar value out-of-place.""" return deepcopy(self).__imul__(val) - def __itruediv__(self, val: Union[Number, np.number]): + def __itruediv__(self, val: Union[Number, np.number]) -> "Batch": """Algebraic division with a scalar value in-place.""" assert _is_number(val), "Only division by a number is supported." for k, r in self.__dict__.items(): @@ -299,23 +311,23 @@ class Batch: self.__dict__[k] /= val return self - def __truediv__(self, val: Union[Number, np.number]): + def __truediv__(self, val: Union[Number, np.number]) -> "Batch": """Algebraic division with a scalar value out-of-place.""" return deepcopy(self).__itruediv__(val) def __repr__(self) -> str: """Return str(self).""" - s = self.__class__.__name__ + '(\n' + s = self.__class__.__name__ + "(\n" flag = False for k, v in self.__dict__.items(): - rpl = '\n' + ' ' * (6 + len(k)) - obj = pprint.pformat(v).replace('\n', rpl) - s += f' {k}: {obj},\n' + rpl = "\n" + " " * (6 + len(k)) + obj = pprint.pformat(v).replace("\n", rpl) + s += f" {k}: {obj},\n" flag = True if flag: - s += ')' + s += ")" else: - s = self.__class__.__name__ + '()' + s = self.__class__.__name__ + "()" return s def __contains__(self, key: str) -> bool: @@ -350,8 +362,11 @@ class Batch: elif isinstance(v, Batch): v.to_numpy() - def to_torch(self, dtype: Optional[torch.dtype] = None, - device: Union[str, int, torch.device] = 'cpu') -> None: + def to_torch( + self, + dtype: Optional[torch.dtype] = None, + device: Union[str, int, torch.device] = "cpu", + ) -> None: """Change all numpy.ndarray to torch.Tensor in-place.""" if not isinstance(device, torch.device): device = torch.device(device) @@ -376,9 +391,9 @@ class Batch: v = v.type(dtype) self.__dict__[k] = v - def __cat(self, - batches: List[Union[dict, 'Batch']], - lens: List[int]) -> None: + def __cat( + self, batches: Sequence[Union[dict, "Batch"]], lens: List[int] + ) -> None: """Private method for Batch.cat_. :: @@ -445,8 +460,9 @@ class Batch: val, sum_lens[-1], stack=False) self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val - def cat_(self, - batches: Union['Batch', List[Union[dict, 'Batch']]]) -> None: + def cat_( + self, batches: Union["Batch", Sequence[Union[dict, "Batch"]]] + ) -> None: """Concatenate a list of (or one) Batch objects into current batch.""" if isinstance(batches, Batch): batches = [batches] @@ -460,20 +476,19 @@ class Batch: # x.is_empty(recurse=True) here means x is a nested empty batch # like Batch(a=Batch), and we have to treat it as length zero and # keep it. - lens = [0 if x.is_empty(recurse=True) else len(x) - for x in batches] + lens = [0 if x.is_empty(recurse=True) else len(x) for x in batches] except TypeError as e: raise ValueError( - f'Batch.cat_ meets an exception. Maybe because there is any ' - f'scalar in {batches} but Batch.cat_ does not support the ' - f'concatenation of scalar.') from e + "Batch.cat_ meets an exception. Maybe because there is any " + f"scalar in {batches} but Batch.cat_ does not support the " + "concatenation of scalar.") from e if not self.is_empty(): batches = [self] + list(batches) lens = [0 if self.is_empty(recurse=True) else len(self)] + lens - return self.__cat(batches, lens) + self.__cat(batches, lens) @staticmethod - def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch': + def cat(batches: Sequence[Union[dict, "Batch"]]) -> "Batch": """Concatenate a list of Batch object into a single new batch. For keys that are not shared across all batches, batches that do not @@ -494,9 +509,9 @@ class Batch: batch.cat_(batches) return batch - def stack_(self, - batches: List[Union[dict, 'Batch']], - axis: int = 0) -> None: + def stack_( + self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0 + ) -> None: """Stack a list of Batch object into current batch.""" if len(batches) == 0: return @@ -528,8 +543,8 @@ class Batch: keys_partial = keys_reserve_or_partial.difference(keys_reserve) if keys_partial and axis != 0: raise ValueError( - f"Stack of Batch with non-shared keys {keys_partial} " - f"is only supported with axis=0, but got axis={axis}!") + f"Stack of Batch with non-shared keys {keys_partial} is only " + f"supported with axis=0, but got axis={axis}!") for k in keys_reserve: # reserved keys self.__dict__[k] = Batch() @@ -547,7 +562,9 @@ class Batch: self.__dict__[k][i] = val @staticmethod - def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch': + def stack( + batches: Sequence[Union[dict, "Batch"]], axis: int = 0 + ) -> "Batch": """Stack a list of Batch object into a single new batch. For keys that are not shared across all batches, batches that do not @@ -573,9 +590,12 @@ class Batch: batch.stack_(batches, axis) return batch - def empty_(self, index: Union[ - str, slice, int, np.integer, np.ndarray, List[int]] = None - ) -> 'Batch': + def empty_( + self, + index: Union[ + str, slice, int, np.integer, np.ndarray, List[int] + ] = None, + ) -> "Batch": """Return an empty Batch object with 0 or None filled. If "index" is specified, it will only reset the specific indexed-data. @@ -613,8 +633,8 @@ class Batch: elif isinstance(v, Batch): self.__dict__[k].empty_(index=index) else: # scalar value - warnings.warn('You are calling Batch.empty on a NumPy scalar, ' - 'which may cause undefined behaviors.') + warnings.warn("You are calling Batch.empty on a NumPy scalar, " + "which may cause undefined behaviors.") if _is_number(v): self.__dict__[k] = v.__class__(0) else: @@ -622,17 +642,21 @@ class Batch: return self @staticmethod - def empty(batch: 'Batch', index: Union[ - str, slice, int, np.integer, np.ndarray, List[int]] = None - ) -> 'Batch': + def empty( + batch: "Batch", + index: Union[ + str, slice, int, np.integer, np.ndarray, List[int] + ] = None, + ) -> "Batch": """Return an empty Batch object with 0 or None filled. The shape is the same as the given Batch. """ return deepcopy(batch).empty_(index) - def update(self, batch: Optional[Union[dict, 'Batch']] = None, - **kwargs) -> None: + def update( + self, batch: Optional[Union[dict, "Batch"]] = None, **kwargs: Any + ) -> None: """Update this batch from another dict/Batch.""" if batch is None: self.update(kwargs) @@ -648,8 +672,9 @@ class Batch: for v in self.__dict__.values(): if isinstance(v, Batch) and v.is_empty(recurse=True): continue - elif hasattr(v, '__len__') and (not isinstance( - v, (np.ndarray, torch.Tensor)) or v.ndim > 0): + elif hasattr(v, "__len__") and (not isinstance( + v, (np.ndarray, torch.Tensor)) or v.ndim > 0 + ): r.append(len(v)) else: raise TypeError(f"Object {v} in {self} has no len()") @@ -659,7 +684,7 @@ class Batch: raise TypeError(f"Object {self} has no len()") return min(r) - def is_empty(self, recurse: bool = False): + def is_empty(self, recurse: bool = False) -> bool: """Test if a Batch is empty. If ``recurse=True``, it further tests the values of the object; else @@ -689,8 +714,9 @@ class Batch: return True if not recurse: return False - return all(False if not isinstance(x, Batch) - else x.is_empty(recurse=True) for x in self.values()) + return all( + False if not isinstance(x, Batch) else x.is_empty(recurse=True) + for x in self.values()) @property def shape(self) -> List[int]: @@ -707,8 +733,9 @@ class Batch: return list(map(min, zip(*data_shape))) if len(data_shape) > 1 \ else data_shape[0] - def split(self, size: int, shuffle: bool = True, - merge_last: bool = False) -> Iterator['Batch']: + def split( + self, size: int, shuffle: bool = True, merge_last: bool = False + ) -> Iterator["Batch"]: """Split whole data into multiple small batches. :param int size: divide the data batch with the given size, but one diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index e9f47b9..ff16085 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -1,6 +1,7 @@ import torch import numpy as np -from typing import Any, Tuple, Union, Optional +from numbers import Number +from typing import Any, Dict, Tuple, Union, Optional from tianshou.data import Batch, SegmentTree, to_numpy from tianshou.data.batch import _create_value @@ -11,7 +12,7 @@ class ReplayBuffer: interaction between the policy and environment. The current implementation of Tianshou typically use 7 reserved keys in - :class:`~tianshou.data.Batch` + :class:`~tianshou.data.Batch`: * ``obs`` the observation of step :math:`t` ; * ``act`` the action of step :math:`t` ; @@ -124,14 +125,17 @@ class ReplayBuffer: This feature is not supported in Prioritized Replay Buffer currently. """ - def __init__(self, size: int, stack_num: int = 1, - ignore_obs_next: bool = False, - save_only_last_obs: bool = False, - sample_avail: bool = False) -> None: + def __init__( + self, + size: int, + stack_num: int = 1, + ignore_obs_next: bool = False, + save_only_last_obs: bool = False, + sample_avail: bool = False, + ) -> None: super().__init__() self._maxsize = size self._indices = np.arange(size) - self._stack = None self.stack_num = stack_num self._avail = sample_avail and stack_num > 1 self._avail_index = [] @@ -157,7 +161,7 @@ class ReplayBuffer: except KeyError as e: raise AttributeError from e - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: """Unpickling interface. We need it because pickling buffer does not work out-of-the-box @@ -171,11 +175,12 @@ class ReplayBuffer: except KeyError: self._meta.__dict__[name] = _create_value(inst, self._maxsize) value = self._meta.__dict__[name] - if isinstance(inst, (np.ndarray, torch.Tensor)) \ - and value.shape[1:] != inst.shape: + if isinstance(inst, (torch.Tensor, np.ndarray)) \ + and inst.shape != value.shape[1:]: raise ValueError( "Cannot add data to a buffer with different shape, with key " - f"{name}, expect {value.shape[1:]}, given {inst.shape}.") + f"{name}, expect {value.shape[1:]}, given {inst.shape}." + ) try: value[self._index] = inst except KeyError: @@ -184,15 +189,15 @@ class ReplayBuffer: value[self._index] = inst @property - def stack_num(self): + def stack_num(self) -> int: return self._stack @stack_num.setter - def stack_num(self, num): - assert num > 0, 'stack_num should greater than 0' + def stack_num(self, num: int) -> None: + assert num > 0, "stack_num should greater than 0" self._stack = num - def update(self, buffer: 'ReplayBuffer') -> None: + def update(self, buffer: "ReplayBuffer") -> None: """Move the data from the given buffer to self.""" if len(buffer) == 0: return @@ -206,32 +211,35 @@ class ReplayBuffer: break buffer.stack_num = stack_num_orig - def add(self, - obs: Union[dict, Batch, np.ndarray, float], - act: Union[dict, Batch, np.ndarray, float], - rew: Union[int, float], - done: Union[bool, int], - obs_next: Optional[Union[dict, Batch, np.ndarray, float]] = None, - info: Optional[Union[dict, Batch]] = {}, - policy: Optional[Union[dict, Batch]] = {}, - **kwargs) -> None: + def add( + self, + obs: Any, + act: Any, + rew: Union[Number, np.number, np.ndarray], + done: Union[Number, np.number, np.bool_], + obs_next: Any = None, + info: Optional[Union[dict, Batch]] = {}, + policy: Optional[Union[dict, Batch]] = {}, + **kwargs: Any, + ) -> None: """Add a batch of data into replay buffer.""" - assert isinstance(info, (dict, Batch)), \ - 'You should return a dict in the last argument of env.step().' + assert isinstance( + info, (dict, Batch) + ), "You should return a dict in the last argument of env.step()." if self._last_obs: obs = obs[-1] - self._add_to_buffer('obs', obs) - self._add_to_buffer('act', act) - self._add_to_buffer('rew', rew) - self._add_to_buffer('done', done) + self._add_to_buffer("obs", obs) + self._add_to_buffer("act", act) + self._add_to_buffer("rew", rew) + self._add_to_buffer("done", done) if self._save_s_: if obs_next is None: obs_next = Batch() elif self._last_obs: obs_next = obs_next[-1] - self._add_to_buffer('obs_next', obs_next) - self._add_to_buffer('info', info) - self._add_to_buffer('policy', policy) + self._add_to_buffer("obs_next", obs_next) + self._add_to_buffer("info", info) + self._add_to_buffer("policy", policy) # maintain available index for frame-stack sampling if self._avail: @@ -262,7 +270,8 @@ class ReplayBuffer: self._avail_index = [] def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: - """Get a random sample from buffer with size equal to batch_size. \ + """Get a random sample from buffer with size equal to batch_size. + Return all the data in the buffer if batch_size is 0. :return: Sample data and its corresponding index inside the buffer. @@ -278,11 +287,15 @@ class ReplayBuffer: np.arange(self._index, self._size), np.arange(0, self._index), ]) - assert len(indice) > 0, 'No available indice can be sampled.' + assert len(indice) > 0, "No available indice can be sampled." return self[indice], indice - def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str, - stack_num: Optional[int] = None) -> Union[Batch, np.ndarray]: + def get( + self, + indice: Union[slice, int, np.integer, np.ndarray], + key: str, + stack_num: Optional[int] = None, + ) -> Union[Batch, np.ndarray]: """Return the stacked result. E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the @@ -292,7 +305,7 @@ class ReplayBuffer: if stack_num is None: stack_num = self.stack_num if stack_num == 1: # the most often case - if key != 'obs_next' or self._save_s_: + if key != "obs_next" or self._save_s_: val = self._meta.__dict__[key] try: return val[indice] @@ -301,11 +314,11 @@ class ReplayBuffer: raise e # val != Batch() return Batch() indice = self._indices[:self._size][indice] - done = self._meta.__dict__['done'] - if key == 'obs_next' and not self._save_s_: + done = self._meta.__dict__["done"] + if key == "obs_next" and not self._save_s_: indice += 1 - done[indice].astype(np.int) indice[indice == self._size] = 0 - key = 'obs' + key = "obs" val = self._meta.__dict__[key] try: if stack_num == 1: @@ -319,30 +332,30 @@ class ReplayBuffer: pre_indice + done[pre_indice].astype(np.int)) indice[indice == self._size] = 0 if isinstance(val, Batch): - stack = Batch.stack(stack, axis=indice.ndim) + return Batch.stack(stack, axis=indice.ndim) else: - stack = np.stack(stack, axis=indice.ndim) - return stack + return np.stack(stack, axis=indice.ndim) except IndexError as e: if not (isinstance(val, Batch) and val.is_empty()): raise e # val != Batch() return Batch() - def __getitem__(self, index: Union[ - slice, int, np.integer, np.ndarray]) -> Batch: + def __getitem__( + self, index: Union[slice, int, np.integer, np.ndarray] + ) -> Batch: """Return a data batch: self[index]. - If stack_num is larger than 1, return the stacked obs and obs_next - with shape (batch, len, ...). + If stack_num is larger than 1, return the stacked obs and obs_next with + shape (batch, len, ...). """ return Batch( - obs=self.get(index, 'obs'), + obs=self.get(index, "obs"), act=self.act[index], rew=self.rew[index], done=self.done[index], - obs_next=self.get(index, 'obs_next'), - info=self.get(index, 'info'), - policy=self.get(index, 'policy'), + obs_next=self.get(index, "obs_next"), + info=self.get(index, "info"), + policy=self.get(index, "policy"), ) @@ -361,15 +374,15 @@ class ListReplayBuffer(ReplayBuffer): explanation. """ - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: super().__init__(size=0, ignore_obs_next=False, **kwargs) def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: raise NotImplementedError("ListReplayBuffer cannot be sampled!") def _add_to_buffer( - self, name: str, - inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None: + self, name: str, inst: Union[dict, Batch, np.ndarray, float, int, bool] + ) -> None: if self._meta.__dict__.get(name) is None: self._meta.__dict__[name] = [] self._meta.__dict__[name].append(inst) @@ -393,25 +406,29 @@ class PrioritizedReplayBuffer(ReplayBuffer): explanation. """ - def __init__(self, size: int, alpha: float, beta: float, **kwargs) -> None: + def __init__( + self, size: int, alpha: float, beta: float, **kwargs: Any + ) -> None: super().__init__(size, **kwargs) - assert alpha > 0. and beta >= 0. + assert alpha > 0.0 and beta >= 0.0 self._alpha, self._beta = alpha, beta self._max_prio = self._min_prio = 1.0 # save weight directly in this class instead of self._meta self.weight = SegmentTree(size) self.__eps = np.finfo(np.float32).eps.item() - def add(self, - obs: Union[dict, Batch, np.ndarray, float], - act: Union[dict, Batch, np.ndarray, float], - rew: Union[int, float], - done: Union[bool, int], - obs_next: Optional[Union[dict, Batch, np.ndarray, float]] = None, - info: Optional[Union[dict, Batch]] = {}, - policy: Optional[Union[dict, Batch]] = {}, - weight: Optional[float] = None, - **kwargs) -> None: + def add( + self, + obs: Any, + act: Any, + rew: Union[Number, np.number, np.ndarray], + done: Union[Number, np.number, np.bool_], + obs_next: Any = None, + info: Optional[Union[dict, Batch]] = {}, + policy: Optional[Union[dict, Batch]] = {}, + weight: Optional[Union[Number, np.number]] = None, + **kwargs: Any, + ) -> None: """Add a batch of data into replay buffer.""" if weight is None: weight = self._max_prio @@ -433,7 +450,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): to de-bias the sampling process (some transition tuples are sampled more often so their losses are weighted less). """ - assert self._size > 0, 'Cannot sample a buffer with 0 size!' + assert self._size > 0, "Cannot sample a buffer with 0 size!" if batch_size == 0: indice = np.concatenate([ np.arange(self._index, self._size), @@ -449,8 +466,11 @@ class PrioritizedReplayBuffer(ReplayBuffer): batch.weight = (batch.weight / self._min_prio) ** (-self._beta) return batch, indice - def update_weight(self, indice: Union[np.ndarray], - new_weight: Union[np.ndarray, torch.Tensor]) -> None: + def update_weight( + self, + indice: Union[np.ndarray], + new_weight: Union[np.ndarray, torch.Tensor] + ) -> None: """Update priority weight by indice in this buffer. :param np.ndarray indice: indice you want to update weight. @@ -461,15 +481,16 @@ class PrioritizedReplayBuffer(ReplayBuffer): self._max_prio = max(self._max_prio, weight.max()) self._min_prio = min(self._min_prio, weight.min()) - def __getitem__(self, index: Union[ - slice, int, np.integer, np.ndarray]) -> Batch: + def __getitem__( + self, index: Union[slice, int, np.integer, np.ndarray] + ) -> Batch: return Batch( - obs=self.get(index, 'obs'), + obs=self.get(index, "obs"), act=self.act[index], rew=self.rew[index], done=self.done[index], - obs_next=self.get(index, 'obs_next'), - info=self.get(index, 'info'), - policy=self.get(index, 'policy'), + obs_next=self.get(index, "obs_next"), + info=self.get(index, "info"), + policy=self.get(index, "policy"), weight=self.weight[index], ) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 65f869a..aacaa9b 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -4,13 +4,14 @@ import torch import warnings import numpy as np from copy import deepcopy -from typing import Any, Dict, List, Union, Optional, Callable +from numbers import Number +from typing import Dict, List, Union, Optional, Callable -from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.policy import BasePolicy from tianshou.exploration import BaseNoise -from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy from tianshou.data.batch import _create_value +from tianshou.env import BaseVectorEnv, DummyVectorEnv +from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy class Collector(object): @@ -75,14 +76,15 @@ class Collector(object): Please make sure the given environment has a time limitation. """ - def __init__(self, - policy: BasePolicy, - env: Union[gym.Env, BaseVectorEnv], - buffer: Optional[ReplayBuffer] = None, - preprocess_fn: Callable[[Any], Batch] = None, - action_noise: Optional[BaseNoise] = None, - reward_metric: Optional[Callable[[np.ndarray], float]] = None, - ) -> None: + def __init__( + self, + policy: BasePolicy, + env: Union[gym.Env, BaseVectorEnv], + buffer: Optional[ReplayBuffer] = None, + preprocess_fn: Optional[Callable[..., Batch]] = None, + action_noise: Optional[BaseNoise] = None, + reward_metric: Optional[Callable[[np.ndarray], float]] = None, + ) -> None: super().__init__() if not isinstance(env, BaseVectorEnv): env = DummyVectorEnv([lambda: env]) @@ -108,12 +110,15 @@ class Collector(object): self.reset() @staticmethod - def _default_rew_metric(x): + def _default_rew_metric( + x: Union[Number, np.number] + ) -> Union[Number, np.number]: # this internal function is designed for single-agent RL # for multi-agent RL, a reward_metric must be provided - assert np.asanyarray(x).size == 1, \ - 'Please specify the reward_metric ' \ - 'since the reward is not a scalar.' + assert np.asanyarray(x).size == 1, ( + "Please specify the reward_metric " + "since the reward is not a scalar." + ) return x def reset(self) -> None: @@ -124,7 +129,7 @@ class Collector(object): obs_next={}, policy={}) self.reset_env() self.reset_buffer() - self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0 + self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0 if self._action_noise is not None: self._action_noise.reset() @@ -142,7 +147,7 @@ class Collector(object): self._ready_env_ids = np.arange(self.env_num) obs = self.env.reset() if self.preprocess_fn: - obs = self.preprocess_fn(obs=obs).get('obs', obs) + obs = self.preprocess_fn(obs=obs).get("obs", obs) self.data.obs = obs for b in self._cached_buf: b.reset() @@ -157,13 +162,14 @@ class Collector(object): elif isinstance(state, Batch): state.empty_(id) - def collect(self, - n_step: Optional[int] = None, - n_episode: Optional[Union[int, List[int]]] = None, - random: bool = False, - render: Optional[float] = None, - no_grad: bool = True, - ) -> Dict[str, float]: + def collect( + self, + n_step: Optional[int] = None, + n_episode: Optional[Union[int, List[int]]] = None, + random: bool = False, + render: Optional[float] = None, + no_grad: bool = True, + ) -> Dict[str, float]: """Collect a specified number of step or episode. :param int n_step: how many steps you want to collect. @@ -217,8 +223,8 @@ class Collector(object): while True: if step_count >= 100000 and episode_count.sum() == 0: warnings.warn( - 'There are already many steps in an episode. ' - 'You should add a time limitation to your environment!', + "There are already many steps in an episode. " + "You should add a time limitation to your environment!", Warning) is_async = self.is_async or len(finished_env_ids) > 0 @@ -250,11 +256,11 @@ class Collector(object): else: result = self.policy(self.data, last_state) - state = result.get('state', Batch()) + state = result.get("state", Batch()) # convert None to Batch(), since None is reserved for 0-init if state is None: state = Batch() - self.data.update(state=state, policy=result.get('policy', Batch())) + self.data.update(state=state, policy=result.get("policy", Batch())) # save hidden state to policy._state, in order to save into buffer if not (isinstance(state, Batch) and state.is_empty()): self.data.policy._state = self.data.state @@ -268,12 +274,12 @@ class Collector(object): obs_next, rew, done, info = self.env.step(self.data.act) else: # store computed actions, states, etc - _batch_set_item(whole_data, self._ready_env_ids, - self.data, self.env_num) + _batch_set_item( + whole_data, self._ready_env_ids, self.data, self.env_num) # fetch finished data obs_next, rew, done, info = self.env.step( self.data.act, id=self._ready_env_ids) - self._ready_env_ids = np.array([i['env_id'] for i in info]) + self._ready_env_ids = np.array([i["env_id"] for i in info]) # get the stepped data self.data = whole_data[self._ready_env_ids] # move data to self.data @@ -319,15 +325,15 @@ class Collector(object): obs_reset = self.env.reset(env_ind_global) if self.preprocess_fn: obs_next[env_ind_local] = self.preprocess_fn( - obs=obs_reset).get('obs', obs_reset) + obs=obs_reset).get("obs", obs_reset) else: obs_next[env_ind_local] = obs_reset self.data.obs = obs_next if is_async: # set data back whole_data = deepcopy(whole_data) # avoid reference in ListBuf - _batch_set_item(whole_data, self._ready_env_ids, - self.data, self.env_num) + _batch_set_item( + whole_data, self._ready_env_ids, self.data, self.env_num) # let self.data be the data in all environments again self.data = whole_data self._ready_env_ids = np.array( @@ -358,12 +364,12 @@ class Collector(object): if np.asanyarray(reward_avg).size > 1: # non-scalar reward_avg reward_avg = self._rew_metric(reward_avg) return { - 'n/ep': episode_count, - 'n/st': step_count, - 'v/st': step_count / duration, - 'v/ep': episode_count / duration, - 'rew': reward_avg, - 'len': step_count / episode_count, + "n/ep": episode_count, + "n/st": step_count, + "v/st": step_count / duration, + "v/ep": episode_count / duration, + "rew": reward_avg, + "len": step_count / episode_count, } def sample(self, batch_size: int) -> Batch: @@ -377,9 +383,9 @@ class Collector(object): batch_size. """ warnings.warn( - 'Collector.sample is deprecated and will cause error if you use ' - 'prioritized experience replay! Collector.sample will be removed ' - 'upon version 0.3. Use policy.update instead!', Warning) + "Collector.sample is deprecated and will cause error if you use " + "prioritized experience replay! Collector.sample will be removed " + "upon version 0.3. Use policy.update instead!", Warning) assert self.buffer is not None, "Cannot get sample from empty buffer!" batch_data, indice = self.buffer.sample(batch_size) batch_data = self.process_fn(batch_data, self.buffer, indice) @@ -387,12 +393,13 @@ class Collector(object): def close(self) -> None: warnings.warn( - 'Collector.close is deprecated and will be removed upon version ' - '0.3.', Warning) + "Collector.close is deprecated and will be removed upon version " + "0.3.", Warning) -def _batch_set_item(source: Batch, indices: np.ndarray, - target: Batch, size: int): +def _batch_set_item( + source: Batch, indices: np.ndarray, target: Batch, size: int +) -> None: # for any key chain k, there are four cases # 1. source[k] is non-reserved, but target[k] does not exist or is reserved # 2. source[k] does not exist or is reserved, but target[k] is non-reserved diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index dd36c73..6a5d43e 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -1,72 +1,79 @@ import torch import numpy as np +from copy import deepcopy from numbers import Number from typing import Union, Optional from tianshou.data.batch import _parse_value, Batch -def to_numpy(x: Union[ - Batch, dict, list, tuple, np.ndarray, torch.Tensor]) -> Union[ - Batch, dict, list, tuple, np.ndarray, torch.Tensor]: +def to_numpy( + x: Optional[Union[Batch, dict, list, tuple, np.number, np.bool_, Number, + np.ndarray, torch.Tensor]] +) -> Union[Batch, dict, list, tuple, np.ndarray]: """Return an object without torch.Tensor.""" if isinstance(x, torch.Tensor): # most often case - x = x.detach().cpu().numpy() + return x.detach().cpu().numpy() elif isinstance(x, np.ndarray): # second often case - pass + return x elif isinstance(x, (np.number, np.bool_, Number)): - x = np.asanyarray(x) + return np.asanyarray(x) elif x is None: - x = np.array(None, dtype=np.object) + return np.array(None, dtype=np.object) elif isinstance(x, Batch): + x = deepcopy(x) x.to_numpy() + return x elif isinstance(x, dict): - for k, v in x.items(): - x[k] = to_numpy(v) + return {k: to_numpy(v) for k, v in x.items()} elif isinstance(x, (list, tuple)): try: - x = to_numpy(_parse_value(x)) + return to_numpy(_parse_value(x)) except TypeError: - x = [to_numpy(e) for e in x] + return [to_numpy(e) for e in x] else: # fallback - x = np.asanyarray(x) - return x + return np.asanyarray(x) -def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor], - dtype: Optional[torch.dtype] = None, - device: Union[str, int, torch.device] = 'cpu' - ) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]: +def to_torch( + x: Union[Batch, dict, list, tuple, np.number, np.bool_, Number, np.ndarray, + torch.Tensor], + dtype: Optional[torch.dtype] = None, + device: Union[str, int, torch.device] = "cpu", +) -> Union[Batch, dict, list, tuple, torch.Tensor]: """Return an object without np.ndarray.""" - if isinstance(x, np.ndarray) and \ - issubclass(x.dtype.type, (np.bool_, np.number)): # most often case + if isinstance(x, np.ndarray) and issubclass( + x.dtype.type, (np.bool_, np.number) + ): # most often case x = torch.from_numpy(x).to(device) if dtype is not None: x = x.type(dtype) + return x elif isinstance(x, torch.Tensor): # second often case if dtype is not None: x = x.type(dtype) - x = x.to(device) + return x.to(device) elif isinstance(x, (np.number, np.bool_, Number)): - x = to_torch(np.asanyarray(x), dtype, device) + return to_torch(np.asanyarray(x), dtype, device) elif isinstance(x, dict): - for k, v in x.items(): - x[k] = to_torch(v, dtype, device) + return {k: to_torch(v, dtype, device) for k, v in x.items()} elif isinstance(x, Batch): + x = deepcopy(x) x.to_torch(dtype, device) + return x elif isinstance(x, (list, tuple)): try: - x = to_torch(_parse_value(x), dtype, device) + return to_torch(_parse_value(x), dtype, device) except TypeError: - x = [to_torch(e, dtype, device) for e in x] + return [to_torch(e, dtype, device) for e in x] else: # fallback raise TypeError(f"object {x} cannot be converted to torch.") - return x -def to_torch_as(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor], - y: torch.Tensor - ) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]: +def to_torch_as( + x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor], + y: torch.Tensor, +) -> Union[Batch, dict, list, tuple, torch.Tensor]: """Return an object without np.ndarray. Same as ``to_torch(x, dtype=y.dtype, device=y.device)``. diff --git a/tianshou/data/utils/segtree.py b/tianshou/data/utils/segtree.py index c4b4871..5bb6fcc 100644 --- a/tianshou/data/utils/segtree.py +++ b/tianshou/data/utils/segtree.py @@ -24,17 +24,20 @@ class SegmentTree: self._size = size self._bound = bound self._value = np.zeros([bound * 2]) + self._compile() - def __len__(self): + def __len__(self) -> int: return self._size - def __getitem__(self, index: Union[int, np.ndarray] - ) -> Union[float, np.ndarray]: + def __getitem__( + self, index: Union[int, np.ndarray] + ) -> Union[float, np.ndarray]: """Return self[index].""" return self._value[index + self._bound] - def __setitem__(self, index: Union[int, np.ndarray], - value: Union[float, np.ndarray]) -> None: + def __setitem__( + self, index: Union[int, np.ndarray], value: Union[float, np.ndarray] + ) -> None: """Update values in segment tree. Duplicate values in ``index`` are handled by numpy: later index @@ -62,7 +65,8 @@ class SegmentTree: return _reduce(self._value, start + self._bound - 1, end + self._bound) def get_prefix_sum_idx( - self, value: Union[float, np.ndarray]) -> Union[int, np.ndarray]: + self, value: Union[float, np.ndarray] + ) -> Union[int, np.ndarray]: r"""Find the index with given value. Return the minimum index for each ``v`` in ``value`` so that @@ -74,7 +78,7 @@ class SegmentTree: Please make sure all of the values inside the segment tree are non-negative when using this function. """ - assert np.all(value >= 0.) and np.all(value < self._value[1]) + assert np.all(value >= 0.0) and np.all(value < self._value[1]) single = False if not isinstance(value, np.ndarray): value = np.array([value]) @@ -82,6 +86,16 @@ class SegmentTree: index = _get_prefix_sum_idx(value, self._bound, self._value) return index.item() if single else index + def _compile(self) -> None: + f64 = np.array([0, 1], dtype=np.float64) + f32 = np.array([0, 1], dtype=np.float32) + i64 = np.array([0, 1], dtype=np.int64) + _setitem(f64, i64, f64) + _setitem(f64, i64, f32) + _reduce(f64, 0, 1) + _get_prefix_sum_idx(f64, 1, f64) + _get_prefix_sum_idx(f32, 1, f64) + @njit def _setitem(tree: np.ndarray, index: np.ndarray, value: np.ndarray) -> None: @@ -96,7 +110,7 @@ def _setitem(tree: np.ndarray, index: np.ndarray, value: np.ndarray) -> None: def _reduce(tree: np.ndarray, start: int, end: int) -> float: """Numba version, 2x faster: 0.009 -> 0.005.""" # nodes in (start, end) should be aggregated - result = 0. + result = 0.0 while end - start > 1: # (start, end) interval is not empty if start % 2 == 0: result += tree[start + 1] @@ -108,8 +122,9 @@ def _reduce(tree: np.ndarray, start: int, end: int) -> float: @njit -def _get_prefix_sum_idx(value: np.ndarray, bound: int, - sums: np.ndarray) -> np.ndarray: +def _get_prefix_sum_idx( + value: np.ndarray, bound: int, sums: np.ndarray +) -> np.ndarray: """Numba version (v0.51), 5x speed up with size=100000 and bsz=64. vectorized np: 0.0923 (numpy best) -> 0.024 (now) diff --git a/tianshou/env/__init__.py b/tianshou/env/__init__.py index 0fa4d15..d3a49b7 100644 --- a/tianshou/env/__init__.py +++ b/tianshou/env/__init__.py @@ -3,11 +3,11 @@ from tianshou.env.venvs import BaseVectorEnv, DummyVectorEnv, VectorEnv, \ from tianshou.env.maenv import MultiAgentEnv __all__ = [ - 'BaseVectorEnv', - 'DummyVectorEnv', - 'VectorEnv', # TODO: remove in later version - 'SubprocVectorEnv', - 'ShmemVectorEnv', - 'RayVectorEnv', - 'MultiAgentEnv', + "BaseVectorEnv", + "DummyVectorEnv", + "VectorEnv", # TODO: remove in later version + "SubprocVectorEnv", + "ShmemVectorEnv", + "RayVectorEnv", + "MultiAgentEnv", ] diff --git a/tianshou/env/maenv.py b/tianshou/env/maenv.py index 9153cf5..f6a454c 100644 --- a/tianshou/env/maenv.py +++ b/tianshou/env/maenv.py @@ -1,6 +1,6 @@ import gym import numpy as np -from typing import Tuple +from typing import Any, Dict, Tuple from abc import ABC, abstractmethod @@ -22,7 +22,7 @@ class MultiAgentEnv(ABC, gym.Env): usage can be found at :ref:`marl_example`. """ - def __init__(self, **kwargs) -> None: + def __init__(self) -> None: pass @abstractmethod @@ -30,13 +30,14 @@ class MultiAgentEnv(ABC, gym.Env): """Reset the state. Return the initial state, first agent_id, and the initial action set, - for example, ``{'obs': obs, 'agent_id': agent_id, 'mask': mask}`` + for example, ``{'obs': obs, 'agent_id': agent_id, 'mask': mask}``. """ pass @abstractmethod - def step(self, action: np.ndarray - ) -> Tuple[dict, np.ndarray, np.ndarray, np.ndarray]: + def step( + self, action: np.ndarray + ) -> Tuple[Dict[str, Any], np.ndarray, np.ndarray, np.ndarray]: """Run one timestep of the environment’s dynamics. When the end of episode is reached, you are responsible for calling diff --git a/tianshou/env/utils.py b/tianshou/env/utils.py index f7d8c58..5c873ce 100644 --- a/tianshou/env/utils.py +++ b/tianshou/env/utils.py @@ -1,14 +1,15 @@ import cloudpickle +from typing import Any class CloudpickleWrapper(object): """A cloudpickle wrapper used in SubprocVectorEnv.""" - def __init__(self, data): + def __init__(self, data: Any) -> None: self.data = data - def __getstate__(self): + def __getstate__(self) -> str: return cloudpickle.dumps(self.data) - def __setstate__(self, data): + def __setstate__(self, data: str) -> None: self.data = cloudpickle.loads(data) diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 72c5b9e..6b82001 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -1,7 +1,7 @@ import gym import warnings import numpy as np -from typing import List, Union, Optional, Callable, Any +from typing import Any, List, Union, Optional, Callable from tianshou.env.worker import EnvWorker, DummyEnvWorker, SubprocEnvWorker, \ RayEnvWorker @@ -59,12 +59,13 @@ class BaseVectorEnv(gym.Env): within ``timeout`` seconds. """ - def __init__(self, - env_fns: List[Callable[[], gym.Env]], - worker_fn: Callable[[Callable[[], gym.Env]], EnvWorker], - wait_num: Optional[int] = None, - timeout: Optional[float] = None, - ) -> None: + def __init__( + self, + env_fns: List[Callable[[], gym.Env]], + worker_fn: Callable[[Callable[[], gym.Env]], EnvWorker], + wait_num: Optional[int] = None, + timeout: Optional[float] = None, + ) -> None: self._env_fns = env_fns # A VectorEnv contains a pool of EnvWorkers, which corresponds to # interact with the given envs (one worker <-> one env). @@ -75,11 +76,13 @@ class BaseVectorEnv(gym.Env): self.env_num = len(env_fns) self.wait_num = wait_num or len(env_fns) - assert 1 <= self.wait_num <= len(env_fns), \ - f'wait_num should be in [1, {len(env_fns)}], but got {wait_num}' + assert ( + 1 <= self.wait_num <= len(env_fns) + ), f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}" self.timeout = timeout - assert self.timeout is None or self.timeout > 0, \ - f'timeout is {timeout}, it should be positive if provided!' + assert ( + self.timeout is None or self.timeout > 0 + ), f"timeout is {timeout}, it should be positive if provided!" self.is_async = self.wait_num != len(env_fns) or timeout is not None self.waiting_conn = [] # environments in self.ready_id is actually ready @@ -92,8 +95,9 @@ class BaseVectorEnv(gym.Env): self.is_closed = False def _assert_is_not_closed(self) -> None: - assert not self.is_closed, f"Methods of {self.__class__.__name__} "\ - "should not be called after close." + assert not self.is_closed, ( + f"Methods of {self.__class__.__name__} cannot be called after " + "close.") def __len__(self) -> int: """Return len(self), which is the number of environments.""" @@ -113,7 +117,7 @@ class BaseVectorEnv(gym.Env): else: return super().__getattribute__(key) - def __getattr__(self, key: str) -> Any: + def __getattr__(self, key: str) -> List[Any]: """Fetch a list of env attributes. This function tries to retrieve an attribute from each individual @@ -122,8 +126,9 @@ class BaseVectorEnv(gym.Env): """ return [getattr(worker, key) for worker in self.workers] - def _wrap_id(self, id: Optional[Union[int, List[int], np.ndarray]] = None - ) -> List[int]: + def _wrap_id( + self, id: Optional[Union[int, List[int], np.ndarray]] = None + ) -> Union[List[int], np.ndarray]: if id is None: id = list(range(self.env_num)) elif np.isscalar(id): @@ -132,13 +137,16 @@ class BaseVectorEnv(gym.Env): def _assert_id(self, id: List[int]) -> None: for i in id: - assert i not in self.waiting_id, \ - f'Cannot interact with environment {i} which is stepping now.' - assert i in self.ready_id, \ - f'Can only interact with ready environments {self.ready_id}.' + assert ( + i not in self.waiting_id + ), f"Cannot interact with environment {i} which is stepping now." + assert ( + i in self.ready_id + ), f"Can only interact with ready environments {self.ready_id}." - def reset(self, id: Optional[Union[int, List[int], np.ndarray]] = None - ) -> np.ndarray: + def reset( + self, id: Optional[Union[int, List[int], np.ndarray]] = None + ) -> np.ndarray: """Reset the state of some envs and return initial observations. If id is None, reset the state of all the environments and return @@ -152,10 +160,11 @@ class BaseVectorEnv(gym.Env): obs = np.stack([self.workers[i].reset() for i in id]) return obs - def step(self, - action: np.ndarray, - id: Optional[Union[int, List[int], np.ndarray]] = None - ) -> List[np.ndarray]: + def step( + self, + action: np.ndarray, + id: Optional[Union[int, List[int], np.ndarray]] = None + ) -> List[np.ndarray]: """Run one timestep of some environments' dynamics. If id is None, run one timestep of all the environments’ dynamics; @@ -221,8 +230,9 @@ class BaseVectorEnv(gym.Env): self.ready_id.append(env_id) return list(map(np.stack, zip(*result))) - def seed(self, - seed: Optional[Union[int, List[int]]] = None) -> List[List[int]]: + def seed( + self, seed: Optional[Union[int, List[int]]] = None + ) -> List[Optional[List[int]]]: """Set the seed for all environments. Accept ``None``, an int (which will extend ``i`` to @@ -239,13 +249,13 @@ class BaseVectorEnv(gym.Env): seed = [seed + i for i in range(self.env_num)] return [w.seed(s) for w, s in zip(self.workers, seed)] - def render(self, **kwargs) -> List[Any]: + def render(self, **kwargs: Any) -> List[Any]: """Render all of the environments.""" self._assert_is_not_closed() if self.is_async and len(self.waiting_id) > 0: raise RuntimeError( - f"Environments {self.waiting_id} are still " - f"stepping, cannot render them now.") + f"Environments {self.waiting_id} are still stepping, cannot " + "render them now.") return [w.render(**kwargs) for w in self.workers] def close(self) -> None: @@ -275,20 +285,23 @@ class DummyVectorEnv(BaseVectorEnv): explanation. """ - def __init__(self, env_fns: List[Callable[[], gym.Env]], - wait_num: Optional[int] = None, - timeout: Optional[float] = None) -> None: - super().__init__(env_fns, DummyEnvWorker, - wait_num=wait_num, timeout=timeout) + def __init__( + self, + env_fns: List[Callable[[], gym.Env]], + wait_num: Optional[int] = None, + timeout: Optional[float] = None, + ) -> None: + super().__init__( + env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout) class VectorEnv(DummyVectorEnv): """VectorEnv is renamed to DummyVectorEnv.""" - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: warnings.warn( - 'VectorEnv is renamed to DummyVectorEnv, and will be removed in ' - '0.3. Use DummyVectorEnv instead!', Warning) + "VectorEnv is renamed to DummyVectorEnv, and will be removed in " + "0.3. Use DummyVectorEnv instead!", Warning) super().__init__(*args, **kwargs) @@ -301,13 +314,17 @@ class SubprocVectorEnv(BaseVectorEnv): explanation. """ - def __init__(self, env_fns: List[Callable[[], gym.Env]], - wait_num: Optional[int] = None, - timeout: Optional[float] = None) -> None: - def worker_fn(fn): + def __init__( + self, + env_fns: List[Callable[[], gym.Env]], + wait_num: Optional[int] = None, + timeout: Optional[float] = None, + ) -> None: + def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: return SubprocEnvWorker(fn, share_memory=False) - super().__init__(env_fns, worker_fn, - wait_num=wait_num, timeout=timeout) + + super().__init__( + env_fns, worker_fn, wait_num=wait_num, timeout=timeout) class ShmemVectorEnv(BaseVectorEnv): @@ -321,13 +338,17 @@ class ShmemVectorEnv(BaseVectorEnv): detailed explanation. """ - def __init__(self, env_fns: List[Callable[[], gym.Env]], - wait_num: Optional[int] = None, - timeout: Optional[float] = None) -> None: - def worker_fn(fn): + def __init__( + self, + env_fns: List[Callable[[], gym.Env]], + wait_num: Optional[int] = None, + timeout: Optional[float] = None, + ) -> None: + def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker: return SubprocEnvWorker(fn, share_memory=True) - super().__init__(env_fns, worker_fn, - wait_num=wait_num, timeout=timeout) + + super().__init__( + env_fns, worker_fn, wait_num=wait_num, timeout=timeout) class RayVectorEnv(BaseVectorEnv): @@ -341,16 +362,19 @@ class RayVectorEnv(BaseVectorEnv): explanation. """ - def __init__(self, env_fns: List[Callable[[], gym.Env]], - wait_num: Optional[int] = None, - timeout: Optional[float] = None) -> None: + def __init__( + self, + env_fns: List[Callable[[], gym.Env]], + wait_num: Optional[int] = None, + timeout: Optional[float] = None, + ) -> None: try: import ray except ImportError as e: raise ImportError( - 'Please install ray to support RayVectorEnv: pip install ray' + "Please install ray to support RayVectorEnv: pip install ray" ) from e if not ray.is_initialized(): ray.init() - super().__init__(env_fns, RayEnvWorker, - wait_num=wait_num, timeout=timeout) + super().__init__( + env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout) diff --git a/tianshou/env/worker/__init__.py b/tianshou/env/worker/__init__.py index a3d2ea5..b9a20cf 100644 --- a/tianshou/env/worker/__init__.py +++ b/tianshou/env/worker/__init__.py @@ -4,8 +4,8 @@ from tianshou.env.worker.subproc import SubprocEnvWorker from tianshou.env.worker.ray import RayEnvWorker __all__ = [ - 'EnvWorker', - 'DummyEnvWorker', - 'SubprocEnvWorker', - 'RayEnvWorker', + "EnvWorker", + "DummyEnvWorker", + "SubprocEnvWorker", + "RayEnvWorker", ] diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index c3600fa..d097260 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -1,7 +1,7 @@ import gym import numpy as np from abc import ABC, abstractmethod -from typing import List, Tuple, Optional, Callable, Any +from typing import Any, List, Tuple, Optional, Callable class EnvWorker(ABC): @@ -24,12 +24,14 @@ class EnvWorker(ABC): def send_action(self, action: np.ndarray) -> None: pass - def get_result(self) -> Tuple[ - np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + def get_result( + self, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: return self.result - def step(self, action: np.ndarray - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + def step( + self, action: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Perform one timestep of the environment's dynamic. "send_action" and "get_result" are coupled in sync simulation, so @@ -41,19 +43,21 @@ class EnvWorker(ABC): return self.get_result() @staticmethod - def wait(workers: List['EnvWorker'], - wait_num: int, - timeout: Optional[float] = None) -> List['EnvWorker']: + def wait( + workers: List["EnvWorker"], + wait_num: int, + timeout: Optional[float] = None, + ) -> List["EnvWorker"]: """Given a list of workers, return those ready ones.""" raise NotImplementedError @abstractmethod - def seed(self, seed: Optional[int] = None) -> List[int]: + def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: pass @abstractmethod - def render(self, **kwargs) -> Any: - """Renders the environment.""" + def render(self, **kwargs: Any) -> Any: + """Render the environment.""" pass @abstractmethod diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py index b705913..9e88840 100644 --- a/tianshou/env/worker/dummy.py +++ b/tianshou/env/worker/dummy.py @@ -1,6 +1,6 @@ import gym import numpy as np -from typing import List, Callable, Optional, Any +from typing import Any, List, Callable, Optional from tianshou.env.worker import EnvWorker @@ -19,21 +19,24 @@ class DummyEnvWorker(EnvWorker): return self.env.reset() @staticmethod - def wait(workers: List['DummyEnvWorker'], - wait_num: int, - timeout: Optional[float] = None) -> List['DummyEnvWorker']: + def wait( + workers: List["DummyEnvWorker"], + wait_num: int, + timeout: Optional[float] = None, + ) -> List["DummyEnvWorker"]: # Sequential EnvWorker objects are always ready return workers def send_action(self, action: np.ndarray) -> None: self.result = self.env.step(action) - def seed(self, seed: Optional[int] = None) -> List[int]: - return self.env.seed(seed) if hasattr(self.env, 'seed') else None + def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: + return self.env.seed(seed) if hasattr(self.env, "seed") else None - def render(self, **kwargs) -> Any: - return self.env.render(**kwargs) \ - if hasattr(self.env, 'render') else None + def render(self, **kwargs: Any) -> Any: + return ( + self.env.render(**kwargs) if hasattr(self.env, "render") else None + ) def close_env(self) -> None: self.env.close() diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py index 3f71d82..5517388 100644 --- a/tianshou/env/worker/ray.py +++ b/tianshou/env/worker/ray.py @@ -1,6 +1,6 @@ import gym import numpy as np -from typing import List, Callable, Tuple, Optional, Any +from typing import Any, List, Callable, Tuple, Optional from tianshou.env.worker import EnvWorker @@ -24,31 +24,34 @@ class RayEnvWorker(EnvWorker): return ray.get(self.env.reset.remote()) @staticmethod - def wait(workers: List['RayEnvWorker'], - wait_num: int, - timeout: Optional[float] = None) -> List['RayEnvWorker']: + def wait( + workers: List["RayEnvWorker"], + wait_num: int, + timeout: Optional[float] = None, + ) -> List["RayEnvWorker"]: results = [x.result for x in workers] - ready_results, _ = ray.wait(results, - num_returns=wait_num, timeout=timeout) + ready_results, _ = ray.wait( + results, num_returns=wait_num, timeout=timeout + ) return [workers[results.index(result)] for result in ready_results] def send_action(self, action: np.ndarray) -> None: # self.action is actually a handle self.result = self.env.step.remote(action) - def get_result(self) -> Tuple[ - np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + def get_result( + self, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: return ray.get(self.result) - def seed(self, seed: Optional[int] = None) -> List[int]: - if hasattr(self.env, 'seed'): + def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: + if hasattr(self.env, "seed"): return ray.get(self.env.seed.remote(seed)) return None - def render(self, **kwargs) -> Any: - if hasattr(self.env, 'render'): + def render(self, **kwargs: Any) -> Any: + if hasattr(self.env, "render"): return ray.get(self.env.render.remote(**kwargs)) - return None def close_env(self) -> None: ray.get(self.env.close.remote()) diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index 857d148..4b280c2 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -5,14 +5,22 @@ import numpy as np from collections import OrderedDict from multiprocessing.context import Process from multiprocessing import Array, Pipe, connection -from typing import Callable, Any, List, Tuple, Optional +from typing import Any, List, Tuple, Union, Callable, Optional from tianshou.env.worker import EnvWorker from tianshou.env.utils import CloudpickleWrapper -def _worker(parent, p, env_fn_wrapper, obs_bufs=None): - def _encode_obs(obs, buffer): +def _worker( + parent: connection.Connection, + p: connection.Connection, + env_fn_wrapper: CloudpickleWrapper, + obs_bufs: Optional[Union[dict, tuple, "ShArray"]] = None, +) -> None: + def _encode_obs( + obs: Union[dict, tuple, np.ndarray], + buffer: Union[dict, tuple, ShArray], + ) -> None: if isinstance(obs, np.ndarray): buffer.save(obs) elif isinstance(obs, tuple): @@ -32,25 +40,27 @@ def _worker(parent, p, env_fn_wrapper, obs_bufs=None): except EOFError: # the pipe has been closed p.close() break - if cmd == 'step': + if cmd == "step": obs, reward, done, info = env.step(data) if obs_bufs is not None: - obs = _encode_obs(obs, obs_bufs) + _encode_obs(obs, obs_bufs) + obs = None p.send((obs, reward, done, info)) - elif cmd == 'reset': + elif cmd == "reset": obs = env.reset() if obs_bufs is not None: - obs = _encode_obs(obs, obs_bufs) + _encode_obs(obs, obs_bufs) + obs = None p.send(obs) - elif cmd == 'close': + elif cmd == "close": p.send(env.close()) p.close() break - elif cmd == 'render': - p.send(env.render(**data) if hasattr(env, 'render') else None) - elif cmd == 'seed': - p.send(env.seed(data) if hasattr(env, 'seed') else None) - elif cmd == 'getattr': + elif cmd == "render": + p.send(env.render(**data) if hasattr(env, "render") else None) + elif cmd == "seed": + p.send(env.seed(data) if hasattr(env, "seed") else None) + elif cmd == "getattr": p.send(getattr(env, data) if hasattr(env, data) else None) else: p.close() @@ -78,39 +88,39 @@ _NP_TO_CT = { class ShArray: """Wrapper of multiprocessing Array.""" - def __init__(self, dtype, shape): + def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None: self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape))) self.dtype = dtype self.shape = shape - def save(self, ndarray): + def save(self, ndarray: np.ndarray) -> None: assert isinstance(ndarray, np.ndarray) dst = self.arr.get_obj() dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape) np.copyto(dst_np, ndarray) - def get(self): - return np.frombuffer(self.arr.get_obj(), - dtype=self.dtype).reshape(self.shape) + def get(self) -> np.ndarray: + obj = self.arr.get_obj() + return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape) -def _setup_buf(space): +def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]: if isinstance(space, gym.spaces.Dict): assert isinstance(space.spaces, OrderedDict) - buffer = {k: _setup_buf(v) for k, v in space.spaces.items()} + return {k: _setup_buf(v) for k, v in space.spaces.items()} elif isinstance(space, gym.spaces.Tuple): assert isinstance(space.spaces, tuple) - buffer = tuple([_setup_buf(t) for t in space.spaces]) + return tuple([_setup_buf(t) for t in space.spaces]) else: - buffer = ShArray(space.dtype, space.shape) - return buffer + return ShArray(space.dtype, space.shape) class SubprocEnvWorker(EnvWorker): """Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv.""" - def __init__(self, env_fn: Callable[[], gym.Env], - share_memory=False) -> None: + def __init__( + self, env_fn: Callable[[], gym.Env], share_memory: bool = False + ) -> None: super().__init__(env_fn) self.parent_remote, self.child_remote = Pipe() self.share_memory = share_memory @@ -121,18 +131,24 @@ class SubprocEnvWorker(EnvWorker): dummy.close() del dummy self.buffer = _setup_buf(obs_space) - args = (self.parent_remote, self.child_remote, - CloudpickleWrapper(env_fn), self.buffer) + args = ( + self.parent_remote, + self.child_remote, + CloudpickleWrapper(env_fn), + self.buffer, + ) self.process = Process(target=_worker, args=args, daemon=True) self.process.start() self.child_remote.close() def __getattr__(self, key: str) -> Any: - self.parent_remote.send(['getattr', key]) + self.parent_remote.send(["getattr", key]) return self.parent_remote.recv() - def _decode_obs(self, isNone): - def decode_obs(buffer): + def _decode_obs(self) -> Union[dict, tuple, np.ndarray]: + def decode_obs( + buffer: Optional[Union[dict, tuple, ShArray]] + ) -> Union[dict, tuple, np.ndarray]: if isinstance(buffer, ShArray): return buffer.get() elif isinstance(buffer, tuple): @@ -145,16 +161,18 @@ class SubprocEnvWorker(EnvWorker): return decode_obs(self.buffer) def reset(self) -> Any: - self.parent_remote.send(['reset', None]) + self.parent_remote.send(["reset", None]) obs = self.parent_remote.recv() if self.share_memory: - obs = self._decode_obs(obs) + obs = self._decode_obs() return obs @staticmethod - def wait(workers: List['SubprocEnvWorker'], - wait_num: int, - timeout: Optional[float] = None) -> List['SubprocEnvWorker']: + def wait( + workers: List["SubprocEnvWorker"], + wait_num: int, + timeout: Optional[float] = None, + ) -> List["SubprocEnvWorker"]: conns, ready_conns = [x.parent_remote for x in workers], [] remain_conns = conns t1 = time.time() @@ -169,31 +187,32 @@ class SubprocEnvWorker(EnvWorker): new_ready_conns = connection.wait( remain_conns, timeout=remain_time) ready_conns.extend(new_ready_conns) - remain_conns = [conn for conn in remain_conns - if conn not in ready_conns] + remain_conns = [ + conn for conn in remain_conns if conn not in ready_conns] return [workers[conns.index(con)] for con in ready_conns] def send_action(self, action: np.ndarray) -> None: - self.parent_remote.send(['step', action]) + self.parent_remote.send(["step", action]) - def get_result(self) -> Tuple[ - np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + def get_result( + self, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: obs, rew, done, info = self.parent_remote.recv() if self.share_memory: - obs = self._decode_obs(obs) + obs = self._decode_obs() return obs, rew, done, info - def seed(self, seed: Optional[int] = None) -> List[int]: - self.parent_remote.send(['seed', seed]) + def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: + self.parent_remote.send(["seed", seed]) return self.parent_remote.recv() - def render(self, **kwargs) -> Any: - self.parent_remote.send(['render', kwargs]) + def render(self, **kwargs: Any) -> Any: + self.parent_remote.send(["render", kwargs]) return self.parent_remote.recv() def close_env(self) -> None: try: - self.parent_remote.send(['close', None]) + self.parent_remote.send(["close", None]) # mp may be deleted so it may raise AttributeError self.parent_remote.recv() self.process.join() diff --git a/tianshou/exploration/__init__.py b/tianshou/exploration/__init__.py index abe6f38..0878d23 100644 --- a/tianshou/exploration/__init__.py +++ b/tianshou/exploration/__init__.py @@ -1,7 +1,7 @@ from tianshou.exploration.random import BaseNoise, GaussianNoise, OUNoise __all__ = [ - 'BaseNoise', - 'GaussianNoise', - 'OUNoise', + "BaseNoise", + "GaussianNoise", + "OUNoise", ] diff --git a/tianshou/exploration/random.py b/tianshou/exploration/random.py index d9ef006..f06b7b5 100644 --- a/tianshou/exploration/random.py +++ b/tianshou/exploration/random.py @@ -1,16 +1,16 @@ import numpy as np -from typing import Union, Optional from abc import ABC, abstractmethod +from typing import Union, Optional, Sequence class BaseNoise(ABC, object): """The action noise base class.""" - def __init__(self, **kwargs) -> None: + def __init__(self) -> None: super().__init__() @abstractmethod - def __call__(self, **kwargs) -> np.ndarray: + def __call__(self, size: Sequence[int]) -> np.ndarray: """Generate new noise.""" raise NotImplementedError @@ -22,15 +22,13 @@ class BaseNoise(ABC, object): class GaussianNoise(BaseNoise): """The vanilla gaussian process, for exploration in DDPG by default.""" - def __init__(self, - mu: float = 0.0, - sigma: float = 1.0): + def __init__(self, mu: float = 0.0, sigma: float = 1.0) -> None: super().__init__() self._mu = mu - assert 0 <= sigma, 'noise std should not be negative' + assert 0 <= sigma, "Noise std should not be negative." self._sigma = sigma - def __call__(self, size: tuple) -> np.ndarray: + def __call__(self, size: Sequence[int]) -> np.ndarray: return np.random.normal(self._mu, self._sigma, size) @@ -51,27 +49,30 @@ class OUNoise(BaseNoise): Ornstein-Uhlenbeck process. """ - def __init__(self, - mu: float = 0.0, - sigma: float = 0.3, - theta: float = 0.15, - dt: float = 1e-2, - x0: Optional[Union[float, np.ndarray]] = None - ) -> None: - super(BaseNoise, self).__init__() + def __init__( + self, + mu: float = 0.0, + sigma: float = 0.3, + theta: float = 0.15, + dt: float = 1e-2, + x0: Optional[Union[float, np.ndarray]] = None, + ) -> None: + super().__init__() self._mu = mu self._alpha = theta * dt self._beta = sigma * np.sqrt(dt) self._x0 = x0 self.reset() - def __call__(self, size: tuple, mu: Optional[float] = None) -> np.ndarray: + def __call__( + self, size: Sequence[int], mu: Optional[float] = None + ) -> np.ndarray: """Generate new noise. - Return a ``numpy.ndarray`` which size is equal to ``size``. + Return an numpy array which size is equal to ``size``. """ if self._x is None or self._x.shape != size: - self._x = 0 + self._x = 0.0 if mu is None: mu = self._mu r = self._beta * np.random.normal(size=size) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 95b7f0e..7383390 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -12,15 +12,15 @@ from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager __all__ = [ - 'BasePolicy', - 'RandomPolicy', - 'ImitationPolicy', - 'DQNPolicy', - 'PGPolicy', - 'A2CPolicy', - 'DDPGPolicy', - 'PPOPolicy', - 'TD3Policy', - 'SACPolicy', - 'MultiAgentPolicyManager', + "BasePolicy", + "RandomPolicy", + "ImitationPolicy", + "DQNPolicy", + "PGPolicy", + "A2CPolicy", + "DDPGPolicy", + "PPOPolicy", + "TD3Policy", + "SACPolicy", + "MultiAgentPolicyManager", ] diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index a380670..0b00585 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -4,7 +4,7 @@ import numpy as np from torch import nn from numba import njit from abc import ABC, abstractmethod -from typing import Dict, List, Union, Optional, Callable +from typing import Any, List, Union, Mapping, Optional, Callable from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \ to_torch_as, to_numpy @@ -52,23 +52,28 @@ class BasePolicy(ABC, nn.Module): policy.load_state_dict(torch.load("policy.pth")) """ - def __init__(self, - observation_space: gym.Space = None, - action_space: gym.Space = None - ) -> None: + def __init__( + self, + observation_space: gym.Space = None, + action_space: gym.Space = None + ) -> None: super().__init__() self.observation_space = observation_space self.action_space = action_space self.agent_id = 0 + self._compile() def set_agent_id(self, agent_id: int) -> None: """Set self.agent_id = agent_id, for MARL.""" self.agent_id = agent_id @abstractmethod - def forward(self, batch: Batch, - state: Optional[Union[dict, Batch, np.ndarray]] = None, - **kwargs) -> Batch: + def forward( + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + **kwargs: Any, + ) -> Batch: """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which MUST have the following\ @@ -96,8 +101,9 @@ class BasePolicy(ABC, nn.Module): """ pass - def process_fn(self, batch: Batch, buffer: ReplayBuffer, - indice: np.ndarray) -> Batch: + def process_fn( + self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray + ) -> Batch: """Pre-process the data from the provided replay buffer. Used in :meth:`update`. Check out :ref:`process_fn` for more @@ -106,8 +112,9 @@ class BasePolicy(ABC, nn.Module): return batch @abstractmethod - def learn(self, batch: Batch, **kwargs - ) -> Dict[str, Union[float, List[float]]]: + def learn( + self, batch: Batch, **kwargs: Any + ) -> Mapping[str, Union[float, List[float]]]: """Update policy with a given batch of data. :return: A dict which includes loss and its corresponding label. @@ -123,19 +130,22 @@ class BasePolicy(ABC, nn.Module): """ pass - def post_process_fn(self, batch: Batch, - buffer: ReplayBuffer, indice: np.ndarray) -> None: + def post_process_fn( + self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray + ) -> None: """Post-process the data from the provided replay buffer. Typical usage is to update the sampling weight in prioritized experience replay. Used in :meth:`update`. """ - if isinstance(buffer, PrioritizedReplayBuffer) \ - and hasattr(batch, 'weight'): + if isinstance(buffer, PrioritizedReplayBuffer) and hasattr( + batch, "weight" + ): buffer.update_weight(indice, batch.weight) - def update(self, sample_size: int, buffer: Optional[ReplayBuffer], - *args, **kwargs) -> Dict[str, Union[float, List[float]]]: + def update( + self, sample_size: int, buffer: Optional[ReplayBuffer], **kwargs: Any + ) -> Mapping[str, Union[float, List[float]]]: """Update the policy network and replay buffer. It includes 3 function steps: process_fn, learn, and post_process_fn. @@ -148,7 +158,7 @@ class BasePolicy(ABC, nn.Module): return {} batch, indice = buffer.sample(sample_size) batch = self.process_fn(batch, buffer, indice) - result = self.learn(batch, *args, **kwargs) + result = self.learn(batch, **kwargs) self.post_process_fn(batch, buffer, indice) return result @@ -182,7 +192,7 @@ class BasePolicy(ABC, nn.Module): rew = batch.rew v_s_ = np.zeros_like(rew) if v_s_ is None else to_numpy(v_s_).flatten() returns = _episodic_return(v_s_, rew, batch.done, gamma, gae_lambda) - if rew_norm and not np.isclose(returns.std(), 0, 1e-2): + if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2): returns = (returns - returns.mean()) / returns.std() batch.returns = returns return batch @@ -231,9 +241,9 @@ class BasePolicy(ABC, nn.Module): bfr = rew[:min(len(buffer), 1000)] # avoid large buffer mean, std = bfr.mean(), bfr.std() if np.isclose(std, 0, 1e-2): - mean, std = 0., 1. + mean, std = 0.0, 1.0 else: - mean, std = 0., 1. + mean, std = 0.0, 1.0 buf_len = len(buffer) terminal = (indice + n_step - 1) % buf_len target_q_torch = target_q_fn(buffer, terminal).flatten() # (bsz, ) @@ -248,18 +258,30 @@ class BasePolicy(ABC, nn.Module): batch.weight = to_torch_as(batch.weight, target_q_torch) return batch + def _compile(self) -> None: + f64 = np.array([0, 1], dtype=np.float64) + f32 = np.array([0, 1], dtype=np.float32) + b = np.array([False, True], dtype=np.bool_) + i64 = np.array([0, 1], dtype=np.int64) + _episodic_return(f64, f64, b, 0.1, 0.1) + _episodic_return(f32, f64, b, 0.1, 0.1) + _nstep_return(f64, b, f32, i64, 0.1, 1, 4, 1.0, 0.0) + @njit def _episodic_return( - v_s_: np.ndarray, rew: np.ndarray, done: np.ndarray, - gamma: float, gae_lambda: float, + v_s_: np.ndarray, + rew: np.ndarray, + done: np.ndarray, + gamma: float, + gae_lambda: float, ) -> np.ndarray: """Numba speedup: 4.1s -> 0.057s.""" returns = np.roll(v_s_, 1) - m = (1. - done) * gamma + m = (1.0 - done) * gamma delta = rew + v_s_ * m - returns m *= gae_lambda - gae = 0. + gae = 0.0 for i in range(len(rew) - 1, -1, -1): gae = delta[i] + m[i] * gae returns[i] += gae @@ -268,9 +290,15 @@ def _episodic_return( @njit def _nstep_return( - rew: np.ndarray, done: np.ndarray, target_q: np.ndarray, - indice: np.ndarray, gamma: float, n_step: int, buf_len: int, - mean: float, std: float + rew: np.ndarray, + done: np.ndarray, + target_q: np.ndarray, + indice: np.ndarray, + gamma: float, + n_step: int, + buf_len: int, + mean: float, + std: float, ) -> np.ndarray: """Numba speedup: 0.3s -> 0.15s.""" returns = np.zeros(indice.shape) @@ -278,8 +306,8 @@ def _nstep_return( for n in range(n_step - 1, -1, -1): now = (indice + n) % buf_len gammas[done[now] > 0] = n - returns[done[now] > 0] = 0. + returns[done[now] > 0] = 0.0 returns = (rew[now] - mean) / std + gamma * returns - target_q[gammas != n_step] = 0 + target_q[gammas != n_step] = 0.0 target_q = target_q * (gamma ** gammas) + returns return target_q diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index 674de04..13dda4e 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -1,7 +1,7 @@ import torch import numpy as np import torch.nn.functional as F -from typing import Dict, Union, Optional +from typing import Any, Dict, Union, Optional from tianshou.data import Batch, to_torch from tianshou.policy import BasePolicy @@ -22,36 +22,44 @@ class ImitationPolicy(BasePolicy): explanation. """ - def __init__(self, model: torch.nn.Module, optim: torch.optim.Optimizer, - mode: str = 'continuous') -> None: - super().__init__() + def __init__( + self, + model: torch.nn.Module, + optim: torch.optim.Optimizer, + mode: str = "continuous", + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) self.model = model self.optim = optim - assert mode in ['continuous', 'discrete'], \ - f'Mode {mode} is not in ["continuous", "discrete"]' + assert ( + mode in ["continuous", "discrete"] + ), f"Mode {mode} is not in ['continuous', 'discrete']." self.mode = mode - def forward(self, - batch: Batch, - state: Optional[Union[dict, Batch, np.ndarray]] = None, - **kwargs) -> Batch: + def forward( + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + **kwargs: Any, + ) -> Batch: logits, h = self.model(batch.obs, state=state, info=batch.info) - if self.mode == 'discrete': + if self.mode == "discrete": a = logits.max(dim=1)[1] else: a = logits return Batch(logits=logits, act=a, state=h) - def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: self.optim.zero_grad() - if self.mode == 'continuous': + if self.mode == "continuous": # regression a = self(batch).act a_ = to_torch(batch.act, dtype=torch.float32, device=a.device) loss = F.mse_loss(a, a_) - elif self.mode == 'discrete': # classification + elif self.mode == "discrete": # classification a = self(batch).logits a_ = to_torch(batch.act, dtype=torch.long, device=a.device) loss = F.nll_loss(a, a_) loss.backward() self.optim.step() - return {'loss': loss.item()} + return {"loss": loss.item()} diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index dfdc02d..ae7cfa7 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -2,7 +2,7 @@ import torch import numpy as np from torch import nn import torch.nn.functional as F -from typing import Dict, List, Union, Optional +from typing import Any, Dict, List, Union, Optional, Callable from tianshou.policy import PGPolicy from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy @@ -17,6 +17,7 @@ class A2CPolicy(PGPolicy): :param torch.optim.Optimizer optim: the optimizer for actor and critic network. :param dist_fn: distribution class for computing the action. + :type dist_fn: Callable[[], torch.distributions.Distribution] :param float discount_factor: in [0, 1], defaults to 0.99. :param float vf_coef: weight for value loss, defaults to 0.5. :param float ent_coef: weight for entropy loss, defaults to 0.01. @@ -37,23 +38,25 @@ class A2CPolicy(PGPolicy): explanation. """ - def __init__(self, - actor: torch.nn.Module, - critic: torch.nn.Module, - optim: torch.optim.Optimizer, - dist_fn: torch.distributions.Distribution, - discount_factor: float = 0.99, - vf_coef: float = .5, - ent_coef: float = .01, - max_grad_norm: Optional[float] = None, - gae_lambda: float = 0.95, - reward_normalization: bool = False, - max_batchsize: int = 256, - **kwargs) -> None: + def __init__( + self, + actor: torch.nn.Module, + critic: torch.nn.Module, + optim: torch.optim.Optimizer, + dist_fn: Callable[[], torch.distributions.Distribution], + discount_factor: float = 0.99, + vf_coef: float = 0.5, + ent_coef: float = 0.01, + max_grad_norm: Optional[float] = None, + gae_lambda: float = 0.95, + reward_normalization: bool = False, + max_batchsize: int = 256, + **kwargs: Any + ) -> None: super().__init__(None, optim, dist_fn, discount_factor, **kwargs) self.actor = actor self.critic = critic - assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].' + assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]." self._lambda = gae_lambda self._w_vf = vf_coef self._w_ent = ent_coef @@ -61,9 +64,10 @@ class A2CPolicy(PGPolicy): self._batch = max_batchsize self._rew_norm = reward_normalization - def process_fn(self, batch: Batch, buffer: ReplayBuffer, - indice: np.ndarray) -> Batch: - if self._lambda in [0, 1]: + def process_fn( + self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray + ) -> Batch: + if self._lambda in [0.0, 1.0]: return self.compute_episodic_return( batch, None, gamma=self._gamma, gae_lambda=self._lambda) v_ = [] @@ -75,9 +79,12 @@ class A2CPolicy(PGPolicy): batch, v_, gamma=self._gamma, gae_lambda=self._lambda, rew_norm=self._rew_norm) - def forward(self, batch: Batch, - state: Optional[Union[dict, Batch, np.ndarray]] = None, - **kwargs) -> Batch: + def forward( + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + **kwargs: Any + ) -> Batch: """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which has 4 keys: @@ -100,8 +107,9 @@ class A2CPolicy(PGPolicy): act = dist.sample() return Batch(logits=logits, act=act, state=h, dist=dist) - def learn(self, batch: Batch, batch_size: int, repeat: int, - **kwargs) -> Dict[str, List[float]]: + def learn( + self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any + ) -> Dict[str, List[float]]: losses, actor_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): for b in batch.split(batch_size, merge_last=True): @@ -110,8 +118,7 @@ class A2CPolicy(PGPolicy): v = self.critic(b.obs).flatten() a = to_torch_as(b.act, v) r = to_torch_as(b.returns, v) - log_prob = dist.log_prob(a).reshape( - r.shape[0], -1).transpose(0, 1) + log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1) a_loss = -(log_prob * (r - v).detach()).mean() vf_loss = F.mse_loss(r, v) ent_loss = dist.entropy().mean() @@ -119,17 +126,18 @@ class A2CPolicy(PGPolicy): loss.backward() if self._grad_norm is not None: nn.utils.clip_grad_norm_( - list(self.actor.parameters()) + - list(self.critic.parameters()), - max_norm=self._grad_norm) + list(self.actor.parameters()) + + list(self.critic.parameters()), + max_norm=self._grad_norm, + ) self.optim.step() actor_losses.append(a_loss.item()) vf_losses.append(vf_loss.item()) ent_losses.append(ent_loss.item()) losses.append(loss.item()) return { - 'loss': losses, - 'loss/actor': actor_losses, - 'loss/vf': vf_losses, - 'loss/ent': ent_losses, + "loss": losses, + "loss/actor": actor_losses, + "loss/vf": vf_losses, + "loss/ent": ent_losses, } diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index f8cf60a..3770c64 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -1,7 +1,7 @@ import torch import numpy as np from copy import deepcopy -from typing import Dict, Tuple, Union, Optional +from typing import Any, Dict, Tuple, Union, Optional from tianshou.policy import BasePolicy from tianshou.exploration import BaseNoise, GaussianNoise @@ -17,13 +17,13 @@ class DDPGPolicy(BasePolicy): :param torch.nn.Module critic: the critic network. (s, a -> Q(s, a)) :param torch.optim.Optimizer critic_optim: the optimizer for critic network. + :param action_range: the action range (minimum, maximum). + :type action_range: Tuple[float, float] :param float tau: param for soft update of the target network, defaults to 0.005. :param float gamma: discount factor, in [0, 1], defaults to 0.99. :param BaseNoise exploration_noise: the exploration noise, add to the action, defaults to ``GaussianNoise(sigma=0.1)``. - :param action_range: the action range (minimum, maximum). - :type action_range: (float, float) :param bool reward_normalization: normalize the reward to Normal(0, 1), defaults to False. :param bool ignore_done: ignore the done flag while training the policy, @@ -37,20 +37,21 @@ class DDPGPolicy(BasePolicy): explanation. """ - def __init__(self, - actor: torch.nn.Module, - actor_optim: torch.optim.Optimizer, - critic: torch.nn.Module, - critic_optim: torch.optim.Optimizer, - tau: float = 0.005, - gamma: float = 0.99, - exploration_noise: Optional[BaseNoise] - = GaussianNoise(sigma=0.1), - action_range: Optional[Tuple[float, float]] = None, - reward_normalization: bool = False, - ignore_done: bool = False, - estimation_step: int = 1, - **kwargs) -> None: + def __init__( + self, + actor: Optional[torch.nn.Module], + actor_optim: Optional[torch.optim.Optimizer], + critic: Optional[torch.nn.Module], + critic_optim: Optional[torch.optim.Optimizer], + action_range: Tuple[float, float], + tau: float = 0.005, + gamma: float = 0.99, + exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1), + reward_normalization: bool = False, + ignore_done: bool = False, + estimation_step: int = 1, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) if actor is not None: self.actor, self.actor_old = actor, deepcopy(actor) @@ -60,27 +61,26 @@ class DDPGPolicy(BasePolicy): self.critic, self.critic_old = critic, deepcopy(critic) self.critic_old.eval() self.critic_optim = critic_optim - assert 0 <= tau <= 1, 'tau should in [0, 1]' + assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]" self._tau = tau - assert 0 <= gamma <= 1, 'gamma should in [0, 1]' + assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]" self._gamma = gamma self._noise = exploration_noise - assert action_range is not None self._range = action_range - self._action_bias = (action_range[0] + action_range[1]) / 2 - self._action_scale = (action_range[1] - action_range[0]) / 2 - # it is only a little difference to use rand_normal + self._action_bias = (action_range[0] + action_range[1]) / 2.0 + self._action_scale = (action_range[1] - action_range[0]) / 2.0 + # it is only a little difference to use GaussianNoise # self.noise = OUNoise() self._rm_done = ignore_done self._rew_norm = reward_normalization - assert estimation_step > 0, 'estimation_step should greater than 0' + assert estimation_step > 0, "estimation_step should be greater than 0" self._n_step = estimation_step def set_exp_noise(self, noise: Optional[BaseNoise]) -> None: """Set the exploration noise.""" self._noise = noise - def train(self, mode=True) -> torch.nn.Module: + def train(self, mode: bool = True) -> "DDPGPolicy": """Set the module in training mode, except for the target network.""" self.training = mode self.actor.train(mode) @@ -90,13 +90,15 @@ class DDPGPolicy(BasePolicy): def sync_weight(self) -> None: """Soft-update the weight for the target network.""" for o, n in zip(self.actor_old.parameters(), self.actor.parameters()): - o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) + o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) for o, n in zip( - self.critic_old.parameters(), self.critic.parameters()): - o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) + self.critic_old.parameters(), self.critic.parameters() + ): + o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) - def _target_q(self, buffer: ReplayBuffer, - indice: np.ndarray) -> torch.Tensor: + def _target_q( + self, buffer: ReplayBuffer, indice: np.ndarray + ) -> torch.Tensor: batch = buffer[indice] # batch.obs_next: s_{t+n} with torch.no_grad(): target_q = self.critic_old(batch.obs_next, self( @@ -104,21 +106,25 @@ class DDPGPolicy(BasePolicy): explorating=False).act) return target_q - def process_fn(self, batch: Batch, buffer: ReplayBuffer, - indice: np.ndarray) -> Batch: + def process_fn( + self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray + ) -> Batch: if self._rm_done: - batch.done = batch.done * 0. + batch.done = batch.done * 0.0 batch = self.compute_nstep_return( batch, buffer, indice, self._target_q, self._gamma, self._n_step, self._rew_norm) return batch - def forward(self, batch: Batch, - state: Optional[Union[dict, Batch, np.ndarray]] = None, - model: str = 'actor', - input: str = 'obs', - explorating: bool = True, - **kwargs) -> Batch: + def forward( + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + model: str = "actor", + input: str = "obs", + explorating: bool = True, + **kwargs: Any, + ) -> Batch: """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which has 2 keys: @@ -140,8 +146,8 @@ class DDPGPolicy(BasePolicy): actions = actions.clamp(self._range[0], self._range[1]) return Batch(act=actions, state=h) - def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: - weight = batch.pop('weight', 1.) + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: + weight = batch.pop("weight", 1.0) current_q = self.critic(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td = current_q - target_q @@ -157,6 +163,6 @@ class DDPGPolicy(BasePolicy): self.actor_optim.step() self.sync_weight() return { - 'loss/actor': actor_loss.item(), - 'loss/critic': critic_loss.item(), + "loss/actor": actor_loss.item(), + "loss/critic": critic_loss.item(), } diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 4018750..c6ba4fa 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -1,7 +1,7 @@ import torch import numpy as np from copy import deepcopy -from typing import Dict, Union, Optional +from typing import Any, Dict, Union, Optional from tianshou.policy import BasePolicy from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy @@ -32,21 +32,25 @@ class DQNPolicy(BasePolicy): explanation. """ - def __init__(self, - model: torch.nn.Module, - optim: torch.optim.Optimizer, - discount_factor: float = 0.99, - estimation_step: int = 1, - target_update_freq: int = 0, - reward_normalization: bool = False, - **kwargs) -> None: + def __init__( + self, + model: torch.nn.Module, + optim: torch.optim.Optimizer, + discount_factor: float = 0.99, + estimation_step: int = 1, + target_update_freq: int = 0, + reward_normalization: bool = False, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self.model = model self.optim = optim - self.eps = 0 - assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]' + self.eps = 0.0 + assert ( + 0.0 <= discount_factor <= 1.0 + ), "discount factor should be in [0, 1]" self._gamma = discount_factor - assert estimation_step > 0, 'estimation_step should greater than 0' + assert estimation_step > 0, "estimation_step should be greater than 0" self._n_step = estimation_step self._target = target_update_freq > 0 self._freq = target_update_freq @@ -60,7 +64,7 @@ class DQNPolicy(BasePolicy): """Set the eps for epsilon-greedy exploration.""" self.eps = eps - def train(self, mode=True) -> torch.nn.Module: + def train(self, mode: bool = True) -> "DQNPolicy": """Set the module in training mode, except for the target network.""" self.training = mode self.model.train(mode) @@ -70,23 +74,26 @@ class DQNPolicy(BasePolicy): """Synchronize the weight for the target network.""" self.model_old.load_state_dict(self.model.state_dict()) - def _target_q(self, buffer: ReplayBuffer, - indice: np.ndarray) -> torch.Tensor: + def _target_q( + self, buffer: ReplayBuffer, indice: np.ndarray + ) -> torch.Tensor: batch = buffer[indice] # batch.obs_next: s_{t+n} if self._target: # target_Q = Q_old(s_, argmax(Q_new(s_, *))) - a = self(batch, input='obs_next', eps=0).act + a = self(batch, input="obs_next", eps=0).act with torch.no_grad(): target_q = self( - batch, model='model_old', input='obs_next').logits + batch, model="model_old", input="obs_next" + ).logits target_q = target_q[np.arange(len(a)), a] else: with torch.no_grad(): - target_q = self(batch, input='obs_next').logits.max(dim=1)[0] + target_q = self(batch, input="obs_next").logits.max(dim=1)[0] return target_q - def process_fn(self, batch: Batch, buffer: ReplayBuffer, - indice: np.ndarray) -> Batch: + def process_fn( + self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray + ) -> Batch: """Compute the n-step return for Q-learning targets. More details can be found at @@ -97,12 +104,15 @@ class DQNPolicy(BasePolicy): self._gamma, self._n_step, self._rew_norm) return batch - def forward(self, batch: Batch, - state: Optional[Union[dict, Batch, np.ndarray]] = None, - model: str = 'model', - input: str = 'obs', - eps: Optional[float] = None, - **kwargs) -> Batch: + def forward( + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + model: str = "model", + input: str = "obs", + eps: Optional[float] = None, + **kwargs: Any, + ) -> Batch: """Compute action over the given batch data. If you need to mask the action, please add a "mask" into batch.obs, for @@ -134,7 +144,7 @@ class DQNPolicy(BasePolicy): """ model = getattr(self, model) obs = getattr(batch, input) - obs_ = obs.obs if hasattr(obs, 'obs') else obs + obs_ = obs.obs if hasattr(obs, "obs") else obs q, h = model(obs_, state=state, info=batch.info) act = to_numpy(q.max(dim=1)[1]) has_mask = hasattr(obs, 'mask') @@ -146,7 +156,7 @@ class DQNPolicy(BasePolicy): # add eps to act if eps is None: eps = self.eps - if not np.isclose(eps, 0): + if not np.isclose(eps, 0.0): for i in range(len(q)): if np.random.rand() < eps: q_ = np.random.rand(*q[i].shape) @@ -155,12 +165,12 @@ class DQNPolicy(BasePolicy): act[i] = q_.argmax() return Batch(logits=q, act=act, state=h) - def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: if self._target and self._cnt % self._freq == 0: self.sync_weight() self.optim.zero_grad() - weight = batch.pop('weight', 1.) - q = self(batch, eps=0.).logits + weight = batch.pop("weight", 1.0) + q = self(batch, eps=0.0).logits q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns, q).flatten() td = r - q @@ -169,4 +179,4 @@ class DQNPolicy(BasePolicy): loss.backward() self.optim.step() self._cnt += 1 - return {'loss': loss.item()} + return {"loss": loss.item()} diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index e0a5212..e93b4f9 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -1,6 +1,6 @@ import torch import numpy as np -from typing import Dict, List, Union, Optional +from typing import Any, Dict, List, Union, Optional, Callable from tianshou.policy import BasePolicy from tianshou.data import Batch, ReplayBuffer, to_torch_as @@ -13,6 +13,7 @@ class PGPolicy(BasePolicy): :class:`~tianshou.policy.BasePolicy`. (s -> logits) :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. :param dist_fn: distribution class for computing the action. + :type dist_fn: Callable[[], torch.distributions.Distribution] :param float discount_factor: in [0, 1]. .. seealso:: @@ -21,23 +22,28 @@ class PGPolicy(BasePolicy): explanation. """ - def __init__(self, - model: torch.nn.Module, - optim: torch.optim.Optimizer, - dist_fn: torch.distributions.Distribution, - discount_factor: float = 0.99, - reward_normalization: bool = False, - **kwargs) -> None: + def __init__( + self, + model: Optional[torch.nn.Module], + optim: torch.optim.Optimizer, + dist_fn: Callable[[], torch.distributions.Distribution], + discount_factor: float = 0.99, + reward_normalization: bool = False, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self.model = model self.optim = optim self.dist_fn = dist_fn - assert 0 <= discount_factor <= 1, 'discount factor should in [0, 1]' + assert ( + 0.0 <= discount_factor <= 1.0 + ), "discount factor should be in [0, 1]" self._gamma = discount_factor self._rew_norm = reward_normalization - def process_fn(self, batch: Batch, buffer: ReplayBuffer, - indice: np.ndarray) -> Batch: + def process_fn( + self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray + ) -> Batch: r"""Compute the discounted returns for each frame. .. math:: @@ -48,13 +54,15 @@ class PGPolicy(BasePolicy): """ # batch.returns = self._vanilla_returns(batch) # batch.returns = self._vectorized_returns(batch) - # return batch return self.compute_episodic_return( - batch, gamma=self._gamma, gae_lambda=1., rew_norm=self._rew_norm) + batch, gamma=self._gamma, gae_lambda=1.0, rew_norm=self._rew_norm) - def forward(self, batch: Batch, - state: Optional[Union[dict, Batch, np.ndarray]] = None, - **kwargs) -> Batch: + def forward( + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + **kwargs: Any, + ) -> Batch: """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which has 4 keys: @@ -77,8 +85,9 @@ class PGPolicy(BasePolicy): act = dist.sample() return Batch(logits=logits, act=act, state=h, dist=dist) - def learn(self, batch: Batch, batch_size: int, repeat: int, - **kwargs) -> Dict[str, List[float]]: + def learn( + self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any + ) -> Dict[str, List[float]]: losses = [] for _ in range(repeat): for b in batch.split(batch_size, merge_last=True): @@ -86,13 +95,12 @@ class PGPolicy(BasePolicy): dist = self(b).dist a = to_torch_as(b.act, dist.logits) r = to_torch_as(b.returns, dist.logits) - log_prob = dist.log_prob(a).reshape( - r.shape[0], -1).transpose(0, 1) + log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1) loss = -(log_prob * r).mean() loss.backward() self.optim.step() losses.append(loss.item()) - return {'loss': losses} + return {"loss": losses} # def _vanilla_returns(self, batch): # returns = batch.rew[:] diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index df84eb9..349ab9b 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -1,7 +1,7 @@ import torch import numpy as np from torch import nn -from typing import Dict, List, Tuple, Union, Optional +from typing import Any, Dict, List, Tuple, Union, Optional, Callable from tianshou.policy import PGPolicy from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as @@ -16,6 +16,7 @@ class PPOPolicy(PGPolicy): :param torch.optim.Optimizer optim: the optimizer for actor and critic network. :param dist_fn: distribution class for computing the action. + :type dist_fn: Callable[[], torch.distributions.Distribution] :param float discount_factor: in [0, 1], defaults to 0.99. :param float max_grad_norm: clipping gradients in back propagation, defaults to None. @@ -45,24 +46,26 @@ class PPOPolicy(PGPolicy): explanation. """ - def __init__(self, - actor: torch.nn.Module, - critic: torch.nn.Module, - optim: torch.optim.Optimizer, - dist_fn: torch.distributions.Distribution, - discount_factor: float = 0.99, - max_grad_norm: Optional[float] = None, - eps_clip: float = .2, - vf_coef: float = .5, - ent_coef: float = .01, - action_range: Optional[Tuple[float, float]] = None, - gae_lambda: float = 0.95, - dual_clip: Optional[float] = None, - value_clip: bool = True, - reward_normalization: bool = True, - max_batchsize: int = 256, - **kwargs) -> None: - super().__init__(None, None, dist_fn, discount_factor, **kwargs) + def __init__( + self, + actor: torch.nn.Module, + critic: torch.nn.Module, + optim: torch.optim.Optimizer, + dist_fn: Callable[[], torch.distributions.Distribution], + discount_factor: float = 0.99, + max_grad_norm: Optional[float] = None, + eps_clip: float = 0.2, + vf_coef: float = 0.5, + ent_coef: float = 0.01, + action_range: Optional[Tuple[float, float]] = None, + gae_lambda: float = 0.95, + dual_clip: Optional[float] = None, + value_clip: bool = True, + reward_normalization: bool = True, + max_batchsize: int = 256, + **kwargs: Any, + ) -> None: + super().__init__(None, optim, dist_fn, discount_factor, **kwargs) self._max_grad_norm = max_grad_norm self._eps_clip = eps_clip self._w_vf = vf_coef @@ -70,29 +73,31 @@ class PPOPolicy(PGPolicy): self._range = action_range self.actor = actor self.critic = critic - self.optim = optim self._batch = max_batchsize - assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].' + assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]." self._lambda = gae_lambda - assert dual_clip is None or dual_clip > 1, \ - 'Dual-clip PPO parameter should greater than 1.' + assert ( + dual_clip is None or dual_clip > 1.0 + ), "Dual-clip PPO parameter should greater than 1.0." self._dual_clip = dual_clip self._value_clip = value_clip self._rew_norm = reward_normalization - def process_fn(self, batch: Batch, buffer: ReplayBuffer, - indice: np.ndarray) -> Batch: + def process_fn( + self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray + ) -> Batch: if self._rew_norm: mean, std = batch.rew.mean(), batch.rew.std() - if not np.isclose(std, 0, 1e-2): + if not np.isclose(std, 0.0, 1e-2): batch.rew = (batch.rew - mean) / std v, v_, old_log_prob = [], [], [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False, merge_last=True): v_.append(self.critic(b.obs_next)) v.append(self.critic(b.obs)) - old_log_prob.append(self(b).dist.log_prob( - to_torch_as(b.act, v[0]))) + old_log_prob.append( + self(b).dist.log_prob(to_torch_as(b.act, v[0])) + ) v_ = to_numpy(torch.cat(v_, dim=0)) batch = self.compute_episodic_return( batch, v_, gamma=self._gamma, gae_lambda=self._lambda, @@ -104,13 +109,16 @@ class PPOPolicy(PGPolicy): batch.adv = batch.returns - batch.v if self._rew_norm: mean, std = batch.adv.mean(), batch.adv.std() - if not np.isclose(std.item(), 0, 1e-2): + if not np.isclose(std.item(), 0.0, 1e-2): batch.adv = (batch.adv - mean) / std return batch - def forward(self, batch: Batch, - state: Optional[Union[dict, Batch, np.ndarray]] = None, - **kwargs) -> Batch: + def forward( + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + **kwargs: Any, + ) -> Batch: """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which has 4 keys: @@ -135,8 +143,9 @@ class PPOPolicy(PGPolicy): act = act.clamp(self._range[0], self._range[1]) return Batch(logits=logits, act=act, state=h, dist=dist) - def learn(self, batch: Batch, batch_size: int, repeat: int, - **kwargs) -> Dict[str, List[float]]: + def learn( + self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any + ) -> Dict[str, List[float]]: losses, clip_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): for b in batch.split(batch_size, merge_last=True): @@ -145,8 +154,8 @@ class PPOPolicy(PGPolicy): ratio = (dist.log_prob(b.act) - b.logp_old).exp().float() ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1) surr1 = ratio * b.adv - surr2 = ratio.clamp(1. - self._eps_clip, - 1. + self._eps_clip) * b.adv + surr2 = ratio.clamp(1.0 - self._eps_clip, + 1.0 + self._eps_clip) * b.adv if self._dual_clip: clip_loss = -torch.max(torch.min(surr1, surr2), self._dual_clip * b.adv).mean() @@ -158,9 +167,9 @@ class PPOPolicy(PGPolicy): -self._eps_clip, self._eps_clip) vf1 = (b.returns - value).pow(2) vf2 = (b.returns - v_clip).pow(2) - vf_loss = .5 * torch.max(vf1, vf2).mean() + vf_loss = 0.5 * torch.max(vf1, vf2).mean() else: - vf_loss = .5 * (b.returns - value).pow(2).mean() + vf_loss = 0.5 * (b.returns - value).pow(2).mean() vf_losses.append(vf_loss.item()) e_loss = dist.entropy().mean() ent_losses.append(e_loss.item()) @@ -168,13 +177,14 @@ class PPOPolicy(PGPolicy): losses.append(loss.item()) self.optim.zero_grad() loss.backward() - nn.utils.clip_grad_norm_(list( - self.actor.parameters()) + list(self.critic.parameters()), + nn.utils.clip_grad_norm_( + list(self.actor.parameters()) + + list(self.critic.parameters()), self._max_grad_norm) self.optim.step() return { - 'loss': losses, - 'loss/clip': clip_losses, - 'loss/vf': vf_losses, - 'loss/ent': ent_losses, + "loss": losses, + "loss/clip": clip_losses, + "loss/vf": vf_losses, + "loss/ent": ent_losses, } diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 47823c6..b596eb5 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -1,12 +1,12 @@ import torch import numpy as np from copy import deepcopy -from typing import Dict, Tuple, Union, Optional -from torch.distributions import Normal, Independent +from torch.distributions import Independent, Normal +from typing import Any, Dict, Tuple, Union, Optional from tianshou.policy import DDPGPolicy -from tianshou.data import Batch, to_torch_as, ReplayBuffer from tianshou.exploration import BaseNoise +from tianshou.data import Batch, ReplayBuffer, to_torch_as class SACPolicy(DDPGPolicy): @@ -23,6 +23,8 @@ class SACPolicy(DDPGPolicy): a)) :param torch.optim.Optimizer critic2_optim: the optimizer for the second critic network. + :param action_range: the action range (minimum, maximum). + :type action_range: Tuple[float, float] :param float tau: param for soft update of the target network, defaults to 0.005. :param float gamma: discount factor, in [0, 1], defaults to 0.99. @@ -32,8 +34,6 @@ class SACPolicy(DDPGPolicy): regularization coefficient, default to 0.2. If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then alpha is automatatically tuned. - :param action_range: the action range (minimum, maximum). - :type action_range: (float, float) :param bool reward_normalization: normalize the reward to Normal(0, 1), defaults to False. :param bool ignore_done: ignore the done flag while training the policy, @@ -55,20 +55,20 @@ class SACPolicy(DDPGPolicy): critic1_optim: torch.optim.Optimizer, critic2: torch.nn.Module, critic2_optim: torch.optim.Optimizer, + action_range: Tuple[float, float], tau: float = 0.005, gamma: float = 0.99, alpha: Union[ float, Tuple[float, torch.Tensor, torch.optim.Optimizer] ] = 0.2, - action_range: Optional[Tuple[float, float]] = None, reward_normalization: bool = False, ignore_done: bool = False, estimation_step: int = 1, exploration_noise: Optional[BaseNoise] = None, - **kwargs + **kwargs: Any, ) -> None: - super().__init__(None, None, None, None, tau, gamma, exploration_noise, - action_range, reward_normalization, ignore_done, + super().__init__(None, None, None, None, action_range, tau, gamma, + exploration_noise, reward_normalization, ignore_done, estimation_step, **kwargs) self.actor, self.actor_optim = actor, actor_optim self.critic1, self.critic1_old = critic1, deepcopy(critic1) @@ -79,6 +79,7 @@ class SACPolicy(DDPGPolicy): self.critic2_optim = critic2_optim self._is_auto_alpha = False + self._alpha: Union[float, torch.Tensor] if isinstance(alpha, tuple): self._is_auto_alpha = True self._target_entropy, self._log_alpha, self._alpha_optim = alpha @@ -89,7 +90,7 @@ class SACPolicy(DDPGPolicy): self.__eps = np.finfo(np.float32).eps.item() - def train(self, mode=True) -> torch.nn.Module: + def train(self, mode: bool = True) -> "SACPolicy": self.training = mode self.actor.train(mode) self.critic1.train(mode) @@ -98,17 +99,22 @@ class SACPolicy(DDPGPolicy): def sync_weight(self) -> None: for o, n in zip( - self.critic1_old.parameters(), self.critic1.parameters()): - o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) + self.critic1_old.parameters(), self.critic1.parameters() + ): + o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) for o, n in zip( - self.critic2_old.parameters(), self.critic2.parameters()): - o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) + self.critic2_old.parameters(), self.critic2.parameters() + ): + o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) - def forward(self, batch: Batch, - state: Optional[Union[dict, Batch, np.ndarray]] = None, - input: str = 'obs', - explorating: bool = True, - **kwargs) -> Batch: + def forward( + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + input: str = "obs", + explorating: bool = True, + **kwargs: Any, + ) -> Batch: obs = getattr(batch, input) logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) @@ -125,8 +131,9 @@ class SACPolicy(DDPGPolicy): return Batch( logits=logits, act=act, state=h, dist=dist, log_prob=log_prob) - def _target_q(self, buffer: ReplayBuffer, - indice: np.ndarray) -> torch.Tensor: + def _target_q( + self, buffer: ReplayBuffer, indice: np.ndarray + ) -> torch.Tensor: batch = buffer[indice] # batch.obs: s_{t+n} with torch.no_grad(): obs_next_result = self(batch, input='obs_next', explorating=False) @@ -138,8 +145,8 @@ class SACPolicy(DDPGPolicy): ) - self._alpha * obs_next_result.log_prob return target_q - def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: - weight = batch.pop('weight', 1.) + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: + weight = batch.pop("weight", 1.0) # critic 1 current_q1 = self.critic1(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() @@ -157,7 +164,7 @@ class SACPolicy(DDPGPolicy): self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() - batch.weight = (td1 + td2) / 2. # prio-buffer + batch.weight = (td1 + td2) / 2.0 # prio-buffer # actor obs_result = self(batch, explorating=False) a = obs_result.act @@ -180,11 +187,11 @@ class SACPolicy(DDPGPolicy): self.sync_weight() result = { - 'loss/actor': actor_loss.item(), - 'loss/critic1': critic1_loss.item(), - 'loss/critic2': critic2_loss.item(), + "loss/actor": actor_loss.item(), + "loss/critic1": critic1_loss.item(), + "loss/critic2": critic2_loss.item(), } if self._is_auto_alpha: - result['loss/alpha'] = alpha_loss.item() - result['v/alpha'] = self._alpha.item() + result["loss/alpha"] = alpha_loss.item() + result["v/alpha"] = self._alpha.item() return result diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 384d4b9..d7cf9c0 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -1,7 +1,7 @@ import torch import numpy as np from copy import deepcopy -from typing import Dict, Tuple, Optional +from typing import Any, Dict, Tuple, Optional from tianshou.policy import DDPGPolicy from tianshou.data import Batch, ReplayBuffer @@ -22,6 +22,8 @@ class TD3Policy(DDPGPolicy): a)) :param torch.optim.Optimizer critic2_optim: the optimizer for the second critic network. + :param action_range: the action range (minimum, maximum). + :type action_range: Tuple[float, float] :param float tau: param for soft update of the target network, defaults to 0.005. :param float gamma: discount factor, in [0, 1], defaults to 0.99. @@ -33,8 +35,6 @@ class TD3Policy(DDPGPolicy): default to 2. :param float noise_clip: the clipping range used in updating policy network, default to 0.5. - :param action_range: the action range (minimum, maximum). - :type action_range: (float, float) :param bool reward_normalization: normalize the reward to Normal(0, 1), defaults to False. :param bool ignore_done: ignore the done flag while training the policy, @@ -46,27 +46,28 @@ class TD3Policy(DDPGPolicy): explanation. """ - def __init__(self, - actor: torch.nn.Module, - actor_optim: torch.optim.Optimizer, - critic1: torch.nn.Module, - critic1_optim: torch.optim.Optimizer, - critic2: torch.nn.Module, - critic2_optim: torch.optim.Optimizer, - tau: float = 0.005, - gamma: float = 0.99, - exploration_noise: Optional[BaseNoise] - = GaussianNoise(sigma=0.1), - policy_noise: float = 0.2, - update_actor_freq: int = 2, - noise_clip: float = 0.5, - action_range: Optional[Tuple[float, float]] = None, - reward_normalization: bool = False, - ignore_done: bool = False, - estimation_step: int = 1, - **kwargs) -> None: - super().__init__(actor, actor_optim, None, None, tau, gamma, - exploration_noise, action_range, reward_normalization, + def __init__( + self, + actor: torch.nn.Module, + actor_optim: torch.optim.Optimizer, + critic1: torch.nn.Module, + critic1_optim: torch.optim.Optimizer, + critic2: torch.nn.Module, + critic2_optim: torch.optim.Optimizer, + action_range: Tuple[float, float], + tau: float = 0.005, + gamma: float = 0.99, + exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1), + policy_noise: float = 0.2, + update_actor_freq: int = 2, + noise_clip: float = 0.5, + reward_normalization: bool = False, + ignore_done: bool = False, + estimation_step: int = 1, + **kwargs: Any, + ) -> None: + super().__init__(actor, actor_optim, None, None, action_range, + tau, gamma, exploration_noise, reward_normalization, ignore_done, estimation_step, **kwargs) self.critic1, self.critic1_old = critic1, deepcopy(critic1) self.critic1_old.eval() @@ -80,7 +81,7 @@ class TD3Policy(DDPGPolicy): self._cnt = 0 self._last = 0 - def train(self, mode=True) -> torch.nn.Module: + def train(self, mode: bool = True) -> "TD3Policy": self.training = mode self.actor.train(mode) self.critic1.train(mode) @@ -89,22 +90,25 @@ class TD3Policy(DDPGPolicy): def sync_weight(self) -> None: for o, n in zip(self.actor_old.parameters(), self.actor.parameters()): - o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) + o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) for o, n in zip( - self.critic1_old.parameters(), self.critic1.parameters()): - o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) + self.critic1_old.parameters(), self.critic1.parameters() + ): + o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) for o, n in zip( - self.critic2_old.parameters(), self.critic2.parameters()): - o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) + self.critic2_old.parameters(), self.critic2.parameters() + ): + o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau) - def _target_q(self, buffer: ReplayBuffer, - indice: np.ndarray) -> torch.Tensor: + def _target_q( + self, buffer: ReplayBuffer, indice: np.ndarray + ) -> torch.Tensor: batch = buffer[indice] # batch.obs: s_{t+n} with torch.no_grad(): - a_ = self(batch, model='actor_old', input='obs_next').act + a_ = self(batch, model="actor_old", input="obs_next").act dev = a_.device noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise - if self._noise_clip > 0: + if self._noise_clip > 0.0: noise = noise.clamp(-self._noise_clip, self._noise_clip) a_ += noise a_ = a_.clamp(self._range[0], self._range[1]) @@ -113,8 +117,8 @@ class TD3Policy(DDPGPolicy): self.critic2_old(batch.obs_next, a_)) return target_q - def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: - weight = batch.pop('weight', 1.) + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: + weight = batch.pop("weight", 1.0) # critic 1 current_q1 = self.critic1(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() @@ -132,10 +136,10 @@ class TD3Policy(DDPGPolicy): self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() - batch.weight = (td1 + td2) / 2. # prio-buffer + batch.weight = (td1 + td2) / 2.0 # prio-buffer if self._cnt % self._freq == 0: actor_loss = -self.critic1( - batch.obs, self(batch, eps=0).act).mean() + batch.obs, self(batch, eps=0.0).act).mean() self.actor_optim.zero_grad() actor_loss.backward() self._last = actor_loss.item() @@ -143,7 +147,7 @@ class TD3Policy(DDPGPolicy): self.sync_weight() self._cnt += 1 return { - 'loss/actor': self._last, - 'loss/critic1': critic1_loss.item(), - 'loss/critic2': critic2_loss.item(), + "loss/actor": self._last, + "loss/critic1": critic1_loss.item(), + "loss/critic2": critic2_loss.item(), } diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index 541481e..5bfd5e9 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -1,5 +1,5 @@ import numpy as np -from typing import Union, Optional, Dict, List +from typing import Any, Dict, List, Union, Optional from tianshou.policy import BasePolicy from tianshou.data import Batch, ReplayBuffer @@ -15,21 +15,22 @@ class MultiAgentPolicyManager(BasePolicy): :ref:`marl_example` can help you better understand this procedure. """ - def __init__(self, policies: List[BasePolicy]): - super().__init__() + def __init__(self, policies: List[BasePolicy], **kwargs: Any) -> None: + super().__init__(**kwargs) self.policies = policies for i, policy in enumerate(policies): # agent_id 0 is reserved for the environment proxy # (this MultiAgentPolicyManager) policy.set_agent_id(i + 1) - def replace_policy(self, policy, agent_id): + def replace_policy(self, policy: BasePolicy, agent_id: int) -> None: """Replace the "agent_id"th policy in this manager.""" self.policies[agent_id - 1] = policy policy.set_agent_id(agent_id) - def process_fn(self, batch: Batch, buffer: ReplayBuffer, - indice: np.ndarray) -> Batch: + def process_fn( + self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray + ) -> Batch: """Dispatch batch data from obs.agent_id to every policy's process_fn. Save original multi-dimensional rew in "save_rew", set rew to the @@ -46,21 +47,24 @@ class MultiAgentPolicyManager(BasePolicy): for policy in self.policies: agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0] if len(agent_index) == 0: - results[f'agent_{policy.agent_id}'] = Batch() + results[f"agent_{policy.agent_id}"] = Batch() continue tmp_batch, tmp_indice = batch[agent_index], indice[agent_index] if has_rew: tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1] buffer._meta.rew = save_rew[:, policy.agent_id - 1] - results[f'agent_{policy.agent_id}'] = \ - policy.process_fn(tmp_batch, buffer, tmp_indice) + results[f"agent_{policy.agent_id}"] = policy.process_fn( + tmp_batch, buffer, tmp_indice) if has_rew: # restore from save_rew buffer._meta.rew = save_rew return Batch(results) - def forward(self, batch: Batch, - state: Optional[Union[dict, Batch]] = None, - **kwargs) -> Batch: + def forward( + self, + batch: Batch, + state: Optional[Union[dict, Batch]] = None, + **kwargs: Any, + ) -> Batch: """Dispatch batch data from obs.agent_id to every policy's forward. :param state: if None, it means all agents have no state. If not @@ -107,15 +111,15 @@ class MultiAgentPolicyManager(BasePolicy): **kwargs) act = out.act each_state = out.state \ - if (hasattr(out, 'state') and out.state is not None) \ + if (hasattr(out, "state") and out.state is not None) \ else Batch() results.append((True, agent_index, out, act, each_state)) - holder = Batch.cat([{'act': act} for + holder = Batch.cat([{"act": act} for (has_data, agent_index, out, act, each_state) in results if has_data]) state_dict, out_dict = {}, {} - for policy, (has_data, agent_index, out, act, state) in \ - zip(self.policies, results): + for policy, (has_data, agent_index, out, act, state) in zip( + self.policies, results): if has_data: holder.act[agent_index] = act state_dict["agent_" + str(policy.agent_id)] = state @@ -124,8 +128,9 @@ class MultiAgentPolicyManager(BasePolicy): holder["state"] = state_dict return holder - def learn(self, batch: Batch, **kwargs - ) -> Dict[str, Union[float, List[float]]]: + def learn( + self, batch: Batch, **kwargs: Any + ) -> Dict[str, Union[float, List[float]]]: """Dispatch the data to all policies for learning. :return: a dict with the following contents: @@ -142,9 +147,9 @@ class MultiAgentPolicyManager(BasePolicy): """ results = {} for policy in self.policies: - data = batch[f'agent_{policy.agent_id}'] + data = batch[f"agent_{policy.agent_id}"] if not data.is_empty(): out = policy.learn(batch=data, **kwargs) for k, v in out.items(): - results["agent_" + str(policy.agent_id) + '/' + k] = v + results["agent_" + str(policy.agent_id) + "/" + k] = v return results diff --git a/tianshou/policy/random.py b/tianshou/policy/random.py index baac742..13f9159 100644 --- a/tianshou/policy/random.py +++ b/tianshou/policy/random.py @@ -1,5 +1,5 @@ import numpy as np -from typing import Union, Optional, Dict, List +from typing import Any, Dict, Union, Optional from tianshou.data import Batch from tianshou.policy import BasePolicy @@ -11,9 +11,12 @@ class RandomPolicy(BasePolicy): It randomly chooses an action from the legal action. """ - def forward(self, batch: Batch, - state: Optional[Union[dict, Batch, np.ndarray]] = None, - **kwargs) -> Batch: + def forward( + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + **kwargs: Any, + ) -> Batch: """Compute the random action over the given batch data. The input should contain a mask in batch.obs, with "True" to be @@ -34,7 +37,6 @@ class RandomPolicy(BasePolicy): logits[~mask] = -np.inf return Batch(act=logits.argmax(axis=-1)) - def learn(self, batch: Batch, **kwargs - ) -> Dict[str, Union[float, List[float]]]: + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: """Since a random agent learn nothing, it returns an empty dict.""" return {} diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py index c8c7e05..36a8ed4 100644 --- a/tianshou/trainer/__init__.py +++ b/tianshou/trainer/__init__.py @@ -3,8 +3,8 @@ from tianshou.trainer.onpolicy import onpolicy_trainer from tianshou.trainer.offpolicy import offpolicy_trainer __all__ = [ - 'gather_info', - 'test_episode', - 'onpolicy_trainer', - 'offpolicy_trainer', + "gather_info", + "test_episode", + "onpolicy_trainer", + "offpolicy_trainer", ] diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index be52c54..bbf5233 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -10,23 +10,23 @@ from tianshou.trainer import test_episode, gather_info def offpolicy_trainer( - policy: BasePolicy, - train_collector: Collector, - test_collector: Collector, - max_epoch: int, - step_per_epoch: int, - collect_per_step: int, - episode_per_test: Union[int, List[int]], - batch_size: int, - update_per_step: int = 1, - train_fn: Optional[Callable[[int], None]] = None, - test_fn: Optional[Callable[[int], None]] = None, - stop_fn: Optional[Callable[[float], bool]] = None, - save_fn: Optional[Callable[[BasePolicy], None]] = None, - writer: Optional[SummaryWriter] = None, - log_interval: int = 1, - verbose: bool = True, - test_in_train: bool = True, + policy: BasePolicy, + train_collector: Collector, + test_collector: Collector, + max_epoch: int, + step_per_epoch: int, + collect_per_step: int, + episode_per_test: Union[int, List[int]], + batch_size: int, + update_per_step: int = 1, + train_fn: Optional[Callable[[int], None]] = None, + test_fn: Optional[Callable[[int], None]] = None, + stop_fn: Optional[Callable[[float], bool]] = None, + save_fn: Optional[Callable[[BasePolicy], None]] = None, + writer: Optional[SummaryWriter] = None, + log_interval: int = 1, + verbose: bool = True, + test_in_train: bool = True, ) -> Dict[str, Union[float, str]]: """A wrapper for off-policy trainer procedure. @@ -72,7 +72,7 @@ def offpolicy_trainer( :return: See :func:`~tianshou.trainer.gather_info`. """ global_step = 0 - best_epoch, best_reward = -1, -1. + best_epoch, best_reward = -1, -1.0 stat = {} start_time = time.time() test_in_train = test_in_train and train_collector.policy == policy @@ -81,42 +81,43 @@ def offpolicy_trainer( policy.train() if train_fn: train_fn(epoch) - with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}', - **tqdm_config) as t: + with tqdm.tqdm( + total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config + ) as t: while t.n < t.total: result = train_collector.collect(n_step=collect_per_step) data = {} - if test_in_train and stop_fn and stop_fn(result['rew']): + if test_in_train and stop_fn and stop_fn(result["rew"]): test_result = test_episode( policy, test_collector, test_fn, epoch, episode_per_test, writer, global_step) - if stop_fn and stop_fn(test_result['rew']): + if stop_fn and stop_fn(test_result["rew"]): if save_fn: save_fn(policy) for k in result.keys(): - data[k] = f'{result[k]:.2f}' + data[k] = f"{result[k]:.2f}" t.set_postfix(**data) return gather_info( start_time, train_collector, test_collector, - test_result['rew']) + test_result["rew"]) else: policy.train() if train_fn: train_fn(epoch) for i in range(update_per_step * min( - result['n/st'] // collect_per_step, t.total - t.n)): + result["n/st"] // collect_per_step, t.total - t.n)): global_step += collect_per_step losses = policy.update(batch_size, train_collector.buffer) for k in result.keys(): - data[k] = f'{result[k]:.2f}' + data[k] = f"{result[k]:.2f}" if writer and global_step % log_interval == 0: - writer.add_scalar('train/' + k, result[k], + writer.add_scalar("train/" + k, result[k], global_step=global_step) for k in losses.keys(): if stat.get(k) is None: stat[k] = MovAvg() stat[k].add(losses[k]) - data[k] = f'{stat[k].get():.6f}' + data[k] = f"{stat[k].get():.6f}" if writer and global_step % log_interval == 0: writer.add_scalar( k, stat[k].get(), global_step=global_step) @@ -127,14 +128,14 @@ def offpolicy_trainer( # test result = test_episode(policy, test_collector, test_fn, epoch, episode_per_test, writer, global_step) - if best_epoch == -1 or best_reward < result['rew']: - best_reward = result['rew'] + if best_epoch == -1 or best_reward < result["rew"]: + best_reward = result["rew"] best_epoch = epoch if save_fn: save_fn(policy) if verbose: - print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, ' - f'best_reward: {best_reward:.6f} in #{best_epoch}') + print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f}, " + f"best_reward: {best_reward:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break return gather_info( diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index db13d06..ac97ba7 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -10,23 +10,23 @@ from tianshou.trainer import test_episode, gather_info def onpolicy_trainer( - policy: BasePolicy, - train_collector: Collector, - test_collector: Collector, - max_epoch: int, - step_per_epoch: int, - collect_per_step: int, - repeat_per_collect: int, - episode_per_test: Union[int, List[int]], - batch_size: int, - train_fn: Optional[Callable[[int], None]] = None, - test_fn: Optional[Callable[[int], None]] = None, - stop_fn: Optional[Callable[[float], bool]] = None, - save_fn: Optional[Callable[[BasePolicy], None]] = None, - writer: Optional[SummaryWriter] = None, - log_interval: int = 1, - verbose: bool = True, - test_in_train: bool = True, + policy: BasePolicy, + train_collector: Collector, + test_collector: Collector, + max_epoch: int, + step_per_epoch: int, + collect_per_step: int, + repeat_per_collect: int, + episode_per_test: Union[int, List[int]], + batch_size: int, + train_fn: Optional[Callable[[int], None]] = None, + test_fn: Optional[Callable[[int], None]] = None, + stop_fn: Optional[Callable[[float], bool]] = None, + save_fn: Optional[Callable[[BasePolicy], None]] = None, + writer: Optional[SummaryWriter] = None, + log_interval: int = 1, + verbose: bool = True, + test_in_train: bool = True, ) -> Dict[str, Union[float, str]]: """A wrapper for on-policy trainer procedure. @@ -72,7 +72,7 @@ def onpolicy_trainer( :return: See :func:`~tianshou.trainer.gather_info`. """ global_step = 0 - best_epoch, best_reward = -1, -1. + best_epoch, best_reward = -1, -1.0 stat = {} start_time = time.time() test_in_train = test_in_train and train_collector.policy == policy @@ -81,30 +81,32 @@ def onpolicy_trainer( policy.train() if train_fn: train_fn(epoch) - with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}', - **tqdm_config) as t: + with tqdm.tqdm( + total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config + ) as t: while t.n < t.total: result = train_collector.collect(n_episode=collect_per_step) data = {} - if test_in_train and stop_fn and stop_fn(result['rew']): + if test_in_train and stop_fn and stop_fn(result["rew"]): test_result = test_episode( policy, test_collector, test_fn, epoch, episode_per_test, writer, global_step) - if stop_fn and stop_fn(test_result['rew']): + if stop_fn and stop_fn(test_result["rew"]): if save_fn: save_fn(policy) for k in result.keys(): - data[k] = f'{result[k]:.2f}' + data[k] = f"{result[k]:.2f}" t.set_postfix(**data) return gather_info( start_time, train_collector, test_collector, - test_result['rew']) + test_result["rew"]) else: policy.train() if train_fn: train_fn(epoch) losses = policy.update( - 0, train_collector.buffer, batch_size, repeat_per_collect) + 0, train_collector.buffer, + batch_size=batch_size, repeat=repeat_per_collect) train_collector.reset_buffer() step = 1 for k in losses.keys(): @@ -112,15 +114,15 @@ def onpolicy_trainer( step = max(step, len(losses[k])) global_step += step * collect_per_step for k in result.keys(): - data[k] = f'{result[k]:.2f}' + data[k] = f"{result[k]:.2f}" if writer and global_step % log_interval == 0: writer.add_scalar( - 'train/' + k, result[k], global_step=global_step) + "train/" + k, result[k], global_step=global_step) for k in losses.keys(): if stat.get(k) is None: stat[k] = MovAvg() stat[k].add(losses[k]) - data[k] = f'{stat[k].get():.6f}' + data[k] = f"{stat[k].get():.6f}" if writer and global_step % log_interval == 0: writer.add_scalar( k, stat[k].get(), global_step=global_step) @@ -131,14 +133,14 @@ def onpolicy_trainer( # test result = test_episode(policy, test_collector, test_fn, epoch, episode_per_test, writer, global_step) - if best_epoch == -1 or best_reward < result['rew']: - best_reward = result['rew'] + if best_epoch == -1 or best_reward < result["rew"]: + best_reward = result["rew"] best_epoch = epoch if save_fn: save_fn(policy) if verbose: - print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, ' - f'best_reward: {best_reward:.6f} in #{best_epoch}') + print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f}, " + f"best_reward: {best_reward:.6f} in #{best_epoch}") if stop_fn and stop_fn(best_reward): break return gather_info( diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index ba914d8..5f6698e 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -8,13 +8,14 @@ from tianshou.policy import BasePolicy def test_episode( - policy: BasePolicy, - collector: Collector, - test_fn: Optional[Callable[[int], None]], - epoch: int, - n_episode: Union[int, List[int]], - writer: SummaryWriter = None, - global_step: int = None) -> Dict[str, float]: + policy: BasePolicy, + collector: Collector, + test_fn: Optional[Callable[[int], None]], + epoch: int, + n_episode: Union[int, List[int]], + writer: Optional[SummaryWriter] = None, + global_step: Optional[int] = None, +) -> Dict[str, float]: """A simple wrapper of testing policy in collector.""" collector.reset_env() collector.reset_buffer() @@ -29,15 +30,16 @@ def test_episode( result = collector.collect(n_episode=n_episode) if writer is not None and global_step is not None: for k in result.keys(): - writer.add_scalar('test/' + k, result[k], global_step=global_step) + writer.add_scalar("test/" + k, result[k], global_step=global_step) return result -def gather_info(start_time: float, - train_c: Collector, - test_c: Collector, - best_reward: float - ) -> Dict[str, Union[float, str]]: +def gather_info( + start_time: float, + train_c: Collector, + test_c: Collector, + best_reward: float, +) -> Dict[str, Union[float, str]]: """A simple wrapper of gathering information from collectors. :return: A dictionary with the following keys: @@ -60,15 +62,15 @@ def gather_info(start_time: float, train_speed = train_c.collect_step / (duration - test_c.collect_time) test_speed = test_c.collect_step / test_c.collect_time return { - 'train_step': train_c.collect_step, - 'train_episode': train_c.collect_episode, - 'train_time/collector': f'{train_c.collect_time:.2f}s', - 'train_time/model': f'{model_time:.2f}s', - 'train_speed': f'{train_speed:.2f} step/s', - 'test_step': test_c.collect_step, - 'test_episode': test_c.collect_episode, - 'test_time': f'{test_c.collect_time:.2f}s', - 'test_speed': f'{test_speed:.2f} step/s', - 'best_reward': best_reward, - 'duration': f'{duration:.2f}s', + "train_step": train_c.collect_step, + "train_episode": train_c.collect_episode, + "train_time/collector": f"{train_c.collect_time:.2f}s", + "train_time/model": f"{model_time:.2f}s", + "train_speed": f"{train_speed:.2f} step/s", + "test_step": test_c.collect_step, + "test_episode": test_c.collect_episode, + "test_time": f"{test_c.collect_time:.2f}s", + "test_speed": f"{test_speed:.2f} step/s", + "best_reward": best_reward, + "duration": f"{duration:.2f}s", } diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index aeb34ae..e5827fd 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -1,9 +1,7 @@ from tianshou.utils.config import tqdm_config -from tianshou.utils.compile import pre_compile from tianshou.utils.moving_average import MovAvg __all__ = [ "MovAvg", - "pre_compile", "tqdm_config", ] diff --git a/tianshou/utils/compile.py b/tianshou/utils/compile.py deleted file mode 100644 index bf051bd..0000000 --- a/tianshou/utils/compile.py +++ /dev/null @@ -1,27 +0,0 @@ -import numpy as np - -from tianshou.policy.base import _episodic_return, _nstep_return -from tianshou.data.utils.segtree import _reduce, _setitem, _get_prefix_sum_idx - - -def pre_compile(): - """Functions that need to pre-compile for producing benchmark result. - - Since Numba acceleration needs to compile the function in the first run, - here we use some fake data for the common-type function-call compilation. - Otherwise, the current training speed cannot compare with the previous. - """ - f64 = np.array([0, 1], dtype=np.float64) - f32 = np.array([0, 1], dtype=np.float32) - b = np.array([False, True], dtype=np.bool_) - i64 = np.array([0, 1], dtype=np.int64) - # returns - _episodic_return(f64, f64, b, .1, .1) - _episodic_return(f32, f64, b, .1, .1) - _nstep_return(f64, b, f32, i64, .1, 1, 4, 1., 0.) - # segtree - _setitem(f64, i64, f64) - _setitem(f64, i64, f32) - _reduce(f64, 0, 1) - _get_prefix_sum_idx(f64, 1, f64) - _get_prefix_sum_idx(f32, 1, f64) diff --git a/tianshou/utils/config.py b/tianshou/utils/config.py index 4cf8503..ca6e91f 100644 --- a/tianshou/utils/config.py +++ b/tianshou/utils/config.py @@ -1,4 +1,4 @@ tqdm_config = { - 'dynamic_ncols': True, - 'ascii': True, + "dynamic_ncols": True, + "ascii": True, } diff --git a/tianshou/utils/moving_average.py b/tianshou/utils/moving_average.py index a138b1c..58c6792 100644 --- a/tianshou/utils/moving_average.py +++ b/tianshou/utils/moving_average.py @@ -1,5 +1,6 @@ import torch import numpy as np +from numbers import Number from typing import Union from tianshou.data import to_numpy @@ -30,7 +31,9 @@ class MovAvg(object): self.cache = [] self.banned = [np.inf, np.nan, -np.inf] - def add(self, x: Union[float, list, np.ndarray, torch.Tensor]) -> float: + def add( + self, x: Union[Number, np.number, list, np.ndarray, torch.Tensor] + ) -> np.number: """Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with only one element, a python scalar, or @@ -39,26 +42,26 @@ class MovAvg(object): if isinstance(x, torch.Tensor): x = to_numpy(x.flatten()) if isinstance(x, list) or isinstance(x, np.ndarray): - for _ in x: - if _ not in self.banned: - self.cache.append(_) + for i in x: + if i not in self.banned: + self.cache.append(i) elif x not in self.banned: self.cache.append(x) if self.size > 0 and len(self.cache) > self.size: self.cache = self.cache[-self.size:] return self.get() - def get(self) -> float: + def get(self) -> np.number: """Get the average.""" if len(self.cache) == 0: return 0 return np.mean(self.cache) - def mean(self) -> float: + def mean(self) -> np.number: """Get the average. Same as :meth:`get`.""" return self.get() - def std(self) -> float: + def std(self) -> np.number: """Get the standard deviation.""" if len(self.cache) == 0: return 0 diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 8c4fcc5..96023fe 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -1,13 +1,16 @@ import torch import numpy as np from torch import nn -from typing import List, Tuple, Union, Optional +from typing import Any, Dict, List, Tuple, Union, Callable, Optional, Sequence from tianshou.data import to_torch -def miniblock(inp: int, oup: int, - norm_layer: nn.modules.Module) -> List[nn.modules.Module]: +def miniblock( + inp: int, + oup: int, + norm_layer: Optional[Callable[[int], nn.modules.Module]], +) -> List[nn.modules.Module]: """Construct a miniblock with given input/output-size and norm layer.""" ret = [nn.Linear(inp, oup)] if norm_layer is not None: @@ -27,18 +30,22 @@ class Net(nn.Module): shape, but affects the input shape. :param bool dueling: whether to use dueling network to calculate Q values (for Dueling DQN), defaults to False. - :param nn.modules.Module norm_layer: use which normalization before ReLU, - e.g., ``nn.LayerNorm`` and ``nn.BatchNorm1d``, defaults to None. + :param norm_layer: use which normalization before ReLU, e.g., + ``nn.LayerNorm`` and ``nn.BatchNorm1d``, defaults to None. """ - def __init__(self, layer_num: int, state_shape: tuple, - action_shape: Optional[Union[tuple, int]] = 0, - device: Union[str, torch.device] = 'cpu', - softmax: bool = False, - concat: bool = False, - hidden_layer_size: int = 128, - dueling: Optional[Tuple[int, int]] = None, - norm_layer: Optional[nn.modules.Module] = None): + def __init__( + self, + layer_num: int, + state_shape: tuple, + action_shape: Optional[Union[tuple, int]] = 0, + device: Union[str, int, torch.device] = "cpu", + softmax: bool = False, + concat: bool = False, + hidden_layer_size: int = 128, + dueling: Optional[Tuple[int, int]] = None, + norm_layer: Optional[Callable[[int], nn.modules.Module]] = None, + ) -> None: super().__init__() self.device = device self.dueling = dueling @@ -78,7 +85,12 @@ class Net(nn.Module): self.V = nn.Sequential(*self.V) self.model = nn.Sequential(*self.model) - def forward(self, s, state=None, info={}): + def forward( + self, + s: Union[np.ndarray, torch.Tensor], + state: Optional[Any] = None, + info: Dict[str, Any] = {}, + ) -> Tuple[torch.Tensor, Any]: """Mapping: s -> flatten -> logits.""" s = to_torch(s, device=self.device, dtype=torch.float32) s = s.reshape(s.size(0), -1) @@ -98,19 +110,33 @@ class Recurrent(nn.Module): :ref:`build_the_network`. """ - def __init__(self, layer_num, state_shape, action_shape, - device='cpu', hidden_layer_size=128): + def __init__( + self, + layer_num: int, + state_shape: Sequence[int], + action_shape: Sequence[int], + device: Union[str, int, torch.device] = "cpu", + hidden_layer_size: int = 128, + ) -> None: super().__init__() self.state_shape = state_shape self.action_shape = action_shape self.device = device - self.nn = nn.LSTM(input_size=hidden_layer_size, - hidden_size=hidden_layer_size, - num_layers=layer_num, batch_first=True) + self.nn = nn.LSTM( + input_size=hidden_layer_size, + hidden_size=hidden_layer_size, + num_layers=layer_num, + batch_first=True, + ) self.fc1 = nn.Linear(np.prod(state_shape), hidden_layer_size) self.fc2 = nn.Linear(hidden_layer_size, np.prod(action_shape)) - def forward(self, s, state=None, info={}): + def forward( + self, + s: Union[np.ndarray, torch.Tensor], + state: Optional[Dict[str, torch.Tensor]] = None, + info: Dict[str, Any] = {}, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """Mapping: s -> flatten -> logits. In the evaluation mode, s should be with shape ``[bsz, dim]``; in the @@ -130,9 +156,9 @@ class Recurrent(nn.Module): else: # we store the stack data in [bsz, len, ...] format # but pytorch rnn needs [len, bsz, ...] - s, (h, c) = self.nn(s, (state['h'].transpose(0, 1).contiguous(), - state['c'].transpose(0, 1).contiguous())) + s, (h, c) = self.nn(s, (state["h"].transpose(0, 1).contiguous(), + state["c"].transpose(0, 1).contiguous())) s = self.fc2(s[:, -1]) # please ensure the first dim is batch size: [bsz, len, ...] - return s, {'h': h.transpose(0, 1).detach(), - 'c': c.transpose(0, 1).detach()} + return s, {"h": h.transpose(0, 1).detach(), + "c": c.transpose(0, 1).detach()} diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index 19586bb..85f03d2 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -1,6 +1,7 @@ import torch import numpy as np from torch import nn +from typing import Any, Dict, Tuple, Union, Optional, Sequence from tianshou.data import to_torch, to_torch_as @@ -12,14 +13,25 @@ class Actor(nn.Module): :ref:`build_the_network`. """ - def __init__(self, preprocess_net, action_shape, max_action=1., - device='cpu', hidden_layer_size=128): + def __init__( + self, + preprocess_net: nn.Module, + action_shape: Sequence[int], + max_action: float = 1.0, + device: Union[str, int, torch.device] = "cpu", + hidden_layer_size: int = 128, + ) -> None: super().__init__() self.preprocess = preprocess_net self.last = nn.Linear(hidden_layer_size, np.prod(action_shape)) self._max = max_action - def forward(self, s, state=None, info={}): + def forward( + self, + s: Union[np.ndarray, torch.Tensor], + state: Optional[Any] = None, + info: Dict[str, Any] = {}, + ) -> Tuple[torch.Tensor, Any]: """Mapping: s -> logits -> action.""" logits, h = self.preprocess(s, state) logits = self._max * torch.tanh(self.last(logits)) @@ -33,13 +45,23 @@ class Critic(nn.Module): :ref:`build_the_network`. """ - def __init__(self, preprocess_net, device='cpu', hidden_layer_size=128): + def __init__( + self, + preprocess_net: nn.Module, + device: Union[str, int, torch.device] = "cpu", + hidden_layer_size: int = 128, + ) -> None: super().__init__() self.device = device self.preprocess = preprocess_net self.last = nn.Linear(hidden_layer_size, 1) - def forward(self, s, a=None, info={}): + def forward( + self, + s: Union[np.ndarray, torch.Tensor], + a: Optional[Union[np.ndarray, torch.Tensor]] = None, + info: Dict[str, Any] = {}, + ) -> torch.Tensor: """Mapping: (s, a) -> logits -> Q(s, a).""" s = to_torch(s, device=self.device, dtype=torch.float32) s = s.flatten(1) @@ -59,8 +81,15 @@ class ActorProb(nn.Module): :ref:`build_the_network`. """ - def __init__(self, preprocess_net, action_shape, max_action=1., - device='cpu', unbounded=False, hidden_layer_size=128): + def __init__( + self, + preprocess_net: nn.Module, + action_shape: Sequence[int], + max_action: float = 1.0, + device: Union[str, int, torch.device] = "cpu", + unbounded: bool = False, + hidden_layer_size: int = 128, + ) -> None: super().__init__() self.preprocess = preprocess_net self.device = device @@ -69,7 +98,12 @@ class ActorProb(nn.Module): self._max = max_action self._unbounded = unbounded - def forward(self, s, state=None, info={}): + def forward( + self, + s: Union[np.ndarray, torch.Tensor], + state: Optional[Any] = None, + info: Dict[str, Any] = {}, + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Any]: """Mapping: s -> logits -> (mu, sigma).""" logits, h = self.preprocess(s, state) mu = self.mu(logits) @@ -78,7 +112,7 @@ class ActorProb(nn.Module): shape = [1] * len(mu.shape) shape[1] = -1 sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp() - return (mu, sigma), None + return (mu, sigma), state class RecurrentActorProb(nn.Module): @@ -88,19 +122,35 @@ class RecurrentActorProb(nn.Module): :ref:`build_the_network`. """ - def __init__(self, layer_num, state_shape, action_shape, max_action=1., - device='cpu', unbounded=False, hidden_layer_size=128): + def __init__( + self, + layer_num: int, + state_shape: Sequence[int], + action_shape: Sequence[int], + max_action: float = 1.0, + device: Union[str, int, torch.device] = "cpu", + unbounded: bool = False, + hidden_layer_size: int = 128, + ) -> None: super().__init__() self.device = device - self.nn = nn.LSTM(input_size=np.prod(state_shape), - hidden_size=hidden_layer_size, - num_layers=layer_num, batch_first=True) + self.nn = nn.LSTM( + input_size=np.prod(state_shape), + hidden_size=hidden_layer_size, + num_layers=layer_num, + batch_first=True, + ) self.mu = nn.Linear(hidden_layer_size, np.prod(action_shape)) self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1)) self._max = max_action self._unbounded = unbounded - def forward(self, s, state=None, info={}): + def forward( + self, + s: Union[np.ndarray, torch.Tensor], + state: Optional[Dict[str, torch.Tensor]] = None, + info: Dict[str, Any] = {}, + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Dict[str, torch.Tensor]]: """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" s = to_torch(s, device=self.device, dtype=torch.float32) # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) @@ -114,8 +164,8 @@ class RecurrentActorProb(nn.Module): else: # we store the stack data in [bsz, len, ...] format # but pytorch rnn needs [len, bsz, ...] - s, (h, c) = self.nn(s, (state['h'].transpose(0, 1).contiguous(), - state['c'].transpose(0, 1).contiguous())) + s, (h, c) = self.nn(s, (state["h"].transpose(0, 1).contiguous(), + state["c"].transpose(0, 1).contiguous())) logits = s[:, -1] mu = self.mu(logits) if not self._unbounded: @@ -124,8 +174,8 @@ class RecurrentActorProb(nn.Module): shape[1] = -1 sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp() # please ensure the first dim is batch size: [bsz, len, ...] - return (mu, sigma), {'h': h.transpose(0, 1).detach(), - 'c': c.transpose(0, 1).detach()} + return (mu, sigma), {"h": h.transpose(0, 1).detach(), + "c": c.transpose(0, 1).detach()} class RecurrentCritic(nn.Module): @@ -135,18 +185,32 @@ class RecurrentCritic(nn.Module): :ref:`build_the_network`. """ - def __init__(self, layer_num, state_shape, - action_shape=0, device='cpu', hidden_layer_size=128): + def __init__( + self, + layer_num: int, + state_shape: Sequence[int], + action_shape: Sequence[int] = [0], + device: Union[str, int, torch.device] = "cpu", + hidden_layer_size: int = 128, + ) -> None: super().__init__() self.state_shape = state_shape self.action_shape = action_shape self.device = device - self.nn = nn.LSTM(input_size=np.prod(state_shape), - hidden_size=hidden_layer_size, - num_layers=layer_num, batch_first=True) + self.nn = nn.LSTM( + input_size=np.prod(state_shape), + hidden_size=hidden_layer_size, + num_layers=layer_num, + batch_first=True, + ) self.fc2 = nn.Linear(hidden_layer_size + np.prod(action_shape), 1) - def forward(self, s, a=None): + def forward( + self, + s: Union[np.ndarray, torch.Tensor], + a: Optional[Union[np.ndarray, torch.Tensor]] = None, + info: Dict[str, Any] = {}, + ) -> torch.Tensor: """Almost the same as :class:`~tianshou.utils.net.common.Recurrent`.""" s = to_torch(s, device=self.device, dtype=torch.float32) # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 03f4583..6734316 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -2,6 +2,7 @@ import torch import numpy as np from torch import nn import torch.nn.functional as F +from typing import Any, Dict, Tuple, Union, Optional, Sequence class Actor(nn.Module): @@ -11,12 +12,22 @@ class Actor(nn.Module): :ref:`build_the_network`. """ - def __init__(self, preprocess_net, action_shape, hidden_layer_size=128): + def __init__( + self, + preprocess_net: nn.Module, + action_shape: Sequence[int], + hidden_layer_size: int = 128, + ) -> None: super().__init__() self.preprocess = preprocess_net self.last = nn.Linear(hidden_layer_size, np.prod(action_shape)) - def forward(self, s, state=None, info={}): + def forward( + self, + s: Union[np.ndarray, torch.Tensor], + state: Optional[Any] = None, + info: Dict[str, Any] = {}, + ) -> Tuple[torch.Tensor, Any]: r"""Mapping: s -> Q(s, \*).""" logits, h = self.preprocess(s, state) logits = F.softmax(self.last(logits), dim=-1) @@ -30,14 +41,18 @@ class Critic(nn.Module): :ref:`build_the_network`. """ - def __init__(self, preprocess_net, hidden_layer_size=128): + def __init__( + self, preprocess_net: nn.Module, hidden_layer_size: int = 128 + ) -> None: super().__init__() self.preprocess = preprocess_net self.last = nn.Linear(hidden_layer_size, 1) - def forward(self, s, **kwargs): + def forward( + self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any + ) -> torch.Tensor: """Mapping: s -> V(s).""" - logits, h = self.preprocess(s, state=kwargs.get('state', None)) + logits, h = self.preprocess(s, state=kwargs.get("state", None)) logits = self.last(logits) return logits @@ -49,17 +64,31 @@ class DQN(nn.Module): :ref:`build_the_network`. """ - def __init__(self, c, h, w, action_shape, device='cpu'): - super(DQN, self).__init__() + def __init__( + self, + c: int, + h: int, + w: int, + action_shape: Sequence[int], + device: Union[str, int, torch.device] = "cpu", + ) -> None: + super().__init__() self.device = device - def conv2d_size_out(size, kernel_size=5, stride=2): + def conv2d_size_out( + size: int, kernel_size: int = 5, stride: int = 2 + ) -> int: return (size - (kernel_size - 1) - 1) // stride + 1 - def conv2d_layers_size_out(size, - kernel_size_1=8, stride_1=4, - kernel_size_2=4, stride_2=2, - kernel_size_3=3, stride_3=1): + def conv2d_layers_size_out( + size: int, + kernel_size_1: int = 8, + stride_1: int = 4, + kernel_size_2: int = 4, + stride_2: int = 2, + kernel_size_3: int = 3, + stride_3: int = 1, + ) -> int: size = conv2d_size_out(size, kernel_size_1, stride_1) size = conv2d_size_out(size, kernel_size_2, stride_2) size = conv2d_size_out(size, kernel_size_3, stride_3) @@ -78,10 +107,15 @@ class DQN(nn.Module): nn.ReLU(inplace=True), nn.Flatten(), nn.Linear(linear_input_size, 512), - nn.Linear(512, np.prod(action_shape)) + nn.Linear(512, np.prod(action_shape)), ) - def forward(self, x, state=None, info={}): + def forward( + self, + x: Union[np.ndarray, torch.Tensor], + state: Optional[Any] = None, + info: Dict[str, Any] = {}, + ) -> Tuple[torch.Tensor, Any]: r"""Mapping: x -> Q(x, \*).""" if not isinstance(x, torch.Tensor): x = torch.tensor(x, device=self.device, dtype=torch.float32)