diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index fdc2908..8d56bf6 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -4,7 +4,7 @@ + [ ] documentation request (i.e. "X is missing from the documentation.") + [ ] new feature request - [ ] I have visited the [source website], and in particular read the [known issues] -- [ ] I have searched through the [issue categories] for duplicates +- [ ] I have searched through the [issue tracker] and [issue categories] for duplicates - [ ] I have mentioned version numbers, operating system and environment, where applicable: ```python import tianshou, torch, sys @@ -14,3 +14,4 @@ [source website]: https://github.com/thu-ml/tianshou/ [known issues]: https://github.com/thu-ml/tianshou/#faq-and-known-issues [issue categories]: https://github.com/thu-ml/tianshou/projects/2 + [issue tracker]: https://github.com/thu-ml/tianshou/issues?q= diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 3fef4b1..38b9f83 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -8,7 +8,7 @@ Less important but also useful: - [ ] I have visited the [source website], and in particular read the [known issues] -- [ ] I have searched through the [issue categories] for duplicates +- [ ] I have searched through the [issue tracker] and [issue categories] for duplicates - [ ] I have mentioned version numbers, operating system and environment, where applicable: ```python import tianshou, torch, sys @@ -18,3 +18,4 @@ Less important but also useful: [source website]: https://github.com/thu-ml/tianshou [known issues]: https://github.com/thu-ml/tianshou/#faq-and-known-issues [issue categories]: https://github.com/thu-ml/tianshou/projects/2 + [issue tracker]: https://github.com/thu-ml/tianshou/issues?q= diff --git a/examples/ant_v2_sac.py b/examples/ant_v2_sac.py index 9c6e71f..8632e92 100644 --- a/examples/ant_v2_sac.py +++ b/examples/ant_v2_sac.py @@ -1,3 +1,4 @@ +import os import gym import torch import pprint @@ -32,6 +33,7 @@ def get_args(): parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--render', type=float, default=0.) + parser.add_argument('--rew-norm', type=bool, default=True) parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') @@ -73,14 +75,18 @@ def test_sac(args=get_args()): actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, args.tau, args.gamma, args.alpha, [env.action_space.low[0], env.action_space.high[0]], - reward_normalization=True, ignore_done=True) + reward_normalization=args.rew_norm, ignore_done=True) # collector train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) test_collector = Collector(policy, test_envs) # train_collector.collect(n_step=args.buffer_size) # log - writer = SummaryWriter(args.logdir + '/' + 'sac') + log_path = os.path.join(args.logdir, args.task, 'sac') + writer = SummaryWriter(log_path) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(x): return x >= env.spec.reward_threshold @@ -89,7 +95,7 @@ def test_sac(args=get_args()): result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.collect_per_step, args.test_num, - args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task) + args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) train_collector.close() test_collector.close() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index de599b1..f8352ec 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -1,6 +1,7 @@ import torch import pprint import numpy as np +from typing import Any, List, Union, Iterator, Optional class Batch(object): @@ -19,8 +20,8 @@ class Batch(object): >>> print(data) Batch( a: 4, - b: [3 4 5], - c: 2312312, + b: array([3, 4, 5]), + c: '2312312', ) In short, you can define a :class:`Batch` with any key-value pair. The @@ -65,7 +66,7 @@ class Batch(object): [11 22] [6 6] """ - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super().__init__() self._meta = {} for k, v in kwargs.items(): @@ -85,7 +86,7 @@ class Batch(object): else: self.__dict__[k] = kwargs[k] - def __getitem__(self, index): + def __getitem__(self, index: Union[str, slice]) -> Union['Batch', dict]: """Return self[index].""" if isinstance(index, str): return self.__getattr__(index) @@ -96,7 +97,7 @@ class Batch(object): b._meta = self._meta return b - def __getattr__(self, key): + def __getattr__(self, key: str) -> Union['Batch', Any]: """Return self.key""" if key not in self._meta: if key not in self.__dict__: @@ -108,7 +109,7 @@ class Batch(object): d[k_] = self.__dict__[k__] return Batch(**d) - def __repr__(self): + def __repr__(self) -> str: """Return str(self).""" s = self.__class__.__name__ + '(\n' flag = False @@ -125,18 +126,18 @@ class Batch(object): s = self.__class__.__name__ + '()' return s - def keys(self): + def keys(self) -> List[str]: """Return self.keys().""" - return sorted([i for i in self.__dict__ if i[0] != '_'] + - list(self._meta)) + return sorted([ + i for i in self.__dict__ if i[0] != '_'] + list(self._meta)) - def get(self, k, d=None): + def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]: """Return self[k] if k in self else d. d defaults to None.""" if k in self.__dict__ or k in self._meta: return self.__getattr__(k) return d - def to_numpy(self): + def to_numpy(self) -> np.ndarray: """Change all torch.Tensor to numpy.ndarray. This is an inplace operation. """ @@ -144,7 +145,7 @@ class Batch(object): if isinstance(self.__dict__[k], torch.Tensor): self.__dict__[k] = self.__dict__[k].cpu().numpy() - def append(self, batch): + def append(self, batch: 'Batch') -> None: """Append a :class:`~tianshou.data.Batch` object to current batch.""" assert isinstance(batch, Batch), 'Only append Batch is allowed!' for k in batch.__dict__: @@ -169,13 +170,14 @@ class Batch(object): + 'in class Batch.' raise TypeError(s) - def __len__(self): + def __len__(self) -> int: """Return len(self).""" return min([ len(self.__dict__[k]) for k in self.__dict__ if k != '_meta' and self.__dict__[k] is not None]) - def split(self, size=None, shuffle=True): + def split(self, size: Optional[int] = None, + shuffle: Optional[bool] = True) -> Iterator['Batch']: """Split whole data into multiple small batch. :param int size: if it is ``None``, it does not split the data batch; diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 984c0a4..028c189 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -1,5 +1,8 @@ import pprint import numpy as np +from copy import deepcopy +from typing import Tuple, Union, Optional + from tianshou.data.batch import Batch @@ -92,7 +95,8 @@ class ReplayBuffer(object): [ 7. 7. 8. 9.]] """ - def __init__(self, size, stack_num=0, ignore_obs_next=False, **kwargs): + def __init__(self, size: int, stack_num: Optional[int] = 0, + ignore_obs_next: Optional[bool] = False, **kwargs) -> None: super().__init__() self._maxsize = size self._stack = stack_num @@ -100,11 +104,11 @@ class ReplayBuffer(object): self._meta = {} self.reset() - def __len__(self): + def __len__(self) -> int: """Return len(self).""" return self._size - def __repr__(self): + def __repr__(self) -> str: """Return str(self).""" s = self.__class__.__name__ + '(\n' flag = False @@ -121,7 +125,7 @@ class ReplayBuffer(object): s = self.__class__.__name__ + '()' return s - def __getattr__(self, key): + def __getattr__(self, key: str) -> Union[Batch, np.ndarray]: """Return self.key""" if key not in self._meta: if key not in self.__dict__: @@ -133,7 +137,9 @@ class ReplayBuffer(object): d[k_] = self.__dict__[k__] return Batch(**d) - def _add_to_buffer(self, name, inst): + def _add_to_buffer( + self, name: str, + inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None: if inst is None: if getattr(self, name, None) is None: self.__dict__[name] = None @@ -164,9 +170,11 @@ class ReplayBuffer(object): f"key: {name}, expect shape: {self.__dict__[name].shape[1:]}, " f"given shape: {inst.shape}.") if name not in self._meta: + if name == 'info': + inst = deepcopy(inst) self.__dict__[name][self._index] = inst - def update(self, buffer): + def update(self, buffer: 'ReplayBuffer') -> None: """Move the data from the given buffer to self.""" i = begin = buffer._index % len(buffer) while True: @@ -178,8 +186,15 @@ class ReplayBuffer(object): if i == begin: break - def add(self, obs, act, rew, done, obs_next=None, info={}, policy={}, - **kwargs): + def add(self, + obs: Union[dict, np.ndarray], + act: Union[np.ndarray, float], + rew: float, + done: bool, + obs_next: Optional[Union[dict, np.ndarray]] = None, + info: Optional[dict] = {}, + policy: Optional[Union[dict, Batch]] = {}, + **kwargs) -> None: """Add a batch of data into replay buffer.""" assert isinstance(info, dict), \ 'You should return a dict in the last argument of env.step().' @@ -197,11 +212,11 @@ class ReplayBuffer(object): else: self._size = self._index = self._index + 1 - def reset(self): + def reset(self) -> None: """Clear all the data in replay buffer.""" self._index = self._size = 0 - def sample(self, batch_size): + def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: """Get a random sample from buffer with size equal to batch_size. \ Return all the data in the buffer if batch_size is ``0``. @@ -216,7 +231,8 @@ class ReplayBuffer(object): ]) return self[indice], indice - def get(self, indice, key, stack_num=None): + def get(self, indice: Union[slice, np.ndarray], key: str, + stack_num: Optional[int] = None) -> Union[Batch, np.ndarray]: """Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is indice. The stack_num (here equals to 4) is given from buffer initialization procedure. @@ -275,7 +291,7 @@ class ReplayBuffer(object): stack = np.stack(stack, axis=1) return stack - def __getitem__(self, index): + def __getitem__(self, index: Union[slice, np.ndarray]) -> Batch: """Return a data batch: self[index]. If stack_num is set to be > 0, return the stacked obs and obs_next with shape [batch, len, ...]. """ @@ -302,17 +318,21 @@ class ListReplayBuffer(ReplayBuffer): detailed explanation. """ - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super().__init__(size=0, ignore_obs_next=False, **kwargs) - def _add_to_buffer(self, name, inst): + def _add_to_buffer( + self, name: str, + inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None: if inst is None: return if self.__dict__.get(name, None) is None: self.__dict__[name] = [] + if name == 'info': + inst = deepcopy(inst) self.__dict__[name].append(inst) - def reset(self): + def reset(self) -> None: self._index = self._size = 0 for k in list(self.__dict__): if isinstance(self.__dict__[k], list): @@ -332,8 +352,8 @@ class PrioritizedReplayBuffer(ReplayBuffer): detailed explanation. """ - def __init__(self, size, alpha: float, beta: float, - mode: str = 'weight', **kwargs): + def __init__(self, size: int, alpha: float, beta: float, + mode: Optional[str] = 'weight', **kwargs) -> None: if mode != 'weight': raise NotImplementedError super().__init__(size, **kwargs) @@ -344,17 +364,27 @@ class PrioritizedReplayBuffer(ReplayBuffer): self._amortization_freq = 50 self._amortization_counter = 0 - def add(self, obs, act, rew, done, obs_next=0, info={}, policy={}, - weight=1.0): + def add(self, + obs: Union[dict, np.ndarray], + act: Union[np.ndarray, float], + rew: float, + done: bool, + obs_next: Optional[Union[dict, np.ndarray]] = None, + info: Optional[dict] = {}, + policy: Optional[Union[dict, Batch]] = {}, + weight: Optional[float] = 1.0, + **kwargs) -> None: """Add a batch of data into replay buffer.""" - self._weight_sum += np.abs(weight)**self._alpha - \ + self._weight_sum += np.abs(weight) ** self._alpha - \ self.weight[self._index] # we have to sacrifice some convenience for speed :( self._add_to_buffer('weight', np.abs(weight) ** self._alpha) super().add(obs, act, rew, done, obs_next, info, policy) self._check_weight_sum() - def sample(self, batch_size: int = 0, importance_sample: bool = True): + def sample(self, batch_size: Optional[int] = 0, + importance_sample: Optional[bool] = True + ) -> Tuple[Batch, np.ndarray]: """Get a random sample from buffer with priority probability. \ Return all the data in the buffer if batch_size is ``0``. @@ -388,11 +418,12 @@ class PrioritizedReplayBuffer(ReplayBuffer): self._check_weight_sum() return batch, indice - def reset(self): + def reset(self) -> None: self._amortization_counter = 0 super().reset() - def update_weight(self, indice, new_weight: np.ndarray): + def update_weight(self, indice: Union[slice, np.ndarray], + new_weight: np.ndarray) -> None: """Update priority weight by indice in this buffer. :param np.ndarray indice: indice you want to update weight @@ -402,7 +433,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): - self.weight[indice].sum() self.weight[indice] = np.power(np.abs(new_weight), self._alpha) - def __getitem__(self, index): + def __getitem__(self, index: Union[slice, np.ndarray]) -> Batch: return Batch( obs=self.get(index, 'obs'), act=self.act[index], @@ -415,7 +446,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): policy=self.get(index, 'policy'), ) - def _check_weight_sum(self): + def _check_weight_sum(self) -> None: # keep an accurate _weight_sum self._amortization_counter += 1 if self._amortization_counter % self._amortization_freq == 0: diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index bcb920f..ce6a301 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -1,10 +1,13 @@ +import gym import time import torch import warnings import numpy as np +from typing import Any, Dict, List, Union, Optional, Callable from tianshou.utils import MovAvg from tianshou.env import BaseVectorEnv +from tianshou.policy import BasePolicy from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer @@ -77,8 +80,14 @@ class Collector(object): Please make sure the given environment has a time limitation. """ - def __init__(self, policy, env, buffer=None, preprocess_fn=None, - stat_size=100, **kwargs): + def __init__(self, + policy: BasePolicy, + env: Union[gym.Env, BaseVectorEnv], + buffer: Optional[Union[ReplayBuffer, List[ReplayBuffer]]] + = None, + preprocess_fn: Callable[[Any], Union[dict, Batch]] = None, + stat_size: Optional[int] = 100, + **kwargs) -> None: super().__init__() self.env = env self.env_num = 1 @@ -112,7 +121,7 @@ class Collector(object): self.stat_size = stat_size self.reset() - def reset(self): + def reset(self) -> None: """Reset all related variables in the collector.""" self.reset_env() self.reset_buffer() @@ -124,7 +133,7 @@ class Collector(object): self.collect_episode = 0 self.collect_time = 0 - def reset_buffer(self): + def reset_buffer(self) -> None: """Reset the main data buffer.""" if self._multi_buf: for b in self.buffer: @@ -133,11 +142,11 @@ class Collector(object): if self.buffer is not None: self.buffer.reset() - def get_env_num(self): + def get_env_num(self) -> int: """Return the number of environments the collector have.""" return self.env_num - def reset_env(self): + def reset_env(self) -> None: """Reset all of the environment(s)' states and reset all of the cache buffers (if need). """ @@ -155,36 +164,36 @@ class Collector(object): for b in self._cached_buf: b.reset() - def seed(self, seed=None): + def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None: """Reset all the seed(s) of the given environment(s).""" if hasattr(self.env, 'seed'): return self.env.seed(seed) - def render(self, **kwargs): + def render(self, **kwargs) -> None: """Render all the environment(s).""" if hasattr(self.env, 'render'): return self.env.render(**kwargs) - def close(self): + def close(self) -> None: """Close the environment(s).""" if hasattr(self.env, 'close'): self.env.close() - def _make_batch(self, data): + def _make_batch(self, data: Any) -> Union[Any, np.ndarray]: """Return [data].""" if isinstance(data, np.ndarray): return data[None] else: return np.array([data]) - def _reset_state(self, id): + def _reset_state(self, id: Union[int, List[int]]) -> None: """Reset self.state[id].""" if self.state is None: return if isinstance(self.state, list): self.state[id] = None - elif isinstance(self.state, dict): - for k in self.state: + elif isinstance(self.state, dict) or isinstance(self.state, Batch): + for k in self.state.keys(): if isinstance(self.state[k], list): self.state[k][id] = None elif isinstance(self.state[k], torch.Tensor) or \ @@ -194,7 +203,8 @@ class Collector(object): isinstance(self.state, np.ndarray): self.state[id] = 0 - def _to_numpy(self, x): + def _to_numpy(self, x: Union[ + torch.Tensor, dict, Batch, np.ndarray]) -> None: """Return an object without torch.Tensor.""" if isinstance(x, torch.Tensor): return x.cpu().numpy() @@ -208,7 +218,12 @@ class Collector(object): return x return x - def collect(self, n_step=0, n_episode=0, render=None, log_fn=None): + def collect(self, + n_step: Optional[int] = 0, + n_episode: Optional[Union[int, List[int]]] = 0, + render: Optional[float] = None, + log_fn: Optional[Callable[[dict], None]] = None + ) -> Dict[str, float]: """Collect a specified number of step or episode. :param int n_step: how many steps you want to collect. @@ -375,7 +390,7 @@ class Collector(object): 'len': length_sum / n_episode, } - def sample(self, batch_size): + def sample(self, batch_size: int) -> Batch: """Sample a data batch from the internal replay buffer. It will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data. diff --git a/tianshou/env/vecenv.py b/tianshou/env/vecenv.py index 5f7225a..6145cc7 100644 --- a/tianshou/env/vecenv.py +++ b/tianshou/env/vecenv.py @@ -2,6 +2,7 @@ import gym import numpy as np from abc import ABC, abstractmethod from multiprocessing import Process, Pipe +from typing import List, Tuple, Union, Optional, Callable try: import ray @@ -36,16 +37,16 @@ class BaseVectorEnv(ABC, gym.Wrapper): envs.close() # close all environments """ - def __init__(self, env_fns): + def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None: self._env_fns = env_fns self.env_num = len(env_fns) - def __len__(self): + def __len__(self) -> int: """Return len(self), which is the number of environments.""" return self.env_num @abstractmethod - def reset(self, id=None): + def reset(self, id: Optional[Union[int, List[int]]] = None): """Reset the state of all the environments and return initial observations if id is ``None``, otherwise reset the specific environments with given id, either an int or a list. @@ -53,7 +54,8 @@ class BaseVectorEnv(ABC, gym.Wrapper): pass @abstractmethod - def step(self, action): + def step(self, action: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Run one timestep of all the environments’ dynamics. When the end of episode is reached, you are responsible for calling reset(id) to reset this environment’s state. @@ -76,19 +78,19 @@ class BaseVectorEnv(ABC, gym.Wrapper): pass @abstractmethod - def seed(self, seed=None): + def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None: """Set the seed for all environments. Accept ``None``, an int (which will extend ``i`` to ``[i, i + 1, i + 2, ...]``) or a list. """ pass @abstractmethod - def render(self, **kwargs): + def render(self, **kwargs) -> None: """Render all of the environments.""" pass @abstractmethod - def close(self): + def close(self) -> None: """Close all of the environments.""" pass @@ -102,11 +104,11 @@ class VectorEnv(BaseVectorEnv): explanation. """ - def __init__(self, env_fns): + def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None: super().__init__(env_fns) self.envs = [_() for _ in env_fns] - def reset(self, id=None): + def reset(self, id: Optional[Union[int, List[int]]] = None) -> None: if id is None: self._obs = np.stack([e.reset() for e in self.envs]) else: @@ -116,7 +118,8 @@ class VectorEnv(BaseVectorEnv): self._obs[i] = self.envs[i].reset() return self._obs - def step(self, action): + def step(self, action: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: assert len(action) == self.env_num result = [e.step(a) for e, a in zip(self.envs, action)] self._obs, self._rew, self._done, self._info = zip(*result) @@ -126,7 +129,7 @@ class VectorEnv(BaseVectorEnv): self._info = np.stack(self._info) return self._obs, self._rew, self._done, self._info - def seed(self, seed=None): + def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None: if np.isscalar(seed): seed = [seed + _ for _ in range(self.env_num)] elif seed is None: @@ -137,14 +140,14 @@ class VectorEnv(BaseVectorEnv): result.append(e.seed(s)) return result - def render(self, **kwargs): + def render(self, **kwargs) -> None: result = [] for e in self.envs: if hasattr(e, 'render'): result.append(e.render(**kwargs)) return result - def close(self): + def close(self) -> None: return [e.close() for e in self.envs] @@ -182,7 +185,7 @@ class SubprocVectorEnv(BaseVectorEnv): explanation. """ - def __init__(self, env_fns): + def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None: super().__init__(env_fns) self.closed = False self.parent_remote, self.child_remote = \ @@ -198,7 +201,8 @@ class SubprocVectorEnv(BaseVectorEnv): for c in self.child_remote: c.close() - def step(self, action): + def step(self, action: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: assert len(action) == self.env_num for p, a in zip(self.parent_remote, action): p.send(['step', a]) @@ -210,7 +214,7 @@ class SubprocVectorEnv(BaseVectorEnv): self._info = np.stack(self._info) return self._obs, self._rew, self._done, self._info - def reset(self, id=None): + def reset(self, id: Optional[Union[int, List[int]]] = None) -> None: if id is None: for p in self.parent_remote: p.send(['reset', None]) @@ -225,7 +229,7 @@ class SubprocVectorEnv(BaseVectorEnv): self._obs[i] = self.parent_remote[i].recv() return self._obs - def seed(self, seed=None): + def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None: if np.isscalar(seed): seed = [seed + _ for _ in range(self.env_num)] elif seed is None: @@ -234,12 +238,12 @@ class SubprocVectorEnv(BaseVectorEnv): p.send(['seed', s]) return [p.recv() for p in self.parent_remote] - def render(self, **kwargs): + def render(self, **kwargs) -> None: for p in self.parent_remote: p.send(['render', kwargs]) return [p.recv() for p in self.parent_remote] - def close(self): + def close(self) -> None: if self.closed: return for p in self.parent_remote: @@ -263,7 +267,7 @@ class RayVectorEnv(BaseVectorEnv): explanation. """ - def __init__(self, env_fns): + def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None: super().__init__(env_fns) try: if not ray.is_initialized(): @@ -275,7 +279,8 @@ class RayVectorEnv(BaseVectorEnv): ray.remote(gym.Wrapper).options(num_cpus=0).remote(e()) for e in env_fns] - def step(self, action): + def step(self, action: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: assert len(action) == self.env_num result = ray.get([e.step.remote(a) for e, a in zip(self.envs, action)]) self._obs, self._rew, self._done, self._info = zip(*result) @@ -285,7 +290,7 @@ class RayVectorEnv(BaseVectorEnv): self._info = np.stack(self._info) return self._obs, self._rew, self._done, self._info - def reset(self, id=None): + def reset(self, id: Optional[Union[int, List[int]]] = None) -> None: if id is None: result_obj = [e.reset.remote() for e in self.envs] self._obs = np.stack(ray.get(result_obj)) @@ -299,7 +304,7 @@ class RayVectorEnv(BaseVectorEnv): self._obs[i] = ray.get(result_obj[_]) return self._obs - def seed(self, seed=None): + def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None: if not hasattr(self.envs[0], 'seed'): return if np.isscalar(seed): @@ -308,10 +313,10 @@ class RayVectorEnv(BaseVectorEnv): seed = [seed] * self.env_num return ray.get([e.seed.remote(s) for e, s in zip(self.envs, seed)]) - def render(self, **kwargs): + def render(self, **kwargs) -> None: if not hasattr(self.envs[0], 'render'): return return ray.get([e.render.remote(**kwargs) for e in self.envs]) - def close(self): + def close(self) -> None: return ray.get([e.close.remote() for e in self.envs]) diff --git a/tianshou/exploration/random.py b/tianshou/exploration/random.py index 25fee39..385388d 100644 --- a/tianshou/exploration/random.py +++ b/tianshou/exploration/random.py @@ -1,4 +1,5 @@ import numpy as np +from typing import Union, Optional class OUNoise(object): @@ -17,13 +18,18 @@ class OUNoise(object): Ornstein-Uhlenbeck process. """ - def __init__(self, sigma=0.3, theta=0.15, dt=1e-2, x0=None): + def __init__(self, + sigma: Optional[float] = 0.3, + theta: Optional[float] = 0.15, + dt: Optional[float] = 1e-2, + x0: Optional[Union[float, np.ndarray]] = None + ) -> None: self.alpha = theta * dt self.beta = sigma * np.sqrt(dt) self.x0 = x0 self.reset() - def __call__(self, size, mu=.1): + def __call__(self, size: tuple, mu: Optional[float] = .1) -> np.ndarray: """Generate new noise. Return a ``numpy.ndarray`` which size is equal to ``size``. """ @@ -33,6 +39,6 @@ class OUNoise(object): self.x = self.x + self.alpha * (mu - self.x) + r return self.x - def reset(self): + def reset(self) -> None: """Reset to the initial state.""" self.x = None diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index ff0b599..481f462 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -1,6 +1,10 @@ +import torch import numpy as np from torch import nn from abc import ABC, abstractmethod +from typing import Dict, List, Union, Optional + +from tianshou.data import Batch, ReplayBuffer class BasePolicy(ABC, nn.Module): @@ -39,17 +43,20 @@ class BasePolicy(ABC, nn.Module): policy.load_state_dict(torch.load('policy.pth')) """ - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super().__init__() - def process_fn(self, batch, buffer, indice): + def process_fn(self, batch: Batch, buffer: ReplayBuffer, + indice: np.ndarray) -> Batch: """Pre-process the data from the provided replay buffer. Check out :ref:`policy_concept` for more information. """ return batch @abstractmethod - def forward(self, batch, state=None, **kwargs): + def forward(self, batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + **kwargs) -> Batch: """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which MUST have the following\ @@ -80,7 +87,8 @@ class BasePolicy(ABC, nn.Module): pass @abstractmethod - def learn(self, batch, **kwargs): + def learn(self, batch: Batch, **kwargs + ) -> Dict[str, Union[float, List[float]]]: """Update policy with a given batch of data. :return: A dict which includes loss and its corresponding label. @@ -88,8 +96,11 @@ class BasePolicy(ABC, nn.Module): pass @staticmethod - def compute_episodic_return(batch, v_s_=None, - gamma=0.99, gae_lambda=0.95): + def compute_episodic_return( + batch: Batch, + v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None, + gamma: Optional[float] = 0.99, + gae_lambda: Optional[float] = 0.95) -> Batch: """Compute returns over given full-length episodes, including the implementation of Generalized Advantage Estimation (arXiv:1506.02438). diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index 10bcef6..9e4ec97 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -1,5 +1,7 @@ import torch +import numpy as np import torch.nn.functional as F +from typing import Dict, Union, Optional from tianshou.data import Batch from tianshou.policy import BasePolicy @@ -19,7 +21,9 @@ class ImitationPolicy(BasePolicy): Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed explanation. """ - def __init__(self, model, optim, mode='continuous'): + + def __init__(self, model: torch.nn.Module, optim: torch.optim.Optimizer, + mode: Optional[str] = 'continuous', **kwargs) -> None: super().__init__() self.model = model self.optim = optim @@ -27,7 +31,10 @@ class ImitationPolicy(BasePolicy): f'Mode {mode} is not in ["continuous", "discrete"]' self.mode = mode - def forward(self, batch, state=None): + def forward(self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + **kwargs) -> Batch: logits, h = self.model(batch.obs, state=state, info=batch.info) if self.mode == 'discrete': a = logits.max(dim=1)[1] @@ -35,7 +42,7 @@ class ImitationPolicy(BasePolicy): a = logits return Batch(logits=logits, act=a, state=h) - def learn(self, batch, **kwargs): + def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: self.optim.zero_grad() if self.mode == 'continuous': a = self(batch).act diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 0ed57eb..3a821b0 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -2,9 +2,10 @@ import torch import numpy as np from torch import nn import torch.nn.functional as F +from typing import Dict, List, Union, Optional -from tianshou.data import Batch from tianshou.policy import PGPolicy +from tianshou.data import Batch, ReplayBuffer class A2CPolicy(PGPolicy): @@ -31,11 +32,19 @@ class A2CPolicy(PGPolicy): explanation. """ - def __init__(self, actor, critic, optim, - dist_fn=torch.distributions.Categorical, - discount_factor=0.99, vf_coef=.5, ent_coef=.01, - max_grad_norm=None, gae_lambda=0.95, - reward_normalization=False, **kwargs): + def __init__(self, + actor: torch.nn.Module, + critic: torch.nn.Module, + optim: torch.optim.Optimizer, + dist_fn: Optional[torch.distributions.Distribution] + = torch.distributions.Categorical, + discount_factor: Optional[float] = 0.99, + vf_coef: Optional[float] = .5, + ent_coef: Optional[float] = .01, + max_grad_norm: Optional[float] = None, + gae_lambda: Optional[float] = 0.95, + reward_normalization: Optional[bool] = False, + **kwargs) -> None: super().__init__(None, optim, dist_fn, discount_factor, **kwargs) self.actor = actor self.critic = critic @@ -48,7 +57,8 @@ class A2CPolicy(PGPolicy): self._rew_norm = reward_normalization self.__eps = np.finfo(np.float32).eps.item() - def process_fn(self, batch, buffer, indice): + def process_fn(self, batch: Batch, buffer: ReplayBuffer, + indice: np.ndarray) -> Batch: if self._lambda in [0, 1]: return self.compute_episodic_return( batch, None, gamma=self._gamma, gae_lambda=self._lambda) @@ -60,7 +70,9 @@ class A2CPolicy(PGPolicy): return self.compute_episodic_return( batch, v_, gamma=self._gamma, gae_lambda=self._lambda) - def forward(self, batch, state=None, **kwargs): + def forward(self, batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + **kwargs) -> Batch: """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which has 4 keys: @@ -83,7 +95,8 @@ class A2CPolicy(PGPolicy): act = dist.sample() return Batch(logits=logits, act=act, state=h, dist=dist) - def learn(self, batch, batch_size=None, repeat=1, **kwargs): + def learn(self, batch: Batch, batch_size: int, repeat: int, + **kwargs) -> Dict[str, List[float]]: self._batch = batch_size r = batch.returns if self._rew_norm and r.std() > self.__eps: diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 503b39d..60c147e 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -2,10 +2,11 @@ import torch import numpy as np from copy import deepcopy import torch.nn.functional as F +from typing import Dict, Tuple, Union, Optional -from tianshou.data import Batch from tianshou.policy import BasePolicy # from tianshou.exploration import OUNoise +from tianshou.data import Batch, ReplayBuffer class DDPGPolicy(BasePolicy): @@ -23,7 +24,7 @@ class DDPGPolicy(BasePolicy): :param float exploration_noise: the noise intensity, add to the action, defaults to 0.1. :param action_range: the action range (minimum, maximum). - :type action_range: [float, float] + :type action_range: (float, float) :param bool reward_normalization: normalize the reward to Normal(0, 1), defaults to ``False``. :param bool ignore_done: ignore the done flag while training the policy, @@ -35,10 +36,18 @@ class DDPGPolicy(BasePolicy): explanation. """ - def __init__(self, actor, actor_optim, critic, critic_optim, - tau=0.005, gamma=0.99, exploration_noise=0.1, - action_range=None, reward_normalization=False, - ignore_done=False, **kwargs): + def __init__(self, + actor: torch.nn.Module, + actor_optim: torch.optim.Optimizer, + critic: torch.nn.Module, + critic_optim: torch.optim.Optimizer, + tau: Optional[float] = 0.005, + gamma: Optional[float] = 0.99, + exploration_noise: Optional[float] = 0.1, + action_range: Optional[Tuple[float, float]] = None, + reward_normalization: Optional[bool] = False, + ignore_done: Optional[bool] = False, + **kwargs) -> None: super().__init__(**kwargs) if actor is not None: self.actor, self.actor_old = actor, deepcopy(actor) @@ -64,23 +73,23 @@ class DDPGPolicy(BasePolicy): self._rew_norm = reward_normalization self.__eps = np.finfo(np.float32).eps.item() - def set_eps(self, eps): + def set_eps(self, eps: float) -> None: """Set the eps for exploration.""" self._eps = eps - def train(self): + def train(self) -> None: """Set the module in training mode, except for the target network.""" self.training = True self.actor.train() self.critic.train() - def eval(self): + def eval(self) -> None: """Set the module in evaluation mode, except for the target network.""" self.training = False self.actor.eval() self.critic.eval() - def sync_weight(self): + def sync_weight(self) -> None: """Soft-update the weight for the target network.""" for o, n in zip(self.actor_old.parameters(), self.actor.parameters()): o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) @@ -88,7 +97,8 @@ class DDPGPolicy(BasePolicy): self.critic_old.parameters(), self.critic.parameters()): o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) - def process_fn(self, batch, buffer, indice): + def process_fn(self, batch: Batch, buffer: ReplayBuffer, + indice: np.ndarray) -> Batch: if self._rew_norm: bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer mean, std = bfr.mean(), bfr.std() @@ -98,8 +108,12 @@ class DDPGPolicy(BasePolicy): batch.done = batch.done * 0. return batch - def forward(self, batch, state=None, - model='actor', input='obs', eps=None, **kwargs): + def forward(self, batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + model: Optional[str] = 'actor', + input: Optional[str] = 'obs', + eps: Optional[float] = None, + **kwargs) -> Batch: """Compute action over the given batch data. :param float eps: in [0, 1], for exploration use. @@ -129,7 +143,7 @@ class DDPGPolicy(BasePolicy): logits = logits.clamp(self._range[0], self._range[1]) return Batch(act=logits, state=h) - def learn(self, batch, **kwargs): + def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: with torch.no_grad(): target_q = self.critic_old(batch.obs_next, self( batch, model='actor_old', input='obs_next', eps=0).act) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 61424fc..611d9fa 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -2,9 +2,10 @@ import torch import numpy as np from copy import deepcopy import torch.nn.functional as F +from typing import Dict, Union, Optional from tianshou.policy import BasePolicy -from tianshou.data import Batch, PrioritizedReplayBuffer +from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer class DQNPolicy(BasePolicy): @@ -25,8 +26,13 @@ class DQNPolicy(BasePolicy): explanation. """ - def __init__(self, model, optim, discount_factor=0.99, - estimation_step=1, target_update_freq=0, **kwargs): + def __init__(self, + model: torch.nn.Module, + optim: torch.optim.Optimizer, + discount_factor: Optional[float] = 0.99, + estimation_step: Optional[int] = 1, + target_update_freq: Optional[int] = 0, + **kwargs) -> None: super().__init__(**kwargs) self.model = model self.optim = optim @@ -42,25 +48,26 @@ class DQNPolicy(BasePolicy): self.model_old = deepcopy(self.model) self.model_old.eval() - def set_eps(self, eps): + def set_eps(self, eps: float) -> None: """Set the eps for epsilon-greedy exploration.""" self.eps = eps - def train(self): + def train(self) -> None: """Set the module in training mode, except for the target network.""" self.training = True self.model.train() - def eval(self): + def eval(self) -> None: """Set the module in evaluation mode, except for the target network.""" self.training = False self.model.eval() - def sync_weight(self): + def sync_weight(self) -> None: """Synchronize the weight for the target network.""" self.model_old.load_state_dict(self.model.state_dict()) - def process_fn(self, batch, buffer, indice): + def process_fn(self, batch: Batch, buffer: ReplayBuffer, + indice: np.ndarray) -> Batch: r"""Compute the n-step return for Q-learning targets: .. math:: @@ -115,8 +122,12 @@ class DQNPolicy(BasePolicy): batch.loss += loss return batch - def forward(self, batch, state=None, - model='model', input='obs', eps=None, **kwargs): + def forward(self, batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + model: Optional[str] = 'model', + input: Optional[str] = 'obs', + eps: Optional[float] = None, + **kwargs) -> Batch: """Compute action over the given batch data. :param float eps: in [0, 1], for epsilon-greedy exploration method. @@ -144,7 +155,7 @@ class DQNPolicy(BasePolicy): act[i] = np.random.randint(q.shape[1]) return Batch(logits=q, act=act, state=h) - def learn(self, batch, **kwargs): + def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: if self._target and self._cnt % self._freq == 0: self.sync_weight() self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 4c06e39..1df44bf 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -1,8 +1,9 @@ import torch import numpy as np +from typing import Dict, List, Union, Optional -from tianshou.data import Batch from tianshou.policy import BasePolicy +from tianshou.data import Batch, ReplayBuffer class PGPolicy(BasePolicy): @@ -20,8 +21,14 @@ class PGPolicy(BasePolicy): explanation. """ - def __init__(self, model, optim, dist_fn=torch.distributions.Categorical, - discount_factor=0.99, reward_normalization=False, **kwargs): + def __init__(self, + model: torch.nn.Module, + optim: torch.optim.Optimizer, + dist_fn: Optional[torch.distributions.Distribution] + = torch.distributions.Categorical, + discount_factor: Optional[float] = 0.99, + reward_normalization: Optional[bool] = False, + **kwargs) -> None: super().__init__(**kwargs) self.model = model self.optim = optim @@ -31,7 +38,8 @@ class PGPolicy(BasePolicy): self._rew_norm = reward_normalization self.__eps = np.finfo(np.float32).eps.item() - def process_fn(self, batch, buffer, indice): + def process_fn(self, batch: Batch, buffer: ReplayBuffer, + indice: np.ndarray) -> Batch: r"""Compute the discounted returns for each frame: .. math:: @@ -46,7 +54,9 @@ class PGPolicy(BasePolicy): return self.compute_episodic_return( batch, gamma=self._gamma, gae_lambda=1.) - def forward(self, batch, state=None, **kwargs): + def forward(self, batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + **kwargs) -> Batch: """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which has 4 keys: @@ -69,7 +79,8 @@ class PGPolicy(BasePolicy): act = dist.sample() return Batch(logits=logits, act=act, state=h, dist=dist) - def learn(self, batch, batch_size=None, repeat=1, **kwargs): + def learn(self, batch: Batch, batch_size: int, repeat: int, + **kwargs) -> Dict[str, List[float]]: losses = [] r = batch.returns if self._rew_norm and r.std() > self.__eps: diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 25cfb28..9f0a4ab 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -1,9 +1,10 @@ import torch import numpy as np from torch import nn +from typing import Dict, List, Tuple, Union, Optional -from tianshou.data import Batch from tianshou.policy import PGPolicy +from tianshou.data import Batch, ReplayBuffer class PPOPolicy(PGPolicy): @@ -23,7 +24,7 @@ class PPOPolicy(PGPolicy): :param float vf_coef: weight for value loss, defaults to 0.5. :param float ent_coef: weight for entropy loss, defaults to 0.01. :param action_range: the action range (minimum, maximum). - :type action_range: [float, float] + :type action_range: (float, float) :param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation, defaults to 0.95. :param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5, @@ -40,11 +41,22 @@ class PPOPolicy(PGPolicy): explanation. """ - def __init__(self, actor, critic, optim, dist_fn, - discount_factor=0.99, max_grad_norm=.5, eps_clip=.2, - vf_coef=.5, ent_coef=.01, action_range=None, gae_lambda=0.95, - dual_clip=5., value_clip=True, reward_normalization=True, - **kwargs): + def __init__(self, + actor: torch.nn.Module, + critic: torch.nn.Module, + optim: torch.optim.Optimizer, + dist_fn: torch.distributions.Distribution, + discount_factor: Optional[float] = 0.99, + max_grad_norm: Optional[float] = None, + eps_clip: Optional[float] = .2, + vf_coef: Optional[float] = .5, + ent_coef: Optional[float] = .01, + action_range: Optional[Tuple[float, float]] = None, + gae_lambda: Optional[float] = 0.95, + dual_clip: Optional[float] = 5., + value_clip: Optional[bool] = True, + reward_normalization: Optional[bool] = True, + **kwargs) -> None: super().__init__(None, None, dist_fn, discount_factor, **kwargs) self._max_grad_norm = max_grad_norm self._eps_clip = eps_clip @@ -64,7 +76,8 @@ class PPOPolicy(PGPolicy): self._rew_norm = reward_normalization self.__eps = np.finfo(np.float32).eps.item() - def process_fn(self, batch, buffer, indice): + def process_fn(self, batch: Batch, buffer: ReplayBuffer, + indice: np.ndarray) -> Batch: if self._rew_norm: mean, std = batch.rew.mean(), batch.rew.std() if std > self.__eps: @@ -80,7 +93,9 @@ class PPOPolicy(PGPolicy): return self.compute_episodic_return( batch, v_, gamma=self._gamma, gae_lambda=self._lambda) - def forward(self, batch, state=None, **kwargs): + def forward(self, batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + **kwargs) -> Batch: """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which has 4 keys: @@ -105,7 +120,8 @@ class PPOPolicy(PGPolicy): act = act.clamp(self._range[0], self._range[1]) return Batch(logits=logits, act=act, state=h, dist=dist) - def learn(self, batch, batch_size=None, repeat=1, **kwargs): + def learn(self, batch: Batch, batch_size: int, repeat: int, + **kwargs) -> Dict[str, List[float]]: self._batch = batch_size losses, clip_losses, vf_losses, ent_losses = [], [], [], [] v = [] diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 3a55fba..0bb78d4 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -2,6 +2,7 @@ import torch import numpy as np from copy import deepcopy import torch.nn.functional as F +from typing import Dict, Tuple, Union, Optional from tianshou.data import Batch from tianshou.policy import DDPGPolicy @@ -28,7 +29,7 @@ class SACPolicy(DDPGPolicy): defaults to 0.1. :param float alpha: entropy regularization coefficient, default to 0.2. :param action_range: the action range (minimum, maximum). - :type action_range: [float, float] + :type action_range: (float, float) :param bool reward_normalization: normalize the reward to Normal(0, 1), defaults to ``False``. :param bool ignore_done: ignore the done flag while training the policy, @@ -40,10 +41,20 @@ class SACPolicy(DDPGPolicy): explanation. """ - def __init__(self, actor, actor_optim, critic1, critic1_optim, - critic2, critic2_optim, tau=0.005, gamma=0.99, - alpha=0.2, action_range=None, reward_normalization=False, - ignore_done=False, **kwargs): + def __init__(self, + actor: torch.nn.Module, + actor_optim: torch.optim.Optimizer, + critic1: torch.nn.Module, + critic1_optim: torch.optim.Optimizer, + critic2: torch.nn.Module, + critic2_optim: torch.optim.Optimizer, + tau: Optional[float] = 0.005, + gamma: Optional[float] = 0.99, + alpha: Optional[float] = 0.2, + action_range: Optional[Tuple[float, float]] = None, + reward_normalization: Optional[bool] = False, + ignore_done: Optional[bool] = False, + **kwargs) -> None: super().__init__(None, None, None, None, tau, gamma, 0, action_range, reward_normalization, ignore_done, **kwargs) @@ -57,19 +68,19 @@ class SACPolicy(DDPGPolicy): self._alpha = alpha self.__eps = np.finfo(np.float32).eps.item() - def train(self): + def train(self) -> None: self.training = True self.actor.train() self.critic1.train() self.critic2.train() - def eval(self): + def eval(self) -> None: self.training = False self.actor.eval() self.critic1.eval() self.critic2.eval() - def sync_weight(self): + def sync_weight(self) -> None: for o, n in zip( self.critic1_old.parameters(), self.critic1.parameters()): o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) @@ -77,7 +88,9 @@ class SACPolicy(DDPGPolicy): self.critic2_old.parameters(), self.critic2.parameters()): o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) - def forward(self, batch, state=None, input='obs', **kwargs): + def forward(self, batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + input: Optional[str] = 'obs', **kwargs) -> Batch: obs = getattr(batch, input) logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) @@ -92,7 +105,7 @@ class SACPolicy(DDPGPolicy): return Batch( logits=logits, act=act, state=h, dist=dist, log_prob=log_prob) - def learn(self, batch, **kwargs): + def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: with torch.no_grad(): obs_next_result = self(batch, input='obs_next') a_ = obs_next_result.act diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 349940d..807c661 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -1,7 +1,9 @@ import torch from copy import deepcopy import torch.nn.functional as F +from typing import Dict, Tuple, Optional +from tianshou.data import Batch from tianshou.policy import DDPGPolicy @@ -32,7 +34,7 @@ class TD3Policy(DDPGPolicy): :param float noise_clip: the clipping range used in updating policy network, default to 0.5. :param action_range: the action range (minimum, maximum). - :type action_range: [float, float] + :type action_range: (float, float) :param bool reward_normalization: normalize the reward to Normal(0, 1), defaults to ``False``. :param bool ignore_done: ignore the done flag while training the policy, @@ -44,11 +46,23 @@ class TD3Policy(DDPGPolicy): explanation. """ - def __init__(self, actor, actor_optim, critic1, critic1_optim, - critic2, critic2_optim, tau=0.005, gamma=0.99, - exploration_noise=0.1, policy_noise=0.2, update_actor_freq=2, - noise_clip=0.5, action_range=None, - reward_normalization=False, ignore_done=False, **kwargs): + def __init__(self, + actor: torch.nn.Module, + actor_optim: torch.optim.Optimizer, + critic1: torch.nn.Module, + critic1_optim: torch.optim.Optimizer, + critic2: torch.nn.Module, + critic2_optim: torch.optim.Optimizer, + tau: Optional[float] = 0.005, + gamma: Optional[float] = 0.99, + exploration_noise: Optional[float] = 0.1, + policy_noise: Optional[float] = 0.2, + update_actor_freq: Optional[int] = 2, + noise_clip: Optional[float] = 0.5, + action_range: Optional[Tuple[float, float]] = None, + reward_normalization: Optional[bool] = False, + ignore_done: Optional[bool] = False, + **kwargs) -> None: super().__init__(actor, actor_optim, None, None, tau, gamma, exploration_noise, action_range, reward_normalization, ignore_done, **kwargs) @@ -64,19 +78,19 @@ class TD3Policy(DDPGPolicy): self._cnt = 0 self._last = 0 - def train(self): + def train(self) -> None: self.training = True self.actor.train() self.critic1.train() self.critic2.train() - def eval(self): + def eval(self) -> None: self.training = False self.actor.eval() self.critic1.eval() self.critic2.eval() - def sync_weight(self): + def sync_weight(self) -> None: for o, n in zip(self.actor_old.parameters(), self.actor.parameters()): o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) for o, n in zip( @@ -86,7 +100,7 @@ class TD3Policy(DDPGPolicy): self.critic2_old.parameters(), self.critic2.parameters()): o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) - def learn(self, batch, **kwargs): + def learn(self, batch: Batch, **kwargs) -> Dict[str, float]: with torch.no_grad(): a_ = self(batch, model='actor_old', input='obs_next').act dev = a_.device diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index dfc0159..5dbdf41 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -1,16 +1,33 @@ import time import tqdm +from torch.utils.tensorboard import SummaryWriter +from typing import Dict, List, Union, Callable, Optional +from tianshou.data import Collector +from tianshou.policy import BasePolicy from tianshou.utils import tqdm_config, MovAvg from tianshou.trainer import test_episode, gather_info -def offpolicy_trainer(policy, train_collector, test_collector, max_epoch, - step_per_epoch, collect_per_step, episode_per_test, - batch_size, - train_fn=None, test_fn=None, stop_fn=None, save_fn=None, - log_fn=None, writer=None, log_interval=1, verbose=True, - **kwargs): +def offpolicy_trainer( + policy: BasePolicy, + train_collector: Collector, + test_collector: Collector, + max_epoch: int, + step_per_epoch: int, + collect_per_step: int, + episode_per_test: Union[int, List[int]], + batch_size: int, + train_fn: Optional[Callable[[int], None]] = None, + test_fn: Optional[Callable[[int], None]] = None, + stop_fn: Optional[Callable[[float], bool]] = None, + save_fn: Optional[Callable[[BasePolicy], None]] = None, + log_fn: Optional[Callable[[dict], None]] = None, + writer: Optional[SummaryWriter] = None, + log_interval: Optional[int] = 1, + verbose: Optional[bool] = True, + **kwargs +) -> Dict[str, Union[float, str]]: """A wrapper for off-policy trainer procedure. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 44218ec..072ce05 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -1,16 +1,34 @@ import time import tqdm +from torch.utils.tensorboard import SummaryWriter +from typing import Dict, List, Union, Callable, Optional +from tianshou.data import Collector +from tianshou.policy import BasePolicy from tianshou.utils import tqdm_config, MovAvg from tianshou.trainer import test_episode, gather_info -def onpolicy_trainer(policy, train_collector, test_collector, max_epoch, - step_per_epoch, collect_per_step, repeat_per_collect, - episode_per_test, batch_size, - train_fn=None, test_fn=None, stop_fn=None, save_fn=None, - log_fn=None, writer=None, log_interval=1, verbose=True, - **kwargs): +def onpolicy_trainer( + policy: BasePolicy, + train_collector: Collector, + test_collector: Collector, + max_epoch: int, + step_per_epoch: int, + collect_per_step: int, + repeat_per_collect: int, + episode_per_test: Union[int, List[int]], + batch_size: int, + train_fn: Optional[Callable[[int], None]] = None, + test_fn: Optional[Callable[[int], None]] = None, + stop_fn: Optional[Callable[[float], bool]] = None, + save_fn: Optional[Callable[[BasePolicy], None]] = None, + log_fn: Optional[Callable[[dict], None]] = None, + writer: Optional[SummaryWriter] = None, + log_interval: Optional[int] = 1, + verbose: Optional[bool] = True, + **kwargs +) -> Dict[str, Union[float, str]]: """A wrapper for on-policy trainer procedure. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 0e3b996..eb9bd32 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -1,8 +1,17 @@ import time import numpy as np +from typing import Dict, List, Union, Callable + +from tianshou.data import Collector +from tianshou.policy import BasePolicy -def test_episode(policy, collector, test_fn, epoch, n_episode): +def test_episode( + policy: BasePolicy, + collector: Collector, + test_fn: Callable[[int], None], + epoch: int, + n_episode: Union[int, List[int]]) -> Dict[str, float]: """A simple wrapper of testing policy in collector.""" collector.reset_env() collector.reset_buffer() @@ -17,7 +26,11 @@ def test_episode(policy, collector, test_fn, epoch, n_episode): return collector.collect(n_episode=n_episode) -def gather_info(start_time, train_c, test_c, best_reward): +def gather_info(start_time: float, + train_c: Collector, + test_c: Collector, + best_reward: float + ) -> Dict[str, Union[float, str]]: """A simple wrapper of gathering information from collectors. :return: A dictionary with the following keys: diff --git a/tianshou/utils/moving_average.py b/tianshou/utils/moving_average.py index 3aefb23..c700f1a 100644 --- a/tianshou/utils/moving_average.py +++ b/tianshou/utils/moving_average.py @@ -1,5 +1,6 @@ import torch import numpy as np +from typing import Union, Optional class MovAvg(object): @@ -19,19 +20,20 @@ class MovAvg(object): >>> print(f'{stat.mean():.2f}±{stat.std():.2f}') 6.50±1.12 """ - def __init__(self, size=100): + + def __init__(self, size: Optional[int] = 100) -> None: super().__init__() self.size = size self.cache = [] self.banned = [np.inf, np.nan, -np.inf] - def add(self, x): + def add(self, x: Union[float, list, np.ndarray, torch.Tensor]) -> float: """Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with only one element, a python scalar, or a list of python scalar. """ if isinstance(x, torch.Tensor): x = x.item() - if isinstance(x, list): + if isinstance(x, list) or isinstance(x, np.ndarray): for _ in x: if _ not in self.banned: self.cache.append(_) @@ -41,17 +43,17 @@ class MovAvg(object): self.cache = self.cache[-self.size:] return self.get() - def get(self): + def get(self) -> float: """Get the average.""" if len(self.cache) == 0: return 0 return np.mean(self.cache) - def mean(self): + def mean(self) -> float: """Get the average. Same as :meth:`get`.""" return self.get() - def std(self): + def std(self) -> float: """Get the standard deviation.""" if len(self.cache) == 0: return 0