fix optional type syntax

This commit is contained in:
Trinkle23897 2020-05-16 20:08:32 +08:00
parent 3243484f8e
commit 0eef0ca198
17 changed files with 70 additions and 70 deletions

View File

@ -20,7 +20,7 @@ else: # pytest
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--il-lr', type=float, default=1e-3) parser.add_argument('--il-lr', type=float, default=1e-3)
@ -48,7 +48,7 @@ def get_args():
return args return args
def test_a2c(args=get_args()): def test_a2c_with_il(args=get_args()):
torch.set_num_threads(1) # for poor CPU torch.set_num_threads(1) # for poor CPU
env = gym.make(args.task) env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n args.state_shape = env.observation_space.shape or env.observation_space.n
@ -108,8 +108,8 @@ def test_a2c(args=get_args()):
collector.close() collector.close()
# here we define an imitation collector with a trivial policy # here we define an imitation collector with a trivial policy
if args.task == 'Pendulum-v0': if args.task == 'CartPole-v0':
env.spec.reward_threshold = -300 # lower the goal env.spec.reward_threshold = 190 # lower the goal
net = Net(1, args.state_shape, device=args.device) net = Net(1, args.state_shape, device=args.device)
net = Actor(net, args.action_shape).to(args.device) net = Actor(net, args.action_shape).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
@ -134,4 +134,4 @@ def test_a2c(args=get_args()):
if __name__ == '__main__': if __name__ == '__main__':
test_a2c() test_a2c_with_il()

View File

@ -177,7 +177,7 @@ class Batch(object):
if k != '_meta' and self.__dict__[k] is not None]) if k != '_meta' and self.__dict__[k] is not None])
def split(self, size: Optional[int] = None, def split(self, size: Optional[int] = None,
shuffle: Optional[bool] = True) -> Iterator['Batch']: shuffle: bool = True) -> Iterator['Batch']:
"""Split whole data into multiple small batch. """Split whole data into multiple small batch.
:param int size: if it is ``None``, it does not split the data batch; :param int size: if it is ``None``, it does not split the data batch;

View File

@ -96,7 +96,7 @@ class ReplayBuffer(object):
""" """
def __init__(self, size: int, stack_num: Optional[int] = 0, def __init__(self, size: int, stack_num: Optional[int] = 0,
ignore_obs_next: Optional[bool] = False, **kwargs) -> None: ignore_obs_next: bool = False, **kwargs) -> None:
super().__init__() super().__init__()
self._maxsize = size self._maxsize = size
self._stack = stack_num self._stack = stack_num
@ -192,7 +192,7 @@ class ReplayBuffer(object):
rew: float, rew: float,
done: bool, done: bool,
obs_next: Optional[Union[dict, np.ndarray]] = None, obs_next: Optional[Union[dict, np.ndarray]] = None,
info: Optional[dict] = {}, info: dict = {},
policy: Optional[Union[dict, Batch]] = {}, policy: Optional[Union[dict, Batch]] = {},
**kwargs) -> None: **kwargs) -> None:
"""Add a batch of data into replay buffer.""" """Add a batch of data into replay buffer."""
@ -353,7 +353,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
""" """
def __init__(self, size: int, alpha: float, beta: float, def __init__(self, size: int, alpha: float, beta: float,
mode: Optional[str] = 'weight', **kwargs) -> None: mode: str = 'weight', **kwargs) -> None:
if mode != 'weight': if mode != 'weight':
raise NotImplementedError raise NotImplementedError
super().__init__(size, **kwargs) super().__init__(size, **kwargs)
@ -370,9 +370,9 @@ class PrioritizedReplayBuffer(ReplayBuffer):
rew: float, rew: float,
done: bool, done: bool,
obs_next: Optional[Union[dict, np.ndarray]] = None, obs_next: Optional[Union[dict, np.ndarray]] = None,
info: Optional[dict] = {}, info: dict = {},
policy: Optional[Union[dict, Batch]] = {}, policy: Optional[Union[dict, Batch]] = {},
weight: Optional[float] = 1.0, weight: float = 1.0,
**kwargs) -> None: **kwargs) -> None:
"""Add a batch of data into replay buffer.""" """Add a batch of data into replay buffer."""
self._weight_sum += np.abs(weight) ** self._alpha - \ self._weight_sum += np.abs(weight) ** self._alpha - \
@ -382,8 +382,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
super().add(obs, act, rew, done, obs_next, info, policy) super().add(obs, act, rew, done, obs_next, info, policy)
self._check_weight_sum() self._check_weight_sum()
def sample(self, batch_size: Optional[int] = 0, def sample(self, batch_size: int,
importance_sample: Optional[bool] = True importance_sample: bool = True
) -> Tuple[Batch, np.ndarray]: ) -> Tuple[Batch, np.ndarray]:
"""Get a random sample from buffer with priority probability. \ """Get a random sample from buffer with priority probability. \
Return all the data in the buffer if batch_size is ``0``. Return all the data in the buffer if batch_size is ``0``.

