code format and update function signatures (#213)

Cherry-pick from #200 

- update the function signature
- format code-style
- move _compile into separate functions
- fix a bug in to_torch and to_numpy (Batch)
- remove None in action_range

In short, the code-format only contains function-signature style and `'` -> `"`. (pick up from [black](https://github.com/psf/black))
This commit is contained in:
n+e 2020-09-12 15:39:01 +08:00 committed by GitHub
parent 16d8e9b051
commit c91def6cbc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
51 changed files with 1325 additions and 991 deletions

View File

@ -119,8 +119,8 @@ def test_sac_bipedal(args=get_args()):
policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, args.alpha,
[env.action_space.low[0], env.action_space.high[0]],
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=args.rew_norm,
ignore_done=args.ignore_done,
estimation_step=args.n_step)

View File

@ -78,14 +78,12 @@ def test_sac(args=get_args()):
target_entropy = -np.prod(env.action_space.shape)
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
alpha = (target_entropy, log_alpha, alpha_optim)
else:
alpha = args.alpha
args.alpha = (target_entropy, log_alpha, alpha_optim)
policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, alpha,
[env.action_space.low[0], env.action_space.high[0]],
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=args.rew_norm, ignore_done=True,
exploration_noise=OUNoise(0.0, args.noise_std))
# collector

View File

@ -66,8 +66,9 @@ def test_ddpg(args=get_args()):
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
policy = DDPGPolicy(
actor, actor_optim, critic, critic_optim,
args.tau, args.gamma, GaussianNoise(sigma=args.exploration_noise),
[env.action_space.low[0], env.action_space.high[0]],
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
reward_normalization=True, ignore_done=True)
# collector
train_collector = Collector(

View File

@ -71,8 +71,8 @@ def test_sac(args=get_args()):
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, args.alpha,
[env.action_space.low[0], env.action_space.high[0]],
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=args.rew_norm, ignore_done=True)
# collector
train_collector = Collector(

View File

@ -73,10 +73,12 @@ def test_td3(args=get_args()):
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma,
GaussianNoise(sigma=args.exploration_noise), args.policy_noise,
args.update_actor_freq, args.noise_clip,
[env.action_space.low[0], env.action_space.high[0]],
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
policy_noise=args.policy_noise,
update_actor_freq=args.update_actor_freq,
noise_clip=args.noise_clip,
reward_normalization=True, ignore_done=True)
# collector
train_collector = Collector(

View File

@ -79,8 +79,8 @@ def test_sac(args=get_args()):
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, args.alpha,
[env.action_space.low[0], env.action_space.high[0]],
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=True, ignore_done=True)
# collector
train_collector = Collector(

View File

@ -76,10 +76,12 @@ def test_td3(args=get_args()):
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma,
GaussianNoise(sigma=args.exploration_noise), args.policy_noise,
args.update_actor_freq, args.noise_clip,
[env.action_space.low[0], env.action_space.high[0]],
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
policy_noise=args.policy_noise,
update_actor_freq=args.update_actor_freq,
noise_clip=args.noise_clip,
reward_normalization=True, ignore_done=True)
# collector
train_collector = Collector(

View File

@ -78,8 +78,9 @@ def test_ddpg(args=get_args()):
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
policy = DDPGPolicy(
actor, actor_optim, critic, critic_optim,
args.tau, args.gamma, GaussianNoise(sigma=args.exploration_noise),
[env.action_space.low[0], env.action_space.high[0]],
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
reward_normalization=args.rew_norm,
ignore_done=args.ignore_done,
estimation_step=args.n_step)

View File

@ -79,8 +79,8 @@ def test_sac_with_il(args=get_args()):
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, args.alpha,
[env.action_space.low[0], env.action_space.high[0]],
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=args.rew_norm,
ignore_done=args.ignore_done,
estimation_step=args.n_step)

View File

@ -82,9 +82,12 @@ def test_td3(args=get_args()):
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, GaussianNoise(sigma=args.exploration_noise),
args.policy_noise, args.update_actor_freq, args.noise_clip,
[env.action_space.low[0], env.action_space.high[0]],
action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
policy_noise=args.policy_noise,
update_actor_freq=args.update_actor_freq,
noise_clip=args.noise_clip,
reward_normalization=args.rew_norm,
ignore_done=args.ignore_done,
estimation_step=args.n_step)

View File

@ -1,17 +1,13 @@
from tianshou import data, env, utils, policy, trainer, exploration
# pre-compile some common-type function-call to produce the correct benchmark
# result: https://github.com/thu-ml/tianshou/pull/193#discussion_r480536371
utils.pre_compile()
__version__ = '0.2.7'
__version__ = "0.2.7"
__all__ = [
'env',
'data',
'utils',
'policy',
'trainer',
'exploration',
"env",
"data",
"utils",
"policy",
"trainer",
"exploration",
]

View File

@ -1,19 +1,18 @@
from tianshou.data.batch import Batch
from tianshou.data.utils.converter import to_numpy, to_torch, \
to_torch_as
from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as
from tianshou.data.utils.segtree import SegmentTree
from tianshou.data.buffer import ReplayBuffer, \
ListReplayBuffer, PrioritizedReplayBuffer
from tianshou.data.collector import Collector
__all__ = [
'Batch',
'to_numpy',
'to_torch',
'to_torch_as',
'SegmentTree',
'ReplayBuffer',
'ListReplayBuffer',
'PrioritizedReplayBuffer',
'Collector',
"Batch",
"to_numpy",
"to_torch",
"to_torch_as",
"SegmentTree",
"ReplayBuffer",
"ListReplayBuffer",
"PrioritizedReplayBuffer",
"Collector",
]

View File

