From 0eef0ca1985b4659e7e57aadaa7902613ff35001 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 16 May 2020 20:08:32 +0800 Subject: [PATCH] fix optional type syntax --- test/discrete/test_a2c_with_il.py | 10 +++++----- tianshou/data/batch.py | 2 +- tianshou/data/buffer.py | 14 +++++++------- tianshou/data/collector.py | 4 ++-- tianshou/exploration/random.py | 8 ++++---- tianshou/policy/base.py | 4 ++-- tianshou/policy/imitation/base.py | 2 +- tianshou/policy/modelfree/a2c.py | 12 ++++++------ tianshou/policy/modelfree/ddpg.py | 14 +++++++------- tianshou/policy/modelfree/dqn.py | 8 ++++---- tianshou/policy/modelfree/pg.py | 6 +++--- tianshou/policy/modelfree/ppo.py | 16 ++++++++-------- tianshou/policy/modelfree/sac.py | 12 ++++++------ tianshou/policy/modelfree/td3.py | 16 ++++++++-------- tianshou/trainer/offpolicy.py | 4 ++-- tianshou/trainer/onpolicy.py | 4 ++-- tianshou/utils/moving_average.py | 4 ++-- 17 files changed, 70 insertions(+), 70 deletions(-) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 60d3e44..1b12803 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -20,7 +20,7 @@ else: # pytest def get_args(): parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, default='CartPole-v0') - parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--seed', type=int, default=1) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--il-lr', type=float, default=1e-3) @@ -48,7 +48,7 @@ def get_args(): return args -def test_a2c(args=get_args()): +def test_a2c_with_il(args=get_args()): torch.set_num_threads(1) # for poor CPU env = gym.make(args.task) args.state_shape = env.observation_space.shape or env.observation_space.n @@ -108,8 +108,8 @@ def test_a2c(args=get_args()): collector.close() # here we define an imitation collector with a trivial policy - if args.task == 'Pendulum-v0': - env.spec.reward_threshold = -300 # lower the goal + if args.task == 'CartPole-v0': + env.spec.reward_threshold = 190 # lower the goal net = Net(1, args.state_shape, device=args.device) net = Actor(net, args.action_shape).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) @@ -134,4 +134,4 @@ def test_a2c(args=get_args()): if __name__ == '__main__': - test_a2c() + test_a2c_with_il() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index f8352ec..8015b86 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -177,7 +177,7 @@ class Batch(object): if k != '_meta' and self.__dict__[k] is not None]) def split(self, size: Optional[int] = None, - shuffle: Optional[bool] = True) -> Iterator['Batch']: + shuffle: bool = True) -> Iterator['Batch']: """Split whole data into multiple small batch. :param int size: if it is ``None``, it does not split the data batch; diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 028c189..163e852 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -96,7 +96,7 @@ class ReplayBuffer(object): """ def __init__(self, size: int, stack_num: Optional[int] = 0, - ignore_obs_next: Optional[bool] = False, **kwargs) -> None: + ignore_obs_next: bool = False, **kwargs) -> None: super().__init__() self._maxsize = size self._stack = stack_num @@ -192,7 +192,7 @@ class ReplayBuffer(object): rew: float, done: bool, obs_next: Optional[Union[dict, np.ndarray]] = None, - info: Optional[dict] = {}, + info: dict = {}, policy: Optional[Union[dict, Batch]] = {}, **kwargs) -> None: """Add a batch of data into replay buffer.""" @@ -353,7 +353,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): """ def __init__(self, size: int, alpha: float, beta: float, - mode: Optional[str] = 'weight', **kwargs) -> None: + mode: str = 'weight', **kwargs) -> None: if mode != 'weight': raise NotImplementedError super().__init__(size, **kwargs) @@ -370,9 +370,9 @@ class PrioritizedReplayBuffer(ReplayBuffer): rew: float, done: bool, obs_next: Optional[Union[dict, np.ndarray]] = None, - info: Optional[dict] = {}, + info: dict = {}, policy: Optional[Union[dict, Batch]] = {}, - weight: Optional[float] = 1.0, + weight: float = 1.0, **kwargs) -> None: """Add a batch of data into replay buffer.""" self._weight_sum += np.abs(weight) ** self._alpha - \ @@ -382,8 +382,8 @@ class PrioritizedReplayBuffer(ReplayBuffer): super().add(obs, act, rew, done, obs_next, info, policy) self._check_weight_sum() - def sample(self, batch_size: Optional[int] = 0, - importance_sample: Optional[bool] = True + def sample(self, batch_size: int, + importance_sample: bool = True ) -> Tuple[Batch, np.ndarray]: """Get a random sample from buffer with priority probability. \ Return all the data in the buffer if batch_size is ``0``. diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index ce6a301..9a2d604 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -219,8 +219,8 @@ class Collector(object): return x def collect(self, - n_step: Optional[int] = 0, - n_episode: Optional[Union[int, List[int]]] = 0, + n_step: int = 0, + n_episode: Union[int, List[int]] = 0, render: Optional[float] = None, log_fn: Optional[Callable[[dict], None]] = None ) -> Dict[str, float]: diff --git a/tianshou/exploration/random.py b/tianshou/exploration/random.py index 385388d..1e83dfc 100644 --- a/tianshou/exploration/random.py +++ b/tianshou/exploration/random.py @@ -19,9 +19,9 @@ class OUNoise(object): """ def __init__(self, - sigma: Optional[float] = 0.3, - theta: Optional[float] = 0.15, - dt: Optional[float] = 1e-2, + sigma: float = 0.3, + theta: float = 0.15, + dt: float = 1e-2, x0: Optional[Union[float, np.ndarray]] = None ) -> None: self.alpha = theta * dt @@ -29,7 +29,7 @@ class OUNoise(object): self.x0 = x0 self.reset() - def __call__(self, size: tuple, mu: Optional[float] = .1) -> np.ndarray: + def __call__(self, size: tuple, mu: float = .1) -> np.ndarray: """Generate new noise. Return a ``numpy.ndarray`` which size is equal to ``size``. """ diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 481f462..737aaf5 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -99,8 +99,8 @@ class BasePolicy(ABC, nn.Module): def compute_episodic_return( batch: Batch, v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None, - gamma: Optional[float] = 0.99, - gae_lambda: Optional[float] = 0.95) -> Batch: + gamma: float = 0.99, + gae_lambda: float = 0.95) -> Batch: """Compute returns over given full-length episodes, including the implementation of Generalized Advantage Estimation (arXiv:1506.02438). diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index 9e4ec97..64c3c7e 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -23,7 +23,7 @@ class ImitationPolicy(BasePolicy): """ def __init__(self, model: torch.nn.Module, optim: torch.optim.Optimizer, - mode: Optional[str] = 'continuous', **kwargs) -> None: + mode: str = 'continuous', **kwargs) -> None: super().__init__() self.model = model self.optim = optim diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 3a821b0..e74e19e 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -36,14 +36,14 @@ class A2CPolicy(PGPolicy): actor: torch.nn.Module, critic: torch.nn.Module, optim: torch.optim.Optimizer, - dist_fn: Optional[torch.distributions.Distribution] + dist_fn: torch.distributions.Distribution = torch.distributions.Categorical, - discount_factor: Optional[float] = 0.99, - vf_coef: Optional[float] = .5, - ent_coef: Optional[float] = .01, + discount_factor: float = 0.99, + vf_coef: float = .5, + ent_coef: float = .01, max_grad_norm: Optional[float] = None, - gae_lambda: Optional[float] = 0.95, - reward_normalization: Optional[bool] = False, + gae_lambda: float = 0.95, + reward_normalization: bool = False, **kwargs) -> None: super().__init__(None, optim, dist_fn, discount_factor, **kwargs) self.actor = actor diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 60c147e..0a4aa9c 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -41,12 +41,12 @@ class DDPGPolicy(BasePolicy): actor_optim: torch.optim.Optimizer, critic: torch.nn.Module, critic_optim: torch.optim.Optimizer, - tau: Optional[float] = 0.005, - gamma: Optional[float] = 0.99, - exploration_noise: Optional[float] = 0.1, + tau: float = 0.005, + gamma: float = 0.99, + exploration_noise: float = 0.1, action_range: Optional[Tuple[float, float]] = None, - reward_normalization: Optional[bool] = False, - ignore_done: Optional[bool] = False, + reward_normalization: bool = False, + ignore_done: bool = False, **kwargs) -> None: super().__init__(**kwargs) if actor is not None: @@ -110,8 +110,8 @@ class DDPGPolicy(BasePolicy): def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, - model: Optional[str] = 'actor', - input: Optional[str] = 'obs', + model: str = 'actor', + input: str = 'obs', eps: Optional[float] = None, **kwargs) -> Batch: """Compute action over the given batch data. diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 611d9fa..b337e9c 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -29,8 +29,8 @@ class DQNPolicy(BasePolicy): def __init__(self, model: torch.nn.Module, optim: torch.optim.Optimizer, - discount_factor: Optional[float] = 0.99, - estimation_step: Optional[int] = 1, + discount_factor: float = 0.99, + estimation_step: int = 1, target_update_freq: Optional[int] = 0, **kwargs) -> None: super().__init__(**kwargs) @@ -124,8 +124,8 @@ class DQNPolicy(BasePolicy): def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, - model: Optional[str] = 'model', - input: Optional[str] = 'obs', + model: str = 'model', + input: str = 'obs', eps: Optional[float] = None, **kwargs) -> Batch: """Compute action over the given batch data. diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 1df44bf..cd0e20a 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -24,10 +24,10 @@ class PGPolicy(BasePolicy): def __init__(self, model: torch.nn.Module, optim: torch.optim.Optimizer, - dist_fn: Optional[torch.distributions.Distribution] + dist_fn: torch.distributions.Distribution = torch.distributions.Categorical, - discount_factor: Optional[float] = 0.99, - reward_normalization: Optional[bool] = False, + discount_factor: float = 0.99, + reward_normalization: bool = False, **kwargs) -> None: super().__init__(**kwargs) self.model = model diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 9f0a4ab..85cb7e1 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -46,16 +46,16 @@ class PPOPolicy(PGPolicy): critic: torch.nn.Module, optim: torch.optim.Optimizer, dist_fn: torch.distributions.Distribution, - discount_factor: Optional[float] = 0.99, + discount_factor: float = 0.99, max_grad_norm: Optional[float] = None, - eps_clip: Optional[float] = .2, - vf_coef: Optional[float] = .5, - ent_coef: Optional[float] = .01, + eps_clip: float = .2, + vf_coef: float = .5, + ent_coef: float = .01, action_range: Optional[Tuple[float, float]] = None, - gae_lambda: Optional[float] = 0.95, - dual_clip: Optional[float] = 5., - value_clip: Optional[bool] = True, - reward_normalization: Optional[bool] = True, + gae_lambda: float = 0.95, + dual_clip: float = 5., + value_clip: bool = True, + reward_normalization: bool = True, **kwargs) -> None: super().__init__(None, None, dist_fn, discount_factor, **kwargs) self._max_grad_norm = max_grad_norm diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 0bb78d4..d1357a5 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -48,12 +48,12 @@ class SACPolicy(DDPGPolicy): critic1_optim: torch.optim.Optimizer, critic2: torch.nn.Module, critic2_optim: torch.optim.Optimizer, - tau: Optional[float] = 0.005, - gamma: Optional[float] = 0.99, - alpha: Optional[float] = 0.2, + tau: float = 0.005, + gamma: float = 0.99, + alpha: float = 0.2, action_range: Optional[Tuple[float, float]] = None, - reward_normalization: Optional[bool] = False, - ignore_done: Optional[bool] = False, + reward_normalization: bool = False, + ignore_done: bool = False, **kwargs) -> None: super().__init__(None, None, None, None, tau, gamma, 0, action_range, reward_normalization, ignore_done, @@ -90,7 +90,7 @@ class SACPolicy(DDPGPolicy): def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, - input: Optional[str] = 'obs', **kwargs) -> Batch: + input: str = 'obs', **kwargs) -> Batch: obs = getattr(batch, input) logits, h = self.actor(obs, state=state, info=batch.info) assert isinstance(logits, tuple) diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 807c661..ee6519a 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -53,15 +53,15 @@ class TD3Policy(DDPGPolicy): critic1_optim: torch.optim.Optimizer, critic2: torch.nn.Module, critic2_optim: torch.optim.Optimizer, - tau: Optional[float] = 0.005, - gamma: Optional[float] = 0.99, - exploration_noise: Optional[float] = 0.1, - policy_noise: Optional[float] = 0.2, - update_actor_freq: Optional[int] = 2, - noise_clip: Optional[float] = 0.5, + tau: float = 0.005, + gamma: float = 0.99, + exploration_noise: float = 0.1, + policy_noise: float = 0.2, + update_actor_freq: int = 2, + noise_clip: float = 0.5, action_range: Optional[Tuple[float, float]] = None, - reward_normalization: Optional[bool] = False, - ignore_done: Optional[bool] = False, + reward_normalization: bool = False, + ignore_done: bool = False, **kwargs) -> None: super().__init__(actor, actor_optim, None, None, tau, gamma, exploration_noise, action_range, reward_normalization, diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 5dbdf41..e760a5d 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -24,8 +24,8 @@ def offpolicy_trainer( save_fn: Optional[Callable[[BasePolicy], None]] = None, log_fn: Optional[Callable[[dict], None]] = None, writer: Optional[SummaryWriter] = None, - log_interval: Optional[int] = 1, - verbose: Optional[bool] = True, + log_interval: int = 1, + verbose: bool = True, **kwargs ) -> Dict[str, Union[float, str]]: """A wrapper for off-policy trainer procedure. diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 072ce05..e3849fe 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -25,8 +25,8 @@ def onpolicy_trainer( save_fn: Optional[Callable[[BasePolicy], None]] = None, log_fn: Optional[Callable[[dict], None]] = None, writer: Optional[SummaryWriter] = None, - log_interval: Optional[int] = 1, - verbose: Optional[bool] = True, + log_interval: int = 1, + verbose: bool = True, **kwargs ) -> Dict[str, Union[float, str]]: """A wrapper for on-policy trainer procedure. diff --git a/tianshou/utils/moving_average.py b/tianshou/utils/moving_average.py index c700f1a..d994762 100644 --- a/tianshou/utils/moving_average.py +++ b/tianshou/utils/moving_average.py @@ -1,6 +1,6 @@ import torch import numpy as np -from typing import Union, Optional +from typing import Union class MovAvg(object): @@ -21,7 +21,7 @@ class MovAvg(object): 6.50±1.12 """ - def __init__(self, size: Optional[int] = 100) -> None: + def __init__(self, size: int = 100) -> None: super().__init__() self.size = size self.cache = []