View File

@ -219,8 +219,8 @@ class Collector(object):
return x return x
def collect(self, def collect(self,
n_step: Optional[int] = 0, n_step: int = 0,
n_episode: Optional[Union[int, List[int]]] = 0, n_episode: Union[int, List[int]] = 0,
render: Optional[float] = None, render: Optional[float] = None,
log_fn: Optional[Callable[[dict], None]] = None log_fn: Optional[Callable[[dict], None]] = None
) -> Dict[str, float]: ) -> Dict[str, float]:

View File

@ -19,9 +19,9 @@ class OUNoise(object):
""" """
def __init__(self, def __init__(self,
sigma: Optional[float] = 0.3, sigma: float = 0.3,
theta: Optional[float] = 0.15, theta: float = 0.15,
dt: Optional[float] = 1e-2, dt: float = 1e-2,
x0: Optional[Union[float, np.ndarray]] = None x0: Optional[Union[float, np.ndarray]] = None
) -> None: ) -> None:
self.alpha = theta * dt self.alpha = theta * dt
@ -29,7 +29,7 @@ class OUNoise(object):
self.x0 = x0 self.x0 = x0
self.reset() self.reset()
def __call__(self, size: tuple, mu: Optional[float] = .1) -> np.ndarray: def __call__(self, size: tuple, mu: float = .1) -> np.ndarray:
"""Generate new noise. Return a ``numpy.ndarray`` which size is equal """Generate new noise. Return a ``numpy.ndarray`` which size is equal
to ``size``. to ``size``.
""" """

View File

@ -99,8 +99,8 @@ class BasePolicy(ABC, nn.Module):
def compute_episodic_return( def compute_episodic_return(
batch: Batch, batch: Batch,
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None, v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
gamma: Optional[float] = 0.99, gamma: float = 0.99,
gae_lambda: Optional[float] = 0.95) -> Batch: gae_lambda: float = 0.95) -> Batch:
"""Compute returns over given full-length episodes, including the """Compute returns over given full-length episodes, including the
implementation of Generalized Advantage Estimation (arXiv:1506.02438). implementation of Generalized Advantage Estimation (arXiv:1506.02438).

View File

@ -23,7 +23,7 @@ class ImitationPolicy(BasePolicy):
""" """
def __init__(self, model: torch.nn.Module, optim: torch.optim.Optimizer, def __init__(self, model: torch.nn.Module, optim: torch.optim.Optimizer,
mode: Optional[str] = 'continuous', **kwargs) -> None: mode: str = 'continuous', **kwargs) -> None:
super().__init__() super().__init__()
self.model = model self.model = model
self.optim = optim self.optim = optim

View File

@ -36,14 +36,14 @@ class A2CPolicy(PGPolicy):
actor: torch.nn.Module, actor: torch.nn.Module,
critic: torch.nn.Module, critic: torch.nn.Module,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
dist_fn: Optional[torch.distributions.Distribution] dist_fn: torch.distributions.Distribution
= torch.distributions.Categorical, = torch.distributions.Categorical,
discount_factor: Optional[float] = 0.99, discount_factor: float = 0.99,
vf_coef: Optional[float] = .5, vf_coef: float = .5,
ent_coef: Optional[float] = .01, ent_coef: float = .01,
max_grad_norm: Optional[float] = None, max_grad_norm: Optional[float] = None,
gae_lambda: Optional[float] = 0.95, gae_lambda: float = 0.95,
reward_normalization: Optional[bool] = False, reward_normalization: bool = False,
**kwargs) -> None: **kwargs) -> None:
super().__init__(None, optim, dist_fn, discount_factor, **kwargs) super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
self.actor = actor self.actor = actor

View File