@ -5,8 +5,8 @@ import numpy as np
from copy import deepcopy
from numbers import Number
from collections.abc import Collection
from typing import Any, List, Tuple, Union, Iterator, KeysView, ValuesView, \
ItemsView, Optional
from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, \
Sequence, KeysView, ValuesView, ItemsView
# Disable pickle warning related to torch, since it has been removed
# on torch master branch. See Pull Request #39003 for details:
@ -23,8 +23,8 @@ def _is_batch_set(data: Any) -> bool:
# "for e in data" will just unpack the first dimension,
# but data.tolist() will flatten ndarray of objects
# so do not use data.tolist()
return data.dtype == np.object and \
all(isinstance(e, (dict, Batch)) for e in data)
return data.dtype == np.object and all(
isinstance(e, (dict, Batch)) for e in data)
elif isinstance(data, (list, tuple)):
if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data):
return True
@ -54,8 +54,9 @@ def _is_number(value: Any) -> bool:
def _to_array_with_correct_type(v: Any) -> np.ndarray:
if isinstance(v, np.ndarray) and \
issubclass(v.dtype.type, (np.bool_, np.number)): # most often case
if isinstance(v, np.ndarray) and issubclass(
v.dtype.type, (np.bool_, np.number)
): # most often case
return v
# convert the value to np.ndarray
# convert to np.object data type if neither bool nor number
@ -71,14 +72,16 @@ def _to_array_with_correct_type(v: Any) -> np.ndarray:
# array([{}, array({}, dtype=object)], dtype=object)
if not v.shape:
v = v.item(0)
elif any(isinstance(e, (np.ndarray, torch.Tensor))
for e in v.reshape(-1)):
elif any(
isinstance(e, (np.ndarray, torch.Tensor)) for e in v.reshape(-1)
):
raise ValueError("Numpy arrays of tensors are not supported yet.")
return v
def _create_value(inst: Any, size: int, stack=True) -> Union[
'Batch', np.ndarray, torch.Tensor]:
def _create_value(
inst: Any, size: int, stack: bool = True
) -> Union["Batch", np.ndarray, torch.Tensor]:
"""Create empty place-holders accroding to inst's shape.
:param bool stack: whether to stack or to concatenate. E.g. if inst has
@ -100,12 +103,15 @@ def _create_value(inst: Any, size: int, stack=True) -> Union[
target_type = inst.dtype.type
else:
target_type = np.object
return np.full(shape,
fill_value=None if target_type == np.object else 0,
dtype=target_type)
return np.full(
shape,
fill_value=None if target_type == np.object else 0,
dtype=target_type
)
elif isinstance(inst, torch.Tensor):
return torch.full(shape,
fill_value=0, device=inst.device, dtype=inst.dtype)
return torch.full(
shape, fill_value=0, device=inst.device, dtype=inst.dtype
)
elif isinstance(inst, (dict, Batch)):
zero_batch = Batch()
for key, val in inst.items():
@ -117,12 +123,13 @@ def _create_value(inst: Any, size: int, stack=True) -> Union[
return np.array([None for _ in range(size)])
def _assert_type_keys(keys) -> None:
assert all(isinstance(e, str) for e in keys), \
f"keys should all be string, but got {keys}"
def _assert_type_keys(keys: Iterable[str]) -> None:
assert all(
isinstance(e, str) for e in keys
), f"keys should all be string, but got {keys}"
def _parse_value(v: Any):
def _parse_value(v: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]:
if isinstance(v, Batch): # most often case
return v
elif (isinstance(v, np.ndarray) and
@ -166,12 +173,14 @@ class Batch:
For a detailed description, please refer to :ref:`batch_concept`.
"""
def __init__(self,
batch_dict: Optional[Union[
dict, 'Batch', Tuple[Union[dict, 'Batch']],
List[Union[dict, 'Batch']], np.ndarray]] = None,
copy: bool = False,
**kwargs) -> None:
def __init__(
self,
batch_dict: Optional[
Union[dict, "Batch", Sequence[Union[dict, "Batch"]], np.ndarray]
] = None,
copy: bool = False,
**kwargs: Any,
) -> None:
if copy:
batch_dict = deepcopy(batch_dict)
if batch_dict is not None:
@ -188,7 +197,7 @@ class Batch:
"""Set self.key = value."""
self.__dict__[key] = _parse_value(value)
def __getstate__(self) -> dict:
def __getstate__(self) -> Dict[str, Any]:
"""Pickling interface.
Only the actual data are serialized for both efficiency and simplicity.
@ -200,7 +209,7 @@ class Batch:
state[k] = v
return state
def __setstate__(self, state) -> None:
def __setstate__(self, state: Dict[str, Any]) -> None:
"""Unpickling interface.
At this point, self is an empty Batch instance that has not been
@ -208,8 +217,9 @@ class Batch:
"""
self.__init__(**state)
def __getitem__(self, index: Union[
str, slice, int, np.integer, np.ndarray, List[int]]) -> 'Batch':
def __getitem__(
self, index: Union[str, slice, int, np.integer, np.ndarray, List[int]]
) -> Union["Batch", np.ndarray, torch.Tensor]:
"""Return self[index]."""
if isinstance(index, str):
return self.__dict__[index]
@ -225,9 +235,11 @@ class Batch:
else:
raise IndexError("Cannot access item from empty Batch object.")
def __setitem__(self, index: Union[
str, slice, int, np.integer, np.ndarray, List[int]],
value: Any) -> None:
def __setitem__(
self,
index: Union[str, slice, int, np.integer, np.ndarray, List[int]],
value: Any,
) -> None:
"""Assign value to self[index]."""
value = _parse_value(value)
if isinstance(index, str):
@ -252,12 +264,12 @@ class Batch:
else:
self.__dict__[key][index] = None
def __iadd__(self, other: Union['Batch', Number, np.number]):
def __iadd__(self, other: Union["Batch", Number, np.number]) -> "Batch":
"""Algebraic addition with another Batch instance in-place."""
if isinstance(other, Batch):
for (k, r), v in zip(self.__dict__.items(),
other.__dict__.values()):
# TODO are keys consistent?
for (k, r), v in zip(
self.__dict__.items(), other.__dict__.values()
): # TODO are keys consistent?
if isinstance(r, Batch) and r.is_empty():
continue
else:
@ -273,11 +285,11 @@ class Batch:
else:
raise TypeError("Only addition of Batch or number is supported.")
def __add__(self, other: Union['Batch', Number, np.number]):
def __add__(self, other: Union["Batch", Number, np.number]) -> "Batch":
"""Algebraic addition with another Batch instance out-of-place."""
return deepcopy(self).__iadd__(other)
def __imul__(self, val: Union[Number, np.number]):
def __imul__(self, val: Union[Number, np.number]) -> "Batch":
"""Algebraic multiplication with a scalar value in-place."""
assert _is_number(val), "Only multiplication by a number is supported."
for k, r in self.__dict__.items():
@ -286,11 +298,11 @@ class Batch:
self.__dict__[k] *= val
return self
def __mul__(self, val: Union[Number, np.number]):
def __mul__(self, val: Union[Number, np.number]) -> "Batch":
"""Algebraic multiplication with a scalar value out-of-place."""
return deepcopy(self).__imul__(val)
def __itruediv__(self, val: Union[Number, np.number]):
def __itruediv__(self, val: Union[Number, np.number]) -> "Batch":
"""Algebraic division with a scalar value in-place."""
assert _is_number(val), "Only division by a number is supported."
for k, r in self.__dict__.items():
@ -299,23 +311,23 @@ class Batch:
self.__dict__[k] /= val
return self
def __truediv__(self, val: Union[Number, np.number]):
def __truediv__(self, val: Union[Number, np.number]) -> "Batch":
"""Algebraic division with a scalar value out-of-place."""
return deepcopy(self).__itruediv__(val)
def __repr__(self) -> str:
"""Return str(self)."""
s = self.__class__.__name__ + '(\n'
s = self.__class__.__name__ + "(\n"
flag = False
for k, v in self.__dict__.items():
rpl = '\n' + ' ' * (6 + len(k))
obj = pprint.pformat(v).replace('\n', rpl)
s += f' {k}: {obj},\n'
rpl = "\n" + " " * (6 + len(k))
obj = pprint.pformat(v).replace("\n", rpl)
s += f" {k}: {obj},\n"
flag = True
if flag:
s += ')'
s += ")"
else:
s = self.__class__.__name__ + '()'
s = self.__class__.__name__ + "()"
return s
def __contains__(self, key: str) -> bool:
@ -350,8 +362,11 @@ class Batch:
elif isinstance(v, Batch):
v.to_numpy()
def to_torch(self, dtype: Optional[torch.dtype] = None,
device: Union[str, int, torch.device] = 'cpu') -> None:
def to_torch(
self,
dtype: Optional[torch.dtype] = None,
device: Union[str, int, torch.device] = "cpu",
) -> None:
"""Change all numpy.ndarray to torch.Tensor in-place."""
if not isinstance(device, torch.device):
device = torch.device(device)
@ -376,9 +391,9 @@ class Batch:
v = v.type(dtype)
self.__dict__[k] = v
def __cat(self,
batches: List[Union[dict, 'Batch']],
lens: List[int]) -> None:
def __cat(
self, batches: Sequence[Union[dict, "Batch"]], lens: List[int]
) -> None:
"""Private method for Batch.cat_.
::
@ -445,8 +460,9 @@ class Batch:
val, sum_lens[-1], stack=False)
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
def cat_(self,
batches: Union['Batch', List[Union[dict, 'Batch']]]) -> None:
def cat_(
self, batches: Union["Batch", Sequence[Union[dict, "Batch"]]]
) -> None:
"""Concatenate a list of (or one) Batch objects into current batch."""
if isinstance(batches, Batch):
batches = [batches]
@ -460,20 +476,19 @@ class Batch:
# x.is_empty(recurse=True) here means x is a nested empty batch
# like Batch(a=Batch), and we have to treat it as length zero and
# keep it.
lens = [0 if x.is_empty(recurse=True) else len(x)
for x in batches]
lens = [0 if x.is_empty(recurse=True) else len(x) for x in batches]
except TypeError as e:
raise ValueError(
f'Batch.cat_ meets an exception. Maybe because there is any '
f'scalar in {batches} but Batch.cat_ does not support the '
f'concatenation of scalar.') from e
"Batch.cat_ meets an exception. Maybe because there is any "
f"scalar in {batches} but Batch.cat_ does not support the "
"concatenation of scalar.") from e
if not self.is_empty():
batches = [self] + list(batches)
lens = [0 if self.is_empty(recurse=True) else len(self)] + lens
return self.__cat(batches, lens)
self.__cat(batches, lens)
@staticmethod
def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch':
def cat(batches: Sequence[Union[dict, "Batch"]]) -> "Batch":
"""Concatenate a list of Batch object into a single new batch.
For keys that are not shared across all batches, batches that do not
@ -494,9 +509,9 @@ class Batch:
batch.cat_(batches)
return batch
def stack_(self,
batches: List[Union[dict, 'Batch']],
axis: int = 0) -> None:
def stack_(
self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0
) -> None:
"""Stack a list of Batch object into current batch."""
if len(batches) == 0:
return
@ -528,8 +543,8 @@ class Batch:
keys_partial = keys_reserve_or_partial.difference(keys_reserve)
if keys_partial and axis != 0:
raise ValueError(
f"Stack of Batch with non-shared keys {keys_partial} "
f"is only supported with axis=0, but got axis={axis}!")
f"Stack of Batch with non-shared keys {keys_partial} is only "
f"supported with axis=0, but got axis={axis}!")
for k in keys_reserve:
# reserved keys
self.__dict__[k] = Batch()
@ -547,7 +562,9 @@ class Batch:
self.__dict__[k][i] = val
@staticmethod
def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch':
def stack(
batches: Sequence[Union[dict, "Batch"]], axis: int = 0
) -> "Batch":
"""Stack a list of Batch object into a single new batch.
For keys that are not shared across all batches, batches that do not
@ -573,9 +590,12 @@ class Batch:
batch.stack_(batches, axis)
return batch
def empty_(self, index: Union[
str, slice, int, np.integer, np.ndarray, List[int]] = None
) -> 'Batch':
def empty_(
self,
index: Union[
str, slice, int, np.integer, np.ndarray, List[int]
] = None,
) -> "Batch":
"""Return an empty Batch object with 0 or None filled.
If "index" is specified, it will only reset the specific indexed-data.
@ -613,8 +633,8 @@ class Batch:
elif isinstance(v, Batch):
self.__dict__[k].empty_(index=index)
else: # scalar value
warnings.warn('You are calling Batch.empty on a NumPy scalar, '
'which may cause undefined behaviors.')
warnings.warn("You are calling Batch.empty on a NumPy scalar, "
"which may cause undefined behaviors.")
if _is_number(v):
self.__dict__[k] = v.__class__(0)
else:
@ -622,17 +642,21 @@ class Batch:
return self
@staticmethod
def empty(batch: 'Batch', index: Union[
str, slice, int, np.integer, np.ndarray, List[int]] = None
) -> 'Batch':
def empty(
batch: "Batch",
index: Union[
str, slice, int, np.integer, np.ndarray, List[int]
] = None,
) -> "Batch":
"""Return an empty Batch object with 0 or None filled.
The shape is the same as the given Batch.
"""
return deepcopy(batch).empty_(index)
def update(self, batch: Optional[Union[dict, 'Batch']] = None,
**kwargs) -> None:
def update(
self, batch: Optional[Union[dict, "Batch"]] = None, **kwargs: Any
) -> None:
"""Update this batch from another dict/Batch."""
if batch is None:
self.update(kwargs)
@ -648,8 +672,9 @@ class Batch:
for v in self.__dict__.values():
if isinstance(v, Batch) and v.is_empty(recurse=True):
continue
elif hasattr(v, '__len__') and (not isinstance(
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
elif hasattr(v, "__len__") and (not isinstance(
v, (np.ndarray, torch.Tensor)) or v.ndim > 0
):
r.append(len(v))
else:
raise TypeError(f"Object {v} in {self} has no len()")
@ -659,7 +684,7 @@ class Batch:
raise TypeError(f"Object {self} has no len()")
return min(r)
def is_empty(self, recurse: bool = False):
def is_empty(self, recurse: bool = False) -> bool:
"""Test if a Batch is empty.
If ``recurse=True``, it further tests the values of the object; else
@ -689,8 +714,9 @@ class Batch:
return True
if not recurse:
return False
return all(False if not isinstance(x, Batch)
else x.is_empty(recurse=True) for x in self.values())
return all(
False if not isinstance(x, Batch) else x.is_empty(recurse=True)
for x in self.values())
@property
def shape(self) -> List[int]:
@ -707,8 +733,9 @@ class Batch:
return list(map(min, zip(*data_shape))) if len(data_shape) > 1 \
else data_shape[0]
def split(self, size: int, shuffle: bool = True,
merge_last: bool = False) -> Iterator['Batch']:
def split(
self, size: int, shuffle: bool = True, merge_last: bool = False
) -> Iterator["Batch"]:
"""Split whole data into multiple small batches.
:param int size: divide the data batch with the given size, but one

View File

@ -1,6 +1,7 @@
import torch
import numpy as np
from typing import Any, Tuple, Union, Optional
from numbers import Number
from typing import Any, Dict, Tuple, Union, Optional
from tianshou.data import Batch, SegmentTree, to_numpy
from tianshou.data.batch import _create_value
@ -11,7 +12,7 @@ class ReplayBuffer:
interaction between the policy and environment.
The current implementation of Tianshou typically use 7 reserved keys in
:class:`~tianshou.data.Batch`
:class:`~tianshou.data.Batch`:
* ``obs`` the observation of step :math:`t` ;
* ``act`` the action of step :math:`t` ;
@ -124,14 +125,17 @@ class ReplayBuffer:
This feature is not supported in Prioritized Replay Buffer currently.
"""
def __init__(self, size: int, stack_num: int = 1,
ignore_obs_next: bool = False,
save_only_last_obs: bool = False,
sample_avail: bool = False) -> None:
def __init__(
self,
size: int,
stack_num: int = 1,
ignore_obs_next: bool = False,
save_only_last_obs: bool = False,
sample_avail: bool = False,
) -> None:
super().__init__()
self._maxsize = size
self._indices = np.arange(size)
self._stack = None
self.stack_num = stack_num
self._avail = sample_avail and stack_num > 1
self._avail_index = []
@ -157,7 +161,7 @@ class ReplayBuffer:
except KeyError as e:
raise AttributeError from e
def __setstate__(self, state):
def __setstate__(self, state: Dict[str, Any]) -> None:
"""Unpickling interface.
We need it because pickling buffer does not work out-of-the-box
@ -171,11 +175,12 @@ class ReplayBuffer:
except KeyError:
self._meta.__dict__[name] = _create_value(inst, self._maxsize)
value = self._meta.__dict__[name]
if isinstance(inst, (np.ndarray, torch.Tensor)) \
and value.shape[1:] != inst.shape:
if isinstance(inst, (torch.Tensor, np.ndarray)) \
and inst.shape != value.shape[1:]:
raise ValueError(
"Cannot add data to a buffer with different shape, with key "
f"{name}, expect {value.shape[1:]}, given {inst.shape}.")
f"{name}, expect {value.shape[1:]}, given {inst.shape}."
)
try:
value[self._index] = inst
except KeyError:
@ -184,15 +189,15 @@ class ReplayBuffer:
value[self._index] = inst
@property
def stack_num(self):
def stack_num(self) -> int:
return self._stack
@stack_num.setter
def stack_num(self, num):
assert num > 0, 'stack_num should greater than 0'
def stack_num(self, num: int) -> None:
assert num > 0, "stack_num should greater than 0"
self._stack = num
def update(self, buffer: 'ReplayBuffer') -> None:
def update(self, buffer: "ReplayBuffer") -> None:
"""Move the data from the given buffer to self."""
if len(buffer) == 0:
return
@ -206,32 +211,35 @@ class ReplayBuffer:
break
buffer.stack_num = stack_num_orig
def add(self,
obs: Union[dict, Batch, np.ndarray, float],
act: Union[dict, Batch, np.ndarray, float],
rew: Union[int, float],
done: Union[bool, int],
obs_next: Optional[Union[dict, Batch, np.ndarray, float]] = None,
info: Optional[Union[dict, Batch]] = {},
policy: Optional[Union[dict, Batch]] = {},
**kwargs) -> None:
def add(
self,
obs: Any,
act: Any,
rew: Union[Number, np.number, np.ndarray],
done: Union[Number, np.number, np.bool_],
obs_next: Any = None,
info: Optional[Union[dict, Batch]] = {},
policy: Optional[Union[dict, Batch]] = {},
**kwargs: Any,
) -> None:
"""Add a batch of data into replay buffer."""
assert isinstance(info, (dict, Batch)), \
'You should return a dict in the last argument of env.step().'
assert isinstance(
info, (dict, Batch)
), "You should return a dict in the last argument of env.step()."
if self._last_obs:
obs = obs[-1]
self._add_to_buffer('obs', obs)
self._add_to_buffer('act', act)
self._add_to_buffer('rew', rew)
self._add_to_buffer('done', done)
self._add_to_buffer("obs", obs)
self._add_to_buffer("act", act)
self._add_to_buffer("rew", rew)
self._add_to_buffer("done", done)
if self._save_s_:
if obs_next is None:
obs_next = Batch()
elif self._last_obs:
obs_next = obs_next[-1]
self._add_to_buffer('obs_next', obs_next)
self._add_to_buffer('info', info)
self._add_to_buffer('policy', policy)
self._add_to_buffer("obs_next", obs_next)
self._add_to_buffer("info", info)
self._add_to_buffer("policy", policy)
# maintain available index for frame-stack sampling
if self._avail:
@ -262,7 +270,8 @@ class ReplayBuffer:
self._avail_index = []
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
"""Get a random sample from buffer with size equal to batch_size. \
"""Get a random sample from buffer with size equal to batch_size.
Return all the data in the buffer if batch_size is 0.
:return: Sample data and its corresponding index inside the buffer.
@ -278,11 +287,15 @@ class ReplayBuffer:
np.arange(self._index, self._size),
np.arange(0, self._index),
])
assert len(indice) > 0, 'No available indice can be sampled.'
assert len(indice) > 0, "No available indice can be sampled."
return self[indice], indice
def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str,
stack_num: Optional[int] = None) -> Union[Batch, np.ndarray]:
def get(
self,
indice: Union[slice, int, np.integer, np.ndarray],
key: str,
stack_num: Optional[int] = None,
) -> Union[Batch, np.ndarray]:
"""Return the stacked result.
E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the
@ -292,7 +305,7 @@ class ReplayBuffer:
if stack_num is None:
stack_num = self.stack_num
if stack_num == 1: # the most often case
if key != 'obs_next' or self._save_s_:
if key != "obs_next" or self._save_s_:
val = self._meta.__dict__[key]
try:
return val[indice]
@ -301,11 +314,11 @@ class ReplayBuffer:
raise e # val != Batch()
return Batch()
indice = self._indices[:self._size][indice]
done = self._meta.__dict__['done']
if key == 'obs_next' and not self._save_s_:
done = self._meta.__dict__["done"]
if key == "obs_next" and not self._save_s_:
indice += 1 - done[indice].astype(np.int)
indice[indice == self._size] = 0
key = 'obs'
key = "obs"
val = self._meta.__dict__[key]
try:
if stack_num == 1:
@ -319,30 +332,30 @@ class ReplayBuffer:
pre_indice + done[pre_indice].astype(np.int))
indice[indice == self._size] = 0
if isinstance(val, Batch):
stack = Batch.stack(stack, axis=indice.ndim)
return Batch.stack(stack, axis=indice.ndim)
else:
stack = np.stack(stack, axis=indice.ndim)
return stack
return np.stack(stack, axis=indice.ndim)
except IndexError as e:
if not (isinstance(val, Batch) and val.is_empty()):
raise e # val != Batch()
return Batch()
def __getitem__(self, index: Union[
slice, int, np.integer, np.ndarray]) -> Batch:
def __getitem__(
self, index: Union[slice, int, np.integer, np.ndarray]
) -> Batch:
"""Return a data batch: self[index].
If stack_num is larger than 1, return the stacked obs and obs_next
with shape (batch, len, ...).
If stack_num is larger than 1, return the stacked obs and obs_next with
shape (batch, len, ...).
"""
return Batch(
obs=self.get(index, 'obs'),
obs=self.get(index, "obs"),
act=self.act[index],
rew=self.rew[index],
done=self.done[index],
obs_next=self.get(index, 'obs_next'),
info=self.get(index, 'info'),
policy=self.get(index, 'policy'),
obs_next=self.get(index, "obs_next"),
info=self.get(index, "info"),
policy=self.get(index, "policy"),
)
@ -361,15 +374,15 @@ class ListReplayBuffer(ReplayBuffer):
explanation.
"""
def __init__(self, **kwargs) -> None:
def __init__(self, **kwargs: Any) -> None:
super().__init__(size=0, ignore_obs_next=False, **kwargs)
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
raise NotImplementedError("ListReplayBuffer cannot be sampled!")
def _add_to_buffer(
self, name: str,
inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None:
self, name: str, inst: Union[dict, Batch, np.ndarray, float, int, bool]
) -> None:
if self._meta.__dict__.get(name) is None:
self._meta.__dict__[name] = []
self._meta.__dict__[name].append(inst)
@ -393,25 +406,29 @@ class PrioritizedReplayBuffer(ReplayBuffer):
explanation.
"""
def __init__(self, size: int, alpha: float, beta: float, **kwargs) -> None:
def __init__(
self, size: int, alpha: float, beta: float, **kwargs: Any
) -> None:
super().__init__(size, **kwargs)
assert alpha > 0. and beta >= 0.
assert alpha > 0.0 and beta >= 0.0
self._alpha, self._beta = alpha, beta
self._max_prio = self._min_prio = 1.0
# save weight directly in this class instead of self._meta
self.weight = SegmentTree(size)
self.__eps = np.finfo(np.float32).eps.item()
def add(self,
obs: Union[dict, Batch, np.ndarray, float],
act: Union[dict, Batch, np.ndarray, float],
rew: Union[int, float],
done: Union[bool, int],
obs_next: Optional[Union[dict, Batch, np.ndarray, float]] = None,
info: Optional[Union[dict, Batch]] = {},
policy: Optional[Union[dict, Batch]] = {},
weight: Optional[float] = None,
**kwargs) -> None:
def add(
self,
obs: Any,
act: Any,
rew: Union[Number, np.number, np.ndarray],
done: Union[Number, np.number, np.bool_],
obs_next: Any = None,
info: Optional[Union[dict, Batch]] = {},
policy: Optional[Union[dict, Batch]] = {},
weight: Optional[Union[Number, np.number]] = None,
**kwargs: Any,
) -> None:
"""Add a batch of data into replay buffer."""
if weight is None:
weight = self._max_prio
@ -433,7 +450,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
to de-bias the sampling process (some transition tuples are sampled
more often so their losses are weighted less).
"""
assert self._size > 0, 'Cannot sample a buffer with 0 size!'
assert self._size > 0, "Cannot sample a buffer with 0 size!"
if batch_size == 0:
indice = np.concatenate([
np.arange(self._index, self._size),
@ -449,8 +466,11 @@ class PrioritizedReplayBuffer(ReplayBuffer):
batch.weight = (batch.weight / self._min_prio) ** (-self._beta)
return batch, indice
def update_weight(self, indice: Union[np.ndarray],
new_weight: Union[np.ndarray, torch.Tensor]) -> None:
def update_weight(
self,
indice: Union[np.ndarray],
new_weight: Union[np.ndarray, torch.Tensor]
) -> None:
"""Update priority weight by indice in this buffer.
:param np.ndarray indice: indice you want to update weight.
@ -461,15 +481,16 @@ class PrioritizedReplayBuffer(ReplayBuffer):
self._max_prio = max(self._max_prio, weight.max())
self._min_prio = min(self._min_prio, weight.min())
def __getitem__(self, index: Union[
slice, int, np.integer, np.ndarray]) -> Batch:
def __getitem__(
self, index: Union[slice, int, np.integer, np.ndarray]
) -> Batch:
return Batch(
obs=self.get(index, 'obs'),
obs=self.get(index, "obs"),
act=self.act[index],
rew=self.rew[index],
done=self.done[index],
obs_next=self.get(index, 'obs_next'),
info=self.get(index, 'info'),
policy=self.get(index, 'policy'),
obs_next=self.get(index, "obs_next"),
info=self.get(index, "info"),
policy=self.get(index, "policy"),
weight=self.weight[index],
)

View File

