add type annotation

This commit is contained in:
Trinkle23897 2020-05-12 11:31:47 +08:00
parent 075825325e
commit 9b26137cd2
21 changed files with 414 additions and 187 deletions

View File

@ -4,7 +4,7 @@
+ [ ] documentation request (i.e. "X is missing from the documentation.")
+ [ ] new feature request
- [ ] I have visited the [source website], and in particular read the [known issues]
- [ ] I have searched through the [issue categories] for duplicates
- [ ] I have searched through the [issue tracker] and [issue categories] for duplicates
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
```python
import tianshou, torch, sys
@ -14,3 +14,4 @@
[source website]: https://github.com/thu-ml/tianshou/
[known issues]: https://github.com/thu-ml/tianshou/#faq-and-known-issues
[issue categories]: https://github.com/thu-ml/tianshou/projects/2
[issue tracker]: https://github.com/thu-ml/tianshou/issues?q=

View File

@ -8,7 +8,7 @@
Less important but also useful:
- [ ] I have visited the [source website], and in particular read the [known issues]
- [ ] I have searched through the [issue categories] for duplicates
- [ ] I have searched through the [issue tracker] and [issue categories] for duplicates
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
```python
import tianshou, torch, sys
@ -18,3 +18,4 @@ Less important but also useful:
[source website]: https://github.com/thu-ml/tianshou
[known issues]: https://github.com/thu-ml/tianshou/#faq-and-known-issues
[issue categories]: https://github.com/thu-ml/tianshou/projects/2
[issue tracker]: https://github.com/thu-ml/tianshou/issues?q=

View File

@ -1,3 +1,4 @@
import os
import gym
import torch
import pprint
@ -32,6 +33,7 @@ def get_args():
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--rew-norm', type=bool, default=True)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
@ -73,14 +75,18 @@ def test_sac(args=get_args()):
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, args.alpha,
[env.action_space.low[0], env.action_space.high[0]],
reward_normalization=True, ignore_done=True)
reward_normalization=args.rew_norm, ignore_done=True)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size)
# log
writer = SummaryWriter(args.logdir + '/' + 'sac')
log_path = os.path.join(args.logdir, args.task, 'sac')
writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(x):
return x >= env.spec.reward_threshold
@ -89,7 +95,7 @@ def test_sac(args=get_args()):
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -1,6 +1,7 @@
import torch
import pprint
import numpy as np
from typing import Any, List, Union, Iterator, Optional
class Batch(object):
@ -19,8 +20,8 @@ class Batch(object):
>>> print(data)
Batch(
a: 4,
b: [3 4 5],
c: 2312312,
b: array([3, 4, 5]),
c: '2312312',
)
In short, you can define a :class:`Batch` with any key-value pair. The
@ -65,7 +66,7 @@ class Batch(object):
[11 22] [6 6]
"""
def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
super().__init__()
self._meta = {}
for k, v in kwargs.items():
@ -85,7 +86,7 @@ class Batch(object):
else:
self.__dict__[k] = kwargs[k]
def __getitem__(self, index):
def __getitem__(self, index: Union[str, slice]) -> Union['Batch', dict]:
"""Return self[index]."""
if isinstance(index, str):
return self.__getattr__(index)
@ -96,7 +97,7 @@ class Batch(object):
b._meta = self._meta
return b
def __getattr__(self, key):
def __getattr__(self, key: str) -> Union['Batch', Any]:
"""Return self.key"""
if key not in self._meta:
if key not in self.__dict__:
@ -108,7 +109,7 @@ class Batch(object):
d[k_] = self.__dict__[k__]
return Batch(**d)
def __repr__(self):
def __repr__(self) -> str:
"""Return str(self)."""
s = self.__class__.__name__ + '(\n'
flag = False
@ -125,18 +126,18 @@ class Batch(object):
s = self.__class__.__name__ + '()'
return s
def keys(self):
def keys(self) -> List[str]:
"""Return self.keys()."""
return sorted([i for i in self.__dict__ if i[0] != '_'] +
list(self._meta))
return sorted([
i for i in self.__dict__ if i[0] != '_'] + list(self._meta))
def get(self, k, d=None):
def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]:
"""Return self[k] if k in self else d. d defaults to None."""
if k in self.__dict__ or k in self._meta:
return self.__getattr__(k)
return d
def to_numpy(self):
def to_numpy(self) -> np.ndarray:
"""Change all torch.Tensor to numpy.ndarray. This is an inplace
operation.
"""
@ -144,7 +145,7 @@ class Batch(object):
if isinstance(self.__dict__[k], torch.Tensor):
self.__dict__[k] = self.__dict__[k].cpu().numpy()
def append(self, batch):
def append(self, batch: 'Batch') -> None:
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
for k in batch.__dict__:
@ -169,13 +170,14 @@ class Batch(object):
+ 'in class Batch.'
raise TypeError(s)
def __len__(self):
def __len__(self) -> int:
"""Return len(self)."""
return min([
len(self.__dict__[k]) for k in self.__dict__
if k != '_meta' and self.__dict__[k] is not None])
def split(self, size=None, shuffle=True):
def split(self, size: Optional[int] = None,
shuffle: Optional[bool] = True) -> Iterator['Batch']:
"""Split whole data into multiple small batch.
:param int size: if it is ``None``, it does not split the data batch;

View File