@ -41,12 +41,12 @@ class DDPGPolicy(BasePolicy):
actor_optim: torch.optim.Optimizer, actor_optim: torch.optim.Optimizer,
critic: torch.nn.Module, critic: torch.nn.Module,
critic_optim: torch.optim.Optimizer, critic_optim: torch.optim.Optimizer,
tau: Optional[float] = 0.005, tau: float = 0.005,
gamma: Optional[float] = 0.99, gamma: float = 0.99,
exploration_noise: Optional[float] = 0.1, exploration_noise: float = 0.1,
action_range: Optional[Tuple[float, float]] = None, action_range: Optional[Tuple[float, float]] = None,
reward_normalization: Optional[bool] = False, reward_normalization: bool = False,
ignore_done: Optional[bool] = False, ignore_done: bool = False,
**kwargs) -> None: **kwargs) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
if actor is not None: if actor is not None:
@ -110,8 +110,8 @@ class DDPGPolicy(BasePolicy):
def forward(self, batch: Batch, def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None, state: Optional[Union[dict, Batch, np.ndarray]] = None,
model: Optional[str] = 'actor', model: str = 'actor',
input: Optional[str] = 'obs', input: str = 'obs',
eps: Optional[float] = None, eps: Optional[float] = None,
**kwargs) -> Batch: **kwargs) -> Batch:
"""Compute action over the given batch data. """Compute action over the given batch data.

View File

@ -29,8 +29,8 @@ class DQNPolicy(BasePolicy):
def __init__(self, def __init__(self,
model: torch.nn.Module, model: torch.nn.Module,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
discount_factor: Optional[float] = 0.99, discount_factor: float = 0.99,
estimation_step: Optional[int] = 1, estimation_step: int = 1,
target_update_freq: Optional[int] = 0, target_update_freq: Optional[int] = 0,
**kwargs) -> None: **kwargs) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
@ -124,8 +124,8 @@ class DQNPolicy(BasePolicy):
def forward(self, batch: Batch, def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None, state: Optional[Union[dict, Batch, np.ndarray]] = None,
model: Optional[str] = 'model', model: str = 'model',
input: Optional[str] = 'obs', input: str = 'obs',
eps: Optional[float] = None, eps: Optional[float] = None,
**kwargs) -> Batch: **kwargs) -> Batch:
"""Compute action over the given batch data. """Compute action over the given batch data.

View File

@ -24,10 +24,10 @@ class PGPolicy(BasePolicy):
def __init__(self, def __init__(self,
model: torch.nn.Module, model: torch.nn.Module,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
dist_fn: Optional[torch.distributions.Distribution] dist_fn: torch.distributions.Distribution
= torch.distributions.Categorical, = torch.distributions.Categorical,
discount_factor: Optional[float] = 0.99, discount_factor: float = 0.99,
reward_normalization: Optional[bool] = False, reward_normalization: bool = False,
**kwargs) -> None: **kwargs) -> None:
super().__init__(**kwargs) super().__init__(**kwargs)
self.model = model self.model = model

View File

@ -46,16 +46,16 @@ class PPOPolicy(PGPolicy):
critic: torch.nn.Module, critic: torch.nn.Module,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
dist_fn: torch.distributions.Distribution, dist_fn: torch.distributions.Distribution,
discount_factor: Optional[float] = 0.99, discount_factor: float = 0.99,
max_grad_norm: Optional[float] = None, max_grad_norm: Optional[float] = None,
eps_clip: Optional[float] = .2, eps_clip: float = .2,
vf_coef: Optional[float] = .5, vf_coef: float = .5,
ent_coef: Optional[float] = .01, ent_coef: float = .01,
action_range: Optional[Tuple[float, float]] = None, action_range: Optional[Tuple[float, float]] = None,
gae_lambda: Optional[float] = 0.95, gae_lambda: float = 0.95,
dual_clip: Optional[float] = 5., dual_clip: float = 5.,
value_clip: Optional[bool] = True, value_clip: bool = True,
reward_normalization: Optional[bool] = True, reward_normalization: bool = True,
**kwargs) -> None: **kwargs) -> None:
super().__init__(None, None, dist_fn, discount_factor, **kwargs) super().__init__(None, None, dist_fn, discount_factor, **kwargs)
self._max_grad_norm = max_grad_norm self._max_grad_norm = max_grad_norm

View File