@ -4,13 +4,14 @@ import torch
import warnings
import numpy as np
from copy import deepcopy
from typing import Any, Dict, List, Union, Optional, Callable
from numbers import Number
from typing import Dict, List, Union, Optional, Callable
from tianshou.env import BaseVectorEnv, DummyVectorEnv
from tianshou.policy import BasePolicy
from tianshou.exploration import BaseNoise
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
from tianshou.data.batch import _create_value
from tianshou.env import BaseVectorEnv, DummyVectorEnv
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
class Collector(object):
@ -75,14 +76,15 @@ class Collector(object):
Please make sure the given environment has a time limitation.
"""
def __init__(self,
policy: BasePolicy,
env: Union[gym.Env, BaseVectorEnv],
buffer: Optional[ReplayBuffer] = None,
preprocess_fn: Callable[[Any], Batch] = None,
action_noise: Optional[BaseNoise] = None,
reward_metric: Optional[Callable[[np.ndarray], float]] = None,
) -> None:
def __init__(
self,
policy: BasePolicy,
env: Union[gym.Env, BaseVectorEnv],
buffer: Optional[ReplayBuffer] = None,
preprocess_fn: Optional[Callable[..., Batch]] = None,
action_noise: Optional[BaseNoise] = None,
reward_metric: Optional[Callable[[np.ndarray], float]] = None,
) -> None:
super().__init__()
if not isinstance(env, BaseVectorEnv):
env = DummyVectorEnv([lambda: env])
@ -108,12 +110,15 @@ class Collector(object):
self.reset()
@staticmethod
def _default_rew_metric(x):
def _default_rew_metric(
x: Union[Number, np.number]
) -> Union[Number, np.number]:
# this internal function is designed for single-agent RL
# for multi-agent RL, a reward_metric must be provided
assert np.asanyarray(x).size == 1, \
'Please specify the reward_metric ' \
'since the reward is not a scalar.'
assert np.asanyarray(x).size == 1, (
"Please specify the reward_metric "
"since the reward is not a scalar."
)
return x
def reset(self) -> None:
@ -124,7 +129,7 @@ class Collector(object):
obs_next={}, policy={})
self.reset_env()
self.reset_buffer()
self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0
if self._action_noise is not None:
self._action_noise.reset()
@ -142,7 +147,7 @@ class Collector(object):
self._ready_env_ids = np.arange(self.env_num)
obs = self.env.reset()
if self.preprocess_fn:
obs = self.preprocess_fn(obs=obs).get('obs', obs)
obs = self.preprocess_fn(obs=obs).get("obs", obs)
self.data.obs = obs
for b in self._cached_buf:
b.reset()
@ -157,13 +162,14 @@ class Collector(object):
elif isinstance(state, Batch):
state.empty_(id)
def collect(self,
n_step: Optional[int] = None,
n_episode: Optional[Union[int, List[int]]] = None,
random: bool = False,
render: Optional[float] = None,
no_grad: bool = True,
) -> Dict[str, float]:
def collect(
self,
n_step: Optional[int] = None,
n_episode: Optional[Union[int, List[int]]] = None,
random: bool = False,
render: Optional[float] = None,
no_grad: bool = True,
) -> Dict[str, float]:
"""Collect a specified number of step or episode.
:param int n_step: how many steps you want to collect.
@ -217,8 +223,8 @@ class Collector(object):
while True:
if step_count >= 100000 and episode_count.sum() == 0:
warnings.warn(
'There are already many steps in an episode. '
'You should add a time limitation to your environment!',
"There are already many steps in an episode. "
"You should add a time limitation to your environment!",
Warning)
is_async = self.is_async or len(finished_env_ids) > 0
@ -250,11 +256,11 @@ class Collector(object):
else:
result = self.policy(self.data, last_state)
state = result.get('state', Batch())
state = result.get("state", Batch())
# convert None to Batch(), since None is reserved for 0-init
if state is None:
state = Batch()
self.data.update(state=state, policy=result.get('policy', Batch()))
self.data.update(state=state, policy=result.get("policy", Batch()))
# save hidden state to policy._state, in order to save into buffer
if not (isinstance(state, Batch) and state.is_empty()):
self.data.policy._state = self.data.state
@ -268,12 +274,12 @@ class Collector(object):
obs_next, rew, done, info = self.env.step(self.data.act)
else:
# store computed actions, states, etc
_batch_set_item(whole_data, self._ready_env_ids,
self.data, self.env_num)
_batch_set_item(
whole_data, self._ready_env_ids, self.data, self.env_num)
# fetch finished data
obs_next, rew, done, info = self.env.step(
self.data.act, id=self._ready_env_ids)
self._ready_env_ids = np.array([i['env_id'] for i in info])
self._ready_env_ids = np.array([i["env_id"] for i in info])
# get the stepped data
self.data = whole_data[self._ready_env_ids]
# move data to self.data
@ -319,15 +325,15 @@ class Collector(object):
obs_reset = self.env.reset(env_ind_global)
if self.preprocess_fn:
obs_next[env_ind_local] = self.preprocess_fn(
obs=obs_reset).get('obs', obs_reset)
obs=obs_reset).get("obs", obs_reset)
else:
obs_next[env_ind_local] = obs_reset
self.data.obs = obs_next
if is_async:
# set data back
whole_data = deepcopy(whole_data) # avoid reference in ListBuf
_batch_set_item(whole_data, self._ready_env_ids,
self.data, self.env_num)
_batch_set_item(
whole_data, self._ready_env_ids, self.data, self.env_num)
# let self.data be the data in all environments again
self.data = whole_data
self._ready_env_ids = np.array(
@ -358,12 +364,12 @@ class Collector(object):
if np.asanyarray(reward_avg).size > 1: # non-scalar reward_avg
reward_avg = self._rew_metric(reward_avg)
return {
'n/ep': episode_count,
'n/st': step_count,
'v/st': step_count / duration,
'v/ep': episode_count / duration,
'rew': reward_avg,
'len': step_count / episode_count,
"n/ep": episode_count,
"n/st": step_count,
"v/st": step_count / duration,
"v/ep": episode_count / duration,
"rew": reward_avg,
"len": step_count / episode_count,
}
def sample(self, batch_size: int) -> Batch:
@ -377,9 +383,9 @@ class Collector(object):
batch_size.
"""
warnings.warn(
'Collector.sample is deprecated and will cause error if you use '
'prioritized experience replay! Collector.sample will be removed '
'upon version 0.3. Use policy.update instead!', Warning)
"Collector.sample is deprecated and will cause error if you use "
"prioritized experience replay! Collector.sample will be removed "
"upon version 0.3. Use policy.update instead!", Warning)
assert self.buffer is not None, "Cannot get sample from empty buffer!"
batch_data, indice = self.buffer.sample(batch_size)
batch_data = self.process_fn(batch_data, self.buffer, indice)
@ -387,12 +393,13 @@ class Collector(object):
def close(self) -> None:
warnings.warn(
'Collector.close is deprecated and will be removed upon version '
'0.3.', Warning)
"Collector.close is deprecated and will be removed upon version "
"0.3.", Warning)
def _batch_set_item(source: Batch, indices: np.ndarray,
target: Batch, size: int):
def _batch_set_item(
source: Batch, indices: np.ndarray, target: Batch, size: int
) -> None:
# for any key chain k, there are four cases
# 1. source[k] is non-reserved, but target[k] does not exist or is reserved
# 2. source[k] does not exist or is reserved, but target[k] is non-reserved

View File

@ -1,72 +1,79 @@
import torch
import numpy as np
from copy import deepcopy
from numbers import Number
from typing import Union, Optional
from tianshou.data.batch import _parse_value, Batch
def to_numpy(x: Union[
Batch, dict, list, tuple, np.ndarray, torch.Tensor]) -> Union[
Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
def to_numpy(
x: Optional[Union[Batch, dict, list, tuple, np.number, np.bool_, Number,
np.ndarray, torch.Tensor]]
) -> Union[Batch, dict, list, tuple, np.ndarray]:
"""Return an object without torch.Tensor."""
if isinstance(x, torch.Tensor): # most often case
x = x.detach().cpu().numpy()
return x.detach().cpu().numpy()
elif isinstance(x, np.ndarray): # second often case
pass
return x
elif isinstance(x, (np.number, np.bool_, Number)):
x = np.asanyarray(x)
return np.asanyarray(x)
elif x is None:
x = np.array(None, dtype=np.object)
return np.array(None, dtype=np.object)
elif isinstance(x, Batch):
x = deepcopy(x)
x.to_numpy()
return x
elif isinstance(x, dict):
for k, v in x.items():
x[k] = to_numpy(v)
return {k: to_numpy(v) for k, v in x.items()}
elif isinstance(x, (list, tuple)):
try:
x = to_numpy(_parse_value(x))
return to_numpy(_parse_value(x))
except TypeError:
x = [to_numpy(e) for e in x]
return [to_numpy(e) for e in x]
else: # fallback
x = np.asanyarray(x)
return x
return np.asanyarray(x)
def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor],
dtype: Optional[torch.dtype] = None,
device: Union[str, int, torch.device] = 'cpu'
) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
def to_torch(
x: Union[Batch, dict, list, tuple, np.number, np.bool_, Number, np.ndarray,
torch.Tensor],
dtype: Optional[torch.dtype] = None,
device: Union[str, int, torch.device] = "cpu",
) -> Union[Batch, dict, list, tuple, torch.Tensor]:
"""Return an object without np.ndarray."""
if isinstance(x, np.ndarray) and \
issubclass(x.dtype.type, (np.bool_, np.number)): # most often case
if isinstance(x, np.ndarray) and issubclass(
x.dtype.type, (np.bool_, np.number)
): # most often case
x = torch.from_numpy(x).to(device)
if dtype is not None:
x = x.type(dtype)
return x
elif isinstance(x, torch.Tensor): # second often case
if dtype is not None:
x = x.type(dtype)
x = x.to(device)
return x.to(device)
elif isinstance(x, (np.number, np.bool_, Number)):
x = to_torch(np.asanyarray(x), dtype, device)
return to_torch(np.asanyarray(x), dtype, device)
elif isinstance(x, dict):
for k, v in x.items():
x[k] = to_torch(v, dtype, device)
return {k: to_torch(v, dtype, device) for k, v in x.items()}
elif isinstance(x, Batch):
x = deepcopy(x)
x.to_torch(dtype, device)
return x
elif isinstance(x, (list, tuple)):
try:
x = to_torch(_parse_value(x), dtype, device)
return to_torch(_parse_value(x), dtype, device)
except TypeError:
x = [to_torch(e, dtype, device) for e in x]
return [to_torch(e, dtype, device) for e in x]
else: # fallback
raise TypeError(f"object {x} cannot be converted to torch.")
return x
def to_torch_as(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor],
y: torch.Tensor
) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
def to_torch_as(
x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor],
y: torch.Tensor,
) -> Union[Batch, dict, list, tuple, torch.Tensor]:
"""Return an object without np.ndarray.
Same as ``to_torch(x, dtype=y.dtype, device=y.device)``.

View File

@ -24,17 +24,20 @@ class SegmentTree:
self._size = size
self._bound = bound
self._value = np.zeros([bound * 2])
self._compile()
def __len__(self):
def __len__(self) -> int:
return self._size
def __getitem__(self, index: Union[int, np.ndarray]
) -> Union[float, np.ndarray]:
def __getitem__(
self, index: Union[int, np.ndarray]
) -> Union[float, np.ndarray]:
"""Return self[index]."""
return self._value[index + self._bound]
def __setitem__(self, index: Union[int, np.ndarray],
value: Union[float, np.ndarray]) -> None:
def __setitem__(
self, index: Union[int, np.ndarray], value: Union[float, np.ndarray]
) -> None:
"""Update values in segment tree.
Duplicate values in ``index`` are handled by numpy: later index
@ -62,7 +65,8 @@ class SegmentTree:
return _reduce(self._value, start + self._bound - 1, end + self._bound)
def get_prefix_sum_idx(
self, value: Union[float, np.ndarray]) -> Union[int, np.ndarray]:
self, value: Union[float, np.ndarray]
) -> Union[int, np.ndarray]:
r"""Find the index with given value.
Return the minimum index for each ``v`` in ``value`` so that
@ -74,7 +78,7 @@ class SegmentTree:
Please make sure all of the values inside the segment tree are
non-negative when using this function.
"""
assert np.all(value >= 0.) and np.all(value < self._value[1])
assert np.all(value >= 0.0) and np.all(value < self._value[1])
single = False
if not isinstance(value, np.ndarray):
value = np.array([value])
@ -82,6 +86,16 @@ class SegmentTree:
index = _get_prefix_sum_idx(value, self._bound, self._value)
return index.item() if single else index
def _compile(self) -> None:
f64 = np.array([0, 1], dtype=np.float64)
f32 = np.array([0, 1], dtype=np.float32)
i64 = np.array([0, 1], dtype=np.int64)
_setitem(f64, i64, f64)
_setitem(f64, i64, f32)
_reduce(f64, 0, 1)
_get_prefix_sum_idx(f64, 1, f64)
_get_prefix_sum_idx(f32, 1, f64)
@njit
def _setitem(tree: np.ndarray, index: np.ndarray, value: np.ndarray) -> None:
@ -96,7 +110,7 @@ def _setitem(tree: np.ndarray, index: np.ndarray, value: np.ndarray) -> None:
def _reduce(tree: np.ndarray, start: int, end: int) -> float:
"""Numba version, 2x faster: 0.009 -> 0.005."""
# nodes in (start, end) should be aggregated
result = 0.
result = 0.0
while end - start > 1: # (start, end) interval is not empty
if start % 2 == 0:
result += tree[start + 1]
@ -108,8 +122,9 @@ def _reduce(tree: np.ndarray, start: int, end: int) -> float:
@njit
def _get_prefix_sum_idx(value: np.ndarray, bound: int,
sums: np.ndarray) -> np.ndarray:
def _get_prefix_sum_idx(
value: np.ndarray, bound: int, sums: np.ndarray
) -> np.ndarray:
"""Numba version (v0.51), 5x speed up with size=100000 and bsz=64.
vectorized np: 0.0923 (numpy best) -> 0.024 (now)

View File

@ -3,11 +3,11 @@ from tianshou.env.venvs import BaseVectorEnv, DummyVectorEnv, VectorEnv, \
from tianshou.env.maenv import MultiAgentEnv
__all__ = [
'BaseVectorEnv',
'DummyVectorEnv',
'VectorEnv', # TODO: remove in later version
'SubprocVectorEnv',
'ShmemVectorEnv',
'RayVectorEnv',
'MultiAgentEnv',
"BaseVectorEnv",
"DummyVectorEnv",
"VectorEnv", # TODO: remove in later version
"SubprocVectorEnv",
"ShmemVectorEnv",
"RayVectorEnv",
"MultiAgentEnv",
]

11
tianshou/env/maenv.py vendored
View File

@ -1,6 +1,6 @@
import gym
import numpy as np
from typing import Tuple
from typing import Any, Dict, Tuple
from abc import ABC, abstractmethod
@ -22,7 +22,7 @@ class MultiAgentEnv(ABC, gym.Env):
usage can be found at :ref:`marl_example`.
"""
def __init__(self, **kwargs) -> None:
def __init__(self) -> None:
pass
@abstractmethod
@ -30,13 +30,14 @@ class MultiAgentEnv(ABC, gym.Env):
"""Reset the state.
Return the initial state, first agent_id, and the initial action set,
for example, ``{'obs': obs, 'agent_id': agent_id, 'mask': mask}``
for example, ``{'obs': obs, 'agent_id': agent_id, 'mask': mask}``.
"""
pass
@abstractmethod
def step(self, action: np.ndarray
) -> Tuple[dict, np.ndarray, np.ndarray, np.ndarray]:
def step(
self, action: np.ndarray
) -> Tuple[Dict[str, Any], np.ndarray, np.ndarray, np.ndarray]:
"""Run one timestep of the environments dynamics.
When the end of episode is reached, you are responsible for calling

View File

@ -1,14 +1,15 @@
import cloudpickle
from typing import Any
class CloudpickleWrapper(object):
"""A cloudpickle wrapper used in SubprocVectorEnv."""
def __init__(self, data):
def __init__(self, data: Any) -> None:
self.data = data
def __getstate__(self):
def __getstate__(self) -> str:
return cloudpickle.dumps(self.data)
def __setstate__(self, data):
def __setstate__(self, data: str) -> None:
self.data = cloudpickle.loads(data)

138
tianshou/env/venvs.py vendored
View File

@ -1,7 +1,7 @@
import gym
import warnings
import numpy as np
from typing import List, Union, Optional, Callable, Any
from typing import Any, List, Union, Optional, Callable
from tianshou.env.worker import EnvWorker, DummyEnvWorker, SubprocEnvWorker, \
RayEnvWorker
@ -59,12 +59,13 @@ class BaseVectorEnv(gym.Env):
within ``timeout`` seconds.
"""
def __init__(self,
env_fns: List[Callable[[], gym.Env]],
worker_fn: Callable[[Callable[[], gym.Env]], EnvWorker],
wait_num: Optional[int] = None,
timeout: Optional[float] = None,
) -> None:
def __init__(
self,
env_fns: List[Callable[[], gym.Env]],
worker_fn: Callable[[Callable[[], gym.Env]], EnvWorker],
wait_num: Optional[int] = None,
timeout: Optional[float] = None,
) -> None:
self._env_fns = env_fns
# A VectorEnv contains a pool of EnvWorkers, which corresponds to
# interact with the given envs (one worker <-> one env).
@ -75,11 +76,13 @@ class BaseVectorEnv(gym.Env):
self.env_num = len(env_fns)
self.wait_num = wait_num or len(env_fns)
assert 1 <= self.wait_num <= len(env_fns), \
f'wait_num should be in [1, {len(env_fns)}], but got {wait_num}'
assert (
1 <= self.wait_num <= len(env_fns)
), f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}"
self.timeout = timeout
assert self.timeout is None or self.timeout > 0, \
f'timeout is {timeout}, it should be positive if provided!'
assert (
self.timeout is None or self.timeout > 0
), f"timeout is {timeout}, it should be positive if provided!"
self.is_async = self.wait_num != len(env_fns) or timeout is not None
self.waiting_conn = []
# environments in self.ready_id is actually ready
@ -92,8 +95,9 @@ class BaseVectorEnv(gym.Env):
self.is_closed = False
def _assert_is_not_closed(self) -> None:
assert not self.is_closed, f"Methods of {self.__class__.__name__} "\
"should not be called after close."
assert not self.is_closed, (
f"Methods of {self.__class__.__name__} cannot be called after "
"close.")
def __len__(self) -> int:
"""Return len(self), which is the number of environments."""
@ -113,7 +117,7 @@ class BaseVectorEnv(gym.Env):
else:
return super().__getattribute__(key)
def __getattr__(self, key: str) -> Any:
def __getattr__(self, key: str) -> List[Any]:
"""Fetch a list of env attributes.
This function tries to retrieve an attribute from each individual
@ -122,8 +126,9 @@ class BaseVectorEnv(gym.Env):
"""
return [getattr(worker, key) for worker in self.workers]
def _wrap_id(self, id: Optional[Union[int, List[int], np.ndarray]] = None
) -> List[int]:
def _wrap_id(
self, id: Optional[Union[int, List[int], np.ndarray]] = None
) -> Union[List[int], np.ndarray]:
if id is None:
id = list(range(self.env_num))
elif np.isscalar(id):
@ -132,13 +137,16 @@ class BaseVectorEnv(gym.Env):
def _assert_id(self, id: List[int]) -> None:
for i in id:
assert i not in self.waiting_id, \
f'Cannot interact with environment {i} which is stepping now.'
assert i in self.ready_id, \
f'Can only interact with ready environments {self.ready_id}.'
assert (
i not in self.waiting_id
), f"Cannot interact with environment {i} which is stepping now."
assert (
i in self.ready_id
), f"Can only interact with ready environments {self.ready_id}."
def reset(self, id: Optional[Union[int, List[int], np.ndarray]] = None
) -> np.ndarray:
def reset(
self, id: Optional[Union[int, List[int], np.ndarray]] = None
) -> np.ndarray:
"""Reset the state of some envs and return initial observations.
If id is None, reset the state of all the environments and return
@ -152,10 +160,11 @@ class BaseVectorEnv(gym.Env):
obs = np.stack([self.workers[i].reset() for i in id])
return obs
def step(self,
action: np.ndarray,
id: Optional[Union[int, List[int], np.ndarray]] = None
) -> List[np.ndarray]:
def step(
self,
action: np.ndarray,
id: Optional[Union[int, List[int], np.ndarray]] = None
) -> List[np.ndarray]:
"""Run one timestep of some environments' dynamics.
If id is None, run one timestep of all the environments dynamics;
@ -221,8 +230,9 @@ class BaseVectorEnv(gym.Env):
self.ready_id.append(env_id)
return list(map(np.stack, zip(*result)))
def seed(self,
seed: Optional[Union[int, List[int]]] = None) -> List[List[int]]:
def seed(
self, seed: Optional[Union[int, List[int]]] = None
) -> List[Optional[List[int]]]:
"""Set the seed for all environments.
Accept ``None``, an int (which will extend ``i`` to
@ -239,13 +249,13 @@ class BaseVectorEnv(gym.Env):
seed = [seed + i for i in range(self.env_num)]
return [w.seed(s) for w, s in zip(self.workers, seed)]
def render(self, **kwargs) -> List[Any]:
def render(self, **kwargs: Any) -> List[Any]:
"""Render all of the environments."""
self._assert_is_not_closed()
if self.is_async and len(self.waiting_id) > 0:
raise RuntimeError(
f"Environments {self.waiting_id} are still "
f"stepping, cannot render them now.")
f"Environments {self.waiting_id} are still stepping, cannot "
"render them now.")
return [w.render(**kwargs) for w in self.workers]
def close(self) -> None:
@ -275,20 +285,23 @@ class DummyVectorEnv(BaseVectorEnv):
explanation.
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]],
wait_num: Optional[int] = None,
timeout: Optional[float] = None) -> None:
super().__init__(env_fns, DummyEnvWorker,
wait_num=wait_num, timeout=timeout)
def __init__(
self,
env_fns: List[Callable[[], gym.Env]],
wait_num: Optional[int] = None,
timeout: Optional[float] = None,
) -> None:
super().__init__(
env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)
class VectorEnv(DummyVectorEnv):
"""VectorEnv is renamed to DummyVectorEnv."""
def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn(
'VectorEnv is renamed to DummyVectorEnv, and will be removed in '
'0.3. Use DummyVectorEnv instead!', Warning)
"VectorEnv is renamed to DummyVectorEnv, and will be removed in "
"0.3. Use DummyVectorEnv instead!", Warning)
super().__init__(*args, **kwargs)
@ -301,13 +314,17 @@ class SubprocVectorEnv(BaseVectorEnv):
explanation.
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]],
wait_num: Optional[int] = None,
timeout: Optional[float] = None) -> None:
def worker_fn(fn):
def __init__(
self,
env_fns: List[Callable[[], gym.Env]],
wait_num: Optional[int] = None,
timeout: Optional[float] = None,
) -> None:
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
return SubprocEnvWorker(fn, share_memory=False)
super().__init__(env_fns, worker_fn,
wait_num=wait_num, timeout=timeout)
super().__init__(
env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
class ShmemVectorEnv(BaseVectorEnv):
@ -321,13 +338,17 @@ class ShmemVectorEnv(BaseVectorEnv):
detailed explanation.
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]],
wait_num: Optional[int] = None,
timeout: Optional[float] = None) -> None:
def worker_fn(fn):
def __init__(
self,
env_fns: List[Callable[[], gym.Env]],
wait_num: Optional[int] = None,
timeout: Optional[float] = None,
) -> None:
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
return SubprocEnvWorker(fn, share_memory=True)
super().__init__(env_fns, worker_fn,
wait_num=wait_num, timeout=timeout)
super().__init__(
env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
class RayVectorEnv(BaseVectorEnv):
@ -341,16 +362,19 @@ class RayVectorEnv(BaseVectorEnv):
explanation.
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]],
wait_num: Optional[int] = None,
timeout: Optional[float] = None) -> None:
def __init__(
self,
env_fns: List[Callable[[], gym.Env]],
wait_num: Optional[int] = None,
timeout: Optional[float] = None,
) -> None:
try:
import ray
except ImportError as e:
raise ImportError(
'Please install ray to support RayVectorEnv: pip install ray'
"Please install ray to support RayVectorEnv: pip install ray"
) from e
if not ray.is_initialized():
ray.init()
super().__init__(env_fns, RayEnvWorker,
wait_num=wait_num, timeout=timeout)
super().__init__(
env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout)

