add type annotation
This commit is contained in:
parent
075825325e
commit
9b26137cd2
3
.github/ISSUE_TEMPLATE.md
vendored
3
.github/ISSUE_TEMPLATE.md
vendored
@ -4,7 +4,7 @@
|
||||
+ [ ] documentation request (i.e. "X is missing from the documentation.")
|
||||
+ [ ] new feature request
|
||||
- [ ] I have visited the [source website], and in particular read the [known issues]
|
||||
- [ ] I have searched through the [issue categories] for duplicates
|
||||
- [ ] I have searched through the [issue tracker] and [issue categories] for duplicates
|
||||
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
|
||||
```python
|
||||
import tianshou, torch, sys
|
||||
@ -14,3 +14,4 @@
|
||||
[source website]: https://github.com/thu-ml/tianshou/
|
||||
[known issues]: https://github.com/thu-ml/tianshou/#faq-and-known-issues
|
||||
[issue categories]: https://github.com/thu-ml/tianshou/projects/2
|
||||
[issue tracker]: https://github.com/thu-ml/tianshou/issues?q=
|
||||
|
3
.github/PULL_REQUEST_TEMPLATE.md
vendored
3
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -8,7 +8,7 @@
|
||||
Less important but also useful:
|
||||
|
||||
- [ ] I have visited the [source website], and in particular read the [known issues]
|
||||
- [ ] I have searched through the [issue categories] for duplicates
|
||||
- [ ] I have searched through the [issue tracker] and [issue categories] for duplicates
|
||||
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
|
||||
```python
|
||||
import tianshou, torch, sys
|
||||
@ -18,3 +18,4 @@ Less important but also useful:
|
||||
[source website]: https://github.com/thu-ml/tianshou
|
||||
[known issues]: https://github.com/thu-ml/tianshou/#faq-and-known-issues
|
||||
[issue categories]: https://github.com/thu-ml/tianshou/projects/2
|
||||
[issue tracker]: https://github.com/thu-ml/tianshou/issues?q=
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import pprint
|
||||
@ -32,6 +33,7 @@ def get_args():
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument('--rew-norm', type=bool, default=True)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
@ -73,14 +75,18 @@ def test_sac(args=get_args()):
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
args.tau, args.gamma, args.alpha,
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
reward_normalization=True, ignore_done=True)
|
||||
reward_normalization=args.rew_norm, ignore_done=True)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# train_collector.collect(n_step=args.buffer_size)
|
||||
# log
|
||||
writer = SummaryWriter(args.logdir + '/' + 'sac')
|
||||
log_path = os.path.join(args.logdir, args.task, 'sac')
|
||||
writer = SummaryWriter(log_path)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
@ -89,7 +95,7 @@ def test_sac(args=get_args()):
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
|
||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import pprint
|
||||
import numpy as np
|
||||
from typing import Any, List, Union, Iterator, Optional
|
||||
|
||||
|
||||
class Batch(object):
|
||||
@ -19,8 +20,8 @@ class Batch(object):
|
||||
>>> print(data)
|
||||
Batch(
|
||||
a: 4,
|
||||
b: [3 4 5],
|
||||
c: 2312312,
|
||||
b: array([3, 4, 5]),
|
||||
c: '2312312',
|
||||
)
|
||||
|
||||
In short, you can define a :class:`Batch` with any key-value pair. The
|
||||
@ -65,7 +66,7 @@ class Batch(object):
|
||||
[11 22] [6 6]
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self._meta = {}
|
||||
for k, v in kwargs.items():
|
||||
@ -85,7 +86,7 @@ class Batch(object):
|
||||
else:
|
||||
self.__dict__[k] = kwargs[k]
|
||||
|
||||
def __getitem__(self, index):
|
||||
def __getitem__(self, index: Union[str, slice]) -> Union['Batch', dict]:
|
||||
"""Return self[index]."""
|
||||
if isinstance(index, str):
|
||||
return self.__getattr__(index)
|
||||
@ -96,7 +97,7 @@ class Batch(object):
|
||||
b._meta = self._meta
|
||||
return b
|
||||
|
||||
def __getattr__(self, key):
|
||||
def __getattr__(self, key: str) -> Union['Batch', Any]:
|
||||
"""Return self.key"""
|
||||
if key not in self._meta:
|
||||
if key not in self.__dict__:
|
||||
@ -108,7 +109,7 @@ class Batch(object):
|
||||
d[k_] = self.__dict__[k__]
|
||||
return Batch(**d)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
"""Return str(self)."""
|
||||
s = self.__class__.__name__ + '(\n'
|
||||
flag = False
|
||||
@ -125,18 +126,18 @@ class Batch(object):
|
||||
s = self.__class__.__name__ + '()'
|
||||
return s
|
||||
|
||||
def keys(self):
|
||||
def keys(self) -> List[str]:
|
||||
"""Return self.keys()."""
|
||||
return sorted([i for i in self.__dict__ if i[0] != '_'] +
|
||||
list(self._meta))
|
||||
return sorted([
|
||||
i for i in self.__dict__ if i[0] != '_'] + list(self._meta))
|
||||
|
||||
def get(self, k, d=None):
|
||||
def get(self, k: str, d: Optional[Any] = None) -> Union['Batch', Any]:
|
||||
"""Return self[k] if k in self else d. d defaults to None."""
|
||||
if k in self.__dict__ or k in self._meta:
|
||||
return self.__getattr__(k)
|
||||
return d
|
||||
|
||||
def to_numpy(self):
|
||||
def to_numpy(self) -> np.ndarray:
|
||||
"""Change all torch.Tensor to numpy.ndarray. This is an inplace
|
||||
operation.
|
||||
"""
|
||||
@ -144,7 +145,7 @@ class Batch(object):
|
||||
if isinstance(self.__dict__[k], torch.Tensor):
|
||||
self.__dict__[k] = self.__dict__[k].cpu().numpy()
|
||||
|
||||
def append(self, batch):
|
||||
def append(self, batch: 'Batch') -> None:
|
||||
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
|
||||
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
|
||||
for k in batch.__dict__:
|
||||
@ -169,13 +170,14 @@ class Batch(object):
|
||||
+ 'in class Batch.'
|
||||
raise TypeError(s)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
"""Return len(self)."""
|
||||
return min([
|
||||
len(self.__dict__[k]) for k in self.__dict__
|
||||
if k != '_meta' and self.__dict__[k] is not None])
|
||||
|
||||
def split(self, size=None, shuffle=True):
|
||||
def split(self, size: Optional[int] = None,
|
||||
shuffle: Optional[bool] = True) -> Iterator['Batch']:
|
||||
"""Split whole data into multiple small batch.
|
||||
|
||||
:param int size: if it is ``None``, it does not split the data batch;
|
||||
|
@ -1,5 +1,8 @@
|
||||
import pprint
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from typing import Tuple, Union, Optional
|
||||
|
||||
from tianshou.data.batch import Batch
|
||||
|
||||
|
||||
@ -92,7 +95,8 @@ class ReplayBuffer(object):
|
||||
[ 7. 7. 8. 9.]]
|
||||
"""
|
||||
|
||||
def __init__(self, size, stack_num=0, ignore_obs_next=False, **kwargs):
|
||||
def __init__(self, size: int, stack_num: Optional[int] = 0,
|
||||
ignore_obs_next: Optional[bool] = False, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self._maxsize = size
|
||||
self._stack = stack_num
|
||||
@ -100,11 +104,11 @@ class ReplayBuffer(object):
|
||||
self._meta = {}
|
||||
self.reset()
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
"""Return len(self)."""
|
||||
return self._size
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
"""Return str(self)."""
|
||||
s = self.__class__.__name__ + '(\n'
|
||||
flag = False
|
||||
@ -121,7 +125,7 @@ class ReplayBuffer(object):
|
||||
s = self.__class__.__name__ + '()'
|
||||
return s
|
||||
|
||||
def __getattr__(self, key):
|
||||
def __getattr__(self, key: str) -> Union[Batch, np.ndarray]:
|
||||
"""Return self.key"""
|
||||
if key not in self._meta:
|
||||
if key not in self.__dict__:
|
||||
@ -133,7 +137,9 @@ class ReplayBuffer(object):
|
||||
d[k_] = self.__dict__[k__]
|
||||
return Batch(**d)
|
||||
|
||||
def _add_to_buffer(self, name, inst):
|
||||
def _add_to_buffer(
|
||||
self, name: str,
|
||||
inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None:
|
||||
if inst is None:
|
||||
if getattr(self, name, None) is None:
|
||||
self.__dict__[name] = None
|
||||
@ -164,9 +170,11 @@ class ReplayBuffer(object):
|
||||
f"key: {name}, expect shape: {self.__dict__[name].shape[1:]}, "
|
||||
f"given shape: {inst.shape}.")
|
||||
if name not in self._meta:
|
||||
if name == 'info':
|
||||
inst = deepcopy(inst)
|
||||
self.__dict__[name][self._index] = inst
|
||||
|
||||
def update(self, buffer):
|
||||
def update(self, buffer: 'ReplayBuffer') -> None:
|
||||
"""Move the data from the given buffer to self."""
|
||||
i = begin = buffer._index % len(buffer)
|
||||
while True:
|
||||
@ -178,8 +186,15 @@ class ReplayBuffer(object):
|
||||
if i == begin:
|
||||
break
|
||||
|
||||
def add(self, obs, act, rew, done, obs_next=None, info={}, policy={},
|
||||
**kwargs):
|
||||
def add(self,
|
||||
obs: Union[dict, np.ndarray],
|
||||
act: Union[np.ndarray, float],
|
||||
rew: float,
|
||||
done: bool,
|
||||
obs_next: Optional[Union[dict, np.ndarray]] = None,
|
||||
info: Optional[dict] = {},
|
||||
policy: Optional[Union[dict, Batch]] = {},
|
||||
**kwargs) -> None:
|
||||
"""Add a batch of data into replay buffer."""
|
||||
assert isinstance(info, dict), \
|
||||
'You should return a dict in the last argument of env.step().'
|
||||
@ -197,11 +212,11 @@ class ReplayBuffer(object):
|
||||
else:
|
||||
self._size = self._index = self._index + 1
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
"""Clear all the data in replay buffer."""
|
||||
self._index = self._size = 0
|
||||
|
||||
def sample(self, batch_size):
|
||||
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
|
||||
"""Get a random sample from buffer with size equal to batch_size. \
|
||||
Return all the data in the buffer if batch_size is ``0``.
|
||||
|
||||
@ -216,7 +231,8 @@ class ReplayBuffer(object):
|
||||
])
|
||||
return self[indice], indice
|
||||
|
||||
def get(self, indice, key, stack_num=None):
|
||||
def get(self, indice: Union[slice, np.ndarray], key: str,
|
||||
stack_num: Optional[int] = None) -> Union[Batch, np.ndarray]:
|
||||
"""Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t],
|
||||
where s is self.key, t is indice. The stack_num (here equals to 4) is
|
||||
given from buffer initialization procedure.
|
||||
@ -275,7 +291,7 @@ class ReplayBuffer(object):
|
||||
stack = np.stack(stack, axis=1)
|
||||
return stack
|
||||
|
||||
def __getitem__(self, index):
|
||||
def __getitem__(self, index: Union[slice, np.ndarray]) -> Batch:
|
||||
"""Return a data batch: self[index]. If stack_num is set to be > 0,
|
||||
return the stacked obs and obs_next with shape [batch, len, ...].
|
||||
"""
|
||||
@ -302,17 +318,21 @@ class ListReplayBuffer(ReplayBuffer):
|
||||
detailed explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(size=0, ignore_obs_next=False, **kwargs)
|
||||
|
||||
def _add_to_buffer(self, name, inst):
|
||||
def _add_to_buffer(
|
||||
self, name: str,
|
||||
inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None:
|
||||
if inst is None:
|
||||
return
|
||||
if self.__dict__.get(name, None) is None:
|
||||
self.__dict__[name] = []
|
||||
if name == 'info':
|
||||
inst = deepcopy(inst)
|
||||
self.__dict__[name].append(inst)
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
self._index = self._size = 0
|
||||
for k in list(self.__dict__):
|
||||
if isinstance(self.__dict__[k], list):
|
||||
@ -332,8 +352,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
detailed explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, size, alpha: float, beta: float,
|
||||
mode: str = 'weight', **kwargs):
|
||||
def __init__(self, size: int, alpha: float, beta: float,
|
||||
mode: Optional[str] = 'weight', **kwargs) -> None:
|
||||
if mode != 'weight':
|
||||
raise NotImplementedError
|
||||
super().__init__(size, **kwargs)
|
||||
@ -344,17 +364,27 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
self._amortization_freq = 50
|
||||
self._amortization_counter = 0
|
||||
|
||||
def add(self, obs, act, rew, done, obs_next=0, info={}, policy={},
|
||||
weight=1.0):
|
||||
def add(self,
|
||||
obs: Union[dict, np.ndarray],
|
||||
act: Union[np.ndarray, float],
|
||||
rew: float,
|
||||
done: bool,
|
||||
obs_next: Optional[Union[dict, np.ndarray]] = None,
|
||||
info: Optional[dict] = {},
|
||||
policy: Optional[Union[dict, Batch]] = {},
|
||||
weight: Optional[float] = 1.0,
|
||||
**kwargs) -> None:
|
||||
"""Add a batch of data into replay buffer."""
|
||||
self._weight_sum += np.abs(weight)**self._alpha - \
|
||||
self._weight_sum += np.abs(weight) ** self._alpha - \
|
||||
self.weight[self._index]
|
||||
# we have to sacrifice some convenience for speed :(
|
||||
self._add_to_buffer('weight', np.abs(weight) ** self._alpha)
|
||||
super().add(obs, act, rew, done, obs_next, info, policy)
|
||||
self._check_weight_sum()
|
||||
|
||||
def sample(self, batch_size: int = 0, importance_sample: bool = True):
|
||||
def sample(self, batch_size: Optional[int] = 0,
|
||||
importance_sample: Optional[bool] = True
|
||||
) -> Tuple[Batch, np.ndarray]:
|
||||
"""Get a random sample from buffer with priority probability. \
|
||||
Return all the data in the buffer if batch_size is ``0``.
|
||||
|
||||
@ -388,11 +418,12 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
self._check_weight_sum()
|
||||
return batch, indice
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
self._amortization_counter = 0
|
||||
super().reset()
|
||||
|
||||
def update_weight(self, indice, new_weight: np.ndarray):
|
||||
def update_weight(self, indice: Union[slice, np.ndarray],
|
||||
new_weight: np.ndarray) -> None:
|
||||
"""Update priority weight by indice in this buffer.
|
||||
|
||||
:param np.ndarray indice: indice you want to update weight
|
||||
@ -402,7 +433,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
- self.weight[indice].sum()
|
||||
self.weight[indice] = np.power(np.abs(new_weight), self._alpha)
|
||||
|
||||
def __getitem__(self, index):
|
||||
def __getitem__(self, index: Union[slice, np.ndarray]) -> Batch:
|
||||
return Batch(
|
||||
obs=self.get(index, 'obs'),
|
||||
act=self.act[index],
|
||||
@ -415,7 +446,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
policy=self.get(index, 'policy'),
|
||||
)
|
||||
|
||||
def _check_weight_sum(self):
|
||||
def _check_weight_sum(self) -> None:
|
||||
# keep an accurate _weight_sum
|
||||
self._amortization_counter += 1
|
||||
if self._amortization_counter % self._amortization_freq == 0:
|
||||
|
@ -1,10 +1,13 @@
|
||||
import gym
|
||||
import time
|
||||
import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
from typing import Any, Dict, List, Union, Optional, Callable
|
||||
|
||||
from tianshou.utils import MovAvg
|
||||
from tianshou.env import BaseVectorEnv
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer
|
||||
|
||||
|
||||
@ -77,8 +80,14 @@ class Collector(object):
|
||||
Please make sure the given environment has a time limitation.
|
||||
"""
|
||||
|
||||
def __init__(self, policy, env, buffer=None, preprocess_fn=None,
|
||||
stat_size=100, **kwargs):
|
||||
def __init__(self,
|
||||
policy: BasePolicy,
|
||||
env: Union[gym.Env, BaseVectorEnv],
|
||||
buffer: Optional[Union[ReplayBuffer, List[ReplayBuffer]]]
|
||||
= None,
|
||||
preprocess_fn: Callable[[Any], Union[dict, Batch]] = None,
|
||||
stat_size: Optional[int] = 100,
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
self.env = env
|
||||
self.env_num = 1
|
||||
@ -112,7 +121,7 @@ class Collector(object):
|
||||
self.stat_size = stat_size
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
"""Reset all related variables in the collector."""
|
||||
self.reset_env()
|
||||
self.reset_buffer()
|
||||
@ -124,7 +133,7 @@ class Collector(object):
|
||||
self.collect_episode = 0
|
||||
self.collect_time = 0
|
||||
|
||||
def reset_buffer(self):
|
||||
def reset_buffer(self) -> None:
|
||||
"""Reset the main data buffer."""
|
||||
if self._multi_buf:
|
||||
for b in self.buffer:
|
||||
@ -133,11 +142,11 @@ class Collector(object):
|
||||
if self.buffer is not None:
|
||||
self.buffer.reset()
|
||||
|
||||
def get_env_num(self):
|
||||
def get_env_num(self) -> int:
|
||||
"""Return the number of environments the collector have."""
|
||||
return self.env_num
|
||||
|
||||
def reset_env(self):
|
||||
def reset_env(self) -> None:
|
||||
"""Reset all of the environment(s)' states and reset all of the cache
|
||||
buffers (if need).
|
||||
"""
|
||||
@ -155,36 +164,36 @@ class Collector(object):
|
||||
for b in self._cached_buf:
|
||||
b.reset()
|
||||
|
||||
def seed(self, seed=None):
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
|
||||
"""Reset all the seed(s) of the given environment(s)."""
|
||||
if hasattr(self.env, 'seed'):
|
||||
return self.env.seed(seed)
|
||||
|
||||
def render(self, **kwargs):
|
||||
def render(self, **kwargs) -> None:
|
||||
"""Render all the environment(s)."""
|
||||
if hasattr(self.env, 'render'):
|
||||
return self.env.render(**kwargs)
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Close the environment(s)."""
|
||||
if hasattr(self.env, 'close'):
|
||||
self.env.close()
|
||||
|
||||
def _make_batch(self, data):
|
||||
def _make_batch(self, data: Any) -> Union[Any, np.ndarray]:
|
||||
"""Return [data]."""
|
||||
if isinstance(data, np.ndarray):
|
||||
return data[None]
|
||||
else:
|
||||
return np.array([data])
|
||||
|
||||
def _reset_state(self, id):
|
||||
def _reset_state(self, id: Union[int, List[int]]) -> None:
|
||||
"""Reset self.state[id]."""
|
||||
if self.state is None:
|
||||
return
|
||||
if isinstance(self.state, list):
|
||||
self.state[id] = None
|
||||
elif isinstance(self.state, dict):
|
||||
for k in self.state:
|
||||
elif isinstance(self.state, dict) or isinstance(self.state, Batch):
|
||||
for k in self.state.keys():
|
||||
if isinstance(self.state[k], list):
|
||||
self.state[k][id] = None
|
||||
elif isinstance(self.state[k], torch.Tensor) or \
|
||||
@ -194,7 +203,8 @@ class Collector(object):
|
||||
isinstance(self.state, np.ndarray):
|
||||
self.state[id] = 0
|
||||
|
||||
def _to_numpy(self, x):
|
||||
def _to_numpy(self, x: Union[
|
||||
torch.Tensor, dict, Batch, np.ndarray]) -> None:
|
||||
"""Return an object without torch.Tensor."""
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.cpu().numpy()
|
||||
@ -208,7 +218,12 @@ class Collector(object):
|
||||
return x
|
||||
return x
|
||||
|
||||
def collect(self, n_step=0, n_episode=0, render=None, log_fn=None):
|
||||
def collect(self,
|
||||
n_step: Optional[int] = 0,
|
||||
n_episode: Optional[Union[int, List[int]]] = 0,
|
||||
render: Optional[float] = None,
|
||||
log_fn: Optional[Callable[[dict], None]] = None
|
||||
) -> Dict[str, float]:
|
||||
"""Collect a specified number of step or episode.
|
||||
|
||||
:param int n_step: how many steps you want to collect.
|
||||
@ -375,7 +390,7 @@ class Collector(object):
|
||||
'len': length_sum / n_episode,
|
||||
}
|
||||
|
||||
def sample(self, batch_size):
|
||||
def sample(self, batch_size: int) -> Batch:
|
||||
"""Sample a data batch from the internal replay buffer. It will call
|
||||
:meth:`~tianshou.policy.BasePolicy.process_fn` before returning
|
||||
the final batch data.
|
||||
|
55
tianshou/env/vecenv.py
vendored
55
tianshou/env/vecenv.py
vendored
@ -2,6 +2,7 @@ import gym
|
||||
import numpy as np
|
||||
from abc import ABC, abstractmethod
|
||||
from multiprocessing import Process, Pipe
|
||||
from typing import List, Tuple, Union, Optional, Callable
|
||||
|
||||
try:
|
||||
import ray
|
||||
@ -36,16 +37,16 @@ class BaseVectorEnv(ABC, gym.Wrapper):
|
||||
envs.close() # close all environments
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns):
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
||||
self._env_fns = env_fns
|
||||
self.env_num = len(env_fns)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
"""Return len(self), which is the number of environments."""
|
||||
return self.env_num
|
||||
|
||||
@abstractmethod
|
||||
def reset(self, id=None):
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None):
|
||||
"""Reset the state of all the environments and return initial
|
||||
observations if id is ``None``, otherwise reset the specific
|
||||
environments with given id, either an int or a list.
|
||||
@ -53,7 +54,8 @@ class BaseVectorEnv(ABC, gym.Wrapper):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def step(self, action):
|
||||
def step(self, action: np.ndarray
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Run one timestep of all the environments’ dynamics. When the end of
|
||||
episode is reached, you are responsible for calling reset(id) to reset
|
||||
this environment’s state.
|
||||
@ -76,19 +78,19 @@ class BaseVectorEnv(ABC, gym.Wrapper):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def seed(self, seed=None):
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
|
||||
"""Set the seed for all environments. Accept ``None``, an int (which
|
||||
will extend ``i`` to ``[i, i + 1, i + 2, ...]``) or a list.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def render(self, **kwargs):
|
||||
def render(self, **kwargs) -> None:
|
||||
"""Render all of the environments."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
"""Close all of the environments."""
|
||||
pass
|
||||
|
||||
@ -102,11 +104,11 @@ class VectorEnv(BaseVectorEnv):
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns):
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
||||
super().__init__(env_fns)
|
||||
self.envs = [_() for _ in env_fns]
|
||||
|
||||
def reset(self, id=None):
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> None:
|
||||
if id is None:
|
||||
self._obs = np.stack([e.reset() for e in self.envs])
|
||||
else:
|
||||
@ -116,7 +118,8 @@ class VectorEnv(BaseVectorEnv):
|
||||
self._obs[i] = self.envs[i].reset()
|
||||
return self._obs
|
||||
|
||||
def step(self, action):
|
||||
def step(self, action: np.ndarray
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
assert len(action) == self.env_num
|
||||
result = [e.step(a) for e, a in zip(self.envs, action)]
|
||||
self._obs, self._rew, self._done, self._info = zip(*result)
|
||||
@ -126,7 +129,7 @@ class VectorEnv(BaseVectorEnv):
|
||||
self._info = np.stack(self._info)
|
||||
return self._obs, self._rew, self._done, self._info
|
||||
|
||||
def seed(self, seed=None):
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
|
||||
if np.isscalar(seed):
|
||||
seed = [seed + _ for _ in range(self.env_num)]
|
||||
elif seed is None:
|
||||
@ -137,14 +140,14 @@ class VectorEnv(BaseVectorEnv):
|
||||
result.append(e.seed(s))
|
||||
return result
|
||||
|
||||
def render(self, **kwargs):
|
||||
def render(self, **kwargs) -> None:
|
||||
result = []
|
||||
for e in self.envs:
|
||||
if hasattr(e, 'render'):
|
||||
result.append(e.render(**kwargs))
|
||||
return result
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
return [e.close() for e in self.envs]
|
||||
|
||||
|
||||
@ -182,7 +185,7 @@ class SubprocVectorEnv(BaseVectorEnv):
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns):
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
||||
super().__init__(env_fns)
|
||||
self.closed = False
|
||||
self.parent_remote, self.child_remote = \
|
||||
@ -198,7 +201,8 @@ class SubprocVectorEnv(BaseVectorEnv):
|
||||
for c in self.child_remote:
|
||||
c.close()
|
||||
|
||||
def step(self, action):
|
||||
def step(self, action: np.ndarray
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
assert len(action) == self.env_num
|
||||
for p, a in zip(self.parent_remote, action):
|
||||
p.send(['step', a])
|
||||
@ -210,7 +214,7 @@ class SubprocVectorEnv(BaseVectorEnv):
|
||||
self._info = np.stack(self._info)
|
||||
return self._obs, self._rew, self._done, self._info
|
||||
|
||||
def reset(self, id=None):
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> None:
|
||||
if id is None:
|
||||
for p in self.parent_remote:
|
||||
p.send(['reset', None])
|
||||
@ -225,7 +229,7 @@ class SubprocVectorEnv(BaseVectorEnv):
|
||||
self._obs[i] = self.parent_remote[i].recv()
|
||||
return self._obs
|
||||
|
||||
def seed(self, seed=None):
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
|
||||
if np.isscalar(seed):
|
||||
seed = [seed + _ for _ in range(self.env_num)]
|
||||
elif seed is None:
|
||||
@ -234,12 +238,12 @@ class SubprocVectorEnv(BaseVectorEnv):
|
||||
p.send(['seed', s])
|
||||
return [p.recv() for p in self.parent_remote]
|
||||
|
||||
def render(self, **kwargs):
|
||||
def render(self, **kwargs) -> None:
|
||||
for p in self.parent_remote:
|
||||
p.send(['render', kwargs])
|
||||
return [p.recv() for p in self.parent_remote]
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
if self.closed:
|
||||
return
|
||||
for p in self.parent_remote:
|
||||
@ -263,7 +267,7 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns):
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
||||
super().__init__(env_fns)
|
||||
try:
|
||||
if not ray.is_initialized():
|
||||
@ -275,7 +279,8 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
ray.remote(gym.Wrapper).options(num_cpus=0).remote(e())
|
||||
for e in env_fns]
|
||||
|
||||
def step(self, action):
|
||||
def step(self, action: np.ndarray
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
assert len(action) == self.env_num
|
||||
result = ray.get([e.step.remote(a) for e, a in zip(self.envs, action)])
|
||||
self._obs, self._rew, self._done, self._info = zip(*result)
|
||||
@ -285,7 +290,7 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
self._info = np.stack(self._info)
|
||||
return self._obs, self._rew, self._done, self._info
|
||||
|
||||
def reset(self, id=None):
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> None:
|
||||
if id is None:
|
||||
result_obj = [e.reset.remote() for e in self.envs]
|
||||
self._obs = np.stack(ray.get(result_obj))
|
||||
@ -299,7 +304,7 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
self._obs[i] = ray.get(result_obj[_])
|
||||
return self._obs
|
||||
|
||||
def seed(self, seed=None):
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
|
||||
if not hasattr(self.envs[0], 'seed'):
|
||||
return
|
||||
if np.isscalar(seed):
|
||||
@ -308,10 +313,10 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
seed = [seed] * self.env_num
|
||||
return ray.get([e.seed.remote(s) for e, s in zip(self.envs, seed)])
|
||||
|
||||
def render(self, **kwargs):
|
||||
def render(self, **kwargs) -> None:
|
||||
if not hasattr(self.envs[0], 'render'):
|
||||
return
|
||||
return ray.get([e.render.remote(**kwargs) for e in self.envs])
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
return ray.get([e.close.remote() for e in self.envs])
|
||||
|
@ -1,4 +1,5 @@
|
||||
import numpy as np
|
||||
from typing import Union, Optional
|
||||
|
||||
|
||||
class OUNoise(object):
|
||||
@ -17,13 +18,18 @@ class OUNoise(object):
|
||||
Ornstein-Uhlenbeck process.
|
||||
"""
|
||||
|
||||
def __init__(self, sigma=0.3, theta=0.15, dt=1e-2, x0=None):
|
||||
def __init__(self,
|
||||
sigma: Optional[float] = 0.3,
|
||||
theta: Optional[float] = 0.15,
|
||||
dt: Optional[float] = 1e-2,
|
||||
x0: Optional[Union[float, np.ndarray]] = None
|
||||
) -> None:
|
||||
self.alpha = theta * dt
|
||||
self.beta = sigma * np.sqrt(dt)
|
||||
self.x0 = x0
|
||||
self.reset()
|
||||
|
||||
def __call__(self, size, mu=.1):
|
||||
def __call__(self, size: tuple, mu: Optional[float] = .1) -> np.ndarray:
|
||||
"""Generate new noise. Return a ``numpy.ndarray`` which size is equal
|
||||
to ``size``.
|
||||
"""
|
||||
@ -33,6 +39,6 @@ class OUNoise(object):
|
||||
self.x = self.x + self.alpha * (mu - self.x) + r
|
||||
return self.x
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
"""Reset to the initial state."""
|
||||
self.x = None
|
||||
|
@ -1,6 +1,10 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Union, Optional
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
|
||||
|
||||
class BasePolicy(ABC, nn.Module):
|
||||
@ -39,17 +43,20 @@ class BasePolicy(ABC, nn.Module):
|
||||
policy.load_state_dict(torch.load('policy.pth'))
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__()
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> Batch:
|
||||
"""Pre-process the data from the provided replay buffer. Check out
|
||||
:ref:`policy_concept` for more information.
|
||||
"""
|
||||
return batch
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, batch, state=None, **kwargs):
|
||||
def forward(self, batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
**kwargs) -> Batch:
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which MUST have the following\
|
||||
@ -80,7 +87,8 @@ class BasePolicy(ABC, nn.Module):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def learn(self, batch, **kwargs):
|
||||
def learn(self, batch: Batch, **kwargs
|
||||
) -> Dict[str, Union[float, List[float]]]:
|
||||
"""Update policy with a given batch of data.
|
||||
|
||||
:return: A dict which includes loss and its corresponding label.
|
||||
@ -88,8 +96,11 @@ class BasePolicy(ABC, nn.Module):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def compute_episodic_return(batch, v_s_=None,
|
||||
gamma=0.99, gae_lambda=0.95):
|
||||
def compute_episodic_return(
|
||||
batch: Batch,
|
||||
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
||||
gamma: Optional[float] = 0.99,
|
||||
gae_lambda: Optional[float] = 0.95) -> Batch:
|
||||
"""Compute returns over given full-length episodes, including the
|
||||
implementation of Generalized Advantage Estimation (arXiv:1506.02438).
|
||||
|
||||
|
@ -1,5 +1,7 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Union, Optional
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import BasePolicy
|
||||
@ -19,7 +21,9 @@ class ImitationPolicy(BasePolicy):
|
||||
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
def __init__(self, model, optim, mode='continuous'):
|
||||
|
||||
def __init__(self, model: torch.nn.Module, optim: torch.optim.Optimizer,
|
||||
mode: Optional[str] = 'continuous', **kwargs) -> None:
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.optim = optim
|
||||
@ -27,7 +31,10 @@ class ImitationPolicy(BasePolicy):
|
||||
f'Mode {mode} is not in ["continuous", "discrete"]'
|
||||
self.mode = mode
|
||||
|
||||
def forward(self, batch, state=None):
|
||||
def forward(self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
**kwargs) -> Batch:
|
||||
logits, h = self.model(batch.obs, state=state, info=batch.info)
|
||||
if self.mode == 'discrete':
|
||||
a = logits.max(dim=1)[1]
|
||||
@ -35,7 +42,7 @@ class ImitationPolicy(BasePolicy):
|
||||
a = logits
|
||||
return Batch(logits=logits, act=a, state=h)
|
||||
|
||||
def learn(self, batch, **kwargs):
|
||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||
self.optim.zero_grad()
|
||||
if self.mode == 'continuous':
|
||||
a = self(batch).act
|
||||
|
@ -2,9 +2,10 @@ import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, List, Union, Optional
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
|
||||
|
||||
class A2CPolicy(PGPolicy):
|
||||
@ -31,11 +32,19 @@ class A2CPolicy(PGPolicy):
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, actor, critic, optim,
|
||||
dist_fn=torch.distributions.Categorical,
|
||||
discount_factor=0.99, vf_coef=.5, ent_coef=.01,
|
||||
max_grad_norm=None, gae_lambda=0.95,
|
||||
reward_normalization=False, **kwargs):
|
||||
def __init__(self,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: Optional[torch.distributions.Distribution]
|
||||
= torch.distributions.Categorical,
|
||||
discount_factor: Optional[float] = 0.99,
|
||||
vf_coef: Optional[float] = .5,
|
||||
ent_coef: Optional[float] = .01,
|
||||
max_grad_norm: Optional[float] = None,
|
||||
gae_lambda: Optional[float] = 0.95,
|
||||
reward_normalization: Optional[bool] = False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
@ -48,7 +57,8 @@ class A2CPolicy(PGPolicy):
|
||||
self._rew_norm = reward_normalization
|
||||
self.__eps = np.finfo(np.float32).eps.item()
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> Batch:
|
||||
if self._lambda in [0, 1]:
|
||||
return self.compute_episodic_return(
|
||||
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
@ -60,7 +70,9 @@ class A2CPolicy(PGPolicy):
|
||||
return self.compute_episodic_return(
|
||||
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
|
||||
def forward(self, batch, state=None, **kwargs):
|
||||
def forward(self, batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
**kwargs) -> Batch:
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
|
||||
@ -83,7 +95,8 @@ class A2CPolicy(PGPolicy):
|
||||
act = dist.sample()
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
|
||||
def learn(self, batch: Batch, batch_size: int, repeat: int,
|
||||
**kwargs) -> Dict[str, List[float]]:
|
||||
self._batch = batch_size
|
||||
r = batch.returns
|
||||
if self._rew_norm and r.std() > self.__eps:
|
||||
|
@ -2,10 +2,11 @@ import torch
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Tuple, Union, Optional
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import BasePolicy
|
||||
# from tianshou.exploration import OUNoise
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
|
||||
|
||||
class DDPGPolicy(BasePolicy):
|
||||
@ -23,7 +24,7 @@ class DDPGPolicy(BasePolicy):
|
||||
:param float exploration_noise: the noise intensity, add to the action,
|
||||
defaults to 0.1.
|
||||
:param action_range: the action range (minimum, maximum).
|
||||
:type action_range: [float, float]
|
||||
:type action_range: (float, float)
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||
defaults to ``False``.
|
||||
:param bool ignore_done: ignore the done flag while training the policy,
|
||||
@ -35,10 +36,18 @@ class DDPGPolicy(BasePolicy):
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, actor, actor_optim, critic, critic_optim,
|
||||
tau=0.005, gamma=0.99, exploration_noise=0.1,
|
||||
action_range=None, reward_normalization=False,
|
||||
ignore_done=False, **kwargs):
|
||||
def __init__(self,
|
||||
actor: torch.nn.Module,
|
||||
actor_optim: torch.optim.Optimizer,
|
||||
critic: torch.nn.Module,
|
||||
critic_optim: torch.optim.Optimizer,
|
||||
tau: Optional[float] = 0.005,
|
||||
gamma: Optional[float] = 0.99,
|
||||
exploration_noise: Optional[float] = 0.1,
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
reward_normalization: Optional[bool] = False,
|
||||
ignore_done: Optional[bool] = False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if actor is not None:
|
||||
self.actor, self.actor_old = actor, deepcopy(actor)
|
||||
@ -64,23 +73,23 @@ class DDPGPolicy(BasePolicy):
|
||||
self._rew_norm = reward_normalization
|
||||
self.__eps = np.finfo(np.float32).eps.item()
|
||||
|
||||
def set_eps(self, eps):
|
||||
def set_eps(self, eps: float) -> None:
|
||||
"""Set the eps for exploration."""
|
||||
self._eps = eps
|
||||
|
||||
def train(self):
|
||||
def train(self) -> None:
|
||||
"""Set the module in training mode, except for the target network."""
|
||||
self.training = True
|
||||
self.actor.train()
|
||||
self.critic.train()
|
||||
|
||||
def eval(self):
|
||||
def eval(self) -> None:
|
||||
"""Set the module in evaluation mode, except for the target network."""
|
||||
self.training = False
|
||||
self.actor.eval()
|
||||
self.critic.eval()
|
||||
|
||||
def sync_weight(self):
|
||||
def sync_weight(self) -> None:
|
||||
"""Soft-update the weight for the target network."""
|
||||
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
|
||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||
@ -88,7 +97,8 @@ class DDPGPolicy(BasePolicy):
|
||||
self.critic_old.parameters(), self.critic.parameters()):
|
||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> Batch:
|
||||
if self._rew_norm:
|
||||
bfr = buffer.rew[:min(len(buffer), 1000)] # avoid large buffer
|
||||
mean, std = bfr.mean(), bfr.std()
|
||||
@ -98,8 +108,12 @@ class DDPGPolicy(BasePolicy):
|
||||
batch.done = batch.done * 0.
|
||||
return batch
|
||||
|
||||
def forward(self, batch, state=None,
|
||||
model='actor', input='obs', eps=None, **kwargs):
|
||||
def forward(self, batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
model: Optional[str] = 'actor',
|
||||
input: Optional[str] = 'obs',
|
||||
eps: Optional[float] = None,
|
||||
**kwargs) -> Batch:
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:param float eps: in [0, 1], for exploration use.
|
||||
@ -129,7 +143,7 @@ class DDPGPolicy(BasePolicy):
|
||||
logits = logits.clamp(self._range[0], self._range[1])
|
||||
return Batch(act=logits, state=h)
|
||||
|
||||
def learn(self, batch, **kwargs):
|
||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||
with torch.no_grad():
|
||||
target_q = self.critic_old(batch.obs_next, self(
|
||||
batch, model='actor_old', input='obs_next', eps=0).act)
|
||||
|
@ -2,9 +2,10 @@ import torch
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Union, Optional
|
||||
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.data import Batch, PrioritizedReplayBuffer
|
||||
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer
|
||||
|
||||
|
||||
class DQNPolicy(BasePolicy):
|
||||
@ -25,8 +26,13 @@ class DQNPolicy(BasePolicy):
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, model, optim, discount_factor=0.99,
|
||||
estimation_step=1, target_update_freq=0, **kwargs):
|
||||
def __init__(self,
|
||||
model: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
discount_factor: Optional[float] = 0.99,
|
||||
estimation_step: Optional[int] = 1,
|
||||
target_update_freq: Optional[int] = 0,
|
||||
**kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.model = model
|
||||
self.optim = optim
|
||||
@ -42,25 +48,26 @@ class DQNPolicy(BasePolicy):
|
||||
self.model_old = deepcopy(self.model)
|
||||
self.model_old.eval()
|
||||
|
||||
def set_eps(self, eps):
|
||||
def set_eps(self, eps: float) -> None:
|
||||
"""Set the eps for epsilon-greedy exploration."""
|
||||
self.eps = eps
|
||||
|
||||
def train(self):
|
||||
def train(self) -> None:
|
||||
"""Set the module in training mode, except for the target network."""
|
||||
self.training = True
|
||||
self.model.train()
|
||||
|
||||
def eval(self):
|
||||
def eval(self) -> None:
|
||||
"""Set the module in evaluation mode, except for the target network."""
|
||||
self.training = False
|
||||
self.model.eval()
|
||||
|
||||
def sync_weight(self):
|
||||
def sync_weight(self) -> None:
|
||||
"""Synchronize the weight for the target network."""
|
||||
self.model_old.load_state_dict(self.model.state_dict())
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> Batch:
|
||||
r"""Compute the n-step return for Q-learning targets:
|
||||
|
||||
.. math::
|
||||
@ -115,8 +122,12 @@ class DQNPolicy(BasePolicy):
|
||||
batch.loss += loss
|
||||
return batch
|
||||
|
||||
def forward(self, batch, state=None,
|
||||
model='model', input='obs', eps=None, **kwargs):
|
||||
def forward(self, batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
model: Optional[str] = 'model',
|
||||
input: Optional[str] = 'obs',
|
||||
eps: Optional[float] = None,
|
||||
**kwargs) -> Batch:
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:param float eps: in [0, 1], for epsilon-greedy exploration method.
|
||||
@ -144,7 +155,7 @@ class DQNPolicy(BasePolicy):
|
||||
act[i] = np.random.randint(q.shape[1])
|
||||
return Batch(logits=q, act=act, state=h)
|
||||
|
||||
def learn(self, batch, **kwargs):
|
||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||
if self._target and self._cnt % self._freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
|
@ -1,8 +1,9 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Dict, List, Union, Optional
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
|
||||
|
||||
class PGPolicy(BasePolicy):
|
||||
@ -20,8 +21,14 @@ class PGPolicy(BasePolicy):
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
|
||||
discount_factor=0.99, reward_normalization=False, **kwargs):
|
||||
def __init__(self,
|
||||
model: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: Optional[torch.distributions.Distribution]
|
||||
= torch.distributions.Categorical,
|
||||
discount_factor: Optional[float] = 0.99,
|
||||
reward_normalization: Optional[bool] = False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.model = model
|
||||
self.optim = optim
|
||||
@ -31,7 +38,8 @@ class PGPolicy(BasePolicy):
|
||||
self._rew_norm = reward_normalization
|
||||
self.__eps = np.finfo(np.float32).eps.item()
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> Batch:
|
||||
r"""Compute the discounted returns for each frame:
|
||||
|
||||
.. math::
|
||||
@ -46,7 +54,9 @@ class PGPolicy(BasePolicy):
|
||||
return self.compute_episodic_return(
|
||||
batch, gamma=self._gamma, gae_lambda=1.)
|
||||
|
||||
def forward(self, batch, state=None, **kwargs):
|
||||
def forward(self, batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
**kwargs) -> Batch:
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
|
||||
@ -69,7 +79,8 @@ class PGPolicy(BasePolicy):
|
||||
act = dist.sample()
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
|
||||
def learn(self, batch: Batch, batch_size: int, repeat: int,
|
||||
**kwargs) -> Dict[str, List[float]]:
|
||||
losses = []
|
||||
r = batch.returns
|
||||
if self._rew_norm and r.std() > self.__eps:
|
||||
|
@ -1,9 +1,10 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from typing import Dict, List, Tuple, Union, Optional
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
|
||||
|
||||
class PPOPolicy(PGPolicy):
|
||||
@ -23,7 +24,7 @@ class PPOPolicy(PGPolicy):
|
||||
:param float vf_coef: weight for value loss, defaults to 0.5.
|
||||
:param float ent_coef: weight for entropy loss, defaults to 0.01.
|
||||
:param action_range: the action range (minimum, maximum).
|
||||
:type action_range: [float, float]
|
||||
:type action_range: (float, float)
|
||||
:param float gae_lambda: in [0, 1], param for Generalized Advantage
|
||||
Estimation, defaults to 0.95.
|
||||
:param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5,
|
||||
@ -40,11 +41,22 @@ class PPOPolicy(PGPolicy):
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, actor, critic, optim, dist_fn,
|
||||
discount_factor=0.99, max_grad_norm=.5, eps_clip=.2,
|
||||
vf_coef=.5, ent_coef=.01, action_range=None, gae_lambda=0.95,
|
||||
dual_clip=5., value_clip=True, reward_normalization=True,
|
||||
**kwargs):
|
||||
def __init__(self,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: torch.distributions.Distribution,
|
||||
discount_factor: Optional[float] = 0.99,
|
||||
max_grad_norm: Optional[float] = None,
|
||||
eps_clip: Optional[float] = .2,
|
||||
vf_coef: Optional[float] = .5,
|
||||
ent_coef: Optional[float] = .01,
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
gae_lambda: Optional[float] = 0.95,
|
||||
dual_clip: Optional[float] = 5.,
|
||||
value_clip: Optional[bool] = True,
|
||||
reward_normalization: Optional[bool] = True,
|
||||
**kwargs) -> None:
|
||||
super().__init__(None, None, dist_fn, discount_factor, **kwargs)
|
||||
self._max_grad_norm = max_grad_norm
|
||||
self._eps_clip = eps_clip
|
||||
@ -64,7 +76,8 @@ class PPOPolicy(PGPolicy):
|
||||
self._rew_norm = reward_normalization
|
||||
self.__eps = np.finfo(np.float32).eps.item()
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> Batch:
|
||||
if self._rew_norm:
|
||||
mean, std = batch.rew.mean(), batch.rew.std()
|
||||
if std > self.__eps:
|
||||
@ -80,7 +93,9 @@ class PPOPolicy(PGPolicy):
|
||||
return self.compute_episodic_return(
|
||||
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
|
||||
def forward(self, batch, state=None, **kwargs):
|
||||
def forward(self, batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
**kwargs) -> Batch:
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
|
||||
@ -105,7 +120,8 @@ class PPOPolicy(PGPolicy):
|
||||
act = act.clamp(self._range[0], self._range[1])
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
|
||||
def learn(self, batch: Batch, batch_size: int, repeat: int,
|
||||
**kwargs) -> Dict[str, List[float]]:
|
||||
self._batch = batch_size
|
||||
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
||||
v = []
|
||||
|
@ -2,6 +2,7 @@ import torch
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Tuple, Union, Optional
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import DDPGPolicy
|
||||
@ -28,7 +29,7 @@ class SACPolicy(DDPGPolicy):
|
||||
defaults to 0.1.
|
||||
:param float alpha: entropy regularization coefficient, default to 0.2.
|
||||
:param action_range: the action range (minimum, maximum).
|
||||
:type action_range: [float, float]
|
||||
:type action_range: (float, float)
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||
defaults to ``False``.
|
||||
:param bool ignore_done: ignore the done flag while training the policy,
|
||||
@ -40,10 +41,20 @@ class SACPolicy(DDPGPolicy):
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, actor, actor_optim, critic1, critic1_optim,
|
||||
critic2, critic2_optim, tau=0.005, gamma=0.99,
|
||||
alpha=0.2, action_range=None, reward_normalization=False,
|
||||
ignore_done=False, **kwargs):
|
||||
def __init__(self,
|
||||
actor: torch.nn.Module,
|
||||
actor_optim: torch.optim.Optimizer,
|
||||
critic1: torch.nn.Module,
|
||||
critic1_optim: torch.optim.Optimizer,
|
||||
critic2: torch.nn.Module,
|
||||
critic2_optim: torch.optim.Optimizer,
|
||||
tau: Optional[float] = 0.005,
|
||||
gamma: Optional[float] = 0.99,
|
||||
alpha: Optional[float] = 0.2,
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
reward_normalization: Optional[bool] = False,
|
||||
ignore_done: Optional[bool] = False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(None, None, None, None, tau, gamma, 0,
|
||||
action_range, reward_normalization, ignore_done,
|
||||
**kwargs)
|
||||
@ -57,19 +68,19 @@ class SACPolicy(DDPGPolicy):
|
||||
self._alpha = alpha
|
||||
self.__eps = np.finfo(np.float32).eps.item()
|
||||
|
||||
def train(self):
|
||||
def train(self) -> None:
|
||||
self.training = True
|
||||
self.actor.train()
|
||||
self.critic1.train()
|
||||
self.critic2.train()
|
||||
|
||||
def eval(self):
|
||||
def eval(self) -> None:
|
||||
self.training = False
|
||||
self.actor.eval()
|
||||
self.critic1.eval()
|
||||
self.critic2.eval()
|
||||
|
||||
def sync_weight(self):
|
||||
def sync_weight(self) -> None:
|
||||
for o, n in zip(
|
||||
self.critic1_old.parameters(), self.critic1.parameters()):
|
||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||
@ -77,7 +88,9 @@ class SACPolicy(DDPGPolicy):
|
||||
self.critic2_old.parameters(), self.critic2.parameters()):
|
||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||
|
||||
def forward(self, batch, state=None, input='obs', **kwargs):
|
||||
def forward(self, batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
input: Optional[str] = 'obs', **kwargs) -> Batch:
|
||||
obs = getattr(batch, input)
|
||||
logits, h = self.actor(obs, state=state, info=batch.info)
|
||||
assert isinstance(logits, tuple)
|
||||
@ -92,7 +105,7 @@ class SACPolicy(DDPGPolicy):
|
||||
return Batch(
|
||||
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
|
||||
|
||||
def learn(self, batch, **kwargs):
|
||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||
with torch.no_grad():
|
||||
obs_next_result = self(batch, input='obs_next')
|
||||
a_ = obs_next_result.act
|
||||
|
@ -1,7 +1,9 @@
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Tuple, Optional
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import DDPGPolicy
|
||||
|
||||
|
||||
@ -32,7 +34,7 @@ class TD3Policy(DDPGPolicy):
|
||||
:param float noise_clip: the clipping range used in updating policy
|
||||
network, default to 0.5.
|
||||
:param action_range: the action range (minimum, maximum).
|
||||
:type action_range: [float, float]
|
||||
:type action_range: (float, float)
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||
defaults to ``False``.
|
||||
:param bool ignore_done: ignore the done flag while training the policy,
|
||||
@ -44,11 +46,23 @@ class TD3Policy(DDPGPolicy):
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, actor, actor_optim, critic1, critic1_optim,
|
||||
critic2, critic2_optim, tau=0.005, gamma=0.99,
|
||||
exploration_noise=0.1, policy_noise=0.2, update_actor_freq=2,
|
||||
noise_clip=0.5, action_range=None,
|
||||
reward_normalization=False, ignore_done=False, **kwargs):
|
||||
def __init__(self,
|
||||
actor: torch.nn.Module,
|
||||
actor_optim: torch.optim.Optimizer,
|
||||
critic1: torch.nn.Module,
|
||||
critic1_optim: torch.optim.Optimizer,
|
||||
critic2: torch.nn.Module,
|
||||
critic2_optim: torch.optim.Optimizer,
|
||||
tau: Optional[float] = 0.005,
|
||||
gamma: Optional[float] = 0.99,
|
||||
exploration_noise: Optional[float] = 0.1,
|
||||
policy_noise: Optional[float] = 0.2,
|
||||
update_actor_freq: Optional[int] = 2,
|
||||
noise_clip: Optional[float] = 0.5,
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
reward_normalization: Optional[bool] = False,
|
||||
ignore_done: Optional[bool] = False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(actor, actor_optim, None, None, tau, gamma,
|
||||
exploration_noise, action_range, reward_normalization,
|
||||
ignore_done, **kwargs)
|
||||
@ -64,19 +78,19 @@ class TD3Policy(DDPGPolicy):
|
||||
self._cnt = 0
|
||||
self._last = 0
|
||||
|
||||
def train(self):
|
||||
def train(self) -> None:
|
||||
self.training = True
|
||||
self.actor.train()
|
||||
self.critic1.train()
|
||||
self.critic2.train()
|
||||
|
||||
def eval(self):
|
||||
def eval(self) -> None:
|
||||
self.training = False
|
||||
self.actor.eval()
|
||||
self.critic1.eval()
|
||||
self.critic2.eval()
|
||||
|
||||
def sync_weight(self):
|
||||
def sync_weight(self) -> None:
|
||||
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
|
||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||
for o, n in zip(
|
||||
@ -86,7 +100,7 @@ class TD3Policy(DDPGPolicy):
|
||||
self.critic2_old.parameters(), self.critic2.parameters()):
|
||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||
|
||||
def learn(self, batch, **kwargs):
|
||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||
with torch.no_grad():
|
||||
a_ = self(batch, model='actor_old', input='obs_next').act
|
||||
dev = a_.device
|
||||
|
@ -1,16 +1,33 @@
|
||||
import time
|
||||
import tqdm
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from typing import Dict, List, Union, Callable, Optional
|
||||
|
||||
from tianshou.data import Collector
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.utils import tqdm_config, MovAvg
|
||||
from tianshou.trainer import test_episode, gather_info
|
||||
|
||||
|
||||
def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
step_per_epoch, collect_per_step, episode_per_test,
|
||||
batch_size,
|
||||
train_fn=None, test_fn=None, stop_fn=None, save_fn=None,
|
||||
log_fn=None, writer=None, log_interval=1, verbose=True,
|
||||
**kwargs):
|
||||
def offpolicy_trainer(
|
||||
policy: BasePolicy,
|
||||
train_collector: Collector,
|
||||
test_collector: Collector,
|
||||
max_epoch: int,
|
||||
step_per_epoch: int,
|
||||
collect_per_step: int,
|
||||
episode_per_test: Union[int, List[int]],
|
||||
batch_size: int,
|
||||
train_fn: Optional[Callable[[int], None]] = None,
|
||||
test_fn: Optional[Callable[[int], None]] = None,
|
||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
log_fn: Optional[Callable[[dict], None]] = None,
|
||||
writer: Optional[SummaryWriter] = None,
|
||||
log_interval: Optional[int] = 1,
|
||||
verbose: Optional[bool] = True,
|
||||
**kwargs
|
||||
) -> Dict[str, Union[float, str]]:
|
||||
"""A wrapper for off-policy trainer procedure.
|
||||
|
||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
||||
|
@ -1,16 +1,34 @@
|
||||
import time
|
||||
import tqdm
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from typing import Dict, List, Union, Callable, Optional
|
||||
|
||||
from tianshou.data import Collector
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.utils import tqdm_config, MovAvg
|
||||
from tianshou.trainer import test_episode, gather_info
|
||||
|
||||
|
||||
def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
step_per_epoch, collect_per_step, repeat_per_collect,
|
||||
episode_per_test, batch_size,
|
||||
train_fn=None, test_fn=None, stop_fn=None, save_fn=None,
|
||||
log_fn=None, writer=None, log_interval=1, verbose=True,
|
||||
**kwargs):
|
||||
def onpolicy_trainer(
|
||||
policy: BasePolicy,
|
||||
train_collector: Collector,
|
||||
test_collector: Collector,
|
||||
max_epoch: int,
|
||||
step_per_epoch: int,
|
||||
collect_per_step: int,
|
||||
repeat_per_collect: int,
|
||||
episode_per_test: Union[int, List[int]],
|
||||
batch_size: int,
|
||||
train_fn: Optional[Callable[[int], None]] = None,
|
||||
test_fn: Optional[Callable[[int], None]] = None,
|
||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
log_fn: Optional[Callable[[dict], None]] = None,
|
||||
writer: Optional[SummaryWriter] = None,
|
||||
log_interval: Optional[int] = 1,
|
||||
verbose: Optional[bool] = True,
|
||||
**kwargs
|
||||
) -> Dict[str, Union[float, str]]:
|
||||
"""A wrapper for on-policy trainer procedure.
|
||||
|
||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy`
|
||||
|
@ -1,8 +1,17 @@
|
||||
import time
|
||||
import numpy as np
|
||||
from typing import Dict, List, Union, Callable
|
||||
|
||||
from tianshou.data import Collector
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
|
||||
def test_episode(policy, collector, test_fn, epoch, n_episode):
|
||||
def test_episode(
|
||||
policy: BasePolicy,
|
||||
collector: Collector,
|
||||
test_fn: Callable[[int], None],
|
||||
epoch: int,
|
||||
n_episode: Union[int, List[int]]) -> Dict[str, float]:
|
||||
"""A simple wrapper of testing policy in collector."""
|
||||
collector.reset_env()
|
||||
collector.reset_buffer()
|
||||
@ -17,7 +26,11 @@ def test_episode(policy, collector, test_fn, epoch, n_episode):
|
||||
return collector.collect(n_episode=n_episode)
|
||||
|
||||
|
||||
def gather_info(start_time, train_c, test_c, best_reward):
|
||||
def gather_info(start_time: float,
|
||||
train_c: Collector,
|
||||
test_c: Collector,
|
||||
best_reward: float
|
||||
) -> Dict[str, Union[float, str]]:
|
||||
"""A simple wrapper of gathering information from collectors.
|
||||
|
||||
:return: A dictionary with the following keys:
|
||||
|
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Union, Optional
|
||||
|
||||
|
||||
class MovAvg(object):
|
||||
@ -19,19 +20,20 @@ class MovAvg(object):
|
||||
>>> print(f'{stat.mean():.2f}±{stat.std():.2f}')
|
||||
6.50±1.12
|
||||
"""
|
||||
def __init__(self, size=100):
|
||||
|
||||
def __init__(self, size: Optional[int] = 100) -> None:
|
||||
super().__init__()
|
||||
self.size = size
|
||||
self.cache = []
|
||||
self.banned = [np.inf, np.nan, -np.inf]
|
||||
|
||||
def add(self, x):
|
||||
def add(self, x: Union[float, list, np.ndarray, torch.Tensor]) -> float:
|
||||
"""Add a scalar into :class:`MovAvg`. You can add ``torch.Tensor`` with
|
||||
only one element, a python scalar, or a list of python scalar.
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.item()
|
||||
if isinstance(x, list):
|
||||
if isinstance(x, list) or isinstance(x, np.ndarray):
|
||||
for _ in x:
|
||||
if _ not in self.banned:
|
||||
self.cache.append(_)
|
||||
@ -41,17 +43,17 @@ class MovAvg(object):
|
||||
self.cache = self.cache[-self.size:]
|
||||
return self.get()
|
||||
|
||||
def get(self):
|
||||
def get(self) -> float:
|
||||
"""Get the average."""
|
||||
if len(self.cache) == 0:
|
||||
return 0
|
||||
return np.mean(self.cache)
|
||||
|
||||
def mean(self):
|
||||
def mean(self) -> float:
|
||||
"""Get the average. Same as :meth:`get`."""
|
||||
return self.get()
|
||||
|
||||
def std(self):
|
||||
def std(self) -> float:
|
||||
"""Get the standard deviation."""
|
||||
if len(self.cache) == 0:
|
||||
return 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user