@ -1,5 +1,8 @@
import pprint
import numpy as np
from copy import deepcopy
from typing import Tuple, Union, Optional
from tianshou.data.batch import Batch
@ -92,7 +95,8 @@ class ReplayBuffer(object):
[ 7. 7. 8. 9.]]
"""
def __init__(self, size, stack_num=0, ignore_obs_next=False, **kwargs):
def __init__(self, size: int, stack_num: Optional[int] = 0,
ignore_obs_next: Optional[bool] = False, **kwargs) -> None:
super().__init__()
self._maxsize = size
self._stack = stack_num
@ -100,11 +104,11 @@ class ReplayBuffer(object):
self._meta = {}
self.reset()
def __len__(self):
def __len__(self) -> int:
"""Return len(self)."""
return self._size
def __repr__(self):
def __repr__(self) -> str:
"""Return str(self)."""
s = self.__class__.__name__ + '(\n'
flag = False
@ -121,7 +125,7 @@ class ReplayBuffer(object):
s = self.__class__.__name__ + '()'
return s
def __getattr__(self, key):
def __getattr__(self, key: str) -> Union[Batch, np.ndarray]:
"""Return self.key"""
if key not in self._meta:
if key not in self.__dict__:
@ -133,7 +137,9 @@ class ReplayBuffer(object):
d[k_] = self.__dict__[k__]
return Batch(**d)
def _add_to_buffer(self, name, inst):
def _add_to_buffer(
self, name: str,
inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None:
if inst is None:
if getattr(self, name, None) is None:
self.__dict__[name] = None
@ -164,9 +170,11 @@ class ReplayBuffer(object):
f"key: {name}, expect shape: {self.__dict__[name].shape[1:]}, "
f"given shape: {inst.shape}.")
if name not in self._meta:
if name == 'info':
inst = deepcopy(inst)
self.__dict__[name][self._index] = inst
def update(self, buffer):
def update(self, buffer: 'ReplayBuffer') -> None:
"""Move the data from the given buffer to self."""
i = begin = buffer._index % len(buffer)
while True:
@ -178,8 +186,15 @@ class ReplayBuffer(object):
if i == begin:
break
def add(self, obs, act, rew, done, obs_next=None, info={}, policy={},
**kwargs):
def add(self,
obs: Union[dict, np.ndarray],
act: Union[np.ndarray, float],
rew: float,
done: bool,
obs_next: Optional[Union[dict, np.ndarray]] = None,
info: Optional[dict] = {},
policy: Optional[Union[dict, Batch]] = {},
**kwargs) -> None:
"""Add a batch of data into replay buffer."""
assert isinstance(info, dict), \
'You should return a dict in the last argument of env.step().'
@ -197,11 +212,11 @@ class ReplayBuffer(object):
else:
self._size = self._index = self._index + 1
def reset(self):
def reset(self) -> None:
"""Clear all the data in replay buffer."""
self._index = self._size = 0
def sample(self, batch_size):
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
"""Get a random sample from buffer with size equal to batch_size. \
Return all the data in the buffer if batch_size is ``0``.
@ -216,7 +231,8 @@ class ReplayBuffer(object):
])
return self[indice], indice
def get(self, indice, key, stack_num=None):
def get(self, indice: Union[slice, np.ndarray], key: str,
stack_num: Optional[int] = None) -> Union[Batch, np.ndarray]:
"""Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t],
where s is self.key, t is indice. The stack_num (here equals to 4) is
given from buffer initialization procedure.
@ -275,7 +291,7 @@ class ReplayBuffer(object):
stack = np.stack(stack, axis=1)
return stack
def __getitem__(self, index):
def __getitem__(self, index: Union[slice, np.ndarray]) -> Batch:
"""Return a data batch: self[index]. If stack_num is set to be > 0,
return the stacked obs and obs_next with shape [batch, len, ...].
"""
@ -302,17 +318,21 @@ class ListReplayBuffer(ReplayBuffer):
detailed explanation.
"""
def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
super().__init__(size=0, ignore_obs_next=False, **kwargs)
def _add_to_buffer(self, name, inst):
def _add_to_buffer(
self, name: str,
inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None:
if inst is None:
return
if self.__dict__.get(name, None) is None:
self.__dict__[name] = []
if name == 'info':
inst = deepcopy(inst)
self.__dict__[name].append(inst)
def reset(self):
def reset(self) -> None:
self._index = self._size = 0
for k in list(self.__dict__):
if isinstance(self.__dict__[k], list):
@ -332,8 +352,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
detailed explanation.
"""
def __init__(self, size, alpha: float, beta: float,
mode: str = 'weight', **kwargs):
def __init__(self, size: int, alpha: float, beta: float,
mode: Optional[str] = 'weight', **kwargs) -> None:
if mode != 'weight':
raise NotImplementedError
super().__init__(size, **kwargs)
@ -344,17 +364,27 @@ class PrioritizedReplayBuffer(ReplayBuffer):
self._amortization_freq = 50
self._amortization_counter = 0
def add(self, obs, act, rew, done, obs_next=0, info={}, policy={},
weight=1.0):
def add(self,
obs: Union[dict, np.ndarray],
act: Union[np.ndarray, float],
rew: float,
done: bool,
obs_next: Optional[Union[dict, np.ndarray]] = None,
info: Optional[dict] = {},
policy: Optional[Union[dict, Batch]] = {},
weight: Optional[float] = 1.0,
**kwargs) -> None:
"""Add a batch of data into replay buffer."""
self._weight_sum += np.abs(weight)**self._alpha - \
self._weight_sum += np.abs(weight) ** self._alpha - \
self.weight[self._index]
# we have to sacrifice some convenience for speed :(
self._add_to_buffer('weight', np.abs(weight) ** self._alpha)
super().add(obs, act, rew, done, obs_next, info, policy)
self._check_weight_sum()
def sample(self, batch_size: int = 0, importance_sample: bool = True):
def sample(self, batch_size: Optional[int] = 0,
importance_sample: Optional[bool] = True
) -> Tuple[Batch, np.ndarray]:
"""Get a random sample from buffer with priority probability. \
Return all the data in the buffer if batch_size is ``0``.
@ -388,11 +418,12 @@ class PrioritizedReplayBuffer(ReplayBuffer):
self._check_weight_sum()
return batch, indice
def reset(self):
def reset(self) -> None:
self._amortization_counter = 0
super().reset()
def update_weight(self, indice, new_weight: np.ndarray):
def update_weight(self, indice: Union[slice, np.ndarray],
new_weight: np.ndarray) -> None:
"""Update priority weight by indice in this buffer.
:param np.ndarray indice: indice you want to update weight
@ -402,7 +433,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
- self.weight[indice].sum()
self.weight[indice] = np.power(np.abs(new_weight), self._alpha)
def __getitem__(self, index):
def __getitem__(self, index: Union[slice, np.ndarray]) -> Batch:
return Batch(
obs=self.get(index, 'obs'),
act=self.act[index],
@ -415,7 +446,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
policy=self.get(index, 'policy'),
)
def _check_weight_sum(self):
def _check_weight_sum(self) -> None:
# keep an accurate _weight_sum
self._amortization_counter += 1
if self._amortization_counter % self._amortization_freq == 0:

View File

@ -1,10 +1,13 @@
import gym
import time
import torch
import warnings
import numpy as np
from typing import Any, Dict, List, Union, Optional, Callable
from tianshou.utils import MovAvg
from tianshou.env import BaseVectorEnv
from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer
@ -77,8 +80,14 @@ class Collector(object):
Please make sure the given environment has a time limitation.
"""
def __init__(self, policy, env, buffer=None, preprocess_fn=None,
stat_size=100, **kwargs):
def __init__(self,
policy: BasePolicy,
env: Union[gym.Env, BaseVectorEnv],
buffer: Optional[Union[ReplayBuffer, List[ReplayBuffer]]]
= None,
preprocess_fn: Callable[[Any], Union[dict, Batch]] = None,
stat_size: Optional[int] = 100,
**kwargs) -> None:
super().__init__()
self.env = env
self.env_num = 1
@ -112,7 +121,7 @@ class Collector(object):
self.stat_size = stat_size
self.reset()
def reset(self):
def reset(self) -> None:
"""Reset all related variables in the collector."""
self.reset_env()
self.reset_buffer()
@ -124,7 +133,7 @@ class Collector(object):
self.collect_episode = 0
self.collect_time = 0
def reset_buffer(self):
def reset_buffer(self) -> None:
"""Reset the main data buffer."""
if self._multi_buf:
for b in self.buffer:
@ -133,11 +142,11 @@ class Collector(object):
if self.buffer is not None:
self.buffer.reset()
def get_env_num(self):
def get_env_num(self) -> int:
"""Return the number of environments the collector have."""
return self.env_num
def reset_env(self):
def reset_env(self) -> None:
"""Reset all of the environment(s)' states and reset all of the cache
buffers (if need).
"""
@ -155,36 +164,36 @@ class Collector(object):
for b in self._cached_buf:
b.reset()
def seed(self, seed=None):
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
"""Reset all the seed(s) of the given environment(s)."""
if hasattr(self.env, 'seed'):
return self.env.seed(seed)
def render(self, **kwargs):
def render(self, **kwargs) -> None:
"""Render all the environment(s)."""
if hasattr(self.env, 'render'):
return self.env.render(**kwargs)
def close(self):
def close(self) -> None:
"""Close the environment(s)."""
if hasattr(self.env, 'close'):
self.env.close()
def _make_batch(self, data):
def _make_batch(self, data: Any) -> Union[Any, np.ndarray]:
"""Return [data]."""
if isinstance(data, np.ndarray):
return data[None]
else:
return np.array([data])
def _reset_state(self, id):
def _reset_state(self, id: Union[int, List[int]]) -> None:
"""Reset self.state[id]."""
if self.state is None:
return
if isinstance(self.state, list):
self.state[id] = None
elif isinstance(self.state, dict):
for k in self.state:
elif isinstance(self.state, dict) or isinstance(self.state, Batch):
for k in self.state.keys():
if isinstance(self.state[k], list):
self.state[k][id] = None
elif isinstance(self.state[k], torch.Tensor) or \
@ -194,7 +203,8 @@ class Collector(object):
isinstance(self.state, np.ndarray):
self.state[id] = 0
def _to_numpy(self, x):
def _to_numpy(self, x: Union[
torch.Tensor, dict, Batch, np.ndarray]) -> None:
"""Return an object without torch.Tensor."""
if isinstance(x, torch.Tensor):
return x.cpu().numpy()
@ -208,7 +218,12 @@ class Collector(object):
return x
return x
def collect(self, n_step=0, n_episode=0, render=None, log_fn=None):
def collect(self,
n_step: Optional[int] = 0,
n_episode: Optional[Union[int, List[int]]] = 0,
render: Optional[float] = None,
log_fn: Optional[Callable[[dict], None]] = None
) -> Dict[str, float]:
"""Collect a specified number of step or episode.
:param int n_step: how many steps you want to collect.
@ -375,7 +390,7 @@ class Collector(object):
'len': length_sum / n_episode,
}
def sample(self, batch_size):
def sample(self, batch_size: int) -> Batch:
"""Sample a data batch from the internal replay buffer. It will call
:meth:`~tianshou.policy.BasePolicy.process_fn` before returning
the final batch data.

View File

@ -2,6 +2,7 @@ import gym
import numpy as np
from abc import ABC, abstractmethod
from multiprocessing import Process, Pipe
from typing import List, Tuple, Union, Optional, Callable
try:
import ray
@ -36,16 +37,16 @@ class BaseVectorEnv(ABC, gym.Wrapper):
envs.close() # close all environments
"""
def __init__(self, env_fns):
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
self._env_fns = env_fns
self.env_num = len(env_fns)
def __len__(self):
def __len__(self) -> int:
"""Return len(self), which is the number of environments."""
return self.env_num
@abstractmethod
def reset(self, id=None):
def reset(self, id: Optional[Union[int, List[int]]] = None):
"""Reset the state of all the environments and return initial
observations if id is ``None``, otherwise reset the specific
environments with given id, either an int or a list.
@ -53,7 +54,8 @@ class BaseVectorEnv(ABC, gym.Wrapper):
pass
@abstractmethod
def step(self, action):
def step(self, action: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Run one timestep of all the environments dynamics. When the end of
episode is reached, you are responsible for calling reset(id) to reset
this environments state.
@ -76,19 +78,19 @@ class BaseVectorEnv(ABC, gym.Wrapper):
pass
@abstractmethod
def seed(self, seed=None):
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
"""Set the seed for all environments. Accept ``None``, an int (which
will extend ``i`` to ``[i, i + 1, i + 2, ...]``) or a list.
"""
pass
@abstractmethod
def render(self, **kwargs):
def render(self, **kwargs) -> None:
"""Render all of the environments."""
pass
@abstractmethod
def close(self):
def close(self) -> None:
"""Close all of the environments."""
pass
@ -102,11 +104,11 @@ class VectorEnv(BaseVectorEnv):
explanation.
"""
def __init__(self, env_fns):
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
super().__init__(env_fns)
self.envs = [_() for _ in env_fns]
def reset(self, id=None):
def reset(self, id: Optional[Union[int, List[int]]] = None) -> None:
if id is None:
self._obs = np.stack([e.reset() for e in self.envs])
else:
@ -116,7 +118,8 @@ class VectorEnv(BaseVectorEnv):
self._obs[i] = self.envs[i].reset()
return self._obs
def step(self, action):
def step(self, action: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
assert len(action) == self.env_num
result = [e.step(a) for e, a in zip(self.envs, action)]
self._obs, self._rew, self._done, self._info = zip(*result)
@ -126,7 +129,7 @@ class VectorEnv(BaseVectorEnv):
self._info = np.stack(self._info)
return self._obs, self._rew, self._done, self._info
def seed(self, seed=None):
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
if np.isscalar(seed):
seed = [seed + _ for _ in range(self.env_num)]
elif seed is None:
@ -137,14 +140,14 @@ class VectorEnv(BaseVectorEnv):
result.append(e.seed(s))
return result
def render(self, **kwargs):
def render(self, **kwargs) -> None:
result = []
for e in self.envs:
if hasattr(e, 'render'):
result.append(e.render(**kwargs))
return result
def close(self):
def close(self) -> None:
return [e.close() for e in self.envs]
@ -182,7 +185,7 @@ class SubprocVectorEnv(BaseVectorEnv):
explanation.
"""
def __init__(self, env_fns):
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
super().__init__(env_fns)
self.closed = False
self.parent_remote, self.child_remote = \
@ -198,7 +201,8 @@ class SubprocVectorEnv(BaseVectorEnv):
for c in self.child_remote:
c.close()
def step(self, action):
def step(self, action: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
assert len(action) == self.env_num
for p, a in zip(self.parent_remote, action):
p.send(['step', a])
@ -210,7 +214,7 @@ class SubprocVectorEnv(BaseVectorEnv):
self._info = np.stack(self._info)
return self._obs, self._rew, self._done, self._info
def reset(self, id=None):
def reset(self, id: Optional[Union[int, List[int]]] = None) -> None:
if id is None:
for p in self.parent_remote:
p.send(['reset', None])
@ -225,7 +229,7 @@ class SubprocVectorEnv(BaseVectorEnv):
self._obs[i] = self.parent_remote[i].recv()
return self._obs
def seed(self, seed=None):
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
if np.isscalar(seed):
seed = [seed + _ for _ in range(self.env_num)]
elif seed is None:
@ -234,12 +238,12 @@ class SubprocVectorEnv(BaseVectorEnv):
p.send(['seed', s])
return [p.recv() for p in self.parent_remote]
def render(self, **kwargs):
def render(self, **kwargs) -> None:
for p in self.parent_remote:
p.send(['render', kwargs])
return [p.recv() for p in self.parent_remote]
def close(self):
def close(self) -> None:
if self.closed:
return
for p in self.parent_remote:
@ -263,7 +267,7 @@ class RayVectorEnv(BaseVectorEnv):
explanation.
"""
def __init__(self, env_fns):
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
super().__init__(env_fns)
try:
if not ray.is_initialized():
@ -275,7 +279,8 @@ class RayVectorEnv(BaseVectorEnv):
ray.remote(gym.Wrapper).options(num_cpus=0).remote(e())
for e in env_fns]
def step(self, action):
def step(self, action: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
assert len(action) == self.env_num
result = ray.get([e.step.remote(a) for e, a in zip(self.envs, action)])
self._obs, self._rew, self._done, self._info = zip(*result)
@ -285,7 +290,7 @@ class RayVectorEnv(BaseVectorEnv):
self._info = np.stack(self._info)
return self._obs, self._rew, self._done, self._info
def reset(self, id=None):
def reset(self, id: Optional[Union[int, List[int]]] = None) -> None:
if id is None:
result_obj = [e.reset.remote() for e in self.envs]
self._obs = np.stack(ray.get(result_obj))
@ -299,7 +304,7 @@ class RayVectorEnv(BaseVectorEnv):
self._obs[i] = ray.get(result_obj[_])
return self._obs
def seed(self, seed=None):
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
if not hasattr(self.envs[0], 'seed'):
return
if np.isscalar(seed):
@ -308,10 +313,10 @@ class RayVectorEnv(BaseVectorEnv):
seed = [seed] * self.env_num
return ray.get([e.seed.remote(s) for e, s in zip(self.envs, seed)])
def render(self, **kwargs):
def render(self, **kwargs) -> None:
if not hasattr(self.envs[0], 'render'):
return
return ray.get([e.render.remote(**kwargs) for e in self.envs])
def close(self):
def close(self) -> None:
return ray.get([e.close.remote() for e in self.envs])

View File

@ -1,4 +1,5 @@
import numpy as np
from typing import Union, Optional
class OUNoise(object):
@ -17,13 +18,18 @@ class OUNoise(object):
Ornstein-Uhlenbeck process.
"""
def __init__(self, sigma=0.3, theta=0.15, dt=1e-2, x0=None):
def __init__(self,
sigma: Optional[float] = 0.3,
theta: Optional[float] = 0.15,
dt: Optional[float] = 1e-2,
x0: Optional[Union[float, np.ndarray]] = None
) -> None:
self.alpha = theta * dt
self.beta = sigma * np.sqrt(dt)
self.x0 = x0
self.reset()
def __call__(self, size, mu=.1):
def __call__(self, size: tuple, mu: Optional[float] = .1) -> np.ndarray:
"""Generate new noise. Return a ``numpy.ndarray`` which size is equal
to ``size``.
"""
@ -33,6 +39,6 @@ class OUNoise(object):
self.x = self.x + self.alpha * (mu - self.x) + r
return self.x
def reset(self):
def reset(self) -> None:
"""Reset to the initial state."""
self.x = None

View File

@ -1,6 +1,10 @@
import torch
import numpy as np
from torch import nn
from abc import ABC, abstractmethod
from typing import Dict, List, Union, Optional
from tianshou.data import Batch, ReplayBuffer
class BasePolicy(ABC, nn.Module):
@ -39,17 +43,20 @@ class BasePolicy(ABC, nn.Module):
policy.load_state_dict(torch.load('policy.pth'))
"""
def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
super().__init__()
def process_fn(self, batch, buffer, indice):
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
"""Pre-process the data from the provided replay buffer. Check out
:ref:`policy_concept` for more information.
"""
return batch
@abstractmethod
def forward(self, batch, state=None, **kwargs):
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs) -> Batch:
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which MUST have the following\
@ -80,7 +87,8 @@ class BasePolicy(ABC, nn.Module):
pass
@abstractmethod
def learn(self, batch, **kwargs):
def learn(self, batch: Batch, **kwargs
) -> Dict[str, Union[float, List[float]]]:
"""Update policy with a given batch of data.
:return: A dict which includes loss and its corresponding label.
@ -88,8 +96,11 @@ class BasePolicy(ABC, nn.Module):
pass
@staticmethod
def compute_episodic_return(batch, v_s_=None,
gamma=0.99, gae_lambda=0.95):
def compute_episodic_return(
batch: Batch,
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
gamma: Optional[float] = 0.99,
gae_lambda: Optional[float] = 0.95) -> Batch:
"""Compute returns over given full-length episodes, including the
implementation of Generalized Advantage Estimation (arXiv:1506.02438).

View File

@ -1,5 +1,7 @@
import torch
import numpy as np
import torch.nn.functional as F
from typing import Dict, Union, Optional
from tianshou.data import Batch
from tianshou.policy import BasePolicy
@ -19,7 +21,9 @@ class ImitationPolicy(BasePolicy):
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(self, model, optim, mode='continuous'):
def __init__(self, model: torch.nn.Module, optim: torch.optim.Optimizer,
mode: Optional[str] = 'continuous', **kwargs) -> None:
super().__init__()
self.model = model
self.optim = optim
@ -27,7 +31,10 @@ class ImitationPolicy(BasePolicy):
f'Mode {mode} is not in ["continuous", "discrete"]'
self.mode = mode
def forward(self, batch, state=None):
def forward(self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs) -> Batch:
logits, h = self.model(batch.obs, state=state, info=batch.info)
if self.mode == 'discrete':
a = logits.max(dim=1)[1]
@ -35,7 +42,7 @@ class ImitationPolicy(BasePolicy):
a = logits
return Batch(logits=logits, act=a, state=h)
def learn(self, batch, **kwargs):
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
self.optim.zero_grad()
if self.mode == 'continuous':
a = self(batch).act

View File

@ -2,9 +2,10 @@ import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from typing import Dict, List, Union, Optional
from tianshou.data import Batch
from tianshou.policy import PGPolicy
from tianshou.data import Batch, ReplayBuffer
class A2CPolicy(PGPolicy):
@ -31,11 +32,19 @@ class A2CPolicy(PGPolicy):
explanation.
"""
def __init__(self, actor, critic, optim,
dist_fn=torch.distributions.Categorical,
discount_factor=0.99, vf_coef=.5, ent_coef=.01,
max_grad_norm=None, gae_lambda=0.95,
reward_normalization=False, **kwargs):
def __init__(self,
actor: torch.nn.Module,
critic: torch.nn.Module,
optim: torch.optim.Optimizer,
dist_fn: Optional[torch.distributions.Distribution]
= torch.distributions.Categorical,
discount_factor: Optional[float] = 0.99,
vf_coef: Optional[float] = .5,
ent_coef: Optional[float] = .01,
max_grad_norm: Optional[float] = None,
gae_lambda: Optional[float] = 0.95,
reward_normalization: Optional[bool] = False,
**kwargs) -> None:
super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
self.actor = actor
self.critic = critic
@ -48,7 +57,8 @@ class A2CPolicy(PGPolicy):
self._rew_norm = reward_normalization
self.__eps = np.finfo(np.float32).eps.item()
def process_fn(self, batch, buffer, indice):
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
if self._lambda in [0, 1]:
return self.compute_episodic_return(
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
@ -60,7 +70,9 @@ class A2CPolicy(PGPolicy):
return self.compute_episodic_return(
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
def forward(self, batch, state=None, **kwargs):
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs) -> Batch:
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
@ -83,7 +95,8 @@ class A2CPolicy(PGPolicy):
act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist)
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
def learn(self, batch: Batch, batch_size: int, repeat: int,
**kwargs) -> Dict[str, List[float]]:
self._batch = batch_size
r = batch.returns
if self._rew_norm and r.std() > self.__eps:

View File

@ -2,10 +2,11 @@ import torch
import numpy as np
from copy import deepcopy
import torch.nn.functional as F
from typing import Dict, Tuple, Union, Optional
from tianshou.data import Batch
from tianshou.policy import BasePolicy
# from tianshou.exploration import OUNoise
from tianshou.data import Batch, ReplayBuffer
class DDPGPolicy(BasePolicy):
@ -23,7 +24,7 @@ class DDPGPolicy(BasePolicy):
:param float exploration_noise: the noise intensity, add to the action,
defaults to 0.1.
:param action_range: the action range (minimum, maximum).
:type action_range: [float, float]
:type action_range: (float, float)
:param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to ``False``.
:param bool ignore_done: ignore the done flag while training the policy,
@ -35,10 +36,18 @@ class DDPGPolicy(BasePolicy):
explanation.
"""
def __init__(self, actor, actor_optim, critic, critic_optim,
tau=0.005, gamma=0.99, exploration_noise=0.1,
action_range=None, reward_normalization=False,
ignore_done=False, **kwargs):
def __init__(self,
actor: torch.nn.Module,
actor_optim: torch.optim.Optimizer,
critic: torch.nn.Module,
critic_optim: torch.optim.Optimizer,
tau: Optional[float] = 0.005,
gamma: Optional[float] = 0.99,
exploration_noise: Optional[float] = 0.1,
action_range: Optional[Tuple[float, float]] = None,
reward_normalization: Optional[bool] = False,
ignore_done: Optional[bool] = False,
**kwargs) -> None:
super().__init__(**kwargs)
if actor is not None:
self.actor, self.actor_old = actor, deepcopy(actor)
@ -64,23 +73,23 @@ class DDPGPolicy(BasePolicy):
self._rew_norm = reward_normalization
self.__eps = np.finfo(np.float32).eps.item()
def set_eps(self, eps):
def set_eps(self, eps: float) -> None:
"""Set the eps for exploration."""
self._eps = eps
def train(self):
def train(self) -> None:
"""Set the module in training mode, except for the target network."""
self.training = True
self.actor.train()
self.critic.train()
def eval(self):
def eval(self) -> None:
"""Set the module in evaluation mode, except for the target network."""
self.training = False
self.actor.eval()
self.critic.eval()
def sync_weight(self):
def sync_weight(self) -> None:
"""Soft-update the weight for the target network."""
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
@ -88,7 +97,8 @@ class DDPGPolicy(BasePolicy):
self.critic_old.parameters(), self.critic.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
def process_fn(self, batch, buffer, indice):
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
if self._rew_norm:
bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer
mean, std = bfr.mean(), bfr.std()
@ -98,8 +108,12 @@ class DDPGPolicy(BasePolicy):
batch.done = batch.done * 0.
return batch
def forward(self, batch, state=None,
model='actor', input='obs', eps=None, **kwargs):
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
model: Optional[str] = 'actor',
input: Optional[str] = 'obs',
eps: Optional[float] = None,
**kwargs) -> Batch:
"""Compute action over the given batch data.
:param float eps: in [0, 1], for exploration use.
@ -129,7 +143,7 @@ class DDPGPolicy(BasePolicy):
logits = logits.clamp(self._range[0], self._range[1])
return Batch(act=logits, state=h)
def learn(self, batch, **kwargs):
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
with torch.no_grad():
target_q = self.critic_old(batch.obs_next, self(
batch, model='actor_old', input='obs_next', eps=0).act)

View File

@ -2,9 +2,10 @@ import torch
import numpy as np
from copy import deepcopy
import torch.nn.functional as F
from typing import Dict, Union, Optional
from tianshou.policy import BasePolicy
from tianshou.data import Batch, PrioritizedReplayBuffer
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer
class DQNPolicy(BasePolicy):
@ -25,8 +26,13 @@ class DQNPolicy(BasePolicy):
explanation.
"""
def __init__(self, model, optim, discount_factor=0.99,
estimation_step=1, target_update_freq=0, **kwargs):
def __init__(self,
model: torch.nn.Module,
optim: torch.optim.Optimizer,
discount_factor: Optional[float] = 0.99,
estimation_step: Optional[int] = 1,
target_update_freq: Optional[int] = 0,
**kwargs) -> None:
super().__init__(**kwargs)
self.model = model
self.optim = optim
@ -42,25 +48,26 @@ class DQNPolicy(BasePolicy):
self.model_old = deepcopy(self.model)
self.model_old.eval()
def set_eps(self, eps):
def set_eps(self, eps: float) -> None:
"""Set the eps for epsilon-greedy exploration."""
self.eps = eps
def train(self):
def train(self) -> None:
"""Set the module in training mode, except for the target network."""
self.training = True
self.model.train()
def eval(self):
def eval(self) -> None:
"""Set the module in evaluation mode, except for the target network."""
self.training = False
self.model.eval()
def sync_weight(self):
def sync_weight(self) -> None:
"""Synchronize the weight for the target network."""
self.model_old.load_state_dict(self.model.state_dict())
def process_fn(self, batch, buffer, indice):
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
r"""Compute the n-step return for Q-learning targets:
.. math::
@ -115,8 +122,12 @@ class DQNPolicy(BasePolicy):
batch.loss += loss
return batch
def forward(self, batch, state=None,
model='model', input='obs', eps=None, **kwargs):
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
model: Optional[str] = 'model',
input: Optional[str] = 'obs',
eps: Optional[float] = None,
**kwargs) -> Batch:
"""Compute action over the given batch data.
:param float eps: in [0, 1], for epsilon-greedy exploration method.
@ -144,7 +155,7 @@ class DQNPolicy(BasePolicy):
act[i] = np.random.randint(q.shape[1])
return Batch(logits=q, act=act, state=h)
def learn(self, batch, **kwargs):
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
if self._target and self._cnt % self._freq == 0:
self.sync_weight()
self.optim.zero_grad()

View File

@ -1,8 +1,9 @@
import torch
import numpy as np
from typing import Dict, List, Union, Optional
from tianshou.data import Batch
from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer
class PGPolicy(BasePolicy):
@ -20,8 +21,14 @@ class PGPolicy(BasePolicy):
explanation.
"""
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
discount_factor=0.99, reward_normalization=False, **kwargs):
def __init__(self,
model: torch.nn.Module,
optim: torch.optim.Optimizer,
dist_fn: Optional[torch.distributions.Distribution]
= torch.distributions.Categorical,
discount_factor: Optional[float] = 0.99,
reward_normalization: Optional[bool] = False,
**kwargs) -> None:
super().__init__(**kwargs)
self.model = model
self.optim = optim
@ -31,7 +38,8 @@ class PGPolicy(BasePolicy):
self._rew_norm = reward_normalization
self.__eps = np.finfo(np.float32).eps.item()
def process_fn(self, batch, buffer, indice):
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
r"""Compute the discounted returns for each frame:
.. math::
@ -46,7 +54,9 @@ class PGPolicy(BasePolicy):
return self.compute_episodic_return(
batch, gamma=self._gamma, gae_lambda=1.)
def forward(self, batch, state=None, **kwargs):
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs) -> Batch:
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
@ -69,7 +79,8 @@ class PGPolicy(BasePolicy):
act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist)
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
def learn(self, batch: Batch, batch_size: int, repeat: int,
**kwargs) -> Dict[str, List[float]]:
losses = []
r = batch.returns
if self._rew_norm and r.std() > self.__eps:

View File

@ -1,9 +1,10 @@
import torch
import numpy as np
from torch import nn
from typing import Dict, List, Tuple, Union, Optional
from tianshou.data import Batch
from tianshou.policy import PGPolicy
from tianshou.data import Batch, ReplayBuffer
class PPOPolicy(PGPolicy):
@ -23,7 +24,7 @@ class PPOPolicy(PGPolicy):
:param float vf_coef: weight for value loss, defaults to 0.5.
:param float ent_coef: weight for entropy loss, defaults to 0.01.
:param action_range: the action range (minimum, maximum).
:type action_range: [float, float]
:type action_range: (float, float)
:param float gae_lambda: in [0, 1], param for Generalized Advantage
Estimation, defaults to 0.95.
:param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5,
@ -40,11 +41,22 @@ class PPOPolicy(PGPolicy):
explanation.
"""
def __init__(self, actor, critic, optim, dist_fn,
discount_factor=0.99, max_grad_norm=.5, eps_clip=.2,
vf_coef=.5, ent_coef=.01, action_range=None, gae_lambda=0.95,
dual_clip=5., value_clip=True, reward_normalization=True,
**kwargs):
def __init__(self,
actor: torch.nn.Module,
critic: torch.nn.Module,
optim: torch.optim.Optimizer,
dist_fn: torch.distributions.Distribution,
discount_factor: Optional[float] = 0.99,
max_grad_norm: Optional[float] = None,
eps_clip: Optional[float] = .2,
vf_coef: Optional[float] = .5,
ent_coef: Optional[float] = .01,
action_range: Optional[Tuple[float, float]] = None,
gae_lambda: Optional[float] = 0.95,
dual_clip: Optional[float] = 5.,
value_clip: Optional[bool] = True,
reward_normalization: Optional[bool] = True,
**kwargs) -> None:
super().__init__(None, None, dist_fn, discount_factor, **kwargs)
self._max_grad_norm = max_grad_norm
self._eps_clip = eps_clip
@ -64,7 +76,8 @@ class PPOPolicy(PGPolicy):
self._rew_norm = reward_normalization
self.__eps = np.finfo(np.float32).eps.item()
def process_fn(self, batch, buffer, indice):
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
if self._rew_norm:
mean, std = batch.rew.mean(), batch.rew.std()
if std > self.__eps:
@ -80,7 +93,9 @@ class PPOPolicy(PGPolicy):
return self.compute_episodic_return(
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
def forward(self, batch, state=None, **kwargs):
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs) -> Batch:
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
@ -105,7 +120,8 @@ class PPOPolicy(PGPolicy):
act = act.clamp(self._range[0], self._range[1])
return Batch(logits=logits, act=act, state=h, dist=dist)
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
def learn(self, batch: Batch, batch_size: int, repeat: int,
**kwargs) -> Dict[str, List[float]]:
self._batch = batch_size
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
v = []

View File

@ -2,6 +2,7 @@ import torch
import numpy as np
from copy import deepcopy
import torch.nn.functional as F
from typing import Dict, Tuple, Union, Optional
from tianshou.data import Batch
from tianshou.policy import DDPGPolicy
@ -28,7 +29,7 @@ class SACPolicy(DDPGPolicy):
defaults to 0.1.
:param float alpha: entropy regularization coefficient, default to 0.2.
:param action_range: the action range (minimum, maximum).
:type action_range: [float, float]
:type action_range: (float, float)
:param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to ``False``.
:param bool ignore_done: ignore the done flag while training the policy,
@ -40,10 +41,20 @@ class SACPolicy(DDPGPolicy):
explanation.
"""
def __init__(self, actor, actor_optim, critic1, critic1_optim,
critic2, critic2_optim, tau=0.005, gamma=0.99,
alpha=0.2, action_range=None, reward_normalization=False,
ignore_done=False, **kwargs):
def __init__(self,
actor: torch.nn.Module,
actor_optim: torch.optim.Optimizer,
critic1: torch.nn.Module,
critic1_optim: torch.optim.Optimizer,
critic2: torch.nn.Module,
critic2_optim: torch.optim.Optimizer,
tau: Optional[float] = 0.005,
gamma: Optional[float] = 0.99,
alpha: Optional[float] = 0.2,
action_range: Optional[Tuple[float, float]] = None,
reward_normalization: Optional[bool] = False,
ignore_done: Optional[bool] = False,
**kwargs) -> None:
super().__init__(None, None, None, None, tau, gamma, 0,
action_range, reward_normalization, ignore_done,
**kwargs)
@ -57,19 +68,19 @@ class SACPolicy(DDPGPolicy):
self._alpha = alpha
self.__eps = np.finfo(np.float32).eps.item()
def train(self):
def train(self) -> None:
self.training = True
self.actor.train()
self.critic1.train()
self.critic2.train()
def eval(self):
def eval(self) -> None:
self.training = False
self.actor.eval()
self.critic1.eval()
self.critic2.eval()
def sync_weight(self):
def sync_weight(self) -> None:
for o, n in zip(
self.critic1_old.parameters(), self.critic1.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
@ -77,7 +88,9 @@ class SACPolicy(DDPGPolicy):
self.critic2_old.parameters(), self.critic2.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
def forward(self, batch, state=None, input='obs', **kwargs):
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
input: Optional[str] = 'obs', **kwargs) -> Batch:
obs = getattr(batch, input)
logits, h = self.actor(obs, state=state, info=batch.info)
assert isinstance(logits, tuple)
@ -92,7 +105,7 @@ class SACPolicy(DDPGPolicy):
return Batch(
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
def learn(self, batch, **kwargs):
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
with torch.no_grad():
obs_next_result = self(batch, input='obs_next')
a_ = obs_next_result.act

View File

@ -1,7 +1,9 @@
import torch
from copy import deepcopy
import torch.nn.functional as F
from typing import Dict, Tuple, Optional
from tianshou.data import Batch
from tianshou.policy import DDPGPolicy
@ -32,7 +34,7 @@ class TD3Policy(DDPGPolicy):
:param float noise_clip: the clipping range used in updating policy
network, default to 0.5.
:param action_range: the action range (minimum, maximum).
:type action_range: [float, float]
:type action_range: (float, float)
:param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to ``False``.
:param bool ignore_done: ignore the done flag while training the policy,
@ -44,11 +46,23 @@ class TD3Policy(DDPGPolicy):
explanation.
"""
def __init__(self, actor, actor_optim, critic1, critic1_optim,
critic2, critic2_optim, tau=0.005, gamma=0.99,
exploration_noise=0.1, policy_noise=0.2, update_actor_freq=2,
noise_clip=0.5, action_range=None,
reward_normalization=False, ignore_done=False, **kwargs):
def __init__(self,
actor: torch.nn.Module,
actor_optim: torch.optim.Optimizer,
critic1: torch.nn.Module,
critic1_optim: torch.optim.Optimizer,
critic2: torch.nn.Module,
critic2_optim: torch.optim.Optimizer,
tau: Optional[float] = 0.005,
gamma: Optional[float] = 0.99,
exploration_noise: Optional[float] = 0.1,
policy_noise: Optional[float] = 0.2,
update_actor_freq: Optional[int] = 2,
noise_clip: Optional[float] = 0.5,
action_range: Optional[Tuple[float, float]] = None,
reward_normalization: Optional[bool] = False,
ignore_done: Optional[bool] = False,
**kwargs) -> None:
super().__init__(actor, actor_optim, None, None, tau, gamma,
exploration_noise, action_range, reward_normalization,
ignore_done, **kwargs)
@ -64,19 +78,19 @@ class TD3Policy(DDPGPolicy):
self._cnt = 0
self._last = 0
def train(self):
def train(self) -> None:
self.training = True
self.actor.train()
self.critic1.train()
self.critic2.train()
def eval(self):
def eval(self) -> None:
self.training = False
self.actor.eval()
self.critic1.eval()
self.critic2.eval()
def sync_weight(self):
def sync_weight(self) -> None:
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
for o, n in zip(
@ -86,7 +100,7 @@ class TD3Policy(DDPGPolicy):
self.critic2_old.parameters(), self.critic2.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
def learn(self, batch, **kwargs):
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
with torch.no_grad():
a_ = self(batch, model='actor_old', input='obs_next').act
dev = a_.device

View File

@ -1,16 +1,33 @@
import time
import tqdm
from torch.utils.tensorboard import SummaryWriter
from typing import Dict, List, Union, Callable, Optional
from tianshou.data import Collector
from tianshou.policy import BasePolicy
from tianshou.utils import tqdm_config, MovAvg
from tianshou.trainer import test_episode, gather_info
def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
step_per_epoch, collect_per_step, episode_per_test,
batch_size,
train_fn=None, test_fn=None, stop_fn=None, save_fn=None,
log_fn=None, writer=None, log_interval=1, verbose=True,
**kwargs):
def offpolicy_trainer(
policy: BasePolicy,
train_collector: Collector,
test_collector: Collector,
max_epoch: int,
step_per_epoch: int,
collect_per_step: int,
episode_per_test: Union[int, List[int]],
batch_size: int,
train_fn: Optional[Callable[[int], None]] = None,
test_fn: Optional[Callable[[int], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
log_fn: Optional[Callable[[dict], None]] = None,
writer: Optional[SummaryWriter] = None,
log_interval: Optional[int] = 1,
verbose: Optional[bool] = True,
**kwargs
) -> Dict[str, Union[float, str]]:
"""A wrapper for off-policy trainer procedure.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`

View File

@ -1,16 +1,34 @@
import time
import tqdm
from torch.utils.tensorboard import SummaryWriter
from typing import Dict, List, Union, Callable, Optional
from tianshou.data import Collector
from tianshou.policy import BasePolicy
from tianshou.utils import tqdm_config, MovAvg
from tianshou.trainer import test_episode, gather_info
def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
step_per_epoch, collect_per_step, repeat_per_collect,
episode_per_test, batch_size,
train_fn=None, test_fn=None, stop_fn=None, save_fn=None,
log_fn=None, writer=None, log_interval=1, verbose=True,
**kwargs):
def onpolicy_trainer(
policy: BasePolicy,
train_collector: Collector,
test_collector: Collector,
max_epoch: int,
step_per_epoch: int,
collect_per_step: int,
repeat_per_collect: int,
episode_per_test: Union[int, List[int]],
batch_size: int,
train_fn: Optional[Callable[[int], None]] = None,
test_fn: Optional[Callable[[int], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
log_fn: Optional[Callable[[dict], None]] = None,
writer: Optional[SummaryWriter] = None,
log_interval: Optional[int] = 1,
verbose: Optional[bool] = True,
**kwargs
) -> Dict[str, Union[float, str]]:
"""A wrapper for on-policy trainer procedure.
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`

View File

@ -1,8 +1,17 @@
import time
import numpy as np
from typing import Dict, List, Union, Callable
from tianshou.data import Collector
from tianshou.policy import BasePolicy
def test_episode(policy, collector, test_fn, epoch, n_episode):
def test_episode(
policy: BasePolicy,
collector: Collector,
test_fn: Callable[[int], None],
epoch: int,
n_episode: Union[int, List[int]]) -> Dict[str, float]:
"""A simple wrapper of testing policy in collector."""
collector.reset_env()
collector.reset_buffer()
@ -17,7 +26,11 @@ def test_episode(policy, collector, test_fn, epoch, n_episode):
return collector.collect(n_episode=n_episode)
def gather_info(start_time, train_c, test_c, best_reward):
def gather_info(start_time: float,
train_c: Collector,
test_c: Collector,
best_reward: float
) -> Dict[str, Union[float, str]]:
"""A simple wrapper of gathering information from collectors.
:return: A dictionary with the following keys:

View File

@ -1,5 +1,6 @@
import torch
import numpy as np
from typing import Union, Optional
class MovAvg(object):
@ -19,19 +20,20 @@ class MovAvg(object):
>>> print(f'{stat.mean():.2f}±{stat.std():.2f}')
6.50±1.12
"""
def __init__(self, size=100):
def __init__(self, size: Optional[int] = 100) -> None:
super().__init__()
self.size = size
self.cache = []
self.banned = [np.inf, np.nan, -np.inf]
def add(self, x):
def add(self, x: Union[float, list, np.ndarray, torch.Tensor]) -> float:
"""Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with
only one element, a python scalar, or a list of python scalar.
"""
if isinstance(x, torch.Tensor):
x = x.item()
if isinstance(x, list):
if isinstance(x, list) or isinstance(x, np.ndarray):
for _ in x:
if _ not in self.banned:
self.cache.append(_)
@ -41,17 +43,17 @@ class MovAvg(object):
self.cache = self.cache[-self.size:]
return self.get()
def get(self):
def get(self) -> float:
"""Get the average."""
if len(self.cache) == 0:
return 0
return np.mean(self.cache)
def mean(self):
def mean(self) -> float:
"""Get the average. Same as :meth:`get`."""
return self.get()
def std(self):
def std(self) -> float:
"""Get the standard deviation."""
if len(self.cache) == 0:
return 0