View File

@ -4,8 +4,8 @@ from tianshou.env.worker.subproc import SubprocEnvWorker
from tianshou.env.worker.ray import RayEnvWorker
__all__ = [
'EnvWorker',
'DummyEnvWorker',
'SubprocEnvWorker',
'RayEnvWorker',
"EnvWorker",
"DummyEnvWorker",
"SubprocEnvWorker",
"RayEnvWorker",
]

View File

@ -1,7 +1,7 @@
import gym
import numpy as np
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, Callable, Any
from typing import Any, List, Tuple, Optional, Callable
class EnvWorker(ABC):
@ -24,12 +24,14 @@ class EnvWorker(ABC):
def send_action(self, action: np.ndarray) -> None:
pass
def get_result(self) -> Tuple[
np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
def get_result(
self,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
return self.result
def step(self, action: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
def step(
self, action: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Perform one timestep of the environment's dynamic.
"send_action" and "get_result" are coupled in sync simulation, so
@ -41,19 +43,21 @@ class EnvWorker(ABC):
return self.get_result()
@staticmethod
def wait(workers: List['EnvWorker'],
wait_num: int,
timeout: Optional[float] = None) -> List['EnvWorker']:
def wait(
workers: List["EnvWorker"],
wait_num: int,
timeout: Optional[float] = None,
) -> List["EnvWorker"]:
"""Given a list of workers, return those ready ones."""
raise NotImplementedError
@abstractmethod
def seed(self, seed: Optional[int] = None) -> List[int]:
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
pass
@abstractmethod
def render(self, **kwargs) -> Any:
"""Renders the environment."""
def render(self, **kwargs: Any) -> Any:
"""Render the environment."""
pass
@abstractmethod

View File

@ -1,6 +1,6 @@
import gym
import numpy as np
from typing import List, Callable, Optional, Any
from typing import Any, List, Callable, Optional
from tianshou.env.worker import EnvWorker
@ -19,21 +19,24 @@ class DummyEnvWorker(EnvWorker):
return self.env.reset()
@staticmethod
def wait(workers: List['DummyEnvWorker'],
wait_num: int,
timeout: Optional[float] = None) -> List['DummyEnvWorker']:
def wait(
workers: List["DummyEnvWorker"],
wait_num: int,
timeout: Optional[float] = None,
) -> List["DummyEnvWorker"]:
# Sequential EnvWorker objects are always ready
return workers
def send_action(self, action: np.ndarray) -> None:
self.result = self.env.step(action)
def seed(self, seed: Optional[int] = None) -> List[int]:
return self.env.seed(seed) if hasattr(self.env, 'seed') else None
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
return self.env.seed(seed) if hasattr(self.env, "seed") else None
def render(self, **kwargs) -> Any:
return self.env.render(**kwargs) \
if hasattr(self.env, 'render') else None
def render(self, **kwargs: Any) -> Any:
return (
self.env.render(**kwargs) if hasattr(self.env, "render") else None
)
def close_env(self) -> None:
self.env.close()

View File

@ -1,6 +1,6 @@
import gym
import numpy as np
from typing import List, Callable, Tuple, Optional, Any
from typing import Any, List, Callable, Tuple, Optional
from tianshou.env.worker import EnvWorker
@ -24,31 +24,34 @@ class RayEnvWorker(EnvWorker):
return ray.get(self.env.reset.remote())
@staticmethod
def wait(workers: List['RayEnvWorker'],
wait_num: int,
timeout: Optional[float] = None) -> List['RayEnvWorker']:
def wait(
workers: List["RayEnvWorker"],
wait_num: int,
timeout: Optional[float] = None,
) -> List["RayEnvWorker"]:
results = [x.result for x in workers]
ready_results, _ = ray.wait(results,
num_returns=wait_num, timeout=timeout)
ready_results, _ = ray.wait(
results, num_returns=wait_num, timeout=timeout
)
return [workers[results.index(result)] for result in ready_results]
def send_action(self, action: np.ndarray) -> None:
# self.action is actually a handle
self.result = self.env.step.remote(action)
def get_result(self) -> Tuple[
np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
def get_result(
self,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
return ray.get(self.result)
def seed(self, seed: Optional[int] = None) -> List[int]:
if hasattr(self.env, 'seed'):
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
if hasattr(self.env, "seed"):
return ray.get(self.env.seed.remote(seed))
return None
def render(self, **kwargs) -> Any:
if hasattr(self.env, 'render'):
def render(self, **kwargs: Any) -> Any:
if hasattr(self.env, "render"):
return ray.get(self.env.render.remote(**kwargs))
return None
def close_env(self) -> None:
ray.get(self.env.close.remote())

View File

@ -5,14 +5,22 @@ import numpy as np
from collections import OrderedDict
from multiprocessing.context import Process
from multiprocessing import Array, Pipe, connection
from typing import Callable, Any, List, Tuple, Optional
from typing import Any, List, Tuple, Union, Callable, Optional
from tianshou.env.worker import EnvWorker
from tianshou.env.utils import CloudpickleWrapper
def _worker(parent, p, env_fn_wrapper, obs_bufs=None):
def _encode_obs(obs, buffer):
def _worker(
parent: connection.Connection,
p: connection.Connection,
env_fn_wrapper: CloudpickleWrapper,
obs_bufs: Optional[Union[dict, tuple, "ShArray"]] = None,
) -> None:
def _encode_obs(
obs: Union[dict, tuple, np.ndarray],
buffer: Union[dict, tuple, ShArray],
) -> None:
if isinstance(obs, np.ndarray):
buffer.save(obs)
elif isinstance(obs, tuple):
@ -32,25 +40,27 @@ def _worker(parent, p, env_fn_wrapper, obs_bufs=None):
except EOFError: # the pipe has been closed
p.close()
break
if cmd == 'step':
if cmd == "step":
obs, reward, done, info = env.step(data)
if obs_bufs is not None:
obs = _encode_obs(obs, obs_bufs)
_encode_obs(obs, obs_bufs)
obs = None
p.send((obs, reward, done, info))
elif cmd == 'reset':
elif cmd == "reset":
obs = env.reset()
if obs_bufs is not None:
obs = _encode_obs(obs, obs_bufs)
_encode_obs(obs, obs_bufs)
obs = None
p.send(obs)
elif cmd == 'close':
elif cmd == "close":
p.send(env.close())
p.close()
break
elif cmd == 'render':
p.send(env.render(**data) if hasattr(env, 'render') else None)
elif cmd == 'seed':
p.send(env.seed(data) if hasattr(env, 'seed') else None)
elif cmd == 'getattr':
elif cmd == "render":
p.send(env.render(**data) if hasattr(env, "render") else None)
elif cmd == "seed":
p.send(env.seed(data) if hasattr(env, "seed") else None)
elif cmd == "getattr":
p.send(getattr(env, data) if hasattr(env, data) else None)
else:
p.close()
@ -78,39 +88,39 @@ _NP_TO_CT = {
class ShArray:
"""Wrapper of multiprocessing Array."""
def __init__(self, dtype, shape):
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape)))
self.dtype = dtype
self.shape = shape
def save(self, ndarray):
def save(self, ndarray: np.ndarray) -> None:
assert isinstance(ndarray, np.ndarray)
dst = self.arr.get_obj()
dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape)
np.copyto(dst_np, ndarray)
def get(self):
return np.frombuffer(self.arr.get_obj(),
dtype=self.dtype).reshape(self.shape)
def get(self) -> np.ndarray:
obj = self.arr.get_obj()
return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape)
def _setup_buf(space):
def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:
if isinstance(space, gym.spaces.Dict):
assert isinstance(space.spaces, OrderedDict)
buffer = {k: _setup_buf(v) for k, v in space.spaces.items()}
return {k: _setup_buf(v) for k, v in space.spaces.items()}
elif isinstance(space, gym.spaces.Tuple):
assert isinstance(space.spaces, tuple)
buffer = tuple([_setup_buf(t) for t in space.spaces])
return tuple([_setup_buf(t) for t in space.spaces])
else:
buffer = ShArray(space.dtype, space.shape)
return buffer
return ShArray(space.dtype, space.shape)
class SubprocEnvWorker(EnvWorker):
"""Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv."""
def __init__(self, env_fn: Callable[[], gym.Env],
share_memory=False) -> None:
def __init__(
self, env_fn: Callable[[], gym.Env], share_memory: bool = False
) -> None:
super().__init__(env_fn)
self.parent_remote, self.child_remote = Pipe()
self.share_memory = share_memory
@ -121,18 +131,24 @@ class SubprocEnvWorker(EnvWorker):
dummy.close()
del dummy
self.buffer = _setup_buf(obs_space)
args = (self.parent_remote, self.child_remote,
CloudpickleWrapper(env_fn), self.buffer)
args = (
self.parent_remote,
self.child_remote,
CloudpickleWrapper(env_fn),
self.buffer,
)
self.process = Process(target=_worker, args=args, daemon=True)
self.process.start()
self.child_remote.close()
def __getattr__(self, key: str) -> Any:
self.parent_remote.send(['getattr', key])
self.parent_remote.send(["getattr", key])
return self.parent_remote.recv()
def _decode_obs(self, isNone):
def decode_obs(buffer):
def _decode_obs(self) -> Union[dict, tuple, np.ndarray]:
def decode_obs(
buffer: Optional[Union[dict, tuple, ShArray]]
) -> Union[dict, tuple, np.ndarray]:
if isinstance(buffer, ShArray):
return buffer.get()
elif isinstance(buffer, tuple):
@ -145,16 +161,18 @@ class SubprocEnvWorker(EnvWorker):
return decode_obs(self.buffer)
def reset(self) -> Any:
self.parent_remote.send(['reset', None])
self.parent_remote.send(["reset", None])
obs = self.parent_remote.recv()
if self.share_memory:
obs = self._decode_obs(obs)
obs = self._decode_obs()
return obs
@staticmethod
def wait(workers: List['SubprocEnvWorker'],
wait_num: int,
timeout: Optional[float] = None) -> List['SubprocEnvWorker']:
def wait(
workers: List["SubprocEnvWorker"],
wait_num: int,
timeout: Optional[float] = None,
) -> List["SubprocEnvWorker"]:
conns, ready_conns = [x.parent_remote for x in workers], []
remain_conns = conns
t1 = time.time()
@ -169,31 +187,32 @@ class SubprocEnvWorker(EnvWorker):
new_ready_conns = connection.wait(
remain_conns, timeout=remain_time)
ready_conns.extend(new_ready_conns)
remain_conns = [conn for conn in remain_conns
if conn not in ready_conns]
remain_conns = [
conn for conn in remain_conns if conn not in ready_conns]
return [workers[conns.index(con)] for con in ready_conns]
def send_action(self, action: np.ndarray) -> None:
self.parent_remote.send(['step', action])
self.parent_remote.send(["step", action])
def get_result(self) -> Tuple[
np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
def get_result(
self,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
obs, rew, done, info = self.parent_remote.recv()
if self.share_memory:
obs = self._decode_obs(obs)
obs = self._decode_obs()
return obs, rew, done, info
def seed(self, seed: Optional[int] = None) -> List[int]:
self.parent_remote.send(['seed', seed])
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
self.parent_remote.send(["seed", seed])
return self.parent_remote.recv()
def render(self, **kwargs) -> Any:
self.parent_remote.send(['render', kwargs])
def render(self, **kwargs: Any) -> Any:
self.parent_remote.send(["render", kwargs])
return self.parent_remote.recv()
def close_env(self) -> None:
try:
self.parent_remote.send(['close', None])
self.parent_remote.send(["close", None])
# mp may be deleted so it may raise AttributeError
self.parent_remote.recv()
self.process.join()

View File

@ -1,7 +1,7 @@
from tianshou.exploration.random import BaseNoise, GaussianNoise, OUNoise
__all__ = [
'BaseNoise',
'GaussianNoise',
'OUNoise',
"BaseNoise",
"GaussianNoise",
"OUNoise",
]

View File

@ -1,16 +1,16 @@
import numpy as np
from typing import Union, Optional
from abc import ABC, abstractmethod
from typing import Union, Optional, Sequence
class BaseNoise(ABC, object):
"""The action noise base class."""
def __init__(self, **kwargs) -> None:
def __init__(self) -> None:
super().__init__()
@abstractmethod
def __call__(self, **kwargs) -> np.ndarray:
def __call__(self, size: Sequence[int]) -> np.ndarray:
"""Generate new noise."""
raise NotImplementedError
@ -22,15 +22,13 @@ class BaseNoise(ABC, object):
class GaussianNoise(BaseNoise):
"""The vanilla gaussian process, for exploration in DDPG by default."""
def __init__(self,
mu: float = 0.0,
sigma: float = 1.0):
def __init__(self, mu: float = 0.0, sigma: float = 1.0) -> None:
super().__init__()
self._mu = mu
assert 0 <= sigma, 'noise std should not be negative'
assert 0 <= sigma, "Noise std should not be negative."
self._sigma = sigma
def __call__(self, size: tuple) -> np.ndarray:
def __call__(self, size: Sequence[int]) -> np.ndarray:
return np.random.normal(self._mu, self._sigma, size)
@ -51,27 +49,30 @@ class OUNoise(BaseNoise):
Ornstein-Uhlenbeck process.
"""
def __init__(self,
mu: float = 0.0,
sigma: float = 0.3,
theta: float = 0.15,
dt: float = 1e-2,
x0: Optional[Union[float, np.ndarray]] = None
) -> None:
super(BaseNoise, self).__init__()
def __init__(
self,
mu: float = 0.0,
sigma: float = 0.3,
theta: float = 0.15,
dt: float = 1e-2,
x0: Optional[Union[float, np.ndarray]] = None,
) -> None:
super().__init__()
self._mu = mu
self._alpha = theta * dt
self._beta = sigma * np.sqrt(dt)
self._x0 = x0
self.reset()
def __call__(self, size: tuple, mu: Optional[float] = None) -> np.ndarray:
def __call__(
self, size: Sequence[int], mu: Optional[float] = None
) -> np.ndarray:
"""Generate new noise.
Return a ``numpy.ndarray`` which size is equal to ``size``.
Return an numpy array which size is equal to ``size``.
"""
if self._x is None or self._x.shape != size:
self._x = 0
self._x = 0.0
if mu is None:
mu = self._mu
r = self._beta * np.random.normal(size=size)

View File

@ -12,15 +12,15 @@ from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager
__all__ = [
'BasePolicy',
'RandomPolicy',
'ImitationPolicy',
'DQNPolicy',
'PGPolicy',
'A2CPolicy',
'DDPGPolicy',
'PPOPolicy',
'TD3Policy',
'SACPolicy',
'MultiAgentPolicyManager',
"BasePolicy",
"RandomPolicy",
"ImitationPolicy",
"DQNPolicy",
"PGPolicy",
"A2CPolicy",
"DDPGPolicy",
"PPOPolicy",
"TD3Policy",
"SACPolicy",
"MultiAgentPolicyManager",
]

View File

@ -4,7 +4,7 @@ import numpy as np
from torch import nn
from numba import njit
from abc import ABC, abstractmethod
from typing import Dict, List, Union, Optional, Callable
from typing import Any, List, Union, Mapping, Optional, Callable
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \
to_torch_as, to_numpy
@ -52,23 +52,28 @@ class BasePolicy(ABC, nn.Module):
policy.load_state_dict(torch.load("policy.pth"))
"""
def __init__(self,
observation_space: gym.Space = None,
action_space: gym.Space = None
) -> None:
def __init__(
self,
observation_space: gym.Space = None,
action_space: gym.Space = None
) -> None:
super().__init__()
self.observation_space = observation_space
self.action_space = action_space
self.agent_id = 0
self._compile()
def set_agent_id(self, agent_id: int) -> None:
"""Set self.agent_id = agent_id, for MARL."""
self.agent_id = agent_id
@abstractmethod
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs) -> Batch:
def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> Batch:
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which MUST have the following\
@ -96,8 +101,9 @@ class BasePolicy(ABC, nn.Module):
"""
pass
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
"""Pre-process the data from the provided replay buffer.
Used in :meth:`update`. Check out :ref:`process_fn` for more
@ -106,8 +112,9 @@ class BasePolicy(ABC, nn.Module):
return batch
@abstractmethod
def learn(self, batch: Batch, **kwargs
) -> Dict[str, Union[float, List[float]]]:
def learn(
self, batch: Batch, **kwargs: Any
) -> Mapping[str, Union[float, List[float]]]:
"""Update policy with a given batch of data.
:return: A dict which includes loss and its corresponding label.
@ -123,19 +130,22 @@ class BasePolicy(ABC, nn.Module):
"""
pass
def post_process_fn(self, batch: Batch,
buffer: ReplayBuffer, indice: np.ndarray) -> None:
def post_process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> None:
"""Post-process the data from the provided replay buffer.
Typical usage is to update the sampling weight in prioritized
experience replay. Used in :meth:`update`.
"""
if isinstance(buffer, PrioritizedReplayBuffer) \
and hasattr(batch, 'weight'):
if isinstance(buffer, PrioritizedReplayBuffer) and hasattr(
batch, "weight"
):
buffer.update_weight(indice, batch.weight)
def update(self, sample_size: int, buffer: Optional[ReplayBuffer],
*args, **kwargs) -> Dict[str, Union[float, List[float]]]:
def update(
self, sample_size: int, buffer: Optional[ReplayBuffer], **kwargs: Any
) -> Mapping[str, Union[float, List[float]]]:
"""Update the policy network and replay buffer.
It includes 3 function steps: process_fn, learn, and post_process_fn.
@ -148,7 +158,7 @@ class BasePolicy(ABC, nn.Module):
return {}
batch, indice = buffer.sample(sample_size)
batch = self.process_fn(batch, buffer, indice)
result = self.learn(batch, *args, **kwargs)
result = self.learn(batch, **kwargs)
self.post_process_fn(batch, buffer, indice)
return result
@ -182,7 +192,7 @@ class BasePolicy(ABC, nn.Module):
rew = batch.rew
v_s_ = np.zeros_like(rew) if v_s_ is None else to_numpy(v_s_).flatten()
returns = _episodic_return(v_s_, rew, batch.done, gamma, gae_lambda)
if rew_norm and not np.isclose(returns.std(), 0, 1e-2):
if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2):
returns = (returns - returns.mean()) / returns.std()
batch.returns = returns
return batch
@ -231,9 +241,9 @@ class BasePolicy(ABC, nn.Module):
bfr = rew[:min(len(buffer), 1000)] # avoid large buffer
mean, std = bfr.mean(), bfr.std()
if np.isclose(std, 0, 1e-2):
mean, std = 0., 1.
mean, std = 0.0, 1.0
else:
mean, std = 0., 1.
mean, std = 0.0, 1.0
buf_len = len(buffer)
terminal = (indice + n_step - 1) % buf_len
target_q_torch = target_q_fn(buffer, terminal).flatten() # (bsz, )
@ -248,18 +258,30 @@ class BasePolicy(ABC, nn.Module):
batch.weight = to_torch_as(batch.weight, target_q_torch)
return batch
def _compile(self) -> None:
f64 = np.array([0, 1], dtype=np.float64)
f32 = np.array([0, 1], dtype=np.float32)
b = np.array([False, True], dtype=np.bool_)
i64 = np.array([0, 1], dtype=np.int64)
_episodic_return(f64, f64, b, 0.1, 0.1)
_episodic_return(f32, f64, b, 0.1, 0.1)
_nstep_return(f64, b, f32, i64, 0.1, 1, 4, 1.0, 0.0)
@njit
def _episodic_return(
v_s_: np.ndarray, rew: np.ndarray, done: np.ndarray,
gamma: float, gae_lambda: float,
v_s_: np.ndarray,
rew: np.ndarray,
done: np.ndarray,
gamma: float,
gae_lambda: float,
) -> np.ndarray:
"""Numba speedup: 4.1s -> 0.057s."""
returns = np.roll(v_s_, 1)
m = (1. - done) * gamma
m = (1.0 - done) * gamma
delta = rew + v_s_ * m - returns
m *= gae_lambda
gae = 0.
gae = 0.0
for i in range(len(rew) - 1, -1, -1):
gae = delta[i] + m[i] * gae
returns[i] += gae
@ -268,9 +290,15 @@ def _episodic_return(
@njit
def _nstep_return(
rew: np.ndarray, done: np.ndarray, target_q: np.ndarray,
indice: np.ndarray, gamma: float, n_step: int, buf_len: int,
mean: float, std: float
rew: np.ndarray,
done: np.ndarray,
target_q: np.ndarray,
indice: np.ndarray,
gamma: float,
n_step: int,
buf_len: int,
mean: float,
std: float,
) -> np.ndarray:
"""Numba speedup: 0.3s -> 0.15s."""
returns = np.zeros(indice.shape)
@ -278,8 +306,8 @@ def _nstep_return(
for n in range(n_step - 1, -1, -1):
now = (indice + n) % buf_len
gammas[done[now] > 0] = n
returns[done[now] > 0] = 0.
returns[done[now] > 0] = 0.0
returns = (rew[now] - mean) / std + gamma * returns
target_q[gammas != n_step] = 0
target_q[gammas != n_step] = 0.0
target_q = target_q * (gamma ** gammas) + returns
return target_q

View File

@ -1,7 +1,7 @@
import torch
import numpy as np
import torch.nn.functional as F
from typing import Dict, Union, Optional
from typing import Any, Dict, Union, Optional
from tianshou.data import Batch, to_torch
from tianshou.policy import BasePolicy
@ -22,36 +22,44 @@ class ImitationPolicy(BasePolicy):
explanation.
"""
def __init__(self, model: torch.nn.Module, optim: torch.optim.Optimizer,
mode: str = 'continuous') -> None:
super().__init__()
def __init__(
self,
model: torch.nn.Module,
optim: torch.optim.Optimizer,
mode: str = "continuous",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.model = model
self.optim = optim
assert mode in ['continuous', 'discrete'], \
f'Mode {mode} is not in ["continuous", "discrete"]'
assert (
mode in ["continuous", "discrete"]
), f"Mode {mode} is not in ['continuous', 'discrete']."
self.mode = mode
def forward(self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs) -> Batch:
def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> Batch:
logits, h = self.model(batch.obs, state=state, info=batch.info)
if self.mode == 'discrete':
if self.mode == "discrete":
a = logits.max(dim=1)[1]
else:
a = logits
return Batch(logits=logits, act=a, state=h)
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
self.optim.zero_grad()
if self.mode == 'continuous':
if self.mode == "continuous": # regression
a = self(batch).act
a_ = to_torch(batch.act, dtype=torch.float32, device=a.device)
loss = F.mse_loss(a, a_)
elif self.mode == 'discrete': # classification
elif self.mode == "discrete": # classification
a = self(batch).logits
a_ = to_torch(batch.act, dtype=torch.long, device=a.device)
loss = F.nll_loss(a, a_)
loss.backward()
self.optim.step()
return {'loss': loss.item()}
return {"loss": loss.item()}

View File

@ -2,7 +2,7 @@ import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from typing import Dict, List, Union, Optional
from typing import Any, Dict, List, Union, Optional, Callable
from tianshou.policy import PGPolicy
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
@ -17,6 +17,7 @@ class A2CPolicy(PGPolicy):
:param torch.optim.Optimizer optim: the optimizer for actor and critic
network.
:param dist_fn: distribution class for computing the action.
:type dist_fn: Callable[[], torch.distributions.Distribution]
:param float discount_factor: in [0, 1], defaults to 0.99.
:param float vf_coef: weight for value loss, defaults to 0.5.
:param float ent_coef: weight for entropy loss, defaults to 0.01.
@ -37,23 +38,25 @@ class A2CPolicy(PGPolicy):
explanation.
"""
def __init__(self,
actor: torch.nn.Module,
critic: torch.nn.Module,
optim: torch.optim.Optimizer,
dist_fn: torch.distributions.Distribution,
discount_factor: float = 0.99,
vf_coef: float = .5,
ent_coef: float = .01,
max_grad_norm: Optional[float] = None,
gae_lambda: float = 0.95,
reward_normalization: bool = False,
max_batchsize: int = 256,
**kwargs) -> None:
def __init__(
self,
actor: torch.nn.Module,
critic: torch.nn.Module,
optim: torch.optim.Optimizer,
dist_fn: Callable[[], torch.distributions.Distribution],
discount_factor: float = 0.99,
vf_coef: float = 0.5,
ent_coef: float = 0.01,
max_grad_norm: Optional[float] = None,
gae_lambda: float = 0.95,
reward_normalization: bool = False,
max_batchsize: int = 256,
**kwargs: Any
) -> None:
super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
self.actor = actor
self.critic = critic
assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].'
assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]."
self._lambda = gae_lambda
self._w_vf = vf_coef
self._w_ent = ent_coef
@ -61,9 +64,10 @@ class A2CPolicy(PGPolicy):
self._batch = max_batchsize
self._rew_norm = reward_normalization
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
if self._lambda in [0, 1]:
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
if self._lambda in [0.0, 1.0]:
return self.compute_episodic_return(
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
v_ = []
@ -75,9 +79,12 @@ class A2CPolicy(PGPolicy):
batch, v_, gamma=self._gamma, gae_lambda=self._lambda,
rew_norm=self._rew_norm)
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs) -> Batch:
def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any
) -> Batch:
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
@ -100,8 +107,9 @@ class A2CPolicy(PGPolicy):
act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist)
def learn(self, batch: Batch, batch_size: int, repeat: int,
**kwargs) -> Dict[str, List[float]]:
def learn(
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
) -> Dict[str, List[float]]:
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
for _ in range(repeat):
for b in batch.split(batch_size, merge_last=True):
@ -110,8 +118,7 @@ class A2CPolicy(PGPolicy):
v = self.critic(b.obs).flatten()
a = to_torch_as(b.act, v)
r = to_torch_as(b.returns, v)
log_prob = dist.log_prob(a).reshape(
r.shape[0], -1).transpose(0, 1)
log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1)
a_loss = -(log_prob * (r - v).detach()).mean()
vf_loss = F.mse_loss(r, v)
ent_loss = dist.entropy().mean()
@ -119,17 +126,18 @@ class A2CPolicy(PGPolicy):
loss.backward()
if self._grad_norm is not None:
nn.utils.clip_grad_norm_(
list(self.actor.parameters()) +
list(self.critic.parameters()),
max_norm=self._grad_norm)
list(self.actor.parameters())
+ list(self.critic.parameters()),
max_norm=self._grad_norm,
)
self.optim.step()
actor_losses.append(a_loss.item())
vf_losses.append(vf_loss.item())
ent_losses.append(ent_loss.item())
losses.append(loss.item())
return {
'loss': losses,
'loss/actor': actor_losses,
'loss/vf': vf_losses,
'loss/ent': ent_losses,
"loss": losses,
"loss/actor": actor_losses,
"loss/vf": vf_losses,
"loss/ent": ent_losses,
}

View File

@ -1,7 +1,7 @@
import torch
import numpy as np
from copy import deepcopy
from typing import Dict, Tuple, Union, Optional
from typing import Any, Dict, Tuple, Union, Optional
from tianshou.policy import BasePolicy
from tianshou.exploration import BaseNoise, GaussianNoise
@ -17,13 +17,13 @@ class DDPGPolicy(BasePolicy):
:param torch.nn.Module critic: the critic network. (s, a -> Q(s, a))
:param torch.optim.Optimizer critic_optim: the optimizer for critic
network.
:param action_range: the action range (minimum, maximum).
:type action_range: Tuple[float, float]
:param float tau: param for soft update of the target network, defaults to
0.005.
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
:param BaseNoise exploration_noise: the exploration noise,
add to the action, defaults to ``GaussianNoise(sigma=0.1)``.
:param action_range: the action range (minimum, maximum).
:type action_range: (float, float)
:param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to False.
:param bool ignore_done: ignore the done flag while training the policy,
@ -37,20 +37,21 @@ class DDPGPolicy(BasePolicy):
explanation.
"""
def __init__(self,
actor: torch.nn.Module,
actor_optim: torch.optim.Optimizer,
critic: torch.nn.Module,
critic_optim: torch.optim.Optimizer,
tau: float = 0.005,
gamma: float = 0.99,
exploration_noise: Optional[BaseNoise]
= GaussianNoise(sigma=0.1),
action_range: Optional[Tuple[float, float]] = None,
reward_normalization: bool = False,
ignore_done: bool = False,
estimation_step: int = 1,
**kwargs) -> None:
def __init__(
self,
actor: Optional[torch.nn.Module],
actor_optim: Optional[torch.optim.Optimizer],
critic: Optional[torch.nn.Module],
critic_optim: Optional[torch.optim.Optimizer],
action_range: Tuple[float, float],
tau: float = 0.005,
gamma: float = 0.99,
exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
reward_normalization: bool = False,
ignore_done: bool = False,
estimation_step: int = 1,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if actor is not None:
self.actor, self.actor_old = actor, deepcopy(actor)
@ -60,27 +61,26 @@ class DDPGPolicy(BasePolicy):
self.critic, self.critic_old = critic, deepcopy(critic)
self.critic_old.eval()
self.critic_optim = critic_optim
assert 0 <= tau <= 1, 'tau should in [0, 1]'
assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]"
self._tau = tau
assert 0 <= gamma <= 1, 'gamma should in [0, 1]'
assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]"
self._gamma = gamma
self._noise = exploration_noise
assert action_range is not None
self._range = action_range
self._action_bias = (action_range[0] + action_range[1]) / 2
self._action_scale = (action_range[1] - action_range[0]) / 2
# it is only a little difference to use rand_normal
self._action_bias = (action_range[0] + action_range[1]) / 2.0
self._action_scale = (action_range[1] - action_range[0]) / 2.0
# it is only a little difference to use GaussianNoise
# self.noise = OUNoise()
self._rm_done = ignore_done
self._rew_norm = reward_normalization
assert estimation_step > 0, 'estimation_step should greater than 0'
assert estimation_step > 0, "estimation_step should be greater than 0"
self._n_step = estimation_step
def set_exp_noise(self, noise: Optional[BaseNoise]) -> None:
"""Set the exploration noise."""
self._noise = noise
def train(self, mode=True) -> torch.nn.Module:
def train(self, mode: bool = True) -> "DDPGPolicy":
"""Set the module in training mode, except for the target network."""
self.training = mode
self.actor.train(mode)
@ -90,13 +90,15 @@ class DDPGPolicy(BasePolicy):
def sync_weight(self) -> None:
"""Soft-update the weight for the target network."""
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
for o, n in zip(
self.critic_old.parameters(), self.critic.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
self.critic_old.parameters(), self.critic.parameters()
):
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
def _target_q(self, buffer: ReplayBuffer,
indice: np.ndarray) -> torch.Tensor:
def _target_q(
self, buffer: ReplayBuffer, indice: np.ndarray
) -> torch.Tensor:
batch = buffer[indice] # batch.obs_next: s_{t+n}
with torch.no_grad():
target_q = self.critic_old(batch.obs_next, self(
@ -104,21 +106,25 @@ class DDPGPolicy(BasePolicy):
explorating=False).act)
return target_q
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
if self._rm_done:
batch.done = batch.done * 0.
batch.done = batch.done * 0.0
batch = self.compute_nstep_return(
batch, buffer, indice, self._target_q,
self._gamma, self._n_step, self._rew_norm)
return batch
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
model: str = 'actor',
input: str = 'obs',
explorating: bool = True,
**kwargs) -> Batch:
def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
model: str = "actor",
input: str = "obs",
explorating: bool = True,
**kwargs: Any,
) -> Batch:
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 2 keys:
@ -140,8 +146,8 @@ class DDPGPolicy(BasePolicy):
actions = actions.clamp(self._range[0], self._range[1])
return Batch(act=actions, state=h)
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
weight = batch.pop('weight', 1.)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
weight = batch.pop("weight", 1.0)
current_q = self.critic(batch.obs, batch.act).flatten()
target_q = batch.returns.flatten()
td = current_q - target_q
@ -157,6 +163,6 @@ class DDPGPolicy(BasePolicy):
self.actor_optim.step()
self.sync_weight()
return {
'loss/actor': actor_loss.item(),
'loss/critic': critic_loss.item(),
"loss/actor": actor_loss.item(),
"loss/critic": critic_loss.item(),
}

View File

@ -1,7 +1,7 @@
import torch
import numpy as np
from copy import deepcopy
from typing import Dict, Union, Optional
from typing import Any, Dict, Union, Optional
from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
@ -32,21 +32,25 @@ class DQNPolicy(BasePolicy):
explanation.
"""
def __init__(self,
model: torch.nn.Module,
optim: torch.optim.Optimizer,
discount_factor: float = 0.99,
estimation_step: int = 1,
target_update_freq: int = 0,
reward_normalization: bool = False,
**kwargs) -> None:
def __init__(
self,
model: torch.nn.Module,
optim: torch.optim.Optimizer,
discount_factor: float = 0.99,
estimation_step: int = 1,
target_update_freq: int = 0,
reward_normalization: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.model = model
self.optim = optim
self.eps = 0
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
self.eps = 0.0
assert (
0.0 <= discount_factor <= 1.0
), "discount factor should be in [0, 1]"
self._gamma = discount_factor
assert estimation_step > 0, 'estimation_step should greater than 0'
assert estimation_step > 0, "estimation_step should be greater than 0"
self._n_step = estimation_step
self._target = target_update_freq > 0
self._freq = target_update_freq
@ -60,7 +64,7 @@ class DQNPolicy(BasePolicy):
"""Set the eps for epsilon-greedy exploration."""
self.eps = eps
def train(self, mode=True) -> torch.nn.Module:
def train(self, mode: bool = True) -> "DQNPolicy":
"""Set the module in training mode, except for the target network."""
self.training = mode
self.model.train(mode)
@ -70,23 +74,26 @@ class DQNPolicy(BasePolicy):
"""Synchronize the weight for the target network."""
self.model_old.load_state_dict(self.model.state_dict())
def _target_q(self, buffer: ReplayBuffer,
indice: np.ndarray) -> torch.Tensor:
def _target_q(
self, buffer: ReplayBuffer, indice: np.ndarray
) -> torch.Tensor:
batch = buffer[indice] # batch.obs_next: s_{t+n}
if self._target:
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
a = self(batch, input='obs_next', eps=0).act
a = self(batch, input="obs_next", eps=0).act
with torch.no_grad():
target_q = self(
batch, model='model_old', input='obs_next').logits
batch, model="model_old", input="obs_next"
).logits
target_q = target_q[np.arange(len(a)), a]
else:
with torch.no_grad():
target_q = self(batch, input='obs_next').logits.max(dim=1)[0]
target_q = self(batch, input="obs_next").logits.max(dim=1)[0]
return target_q
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
"""Compute the n-step return for Q-learning targets.
More details can be found at
@ -97,12 +104,15 @@ class DQNPolicy(BasePolicy):
self._gamma, self._n_step, self._rew_norm)
return batch
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
model: str = 'model',
input: str = 'obs',
eps: Optional[float] = None,
**kwargs) -> Batch:
def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
model: str = "model",
input: str = "obs",
eps: Optional[float] = None,
**kwargs: Any,
) -> Batch:
"""Compute action over the given batch data.
If you need to mask the action, please add a "mask" into batch.obs, for
@ -134,7 +144,7 @@ class DQNPolicy(BasePolicy):
"""
model = getattr(self, model)
obs = getattr(batch, input)
obs_ = obs.obs if hasattr(obs, 'obs') else obs
obs_ = obs.obs if hasattr(obs, "obs") else obs
q, h = model(obs_, state=state, info=batch.info)
act = to_numpy(q.max(dim=1)[1])
has_mask = hasattr(obs, 'mask')
@ -146,7 +156,7 @@ class DQNPolicy(BasePolicy):
# add eps to act
if eps is None:
eps = self.eps
if not np.isclose(eps, 0):
if not np.isclose(eps, 0.0):
for i in range(len(q)):
if np.random.rand() < eps:
q_ = np.random.rand(*q[i].shape)
@ -155,12 +165,12 @@ class DQNPolicy(BasePolicy):
act[i] = q_.argmax()
return Batch(logits=q, act=act, state=h)
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._target and self._cnt % self._freq == 0:
self.sync_weight()
self.optim.zero_grad()
weight = batch.pop('weight', 1.)
q = self(batch, eps=0.).logits
weight = batch.pop("weight", 1.0)
q = self(batch, eps=0.0).logits
q = q[np.arange(len(q)), batch.act]
r = to_torch_as(batch.returns, q).flatten()
td = r - q
@ -169,4 +179,4 @@ class DQNPolicy(BasePolicy):
loss.backward()
self.optim.step()
self._cnt += 1
return {'loss': loss.item()}
return {"loss": loss.item()}

