code format and update function signatures (#213)
Cherry-pick from #200 - update the function signature - format code-style - move _compile into separate functions - fix a bug in to_torch and to_numpy (Batch) - remove None in action_range In short, the code-format only contains function-signature style and `'` -> `"`. (pick up from [black](https://github.com/psf/black))
This commit is contained in:
parent
16d8e9b051
commit
c91def6cbc
@ -119,8 +119,8 @@ def test_sac_bipedal(args=get_args()):
|
||||
|
||||
policy = SACPolicy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
args.tau, args.gamma, args.alpha,
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
|
||||
reward_normalization=args.rew_norm,
|
||||
ignore_done=args.ignore_done,
|
||||
estimation_step=args.n_step)
|
||||
|
@ -78,14 +78,12 @@ def test_sac(args=get_args()):
|
||||
target_entropy = -np.prod(env.action_space.shape)
|
||||
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
|
||||
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
|
||||
alpha = (target_entropy, log_alpha, alpha_optim)
|
||||
else:
|
||||
alpha = args.alpha
|
||||
args.alpha = (target_entropy, log_alpha, alpha_optim)
|
||||
|
||||
policy = SACPolicy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
args.tau, args.gamma, alpha,
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
|
||||
reward_normalization=args.rew_norm, ignore_done=True,
|
||||
exploration_noise=OUNoise(0.0, args.noise_std))
|
||||
# collector
|
||||
|
@ -66,8 +66,9 @@ def test_ddpg(args=get_args()):
|
||||
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
|
||||
policy = DDPGPolicy(
|
||||
actor, actor_optim, critic, critic_optim,
|
||||
args.tau, args.gamma, GaussianNoise(sigma=args.exploration_noise),
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
tau=args.tau, gamma=args.gamma,
|
||||
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
|
||||
reward_normalization=True, ignore_done=True)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
|
@ -71,8 +71,8 @@ def test_sac(args=get_args()):
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
policy = SACPolicy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
args.tau, args.gamma, args.alpha,
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
|
||||
reward_normalization=args.rew_norm, ignore_done=True)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
|
@ -73,10 +73,12 @@ def test_td3(args=get_args()):
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
policy = TD3Policy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
args.tau, args.gamma,
|
||||
GaussianNoise(sigma=args.exploration_noise), args.policy_noise,
|
||||
args.update_actor_freq, args.noise_clip,
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
tau=args.tau, gamma=args.gamma,
|
||||
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
|
||||
policy_noise=args.policy_noise,
|
||||
update_actor_freq=args.update_actor_freq,
|
||||
noise_clip=args.noise_clip,
|
||||
reward_normalization=True, ignore_done=True)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
|
@ -79,8 +79,8 @@ def test_sac(args=get_args()):
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
policy = SACPolicy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
args.tau, args.gamma, args.alpha,
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
|
||||
reward_normalization=True, ignore_done=True)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
|
@ -76,10 +76,12 @@ def test_td3(args=get_args()):
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
policy = TD3Policy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
args.tau, args.gamma,
|
||||
GaussianNoise(sigma=args.exploration_noise), args.policy_noise,
|
||||
args.update_actor_freq, args.noise_clip,
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
tau=args.tau, gamma=args.gamma,
|
||||
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
|
||||
policy_noise=args.policy_noise,
|
||||
update_actor_freq=args.update_actor_freq,
|
||||
noise_clip=args.noise_clip,
|
||||
reward_normalization=True, ignore_done=True)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
|
@ -78,8 +78,9 @@ def test_ddpg(args=get_args()):
|
||||
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
|
||||
policy = DDPGPolicy(
|
||||
actor, actor_optim, critic, critic_optim,
|
||||
args.tau, args.gamma, GaussianNoise(sigma=args.exploration_noise),
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
tau=args.tau, gamma=args.gamma,
|
||||
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
|
||||
reward_normalization=args.rew_norm,
|
||||
ignore_done=args.ignore_done,
|
||||
estimation_step=args.n_step)
|
||||
|
@ -79,8 +79,8 @@ def test_sac_with_il(args=get_args()):
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
policy = SACPolicy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
args.tau, args.gamma, args.alpha,
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
|
||||
reward_normalization=args.rew_norm,
|
||||
ignore_done=args.ignore_done,
|
||||
estimation_step=args.n_step)
|
||||
|
@ -82,9 +82,12 @@ def test_td3(args=get_args()):
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
policy = TD3Policy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
args.tau, args.gamma, GaussianNoise(sigma=args.exploration_noise),
|
||||
args.policy_noise, args.update_actor_freq, args.noise_clip,
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
tau=args.tau, gamma=args.gamma,
|
||||
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
|
||||
policy_noise=args.policy_noise,
|
||||
update_actor_freq=args.update_actor_freq,
|
||||
noise_clip=args.noise_clip,
|
||||
reward_normalization=args.rew_norm,
|
||||
ignore_done=args.ignore_done,
|
||||
estimation_step=args.n_step)
|
||||
|
@ -1,17 +1,13 @@
|
||||
from tianshou import data, env, utils, policy, trainer, exploration
|
||||
|
||||
# pre-compile some common-type function-call to produce the correct benchmark
|
||||
# result: https://github.com/thu-ml/tianshou/pull/193#discussion_r480536371
|
||||
utils.pre_compile()
|
||||
|
||||
|
||||
__version__ = '0.2.7'
|
||||
__version__ = "0.2.7"
|
||||
|
||||
__all__ = [
|
||||
'env',
|
||||
'data',
|
||||
'utils',
|
||||
'policy',
|
||||
'trainer',
|
||||
'exploration',
|
||||
"env",
|
||||
"data",
|
||||
"utils",
|
||||
"policy",
|
||||
"trainer",
|
||||
"exploration",
|
||||
]
|
||||
|
@ -1,19 +1,18 @@
|
||||
from tianshou.data.batch import Batch
|
||||
from tianshou.data.utils.converter import to_numpy, to_torch, \
|
||||
to_torch_as
|
||||
from tianshou.data.utils.converter import to_numpy, to_torch, to_torch_as
|
||||
from tianshou.data.utils.segtree import SegmentTree
|
||||
from tianshou.data.buffer import ReplayBuffer, \
|
||||
ListReplayBuffer, PrioritizedReplayBuffer
|
||||
from tianshou.data.collector import Collector
|
||||
|
||||
__all__ = [
|
||||
'Batch',
|
||||
'to_numpy',
|
||||
'to_torch',
|
||||
'to_torch_as',
|
||||
'SegmentTree',
|
||||
'ReplayBuffer',
|
||||
'ListReplayBuffer',
|
||||
'PrioritizedReplayBuffer',
|
||||
'Collector',
|
||||
"Batch",
|
||||
"to_numpy",
|
||||
"to_torch",
|
||||
"to_torch_as",
|
||||
"SegmentTree",
|
||||
"ReplayBuffer",
|
||||
"ListReplayBuffer",
|
||||
"PrioritizedReplayBuffer",
|
||||
"Collector",
|
||||
]
|
||||
|
@ -5,8 +5,8 @@ import numpy as np
|
||||
from copy import deepcopy
|
||||
from numbers import Number
|
||||
from collections.abc import Collection
|
||||
from typing import Any, List, Tuple, Union, Iterator, KeysView, ValuesView, \
|
||||
ItemsView, Optional
|
||||
from typing import Any, List, Dict, Union, Iterator, Optional, Iterable, \
|
||||
Sequence, KeysView, ValuesView, ItemsView
|
||||
|
||||
# Disable pickle warning related to torch, since it has been removed
|
||||
# on torch master branch. See Pull Request #39003 for details:
|
||||
@ -23,8 +23,8 @@ def _is_batch_set(data: Any) -> bool:
|
||||
# "for e in data" will just unpack the first dimension,
|
||||
# but data.tolist() will flatten ndarray of objects
|
||||
# so do not use data.tolist()
|
||||
return data.dtype == np.object and \
|
||||
all(isinstance(e, (dict, Batch)) for e in data)
|
||||
return data.dtype == np.object and all(
|
||||
isinstance(e, (dict, Batch)) for e in data)
|
||||
elif isinstance(data, (list, tuple)):
|
||||
if len(data) > 0 and all(isinstance(e, (dict, Batch)) for e in data):
|
||||
return True
|
||||
@ -54,8 +54,9 @@ def _is_number(value: Any) -> bool:
|
||||
|
||||
|
||||
def _to_array_with_correct_type(v: Any) -> np.ndarray:
|
||||
if isinstance(v, np.ndarray) and \
|
||||
issubclass(v.dtype.type, (np.bool_, np.number)): # most often case
|
||||
if isinstance(v, np.ndarray) and issubclass(
|
||||
v.dtype.type, (np.bool_, np.number)
|
||||
): # most often case
|
||||
return v
|
||||
# convert the value to np.ndarray
|
||||
# convert to np.object data type if neither bool nor number
|
||||
@ -71,14 +72,16 @@ def _to_array_with_correct_type(v: Any) -> np.ndarray:
|
||||
# array([{}, array({}, dtype=object)], dtype=object)
|
||||
if not v.shape:
|
||||
v = v.item(0)
|
||||
elif any(isinstance(e, (np.ndarray, torch.Tensor))
|
||||
for e in v.reshape(-1)):
|
||||
elif any(
|
||||
isinstance(e, (np.ndarray, torch.Tensor)) for e in v.reshape(-1)
|
||||
):
|
||||
raise ValueError("Numpy arrays of tensors are not supported yet.")
|
||||
return v
|
||||
|
||||
|
||||
def _create_value(inst: Any, size: int, stack=True) -> Union[
|
||||
'Batch', np.ndarray, torch.Tensor]:
|
||||
def _create_value(
|
||||
inst: Any, size: int, stack: bool = True
|
||||
) -> Union["Batch", np.ndarray, torch.Tensor]:
|
||||
"""Create empty place-holders accroding to inst's shape.
|
||||
|
||||
:param bool stack: whether to stack or to concatenate. E.g. if inst has
|
||||
@ -100,12 +103,15 @@ def _create_value(inst: Any, size: int, stack=True) -> Union[
|
||||
target_type = inst.dtype.type
|
||||
else:
|
||||
target_type = np.object
|
||||
return np.full(shape,
|
||||
fill_value=None if target_type == np.object else 0,
|
||||
dtype=target_type)
|
||||
return np.full(
|
||||
shape,
|
||||
fill_value=None if target_type == np.object else 0,
|
||||
dtype=target_type
|
||||
)
|
||||
elif isinstance(inst, torch.Tensor):
|
||||
return torch.full(shape,
|
||||
fill_value=0, device=inst.device, dtype=inst.dtype)
|
||||
return torch.full(
|
||||
shape, fill_value=0, device=inst.device, dtype=inst.dtype
|
||||
)
|
||||
elif isinstance(inst, (dict, Batch)):
|
||||
zero_batch = Batch()
|
||||
for key, val in inst.items():
|
||||
@ -117,12 +123,13 @@ def _create_value(inst: Any, size: int, stack=True) -> Union[
|
||||
return np.array([None for _ in range(size)])
|
||||
|
||||
|
||||
def _assert_type_keys(keys) -> None:
|
||||
assert all(isinstance(e, str) for e in keys), \
|
||||
f"keys should all be string, but got {keys}"
|
||||
def _assert_type_keys(keys: Iterable[str]) -> None:
|
||||
assert all(
|
||||
isinstance(e, str) for e in keys
|
||||
), f"keys should all be string, but got {keys}"
|
||||
|
||||
|
||||
def _parse_value(v: Any):
|
||||
def _parse_value(v: Any) -> Optional[Union["Batch", np.ndarray, torch.Tensor]]:
|
||||
if isinstance(v, Batch): # most often case
|
||||
return v
|
||||
elif (isinstance(v, np.ndarray) and
|
||||
@ -166,12 +173,14 @@ class Batch:
|
||||
For a detailed description, please refer to :ref:`batch_concept`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
batch_dict: Optional[Union[
|
||||
dict, 'Batch', Tuple[Union[dict, 'Batch']],
|
||||
List[Union[dict, 'Batch']], np.ndarray]] = None,
|
||||
copy: bool = False,
|
||||
**kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
batch_dict: Optional[
|
||||
Union[dict, "Batch", Sequence[Union[dict, "Batch"]], np.ndarray]
|
||||
] = None,
|
||||
copy: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
if copy:
|
||||
batch_dict = deepcopy(batch_dict)
|
||||
if batch_dict is not None:
|
||||
@ -188,7 +197,7 @@ class Batch:
|
||||
"""Set self.key = value."""
|
||||
self.__dict__[key] = _parse_value(value)
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
"""Pickling interface.
|
||||
|
||||
Only the actual data are serialized for both efficiency and simplicity.
|
||||
@ -200,7 +209,7 @@ class Batch:
|
||||
state[k] = v
|
||||
return state
|
||||
|
||||
def __setstate__(self, state) -> None:
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||
"""Unpickling interface.
|
||||
|
||||
At this point, self is an empty Batch instance that has not been
|
||||
@ -208,8 +217,9 @@ class Batch:
|
||||
"""
|
||||
self.__init__(**state)
|
||||
|
||||
def __getitem__(self, index: Union[
|
||||
str, slice, int, np.integer, np.ndarray, List[int]]) -> 'Batch':
|
||||
def __getitem__(
|
||||
self, index: Union[str, slice, int, np.integer, np.ndarray, List[int]]
|
||||
) -> Union["Batch", np.ndarray, torch.Tensor]:
|
||||
"""Return self[index]."""
|
||||
if isinstance(index, str):
|
||||
return self.__dict__[index]
|
||||
@ -225,9 +235,11 @@ class Batch:
|
||||
else:
|
||||
raise IndexError("Cannot access item from empty Batch object.")
|
||||
|
||||
def __setitem__(self, index: Union[
|
||||
str, slice, int, np.integer, np.ndarray, List[int]],
|
||||
value: Any) -> None:
|
||||
def __setitem__(
|
||||
self,
|
||||
index: Union[str, slice, int, np.integer, np.ndarray, List[int]],
|
||||
value: Any,
|
||||
) -> None:
|
||||
"""Assign value to self[index]."""
|
||||
value = _parse_value(value)
|
||||
if isinstance(index, str):
|
||||
@ -252,12 +264,12 @@ class Batch:
|
||||
else:
|
||||
self.__dict__[key][index] = None
|
||||
|
||||
def __iadd__(self, other: Union['Batch', Number, np.number]):
|
||||
def __iadd__(self, other: Union["Batch", Number, np.number]) -> "Batch":
|
||||
"""Algebraic addition with another Batch instance in-place."""
|
||||
if isinstance(other, Batch):
|
||||
for (k, r), v in zip(self.__dict__.items(),
|
||||
other.__dict__.values()):
|
||||
# TODO are keys consistent?
|
||||
for (k, r), v in zip(
|
||||
self.__dict__.items(), other.__dict__.values()
|
||||
): # TODO are keys consistent?
|
||||
if isinstance(r, Batch) and r.is_empty():
|
||||
continue
|
||||
else:
|
||||
@ -273,11 +285,11 @@ class Batch:
|
||||
else:
|
||||
raise TypeError("Only addition of Batch or number is supported.")
|
||||
|
||||
def __add__(self, other: Union['Batch', Number, np.number]):
|
||||
def __add__(self, other: Union["Batch", Number, np.number]) -> "Batch":
|
||||
"""Algebraic addition with another Batch instance out-of-place."""
|
||||
return deepcopy(self).__iadd__(other)
|
||||
|
||||
def __imul__(self, val: Union[Number, np.number]):
|
||||
def __imul__(self, val: Union[Number, np.number]) -> "Batch":
|
||||
"""Algebraic multiplication with a scalar value in-place."""
|
||||
assert _is_number(val), "Only multiplication by a number is supported."
|
||||
for k, r in self.__dict__.items():
|
||||
@ -286,11 +298,11 @@ class Batch:
|
||||
self.__dict__[k] *= val
|
||||
return self
|
||||
|
||||
def __mul__(self, val: Union[Number, np.number]):
|
||||
def __mul__(self, val: Union[Number, np.number]) -> "Batch":
|
||||
"""Algebraic multiplication with a scalar value out-of-place."""
|
||||
return deepcopy(self).__imul__(val)
|
||||
|
||||
def __itruediv__(self, val: Union[Number, np.number]):
|
||||
def __itruediv__(self, val: Union[Number, np.number]) -> "Batch":
|
||||
"""Algebraic division with a scalar value in-place."""
|
||||
assert _is_number(val), "Only division by a number is supported."
|
||||
for k, r in self.__dict__.items():
|
||||
@ -299,23 +311,23 @@ class Batch:
|
||||
self.__dict__[k] /= val
|
||||
return self
|
||||
|
||||
def __truediv__(self, val: Union[Number, np.number]):
|
||||
def __truediv__(self, val: Union[Number, np.number]) -> "Batch":
|
||||
"""Algebraic division with a scalar value out-of-place."""
|
||||
return deepcopy(self).__itruediv__(val)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return str(self)."""
|
||||
s = self.__class__.__name__ + '(\n'
|
||||
s = self.__class__.__name__ + "(\n"
|
||||
flag = False
|
||||
for k, v in self.__dict__.items():
|
||||
rpl = '\n' + ' ' * (6 + len(k))
|
||||
obj = pprint.pformat(v).replace('\n', rpl)
|
||||
s += f' {k}: {obj},\n'
|
||||
rpl = "\n" + " " * (6 + len(k))
|
||||
obj = pprint.pformat(v).replace("\n", rpl)
|
||||
s += f" {k}: {obj},\n"
|
||||
flag = True
|
||||
if flag:
|
||||
s += ')'
|
||||
s += ")"
|
||||
else:
|
||||
s = self.__class__.__name__ + '()'
|
||||
s = self.__class__.__name__ + "()"
|
||||
return s
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
@ -350,8 +362,11 @@ class Batch:
|
||||
elif isinstance(v, Batch):
|
||||
v.to_numpy()
|
||||
|
||||
def to_torch(self, dtype: Optional[torch.dtype] = None,
|
||||
device: Union[str, int, torch.device] = 'cpu') -> None:
|
||||
def to_torch(
|
||||
self,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
) -> None:
|
||||
"""Change all numpy.ndarray to torch.Tensor in-place."""
|
||||
if not isinstance(device, torch.device):
|
||||
device = torch.device(device)
|
||||
@ -376,9 +391,9 @@ class Batch:
|
||||
v = v.type(dtype)
|
||||
self.__dict__[k] = v
|
||||
|
||||
def __cat(self,
|
||||
batches: List[Union[dict, 'Batch']],
|
||||
lens: List[int]) -> None:
|
||||
def __cat(
|
||||
self, batches: Sequence[Union[dict, "Batch"]], lens: List[int]
|
||||
) -> None:
|
||||
"""Private method for Batch.cat_.
|
||||
|
||||
::
|
||||
@ -445,8 +460,9 @@ class Batch:
|
||||
val, sum_lens[-1], stack=False)
|
||||
self.__dict__[k][sum_lens[i]:sum_lens[i + 1]] = val
|
||||
|
||||
def cat_(self,
|
||||
batches: Union['Batch', List[Union[dict, 'Batch']]]) -> None:
|
||||
def cat_(
|
||||
self, batches: Union["Batch", Sequence[Union[dict, "Batch"]]]
|
||||
) -> None:
|
||||
"""Concatenate a list of (or one) Batch objects into current batch."""
|
||||
if isinstance(batches, Batch):
|
||||
batches = [batches]
|
||||
@ -460,20 +476,19 @@ class Batch:
|
||||
# x.is_empty(recurse=True) here means x is a nested empty batch
|
||||
# like Batch(a=Batch), and we have to treat it as length zero and
|
||||
# keep it.
|
||||
lens = [0 if x.is_empty(recurse=True) else len(x)
|
||||
for x in batches]
|
||||
lens = [0 if x.is_empty(recurse=True) else len(x) for x in batches]
|
||||
except TypeError as e:
|
||||
raise ValueError(
|
||||
f'Batch.cat_ meets an exception. Maybe because there is any '
|
||||
f'scalar in {batches} but Batch.cat_ does not support the '
|
||||
f'concatenation of scalar.') from e
|
||||
"Batch.cat_ meets an exception. Maybe because there is any "
|
||||
f"scalar in {batches} but Batch.cat_ does not support the "
|
||||
"concatenation of scalar.") from e
|
||||
if not self.is_empty():
|
||||
batches = [self] + list(batches)
|
||||
lens = [0 if self.is_empty(recurse=True) else len(self)] + lens
|
||||
return self.__cat(batches, lens)
|
||||
self.__cat(batches, lens)
|
||||
|
||||
@staticmethod
|
||||
def cat(batches: List[Union[dict, 'Batch']]) -> 'Batch':
|
||||
def cat(batches: Sequence[Union[dict, "Batch"]]) -> "Batch":
|
||||
"""Concatenate a list of Batch object into a single new batch.
|
||||
|
||||
For keys that are not shared across all batches, batches that do not
|
||||
@ -494,9 +509,9 @@ class Batch:
|
||||
batch.cat_(batches)
|
||||
return batch
|
||||
|
||||
def stack_(self,
|
||||
batches: List[Union[dict, 'Batch']],
|
||||
axis: int = 0) -> None:
|
||||
def stack_(
|
||||
self, batches: Sequence[Union[dict, "Batch"]], axis: int = 0
|
||||
) -> None:
|
||||
"""Stack a list of Batch object into current batch."""
|
||||
if len(batches) == 0:
|
||||
return
|
||||
@ -528,8 +543,8 @@ class Batch:
|
||||
keys_partial = keys_reserve_or_partial.difference(keys_reserve)
|
||||
if keys_partial and axis != 0:
|
||||
raise ValueError(
|
||||
f"Stack of Batch with non-shared keys {keys_partial} "
|
||||
f"is only supported with axis=0, but got axis={axis}!")
|
||||
f"Stack of Batch with non-shared keys {keys_partial} is only "
|
||||
f"supported with axis=0, but got axis={axis}!")
|
||||
for k in keys_reserve:
|
||||
# reserved keys
|
||||
self.__dict__[k] = Batch()
|
||||
@ -547,7 +562,9 @@ class Batch:
|
||||
self.__dict__[k][i] = val
|
||||
|
||||
@staticmethod
|
||||
def stack(batches: List[Union[dict, 'Batch']], axis: int = 0) -> 'Batch':
|
||||
def stack(
|
||||
batches: Sequence[Union[dict, "Batch"]], axis: int = 0
|
||||
) -> "Batch":
|
||||
"""Stack a list of Batch object into a single new batch.
|
||||
|
||||
For keys that are not shared across all batches, batches that do not
|
||||
@ -573,9 +590,12 @@ class Batch:
|
||||
batch.stack_(batches, axis)
|
||||
return batch
|
||||
|
||||
def empty_(self, index: Union[
|
||||
str, slice, int, np.integer, np.ndarray, List[int]] = None
|
||||
) -> 'Batch':
|
||||
def empty_(
|
||||
self,
|
||||
index: Union[
|
||||
str, slice, int, np.integer, np.ndarray, List[int]
|
||||
] = None,
|
||||
) -> "Batch":
|
||||
"""Return an empty Batch object with 0 or None filled.
|
||||
|
||||
If "index" is specified, it will only reset the specific indexed-data.
|
||||
@ -613,8 +633,8 @@ class Batch:
|
||||
elif isinstance(v, Batch):
|
||||
self.__dict__[k].empty_(index=index)
|
||||
else: # scalar value
|
||||
warnings.warn('You are calling Batch.empty on a NumPy scalar, '
|
||||
'which may cause undefined behaviors.')
|
||||
warnings.warn("You are calling Batch.empty on a NumPy scalar, "
|
||||
"which may cause undefined behaviors.")
|
||||
if _is_number(v):
|
||||
self.__dict__[k] = v.__class__(0)
|
||||
else:
|
||||
@ -622,17 +642,21 @@ class Batch:
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def empty(batch: 'Batch', index: Union[
|
||||
str, slice, int, np.integer, np.ndarray, List[int]] = None
|
||||
) -> 'Batch':
|
||||
def empty(
|
||||
batch: "Batch",
|
||||
index: Union[
|
||||
str, slice, int, np.integer, np.ndarray, List[int]
|
||||
] = None,
|
||||
) -> "Batch":
|
||||
"""Return an empty Batch object with 0 or None filled.
|
||||
|
||||
The shape is the same as the given Batch.
|
||||
"""
|
||||
return deepcopy(batch).empty_(index)
|
||||
|
||||
def update(self, batch: Optional[Union[dict, 'Batch']] = None,
|
||||
**kwargs) -> None:
|
||||
def update(
|
||||
self, batch: Optional[Union[dict, "Batch"]] = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Update this batch from another dict/Batch."""
|
||||
if batch is None:
|
||||
self.update(kwargs)
|
||||
@ -648,8 +672,9 @@ class Batch:
|
||||
for v in self.__dict__.values():
|
||||
if isinstance(v, Batch) and v.is_empty(recurse=True):
|
||||
continue
|
||||
elif hasattr(v, '__len__') and (not isinstance(
|
||||
v, (np.ndarray, torch.Tensor)) or v.ndim > 0):
|
||||
elif hasattr(v, "__len__") and (not isinstance(
|
||||
v, (np.ndarray, torch.Tensor)) or v.ndim > 0
|
||||
):
|
||||
r.append(len(v))
|
||||
else:
|
||||
raise TypeError(f"Object {v} in {self} has no len()")
|
||||
@ -659,7 +684,7 @@ class Batch:
|
||||
raise TypeError(f"Object {self} has no len()")
|
||||
return min(r)
|
||||
|
||||
def is_empty(self, recurse: bool = False):
|
||||
def is_empty(self, recurse: bool = False) -> bool:
|
||||
"""Test if a Batch is empty.
|
||||
|
||||
If ``recurse=True``, it further tests the values of the object; else
|
||||
@ -689,8 +714,9 @@ class Batch:
|
||||
return True
|
||||
if not recurse:
|
||||
return False
|
||||
return all(False if not isinstance(x, Batch)
|
||||
else x.is_empty(recurse=True) for x in self.values())
|
||||
return all(
|
||||
False if not isinstance(x, Batch) else x.is_empty(recurse=True)
|
||||
for x in self.values())
|
||||
|
||||
@property
|
||||
def shape(self) -> List[int]:
|
||||
@ -707,8 +733,9 @@ class Batch:
|
||||
return list(map(min, zip(*data_shape))) if len(data_shape) > 1 \
|
||||
else data_shape[0]
|
||||
|
||||
def split(self, size: int, shuffle: bool = True,
|
||||
merge_last: bool = False) -> Iterator['Batch']:
|
||||
def split(
|
||||
self, size: int, shuffle: bool = True, merge_last: bool = False
|
||||
) -> Iterator["Batch"]:
|
||||
"""Split whole data into multiple small batches.
|
||||
|
||||
:param int size: divide the data batch with the given size, but one
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Any, Tuple, Union, Optional
|
||||
from numbers import Number
|
||||
from typing import Any, Dict, Tuple, Union, Optional
|
||||
|
||||
from tianshou.data import Batch, SegmentTree, to_numpy
|
||||
from tianshou.data.batch import _create_value
|
||||
@ -11,7 +12,7 @@ class ReplayBuffer:
|
||||
interaction between the policy and environment.
|
||||
|
||||
The current implementation of Tianshou typically use 7 reserved keys in
|
||||
:class:`~tianshou.data.Batch`
|
||||
:class:`~tianshou.data.Batch`:
|
||||
|
||||
* ``obs`` the observation of step :math:`t` ;
|
||||
* ``act`` the action of step :math:`t` ;
|
||||
@ -124,14 +125,17 @@ class ReplayBuffer:
|
||||
This feature is not supported in Prioritized Replay Buffer currently.
|
||||
"""
|
||||
|
||||
def __init__(self, size: int, stack_num: int = 1,
|
||||
ignore_obs_next: bool = False,
|
||||
save_only_last_obs: bool = False,
|
||||
sample_avail: bool = False) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
stack_num: int = 1,
|
||||
ignore_obs_next: bool = False,
|
||||
save_only_last_obs: bool = False,
|
||||
sample_avail: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._maxsize = size
|
||||
self._indices = np.arange(size)
|
||||
self._stack = None
|
||||
self.stack_num = stack_num
|
||||
self._avail = sample_avail and stack_num > 1
|
||||
self._avail_index = []
|
||||
@ -157,7 +161,7 @@ class ReplayBuffer:
|
||||
except KeyError as e:
|
||||
raise AttributeError from e
|
||||
|
||||
def __setstate__(self, state):
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||
"""Unpickling interface.
|
||||
|
||||
We need it because pickling buffer does not work out-of-the-box
|
||||
@ -171,11 +175,12 @@ class ReplayBuffer:
|
||||
except KeyError:
|
||||
self._meta.__dict__[name] = _create_value(inst, self._maxsize)
|
||||
value = self._meta.__dict__[name]
|
||||
if isinstance(inst, (np.ndarray, torch.Tensor)) \
|
||||
and value.shape[1:] != inst.shape:
|
||||
if isinstance(inst, (torch.Tensor, np.ndarray)) \
|
||||
and inst.shape != value.shape[1:]:
|
||||
raise ValueError(
|
||||
"Cannot add data to a buffer with different shape, with key "
|
||||
f"{name}, expect {value.shape[1:]}, given {inst.shape}.")
|
||||
f"{name}, expect {value.shape[1:]}, given {inst.shape}."
|
||||
)
|
||||
try:
|
||||
value[self._index] = inst
|
||||
except KeyError:
|
||||
@ -184,15 +189,15 @@ class ReplayBuffer:
|
||||
value[self._index] = inst
|
||||
|
||||
@property
|
||||
def stack_num(self):
|
||||
def stack_num(self) -> int:
|
||||
return self._stack
|
||||
|
||||
@stack_num.setter
|
||||
def stack_num(self, num):
|
||||
assert num > 0, 'stack_num should greater than 0'
|
||||
def stack_num(self, num: int) -> None:
|
||||
assert num > 0, "stack_num should greater than 0"
|
||||
self._stack = num
|
||||
|
||||
def update(self, buffer: 'ReplayBuffer') -> None:
|
||||
def update(self, buffer: "ReplayBuffer") -> None:
|
||||
"""Move the data from the given buffer to self."""
|
||||
if len(buffer) == 0:
|
||||
return
|
||||
@ -206,32 +211,35 @@ class ReplayBuffer:
|
||||
break
|
||||
buffer.stack_num = stack_num_orig
|
||||
|
||||
def add(self,
|
||||
obs: Union[dict, Batch, np.ndarray, float],
|
||||
act: Union[dict, Batch, np.ndarray, float],
|
||||
rew: Union[int, float],
|
||||
done: Union[bool, int],
|
||||
obs_next: Optional[Union[dict, Batch, np.ndarray, float]] = None,
|
||||
info: Optional[Union[dict, Batch]] = {},
|
||||
policy: Optional[Union[dict, Batch]] = {},
|
||||
**kwargs) -> None:
|
||||
def add(
|
||||
self,
|
||||
obs: Any,
|
||||
act: Any,
|
||||
rew: Union[Number, np.number, np.ndarray],
|
||||
done: Union[Number, np.number, np.bool_],
|
||||
obs_next: Any = None,
|
||||
info: Optional[Union[dict, Batch]] = {},
|
||||
policy: Optional[Union[dict, Batch]] = {},
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Add a batch of data into replay buffer."""
|
||||
assert isinstance(info, (dict, Batch)), \
|
||||
'You should return a dict in the last argument of env.step().'
|
||||
assert isinstance(
|
||||
info, (dict, Batch)
|
||||
), "You should return a dict in the last argument of env.step()."
|
||||
if self._last_obs:
|
||||
obs = obs[-1]
|
||||
self._add_to_buffer('obs', obs)
|
||||
self._add_to_buffer('act', act)
|
||||
self._add_to_buffer('rew', rew)
|
||||
self._add_to_buffer('done', done)
|
||||
self._add_to_buffer("obs", obs)
|
||||
self._add_to_buffer("act", act)
|
||||
self._add_to_buffer("rew", rew)
|
||||
self._add_to_buffer("done", done)
|
||||
if self._save_s_:
|
||||
if obs_next is None:
|
||||
obs_next = Batch()
|
||||
elif self._last_obs:
|
||||
obs_next = obs_next[-1]
|
||||
self._add_to_buffer('obs_next', obs_next)
|
||||
self._add_to_buffer('info', info)
|
||||
self._add_to_buffer('policy', policy)
|
||||
self._add_to_buffer("obs_next", obs_next)
|
||||
self._add_to_buffer("info", info)
|
||||
self._add_to_buffer("policy", policy)
|
||||
|
||||
# maintain available index for frame-stack sampling
|
||||
if self._avail:
|
||||
@ -262,7 +270,8 @@ class ReplayBuffer:
|
||||
self._avail_index = []
|
||||
|
||||
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
|
||||
"""Get a random sample from buffer with size equal to batch_size. \
|
||||
"""Get a random sample from buffer with size equal to batch_size.
|
||||
|
||||
Return all the data in the buffer if batch_size is 0.
|
||||
|
||||
:return: Sample data and its corresponding index inside the buffer.
|
||||
@ -278,11 +287,15 @@ class ReplayBuffer:
|
||||
np.arange(self._index, self._size),
|
||||
np.arange(0, self._index),
|
||||
])
|
||||
assert len(indice) > 0, 'No available indice can be sampled.'
|
||||
assert len(indice) > 0, "No available indice can be sampled."
|
||||
return self[indice], indice
|
||||
|
||||
def get(self, indice: Union[slice, int, np.integer, np.ndarray], key: str,
|
||||
stack_num: Optional[int] = None) -> Union[Batch, np.ndarray]:
|
||||
def get(
|
||||
self,
|
||||
indice: Union[slice, int, np.integer, np.ndarray],
|
||||
key: str,
|
||||
stack_num: Optional[int] = None,
|
||||
) -> Union[Batch, np.ndarray]:
|
||||
"""Return the stacked result.
|
||||
|
||||
E.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], where s is self.key, t is the
|
||||
@ -292,7 +305,7 @@ class ReplayBuffer:
|
||||
if stack_num is None:
|
||||
stack_num = self.stack_num
|
||||
if stack_num == 1: # the most often case
|
||||
if key != 'obs_next' or self._save_s_:
|
||||
if key != "obs_next" or self._save_s_:
|
||||
val = self._meta.__dict__[key]
|
||||
try:
|
||||
return val[indice]
|
||||
@ -301,11 +314,11 @@ class ReplayBuffer:
|
||||
raise e # val != Batch()
|
||||
return Batch()
|
||||
indice = self._indices[:self._size][indice]
|
||||
done = self._meta.__dict__['done']
|
||||
if key == 'obs_next' and not self._save_s_:
|
||||
done = self._meta.__dict__["done"]
|
||||
if key == "obs_next" and not self._save_s_:
|
||||
indice += 1 - done[indice].astype(np.int)
|
||||
indice[indice == self._size] = 0
|
||||
key = 'obs'
|
||||
key = "obs"
|
||||
val = self._meta.__dict__[key]
|
||||
try:
|
||||
if stack_num == 1:
|
||||
@ -319,30 +332,30 @@ class ReplayBuffer:
|
||||
pre_indice + done[pre_indice].astype(np.int))
|
||||
indice[indice == self._size] = 0
|
||||
if isinstance(val, Batch):
|
||||
stack = Batch.stack(stack, axis=indice.ndim)
|
||||
return Batch.stack(stack, axis=indice.ndim)
|
||||
else:
|
||||
stack = np.stack(stack, axis=indice.ndim)
|
||||
return stack
|
||||
return np.stack(stack, axis=indice.ndim)
|
||||
except IndexError as e:
|
||||
if not (isinstance(val, Batch) and val.is_empty()):
|
||||
raise e # val != Batch()
|
||||
return Batch()
|
||||
|
||||
def __getitem__(self, index: Union[
|
||||
slice, int, np.integer, np.ndarray]) -> Batch:
|
||||
def __getitem__(
|
||||
self, index: Union[slice, int, np.integer, np.ndarray]
|
||||
) -> Batch:
|
||||
"""Return a data batch: self[index].
|
||||
|
||||
If stack_num is larger than 1, return the stacked obs and obs_next
|
||||
with shape (batch, len, ...).
|
||||
If stack_num is larger than 1, return the stacked obs and obs_next with
|
||||
shape (batch, len, ...).
|
||||
"""
|
||||
return Batch(
|
||||
obs=self.get(index, 'obs'),
|
||||
obs=self.get(index, "obs"),
|
||||
act=self.act[index],
|
||||
rew=self.rew[index],
|
||||
done=self.done[index],
|
||||
obs_next=self.get(index, 'obs_next'),
|
||||
info=self.get(index, 'info'),
|
||||
policy=self.get(index, 'policy'),
|
||||
obs_next=self.get(index, "obs_next"),
|
||||
info=self.get(index, "info"),
|
||||
policy=self.get(index, "policy"),
|
||||
)
|
||||
|
||||
|
||||
@ -361,15 +374,15 @@ class ListReplayBuffer(ReplayBuffer):
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(size=0, ignore_obs_next=False, **kwargs)
|
||||
|
||||
def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
|
||||
raise NotImplementedError("ListReplayBuffer cannot be sampled!")
|
||||
|
||||
def _add_to_buffer(
|
||||
self, name: str,
|
||||
inst: Union[dict, Batch, np.ndarray, float, int, bool]) -> None:
|
||||
self, name: str, inst: Union[dict, Batch, np.ndarray, float, int, bool]
|
||||
) -> None:
|
||||
if self._meta.__dict__.get(name) is None:
|
||||
self._meta.__dict__[name] = []
|
||||
self._meta.__dict__[name].append(inst)
|
||||
@ -393,25 +406,29 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, size: int, alpha: float, beta: float, **kwargs) -> None:
|
||||
def __init__(
|
||||
self, size: int, alpha: float, beta: float, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(size, **kwargs)
|
||||
assert alpha > 0. and beta >= 0.
|
||||
assert alpha > 0.0 and beta >= 0.0
|
||||
self._alpha, self._beta = alpha, beta
|
||||
self._max_prio = self._min_prio = 1.0
|
||||
# save weight directly in this class instead of self._meta
|
||||
self.weight = SegmentTree(size)
|
||||
self.__eps = np.finfo(np.float32).eps.item()
|
||||
|
||||
def add(self,
|
||||
obs: Union[dict, Batch, np.ndarray, float],
|
||||
act: Union[dict, Batch, np.ndarray, float],
|
||||
rew: Union[int, float],
|
||||
done: Union[bool, int],
|
||||
obs_next: Optional[Union[dict, Batch, np.ndarray, float]] = None,
|
||||
info: Optional[Union[dict, Batch]] = {},
|
||||
policy: Optional[Union[dict, Batch]] = {},
|
||||
weight: Optional[float] = None,
|
||||
**kwargs) -> None:
|
||||
def add(
|
||||
self,
|
||||
obs: Any,
|
||||
act: Any,
|
||||
rew: Union[Number, np.number, np.ndarray],
|
||||
done: Union[Number, np.number, np.bool_],
|
||||
obs_next: Any = None,
|
||||
info: Optional[Union[dict, Batch]] = {},
|
||||
policy: Optional[Union[dict, Batch]] = {},
|
||||
weight: Optional[Union[Number, np.number]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Add a batch of data into replay buffer."""
|
||||
if weight is None:
|
||||
weight = self._max_prio
|
||||
@ -433,7 +450,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
to de-bias the sampling process (some transition tuples are sampled
|
||||
more often so their losses are weighted less).
|
||||
"""
|
||||
assert self._size > 0, 'Cannot sample a buffer with 0 size!'
|
||||
assert self._size > 0, "Cannot sample a buffer with 0 size!"
|
||||
if batch_size == 0:
|
||||
indice = np.concatenate([
|
||||
np.arange(self._index, self._size),
|
||||
@ -449,8 +466,11 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
batch.weight = (batch.weight / self._min_prio) ** (-self._beta)
|
||||
return batch, indice
|
||||
|
||||
def update_weight(self, indice: Union[np.ndarray],
|
||||
new_weight: Union[np.ndarray, torch.Tensor]) -> None:
|
||||
def update_weight(
|
||||
self,
|
||||
indice: Union[np.ndarray],
|
||||
new_weight: Union[np.ndarray, torch.Tensor]
|
||||
) -> None:
|
||||
"""Update priority weight by indice in this buffer.
|
||||
|
||||
:param np.ndarray indice: indice you want to update weight.
|
||||
@ -461,15 +481,16 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
self._max_prio = max(self._max_prio, weight.max())
|
||||
self._min_prio = min(self._min_prio, weight.min())
|
||||
|
||||
def __getitem__(self, index: Union[
|
||||
slice, int, np.integer, np.ndarray]) -> Batch:
|
||||
def __getitem__(
|
||||
self, index: Union[slice, int, np.integer, np.ndarray]
|
||||
) -> Batch:
|
||||
return Batch(
|
||||
obs=self.get(index, 'obs'),
|
||||
obs=self.get(index, "obs"),
|
||||
act=self.act[index],
|
||||
rew=self.rew[index],
|
||||
done=self.done[index],
|
||||
obs_next=self.get(index, 'obs_next'),
|
||||
info=self.get(index, 'info'),
|
||||
policy=self.get(index, 'policy'),
|
||||
obs_next=self.get(index, "obs_next"),
|
||||
info=self.get(index, "info"),
|
||||
policy=self.get(index, "policy"),
|
||||
weight=self.weight[index],
|
||||
)
|
||||
|
@ -4,13 +4,14 @@ import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Union, Optional, Callable
|
||||
from numbers import Number
|
||||
from typing import Dict, List, Union, Optional, Callable
|
||||
|
||||
from tianshou.env import BaseVectorEnv, DummyVectorEnv
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.exploration import BaseNoise
|
||||
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
|
||||
from tianshou.data.batch import _create_value
|
||||
from tianshou.env import BaseVectorEnv, DummyVectorEnv
|
||||
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
|
||||
|
||||
|
||||
class Collector(object):
|
||||
@ -75,14 +76,15 @@ class Collector(object):
|
||||
Please make sure the given environment has a time limitation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
policy: BasePolicy,
|
||||
env: Union[gym.Env, BaseVectorEnv],
|
||||
buffer: Optional[ReplayBuffer] = None,
|
||||
preprocess_fn: Callable[[Any], Batch] = None,
|
||||
action_noise: Optional[BaseNoise] = None,
|
||||
reward_metric: Optional[Callable[[np.ndarray], float]] = None,
|
||||
) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
policy: BasePolicy,
|
||||
env: Union[gym.Env, BaseVectorEnv],
|
||||
buffer: Optional[ReplayBuffer] = None,
|
||||
preprocess_fn: Optional[Callable[..., Batch]] = None,
|
||||
action_noise: Optional[BaseNoise] = None,
|
||||
reward_metric: Optional[Callable[[np.ndarray], float]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if not isinstance(env, BaseVectorEnv):
|
||||
env = DummyVectorEnv([lambda: env])
|
||||
@ -108,12 +110,15 @@ class Collector(object):
|
||||
self.reset()
|
||||
|
||||
@staticmethod
|
||||
def _default_rew_metric(x):
|
||||
def _default_rew_metric(
|
||||
x: Union[Number, np.number]
|
||||
) -> Union[Number, np.number]:
|
||||
# this internal function is designed for single-agent RL
|
||||
# for multi-agent RL, a reward_metric must be provided
|
||||
assert np.asanyarray(x).size == 1, \
|
||||
'Please specify the reward_metric ' \
|
||||
'since the reward is not a scalar.'
|
||||
assert np.asanyarray(x).size == 1, (
|
||||
"Please specify the reward_metric "
|
||||
"since the reward is not a scalar."
|
||||
)
|
||||
return x
|
||||
|
||||
def reset(self) -> None:
|
||||
@ -124,7 +129,7 @@ class Collector(object):
|
||||
obs_next={}, policy={})
|
||||
self.reset_env()
|
||||
self.reset_buffer()
|
||||
self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
|
||||
self.collect_time, self.collect_step, self.collect_episode = 0.0, 0, 0
|
||||
if self._action_noise is not None:
|
||||
self._action_noise.reset()
|
||||
|
||||
@ -142,7 +147,7 @@ class Collector(object):
|
||||
self._ready_env_ids = np.arange(self.env_num)
|
||||
obs = self.env.reset()
|
||||
if self.preprocess_fn:
|
||||
obs = self.preprocess_fn(obs=obs).get('obs', obs)
|
||||
obs = self.preprocess_fn(obs=obs).get("obs", obs)
|
||||
self.data.obs = obs
|
||||
for b in self._cached_buf:
|
||||
b.reset()
|
||||
@ -157,13 +162,14 @@ class Collector(object):
|
||||
elif isinstance(state, Batch):
|
||||
state.empty_(id)
|
||||
|
||||
def collect(self,
|
||||
n_step: Optional[int] = None,
|
||||
n_episode: Optional[Union[int, List[int]]] = None,
|
||||
random: bool = False,
|
||||
render: Optional[float] = None,
|
||||
no_grad: bool = True,
|
||||
) -> Dict[str, float]:
|
||||
def collect(
|
||||
self,
|
||||
n_step: Optional[int] = None,
|
||||
n_episode: Optional[Union[int, List[int]]] = None,
|
||||
random: bool = False,
|
||||
render: Optional[float] = None,
|
||||
no_grad: bool = True,
|
||||
) -> Dict[str, float]:
|
||||
"""Collect a specified number of step or episode.
|
||||
|
||||
:param int n_step: how many steps you want to collect.
|
||||
@ -217,8 +223,8 @@ class Collector(object):
|
||||
while True:
|
||||
if step_count >= 100000 and episode_count.sum() == 0:
|
||||
warnings.warn(
|
||||
'There are already many steps in an episode. '
|
||||
'You should add a time limitation to your environment!',
|
||||
"There are already many steps in an episode. "
|
||||
"You should add a time limitation to your environment!",
|
||||
Warning)
|
||||
|
||||
is_async = self.is_async or len(finished_env_ids) > 0
|
||||
@ -250,11 +256,11 @@ class Collector(object):
|
||||
else:
|
||||
result = self.policy(self.data, last_state)
|
||||
|
||||
state = result.get('state', Batch())
|
||||
state = result.get("state", Batch())
|
||||
# convert None to Batch(), since None is reserved for 0-init
|
||||
if state is None:
|
||||
state = Batch()
|
||||
self.data.update(state=state, policy=result.get('policy', Batch()))
|
||||
self.data.update(state=state, policy=result.get("policy", Batch()))
|
||||
# save hidden state to policy._state, in order to save into buffer
|
||||
if not (isinstance(state, Batch) and state.is_empty()):
|
||||
self.data.policy._state = self.data.state
|
||||
@ -268,12 +274,12 @@ class Collector(object):
|
||||
obs_next, rew, done, info = self.env.step(self.data.act)
|
||||
else:
|
||||
# store computed actions, states, etc
|
||||
_batch_set_item(whole_data, self._ready_env_ids,
|
||||
self.data, self.env_num)
|
||||
_batch_set_item(
|
||||
whole_data, self._ready_env_ids, self.data, self.env_num)
|
||||
# fetch finished data
|
||||
obs_next, rew, done, info = self.env.step(
|
||||
self.data.act, id=self._ready_env_ids)
|
||||
self._ready_env_ids = np.array([i['env_id'] for i in info])
|
||||
self._ready_env_ids = np.array([i["env_id"] for i in info])
|
||||
# get the stepped data
|
||||
self.data = whole_data[self._ready_env_ids]
|
||||
# move data to self.data
|
||||
@ -319,15 +325,15 @@ class Collector(object):
|
||||
obs_reset = self.env.reset(env_ind_global)
|
||||
if self.preprocess_fn:
|
||||
obs_next[env_ind_local] = self.preprocess_fn(
|
||||
obs=obs_reset).get('obs', obs_reset)
|
||||
obs=obs_reset).get("obs", obs_reset)
|
||||
else:
|
||||
obs_next[env_ind_local] = obs_reset
|
||||
self.data.obs = obs_next
|
||||
if is_async:
|
||||
# set data back
|
||||
whole_data = deepcopy(whole_data) # avoid reference in ListBuf
|
||||
_batch_set_item(whole_data, self._ready_env_ids,
|
||||
self.data, self.env_num)
|
||||
_batch_set_item(
|
||||
whole_data, self._ready_env_ids, self.data, self.env_num)
|
||||
# let self.data be the data in all environments again
|
||||
self.data = whole_data
|
||||
self._ready_env_ids = np.array(
|
||||
@ -358,12 +364,12 @@ class Collector(object):
|
||||
if np.asanyarray(reward_avg).size > 1: # non-scalar reward_avg
|
||||
reward_avg = self._rew_metric(reward_avg)
|
||||
return {
|
||||
'n/ep': episode_count,
|
||||
'n/st': step_count,
|
||||
'v/st': step_count / duration,
|
||||
'v/ep': episode_count / duration,
|
||||
'rew': reward_avg,
|
||||
'len': step_count / episode_count,
|
||||
"n/ep": episode_count,
|
||||
"n/st": step_count,
|
||||
"v/st": step_count / duration,
|
||||
"v/ep": episode_count / duration,
|
||||
"rew": reward_avg,
|
||||
"len": step_count / episode_count,
|
||||
}
|
||||
|
||||
def sample(self, batch_size: int) -> Batch:
|
||||
@ -377,9 +383,9 @@ class Collector(object):
|
||||
batch_size.
|
||||
"""
|
||||
warnings.warn(
|
||||
'Collector.sample is deprecated and will cause error if you use '
|
||||
'prioritized experience replay! Collector.sample will be removed '
|
||||
'upon version 0.3. Use policy.update instead!', Warning)
|
||||
"Collector.sample is deprecated and will cause error if you use "
|
||||
"prioritized experience replay! Collector.sample will be removed "
|
||||
"upon version 0.3. Use policy.update instead!", Warning)
|
||||
assert self.buffer is not None, "Cannot get sample from empty buffer!"
|
||||
batch_data, indice = self.buffer.sample(batch_size)
|
||||
batch_data = self.process_fn(batch_data, self.buffer, indice)
|
||||
@ -387,12 +393,13 @@ class Collector(object):
|
||||
|
||||
def close(self) -> None:
|
||||
warnings.warn(
|
||||
'Collector.close is deprecated and will be removed upon version '
|
||||
'0.3.', Warning)
|
||||
"Collector.close is deprecated and will be removed upon version "
|
||||
"0.3.", Warning)
|
||||
|
||||
|
||||
def _batch_set_item(source: Batch, indices: np.ndarray,
|
||||
target: Batch, size: int):
|
||||
def _batch_set_item(
|
||||
source: Batch, indices: np.ndarray, target: Batch, size: int
|
||||
) -> None:
|
||||
# for any key chain k, there are four cases
|
||||
# 1. source[k] is non-reserved, but target[k] does not exist or is reserved
|
||||
# 2. source[k] does not exist or is reserved, but target[k] is non-reserved
|
||||
|
@ -1,72 +1,79 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from numbers import Number
|
||||
from typing import Union, Optional
|
||||
|
||||
from tianshou.data.batch import _parse_value, Batch
|
||||
|
||||
|
||||
def to_numpy(x: Union[
|
||||
Batch, dict, list, tuple, np.ndarray, torch.Tensor]) -> Union[
|
||||
Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
|
||||
def to_numpy(
|
||||
x: Optional[Union[Batch, dict, list, tuple, np.number, np.bool_, Number,
|
||||
np.ndarray, torch.Tensor]]
|
||||
) -> Union[Batch, dict, list, tuple, np.ndarray]:
|
||||
"""Return an object without torch.Tensor."""
|
||||
if isinstance(x, torch.Tensor): # most often case
|
||||
x = x.detach().cpu().numpy()
|
||||
return x.detach().cpu().numpy()
|
||||
elif isinstance(x, np.ndarray): # second often case
|
||||
pass
|
||||
return x
|
||||
elif isinstance(x, (np.number, np.bool_, Number)):
|
||||
x = np.asanyarray(x)
|
||||
return np.asanyarray(x)
|
||||
elif x is None:
|
||||
x = np.array(None, dtype=np.object)
|
||||
return np.array(None, dtype=np.object)
|
||||
elif isinstance(x, Batch):
|
||||
x = deepcopy(x)
|
||||
x.to_numpy()
|
||||
return x
|
||||
elif isinstance(x, dict):
|
||||
for k, v in x.items():
|
||||
x[k] = to_numpy(v)
|
||||
return {k: to_numpy(v) for k, v in x.items()}
|
||||
elif isinstance(x, (list, tuple)):
|
||||
try:
|
||||
x = to_numpy(_parse_value(x))
|
||||
return to_numpy(_parse_value(x))
|
||||
except TypeError:
|
||||
x = [to_numpy(e) for e in x]
|
||||
return [to_numpy(e) for e in x]
|
||||
else: # fallback
|
||||
x = np.asanyarray(x)
|
||||
return x
|
||||
return np.asanyarray(x)
|
||||
|
||||
|
||||
def to_torch(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Union[str, int, torch.device] = 'cpu'
|
||||
) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
|
||||
def to_torch(
|
||||
x: Union[Batch, dict, list, tuple, np.number, np.bool_, Number, np.ndarray,
|
||||
torch.Tensor],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
) -> Union[Batch, dict, list, tuple, torch.Tensor]:
|
||||
"""Return an object without np.ndarray."""
|
||||
if isinstance(x, np.ndarray) and \
|
||||
issubclass(x.dtype.type, (np.bool_, np.number)): # most often case
|
||||
if isinstance(x, np.ndarray) and issubclass(
|
||||
x.dtype.type, (np.bool_, np.number)
|
||||
): # most often case
|
||||
x = torch.from_numpy(x).to(device)
|
||||
if dtype is not None:
|
||||
x = x.type(dtype)
|
||||
return x
|
||||
elif isinstance(x, torch.Tensor): # second often case
|
||||
if dtype is not None:
|
||||
x = x.type(dtype)
|
||||
x = x.to(device)
|
||||
return x.to(device)
|
||||
elif isinstance(x, (np.number, np.bool_, Number)):
|
||||
x = to_torch(np.asanyarray(x), dtype, device)
|
||||
return to_torch(np.asanyarray(x), dtype, device)
|
||||
elif isinstance(x, dict):
|
||||
for k, v in x.items():
|
||||
x[k] = to_torch(v, dtype, device)
|
||||
return {k: to_torch(v, dtype, device) for k, v in x.items()}
|
||||
elif isinstance(x, Batch):
|
||||
x = deepcopy(x)
|
||||
x.to_torch(dtype, device)
|
||||
return x
|
||||
elif isinstance(x, (list, tuple)):
|
||||
try:
|
||||
x = to_torch(_parse_value(x), dtype, device)
|
||||
return to_torch(_parse_value(x), dtype, device)
|
||||
except TypeError:
|
||||
x = [to_torch(e, dtype, device) for e in x]
|
||||
return [to_torch(e, dtype, device) for e in x]
|
||||
else: # fallback
|
||||
raise TypeError(f"object {x} cannot be converted to torch.")
|
||||
return x
|
||||
|
||||
|
||||
def to_torch_as(x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor],
|
||||
y: torch.Tensor
|
||||
) -> Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor]:
|
||||
def to_torch_as(
|
||||
x: Union[Batch, dict, list, tuple, np.ndarray, torch.Tensor],
|
||||
y: torch.Tensor,
|
||||
) -> Union[Batch, dict, list, tuple, torch.Tensor]:
|
||||
"""Return an object without np.ndarray.
|
||||
|
||||
Same as ``to_torch(x, dtype=y.dtype, device=y.device)``.
|
||||
|
@ -24,17 +24,20 @@ class SegmentTree:
|
||||
self._size = size
|
||||
self._bound = bound
|
||||
self._value = np.zeros([bound * 2])
|
||||
self._compile()
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self._size
|
||||
|
||||
def __getitem__(self, index: Union[int, np.ndarray]
|
||||
) -> Union[float, np.ndarray]:
|
||||
def __getitem__(
|
||||
self, index: Union[int, np.ndarray]
|
||||
) -> Union[float, np.ndarray]:
|
||||
"""Return self[index]."""
|
||||
return self._value[index + self._bound]
|
||||
|
||||
def __setitem__(self, index: Union[int, np.ndarray],
|
||||
value: Union[float, np.ndarray]) -> None:
|
||||
def __setitem__(
|
||||
self, index: Union[int, np.ndarray], value: Union[float, np.ndarray]
|
||||
) -> None:
|
||||
"""Update values in segment tree.
|
||||
|
||||
Duplicate values in ``index`` are handled by numpy: later index
|
||||
@ -62,7 +65,8 @@ class SegmentTree:
|
||||
return _reduce(self._value, start + self._bound - 1, end + self._bound)
|
||||
|
||||
def get_prefix_sum_idx(
|
||||
self, value: Union[float, np.ndarray]) -> Union[int, np.ndarray]:
|
||||
self, value: Union[float, np.ndarray]
|
||||
) -> Union[int, np.ndarray]:
|
||||
r"""Find the index with given value.
|
||||
|
||||
Return the minimum index for each ``v`` in ``value`` so that
|
||||
@ -74,7 +78,7 @@ class SegmentTree:
|
||||
Please make sure all of the values inside the segment tree are
|
||||
non-negative when using this function.
|
||||
"""
|
||||
assert np.all(value >= 0.) and np.all(value < self._value[1])
|
||||
assert np.all(value >= 0.0) and np.all(value < self._value[1])
|
||||
single = False
|
||||
if not isinstance(value, np.ndarray):
|
||||
value = np.array([value])
|
||||
@ -82,6 +86,16 @@ class SegmentTree:
|
||||
index = _get_prefix_sum_idx(value, self._bound, self._value)
|
||||
return index.item() if single else index
|
||||
|
||||
def _compile(self) -> None:
|
||||
f64 = np.array([0, 1], dtype=np.float64)
|
||||
f32 = np.array([0, 1], dtype=np.float32)
|
||||
i64 = np.array([0, 1], dtype=np.int64)
|
||||
_setitem(f64, i64, f64)
|
||||
_setitem(f64, i64, f32)
|
||||
_reduce(f64, 0, 1)
|
||||
_get_prefix_sum_idx(f64, 1, f64)
|
||||
_get_prefix_sum_idx(f32, 1, f64)
|
||||
|
||||
|
||||
@njit
|
||||
def _setitem(tree: np.ndarray, index: np.ndarray, value: np.ndarray) -> None:
|
||||
@ -96,7 +110,7 @@ def _setitem(tree: np.ndarray, index: np.ndarray, value: np.ndarray) -> None:
|
||||
def _reduce(tree: np.ndarray, start: int, end: int) -> float:
|
||||
"""Numba version, 2x faster: 0.009 -> 0.005."""
|
||||
# nodes in (start, end) should be aggregated
|
||||
result = 0.
|
||||
result = 0.0
|
||||
while end - start > 1: # (start, end) interval is not empty
|
||||
if start % 2 == 0:
|
||||
result += tree[start + 1]
|
||||
@ -108,8 +122,9 @@ def _reduce(tree: np.ndarray, start: int, end: int) -> float:
|
||||
|
||||
|
||||
@njit
|
||||
def _get_prefix_sum_idx(value: np.ndarray, bound: int,
|
||||
sums: np.ndarray) -> np.ndarray:
|
||||
def _get_prefix_sum_idx(
|
||||
value: np.ndarray, bound: int, sums: np.ndarray
|
||||
) -> np.ndarray:
|
||||
"""Numba version (v0.51), 5x speed up with size=100000 and bsz=64.
|
||||
|
||||
vectorized np: 0.0923 (numpy best) -> 0.024 (now)
|
||||
|
14
tianshou/env/__init__.py
vendored
14
tianshou/env/__init__.py
vendored
@ -3,11 +3,11 @@ from tianshou.env.venvs import BaseVectorEnv, DummyVectorEnv, VectorEnv, \
|
||||
from tianshou.env.maenv import MultiAgentEnv
|
||||
|
||||
__all__ = [
|
||||
'BaseVectorEnv',
|
||||
'DummyVectorEnv',
|
||||
'VectorEnv', # TODO: remove in later version
|
||||
'SubprocVectorEnv',
|
||||
'ShmemVectorEnv',
|
||||
'RayVectorEnv',
|
||||
'MultiAgentEnv',
|
||||
"BaseVectorEnv",
|
||||
"DummyVectorEnv",
|
||||
"VectorEnv", # TODO: remove in later version
|
||||
"SubprocVectorEnv",
|
||||
"ShmemVectorEnv",
|
||||
"RayVectorEnv",
|
||||
"MultiAgentEnv",
|
||||
]
|
||||
|
11
tianshou/env/maenv.py
vendored
11
tianshou/env/maenv.py
vendored
@ -1,6 +1,6 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
from typing import Any, Dict, Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
@ -22,7 +22,7 @@ class MultiAgentEnv(ABC, gym.Env):
|
||||
usage can be found at :ref:`marl_example`.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -30,13 +30,14 @@ class MultiAgentEnv(ABC, gym.Env):
|
||||
"""Reset the state.
|
||||
|
||||
Return the initial state, first agent_id, and the initial action set,
|
||||
for example, ``{'obs': obs, 'agent_id': agent_id, 'mask': mask}``
|
||||
for example, ``{'obs': obs, 'agent_id': agent_id, 'mask': mask}``.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def step(self, action: np.ndarray
|
||||
) -> Tuple[dict, np.ndarray, np.ndarray, np.ndarray]:
|
||||
def step(
|
||||
self, action: np.ndarray
|
||||
) -> Tuple[Dict[str, Any], np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Run one timestep of the environment’s dynamics.
|
||||
|
||||
When the end of episode is reached, you are responsible for calling
|
||||
|
7
tianshou/env/utils.py
vendored
7
tianshou/env/utils.py
vendored
@ -1,14 +1,15 @@
|
||||
import cloudpickle
|
||||
from typing import Any
|
||||
|
||||
|
||||
class CloudpickleWrapper(object):
|
||||
"""A cloudpickle wrapper used in SubprocVectorEnv."""
|
||||
|
||||
def __init__(self, data):
|
||||
def __init__(self, data: Any) -> None:
|
||||
self.data = data
|
||||
|
||||
def __getstate__(self):
|
||||
def __getstate__(self) -> str:
|
||||
return cloudpickle.dumps(self.data)
|
||||
|
||||
def __setstate__(self, data):
|
||||
def __setstate__(self, data: str) -> None:
|
||||
self.data = cloudpickle.loads(data)
|
||||
|
138
tianshou/env/venvs.py
vendored
138
tianshou/env/venvs.py
vendored
@ -1,7 +1,7 @@
|
||||
import gym
|
||||
import warnings
|
||||
import numpy as np
|
||||
from typing import List, Union, Optional, Callable, Any
|
||||
from typing import Any, List, Union, Optional, Callable
|
||||
|
||||
from tianshou.env.worker import EnvWorker, DummyEnvWorker, SubprocEnvWorker, \
|
||||
RayEnvWorker
|
||||
@ -59,12 +59,13 @@ class BaseVectorEnv(gym.Env):
|
||||
within ``timeout`` seconds.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
env_fns: List[Callable[[], gym.Env]],
|
||||
worker_fn: Callable[[Callable[[], gym.Env]], EnvWorker],
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: List[Callable[[], gym.Env]],
|
||||
worker_fn: Callable[[Callable[[], gym.Env]], EnvWorker],
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
self._env_fns = env_fns
|
||||
# A VectorEnv contains a pool of EnvWorkers, which corresponds to
|
||||
# interact with the given envs (one worker <-> one env).
|
||||
@ -75,11 +76,13 @@ class BaseVectorEnv(gym.Env):
|
||||
|
||||
self.env_num = len(env_fns)
|
||||
self.wait_num = wait_num or len(env_fns)
|
||||
assert 1 <= self.wait_num <= len(env_fns), \
|
||||
f'wait_num should be in [1, {len(env_fns)}], but got {wait_num}'
|
||||
assert (
|
||||
1 <= self.wait_num <= len(env_fns)
|
||||
), f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}"
|
||||
self.timeout = timeout
|
||||
assert self.timeout is None or self.timeout > 0, \
|
||||
f'timeout is {timeout}, it should be positive if provided!'
|
||||
assert (
|
||||
self.timeout is None or self.timeout > 0
|
||||
), f"timeout is {timeout}, it should be positive if provided!"
|
||||
self.is_async = self.wait_num != len(env_fns) or timeout is not None
|
||||
self.waiting_conn = []
|
||||
# environments in self.ready_id is actually ready
|
||||
@ -92,8 +95,9 @@ class BaseVectorEnv(gym.Env):
|
||||
self.is_closed = False
|
||||
|
||||
def _assert_is_not_closed(self) -> None:
|
||||
assert not self.is_closed, f"Methods of {self.__class__.__name__} "\
|
||||
"should not be called after close."
|
||||
assert not self.is_closed, (
|
||||
f"Methods of {self.__class__.__name__} cannot be called after "
|
||||
"close.")
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return len(self), which is the number of environments."""
|
||||
@ -113,7 +117,7 @@ class BaseVectorEnv(gym.Env):
|
||||
else:
|
||||
return super().__getattribute__(key)
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
def __getattr__(self, key: str) -> List[Any]:
|
||||
"""Fetch a list of env attributes.
|
||||
|
||||
This function tries to retrieve an attribute from each individual
|
||||
@ -122,8 +126,9 @@ class BaseVectorEnv(gym.Env):
|
||||
"""
|
||||
return [getattr(worker, key) for worker in self.workers]
|
||||
|
||||
def _wrap_id(self, id: Optional[Union[int, List[int], np.ndarray]] = None
|
||||
) -> List[int]:
|
||||
def _wrap_id(
|
||||
self, id: Optional[Union[int, List[int], np.ndarray]] = None
|
||||
) -> Union[List[int], np.ndarray]:
|
||||
if id is None:
|
||||
id = list(range(self.env_num))
|
||||
elif np.isscalar(id):
|
||||
@ -132,13 +137,16 @@ class BaseVectorEnv(gym.Env):
|
||||
|
||||
def _assert_id(self, id: List[int]) -> None:
|
||||
for i in id:
|
||||
assert i not in self.waiting_id, \
|
||||
f'Cannot interact with environment {i} which is stepping now.'
|
||||
assert i in self.ready_id, \
|
||||
f'Can only interact with ready environments {self.ready_id}.'
|
||||
assert (
|
||||
i not in self.waiting_id
|
||||
), f"Cannot interact with environment {i} which is stepping now."
|
||||
assert (
|
||||
i in self.ready_id
|
||||
), f"Can only interact with ready environments {self.ready_id}."
|
||||
|
||||
def reset(self, id: Optional[Union[int, List[int], np.ndarray]] = None
|
||||
) -> np.ndarray:
|
||||
def reset(
|
||||
self, id: Optional[Union[int, List[int], np.ndarray]] = None
|
||||
) -> np.ndarray:
|
||||
"""Reset the state of some envs and return initial observations.
|
||||
|
||||
If id is None, reset the state of all the environments and return
|
||||
@ -152,10 +160,11 @@ class BaseVectorEnv(gym.Env):
|
||||
obs = np.stack([self.workers[i].reset() for i in id])
|
||||
return obs
|
||||
|
||||
def step(self,
|
||||
action: np.ndarray,
|
||||
id: Optional[Union[int, List[int], np.ndarray]] = None
|
||||
) -> List[np.ndarray]:
|
||||
def step(
|
||||
self,
|
||||
action: np.ndarray,
|
||||
id: Optional[Union[int, List[int], np.ndarray]] = None
|
||||
) -> List[np.ndarray]:
|
||||
"""Run one timestep of some environments' dynamics.
|
||||
|
||||
If id is None, run one timestep of all the environments’ dynamics;
|
||||
@ -221,8 +230,9 @@ class BaseVectorEnv(gym.Env):
|
||||
self.ready_id.append(env_id)
|
||||
return list(map(np.stack, zip(*result)))
|
||||
|
||||
def seed(self,
|
||||
seed: Optional[Union[int, List[int]]] = None) -> List[List[int]]:
|
||||
def seed(
|
||||
self, seed: Optional[Union[int, List[int]]] = None
|
||||
) -> List[Optional[List[int]]]:
|
||||
"""Set the seed for all environments.
|
||||
|
||||
Accept ``None``, an int (which will extend ``i`` to
|
||||
@ -239,13 +249,13 @@ class BaseVectorEnv(gym.Env):
|
||||
seed = [seed + i for i in range(self.env_num)]
|
||||
return [w.seed(s) for w, s in zip(self.workers, seed)]
|
||||
|
||||
def render(self, **kwargs) -> List[Any]:
|
||||
def render(self, **kwargs: Any) -> List[Any]:
|
||||
"""Render all of the environments."""
|
||||
self._assert_is_not_closed()
|
||||
if self.is_async and len(self.waiting_id) > 0:
|
||||
raise RuntimeError(
|
||||
f"Environments {self.waiting_id} are still "
|
||||
f"stepping, cannot render them now.")
|
||||
f"Environments {self.waiting_id} are still stepping, cannot "
|
||||
"render them now.")
|
||||
return [w.render(**kwargs) for w in self.workers]
|
||||
|
||||
def close(self) -> None:
|
||||
@ -275,20 +285,23 @@ class DummyVectorEnv(BaseVectorEnv):
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]],
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None) -> None:
|
||||
super().__init__(env_fns, DummyEnvWorker,
|
||||
wait_num=wait_num, timeout=timeout)
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: List[Callable[[], gym.Env]],
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
env_fns, DummyEnvWorker, wait_num=wait_num, timeout=timeout)
|
||||
|
||||
|
||||
class VectorEnv(DummyVectorEnv):
|
||||
"""VectorEnv is renamed to DummyVectorEnv."""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
warnings.warn(
|
||||
'VectorEnv is renamed to DummyVectorEnv, and will be removed in '
|
||||
'0.3. Use DummyVectorEnv instead!', Warning)
|
||||
"VectorEnv is renamed to DummyVectorEnv, and will be removed in "
|
||||
"0.3. Use DummyVectorEnv instead!", Warning)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
@ -301,13 +314,17 @@ class SubprocVectorEnv(BaseVectorEnv):
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]],
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None) -> None:
|
||||
def worker_fn(fn):
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: List[Callable[[], gym.Env]],
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
|
||||
return SubprocEnvWorker(fn, share_memory=False)
|
||||
super().__init__(env_fns, worker_fn,
|
||||
wait_num=wait_num, timeout=timeout)
|
||||
|
||||
super().__init__(
|
||||
env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
|
||||
|
||||
|
||||
class ShmemVectorEnv(BaseVectorEnv):
|
||||
@ -321,13 +338,17 @@ class ShmemVectorEnv(BaseVectorEnv):
|
||||
detailed explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]],
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None) -> None:
|
||||
def worker_fn(fn):
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: List[Callable[[], gym.Env]],
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
def worker_fn(fn: Callable[[], gym.Env]) -> SubprocEnvWorker:
|
||||
return SubprocEnvWorker(fn, share_memory=True)
|
||||
super().__init__(env_fns, worker_fn,
|
||||
wait_num=wait_num, timeout=timeout)
|
||||
|
||||
super().__init__(
|
||||
env_fns, worker_fn, wait_num=wait_num, timeout=timeout)
|
||||
|
||||
|
||||
class RayVectorEnv(BaseVectorEnv):
|
||||
@ -341,16 +362,19 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]],
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
env_fns: List[Callable[[], gym.Env]],
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
try:
|
||||
import ray
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
'Please install ray to support RayVectorEnv: pip install ray'
|
||||
"Please install ray to support RayVectorEnv: pip install ray"
|
||||
) from e
|
||||
if not ray.is_initialized():
|
||||
ray.init()
|
||||
super().__init__(env_fns, RayEnvWorker,
|
||||
wait_num=wait_num, timeout=timeout)
|
||||
super().__init__(
|
||||
env_fns, RayEnvWorker, wait_num=wait_num, timeout=timeout)
|
||||
|
8
tianshou/env/worker/__init__.py
vendored
8
tianshou/env/worker/__init__.py
vendored
@ -4,8 +4,8 @@ from tianshou.env.worker.subproc import SubprocEnvWorker
|
||||
from tianshou.env.worker.ray import RayEnvWorker
|
||||
|
||||
__all__ = [
|
||||
'EnvWorker',
|
||||
'DummyEnvWorker',
|
||||
'SubprocEnvWorker',
|
||||
'RayEnvWorker',
|
||||
"EnvWorker",
|
||||
"DummyEnvWorker",
|
||||
"SubprocEnvWorker",
|
||||
"RayEnvWorker",
|
||||
]
|
||||
|
26
tianshou/env/worker/base.py
vendored
26
tianshou/env/worker/base.py
vendored
@ -1,7 +1,7 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Optional, Callable, Any
|
||||
from typing import Any, List, Tuple, Optional, Callable
|
||||
|
||||
|
||||
class EnvWorker(ABC):
|
||||
@ -24,12 +24,14 @@ class EnvWorker(ABC):
|
||||
def send_action(self, action: np.ndarray) -> None:
|
||||
pass
|
||||
|
||||
def get_result(self) -> Tuple[
|
||||
np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
def get_result(
|
||||
self,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
return self.result
|
||||
|
||||
def step(self, action: np.ndarray
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
def step(
|
||||
self, action: np.ndarray
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Perform one timestep of the environment's dynamic.
|
||||
|
||||
"send_action" and "get_result" are coupled in sync simulation, so
|
||||
@ -41,19 +43,21 @@ class EnvWorker(ABC):
|
||||
return self.get_result()
|
||||
|
||||
@staticmethod
|
||||
def wait(workers: List['EnvWorker'],
|
||||
wait_num: int,
|
||||
timeout: Optional[float] = None) -> List['EnvWorker']:
|
||||
def wait(
|
||||
workers: List["EnvWorker"],
|
||||
wait_num: int,
|
||||
timeout: Optional[float] = None,
|
||||
) -> List["EnvWorker"]:
|
||||
"""Given a list of workers, return those ready ones."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def seed(self, seed: Optional[int] = None) -> List[int]:
|
||||
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def render(self, **kwargs) -> Any:
|
||||
"""Renders the environment."""
|
||||
def render(self, **kwargs: Any) -> Any:
|
||||
"""Render the environment."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
21
tianshou/env/worker/dummy.py
vendored
21
tianshou/env/worker/dummy.py
vendored
@ -1,6 +1,6 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
from typing import List, Callable, Optional, Any
|
||||
from typing import Any, List, Callable, Optional
|
||||
|
||||
from tianshou.env.worker import EnvWorker
|
||||
|
||||
@ -19,21 +19,24 @@ class DummyEnvWorker(EnvWorker):
|
||||
return self.env.reset()
|
||||
|
||||
@staticmethod
|
||||
def wait(workers: List['DummyEnvWorker'],
|
||||
wait_num: int,
|
||||
timeout: Optional[float] = None) -> List['DummyEnvWorker']:
|
||||
def wait(
|
||||
workers: List["DummyEnvWorker"],
|
||||
wait_num: int,
|
||||
timeout: Optional[float] = None,
|
||||
) -> List["DummyEnvWorker"]:
|
||||
# Sequential EnvWorker objects are always ready
|
||||
return workers
|
||||
|
||||
def send_action(self, action: np.ndarray) -> None:
|
||||
self.result = self.env.step(action)
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> List[int]:
|
||||
return self.env.seed(seed) if hasattr(self.env, 'seed') else None
|
||||
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
|
||||
return self.env.seed(seed) if hasattr(self.env, "seed") else None
|
||||
|
||||
def render(self, **kwargs) -> Any:
|
||||
return self.env.render(**kwargs) \
|
||||
if hasattr(self.env, 'render') else None
|
||||
def render(self, **kwargs: Any) -> Any:
|
||||
return (
|
||||
self.env.render(**kwargs) if hasattr(self.env, "render") else None
|
||||
)
|
||||
|
||||
def close_env(self) -> None:
|
||||
self.env.close()
|
||||
|
29
tianshou/env/worker/ray.py
vendored
29
tianshou/env/worker/ray.py
vendored
@ -1,6 +1,6 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
from typing import List, Callable, Tuple, Optional, Any
|
||||
from typing import Any, List, Callable, Tuple, Optional
|
||||
|
||||
from tianshou.env.worker import EnvWorker
|
||||
|
||||
@ -24,31 +24,34 @@ class RayEnvWorker(EnvWorker):
|
||||
return ray.get(self.env.reset.remote())
|
||||
|
||||
@staticmethod
|
||||
def wait(workers: List['RayEnvWorker'],
|
||||
wait_num: int,
|
||||
timeout: Optional[float] = None) -> List['RayEnvWorker']:
|
||||
def wait(
|
||||
workers: List["RayEnvWorker"],
|
||||
wait_num: int,
|
||||
timeout: Optional[float] = None,
|
||||
) -> List["RayEnvWorker"]:
|
||||
results = [x.result for x in workers]
|
||||
ready_results, _ = ray.wait(results,
|
||||
num_returns=wait_num, timeout=timeout)
|
||||
ready_results, _ = ray.wait(
|
||||
results, num_returns=wait_num, timeout=timeout
|
||||
)
|
||||
return [workers[results.index(result)] for result in ready_results]
|
||||
|
||||
def send_action(self, action: np.ndarray) -> None:
|
||||
# self.action is actually a handle
|
||||
self.result = self.env.step.remote(action)
|
||||
|
||||
def get_result(self) -> Tuple[
|
||||
np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
def get_result(
|
||||
self,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
return ray.get(self.result)
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> List[int]:
|
||||
if hasattr(self.env, 'seed'):
|
||||
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
|
||||
if hasattr(self.env, "seed"):
|
||||
return ray.get(self.env.seed.remote(seed))
|
||||
return None
|
||||
|
||||
def render(self, **kwargs) -> Any:
|
||||
if hasattr(self.env, 'render'):
|
||||
def render(self, **kwargs: Any) -> Any:
|
||||
if hasattr(self.env, "render"):
|
||||
return ray.get(self.env.render.remote(**kwargs))
|
||||
return None
|
||||
|
||||
def close_env(self) -> None:
|
||||
ray.get(self.env.close.remote())
|
||||
|
111
tianshou/env/worker/subproc.py
vendored
111
tianshou/env/worker/subproc.py
vendored
@ -5,14 +5,22 @@ import numpy as np
|
||||
from collections import OrderedDict
|
||||
from multiprocessing.context import Process
|
||||
from multiprocessing import Array, Pipe, connection
|
||||
from typing import Callable, Any, List, Tuple, Optional
|
||||
from typing import Any, List, Tuple, Union, Callable, Optional
|
||||
|
||||
from tianshou.env.worker import EnvWorker
|
||||
from tianshou.env.utils import CloudpickleWrapper
|
||||
|
||||
|
||||
def _worker(parent, p, env_fn_wrapper, obs_bufs=None):
|
||||
def _encode_obs(obs, buffer):
|
||||
def _worker(
|
||||
parent: connection.Connection,
|
||||
p: connection.Connection,
|
||||
env_fn_wrapper: CloudpickleWrapper,
|
||||
obs_bufs: Optional[Union[dict, tuple, "ShArray"]] = None,
|
||||
) -> None:
|
||||
def _encode_obs(
|
||||
obs: Union[dict, tuple, np.ndarray],
|
||||
buffer: Union[dict, tuple, ShArray],
|
||||
) -> None:
|
||||
if isinstance(obs, np.ndarray):
|
||||
buffer.save(obs)
|
||||
elif isinstance(obs, tuple):
|
||||
@ -32,25 +40,27 @@ def _worker(parent, p, env_fn_wrapper, obs_bufs=None):
|
||||
except EOFError: # the pipe has been closed
|
||||
p.close()
|
||||
break
|
||||
if cmd == 'step':
|
||||
if cmd == "step":
|
||||
obs, reward, done, info = env.step(data)
|
||||
if obs_bufs is not None:
|
||||
obs = _encode_obs(obs, obs_bufs)
|
||||
_encode_obs(obs, obs_bufs)
|
||||
obs = None
|
||||
p.send((obs, reward, done, info))
|
||||
elif cmd == 'reset':
|
||||
elif cmd == "reset":
|
||||
obs = env.reset()
|
||||
if obs_bufs is not None:
|
||||
obs = _encode_obs(obs, obs_bufs)
|
||||
_encode_obs(obs, obs_bufs)
|
||||
obs = None
|
||||
p.send(obs)
|
||||
elif cmd == 'close':
|
||||
elif cmd == "close":
|
||||
p.send(env.close())
|
||||
p.close()
|
||||
break
|
||||
elif cmd == 'render':
|
||||
p.send(env.render(**data) if hasattr(env, 'render') else None)
|
||||
elif cmd == 'seed':
|
||||
p.send(env.seed(data) if hasattr(env, 'seed') else None)
|
||||
elif cmd == 'getattr':
|
||||
elif cmd == "render":
|
||||
p.send(env.render(**data) if hasattr(env, "render") else None)
|
||||
elif cmd == "seed":
|
||||
p.send(env.seed(data) if hasattr(env, "seed") else None)
|
||||
elif cmd == "getattr":
|
||||
p.send(getattr(env, data) if hasattr(env, data) else None)
|
||||
else:
|
||||
p.close()
|
||||
@ -78,39 +88,39 @@ _NP_TO_CT = {
|
||||
class ShArray:
|
||||
"""Wrapper of multiprocessing Array."""
|
||||
|
||||
def __init__(self, dtype, shape):
|
||||
def __init__(self, dtype: np.generic, shape: Tuple[int]) -> None:
|
||||
self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape)))
|
||||
self.dtype = dtype
|
||||
self.shape = shape
|
||||
|
||||
def save(self, ndarray):
|
||||
def save(self, ndarray: np.ndarray) -> None:
|
||||
assert isinstance(ndarray, np.ndarray)
|
||||
dst = self.arr.get_obj()
|
||||
dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape)
|
||||
np.copyto(dst_np, ndarray)
|
||||
|
||||
def get(self):
|
||||
return np.frombuffer(self.arr.get_obj(),
|
||||
dtype=self.dtype).reshape(self.shape)
|
||||
def get(self) -> np.ndarray:
|
||||
obj = self.arr.get_obj()
|
||||
return np.frombuffer(obj, dtype=self.dtype).reshape(self.shape)
|
||||
|
||||
|
||||
def _setup_buf(space):
|
||||
def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:
|
||||
if isinstance(space, gym.spaces.Dict):
|
||||
assert isinstance(space.spaces, OrderedDict)
|
||||
buffer = {k: _setup_buf(v) for k, v in space.spaces.items()}
|
||||
return {k: _setup_buf(v) for k, v in space.spaces.items()}
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
assert isinstance(space.spaces, tuple)
|
||||
buffer = tuple([_setup_buf(t) for t in space.spaces])
|
||||
return tuple([_setup_buf(t) for t in space.spaces])
|
||||
else:
|
||||
buffer = ShArray(space.dtype, space.shape)
|
||||
return buffer
|
||||
return ShArray(space.dtype, space.shape)
|
||||
|
||||
|
||||
class SubprocEnvWorker(EnvWorker):
|
||||
"""Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv."""
|
||||
|
||||
def __init__(self, env_fn: Callable[[], gym.Env],
|
||||
share_memory=False) -> None:
|
||||
def __init__(
|
||||
self, env_fn: Callable[[], gym.Env], share_memory: bool = False
|
||||
) -> None:
|
||||
super().__init__(env_fn)
|
||||
self.parent_remote, self.child_remote = Pipe()
|
||||
self.share_memory = share_memory
|
||||
@ -121,18 +131,24 @@ class SubprocEnvWorker(EnvWorker):
|
||||
dummy.close()
|
||||
del dummy
|
||||
self.buffer = _setup_buf(obs_space)
|
||||
args = (self.parent_remote, self.child_remote,
|
||||
CloudpickleWrapper(env_fn), self.buffer)
|
||||
args = (
|
||||
self.parent_remote,
|
||||
self.child_remote,
|
||||
CloudpickleWrapper(env_fn),
|
||||
self.buffer,
|
||||
)
|
||||
self.process = Process(target=_worker, args=args, daemon=True)
|
||||
self.process.start()
|
||||
self.child_remote.close()
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
self.parent_remote.send(['getattr', key])
|
||||
self.parent_remote.send(["getattr", key])
|
||||
return self.parent_remote.recv()
|
||||
|
||||
def _decode_obs(self, isNone):
|
||||
def decode_obs(buffer):
|
||||
def _decode_obs(self) -> Union[dict, tuple, np.ndarray]:
|
||||
def decode_obs(
|
||||
buffer: Optional[Union[dict, tuple, ShArray]]
|
||||
) -> Union[dict, tuple, np.ndarray]:
|
||||
if isinstance(buffer, ShArray):
|
||||
return buffer.get()
|
||||
elif isinstance(buffer, tuple):
|
||||
@ -145,16 +161,18 @@ class SubprocEnvWorker(EnvWorker):
|
||||
return decode_obs(self.buffer)
|
||||
|
||||
def reset(self) -> Any:
|
||||
self.parent_remote.send(['reset', None])
|
||||
self.parent_remote.send(["reset", None])
|
||||
obs = self.parent_remote.recv()
|
||||
if self.share_memory:
|
||||
obs = self._decode_obs(obs)
|
||||
obs = self._decode_obs()
|
||||
return obs
|
||||
|
||||
@staticmethod
|
||||
def wait(workers: List['SubprocEnvWorker'],
|
||||
wait_num: int,
|
||||
timeout: Optional[float] = None) -> List['SubprocEnvWorker']:
|
||||
def wait(
|
||||
workers: List["SubprocEnvWorker"],
|
||||
wait_num: int,
|
||||
timeout: Optional[float] = None,
|
||||
) -> List["SubprocEnvWorker"]:
|
||||
conns, ready_conns = [x.parent_remote for x in workers], []
|
||||
remain_conns = conns
|
||||
t1 = time.time()
|
||||
@ -169,31 +187,32 @@ class SubprocEnvWorker(EnvWorker):
|
||||
new_ready_conns = connection.wait(
|
||||
remain_conns, timeout=remain_time)
|
||||
ready_conns.extend(new_ready_conns)
|
||||
remain_conns = [conn for conn in remain_conns
|
||||
if conn not in ready_conns]
|
||||
remain_conns = [
|
||||
conn for conn in remain_conns if conn not in ready_conns]
|
||||
return [workers[conns.index(con)] for con in ready_conns]
|
||||
|
||||
def send_action(self, action: np.ndarray) -> None:
|
||||
self.parent_remote.send(['step', action])
|
||||
self.parent_remote.send(["step", action])
|
||||
|
||||
def get_result(self) -> Tuple[
|
||||
np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
def get_result(
|
||||
self,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
obs, rew, done, info = self.parent_remote.recv()
|
||||
if self.share_memory:
|
||||
obs = self._decode_obs(obs)
|
||||
obs = self._decode_obs()
|
||||
return obs, rew, done, info
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> List[int]:
|
||||
self.parent_remote.send(['seed', seed])
|
||||
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
|
||||
self.parent_remote.send(["seed", seed])
|
||||
return self.parent_remote.recv()
|
||||
|
||||
def render(self, **kwargs) -> Any:
|
||||
self.parent_remote.send(['render', kwargs])
|
||||
def render(self, **kwargs: Any) -> Any:
|
||||
self.parent_remote.send(["render", kwargs])
|
||||
return self.parent_remote.recv()
|
||||
|
||||
def close_env(self) -> None:
|
||||
try:
|
||||
self.parent_remote.send(['close', None])
|
||||
self.parent_remote.send(["close", None])
|
||||
# mp may be deleted so it may raise AttributeError
|
||||
self.parent_remote.recv()
|
||||
self.process.join()
|
||||
|
@ -1,7 +1,7 @@
|
||||
from tianshou.exploration.random import BaseNoise, GaussianNoise, OUNoise
|
||||
|
||||
__all__ = [
|
||||
'BaseNoise',
|
||||
'GaussianNoise',
|
||||
'OUNoise',
|
||||
"BaseNoise",
|
||||
"GaussianNoise",
|
||||
"OUNoise",
|
||||
]
|
||||
|
@ -1,16 +1,16 @@
|
||||
import numpy as np
|
||||
from typing import Union, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Union, Optional, Sequence
|
||||
|
||||
|
||||
class BaseNoise(ABC, object):
|
||||
"""The action noise base class."""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, **kwargs) -> np.ndarray:
|
||||
def __call__(self, size: Sequence[int]) -> np.ndarray:
|
||||
"""Generate new noise."""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -22,15 +22,13 @@ class BaseNoise(ABC, object):
|
||||
class GaussianNoise(BaseNoise):
|
||||
"""The vanilla gaussian process, for exploration in DDPG by default."""
|
||||
|
||||
def __init__(self,
|
||||
mu: float = 0.0,
|
||||
sigma: float = 1.0):
|
||||
def __init__(self, mu: float = 0.0, sigma: float = 1.0) -> None:
|
||||
super().__init__()
|
||||
self._mu = mu
|
||||
assert 0 <= sigma, 'noise std should not be negative'
|
||||
assert 0 <= sigma, "Noise std should not be negative."
|
||||
self._sigma = sigma
|
||||
|
||||
def __call__(self, size: tuple) -> np.ndarray:
|
||||
def __call__(self, size: Sequence[int]) -> np.ndarray:
|
||||
return np.random.normal(self._mu, self._sigma, size)
|
||||
|
||||
|
||||
@ -51,27 +49,30 @@ class OUNoise(BaseNoise):
|
||||
Ornstein-Uhlenbeck process.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mu: float = 0.0,
|
||||
sigma: float = 0.3,
|
||||
theta: float = 0.15,
|
||||
dt: float = 1e-2,
|
||||
x0: Optional[Union[float, np.ndarray]] = None
|
||||
) -> None:
|
||||
super(BaseNoise, self).__init__()
|
||||
def __init__(
|
||||
self,
|
||||
mu: float = 0.0,
|
||||
sigma: float = 0.3,
|
||||
theta: float = 0.15,
|
||||
dt: float = 1e-2,
|
||||
x0: Optional[Union[float, np.ndarray]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._mu = mu
|
||||
self._alpha = theta * dt
|
||||
self._beta = sigma * np.sqrt(dt)
|
||||
self._x0 = x0
|
||||
self.reset()
|
||||
|
||||
def __call__(self, size: tuple, mu: Optional[float] = None) -> np.ndarray:
|
||||
def __call__(
|
||||
self, size: Sequence[int], mu: Optional[float] = None
|
||||
) -> np.ndarray:
|
||||
"""Generate new noise.
|
||||
|
||||
Return a ``numpy.ndarray`` which size is equal to ``size``.
|
||||
Return an numpy array which size is equal to ``size``.
|
||||
"""
|
||||
if self._x is None or self._x.shape != size:
|
||||
self._x = 0
|
||||
self._x = 0.0
|
||||
if mu is None:
|
||||
mu = self._mu
|
||||
r = self._beta * np.random.normal(size=size)
|
||||
|
@ -12,15 +12,15 @@ from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager
|
||||
|
||||
|
||||
__all__ = [
|
||||
'BasePolicy',
|
||||
'RandomPolicy',
|
||||
'ImitationPolicy',
|
||||
'DQNPolicy',
|
||||
'PGPolicy',
|
||||
'A2CPolicy',
|
||||
'DDPGPolicy',
|
||||
'PPOPolicy',
|
||||
'TD3Policy',
|
||||
'SACPolicy',
|
||||
'MultiAgentPolicyManager',
|
||||
"BasePolicy",
|
||||
"RandomPolicy",
|
||||
"ImitationPolicy",
|
||||
"DQNPolicy",
|
||||
"PGPolicy",
|
||||
"A2CPolicy",
|
||||
"DDPGPolicy",
|
||||
"PPOPolicy",
|
||||
"TD3Policy",
|
||||
"SACPolicy",
|
||||
"MultiAgentPolicyManager",
|
||||
]
|
||||
|
@ -4,7 +4,7 @@ import numpy as np
|
||||
from torch import nn
|
||||
from numba import njit
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Union, Optional, Callable
|
||||
from typing import Any, List, Union, Mapping, Optional, Callable
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \
|
||||
to_torch_as, to_numpy
|
||||
@ -52,23 +52,28 @@ class BasePolicy(ABC, nn.Module):
|
||||
policy.load_state_dict(torch.load("policy.pth"))
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
observation_space: gym.Space = None,
|
||||
action_space: gym.Space = None
|
||||
) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
observation_space: gym.Space = None,
|
||||
action_space: gym.Space = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.agent_id = 0
|
||||
self._compile()
|
||||
|
||||
def set_agent_id(self, agent_id: int) -> None:
|
||||
"""Set self.agent_id = agent_id, for MARL."""
|
||||
self.agent_id = agent_id
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
**kwargs) -> Batch:
|
||||
def forward(
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which MUST have the following\
|
||||
@ -96,8 +101,9 @@ class BasePolicy(ABC, nn.Module):
|
||||
"""
|
||||
pass
|
||||
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> Batch:
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> Batch:
|
||||
"""Pre-process the data from the provided replay buffer.
|
||||
|
||||
Used in :meth:`update`. Check out :ref:`process_fn` for more
|
||||
@ -106,8 +112,9 @@ class BasePolicy(ABC, nn.Module):
|
||||
return batch
|
||||
|
||||
@abstractmethod
|
||||
def learn(self, batch: Batch, **kwargs
|
||||
) -> Dict[str, Union[float, List[float]]]:
|
||||
def learn(
|
||||
self, batch: Batch, **kwargs: Any
|
||||
) -> Mapping[str, Union[float, List[float]]]:
|
||||
"""Update policy with a given batch of data.
|
||||
|
||||
:return: A dict which includes loss and its corresponding label.
|
||||
@ -123,19 +130,22 @@ class BasePolicy(ABC, nn.Module):
|
||||
"""
|
||||
pass
|
||||
|
||||
def post_process_fn(self, batch: Batch,
|
||||
buffer: ReplayBuffer, indice: np.ndarray) -> None:
|
||||
def post_process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> None:
|
||||
"""Post-process the data from the provided replay buffer.
|
||||
|
||||
Typical usage is to update the sampling weight in prioritized
|
||||
experience replay. Used in :meth:`update`.
|
||||
"""
|
||||
if isinstance(buffer, PrioritizedReplayBuffer) \
|
||||
and hasattr(batch, 'weight'):
|
||||
if isinstance(buffer, PrioritizedReplayBuffer) and hasattr(
|
||||
batch, "weight"
|
||||
):
|
||||
buffer.update_weight(indice, batch.weight)
|
||||
|
||||
def update(self, sample_size: int, buffer: Optional[ReplayBuffer],
|
||||
*args, **kwargs) -> Dict[str, Union[float, List[float]]]:
|
||||
def update(
|
||||
self, sample_size: int, buffer: Optional[ReplayBuffer], **kwargs: Any
|
||||
) -> Mapping[str, Union[float, List[float]]]:
|
||||
"""Update the policy network and replay buffer.
|
||||
|
||||
It includes 3 function steps: process_fn, learn, and post_process_fn.
|
||||
@ -148,7 +158,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
return {}
|
||||
batch, indice = buffer.sample(sample_size)
|
||||
batch = self.process_fn(batch, buffer, indice)
|
||||
result = self.learn(batch, *args, **kwargs)
|
||||
result = self.learn(batch, **kwargs)
|
||||
self.post_process_fn(batch, buffer, indice)
|
||||
return result
|
||||
|
||||
@ -182,7 +192,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
rew = batch.rew
|
||||
v_s_ = np.zeros_like(rew) if v_s_ is None else to_numpy(v_s_).flatten()
|
||||
returns = _episodic_return(v_s_, rew, batch.done, gamma, gae_lambda)
|
||||
if rew_norm and not np.isclose(returns.std(), 0, 1e-2):
|
||||
if rew_norm and not np.isclose(returns.std(), 0.0, 1e-2):
|
||||
returns = (returns - returns.mean()) / returns.std()
|
||||
batch.returns = returns
|
||||
return batch
|
||||
@ -231,9 +241,9 @@ class BasePolicy(ABC, nn.Module):
|
||||
bfr = rew[:min(len(buffer), 1000)] # avoid large buffer
|
||||
mean, std = bfr.mean(), bfr.std()
|
||||
if np.isclose(std, 0, 1e-2):
|
||||
mean, std = 0., 1.
|
||||
mean, std = 0.0, 1.0
|
||||
else:
|
||||
mean, std = 0., 1.
|
||||
mean, std = 0.0, 1.0
|
||||
buf_len = len(buffer)
|
||||
terminal = (indice + n_step - 1) % buf_len
|
||||
target_q_torch = target_q_fn(buffer, terminal).flatten() # (bsz, )
|
||||
@ -248,18 +258,30 @@ class BasePolicy(ABC, nn.Module):
|
||||
batch.weight = to_torch_as(batch.weight, target_q_torch)
|
||||
return batch
|
||||
|
||||
def _compile(self) -> None:
|
||||
f64 = np.array([0, 1], dtype=np.float64)
|
||||
f32 = np.array([0, 1], dtype=np.float32)
|
||||
b = np.array([False, True], dtype=np.bool_)
|
||||
i64 = np.array([0, 1], dtype=np.int64)
|
||||
_episodic_return(f64, f64, b, 0.1, 0.1)
|
||||
_episodic_return(f32, f64, b, 0.1, 0.1)
|
||||
_nstep_return(f64, b, f32, i64, 0.1, 1, 4, 1.0, 0.0)
|
||||
|
||||
|
||||
@njit
|
||||
def _episodic_return(
|
||||
v_s_: np.ndarray, rew: np.ndarray, done: np.ndarray,
|
||||
gamma: float, gae_lambda: float,
|
||||
v_s_: np.ndarray,
|
||||
rew: np.ndarray,
|
||||
done: np.ndarray,
|
||||
gamma: float,
|
||||
gae_lambda: float,
|
||||
) -> np.ndarray:
|
||||
"""Numba speedup: 4.1s -> 0.057s."""
|
||||
returns = np.roll(v_s_, 1)
|
||||
m = (1. - done) * gamma
|
||||
m = (1.0 - done) * gamma
|
||||
delta = rew + v_s_ * m - returns
|
||||
m *= gae_lambda
|
||||
gae = 0.
|
||||
gae = 0.0
|
||||
for i in range(len(rew) - 1, -1, -1):
|
||||
gae = delta[i] + m[i] * gae
|
||||
returns[i] += gae
|
||||
@ -268,9 +290,15 @@ def _episodic_return(
|
||||
|
||||
@njit
|
||||
def _nstep_return(
|
||||
rew: np.ndarray, done: np.ndarray, target_q: np.ndarray,
|
||||
indice: np.ndarray, gamma: float, n_step: int, buf_len: int,
|
||||
mean: float, std: float
|
||||
rew: np.ndarray,
|
||||
done: np.ndarray,
|
||||
target_q: np.ndarray,
|
||||
indice: np.ndarray,
|
||||
gamma: float,
|
||||
n_step: int,
|
||||
buf_len: int,
|
||||
mean: float,
|
||||
std: float,
|
||||
) -> np.ndarray:
|
||||
"""Numba speedup: 0.3s -> 0.15s."""
|
||||
returns = np.zeros(indice.shape)
|
||||
@ -278,8 +306,8 @@ def _nstep_return(
|
||||
for n in range(n_step - 1, -1, -1):
|
||||
now = (indice + n) % buf_len
|
||||
gammas[done[now] > 0] = n
|
||||
returns[done[now] > 0] = 0.
|
||||
returns[done[now] > 0] = 0.0
|
||||
returns = (rew[now] - mean) / std + gamma * returns
|
||||
target_q[gammas != n_step] = 0
|
||||
target_q[gammas != n_step] = 0.0
|
||||
target_q = target_q * (gamma ** gammas) + returns
|
||||
return target_q
|
||||
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Union, Optional
|
||||
from typing import Any, Dict, Union, Optional
|
||||
|
||||
from tianshou.data import Batch, to_torch
|
||||
from tianshou.policy import BasePolicy
|
||||
@ -22,36 +22,44 @@ class ImitationPolicy(BasePolicy):
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module, optim: torch.optim.Optimizer,
|
||||
mode: str = 'continuous') -> None:
|
||||
super().__init__()
|
||||
def __init__(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
mode: str = "continuous",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.model = model
|
||||
self.optim = optim
|
||||
assert mode in ['continuous', 'discrete'], \
|
||||
f'Mode {mode} is not in ["continuous", "discrete"]'
|
||||
assert (
|
||||
mode in ["continuous", "discrete"]
|
||||
), f"Mode {mode} is not in ['continuous', 'discrete']."
|
||||
self.mode = mode
|
||||
|
||||
def forward(self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
**kwargs) -> Batch:
|
||||
def forward(
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
logits, h = self.model(batch.obs, state=state, info=batch.info)
|
||||
if self.mode == 'discrete':
|
||||
if self.mode == "discrete":
|
||||
a = logits.max(dim=1)[1]
|
||||
else:
|
||||
a = logits
|
||||
return Batch(logits=logits, act=a, state=h)
|
||||
|
||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
self.optim.zero_grad()
|
||||
if self.mode == 'continuous':
|
||||
if self.mode == "continuous": # regression
|
||||
a = self(batch).act
|
||||
a_ = to_torch(batch.act, dtype=torch.float32, device=a.device)
|
||||
loss = F.mse_loss(a, a_)
|
||||
elif self.mode == 'discrete': # classification
|
||||
elif self.mode == "discrete": # classification
|
||||
a = self(batch).logits
|
||||
a_ = to_torch(batch.act, dtype=torch.long, device=a.device)
|
||||
loss = F.nll_loss(a, a_)
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
return {'loss': loss.item()}
|
||||
return {"loss": loss.item()}
|
||||
|
@ -2,7 +2,7 @@ import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, List, Union, Optional
|
||||
from typing import Any, Dict, List, Union, Optional, Callable
|
||||
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
|
||||
@ -17,6 +17,7 @@ class A2CPolicy(PGPolicy):
|
||||
:param torch.optim.Optimizer optim: the optimizer for actor and critic
|
||||
network.
|
||||
:param dist_fn: distribution class for computing the action.
|
||||
:type dist_fn: Callable[[], torch.distributions.Distribution]
|
||||
:param float discount_factor: in [0, 1], defaults to 0.99.
|
||||
:param float vf_coef: weight for value loss, defaults to 0.5.
|
||||
:param float ent_coef: weight for entropy loss, defaults to 0.01.
|
||||
@ -37,23 +38,25 @@ class A2CPolicy(PGPolicy):
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: torch.distributions.Distribution,
|
||||
discount_factor: float = 0.99,
|
||||
vf_coef: float = .5,
|
||||
ent_coef: float = .01,
|
||||
max_grad_norm: Optional[float] = None,
|
||||
gae_lambda: float = 0.95,
|
||||
reward_normalization: bool = False,
|
||||
max_batchsize: int = 256,
|
||||
**kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: Callable[[], torch.distributions.Distribution],
|
||||
discount_factor: float = 0.99,
|
||||
vf_coef: float = 0.5,
|
||||
ent_coef: float = 0.01,
|
||||
max_grad_norm: Optional[float] = None,
|
||||
gae_lambda: float = 0.95,
|
||||
reward_normalization: bool = False,
|
||||
max_batchsize: int = 256,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].'
|
||||
assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]."
|
||||
self._lambda = gae_lambda
|
||||
self._w_vf = vf_coef
|
||||
self._w_ent = ent_coef
|
||||
@ -61,9 +64,10 @@ class A2CPolicy(PGPolicy):
|
||||
self._batch = max_batchsize
|
||||
self._rew_norm = reward_normalization
|
||||
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> Batch:
|
||||
if self._lambda in [0, 1]:
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> Batch:
|
||||
if self._lambda in [0.0, 1.0]:
|
||||
return self.compute_episodic_return(
|
||||
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
||||
v_ = []
|
||||
@ -75,9 +79,12 @@ class A2CPolicy(PGPolicy):
|
||||
batch, v_, gamma=self._gamma, gae_lambda=self._lambda,
|
||||
rew_norm=self._rew_norm)
|
||||
|
||||
def forward(self, batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
**kwargs) -> Batch:
|
||||
def forward(
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
**kwargs: Any
|
||||
) -> Batch:
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
|
||||
@ -100,8 +107,9 @@ class A2CPolicy(PGPolicy):
|
||||
act = dist.sample()
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def learn(self, batch: Batch, batch_size: int, repeat: int,
|
||||
**kwargs) -> Dict[str, List[float]]:
|
||||
def learn(
|
||||
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
||||
) -> Dict[str, List[float]]:
|
||||
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size, merge_last=True):
|
||||
@ -110,8 +118,7 @@ class A2CPolicy(PGPolicy):
|
||||
v = self.critic(b.obs).flatten()
|
||||
a = to_torch_as(b.act, v)
|
||||
r = to_torch_as(b.returns, v)
|
||||
log_prob = dist.log_prob(a).reshape(
|
||||
r.shape[0], -1).transpose(0, 1)
|
||||
log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1)
|
||||
a_loss = -(log_prob * (r - v).detach()).mean()
|
||||
vf_loss = F.mse_loss(r, v)
|
||||
ent_loss = dist.entropy().mean()
|
||||
@ -119,17 +126,18 @@ class A2CPolicy(PGPolicy):
|
||||
loss.backward()
|
||||
if self._grad_norm is not None:
|
||||
nn.utils.clip_grad_norm_(
|
||||
list(self.actor.parameters()) +
|
||||
list(self.critic.parameters()),
|
||||
max_norm=self._grad_norm)
|
||||
list(self.actor.parameters())
|
||||
+ list(self.critic.parameters()),
|
||||
max_norm=self._grad_norm,
|
||||
)
|
||||
self.optim.step()
|
||||
actor_losses.append(a_loss.item())
|
||||
vf_losses.append(vf_loss.item())
|
||||
ent_losses.append(ent_loss.item())
|
||||
losses.append(loss.item())
|
||||
return {
|
||||
'loss': losses,
|
||||
'loss/actor': actor_losses,
|
||||
'loss/vf': vf_losses,
|
||||
'loss/ent': ent_losses,
|
||||
"loss": losses,
|
||||
"loss/actor": actor_losses,
|
||||
"loss/vf": vf_losses,
|
||||
"loss/ent": ent_losses,
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Tuple, Union, Optional
|
||||
from typing import Any, Dict, Tuple, Union, Optional
|
||||
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.exploration import BaseNoise, GaussianNoise
|
||||
@ -17,13 +17,13 @@ class DDPGPolicy(BasePolicy):
|
||||
:param torch.nn.Module critic: the critic network. (s, a -> Q(s, a))
|
||||
:param torch.optim.Optimizer critic_optim: the optimizer for critic
|
||||
network.
|
||||
:param action_range: the action range (minimum, maximum).
|
||||
:type action_range: Tuple[float, float]
|
||||
:param float tau: param for soft update of the target network, defaults to
|
||||
0.005.
|
||||
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
|
||||
:param BaseNoise exploration_noise: the exploration noise,
|
||||
add to the action, defaults to ``GaussianNoise(sigma=0.1)``.
|
||||
:param action_range: the action range (minimum, maximum).
|
||||
:type action_range: (float, float)
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||
defaults to False.
|
||||
:param bool ignore_done: ignore the done flag while training the policy,
|
||||
@ -37,20 +37,21 @@ class DDPGPolicy(BasePolicy):
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
actor: torch.nn.Module,
|
||||
actor_optim: torch.optim.Optimizer,
|
||||
critic: torch.nn.Module,
|
||||
critic_optim: torch.optim.Optimizer,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
exploration_noise: Optional[BaseNoise]
|
||||
= GaussianNoise(sigma=0.1),
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
reward_normalization: bool = False,
|
||||
ignore_done: bool = False,
|
||||
estimation_step: int = 1,
|
||||
**kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
actor: Optional[torch.nn.Module],
|
||||
actor_optim: Optional[torch.optim.Optimizer],
|
||||
critic: Optional[torch.nn.Module],
|
||||
critic_optim: Optional[torch.optim.Optimizer],
|
||||
action_range: Tuple[float, float],
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
|
||||
reward_normalization: bool = False,
|
||||
ignore_done: bool = False,
|
||||
estimation_step: int = 1,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if actor is not None:
|
||||
self.actor, self.actor_old = actor, deepcopy(actor)
|
||||
@ -60,27 +61,26 @@ class DDPGPolicy(BasePolicy):
|
||||
self.critic, self.critic_old = critic, deepcopy(critic)
|
||||
self.critic_old.eval()
|
||||
self.critic_optim = critic_optim
|
||||
assert 0 <= tau <= 1, 'tau should in [0, 1]'
|
||||
assert 0.0 <= tau <= 1.0, "tau should be in [0, 1]"
|
||||
self._tau = tau
|
||||
assert 0 <= gamma <= 1, 'gamma should in [0, 1]'
|
||||
assert 0.0 <= gamma <= 1.0, "gamma should be in [0, 1]"
|
||||
self._gamma = gamma
|
||||
self._noise = exploration_noise
|
||||
assert action_range is not None
|
||||
self._range = action_range
|
||||
self._action_bias = (action_range[0] + action_range[1]) / 2
|
||||
self._action_scale = (action_range[1] - action_range[0]) / 2
|
||||
# it is only a little difference to use rand_normal
|
||||
self._action_bias = (action_range[0] + action_range[1]) / 2.0
|
||||
self._action_scale = (action_range[1] - action_range[0]) / 2.0
|
||||
# it is only a little difference to use GaussianNoise
|
||||
# self.noise = OUNoise()
|
||||
self._rm_done = ignore_done
|
||||
self._rew_norm = reward_normalization
|
||||
assert estimation_step > 0, 'estimation_step should greater than 0'
|
||||
assert estimation_step > 0, "estimation_step should be greater than 0"
|
||||
self._n_step = estimation_step
|
||||
|
||||
def set_exp_noise(self, noise: Optional[BaseNoise]) -> None:
|
||||
"""Set the exploration noise."""
|
||||
self._noise = noise
|
||||
|
||||
def train(self, mode=True) -> torch.nn.Module:
|
||||
def train(self, mode: bool = True) -> "DDPGPolicy":
|
||||
"""Set the module in training mode, except for the target network."""
|
||||
self.training = mode
|
||||
self.actor.train(mode)
|
||||
@ -90,13 +90,15 @@ class DDPGPolicy(BasePolicy):
|
||||
def sync_weight(self) -> None:
|
||||
"""Soft-update the weight for the target network."""
|
||||
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
|
||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
|
||||
for o, n in zip(
|
||||
self.critic_old.parameters(), self.critic.parameters()):
|
||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||
self.critic_old.parameters(), self.critic.parameters()
|
||||
):
|
||||
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> torch.Tensor:
|
||||
def _target_q(
|
||||
self, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
||||
with torch.no_grad():
|
||||
target_q = self.critic_old(batch.obs_next, self(
|
||||
@ -104,21 +106,25 @@ class DDPGPolicy(BasePolicy):
|
||||
explorating=False).act)
|
||||
return target_q
|
||||
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> Batch:
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> Batch:
|
||||
if self._rm_done:
|
||||
batch.done = batch.done * 0.
|
||||
batch.done = batch.done * 0.0
|
||||
batch = self.compute_nstep_return(
|
||||
batch, buffer, indice, self._target_q,
|
||||
self._gamma, self._n_step, self._rew_norm)
|
||||
return batch
|
||||
|
||||
def forward(self, batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
model: str = 'actor',
|
||||
input: str = 'obs',
|
||||
explorating: bool = True,
|
||||
**kwargs) -> Batch:
|
||||
def forward(
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
model: str = "actor",
|
||||
input: str = "obs",
|
||||
explorating: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which has 2 keys:
|
||||
@ -140,8 +146,8 @@ class DDPGPolicy(BasePolicy):
|
||||
actions = actions.clamp(self._range[0], self._range[1])
|
||||
return Batch(act=actions, state=h)
|
||||
|
||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||
weight = batch.pop('weight', 1.)
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
weight = batch.pop("weight", 1.0)
|
||||
current_q = self.critic(batch.obs, batch.act).flatten()
|
||||
target_q = batch.returns.flatten()
|
||||
td = current_q - target_q
|
||||
@ -157,6 +163,6 @@ class DDPGPolicy(BasePolicy):
|
||||
self.actor_optim.step()
|
||||
self.sync_weight()
|
||||
return {
|
||||
'loss/actor': actor_loss.item(),
|
||||
'loss/critic': critic_loss.item(),
|
||||
"loss/actor": actor_loss.item(),
|
||||
"loss/critic": critic_loss.item(),
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Union, Optional
|
||||
from typing import Any, Dict, Union, Optional
|
||||
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
|
||||
@ -32,21 +32,25 @@ class DQNPolicy(BasePolicy):
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
discount_factor: float = 0.99,
|
||||
estimation_step: int = 1,
|
||||
target_update_freq: int = 0,
|
||||
reward_normalization: bool = False,
|
||||
**kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
discount_factor: float = 0.99,
|
||||
estimation_step: int = 1,
|
||||
target_update_freq: int = 0,
|
||||
reward_normalization: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.model = model
|
||||
self.optim = optim
|
||||
self.eps = 0
|
||||
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
|
||||
self.eps = 0.0
|
||||
assert (
|
||||
0.0 <= discount_factor <= 1.0
|
||||
), "discount factor should be in [0, 1]"
|
||||
self._gamma = discount_factor
|
||||
assert estimation_step > 0, 'estimation_step should greater than 0'
|
||||
assert estimation_step > 0, "estimation_step should be greater than 0"
|
||||
self._n_step = estimation_step
|
||||
self._target = target_update_freq > 0
|
||||
self._freq = target_update_freq
|
||||
@ -60,7 +64,7 @@ class DQNPolicy(BasePolicy):
|
||||
"""Set the eps for epsilon-greedy exploration."""
|
||||
self.eps = eps
|
||||
|
||||
def train(self, mode=True) -> torch.nn.Module:
|
||||
def train(self, mode: bool = True) -> "DQNPolicy":
|
||||
"""Set the module in training mode, except for the target network."""
|
||||
self.training = mode
|
||||
self.model.train(mode)
|
||||
@ -70,23 +74,26 @@ class DQNPolicy(BasePolicy):
|
||||
"""Synchronize the weight for the target network."""
|
||||
self.model_old.load_state_dict(self.model.state_dict())
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> torch.Tensor:
|
||||
def _target_q(
|
||||
self, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
||||
if self._target:
|
||||
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
||||
a = self(batch, input='obs_next', eps=0).act
|
||||
a = self(batch, input="obs_next", eps=0).act
|
||||
with torch.no_grad():
|
||||
target_q = self(
|
||||
batch, model='model_old', input='obs_next').logits
|
||||
batch, model="model_old", input="obs_next"
|
||||
).logits
|
||||
target_q = target_q[np.arange(len(a)), a]
|
||||
else:
|
||||
with torch.no_grad():
|
||||
target_q = self(batch, input='obs_next').logits.max(dim=1)[0]
|
||||
target_q = self(batch, input="obs_next").logits.max(dim=1)[0]
|
||||
return target_q
|
||||
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> Batch:
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> Batch:
|
||||
"""Compute the n-step return for Q-learning targets.
|
||||
|
||||
More details can be found at
|
||||
@ -97,12 +104,15 @@ class DQNPolicy(BasePolicy):
|
||||
self._gamma, self._n_step, self._rew_norm)
|
||||
return batch
|
||||
|
||||
def forward(self, batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
model: str = 'model',
|
||||
input: str = 'obs',
|
||||
eps: Optional[float] = None,
|
||||
**kwargs) -> Batch:
|
||||
def forward(
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
model: str = "model",
|
||||
input: str = "obs",
|
||||
eps: Optional[float] = None,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
If you need to mask the action, please add a "mask" into batch.obs, for
|
||||
@ -134,7 +144,7 @@ class DQNPolicy(BasePolicy):
|
||||
"""
|
||||
model = getattr(self, model)
|
||||
obs = getattr(batch, input)
|
||||
obs_ = obs.obs if hasattr(obs, 'obs') else obs
|
||||
obs_ = obs.obs if hasattr(obs, "obs") else obs
|
||||
q, h = model(obs_, state=state, info=batch.info)
|
||||
act = to_numpy(q.max(dim=1)[1])
|
||||
has_mask = hasattr(obs, 'mask')
|
||||
@ -146,7 +156,7 @@ class DQNPolicy(BasePolicy):
|
||||
# add eps to act
|
||||
if eps is None:
|
||||
eps = self.eps
|
||||
if not np.isclose(eps, 0):
|
||||
if not np.isclose(eps, 0.0):
|
||||
for i in range(len(q)):
|
||||
if np.random.rand() < eps:
|
||||
q_ = np.random.rand(*q[i].shape)
|
||||
@ -155,12 +165,12 @@ class DQNPolicy(BasePolicy):
|
||||
act[i] = q_.argmax()
|
||||
return Batch(logits=q, act=act, state=h)
|
||||
|
||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
if self._target and self._cnt % self._freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
weight = batch.pop('weight', 1.)
|
||||
q = self(batch, eps=0.).logits
|
||||
weight = batch.pop("weight", 1.0)
|
||||
q = self(batch, eps=0.0).logits
|
||||
q = q[np.arange(len(q)), batch.act]
|
||||
r = to_torch_as(batch.returns, q).flatten()
|
||||
td = r - q
|
||||
@ -169,4 +179,4 @@ class DQNPolicy(BasePolicy):
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
self._cnt += 1
|
||||
return {'loss': loss.item()}
|
||||
return {"loss": loss.item()}
|
||||
|
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Dict, List, Union, Optional
|
||||
from typing import Any, Dict, List, Union, Optional, Callable
|
||||
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
||||
@ -13,6 +13,7 @@ class PGPolicy(BasePolicy):
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
||||
:param dist_fn: distribution class for computing the action.
|
||||
:type dist_fn: Callable[[], torch.distributions.Distribution]
|
||||
:param float discount_factor: in [0, 1].
|
||||
|
||||
.. seealso::
|
||||
@ -21,23 +22,28 @@ class PGPolicy(BasePolicy):
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: torch.distributions.Distribution,
|
||||
discount_factor: float = 0.99,
|
||||
reward_normalization: bool = False,
|
||||
**kwargs) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[torch.nn.Module],
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: Callable[[], torch.distributions.Distribution],
|
||||
discount_factor: float = 0.99,
|
||||
reward_normalization: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.model = model
|
||||
self.optim = optim
|
||||
self.dist_fn = dist_fn
|
||||
assert 0 <= discount_factor <= 1, 'discount factor should in [0, 1]'
|
||||
assert (
|
||||
0.0 <= discount_factor <= 1.0
|
||||
), "discount factor should be in [0, 1]"
|
||||
self._gamma = discount_factor
|
||||
self._rew_norm = reward_normalization
|
||||
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> Batch:
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> Batch:
|
||||
r"""Compute the discounted returns for each frame.
|
||||
|
||||
.. math::
|
||||
@ -48,13 +54,15 @@ class PGPolicy(BasePolicy):
|
||||
"""
|
||||
# batch.returns = self._vanilla_returns(batch)
|
||||
# batch.returns = self._vectorized_returns(batch)
|
||||
# return batch
|
||||
return self.compute_episodic_return(
|
||||
batch, gamma=self._gamma, gae_lambda=1., rew_norm=self._rew_norm)
|
||||
batch, gamma=self._gamma, gae_lambda=1.0, rew_norm=self._rew_norm)
|
||||
|
||||
def forward(self, batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
**kwargs) -> Batch:
|
||||
def forward(
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
|
||||
@ -77,8 +85,9 @@ class PGPolicy(BasePolicy):
|
||||
act = dist.sample()
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def learn(self, batch: Batch, batch_size: int, repeat: int,
|
||||
**kwargs) -> Dict[str, List[float]]:
|
||||
def learn(
|
||||
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
||||
) -> Dict[str, List[float]]:
|
||||
losses = []
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size, merge_last=True):
|
||||
@ -86,13 +95,12 @@ class PGPolicy(BasePolicy):
|
||||
dist = self(b).dist
|
||||
a = to_torch_as(b.act, dist.logits)
|
||||
r = to_torch_as(b.returns, dist.logits)
|
||||
log_prob = dist.log_prob(a).reshape(
|
||||
r.shape[0], -1).transpose(0, 1)
|
||||
log_prob = dist.log_prob(a).reshape(len(r), -1).transpose(0, 1)
|
||||
loss = -(log_prob * r).mean()
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
losses.append(loss.item())
|
||||
return {'loss': losses}
|
||||
return {"loss": losses}
|
||||
|
||||
# def _vanilla_returns(self, batch):
|
||||
# returns = batch.rew[:]
|
||||
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from typing import Dict, List, Tuple, Union, Optional
|
||||
from typing import Any, Dict, List, Tuple, Union, Optional, Callable
|
||||
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
|
||||
@ -16,6 +16,7 @@ class PPOPolicy(PGPolicy):
|
||||
:param torch.optim.Optimizer optim: the optimizer for actor and critic
|
||||
network.
|
||||
:param dist_fn: distribution class for computing the action.
|
||||
:type dist_fn: Callable[[], torch.distributions.Distribution]
|
||||
:param float discount_factor: in [0, 1], defaults to 0.99.
|
||||
:param float max_grad_norm: clipping gradients in back propagation,
|
||||
defaults to None.
|
||||
@ -45,24 +46,26 @@ class PPOPolicy(PGPolicy):
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: torch.distributions.Distribution,
|
||||
discount_factor: float = 0.99,
|
||||
max_grad_norm: Optional[float] = None,
|
||||
eps_clip: float = .2,
|
||||
vf_coef: float = .5,
|
||||
ent_coef: float = .01,
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
gae_lambda: float = 0.95,
|
||||
dual_clip: Optional[float] = None,
|
||||
value_clip: bool = True,
|
||||
reward_normalization: bool = True,
|
||||
max_batchsize: int = 256,
|
||||
**kwargs) -> None:
|
||||
super().__init__(None, None, dist_fn, discount_factor, **kwargs)
|
||||
def __init__(
|
||||
self,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: Callable[[], torch.distributions.Distribution],
|
||||
discount_factor: float = 0.99,
|
||||
max_grad_norm: Optional[float] = None,
|
||||
eps_clip: float = 0.2,
|
||||
vf_coef: float = 0.5,
|
||||
ent_coef: float = 0.01,
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
gae_lambda: float = 0.95,
|
||||
dual_clip: Optional[float] = None,
|
||||
value_clip: bool = True,
|
||||
reward_normalization: bool = True,
|
||||
max_batchsize: int = 256,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
|
||||
self._max_grad_norm = max_grad_norm
|
||||
self._eps_clip = eps_clip
|
||||
self._w_vf = vf_coef
|
||||
@ -70,29 +73,31 @@ class PPOPolicy(PGPolicy):
|
||||
self._range = action_range
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
self.optim = optim
|
||||
self._batch = max_batchsize
|
||||
assert 0 <= gae_lambda <= 1, 'GAE lambda should be in [0, 1].'
|
||||
assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]."
|
||||
self._lambda = gae_lambda
|
||||
assert dual_clip is None or dual_clip > 1, \
|
||||
'Dual-clip PPO parameter should greater than 1.'
|
||||
assert (
|
||||
dual_clip is None or dual_clip > 1.0
|
||||
), "Dual-clip PPO parameter should greater than 1.0."
|
||||
self._dual_clip = dual_clip
|
||||
self._value_clip = value_clip
|
||||
self._rew_norm = reward_normalization
|
||||
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> Batch:
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> Batch:
|
||||
if self._rew_norm:
|
||||
mean, std = batch.rew.mean(), batch.rew.std()
|
||||
if not np.isclose(std, 0, 1e-2):
|
||||
if not np.isclose(std, 0.0, 1e-2):
|
||||
batch.rew = (batch.rew - mean) / std
|
||||
v, v_, old_log_prob = [], [], []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(self._batch, shuffle=False, merge_last=True):
|
||||
v_.append(self.critic(b.obs_next))
|
||||
v.append(self.critic(b.obs))
|
||||
old_log_prob.append(self(b).dist.log_prob(
|
||||
to_torch_as(b.act, v[0])))
|
||||
old_log_prob.append(
|
||||
self(b).dist.log_prob(to_torch_as(b.act, v[0]))
|
||||
)
|
||||
v_ = to_numpy(torch.cat(v_, dim=0))
|
||||
batch = self.compute_episodic_return(
|
||||
batch, v_, gamma=self._gamma, gae_lambda=self._lambda,
|
||||
@ -104,13 +109,16 @@ class PPOPolicy(PGPolicy):
|
||||
batch.adv = batch.returns - batch.v
|
||||
if self._rew_norm:
|
||||
mean, std = batch.adv.mean(), batch.adv.std()
|
||||
if not np.isclose(std.item(), 0, 1e-2):
|
||||
if not np.isclose(std.item(), 0.0, 1e-2):
|
||||
batch.adv = (batch.adv - mean) / std
|
||||
return batch
|
||||
|
||||
def forward(self, batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
**kwargs) -> Batch:
|
||||
def forward(
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which has 4 keys:
|
||||
@ -135,8 +143,9 @@ class PPOPolicy(PGPolicy):
|
||||
act = act.clamp(self._range[0], self._range[1])
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def learn(self, batch: Batch, batch_size: int, repeat: int,
|
||||
**kwargs) -> Dict[str, List[float]]:
|
||||
def learn(
|
||||
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
||||
) -> Dict[str, List[float]]:
|
||||
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
||||
for _ in range(repeat):
|
||||
for b in batch.split(batch_size, merge_last=True):
|
||||
@ -145,8 +154,8 @@ class PPOPolicy(PGPolicy):
|
||||
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
||||
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
|
||||
surr1 = ratio * b.adv
|
||||
surr2 = ratio.clamp(1. - self._eps_clip,
|
||||
1. + self._eps_clip) * b.adv
|
||||
surr2 = ratio.clamp(1.0 - self._eps_clip,
|
||||
1.0 + self._eps_clip) * b.adv
|
||||
if self._dual_clip:
|
||||
clip_loss = -torch.max(torch.min(surr1, surr2),
|
||||
self._dual_clip * b.adv).mean()
|
||||
@ -158,9 +167,9 @@ class PPOPolicy(PGPolicy):
|
||||
-self._eps_clip, self._eps_clip)
|
||||
vf1 = (b.returns - value).pow(2)
|
||||
vf2 = (b.returns - v_clip).pow(2)
|
||||
vf_loss = .5 * torch.max(vf1, vf2).mean()
|
||||
vf_loss = 0.5 * torch.max(vf1, vf2).mean()
|
||||
else:
|
||||
vf_loss = .5 * (b.returns - value).pow(2).mean()
|
||||
vf_loss = 0.5 * (b.returns - value).pow(2).mean()
|
||||
vf_losses.append(vf_loss.item())
|
||||
e_loss = dist.entropy().mean()
|
||||
ent_losses.append(e_loss.item())
|
||||
@ -168,13 +177,14 @@ class PPOPolicy(PGPolicy):
|
||||
losses.append(loss.item())
|
||||
self.optim.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(list(
|
||||
self.actor.parameters()) + list(self.critic.parameters()),
|
||||
nn.utils.clip_grad_norm_(
|
||||
list(self.actor.parameters())
|
||||
+ list(self.critic.parameters()),
|
||||
self._max_grad_norm)
|
||||
self.optim.step()
|
||||
return {
|
||||
'loss': losses,
|
||||
'loss/clip': clip_losses,
|
||||
'loss/vf': vf_losses,
|
||||
'loss/ent': ent_losses,
|
||||
"loss": losses,
|
||||
"loss/clip": clip_losses,
|
||||
"loss/vf": vf_losses,
|
||||
"loss/ent": ent_losses,
|
||||
}
|
||||
|
@ -1,12 +1,12 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Tuple, Union, Optional
|
||||
from torch.distributions import Normal, Independent
|
||||
from torch.distributions import Independent, Normal
|
||||
from typing import Any, Dict, Tuple, Union, Optional
|
||||
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.data import Batch, to_torch_as, ReplayBuffer
|
||||
from tianshou.exploration import BaseNoise
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as
|
||||
|
||||
|
||||
class SACPolicy(DDPGPolicy):
|
||||
@ -23,6 +23,8 @@ class SACPolicy(DDPGPolicy):
|
||||
a))
|
||||
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
|
||||
critic network.
|
||||
:param action_range: the action range (minimum, maximum).
|
||||
:type action_range: Tuple[float, float]
|
||||
:param float tau: param for soft update of the target network, defaults to
|
||||
0.005.
|
||||
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
|
||||
@ -32,8 +34,6 @@ class SACPolicy(DDPGPolicy):
|
||||
regularization coefficient, default to 0.2.
|
||||
If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then
|
||||
alpha is automatatically tuned.
|
||||
:param action_range: the action range (minimum, maximum).
|
||||
:type action_range: (float, float)
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||
defaults to False.
|
||||
:param bool ignore_done: ignore the done flag while training the policy,
|
||||
@ -55,20 +55,20 @@ class SACPolicy(DDPGPolicy):
|
||||
critic1_optim: torch.optim.Optimizer,
|
||||
critic2: torch.nn.Module,
|
||||
critic2_optim: torch.optim.Optimizer,
|
||||
action_range: Tuple[float, float],
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
alpha: Union[
|
||||
float, Tuple[float, torch.Tensor, torch.optim.Optimizer]
|
||||
] = 0.2,
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
reward_normalization: bool = False,
|
||||
ignore_done: bool = False,
|
||||
estimation_step: int = 1,
|
||||
exploration_noise: Optional[BaseNoise] = None,
|
||||
**kwargs
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(None, None, None, None, tau, gamma, exploration_noise,
|
||||
action_range, reward_normalization, ignore_done,
|
||||
super().__init__(None, None, None, None, action_range, tau, gamma,
|
||||
exploration_noise, reward_normalization, ignore_done,
|
||||
estimation_step, **kwargs)
|
||||
self.actor, self.actor_optim = actor, actor_optim
|
||||
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
||||
@ -79,6 +79,7 @@ class SACPolicy(DDPGPolicy):
|
||||
self.critic2_optim = critic2_optim
|
||||
|
||||
self._is_auto_alpha = False
|
||||
self._alpha: Union[float, torch.Tensor]
|
||||
if isinstance(alpha, tuple):
|
||||
self._is_auto_alpha = True
|
||||
self._target_entropy, self._log_alpha, self._alpha_optim = alpha
|
||||
@ -89,7 +90,7 @@ class SACPolicy(DDPGPolicy):
|
||||
|
||||
self.__eps = np.finfo(np.float32).eps.item()
|
||||
|
||||
def train(self, mode=True) -> torch.nn.Module:
|
||||
def train(self, mode: bool = True) -> "SACPolicy":
|
||||
self.training = mode
|
||||
self.actor.train(mode)
|
||||
self.critic1.train(mode)
|
||||
@ -98,17 +99,22 @@ class SACPolicy(DDPGPolicy):
|
||||
|
||||
def sync_weight(self) -> None:
|
||||
for o, n in zip(
|
||||
self.critic1_old.parameters(), self.critic1.parameters()):
|
||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||
self.critic1_old.parameters(), self.critic1.parameters()
|
||||
):
|
||||
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
|
||||
for o, n in zip(
|
||||
self.critic2_old.parameters(), self.critic2.parameters()):
|
||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||
self.critic2_old.parameters(), self.critic2.parameters()
|
||||
):
|
||||
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
|
||||
|
||||
def forward(self, batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
input: str = 'obs',
|
||||
explorating: bool = True,
|
||||
**kwargs) -> Batch:
|
||||
def forward(
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
input: str = "obs",
|
||||
explorating: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
obs = getattr(batch, input)
|
||||
logits, h = self.actor(obs, state=state, info=batch.info)
|
||||
assert isinstance(logits, tuple)
|
||||
@ -125,8 +131,9 @@ class SACPolicy(DDPGPolicy):
|
||||
return Batch(
|
||||
logits=logits, act=act, state=h, dist=dist, log_prob=log_prob)
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> torch.Tensor:
|
||||
def _target_q(
|
||||
self, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs: s_{t+n}
|
||||
with torch.no_grad():
|
||||
obs_next_result = self(batch, input='obs_next', explorating=False)
|
||||
@ -138,8 +145,8 @@ class SACPolicy(DDPGPolicy):
|
||||
) - self._alpha * obs_next_result.log_prob
|
||||
return target_q
|
||||
|
||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||
weight = batch.pop('weight', 1.)
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
weight = batch.pop("weight", 1.0)
|
||||
# critic 1
|
||||
current_q1 = self.critic1(batch.obs, batch.act).flatten()
|
||||
target_q = batch.returns.flatten()
|
||||
@ -157,7 +164,7 @@ class SACPolicy(DDPGPolicy):
|
||||
self.critic2_optim.zero_grad()
|
||||
critic2_loss.backward()
|
||||
self.critic2_optim.step()
|
||||
batch.weight = (td1 + td2) / 2. # prio-buffer
|
||||
batch.weight = (td1 + td2) / 2.0 # prio-buffer
|
||||
# actor
|
||||
obs_result = self(batch, explorating=False)
|
||||
a = obs_result.act
|
||||
@ -180,11 +187,11 @@ class SACPolicy(DDPGPolicy):
|
||||
self.sync_weight()
|
||||
|
||||
result = {
|
||||
'loss/actor': actor_loss.item(),
|
||||
'loss/critic1': critic1_loss.item(),
|
||||
'loss/critic2': critic2_loss.item(),
|
||||
"loss/actor": actor_loss.item(),
|
||||
"loss/critic1": critic1_loss.item(),
|
||||
"loss/critic2": critic2_loss.item(),
|
||||
}
|
||||
if self._is_auto_alpha:
|
||||
result['loss/alpha'] = alpha_loss.item()
|
||||
result['v/alpha'] = self._alpha.item()
|
||||
result["loss/alpha"] = alpha_loss.item()
|
||||
result["v/alpha"] = self._alpha.item()
|
||||
return result
|
||||
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Tuple, Optional
|
||||
from typing import Any, Dict, Tuple, Optional
|
||||
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
@ -22,6 +22,8 @@ class TD3Policy(DDPGPolicy):
|
||||
a))
|
||||
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
|
||||
critic network.
|
||||
:param action_range: the action range (minimum, maximum).
|
||||
:type action_range: Tuple[float, float]
|
||||
:param float tau: param for soft update of the target network, defaults to
|
||||
0.005.
|
||||
:param float gamma: discount factor, in [0, 1], defaults to 0.99.
|
||||
@ -33,8 +35,6 @@ class TD3Policy(DDPGPolicy):
|
||||
default to 2.
|
||||
:param float noise_clip: the clipping range used in updating policy
|
||||
network, default to 0.5.
|
||||
:param action_range: the action range (minimum, maximum).
|
||||
:type action_range: (float, float)
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||
defaults to False.
|
||||
:param bool ignore_done: ignore the done flag while training the policy,
|
||||
@ -46,27 +46,28 @@ class TD3Policy(DDPGPolicy):
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
actor: torch.nn.Module,
|
||||
actor_optim: torch.optim.Optimizer,
|
||||
critic1: torch.nn.Module,
|
||||
critic1_optim: torch.optim.Optimizer,
|
||||
critic2: torch.nn.Module,
|
||||
critic2_optim: torch.optim.Optimizer,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
exploration_noise: Optional[BaseNoise]
|
||||
= GaussianNoise(sigma=0.1),
|
||||
policy_noise: float = 0.2,
|
||||
update_actor_freq: int = 2,
|
||||
noise_clip: float = 0.5,
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
reward_normalization: bool = False,
|
||||
ignore_done: bool = False,
|
||||
estimation_step: int = 1,
|
||||
**kwargs) -> None:
|
||||
super().__init__(actor, actor_optim, None, None, tau, gamma,
|
||||
exploration_noise, action_range, reward_normalization,
|
||||
def __init__(
|
||||
self,
|
||||
actor: torch.nn.Module,
|
||||
actor_optim: torch.optim.Optimizer,
|
||||
critic1: torch.nn.Module,
|
||||
critic1_optim: torch.optim.Optimizer,
|
||||
critic2: torch.nn.Module,
|
||||
critic2_optim: torch.optim.Optimizer,
|
||||
action_range: Tuple[float, float],
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
|
||||
policy_noise: float = 0.2,
|
||||
update_actor_freq: int = 2,
|
||||
noise_clip: float = 0.5,
|
||||
reward_normalization: bool = False,
|
||||
ignore_done: bool = False,
|
||||
estimation_step: int = 1,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(actor, actor_optim, None, None, action_range,
|
||||
tau, gamma, exploration_noise, reward_normalization,
|
||||
ignore_done, estimation_step, **kwargs)
|
||||
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
||||
self.critic1_old.eval()
|
||||
@ -80,7 +81,7 @@ class TD3Policy(DDPGPolicy):
|
||||
self._cnt = 0
|
||||
self._last = 0
|
||||
|
||||
def train(self, mode=True) -> torch.nn.Module:
|
||||
def train(self, mode: bool = True) -> "TD3Policy":
|
||||
self.training = mode
|
||||
self.actor.train(mode)
|
||||
self.critic1.train(mode)
|
||||
@ -89,22 +90,25 @@ class TD3Policy(DDPGPolicy):
|
||||
|
||||
def sync_weight(self) -> None:
|
||||
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
|
||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
|
||||
for o, n in zip(
|
||||
self.critic1_old.parameters(), self.critic1.parameters()):
|
||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||
self.critic1_old.parameters(), self.critic1.parameters()
|
||||
):
|
||||
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
|
||||
for o, n in zip(
|
||||
self.critic2_old.parameters(), self.critic2.parameters()):
|
||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||
self.critic2_old.parameters(), self.critic2.parameters()
|
||||
):
|
||||
o.data.copy_(o.data * (1.0 - self._tau) + n.data * self._tau)
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> torch.Tensor:
|
||||
def _target_q(
|
||||
self, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs: s_{t+n}
|
||||
with torch.no_grad():
|
||||
a_ = self(batch, model='actor_old', input='obs_next').act
|
||||
a_ = self(batch, model="actor_old", input="obs_next").act
|
||||
dev = a_.device
|
||||
noise = torch.randn(size=a_.shape, device=dev) * self._policy_noise
|
||||
if self._noise_clip > 0:
|
||||
if self._noise_clip > 0.0:
|
||||
noise = noise.clamp(-self._noise_clip, self._noise_clip)
|
||||
a_ += noise
|
||||
a_ = a_.clamp(self._range[0], self._range[1])
|
||||
@ -113,8 +117,8 @@ class TD3Policy(DDPGPolicy):
|
||||
self.critic2_old(batch.obs_next, a_))
|
||||
return target_q
|
||||
|
||||
def learn(self, batch: Batch, **kwargs) -> Dict[str, float]:
|
||||
weight = batch.pop('weight', 1.)
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
weight = batch.pop("weight", 1.0)
|
||||
# critic 1
|
||||
current_q1 = self.critic1(batch.obs, batch.act).flatten()
|
||||
target_q = batch.returns.flatten()
|
||||
@ -132,10 +136,10 @@ class TD3Policy(DDPGPolicy):
|
||||
self.critic2_optim.zero_grad()
|
||||
critic2_loss.backward()
|
||||
self.critic2_optim.step()
|
||||
batch.weight = (td1 + td2) / 2. # prio-buffer
|
||||
batch.weight = (td1 + td2) / 2.0 # prio-buffer
|
||||
if self._cnt % self._freq == 0:
|
||||
actor_loss = -self.critic1(
|
||||
batch.obs, self(batch, eps=0).act).mean()
|
||||
batch.obs, self(batch, eps=0.0).act).mean()
|
||||
self.actor_optim.zero_grad()
|
||||
actor_loss.backward()
|
||||
self._last = actor_loss.item()
|
||||
@ -143,7 +147,7 @@ class TD3Policy(DDPGPolicy):
|
||||
self.sync_weight()
|
||||
self._cnt += 1
|
||||
return {
|
||||
'loss/actor': self._last,
|
||||
'loss/critic1': critic1_loss.item(),
|
||||
'loss/critic2': critic2_loss.item(),
|
||||
"loss/actor": self._last,
|
||||
"loss/critic1": critic1_loss.item(),
|
||||
"loss/critic2": critic2_loss.item(),
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
import numpy as np
|
||||
from typing import Union, Optional, Dict, List
|
||||
from typing import Any, Dict, List, Union, Optional
|
||||
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
@ -15,21 +15,22 @@ class MultiAgentPolicyManager(BasePolicy):
|
||||
:ref:`marl_example` can help you better understand this procedure.
|
||||
"""
|
||||
|
||||
def __init__(self, policies: List[BasePolicy]):
|
||||
super().__init__()
|
||||
def __init__(self, policies: List[BasePolicy], **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.policies = policies
|
||||
for i, policy in enumerate(policies):
|
||||
# agent_id 0 is reserved for the environment proxy
|
||||
# (this MultiAgentPolicyManager)
|
||||
policy.set_agent_id(i + 1)
|
||||
|
||||
def replace_policy(self, policy, agent_id):
|
||||
def replace_policy(self, policy: BasePolicy, agent_id: int) -> None:
|
||||
"""Replace the "agent_id"th policy in this manager."""
|
||||
self.policies[agent_id - 1] = policy
|
||||
policy.set_agent_id(agent_id)
|
||||
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> Batch:
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> Batch:
|
||||
"""Dispatch batch data from obs.agent_id to every policy's process_fn.
|
||||
|
||||
Save original multi-dimensional rew in "save_rew", set rew to the
|
||||
@ -46,21 +47,24 @@ class MultiAgentPolicyManager(BasePolicy):
|
||||
for policy in self.policies:
|
||||
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0]
|
||||
if len(agent_index) == 0:
|
||||
results[f'agent_{policy.agent_id}'] = Batch()
|
||||
results[f"agent_{policy.agent_id}"] = Batch()
|
||||
continue
|
||||
tmp_batch, tmp_indice = batch[agent_index], indice[agent_index]
|
||||
if has_rew:
|
||||
tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1]
|
||||
buffer._meta.rew = save_rew[:, policy.agent_id - 1]
|
||||
results[f'agent_{policy.agent_id}'] = \
|
||||
policy.process_fn(tmp_batch, buffer, tmp_indice)
|
||||
results[f"agent_{policy.agent_id}"] = policy.process_fn(
|
||||
tmp_batch, buffer, tmp_indice)
|
||||
if has_rew: # restore from save_rew
|
||||
buffer._meta.rew = save_rew
|
||||
return Batch(results)
|
||||
|
||||
def forward(self, batch: Batch,
|
||||
state: Optional[Union[dict, Batch]] = None,
|
||||
**kwargs) -> Batch:
|
||||
def forward(
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
"""Dispatch batch data from obs.agent_id to every policy's forward.
|
||||
|
||||
:param state: if None, it means all agents have no state. If not
|
||||
@ -107,15 +111,15 @@ class MultiAgentPolicyManager(BasePolicy):
|
||||
**kwargs)
|
||||
act = out.act
|
||||
each_state = out.state \
|
||||
if (hasattr(out, 'state') and out.state is not None) \
|
||||
if (hasattr(out, "state") and out.state is not None) \
|
||||
else Batch()
|
||||
results.append((True, agent_index, out, act, each_state))
|
||||
holder = Batch.cat([{'act': act} for
|
||||
holder = Batch.cat([{"act": act} for
|
||||
(has_data, agent_index, out, act, each_state)
|
||||
in results if has_data])
|
||||
state_dict, out_dict = {}, {}
|
||||
for policy, (has_data, agent_index, out, act, state) in \
|
||||
zip(self.policies, results):
|
||||
for policy, (has_data, agent_index, out, act, state) in zip(
|
||||
self.policies, results):
|
||||
if has_data:
|
||||
holder.act[agent_index] = act
|
||||
state_dict["agent_" + str(policy.agent_id)] = state
|
||||
@ -124,8 +128,9 @@ class MultiAgentPolicyManager(BasePolicy):
|
||||
holder["state"] = state_dict
|
||||
return holder
|
||||
|
||||
def learn(self, batch: Batch, **kwargs
|
||||
) -> Dict[str, Union[float, List[float]]]:
|
||||
def learn(
|
||||
self, batch: Batch, **kwargs: Any
|
||||
) -> Dict[str, Union[float, List[float]]]:
|
||||
"""Dispatch the data to all policies for learning.
|
||||
|
||||
:return: a dict with the following contents:
|
||||
@ -142,9 +147,9 @@ class MultiAgentPolicyManager(BasePolicy):
|
||||
"""
|
||||
results = {}
|
||||
for policy in self.policies:
|
||||
data = batch[f'agent_{policy.agent_id}']
|
||||
data = batch[f"agent_{policy.agent_id}"]
|
||||
if not data.is_empty():
|
||||
out = policy.learn(batch=data, **kwargs)
|
||||
for k, v in out.items():
|
||||
results["agent_" + str(policy.agent_id) + '/' + k] = v
|
||||
results["agent_" + str(policy.agent_id) + "/" + k] = v
|
||||
return results
|
||||
|
@ -1,5 +1,5 @@
|
||||
import numpy as np
|
||||
from typing import Union, Optional, Dict, List
|
||||
from typing import Any, Dict, Union, Optional
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import BasePolicy
|
||||
@ -11,9 +11,12 @@ class RandomPolicy(BasePolicy):
|
||||
It randomly chooses an action from the legal action.
|
||||
"""
|
||||
|
||||
def forward(self, batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
**kwargs) -> Batch:
|
||||
def forward(
|
||||
self,
|
||||
batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
"""Compute the random action over the given batch data.
|
||||
|
||||
The input should contain a mask in batch.obs, with "True" to be
|
||||
@ -34,7 +37,6 @@ class RandomPolicy(BasePolicy):
|
||||
logits[~mask] = -np.inf
|
||||
return Batch(act=logits.argmax(axis=-1))
|
||||
|
||||
def learn(self, batch: Batch, **kwargs
|
||||
) -> Dict[str, Union[float, List[float]]]:
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
"""Since a random agent learn nothing, it returns an empty dict."""
|
||||
return {}
|
||||
|
@ -3,8 +3,8 @@ from tianshou.trainer.onpolicy import onpolicy_trainer
|
||||
from tianshou.trainer.offpolicy import offpolicy_trainer
|
||||
|
||||
__all__ = [
|
||||
'gather_info',
|
||||
'test_episode',
|
||||
'onpolicy_trainer',
|
||||
'offpolicy_trainer',
|
||||
"gather_info",
|
||||
"test_episode",
|
||||
"onpolicy_trainer",
|
||||
"offpolicy_trainer",
|
||||
]
|
||||
|
@ -10,23 +10,23 @@ from tianshou.trainer import test_episode, gather_info
|
||||
|
||||
|
||||
def offpolicy_trainer(
|
||||
policy: BasePolicy,
|
||||
train_collector: Collector,
|
||||
test_collector: Collector,
|
||||
max_epoch: int,
|
||||
step_per_epoch: int,
|
||||
collect_per_step: int,
|
||||
episode_per_test: Union[int, List[int]],
|
||||
batch_size: int,
|
||||
update_per_step: int = 1,
|
||||
train_fn: Optional[Callable[[int], None]] = None,
|
||||
test_fn: Optional[Callable[[int], None]] = None,
|
||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
writer: Optional[SummaryWriter] = None,
|
||||
log_interval: int = 1,
|
||||
verbose: bool = True,
|
||||
test_in_train: bool = True,
|
||||
policy: BasePolicy,
|
||||
train_collector: Collector,
|
||||
test_collector: Collector,
|
||||
max_epoch: int,
|
||||
step_per_epoch: int,
|
||||
collect_per_step: int,
|
||||
episode_per_test: Union[int, List[int]],
|
||||
batch_size: int,
|
||||
update_per_step: int = 1,
|
||||
train_fn: Optional[Callable[[int], None]] = None,
|
||||
test_fn: Optional[Callable[[int], None]] = None,
|
||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
writer: Optional[SummaryWriter] = None,
|
||||
log_interval: int = 1,
|
||||
verbose: bool = True,
|
||||
test_in_train: bool = True,
|
||||
) -> Dict[str, Union[float, str]]:
|
||||
"""A wrapper for off-policy trainer procedure.
|
||||
|
||||
@ -72,7 +72,7 @@ def offpolicy_trainer(
|
||||
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||
"""
|
||||
global_step = 0
|
||||
best_epoch, best_reward = -1, -1.
|
||||
best_epoch, best_reward = -1, -1.0
|
||||
stat = {}
|
||||
start_time = time.time()
|
||||
test_in_train = test_in_train and train_collector.policy == policy
|
||||
@ -81,42 +81,43 @@ def offpolicy_trainer(
|
||||
policy.train()
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
|
||||
**tqdm_config) as t:
|
||||
with tqdm.tqdm(
|
||||
total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
|
||||
) as t:
|
||||
while t.n < t.total:
|
||||
result = train_collector.collect(n_step=collect_per_step)
|
||||
data = {}
|
||||
if test_in_train and stop_fn and stop_fn(result['rew']):
|
||||
if test_in_train and stop_fn and stop_fn(result["rew"]):
|
||||
test_result = test_episode(
|
||||
policy, test_collector, test_fn,
|
||||
epoch, episode_per_test, writer, global_step)
|
||||
if stop_fn and stop_fn(test_result['rew']):
|
||||
if stop_fn and stop_fn(test_result["rew"]):
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
for k in result.keys():
|
||||
data[k] = f'{result[k]:.2f}'
|
||||
data[k] = f"{result[k]:.2f}"
|
||||
t.set_postfix(**data)
|
||||
return gather_info(
|
||||
start_time, train_collector, test_collector,
|
||||
test_result['rew'])
|
||||
test_result["rew"])
|
||||
else:
|
||||
policy.train()
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
for i in range(update_per_step * min(
|
||||
result['n/st'] // collect_per_step, t.total - t.n)):
|
||||
result["n/st"] // collect_per_step, t.total - t.n)):
|
||||
global_step += collect_per_step
|
||||
losses = policy.update(batch_size, train_collector.buffer)
|
||||
for k in result.keys():
|
||||
data[k] = f'{result[k]:.2f}'
|
||||
data[k] = f"{result[k]:.2f}"
|
||||
if writer and global_step % log_interval == 0:
|
||||
writer.add_scalar('train/' + k, result[k],
|
||||
writer.add_scalar("train/" + k, result[k],
|
||||
global_step=global_step)
|
||||
for k in losses.keys():
|
||||
if stat.get(k) is None:
|
||||
stat[k] = MovAvg()
|
||||
stat[k].add(losses[k])
|
||||
data[k] = f'{stat[k].get():.6f}'
|
||||
data[k] = f"{stat[k].get():.6f}"
|
||||
if writer and global_step % log_interval == 0:
|
||||
writer.add_scalar(
|
||||
k, stat[k].get(), global_step=global_step)
|
||||
@ -127,14 +128,14 @@ def offpolicy_trainer(
|
||||
# test
|
||||
result = test_episode(policy, test_collector, test_fn, epoch,
|
||||
episode_per_test, writer, global_step)
|
||||
if best_epoch == -1 or best_reward < result['rew']:
|
||||
best_reward = result['rew']
|
||||
if best_epoch == -1 or best_reward < result["rew"]:
|
||||
best_reward = result["rew"]
|
||||
best_epoch = epoch
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
if verbose:
|
||||
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
|
||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||
print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f}, "
|
||||
f"best_reward: {best_reward:.6f} in #{best_epoch}")
|
||||
if stop_fn and stop_fn(best_reward):
|
||||
break
|
||||
return gather_info(
|
||||
|
@ -10,23 +10,23 @@ from tianshou.trainer import test_episode, gather_info
|
||||
|
||||
|
||||
def onpolicy_trainer(
|
||||
policy: BasePolicy,
|
||||
train_collector: Collector,
|
||||
test_collector: Collector,
|
||||
max_epoch: int,
|
||||
step_per_epoch: int,
|
||||
collect_per_step: int,
|
||||
repeat_per_collect: int,
|
||||
episode_per_test: Union[int, List[int]],
|
||||
batch_size: int,
|
||||
train_fn: Optional[Callable[[int], None]] = None,
|
||||
test_fn: Optional[Callable[[int], None]] = None,
|
||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
writer: Optional[SummaryWriter] = None,
|
||||
log_interval: int = 1,
|
||||
verbose: bool = True,
|
||||
test_in_train: bool = True,
|
||||
policy: BasePolicy,
|
||||
train_collector: Collector,
|
||||
test_collector: Collector,
|
||||
max_epoch: int,
|
||||
step_per_epoch: int,
|
||||
collect_per_step: int,
|
||||
repeat_per_collect: int,
|
||||
episode_per_test: Union[int, List[int]],
|
||||
batch_size: int,
|
||||
train_fn: Optional[Callable[[int], None]] = None,
|
||||
test_fn: Optional[Callable[[int], None]] = None,
|
||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
writer: Optional[SummaryWriter] = None,
|
||||
log_interval: int = 1,
|
||||
verbose: bool = True,
|
||||
test_in_train: bool = True,
|
||||
) -> Dict[str, Union[float, str]]:
|
||||
"""A wrapper for on-policy trainer procedure.
|
||||
|
||||
@ -72,7 +72,7 @@ def onpolicy_trainer(
|
||||
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||
"""
|
||||
global_step = 0
|
||||
best_epoch, best_reward = -1, -1.
|
||||
best_epoch, best_reward = -1, -1.0
|
||||
stat = {}
|
||||
start_time = time.time()
|
||||
test_in_train = test_in_train and train_collector.policy == policy
|
||||
@ -81,30 +81,32 @@ def onpolicy_trainer(
|
||||
policy.train()
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
|
||||
**tqdm_config) as t:
|
||||
with tqdm.tqdm(
|
||||
total=step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
|
||||
) as t:
|
||||
while t.n < t.total:
|
||||
result = train_collector.collect(n_episode=collect_per_step)
|
||||
data = {}
|
||||
if test_in_train and stop_fn and stop_fn(result['rew']):
|
||||
if test_in_train and stop_fn and stop_fn(result["rew"]):
|
||||
test_result = test_episode(
|
||||
policy, test_collector, test_fn,
|
||||
epoch, episode_per_test, writer, global_step)
|
||||
if stop_fn and stop_fn(test_result['rew']):
|
||||
if stop_fn and stop_fn(test_result["rew"]):
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
for k in result.keys():
|
||||
data[k] = f'{result[k]:.2f}'
|
||||
data[k] = f"{result[k]:.2f}"
|
||||
t.set_postfix(**data)
|
||||
return gather_info(
|
||||
start_time, train_collector, test_collector,
|
||||
test_result['rew'])
|
||||
test_result["rew"])
|
||||
else:
|
||||
policy.train()
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
losses = policy.update(
|
||||
0, train_collector.buffer, batch_size, repeat_per_collect)
|
||||
0, train_collector.buffer,
|
||||
batch_size=batch_size, repeat=repeat_per_collect)
|
||||
train_collector.reset_buffer()
|
||||
step = 1
|
||||
for k in losses.keys():
|
||||
@ -112,15 +114,15 @@ def onpolicy_trainer(
|
||||
step = max(step, len(losses[k]))
|
||||
global_step += step * collect_per_step
|
||||
for k in result.keys():
|
||||
data[k] = f'{result[k]:.2f}'
|
||||
data[k] = f"{result[k]:.2f}"
|
||||
if writer and global_step % log_interval == 0:
|
||||
writer.add_scalar(
|
||||
'train/' + k, result[k], global_step=global_step)
|
||||
"train/" + k, result[k], global_step=global_step)
|
||||
for k in losses.keys():
|
||||
if stat.get(k) is None:
|
||||
stat[k] = MovAvg()
|
||||
stat[k].add(losses[k])
|
||||
data[k] = f'{stat[k].get():.6f}'
|
||||
data[k] = f"{stat[k].get():.6f}"
|
||||
if writer and global_step % log_interval == 0:
|
||||
writer.add_scalar(
|
||||
k, stat[k].get(), global_step=global_step)
|
||||
@ -131,14 +133,14 @@ def onpolicy_trainer(
|
||||
# test
|
||||
result = test_episode(policy, test_collector, test_fn, epoch,
|
||||
episode_per_test, writer, global_step)
|
||||
if best_epoch == -1 or best_reward < result['rew']:
|
||||
best_reward = result['rew']
|
||||
if best_epoch == -1 or best_reward < result["rew"]:
|
||||
best_reward = result["rew"]
|
||||
best_epoch = epoch
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
if verbose:
|
||||
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
|
||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||
print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f}, "
|
||||
f"best_reward: {best_reward:.6f} in #{best_epoch}")
|
||||
if stop_fn and stop_fn(best_reward):
|
||||
break
|
||||
return gather_info(
|
||||
|
@ -8,13 +8,14 @@ from tianshou.policy import BasePolicy
|
||||
|
||||
|
||||
def test_episode(
|
||||
policy: BasePolicy,
|
||||
collector: Collector,
|
||||
test_fn: Optional[Callable[[int], None]],
|
||||
epoch: int,
|
||||
n_episode: Union[int, List[int]],
|
||||
writer: SummaryWriter = None,
|
||||
global_step: int = None) -> Dict[str, float]:
|
||||
policy: BasePolicy,
|
||||
collector: Collector,
|
||||
test_fn: Optional[Callable[[int], None]],
|
||||
epoch: int,
|
||||
n_episode: Union[int, List[int]],
|
||||
writer: Optional[SummaryWriter] = None,
|
||||
global_step: Optional[int] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""A simple wrapper of testing policy in collector."""
|
||||
collector.reset_env()
|
||||
collector.reset_buffer()
|
||||
@ -29,15 +30,16 @@ def test_episode(
|
||||
result = collector.collect(n_episode=n_episode)
|
||||
if writer is not None and global_step is not None:
|
||||
for k in result.keys():
|
||||
writer.add_scalar('test/' + k, result[k], global_step=global_step)
|
||||
writer.add_scalar("test/" + k, result[k], global_step=global_step)
|
||||
return result
|
||||
|
||||
|
||||
def gather_info(start_time: float,
|
||||
train_c: Collector,
|
||||
test_c: Collector,
|
||||
best_reward: float
|
||||
) -> Dict[str, Union[float, str]]:
|
||||
def gather_info(
|
||||
start_time: float,
|
||||
train_c: Collector,
|
||||
test_c: Collector,
|
||||
best_reward: float,
|
||||
) -> Dict[str, Union[float, str]]:
|
||||
"""A simple wrapper of gathering information from collectors.
|
||||
|
||||
:return: A dictionary with the following keys:
|
||||
@ -60,15 +62,15 @@ def gather_info(start_time: float,
|
||||
train_speed = train_c.collect_step / (duration - test_c.collect_time)
|
||||
test_speed = test_c.collect_step / test_c.collect_time
|
||||
return {
|
||||
'train_step': train_c.collect_step,
|
||||
'train_episode': train_c.collect_episode,
|
||||
'train_time/collector': f'{train_c.collect_time:.2f}s',
|
||||
'train_time/model': f'{model_time:.2f}s',
|
||||
'train_speed': f'{train_speed:.2f} step/s',
|
||||
'test_step': test_c.collect_step,
|
||||
'test_episode': test_c.collect_episode,
|
||||
'test_time': f'{test_c.collect_time:.2f}s',
|
||||
'test_speed': f'{test_speed:.2f} step/s',
|
||||
'best_reward': best_reward,
|
||||
'duration': f'{duration:.2f}s',
|
||||
"train_step": train_c.collect_step,
|
||||
"train_episode": train_c.collect_episode,
|
||||
"train_time/collector": f"{train_c.collect_time:.2f}s",
|
||||
"train_time/model": f"{model_time:.2f}s",
|
||||
"train_speed": f"{train_speed:.2f} step/s",
|
||||
"test_step": test_c.collect_step,
|
||||
"test_episode": test_c.collect_episode,
|
||||
"test_time": f"{test_c.collect_time:.2f}s",
|
||||
"test_speed": f"{test_speed:.2f} step/s",
|
||||
"best_reward": best_reward,
|
||||
"duration": f"{duration:.2f}s",
|
||||
}
|
||||
|
@ -1,9 +1,7 @@
|
||||
from tianshou.utils.config import tqdm_config
|
||||
from tianshou.utils.compile import pre_compile
|
||||
from tianshou.utils.moving_average import MovAvg
|
||||
|
||||
__all__ = [
|
||||
"MovAvg",
|
||||
"pre_compile",
|
||||
"tqdm_config",
|
||||
]
|
||||
|
@ -1,27 +0,0 @@
|
||||
import numpy as np
|
||||
|
||||
from tianshou.policy.base import _episodic_return, _nstep_return
|
||||
from tianshou.data.utils.segtree import _reduce, _setitem, _get_prefix_sum_idx
|
||||
|
||||
|
||||
def pre_compile():
|
||||
"""Functions that need to pre-compile for producing benchmark result.
|
||||
|
||||
Since Numba acceleration needs to compile the function in the first run,
|
||||
here we use some fake data for the common-type function-call compilation.
|
||||
Otherwise, the current training speed cannot compare with the previous.
|
||||
"""
|
||||
f64 = np.array([0, 1], dtype=np.float64)
|
||||
f32 = np.array([0, 1], dtype=np.float32)
|
||||
b = np.array([False, True], dtype=np.bool_)
|
||||
i64 = np.array([0, 1], dtype=np.int64)
|
||||
# returns
|
||||
_episodic_return(f64, f64, b, .1, .1)
|
||||
_episodic_return(f32, f64, b, .1, .1)
|
||||
_nstep_return(f64, b, f32, i64, .1, 1, 4, 1., 0.)
|
||||
# segtree
|
||||
_setitem(f64, i64, f64)
|
||||
_setitem(f64, i64, f32)
|
||||
_reduce(f64, 0, 1)
|
||||
_get_prefix_sum_idx(f64, 1, f64)
|
||||
_get_prefix_sum_idx(f32, 1, f64)
|
@ -1,4 +1,4 @@
|
||||
tqdm_config = {
|
||||
'dynamic_ncols': True,
|
||||
'ascii': True,
|
||||
"dynamic_ncols": True,
|
||||
"ascii": True,
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from numbers import Number
|
||||
from typing import Union
|
||||
|
||||
from tianshou.data import to_numpy
|
||||
@ -30,7 +31,9 @@ class MovAvg(object):
|
||||
self.cache = []
|
||||
self.banned = [np.inf, np.nan, -np.inf]
|
||||
|
||||
def add(self, x: Union[float, list, np.ndarray, torch.Tensor]) -> float:
|
||||
def add(
|
||||
self, x: Union[Number, np.number, list, np.ndarray, torch.Tensor]
|
||||
) -> np.number:
|
||||
"""Add a scalar into :class:`MovAvg`.
|
||||
|
||||
You can add ``torch.Tensor`` with only one element, a python scalar, or
|
||||
@ -39,26 +42,26 @@ class MovAvg(object):
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = to_numpy(x.flatten())
|
||||
if isinstance(x, list) or isinstance(x, np.ndarray):
|
||||
for _ in x:
|
||||
if _ not in self.banned:
|
||||
self.cache.append(_)
|
||||
for i in x:
|
||||
if i not in self.banned:
|
||||
self.cache.append(i)
|
||||
elif x not in self.banned:
|
||||
self.cache.append(x)
|
||||
if self.size > 0 and len(self.cache) > self.size:
|
||||
self.cache = self.cache[-self.size:]
|
||||
return self.get()
|
||||
|
||||
def get(self) -> float:
|
||||
def get(self) -> np.number:
|
||||
"""Get the average."""
|
||||
if len(self.cache) == 0:
|
||||
return 0
|
||||
return np.mean(self.cache)
|
||||
|
||||
def mean(self) -> float:
|
||||
def mean(self) -> np.number:
|
||||
"""Get the average. Same as :meth:`get`."""
|
||||
return self.get()
|
||||
|
||||
def std(self) -> float:
|
||||
def std(self) -> np.number:
|
||||
"""Get the standard deviation."""
|
||||
if len(self.cache) == 0:
|
||||
return 0
|
||||
|
@ -1,13 +1,16 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from typing import List, Tuple, Union, Optional
|
||||
from typing import Any, Dict, List, Tuple, Union, Callable, Optional, Sequence
|
||||
|
||||
from tianshou.data import to_torch
|
||||
|
||||
|
||||
def miniblock(inp: int, oup: int,
|
||||
norm_layer: nn.modules.Module) -> List[nn.modules.Module]:
|
||||
def miniblock(
|
||||
inp: int,
|
||||
oup: int,
|
||||
norm_layer: Optional[Callable[[int], nn.modules.Module]],
|
||||
) -> List[nn.modules.Module]:
|
||||
"""Construct a miniblock with given input/output-size and norm layer."""
|
||||
ret = [nn.Linear(inp, oup)]
|
||||
if norm_layer is not None:
|
||||
@ -27,18 +30,22 @@ class Net(nn.Module):
|
||||
shape, but affects the input shape.
|
||||
:param bool dueling: whether to use dueling network to calculate Q values
|
||||
(for Dueling DQN), defaults to False.
|
||||
:param nn.modules.Module norm_layer: use which normalization before ReLU,
|
||||
e.g., ``nn.LayerNorm`` and ``nn.BatchNorm1d``, defaults to None.
|
||||
:param norm_layer: use which normalization before ReLU, e.g.,
|
||||
``nn.LayerNorm`` and ``nn.BatchNorm1d``, defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, layer_num: int, state_shape: tuple,
|
||||
action_shape: Optional[Union[tuple, int]] = 0,
|
||||
device: Union[str, torch.device] = 'cpu',
|
||||
softmax: bool = False,
|
||||
concat: bool = False,
|
||||
hidden_layer_size: int = 128,
|
||||
dueling: Optional[Tuple[int, int]] = None,
|
||||
norm_layer: Optional[nn.modules.Module] = None):
|
||||
def __init__(
|
||||
self,
|
||||
layer_num: int,
|
||||
state_shape: tuple,
|
||||
action_shape: Optional[Union[tuple, int]] = 0,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
softmax: bool = False,
|
||||
concat: bool = False,
|
||||
hidden_layer_size: int = 128,
|
||||
dueling: Optional[Tuple[int, int]] = None,
|
||||
norm_layer: Optional[Callable[[int], nn.modules.Module]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.dueling = dueling
|
||||
@ -78,7 +85,12 @@ class Net(nn.Module):
|
||||
self.V = nn.Sequential(*self.V)
|
||||
self.model = nn.Sequential(*self.model)
|
||||
|
||||
def forward(self, s, state=None, info={}):
|
||||
def forward(
|
||||
self,
|
||||
s: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Any] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
"""Mapping: s -> flatten -> logits."""
|
||||
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||
s = s.reshape(s.size(0), -1)
|
||||
@ -98,19 +110,33 @@ class Recurrent(nn.Module):
|
||||
:ref:`build_the_network`.
|
||||
"""
|
||||
|
||||
def __init__(self, layer_num, state_shape, action_shape,
|
||||
device='cpu', hidden_layer_size=128):
|
||||
def __init__(
|
||||
self,
|
||||
layer_num: int,
|
||||
state_shape: Sequence[int],
|
||||
action_shape: Sequence[int],
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
hidden_layer_size: int = 128,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.state_shape = state_shape
|
||||
self.action_shape = action_shape
|
||||
self.device = device
|
||||
self.nn = nn.LSTM(input_size=hidden_layer_size,
|
||||
hidden_size=hidden_layer_size,
|
||||
num_layers=layer_num, batch_first=True)
|
||||
self.nn = nn.LSTM(
|
||||
input_size=hidden_layer_size,
|
||||
hidden_size=hidden_layer_size,
|
||||
num_layers=layer_num,
|
||||
batch_first=True,
|
||||
)
|
||||
self.fc1 = nn.Linear(np.prod(state_shape), hidden_layer_size)
|
||||
self.fc2 = nn.Linear(hidden_layer_size, np.prod(action_shape))
|
||||
|
||||
def forward(self, s, state=None, info={}):
|
||||
def forward(
|
||||
self,
|
||||
s: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Dict[str, torch.Tensor]] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
"""Mapping: s -> flatten -> logits.
|
||||
|
||||
In the evaluation mode, s should be with shape ``[bsz, dim]``; in the
|
||||
@ -130,9 +156,9 @@ class Recurrent(nn.Module):
|
||||
else:
|
||||
# we store the stack data in [bsz, len, ...] format
|
||||
# but pytorch rnn needs [len, bsz, ...]
|
||||
s, (h, c) = self.nn(s, (state['h'].transpose(0, 1).contiguous(),
|
||||
state['c'].transpose(0, 1).contiguous()))
|
||||
s, (h, c) = self.nn(s, (state["h"].transpose(0, 1).contiguous(),
|
||||
state["c"].transpose(0, 1).contiguous()))
|
||||
s = self.fc2(s[:, -1])
|
||||
# please ensure the first dim is batch size: [bsz, len, ...]
|
||||
return s, {'h': h.transpose(0, 1).detach(),
|
||||
'c': c.transpose(0, 1).detach()}
|
||||
return s, {"h": h.transpose(0, 1).detach(),
|
||||
"c": c.transpose(0, 1).detach()}
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from typing import Any, Dict, Tuple, Union, Optional, Sequence
|
||||
|
||||
from tianshou.data import to_torch, to_torch_as
|
||||
|
||||
@ -12,14 +13,25 @@ class Actor(nn.Module):
|
||||
:ref:`build_the_network`.
|
||||
"""
|
||||
|
||||
def __init__(self, preprocess_net, action_shape, max_action=1.,
|
||||
device='cpu', hidden_layer_size=128):
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
action_shape: Sequence[int],
|
||||
max_action: float = 1.0,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
hidden_layer_size: int = 128,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.preprocess = preprocess_net
|
||||
self.last = nn.Linear(hidden_layer_size, np.prod(action_shape))
|
||||
self._max = max_action
|
||||
|
||||
def forward(self, s, state=None, info={}):
|
||||
def forward(
|
||||
self,
|
||||
s: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Any] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
"""Mapping: s -> logits -> action."""
|
||||
logits, h = self.preprocess(s, state)
|
||||
logits = self._max * torch.tanh(self.last(logits))
|
||||
@ -33,13 +45,23 @@ class Critic(nn.Module):
|
||||
:ref:`build_the_network`.
|
||||
"""
|
||||
|
||||
def __init__(self, preprocess_net, device='cpu', hidden_layer_size=128):
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
hidden_layer_size: int = 128,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.preprocess = preprocess_net
|
||||
self.last = nn.Linear(hidden_layer_size, 1)
|
||||
|
||||
def forward(self, s, a=None, info={}):
|
||||
def forward(
|
||||
self,
|
||||
s: Union[np.ndarray, torch.Tensor],
|
||||
a: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> torch.Tensor:
|
||||
"""Mapping: (s, a) -> logits -> Q(s, a)."""
|
||||
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||
s = s.flatten(1)
|
||||
@ -59,8 +81,15 @@ class ActorProb(nn.Module):
|
||||
:ref:`build_the_network`.
|
||||
"""
|
||||
|
||||
def __init__(self, preprocess_net, action_shape, max_action=1.,
|
||||
device='cpu', unbounded=False, hidden_layer_size=128):
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
action_shape: Sequence[int],
|
||||
max_action: float = 1.0,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
unbounded: bool = False,
|
||||
hidden_layer_size: int = 128,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.preprocess = preprocess_net
|
||||
self.device = device
|
||||
@ -69,7 +98,12 @@ class ActorProb(nn.Module):
|
||||
self._max = max_action
|
||||
self._unbounded = unbounded
|
||||
|
||||
def forward(self, s, state=None, info={}):
|
||||
def forward(
|
||||
self,
|
||||
s: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Any] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Any]:
|
||||
"""Mapping: s -> logits -> (mu, sigma)."""
|
||||
logits, h = self.preprocess(s, state)
|
||||
mu = self.mu(logits)
|
||||
@ -78,7 +112,7 @@ class ActorProb(nn.Module):
|
||||
shape = [1] * len(mu.shape)
|
||||
shape[1] = -1
|
||||
sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp()
|
||||
return (mu, sigma), None
|
||||
return (mu, sigma), state
|
||||
|
||||
|
||||
class RecurrentActorProb(nn.Module):
|
||||
@ -88,19 +122,35 @@ class RecurrentActorProb(nn.Module):
|
||||
:ref:`build_the_network`.
|
||||
"""
|
||||
|
||||
def __init__(self, layer_num, state_shape, action_shape, max_action=1.,
|
||||
device='cpu', unbounded=False, hidden_layer_size=128):
|
||||
def __init__(
|
||||
self,
|
||||
layer_num: int,
|
||||
state_shape: Sequence[int],
|
||||
action_shape: Sequence[int],
|
||||
max_action: float = 1.0,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
unbounded: bool = False,
|
||||
hidden_layer_size: int = 128,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.nn = nn.LSTM(input_size=np.prod(state_shape),
|
||||
hidden_size=hidden_layer_size,
|
||||
num_layers=layer_num, batch_first=True)
|
||||
self.nn = nn.LSTM(
|
||||
input_size=np.prod(state_shape),
|
||||
hidden_size=hidden_layer_size,
|
||||
num_layers=layer_num,
|
||||
batch_first=True,
|
||||
)
|
||||
self.mu = nn.Linear(hidden_layer_size, np.prod(action_shape))
|
||||
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
|
||||
self._max = max_action
|
||||
self._unbounded = unbounded
|
||||
|
||||
def forward(self, s, state=None, info={}):
|
||||
def forward(
|
||||
self,
|
||||
s: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Dict[str, torch.Tensor]] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Dict[str, torch.Tensor]]:
|
||||
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
|
||||
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||
@ -114,8 +164,8 @@ class RecurrentActorProb(nn.Module):
|
||||
else:
|
||||
# we store the stack data in [bsz, len, ...] format
|
||||
# but pytorch rnn needs [len, bsz, ...]
|
||||
s, (h, c) = self.nn(s, (state['h'].transpose(0, 1).contiguous(),
|
||||
state['c'].transpose(0, 1).contiguous()))
|
||||
s, (h, c) = self.nn(s, (state["h"].transpose(0, 1).contiguous(),
|
||||
state["c"].transpose(0, 1).contiguous()))
|
||||
logits = s[:, -1]
|
||||
mu = self.mu(logits)
|
||||
if not self._unbounded:
|
||||
@ -124,8 +174,8 @@ class RecurrentActorProb(nn.Module):
|
||||
shape[1] = -1
|
||||
sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp()
|
||||
# please ensure the first dim is batch size: [bsz, len, ...]
|
||||
return (mu, sigma), {'h': h.transpose(0, 1).detach(),
|
||||
'c': c.transpose(0, 1).detach()}
|
||||
return (mu, sigma), {"h": h.transpose(0, 1).detach(),
|
||||
"c": c.transpose(0, 1).detach()}
|
||||
|
||||
|
||||
class RecurrentCritic(nn.Module):
|
||||
@ -135,18 +185,32 @@ class RecurrentCritic(nn.Module):
|
||||
:ref:`build_the_network`.
|
||||
"""
|
||||
|
||||
def __init__(self, layer_num, state_shape,
|
||||
action_shape=0, device='cpu', hidden_layer_size=128):
|
||||
def __init__(
|
||||
self,
|
||||
layer_num: int,
|
||||
state_shape: Sequence[int],
|
||||
action_shape: Sequence[int] = [0],
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
hidden_layer_size: int = 128,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.state_shape = state_shape
|
||||
self.action_shape = action_shape
|
||||
self.device = device
|
||||
self.nn = nn.LSTM(input_size=np.prod(state_shape),
|
||||
hidden_size=hidden_layer_size,
|
||||
num_layers=layer_num, batch_first=True)
|
||||
self.nn = nn.LSTM(
|
||||
input_size=np.prod(state_shape),
|
||||
hidden_size=hidden_layer_size,
|
||||
num_layers=layer_num,
|
||||
batch_first=True,
|
||||
)
|
||||
self.fc2 = nn.Linear(hidden_layer_size + np.prod(action_shape), 1)
|
||||
|
||||
def forward(self, s, a=None):
|
||||
def forward(
|
||||
self,
|
||||
s: Union[np.ndarray, torch.Tensor],
|
||||
a: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> torch.Tensor:
|
||||
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
|
||||
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||
|
@ -2,6 +2,7 @@ import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Any, Dict, Tuple, Union, Optional, Sequence
|
||||
|
||||
|
||||
class Actor(nn.Module):
|
||||
@ -11,12 +12,22 @@ class Actor(nn.Module):
|
||||
:ref:`build_the_network`.
|
||||
"""
|
||||
|
||||
def __init__(self, preprocess_net, action_shape, hidden_layer_size=128):
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
action_shape: Sequence[int],
|
||||
hidden_layer_size: int = 128,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.preprocess = preprocess_net
|
||||
self.last = nn.Linear(hidden_layer_size, np.prod(action_shape))
|
||||
|
||||
def forward(self, s, state=None, info={}):
|
||||
def forward(
|
||||
self,
|
||||
s: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Any] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
r"""Mapping: s -> Q(s, \*)."""
|
||||
logits, h = self.preprocess(s, state)
|
||||
logits = F.softmax(self.last(logits), dim=-1)
|
||||
@ -30,14 +41,18 @@ class Critic(nn.Module):
|
||||
:ref:`build_the_network`.
|
||||
"""
|
||||
|
||||
def __init__(self, preprocess_net, hidden_layer_size=128):
|
||||
def __init__(
|
||||
self, preprocess_net: nn.Module, hidden_layer_size: int = 128
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.preprocess = preprocess_net
|
||||
self.last = nn.Linear(hidden_layer_size, 1)
|
||||
|
||||
def forward(self, s, **kwargs):
|
||||
def forward(
|
||||
self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any
|
||||
) -> torch.Tensor:
|
||||
"""Mapping: s -> V(s)."""
|
||||
logits, h = self.preprocess(s, state=kwargs.get('state', None))
|
||||
logits, h = self.preprocess(s, state=kwargs.get("state", None))
|
||||
logits = self.last(logits)
|
||||
return logits
|
||||
|
||||
@ -49,17 +64,31 @@ class DQN(nn.Module):
|
||||
:ref:`build_the_network`.
|
||||
"""
|
||||
|
||||
def __init__(self, c, h, w, action_shape, device='cpu'):
|
||||
super(DQN, self).__init__()
|
||||
def __init__(
|
||||
self,
|
||||
c: int,
|
||||
h: int,
|
||||
w: int,
|
||||
action_shape: Sequence[int],
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
|
||||
def conv2d_size_out(size, kernel_size=5, stride=2):
|
||||
def conv2d_size_out(
|
||||
size: int, kernel_size: int = 5, stride: int = 2
|
||||
) -> int:
|
||||
return (size - (kernel_size - 1) - 1) // stride + 1
|
||||
|
||||
def conv2d_layers_size_out(size,
|
||||
kernel_size_1=8, stride_1=4,
|
||||
kernel_size_2=4, stride_2=2,
|
||||
kernel_size_3=3, stride_3=1):
|
||||
def conv2d_layers_size_out(
|
||||
size: int,
|
||||
kernel_size_1: int = 8,
|
||||
stride_1: int = 4,
|
||||
kernel_size_2: int = 4,
|
||||
stride_2: int = 2,
|
||||
kernel_size_3: int = 3,
|
||||
stride_3: int = 1,
|
||||
) -> int:
|
||||
size = conv2d_size_out(size, kernel_size_1, stride_1)
|
||||
size = conv2d_size_out(size, kernel_size_2, stride_2)
|
||||
size = conv2d_size_out(size, kernel_size_3, stride_3)
|
||||
@ -78,10 +107,15 @@ class DQN(nn.Module):
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Flatten(),
|
||||
nn.Linear(linear_input_size, 512),
|
||||
nn.Linear(512, np.prod(action_shape))
|
||||
nn.Linear(512, np.prod(action_shape)),
|
||||
)
|
||||
|
||||
def forward(self, x, state=None, info={}):
|
||||
def forward(
|
||||
self,
|
||||
x: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Any] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
r"""Mapping: x -> Q(x, \*)."""
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = torch.tensor(x, device=self.device, dtype=torch.float32)
|
||||
|
Loading…
x
Reference in New Issue
Block a user