@ -48,12 +48,12 @@ class SACPolicy(DDPGPolicy):
critic1_optim: torch.optim.Optimizer, critic1_optim: torch.optim.Optimizer,
critic2: torch.nn.Module, critic2: torch.nn.Module,
critic2_optim: torch.optim.Optimizer, critic2_optim: torch.optim.Optimizer,
tau: Optional[float] = 0.005, tau: float = 0.005,
gamma: Optional[float] = 0.99, gamma: float = 0.99,
alpha: Optional[float] = 0.2, alpha: float = 0.2,
action_range: Optional[Tuple[float, float]] = None, action_range: Optional[Tuple[float, float]] = None,
reward_normalization: Optional[bool] = False, reward_normalization: bool = False,
ignore_done: Optional[bool] = False, ignore_done: bool = False,
**kwargs) -> None: **kwargs) -> None:
super().__init__(None, None, None, None, tau, gamma, 0, super().__init__(None, None, None, None, tau, gamma, 0,
action_range, reward_normalization, ignore_done, action_range, reward_normalization, ignore_done,
@ -90,7 +90,7 @@ class SACPolicy(DDPGPolicy):
def forward(self, batch: Batch, def forward(self, batch: Batch,
state: Optional[Union[dict, Batch, np.ndarray]] = None, state: Optional[Union[dict, Batch, np.ndarray]] = None,
input: Optional[str] = 'obs', **kwargs) -> Batch: input: str = 'obs', **kwargs) -> Batch:
obs = getattr(batch, input) obs = getattr(batch, input)
logits, h = self.actor(obs, state=state, info=batch.info) logits, h = self.actor(obs, state=state, info=batch.info)
assert isinstance(logits, tuple) assert isinstance(logits, tuple)

View File

@ -53,15 +53,15 @@ class TD3Policy(DDPGPolicy):
critic1_optim: torch.optim.Optimizer, critic1_optim: torch.optim.Optimizer,
critic2: torch.nn.Module, critic2: torch.nn.Module,
critic2_optim: torch.optim.Optimizer, critic2_optim: torch.optim.Optimizer,
tau: Optional[float] = 0.005, tau: float = 0.005,
gamma: Optional[float] = 0.99, gamma: float = 0.99,
exploration_noise: Optional[float] = 0.1, exploration_noise: float = 0.1,
policy_noise: Optional[float] = 0.2, policy_noise: float = 0.2,
update_actor_freq: Optional[int] = 2, update_actor_freq: int = 2,
noise_clip: Optional[float] = 0.5, noise_clip: float = 0.5,
action_range: Optional[Tuple[float, float]] = None, action_range: Optional[Tuple[float, float]] = None,
reward_normalization: Optional[bool] = False, reward_normalization: bool = False,
ignore_done: Optional[bool] = False, ignore_done: bool = False,
**kwargs) -> None: **kwargs) -> None:
super().__init__(actor, actor_optim, None, None, tau, gamma, super().__init__(actor, actor_optim, None, None, tau, gamma,
exploration_noise, action_range, reward_normalization, exploration_noise, action_range, reward_normalization,

View File

@ -24,8 +24,8 @@ def offpolicy_trainer(
save_fn: Optional[Callable[[BasePolicy], None]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None,
log_fn: Optional[Callable[[dict], None]] = None, log_fn: Optional[Callable[[dict], None]] = None,
writer: Optional[SummaryWriter] = None, writer: Optional[SummaryWriter] = None,
log_interval: Optional[int] = 1, log_interval: int = 1,
verbose: Optional[bool] = True, verbose: bool = True,
**kwargs **kwargs
) -> Dict[str, Union[float, str]]: ) -> Dict[str, Union[float, str]]:
"""A wrapper for off-policy trainer procedure. """A wrapper for off-policy trainer procedure.

View File

@ -25,8 +25,8 @@ def onpolicy_trainer(
save_fn: Optional[Callable[[BasePolicy], None]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None,
log_fn: Optional[Callable[[dict], None]] = None, log_fn: Optional[Callable[[dict], None]] = None,
writer: Optional[SummaryWriter] = None, writer: Optional[SummaryWriter] = None,
log_interval: Optional[int] = 1, log_interval: int = 1,
verbose: Optional[bool] = True, verbose: bool = True,
**kwargs **kwargs
) -> Dict[str, Union[float, str]]: ) -> Dict[str, Union[float, str]]:
"""A wrapper for on-policy trainer procedure. """A wrapper for on-policy trainer procedure.

View File

@ -1,6 +1,6 @@
import torch import torch
import numpy as np import numpy as np
from typing import Union, Optional from typing import Union
class MovAvg(object): class MovAvg(object):
@ -21,7 +21,7 @@ class MovAvg(object):
6.50±1.12 6.50±1.12
""" """
def __init__(self, size: Optional[int] = 100) -> None: def __init__(self, size: int = 100) -> None:
super().__init__() super().__init__()
self.size = size self.size = size
self.cache = [] self.cache = []