View File

@ -1,6 +1,6 @@
import torch
import numpy as np
from typing import Dict, List, Union, Optional
from typing import Any, Dict, List, Union, Optional, Callable
from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, to_torch_as
@ -13,6 +13,7 @@ class PGPolicy(BasePolicy):
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
:param dist_fn: distribution class for computing the action.
:type dist_fn: Callable[[], torch.distributions.Distribution]
:param float discount_factor: in [0, 1].
.. seealso::
@ -21,23 +22,28 @@ class PGPolicy(BasePolicy):
explanation.
"""
def __init__(self,
model: torch.nn.Module,
optim: torch.optim.Optimizer,
dist_fn: torch.distributions.Distribution,
discount_factor: float = 0.99,
reward_normalization: bool = False,
**kwargs) -> None:
def __init__(
self,
model: Optional[torch.nn.Module],
optim: torch.optim.Optimizer,
dist_fn: Callable[[], torch.distributions.Distribution],
discount_factor: float = 0.99,
reward_normalization: bool = False,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.model = model
self.optim = optim
self.dist_fn = dist_fn
assert 0 <= discount_factor <= 1, 'discount factor should in [0, 1]'
assert (
0.0 <= discount_factor <= 1.0
), "discount factor should be in [0, 1]"
self._gamma = discount_factor
self._rew_norm = reward_normalization
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
r"""Compute the discounted returns for each frame.
.. math::
@ -48,13 +54,15 @@ class PGPolicy(BasePolicy):
"""
# batch.returns = self._vanilla_returns(batch)
# batch.returns = self._vectorized_returns(batch)
# return batch
return self.compute_episodic_return(
batch, gamma=self._gamma, gae_lambda=1., rew_norm=self._rew_norm)
batch, gamma=self._gamma, gae_lambda=1.0, rew_norm=self._rew_norm)
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs) -> Batch:
def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> Batch:
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
@ -77,8 +85,9 @@ class PGPolicy(BasePolicy):
act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist)
def learn(self, batch: Batch, batch_size: int, repeat: int,
**kwargs) -> Dict[str, List[float]]:
def learn(
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
) -> Dict[str, List[float]]:
losses = []
for _ in range(repeat):
for b in batch.split(batch_size, merge_last=True):
@ -86,13 +95,12 @@ class PGPolicy(BasePolicy):
dist = self(b).dist
a = to_torch_as(b.act, dist.logits)
r = to_torch_as(b.returns, dist.logits)
log_prob = dist.log_prob(a).reshape(
r.shape[0], -1).transpose(0, 1)
log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1)
loss = -(log_prob * r).mean()
loss.backward()
self.optim.step()
losses.append(loss.item())
return {'loss': losses}
return {"loss": losses}
# def _vanilla_returns(self, batch):
# returns = batch.rew[:]

View File

@ -1,7 +1,7 @@
import torch
import numpy as np
from torch import nn
from typing import Dict, List, Tuple, Union, Optional
from typing import Any, Dict, List, Tuple, Union, Optional, Callable
from tianshou.policy import PGPolicy
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
@ -16,6 +16,7 @@ class PPOPolicy(PGPolicy):
:param torch.optim.Optimizer optim: the optimizer for actor and critic
network.
:param dist_fn: distribution class for computing the action.
:type dist_fn: Callable[[], torch.distributions.Distribution]
:param float discount_factor: in [0, 1], defaults to 0.99.
:param float max_grad_norm: clipping gradients in back propagation,
defaults to None.
@ -45,24 +46,26 @@ class PPOPolicy(PGPolicy):
explanation.
"""
def __init__(self,
actor: torch.nn.Module,
critic: torch.nn.Module,
optim: torch.optim.Optimizer,
dist_fn: torch.distributions.Distribution,
discount_factor: float = 0.99,
max_grad_norm: Optional[float] = None,
eps_clip: float = .2,
vf_coef: float = .5,
ent_coef: float = .01,
action_range: Optional[Tuple[float, float]] = None,
gae_lambda: float = 0.95,
dual_clip: Optional[float] = None,
value_clip: bool = True,
reward_normalization: bool = True,
max_batchsize: int = 256,
**kwargs) -> None:
super().__init__(None, None, dist_fn, discount_factor, **kwargs)
def __init__(
self,
actor: torch.nn.Module,
critic: torch.nn.Module,
optim: torch.optim.Optimizer,
dist_fn: Callable[[], torch.distributions.Distribution],
discount_factor: float = 0.99,
max_grad_norm: Optional[float] = None,
eps_clip: float = 0.2,
vf_coef: float = 0.5,
ent_coef: float = 0.01,
action_range: Optional[Tuple[float, float]] = None,
gae_lambda: float = 0.95,
dual_clip: Optional[float] = None,
value_clip: bool = True,
reward_normalization: bool = True,
max_batchsize: int = 256,
**kwargs: Any,
) -> None:
super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
self._max_grad_norm = max_grad_norm
self._eps_clip = eps_clip
self._w_vf = vf_coef
@ -70,29 +73,31 @@ class PPOPolicy(PGPolicy):
self._range = action_range
self.actor = actor
self.critic = critic
self.optim = optim
self._batch = max_batchsize
assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].'
assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]."
self._lambda = gae_lambda
assert dual_clip is None or dual_clip > 1, \
'Dual-clip PPO parameter should greater than 1.'
assert (
dual_clip is None or dual_clip > 1.0
), "Dual-clip PPO parameter should greater than 1.0."
self._dual_clip = dual_clip
self._value_clip = value_clip
self._rew_norm = reward_normalization
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
if self._rew_norm:
mean, std = batch.rew.mean(), batch.rew.std()
if not np.isclose(std, 0, 1e-2):
if not np.isclose(std, 0.0, 1e-2):
batch.rew = (batch.rew - mean) / std
v, v_, old_log_prob = [], [], []
with torch.no_grad():
for b in batch.split(self._batch, shuffle=False, merge_last=True):
v_.append(self.critic(b.obs_next))
v.append(self.critic(b.obs))
old_log_prob.append(self(b).dist.log_prob(
to_torch_as(b.act, v[0])))
old_log_prob.append(
self(b).dist.log_prob(to_torch_as(b.act, v[0]))
)
v_ = to_numpy(torch.cat(v_, dim=0))
batch = self.compute_episodic_return(
batch, v_, gamma=self._gamma, gae_lambda=self._lambda,
@ -104,13 +109,16 @@ class PPOPolicy(PGPolicy):
batch.adv = batch.returns - batch.v
if self._rew_norm:
mean, std = batch.adv.mean(), batch.adv.std()
if not np.isclose(std.item(), 0, 1e-2):
if not np.isclose(std.item(), 0.0, 1e-2):
batch.adv = (batch.adv - mean) / std
return batch
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs) -> Batch:
def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> Batch:
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
@ -135,8 +143,9 @@ class PPOPolicy(PGPolicy):
act = act.clamp(self._range[0], self._range[1])
return Batch(logits=logits, act=act, state=h, dist=dist)
def learn(self, batch: Batch, batch_size: int, repeat: int,
**kwargs) -> Dict[str, List[float]]:
def learn(
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
) -> Dict[str, List[float]]:
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
for _ in range(repeat):
for b in batch.split(batch_size, merge_last=True):
@ -145,8 +154,8 @@ class PPOPolicy(PGPolicy):
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
surr1 = ratio * b.adv
surr2 = ratio.clamp(1. - self._eps_clip,
1. + self._eps_clip) * b.adv
surr2 = ratio.clamp(1.0 - self._eps_clip,
1.0 + self._eps_clip) * b.adv
if self._dual_clip:
clip_loss = -torch.max(torch.min(surr1, surr2),
self._dual_clip * b.adv).mean()
@ -158,9 +167,9 @@ class PPOPolicy(PGPolicy):
-self._eps_clip, self._eps_clip)
vf1 = (b.returns - value).pow(2)
vf2 = (b.returns - v_clip).pow(2)
vf_loss = .5 * torch.max(vf1, vf2).mean()
vf_loss = 0.5 * torch.max(vf1, vf2).mean()
else:
vf_loss = .5 * (b.returns - value).pow(2).mean()
vf_loss = 0.5 * (b.returns - value).pow(2).mean()
vf_losses.append(vf_loss.item())
e_loss = dist.entropy().mean()
ent_losses.append(e_loss.item())
@ -168,13 +177,14 @@ class PPOPolicy(PGPolicy):
losses.append(loss.item())
self.optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(list(
self.actor.parameters()) + list(self.critic.parameters()),
nn.utils.clip_grad_norm_(
list(self.actor.parameters())
+ list(self.critic.parameters()),
self._max_grad_norm)
self.optim.step()
return {
'loss': losses,
'loss/clip': clip_losses,
'loss/vf': vf_losses,
'loss/ent': ent_losses,
"loss": losses,
"loss/clip": clip_losses,
"loss/vf": vf_losses,
"loss/ent": ent_losses,
}

View File

@ -1,12 +1,12 @@
import torch
import numpy as np
from copy import deepcopy
from typing import Dict, Tuple, Union, Optional
from torch.distributions import Normal, Independent
from torch.distributions import Independent, Normal
from typing import Any, Dict, Tuple, Union, Optional
from tianshou.policy import DDPGPolicy
from tianshou.data import Batch, to_torch_as, ReplayBuffer
from tianshou.exploration import BaseNoise
from tianshou.data import Batch, ReplayBuffer, to_torch_as
class SACPolicy(DDPGPolicy):
@ -23,6 +23,8 @@ class SACPolicy(DDPGPolicy):
a))
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
critic network.
:param action_range: the action range (minimum, maximum).
:type action_range: Tuple[float, float]
:param float tau: param for soft update of the target network, defaults to
0.005.
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
@ -32,8 +34,6 @@ class SACPolicy(DDPGPolicy):
regularization coefficient, default to 0.2.
If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then
alpha is automatatically tuned.
:param action_range: the action range (minimum, maximum).
:type action_range: (float, float)
:param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to False.
:param bool ignore_done: ignore the done flag while training the policy,
@ -55,20 +55,20 @@ class SACPolicy(DDPGPolicy):
critic1_optim: torch.optim.Optimizer,
critic2: torch.nn.Module,
critic2_optim: torch.optim.Optimizer,
action_range: Tuple[float, float],
tau: float = 0.005,
gamma: float = 0.99,
alpha: Union[
float, Tuple[float, torch.Tensor, torch.optim.Optimizer]
] = 0.2,
action_range: Optional[Tuple[float, float]] = None,
reward_normalization: bool = False,
ignore_done: bool = False,
estimation_step: int = 1,
exploration_noise: Optional[BaseNoise] = None,
**kwargs
**kwargs: Any,
) -> None:
super().__init__(None, None, None, None, tau, gamma, exploration_noise,
action_range, reward_normalization, ignore_done,
super().__init__(None, None, None, None, action_range, tau, gamma,
exploration_noise, reward_normalization, ignore_done,
estimation_step, **kwargs)
self.actor, self.actor_optim = actor, actor_optim
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
@ -79,6 +79,7 @@ class SACPolicy(DDPGPolicy):
self.critic2_optim = critic2_optim
self._is_auto_alpha = False
self._alpha: Union[float, torch.Tensor]
if isinstance(alpha, tuple):
self._is_auto_alpha = True
self._target_entropy, self._log_alpha, self._alpha_optim = alpha
@ -89,7 +90,7 @@ class SACPolicy(DDPGPolicy):
self.__eps = np.finfo(np.float32).eps.item()
def train(self, mode=True) -> torch.nn.Module:
def train(self, mode: bool = True) -> "SACPolicy":
self.training = mode
self.actor.train(mode)
self.critic1.train(mode)
@ -98,17 +99,22 @@ class SACPolicy(DDPGPolicy):
def sync_weight(self) -> None:
for o, n in zip(
self.critic1_old.parameters(), self.critic1.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
self.critic1_old.parameters(), self.critic1.parameters()
):
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
for o, n in zip(
self.critic2_old.parameters(), self.critic2.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
self.critic2_old.parameters(), self.critic2.parameters()
):
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
input: str = 'obs',
explorating: bool = True,
**kwargs) -> Batch:
def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
input: str = "obs",
explorating: bool = True,
**kwargs: Any,
) -> Batch:
obs = getattr(batch, input)
logits, h = self.actor(obs, state=state, info=batch.info)
assert isinstance(logits, tuple)
@ -125,8 +131,9 @@ class SACPolicy(DDPGPolicy):
return Batch(
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
def _target_q(self, buffer: ReplayBuffer,
indice: np.ndarray) -> torch.Tensor:
def _target_q(
self, buffer: ReplayBuffer, indice: np.ndarray
) -> torch.Tensor:
batch = buffer[indice] # batch.obs: s_{t+n}
with torch.no_grad():
obs_next_result = self(batch, input='obs_next', explorating=False)
@ -138,8 +145,8 @@ class SACPolicy(DDPGPolicy):
) - self._alpha * obs_next_result.log_prob
return target_q
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
weight = batch.pop('weight', 1.)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
weight = batch.pop("weight", 1.0)
# critic 1
current_q1 = self.critic1(batch.obs, batch.act).flatten()
target_q = batch.returns.flatten()
@ -157,7 +164,7 @@ class SACPolicy(DDPGPolicy):
self.critic2_optim.zero_grad()
critic2_loss.backward()
self.critic2_optim.step()
batch.weight = (td1 + td2) / 2. # prio-buffer
batch.weight = (td1 + td2) / 2.0 # prio-buffer
# actor
obs_result = self(batch, explorating=False)
a = obs_result.act
@ -180,11 +187,11 @@ class SACPolicy(DDPGPolicy):
self.sync_weight()
result = {
'loss/actor': actor_loss.item(),
'loss/critic1': critic1_loss.item(),
'loss/critic2': critic2_loss.item(),
"loss/actor": actor_loss.item(),
"loss/critic1": critic1_loss.item(),
"loss/critic2": critic2_loss.item(),
}
if self._is_auto_alpha:
result['loss/alpha'] = alpha_loss.item()
result['v/alpha'] = self._alpha.item()
result["loss/alpha"] = alpha_loss.item()
result["v/alpha"] = self._alpha.item()
return result

View File

@ -1,7 +1,7 @@
import torch
import numpy as np
from copy import deepcopy
from typing import Dict, Tuple, Optional
from typing import Any, Dict, Tuple, Optional
from tianshou.policy import DDPGPolicy
from tianshou.data import Batch, ReplayBuffer
@ -22,6 +22,8 @@ class TD3Policy(DDPGPolicy):
a))
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
critic network.
:param action_range: the action range (minimum, maximum).
:type action_range: Tuple[float, float]
:param float tau: param for soft update of the target network, defaults to
0.005.
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
@ -33,8 +35,6 @@ class TD3Policy(DDPGPolicy):
default to 2.
:param float noise_clip: the clipping range used in updating policy
network, default to 0.5.
:param action_range: the action range (minimum, maximum).
:type action_range: (float, float)
:param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to False.
:param bool ignore_done: ignore the done flag while training the policy,
@ -46,27 +46,28 @@ class TD3Policy(DDPGPolicy):
explanation.
"""
def __init__(self,
actor: torch.nn.Module,
actor_optim: torch.optim.Optimizer,
critic1: torch.nn.Module,
critic1_optim: torch.optim.Optimizer,
critic2: torch.nn.Module,
critic2_optim: torch.optim.Optimizer,
tau: float = 0.005,
gamma: float = 0.99,
exploration_noise: Optional[BaseNoise]
= GaussianNoise(sigma=0.1),
policy_noise: float = 0.2,
update_actor_freq: int = 2,
noise_clip: float = 0.5,
action_range: Optional[Tuple[float, float]] = None,
reward_normalization: bool = False,
ignore_done: bool = False,
estimation_step: int = 1,
**kwargs) -> None:
super().__init__(actor, actor_optim, None, None, tau, gamma,
exploration_noise, action_range, reward_normalization,
def __init__(
self,
actor: torch.nn.Module,
actor_optim: torch.optim.Optimizer,
critic1: torch.nn.Module,
critic1_optim: torch.optim.Optimizer,
critic2: torch.nn.Module,
critic2_optim: torch.optim.Optimizer,
action_range: Tuple[float, float],
tau: float = 0.005,
gamma: float = 0.99,
exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
policy_noise: float = 0.2,
update_actor_freq: int = 2,
noise_clip: float = 0.5,
reward_normalization: bool = False,
ignore_done: bool = False,
estimation_step: int = 1,
**kwargs: Any,
) -> None:
super().__init__(actor, actor_optim, None, None, action_range,
tau, gamma, exploration_noise, reward_normalization,
ignore_done, estimation_step, **kwargs)
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
self.critic1_old.eval()
@ -80,7 +81,7 @@ class TD3Policy(DDPGPolicy):
self._cnt = 0
self._last = 0
def train(self, mode=True) -> torch.nn.Module:
def train(self, mode: bool = True) -> "TD3Policy":
self.training = mode
self.actor.train(mode)
self.critic1.train(mode)
@ -89,22 +90,25 @@ class TD3Policy(DDPGPolicy):
def sync_weight(self) -> None:
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
for o, n in zip(
self.critic1_old.parameters(), self.critic1.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
self.critic1_old.parameters(), self.critic1.parameters()
):
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
for o, n in zip(
self.critic2_old.parameters(), self.critic2.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
self.critic2_old.parameters(), self.critic2.parameters()
):
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
def _target_q(self, buffer: ReplayBuffer,
indice: np.ndarray) -> torch.Tensor:
def _target_q(
self, buffer: ReplayBuffer, indice: np.ndarray
) -> torch.Tensor:
batch = buffer[indice] # batch.obs: s_{t+n}
with torch.no_grad():
a_ = self(batch, model='actor_old', input='obs_next').act
a_ = self(batch, model="actor_old", input="obs_next").act
dev = a_.device
noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise
if self._noise_clip > 0:
if self._noise_clip > 0.0:
noise = noise.clamp(-self._noise_clip, self._noise_clip)
a_ += noise
a_ = a_.clamp(self._range[0], self._range[1])
@ -113,8 +117,8 @@ class TD3Policy(DDPGPolicy):
self.critic2_old(batch.obs_next, a_))
return target_q
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
weight = batch.pop('weight', 1.)
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
weight = batch.pop("weight", 1.0)
# critic 1
current_q1 = self.critic1(batch.obs, batch.act).flatten()
target_q = batch.returns.flatten()
@ -132,10 +136,10 @@ class TD3Policy(DDPGPolicy):
self.critic2_optim.zero_grad()
critic2_loss.backward()
self.critic2_optim.step()
batch.weight = (td1 + td2) / 2. # prio-buffer
batch.weight = (td1 + td2) / 2.0 # prio-buffer
if self._cnt % self._freq == 0:
actor_loss = -self.critic1(
batch.obs, self(batch, eps=0).act).mean()
batch.obs, self(batch, eps=0.0).act).mean()
self.actor_optim.zero_grad()
actor_loss.backward()
self._last = actor_loss.item()
@ -143,7 +147,7 @@ class TD3Policy(DDPGPolicy):
self.sync_weight()
self._cnt += 1
return {
'loss/actor': self._last,
'loss/critic1': critic1_loss.item(),
'loss/critic2': critic2_loss.item(),
"loss/actor": self._last,
"loss/critic1": critic1_loss.item(),
"loss/critic2": critic2_loss.item(),
}

View File

@ -1,5 +1,5 @@
import numpy as np
from typing import Union, Optional, Dict, List
from typing import Any, Dict, List, Union, Optional
from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer
@ -15,21 +15,22 @@ class MultiAgentPolicyManager(BasePolicy):
:ref:`marl_example` can help you better understand this procedure.
"""
def __init__(self, policies: List[BasePolicy]):
super().__init__()
def __init__(self, policies: List[BasePolicy], **kwargs: Any) -> None:
super().__init__(**kwargs)
self.policies = policies
for i, policy in enumerate(policies):
# agent_id 0 is reserved for the environment proxy
# (this MultiAgentPolicyManager)
policy.set_agent_id(i + 1)
def replace_policy(self, policy, agent_id):
def replace_policy(self, policy: BasePolicy, agent_id: int) -> None:
"""Replace the "agent_id"th policy in this manager."""
self.policies[agent_id - 1] = policy
policy.set_agent_id(agent_id)
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch:
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
"""Dispatch batch data from obs.agent_id to every policy's process_fn.
Save original multi-dimensional rew in "save_rew", set rew to the
@ -46,21 +47,24 @@ class MultiAgentPolicyManager(BasePolicy):
for policy in self.policies:
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0]
if len(agent_index) == 0:
results[f'agent_{policy.agent_id}'] = Batch()
results[f"agent_{policy.agent_id}"] = Batch()
continue
tmp_batch, tmp_indice = batch[agent_index], indice[agent_index]
if has_rew:
tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1]
buffer._meta.rew = save_rew[:, policy.agent_id - 1]
results[f'agent_{policy.agent_id}'] = \
policy.process_fn(tmp_batch, buffer, tmp_indice)
results[f"agent_{policy.agent_id}"] = policy.process_fn(
tmp_batch, buffer, tmp_indice)
if has_rew: # restore from save_rew
buffer._meta.rew = save_rew
return Batch(results)
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch]] = None,
**kwargs) -> Batch:
def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch]] = None,
**kwargs: Any,
) -> Batch:
"""Dispatch batch data from obs.agent_id to every policy's forward.
:param state: if None, it means all agents have no state. If not
@ -107,15 +111,15 @@ class MultiAgentPolicyManager(BasePolicy):
**kwargs)
act = out.act
each_state = out.state \
if (hasattr(out, 'state') and out.state is not None) \
if (hasattr(out, "state") and out.state is not None) \
else Batch()
results.append((True, agent_index, out, act, each_state))
holder = Batch.cat([{'act': act} for
holder = Batch.cat([{"act": act} for
(has_data, agent_index, out, act, each_state)
in results if has_data])
state_dict, out_dict = {}, {}
for policy, (has_data, agent_index, out, act, state) in \
zip(self.policies, results):
for policy, (has_data, agent_index, out, act, state) in zip(
self.policies, results):
if has_data:
holder.act[agent_index] = act
state_dict["agent_" + str(policy.agent_id)] = state
@ -124,8 +128,9 @@ class MultiAgentPolicyManager(BasePolicy):
holder["state"] = state_dict
return holder
def learn(self, batch: Batch, **kwargs
) -> Dict[str, Union[float, List[float]]]:
def learn(
self, batch: Batch, **kwargs: Any
) -> Dict[str, Union[float, List[float]]]:
"""Dispatch the data to all policies for learning.
:return: a dict with the following contents:
@ -142,9 +147,9 @@ class MultiAgentPolicyManager(BasePolicy):
"""
results = {}
for policy in self.policies:
data = batch[f'agent_{policy.agent_id}']
data = batch[f"agent_{policy.agent_id}"]
if not data.is_empty():
out = policy.learn(batch=data, **kwargs)
for k, v in out.items():
results["agent_" + str(policy.agent_id) + '/' + k] = v
results["agent_" + str(policy.agent_id) + "/" + k] = v
return results

View File

@ -1,5 +1,5 @@
import numpy as np
from typing import Union, Optional, Dict, List
from typing import Any, Dict, Union, Optional
from tianshou.data import Batch
from tianshou.policy import BasePolicy
@ -11,9 +11,12 @@ class RandomPolicy(BasePolicy):
It randomly chooses an action from the legal action.
"""
def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs) -> Batch:
def forward(
self,
batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None,
**kwargs: Any,
) -> Batch:
"""Compute the random action over the given batch data.
The input should contain a mask in batch.obs, with "True" to be
@ -34,7 +37,6 @@ class RandomPolicy(BasePolicy):
logits[~mask] = -np.inf
return Batch(act=logits.argmax(axis=-1))
def learn(self, batch: Batch, **kwargs
) -> Dict[str, Union[float, List[float]]]:
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
"""Since a random agent learn nothing, it returns an empty dict."""
return {}

View File

@ -3,8 +3,8 @@ from tianshou.trainer.onpolicy import onpolicy_trainer
from tianshou.trainer.offpolicy import offpolicy_trainer
__all__ = [
'gather_info',
'test_episode',
'onpolicy_trainer',
'offpolicy_trainer',
"gather_info",
"test_episode",
"onpolicy_trainer",
"offpolicy_trainer",
]

View File

@ -10,23 +10,23 @@ from tianshou.trainer import test_episode, gather_info
def offpolicy_trainer(
policy: BasePolicy,
train_collector: Collector,
test_collector: Collector,
max_epoch: int,
step_per_epoch: int,
collect_per_step: int,
episode_per_test: Union[int, List[int]],
batch_size: int,
update_per_step: int = 1,
train_fn: Optional[Callable[[int], None]] = None,
test_fn: Optional[Callable[[int], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
writer: Optional[SummaryWriter] = None,
log_interval: int = 1,
verbose: bool = True,
test_in_train: bool = True,
policy: BasePolicy,
train_collector: Collector,
test_collector: Collector,
max_epoch: int,
step_per_epoch: int,
collect_per_step: int,
episode_per_test: Union[int, List[int]],
batch_size: int,
update_per_step: int = 1,
train_fn: Optional[Callable[[int], None]] = None,
test_fn: Optional[Callable[[int], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
writer: Optional[SummaryWriter] = None,
log_interval: int = 1,
verbose: bool = True,
test_in_train: bool = True,
) -> Dict[str, Union[float, str]]:
"""A wrapper for off-policy trainer procedure.
@ -72,7 +72,7 @@ def offpolicy_trainer(
:return: See :func:`~tianshou.trainer.gather_info`.
"""
global_step = 0
best_epoch, best_reward = -1, -1.
best_epoch, best_reward = -1, -1.0
stat = {}
start_time = time.time()
test_in_train = test_in_train and train_collector.policy == policy
@ -81,42 +81,43 @@ def offpolicy_trainer(
policy.train()
if train_fn:
train_fn(epoch)
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
with tqdm.tqdm(
total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
) as t:
while t.n < t.total:
result = train_collector.collect(n_step=collect_per_step)
data = {}
if test_in_train and stop_fn and stop_fn(result['rew']):
if test_in_train and stop_fn and stop_fn(result["rew"]):
test_result = test_episode(
policy, test_collector, test_fn,
epoch, episode_per_test, writer, global_step)
if stop_fn and stop_fn(test_result['rew']):
if stop_fn and stop_fn(test_result["rew"]):
if save_fn:
save_fn(policy)
for k in result.keys():
data[k] = f'{result[k]:.2f}'
data[k] = f"{result[k]:.2f}"
t.set_postfix(**data)
return gather_info(
start_time, train_collector, test_collector,
test_result['rew'])
test_result["rew"])
else:
policy.train()
if train_fn:
train_fn(epoch)
for i in range(update_per_step * min(
result['n/st'] // collect_per_step, t.total - t.n)):
result["n/st"] // collect_per_step, t.total - t.n)):
global_step += collect_per_step
losses = policy.update(batch_size, train_collector.buffer)
for k in result.keys():
data[k] = f'{result[k]:.2f}'
data[k] = f"{result[k]:.2f}"
if writer and global_step % log_interval == 0:
writer.add_scalar('train/' + k, result[k],
writer.add_scalar("train/" + k, result[k],
global_step=global_step)
for k in losses.keys():
if stat.get(k) is None:
stat[k] = MovAvg()
stat[k].add(losses[k])
data[k] = f'{stat[k].get():.6f}'
data[k] = f"{stat[k].get():.6f}"
if writer and global_step % log_interval == 0:
writer.add_scalar(
k, stat[k].get(), global_step=global_step)
@ -127,14 +128,14 @@ def offpolicy_trainer(
# test
result = test_episode(policy, test_collector, test_fn, epoch,
episode_per_test, writer, global_step)
if best_epoch == -1 or best_reward < result['rew']:
best_reward = result['rew']
if best_epoch == -1 or best_reward < result["rew"]:
best_reward = result["rew"]
best_epoch = epoch
if save_fn:
save_fn(policy)
if verbose:
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}')
print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f}, "
f"best_reward: {best_reward:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward):
break
return gather_info(

View File

@ -10,23 +10,23 @@ from tianshou.trainer import test_episode, gather_info
def onpolicy_trainer(
policy: BasePolicy,
train_collector: Collector,
test_collector: Collector,
max_epoch: int,
step_per_epoch: int,
collect_per_step: int,
repeat_per_collect: int,
episode_per_test: Union[int, List[int]],
batch_size: int,
train_fn: Optional[Callable[[int], None]] = None,
test_fn: Optional[Callable[[int], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
writer: Optional[SummaryWriter] = None,
log_interval: int = 1,
verbose: bool = True,
test_in_train: bool = True,
policy: BasePolicy,
train_collector: Collector,
test_collector: Collector,
max_epoch: int,
step_per_epoch: int,
collect_per_step: int,
repeat_per_collect: int,
episode_per_test: Union[int, List[int]],
batch_size: int,
train_fn: Optional[Callable[[int], None]] = None,
test_fn: Optional[Callable[[int], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
writer: Optional[SummaryWriter] = None,
log_interval: int = 1,
verbose: bool = True,
test_in_train: bool = True,
) -> Dict[str, Union[float, str]]:
"""A wrapper for on-policy trainer procedure.
@ -72,7 +72,7 @@ def onpolicy_trainer(
:return: See :func:`~tianshou.trainer.gather_info`.
"""
global_step = 0
best_epoch, best_reward = -1, -1.
best_epoch, best_reward = -1, -1.0
stat = {}
start_time = time.time()
test_in_train = test_in_train and train_collector.policy == policy
@ -81,30 +81,32 @@ def onpolicy_trainer(
policy.train()
if train_fn:
train_fn(epoch)
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
with tqdm.tqdm(
total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
) as t:
while t.n < t.total:
result = train_collector.collect(n_episode=collect_per_step)
data = {}
if test_in_train and stop_fn and stop_fn(result['rew']):
if test_in_train and stop_fn and stop_fn(result["rew"]):
test_result = test_episode(
policy, test_collector, test_fn,
epoch, episode_per_test, writer, global_step)
if stop_fn and stop_fn(test_result['rew']):
if stop_fn and stop_fn(test_result["rew"]):
if save_fn:
save_fn(policy)
for k in result.keys():
data[k] = f'{result[k]:.2f}'
data[k] = f"{result[k]:.2f}"
t.set_postfix(**data)
return gather_info(
start_time, train_collector, test_collector,
test_result['rew'])
test_result["rew"])
else:
policy.train()
if train_fn:
train_fn(epoch)
losses = policy.update(
0, train_collector.buffer, batch_size, repeat_per_collect)
0, train_collector.buffer,
batch_size=batch_size, repeat=repeat_per_collect)
train_collector.reset_buffer()
step = 1
for k in losses.keys():
@ -112,15 +114,15 @@ def onpolicy_trainer(
step = max(step, len(losses[k]))
global_step += step * collect_per_step
for k in result.keys():
data[k] = f'{result[k]:.2f}'
data[k] = f"{result[k]:.2f}"
if writer and global_step % log_interval == 0:
writer.add_scalar(
'train/' + k, result[k], global_step=global_step)
"train/" + k, result[k], global_step=global_step)
for k in losses.keys():
if stat.get(k) is None:
stat[k] = MovAvg()
stat[k].add(losses[k])
data[k] = f'{stat[k].get():.6f}'
data[k] = f"{stat[k].get():.6f}"
if writer and global_step % log_interval == 0:
writer.add_scalar(
k, stat[k].get(), global_step=global_step)
@ -131,14 +133,14 @@ def onpolicy_trainer(
# test
result = test_episode(policy, test_collector, test_fn, epoch,
episode_per_test, writer, global_step)
if best_epoch == -1 or best_reward < result['rew']:
best_reward = result['rew']
if best_epoch == -1 or best_reward < result["rew"]:
best_reward = result["rew"]
best_epoch = epoch
if save_fn:
save_fn(policy)
if verbose:
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}')
print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f}, "
f"best_reward: {best_reward:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward):
break
return gather_info(

View File

@ -8,13 +8,14 @@ from tianshou.policy import BasePolicy
def test_episode(
policy: BasePolicy,
collector: Collector,
test_fn: Optional[Callable[[int], None]],
epoch: int,
n_episode: Union[int, List[int]],
writer: SummaryWriter = None,
global_step: int = None) -> Dict[str, float]:
policy: BasePolicy,
collector: Collector,
test_fn: Optional[Callable[[int], None]],
epoch: int,
n_episode: Union[int, List[int]],
writer: Optional[SummaryWriter] = None,
global_step: Optional[int] = None,
) -> Dict[str, float]:
"""A simple wrapper of testing policy in collector."""
collector.reset_env()
collector.reset_buffer()
@ -29,15 +30,16 @@ def test_episode(
result = collector.collect(n_episode=n_episode)
if writer is not None and global_step is not None:
for k in result.keys():
writer.add_scalar('test/' + k, result[k], global_step=global_step)
writer.add_scalar("test/" + k, result[k], global_step=global_step)
return result
def gather_info(start_time: float,
train_c: Collector,
test_c: Collector,
best_reward: float
) -> Dict[str, Union[float, str]]:
def gather_info(
start_time: float,
train_c: Collector,
test_c: Collector,
best_reward: float,
) -> Dict[str, Union[float, str]]:
"""A simple wrapper of gathering information from collectors.
:return: A dictionary with the following keys:
@ -60,15 +62,15 @@ def gather_info(start_time: float,
train_speed = train_c.collect_step / (duration - test_c.collect_time)
test_speed = test_c.collect_step / test_c.collect_time
return {
'train_step': train_c.collect_step,
'train_episode': train_c.collect_episode,
'train_time/collector': f'{train_c.collect_time:.2f}s',
'train_time/model': f'{model_time:.2f}s',
'train_speed': f'{train_speed:.2f} step/s',
'test_step': test_c.collect_step,
'test_episode': test_c.collect_episode,
'test_time': f'{test_c.collect_time:.2f}s',
'test_speed': f'{test_speed:.2f} step/s',
'best_reward': best_reward,
'duration': f'{duration:.2f}s',
"train_step": train_c.collect_step,
"train_episode": train_c.collect_episode,
"train_time/collector": f"{train_c.collect_time:.2f}s",
"train_time/model": f"{model_time:.2f}s",
"train_speed": f"{train_speed:.2f} step/s",
"test_step": test_c.collect_step,
"test_episode": test_c.collect_episode,
"test_time": f"{test_c.collect_time:.2f}s",
"test_speed": f"{test_speed:.2f} step/s",
"best_reward": best_reward,
"duration": f"{duration:.2f}s",
}

View File

@ -1,9 +1,7 @@
from tianshou.utils.config import tqdm_config
from tianshou.utils.compile import pre_compile
from tianshou.utils.moving_average import MovAvg
__all__ = [
"MovAvg",
"pre_compile",
"tqdm_config",
]

View File

@ -1,27 +0,0 @@
import numpy as np
from tianshou.policy.base import _episodic_return, _nstep_return
from tianshou.data.utils.segtree import _reduce, _setitem, _get_prefix_sum_idx
def pre_compile():
"""Functions that need to pre-compile for producing benchmark result.
Since Numba acceleration needs to compile the function in the first run,
here we use some fake data for the common-type function-call compilation.
Otherwise, the current training speed cannot compare with the previous.
"""
f64 = np.array([0, 1], dtype=np.float64)
f32 = np.array([0, 1], dtype=np.float32)
b = np.array([False, True], dtype=np.bool_)
i64 = np.array([0, 1], dtype=np.int64)
# returns
_episodic_return(f64, f64, b, .1, .1)
_episodic_return(f32, f64, b, .1, .1)
_nstep_return(f64, b, f32, i64, .1, 1, 4, 1., 0.)
# segtree
_setitem(f64, i64, f64)
_setitem(f64, i64, f32)
_reduce(f64, 0, 1)
_get_prefix_sum_idx(f64, 1, f64)
_get_prefix_sum_idx(f32, 1, f64)

View File

@ -1,4 +1,4 @@
tqdm_config = {
'dynamic_ncols': True,
'ascii': True,
"dynamic_ncols": True,
"ascii": True,
}

View File

@ -1,5 +1,6 @@
import torch
import numpy as np
from numbers import Number
from typing import Union
from tianshou.data import to_numpy
@ -30,7 +31,9 @@ class MovAvg(object):
self.cache = []
self.banned = [np.inf, np.nan, -np.inf]
def add(self, x: Union[float, list, np.ndarray, torch.Tensor]) -> float:
def add(
self, x: Union[Number, np.number, list, np.ndarray, torch.Tensor]
) -> np.number:
"""Add a scalar into :class:`MovAvg`.
You can add ``torch.Tensor`` with only one element, a python scalar, or
@ -39,26 +42,26 @@ class MovAvg(object):
if isinstance(x, torch.Tensor):
x = to_numpy(x.flatten())
if isinstance(x, list) or isinstance(x, np.ndarray):
for _ in x:
if _ not in self.banned:
self.cache.append(_)
for i in x:
if i not in self.banned:
self.cache.append(i)
elif x not in self.banned:
self.cache.append(x)
if self.size > 0 and len(self.cache) > self.size:
self.cache = self.cache[-self.size:]
return self.get()
def get(self) -> float:
def get(self) -> np.number:
"""Get the average."""
if len(self.cache) == 0:
return 0
return np.mean(self.cache)
def mean(self) -> float:
def mean(self) -> np.number:
"""Get the average. Same as :meth:`get`."""
return self.get()
def std(self) -> float:
def std(self) -> np.number:
"""Get the standard deviation."""
if len(self.cache) == 0:
return 0

View File

@ -1,13 +1,16 @@
import torch
import numpy as np
from torch import nn
from typing import List, Tuple, Union, Optional
from typing import Any, Dict, List, Tuple, Union, Callable, Optional, Sequence
from tianshou.data import to_torch
def miniblock(inp: int, oup: int,
norm_layer: nn.modules.Module) -> List[nn.modules.Module]:
def miniblock(
inp: int,
oup: int,
norm_layer: Optional[Callable[[int], nn.modules.Module]],
) -> List[nn.modules.Module]:
"""Construct a miniblock with given input/output-size and norm layer."""
ret = [nn.Linear(inp, oup)]
if norm_layer is not None:
@ -27,18 +30,22 @@ class Net(nn.Module):
shape, but affects the input shape.
:param bool dueling: whether to use dueling network to calculate Q values
(for Dueling DQN), defaults to False.
:param nn.modules.Module norm_layer: use which normalization before ReLU,
e.g., ``nn.LayerNorm`` and ``nn.BatchNorm1d``, defaults to None.
:param norm_layer: use which normalization before ReLU, e.g.,
``nn.LayerNorm`` and ``nn.BatchNorm1d``, defaults to None.
"""
def __init__(self, layer_num: int, state_shape: tuple,
action_shape: Optional[Union[tuple, int]] = 0,
device: Union[str, torch.device] = 'cpu',
softmax: bool = False,
concat: bool = False,
hidden_layer_size: int = 128,
dueling: Optional[Tuple[int, int]] = None,
norm_layer: Optional[nn.modules.Module] = None):
def __init__(
self,
layer_num: int,
state_shape: tuple,
action_shape: Optional[Union[tuple, int]] = 0,
device: Union[str, int, torch.device] = "cpu",
softmax: bool = False,
concat: bool = False,
hidden_layer_size: int = 128,
dueling: Optional[Tuple[int, int]] = None,
norm_layer: Optional[Callable[[int], nn.modules.Module]] = None,
) -> None:
super().__init__()
self.device = device
self.dueling = dueling
@ -78,7 +85,12 @@ class Net(nn.Module):
self.V = nn.Sequential(*self.V)
self.model = nn.Sequential(*self.model)
def forward(self, s, state=None, info={}):
def forward(
self,
s: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
"""Mapping: s -> flatten -> logits."""
s = to_torch(s, device=self.device, dtype=torch.float32)
s = s.reshape(s.size(0), -1)
@ -98,19 +110,33 @@ class Recurrent(nn.Module):
:ref:`build_the_network`.
"""
def __init__(self, layer_num, state_shape, action_shape,
device='cpu', hidden_layer_size=128):
def __init__(
self,
layer_num: int,
state_shape: Sequence[int],
action_shape: Sequence[int],
device: Union[str, int, torch.device] = "cpu",
hidden_layer_size: int = 128,
) -> None:
super().__init__()
self.state_shape = state_shape
self.action_shape = action_shape
self.device = device
self.nn = nn.LSTM(input_size=hidden_layer_size,
hidden_size=hidden_layer_size,
num_layers=layer_num, batch_first=True)
self.nn = nn.LSTM(
input_size=hidden_layer_size,
hidden_size=hidden_layer_size,
num_layers=layer_num,
batch_first=True,
)
self.fc1 = nn.Linear(np.prod(state_shape), hidden_layer_size)
self.fc2 = nn.Linear(hidden_layer_size, np.prod(action_shape))
def forward(self, s, state=None, info={}):
def forward(
self,
s: Union[np.ndarray, torch.Tensor],
state: Optional[Dict[str, torch.Tensor]] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Mapping: s -> flatten -> logits.
In the evaluation mode, s should be with shape ``[bsz, dim]``; in the
@ -130,9 +156,9 @@ class Recurrent(nn.Module):
else:
# we store the stack data in [bsz, len, ...] format
# but pytorch rnn needs [len, bsz, ...]
s, (h, c) = self.nn(s, (state['h'].transpose(0, 1).contiguous(),
state['c'].transpose(0, 1).contiguous()))
s, (h, c) = self.nn(s, (state["h"].transpose(0, 1).contiguous(),
state["c"].transpose(0, 1).contiguous()))
s = self.fc2(s[:, -1])
# please ensure the first dim is batch size: [bsz, len, ...]
return s, {'h': h.transpose(0, 1).detach(),
'c': c.transpose(0, 1).detach()}
return s, {"h": h.transpose(0, 1).detach(),
"c": c.transpose(0, 1).detach()}

View File

@ -1,6 +1,7 @@
import torch
import numpy as np
from torch import nn
from typing import Any, Dict, Tuple, Union, Optional, Sequence
from tianshou.data import to_torch, to_torch_as
@ -12,14 +13,25 @@ class Actor(nn.Module):
:ref:`build_the_network`.
"""
def __init__(self, preprocess_net, action_shape, max_action=1.,
device='cpu', hidden_layer_size=128):
def __init__(
self,
preprocess_net: nn.Module,
action_shape: Sequence[int],
max_action: float = 1.0,
device: Union[str, int, torch.device] = "cpu",
hidden_layer_size: int = 128,
) -> None:
super().__init__()
self.preprocess = preprocess_net
self.last = nn.Linear(hidden_layer_size, np.prod(action_shape))
self._max = max_action
def forward(self, s, state=None, info={}):
def forward(
self,
s: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
"""Mapping: s -> logits -> action."""
logits, h = self.preprocess(s, state)
logits = self._max * torch.tanh(self.last(logits))
@ -33,13 +45,23 @@ class Critic(nn.Module):
:ref:`build_the_network`.
"""
def __init__(self, preprocess_net, device='cpu', hidden_layer_size=128):
def __init__(
self,
preprocess_net: nn.Module,
device: Union[str, int, torch.device] = "cpu",
hidden_layer_size: int = 128,
) -> None:
super().__init__()
self.device = device
self.preprocess = preprocess_net
self.last = nn.Linear(hidden_layer_size, 1)
def forward(self, s, a=None, info={}):
def forward(
self,
s: Union[np.ndarray, torch.Tensor],
a: Optional[Union[np.ndarray, torch.Tensor]] = None,
info: Dict[str, Any] = {},
) -> torch.Tensor:
"""Mapping: (s, a) -> logits -> Q(s, a)."""
s = to_torch(s, device=self.device, dtype=torch.float32)
s = s.flatten(1)
@ -59,8 +81,15 @@ class ActorProb(nn.Module):
:ref:`build_the_network`.
"""
def __init__(self, preprocess_net, action_shape, max_action=1.,
device='cpu', unbounded=False, hidden_layer_size=128):
def __init__(
self,
preprocess_net: nn.Module,
action_shape: Sequence[int],
max_action: float = 1.0,
device: Union[str, int, torch.device] = "cpu",
unbounded: bool = False,
hidden_layer_size: int = 128,
) -> None:
super().__init__()
self.preprocess = preprocess_net
self.device = device
@ -69,7 +98,12 @@ class ActorProb(nn.Module):
self._max = max_action
self._unbounded = unbounded
def forward(self, s, state=None, info={}):
def forward(
self,
s: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Any]:
"""Mapping: s -> logits -> (mu, sigma)."""
logits, h = self.preprocess(s, state)
mu = self.mu(logits)
@ -78,7 +112,7 @@ class ActorProb(nn.Module):
shape = [1] * len(mu.shape)
shape[1] = -1
sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp()
return (mu, sigma), None
return (mu, sigma), state
class RecurrentActorProb(nn.Module):
@ -88,19 +122,35 @@ class RecurrentActorProb(nn.Module):
:ref:`build_the_network`.
"""
def __init__(self, layer_num, state_shape, action_shape, max_action=1.,
device='cpu', unbounded=False, hidden_layer_size=128):
def __init__(
self,
layer_num: int,
state_shape: Sequence[int],
action_shape: Sequence[int],
max_action: float = 1.0,
device: Union[str, int, torch.device] = "cpu",
unbounded: bool = False,
hidden_layer_size: int = 128,
) -> None:
super().__init__()
self.device = device
self.nn = nn.LSTM(input_size=np.prod(state_shape),
hidden_size=hidden_layer_size,
num_layers=layer_num, batch_first=True)
self.nn = nn.LSTM(
input_size=np.prod(state_shape),
hidden_size=hidden_layer_size,
num_layers=layer_num,
batch_first=True,
)
self.mu = nn.Linear(hidden_layer_size, np.prod(action_shape))
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
self._max = max_action
self._unbounded = unbounded
def forward(self, s, state=None, info={}):
def forward(
self,
s: Union[np.ndarray, torch.Tensor],
state: Optional[Dict[str, torch.Tensor]] = None,
info: Dict[str, Any] = {},
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Dict[str, torch.Tensor]]:
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
s = to_torch(s, device=self.device, dtype=torch.float32)
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
@ -114,8 +164,8 @@ class RecurrentActorProb(nn.Module):
else:
# we store the stack data in [bsz, len, ...] format
# but pytorch rnn needs [len, bsz, ...]
s, (h, c) = self.nn(s, (state['h'].transpose(0, 1).contiguous(),
state['c'].transpose(0, 1).contiguous()))
s, (h, c) = self.nn(s, (state["h"].transpose(0, 1).contiguous(),
state["c"].transpose(0, 1).contiguous()))
logits = s[:, -1]
mu = self.mu(logits)
if not self._unbounded:
@ -124,8 +174,8 @@ class RecurrentActorProb(nn.Module):
shape[1] = -1
sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp()
# please ensure the first dim is batch size: [bsz, len, ...]
return (mu, sigma), {'h': h.transpose(0, 1).detach(),
'c': c.transpose(0, 1).detach()}
return (mu, sigma), {"h": h.transpose(0, 1).detach(),
"c": c.transpose(0, 1).detach()}
class RecurrentCritic(nn.Module):
@ -135,18 +185,32 @@ class RecurrentCritic(nn.Module):
:ref:`build_the_network`.
"""
def __init__(self, layer_num, state_shape,
action_shape=0, device='cpu', hidden_layer_size=128):
def __init__(
self,
layer_num: int,
state_shape: Sequence[int],
action_shape: Sequence[int] = [0],
device: Union[str, int, torch.device] = "cpu",
hidden_layer_size: int = 128,
) -> None:
super().__init__()
self.state_shape = state_shape
self.action_shape = action_shape
self.device = device
self.nn = nn.LSTM(input_size=np.prod(state_shape),
hidden_size=hidden_layer_size,
num_layers=layer_num, batch_first=True)
self.nn = nn.LSTM(
input_size=np.prod(state_shape),
hidden_size=hidden_layer_size,
num_layers=layer_num,
batch_first=True,
)
self.fc2 = nn.Linear(hidden_layer_size + np.prod(action_shape), 1)
def forward(self, s, a=None):
def forward(
self,
s: Union[np.ndarray, torch.Tensor],
a: Optional[Union[np.ndarray, torch.Tensor]] = None,
info: Dict[str, Any] = {},
) -> torch.Tensor:
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
s = to_torch(s, device=self.device, dtype=torch.float32)
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)

View File

@ -2,6 +2,7 @@ import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from typing import Any, Dict, Tuple, Union, Optional, Sequence
class Actor(nn.Module):
@ -11,12 +12,22 @@ class Actor(nn.Module):
:ref:`build_the_network`.
"""
def __init__(self, preprocess_net, action_shape, hidden_layer_size=128):
def __init__(
self,
preprocess_net: nn.Module,
action_shape: Sequence[int],
hidden_layer_size: int = 128,
) -> None:
super().__init__()
self.preprocess = preprocess_net
self.last = nn.Linear(hidden_layer_size, np.prod(action_shape))
def forward(self, s, state=None, info={}):
def forward(
self,
s: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: s -> Q(s, \*)."""
logits, h = self.preprocess(s, state)
logits = F.softmax(self.last(logits), dim=-1)
@ -30,14 +41,18 @@ class Critic(nn.Module):
:ref:`build_the_network`.
"""
def __init__(self, preprocess_net, hidden_layer_size=128):
def __init__(
self, preprocess_net: nn.Module, hidden_layer_size: int = 128
) -> None:
super().__init__()
self.preprocess = preprocess_net
self.last = nn.Linear(hidden_layer_size, 1)
def forward(self, s, **kwargs):
def forward(
self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any
) -> torch.Tensor:
"""Mapping: s -> V(s)."""
logits, h = self.preprocess(s, state=kwargs.get('state', None))
logits, h = self.preprocess(s, state=kwargs.get("state", None))
logits = self.last(logits)
return logits
@ -49,17 +64,31 @@ class DQN(nn.Module):
:ref:`build_the_network`.
"""
def __init__(self, c, h, w, action_shape, device='cpu'):
super(DQN, self).__init__()
def __init__(
self,
c: int,
h: int,
w: int,
action_shape: Sequence[int],
device: Union[str, int, torch.device] = "cpu",
) -> None:
super().__init__()
self.device = device
def conv2d_size_out(size, kernel_size=5, stride=2):
def conv2d_size_out(
size: int, kernel_size: int = 5, stride: int = 2
) -> int:
return (size - (kernel_size - 1) - 1) // stride + 1
def conv2d_layers_size_out(size,
kernel_size_1=8, stride_1=4,
kernel_size_2=4, stride_2=2,
kernel_size_3=3, stride_3=1):
def conv2d_layers_size_out(
size: int,
kernel_size_1: int = 8,
stride_1: int = 4,
kernel_size_2: int = 4,
stride_2: int = 2,
kernel_size_3: int = 3,
stride_3: int = 1,
) -> int:
size = conv2d_size_out(size, kernel_size_1, stride_1)
size = conv2d_size_out(size, kernel_size_2, stride_2)
size = conv2d_size_out(size, kernel_size_3, stride_3)
@ -78,10 +107,15 @@ class DQN(nn.Module):
nn.ReLU(inplace=True),
nn.Flatten(),
nn.Linear(linear_input_size, 512),
nn.Linear(512, np.prod(action_shape))
nn.Linear(512, np.prod(action_shape)),
)
def forward(self, x, state=None, info={}):
def forward(
self,
x: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Q(x, \*)."""
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, device=self.device, dtype=torch.float32)