fix optional type syntax
This commit is contained in:
parent
3243484f8e
commit
0eef0ca198
@ -20,7 +20,7 @@ else: # pytest
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--seed', type=int, default=1)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=3e-4)
|
||||
parser.add_argument('--il-lr', type=float, default=1e-3)
|
||||
@ -48,7 +48,7 @@ def get_args():
|
||||
return args
|
||||
|
||||
|
||||
def test_a2c(args=get_args()):
|
||||
def test_a2c_with_il(args=get_args()):
|
||||
torch.set_num_threads(1) # for poor CPU
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
@ -108,8 +108,8 @@ def test_a2c(args=get_args()):
|
||||
collector.close()
|
||||
|
||||
# here we define an imitation collector with a trivial policy
|
||||
if args.task == 'Pendulum-v0':
|
||||
env.spec.reward_threshold = -300 # lower the goal
|
||||
if args.task == 'CartPole-v0':
|
||||
env.spec.reward_threshold = 190 # lower the goal
|
||||
net = Net(1, args.state_shape, device=args.device)
|
||||
net = Actor(net, args.action_shape).to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
|
||||
@ -134,4 +134,4 @@ def test_a2c(args=get_args()):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_a2c()
|
||||
test_a2c_with_il()
|
||||
|
@ -177,7 +177,7 @@ class Batch(object):
|
||||
if k != '_meta' and self.__dict__[k] is not None])
|
||||
|
||||
def split(self, size: Optional[int] = None,
|
||||
shuffle: Optional[bool] = True) -> Iterator['Batch']:
|
||||
shuffle: bool = True) -> Iterator['Batch']:
|
||||
"""Split whole data into multiple small batch.
|
||||
|
||||
:param int size: if it is ``None``, it does not split the data batch;
|
||||
|
@ -96,7 +96,7 @@ class ReplayBuffer(object):
|
||||
"""
|
||||
|
||||
def __init__(self, size: int, stack_num: Optional[int] = 0,
|
||||
ignore_obs_next: Optional[bool] = False, **kwargs) -> None:
|
||||
ignore_obs_next: bool = False, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self._maxsize = size
|
||||
self._stack = stack_num
|
||||
@ -192,7 +192,7 @@ class ReplayBuffer(object):
|
||||
rew: float,
|
||||
done: bool,
|
||||
obs_next: Optional[Union[dict, np.ndarray]] = None,
|
||||
info: Optional[dict] = {},
|
||||
info: dict = {},
|
||||
policy: Optional[Union[dict, Batch]] = {},
|
||||
**kwargs) -> None:
|
||||
"""Add a batch of data into replay buffer."""
|
||||
@ -353,7 +353,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
"""
|
||||
|
||||
def __init__(self, size: int, alpha: float, beta: float,
|
||||
mode: Optional[str] = 'weight', **kwargs) -> None:
|
||||
mode: str = 'weight', **kwargs) -> None:
|
||||
if mode != 'weight':
|
||||
raise NotImplementedError
|
||||
super().__init__(size, **kwargs)
|
||||
@ -370,9 +370,9 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
rew: float,
|
||||
done: bool,
|
||||
obs_next: Optional[Union[dict, np.ndarray]] = None,
|
||||
info: Optional[dict] = {},
|
||||
info: dict = {},
|
||||
policy: Optional[Union[dict, Batch]] = {},
|
||||
weight: Optional[float] = 1.0,
|
||||
weight: float = 1.0,
|
||||
**kwargs) -> None:
|
||||
"""Add a batch of data into replay buffer."""
|
||||
self._weight_sum += np.abs(weight) ** self._alpha - \
|
||||
@ -382,8 +382,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
super().add(obs, act, rew, done, obs_next, info, policy)
|
||||
self._check_weight_sum()
|
||||
|
||||
def sample(self, batch_size: Optional[int] = 0,
|
||||
importance_sample: Optional[bool] = True
|
||||
def sample(self, batch_size: int,
|
||||
importance_sample: bool = True
|
||||
) -> Tuple[Batch, np.ndarray]:
|
||||
"""Get a random sample from buffer with priority probability. \
|
||||
Return all the data in the buffer if batch_size is ``0``.
|
||||
|
@ -219,8 +219,8 @@ class Collector(object):
|
||||
return x
|
||||
|
||||
def collect(self,
|
||||
n_step: Optional[int] = 0,
|
||||
n_episode: Optional[Union[int, List[int]]] = 0,
|
||||
n_step: int = 0,
|
||||
n_episode: Union[int, List[int]] = 0,
|
||||
render: Optional[float] = None,
|
||||
log_fn: Optional[Callable[[dict], None]] = None
|
||||
) -> Dict[str, float]:
|
||||
|
@ -19,9 +19,9 @@ class OUNoise(object):
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
sigma: Optional[float] = 0.3,
|
||||
theta: Optional[float] = 0.15,
|
||||
dt: Optional[float] = 1e-2,
|
||||
sigma: float = 0.3,
|
||||
theta: float = 0.15,
|
||||
dt: float = 1e-2,
|
||||
x0: Optional[Union[float, np.ndarray]] = None
|
||||
) -> None:
|
||||
self.alpha = theta * dt
|
||||
@ -29,7 +29,7 @@ class OUNoise(object):
|
||||
self.x0 = x0
|
||||
self.reset()
|
||||
|
||||
def __call__(self, size: tuple, mu: Optional[float] = .1) -> np.ndarray:
|
||||
def __call__(self, size: tuple, mu: float = .1) -> np.ndarray:
|
||||
"""Generate new noise. Return a ``numpy.ndarray`` which size is equal
|
||||
to ``size``.
|
||||
"""
|
||||
|
@ -99,8 +99,8 @@ class BasePolicy(ABC, nn.Module):
|
||||
def compute_episodic_return(
|
||||
batch: Batch,
|
||||
v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
||||
gamma: Optional[float] = 0.99,
|
||||
gae_lambda: Optional[float] = 0.95) -> Batch:
|
||||
gamma: float = 0.99,
|
||||
gae_lambda: float = 0.95) -> Batch:
|
||||
"""Compute returns over given full-length episodes, including the
|
||||
implementation of Generalized Advantage Estimation (arXiv:1506.02438).
|
||||
|
||||
|
@ -23,7 +23,7 @@ class ImitationPolicy(BasePolicy):
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module, optim: torch.optim.Optimizer,
|
||||
mode: Optional[str] = 'continuous', **kwargs) -> None:
|
||||
mode: str = 'continuous', **kwargs) -> None:
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.optim = optim
|
||||
|
@ -36,14 +36,14 @@ class A2CPolicy(PGPolicy):
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: Optional[torch.distributions.Distribution]
|
||||
dist_fn: torch.distributions.Distribution
|
||||
= torch.distributions.Categorical,
|
||||
discount_factor: Optional[float] = 0.99,
|
||||
vf_coef: Optional[float] = .5,
|
||||
ent_coef: Optional[float] = .01,
|
||||
discount_factor: float = 0.99,
|
||||
vf_coef: float = .5,
|
||||
ent_coef: float = .01,
|
||||
max_grad_norm: Optional[float] = None,
|
||||
gae_lambda: Optional[float] = 0.95,
|
||||
reward_normalization: Optional[bool] = False,
|
||||
gae_lambda: float = 0.95,
|
||||
reward_normalization: bool = False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
|
||||
self.actor = actor
|
||||
|
@ -41,12 +41,12 @@ class DDPGPolicy(BasePolicy):
|
||||
actor_optim: torch.optim.Optimizer,
|
||||
critic: torch.nn.Module,
|
||||
critic_optim: torch.optim.Optimizer,
|
||||
tau: Optional[float] = 0.005,
|
||||
gamma: Optional[float] = 0.99,
|
||||
exploration_noise: Optional[float] = 0.1,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
exploration_noise: float = 0.1,
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
reward_normalization: Optional[bool] = False,
|
||||
ignore_done: Optional[bool] = False,
|
||||
reward_normalization: bool = False,
|
||||
ignore_done: bool = False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if actor is not None:
|
||||
@ -110,8 +110,8 @@ class DDPGPolicy(BasePolicy):
|
||||
|
||||
def forward(self, batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
model: Optional[str] = 'actor',
|
||||
input: Optional[str] = 'obs',
|
||||
model: str = 'actor',
|
||||
input: str = 'obs',
|
||||
eps: Optional[float] = None,
|
||||
**kwargs) -> Batch:
|
||||
"""Compute action over the given batch data.
|
||||
|
@ -29,8 +29,8 @@ class DQNPolicy(BasePolicy):
|
||||
def __init__(self,
|
||||
model: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
discount_factor: Optional[float] = 0.99,
|
||||
estimation_step: Optional[int] = 1,
|
||||
discount_factor: float = 0.99,
|
||||
estimation_step: int = 1,
|
||||
target_update_freq: Optional[int] = 0,
|
||||
**kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
@ -124,8 +124,8 @@ class DQNPolicy(BasePolicy):
|
||||
|
||||
def forward(self, batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
model: Optional[str] = 'model',
|
||||
input: Optional[str] = 'obs',
|
||||
model: str = 'model',
|
||||
input: str = 'obs',
|
||||
eps: Optional[float] = None,
|
||||
**kwargs) -> Batch:
|
||||
"""Compute action over the given batch data.
|
||||
|
@ -24,10 +24,10 @@ class PGPolicy(BasePolicy):
|
||||
def __init__(self,
|
||||
model: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: Optional[torch.distributions.Distribution]
|
||||
dist_fn: torch.distributions.Distribution
|
||||
= torch.distributions.Categorical,
|
||||
discount_factor: Optional[float] = 0.99,
|
||||
reward_normalization: Optional[bool] = False,
|
||||
discount_factor: float = 0.99,
|
||||
reward_normalization: bool = False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.model = model
|
||||
|
@ -46,16 +46,16 @@ class PPOPolicy(PGPolicy):
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: torch.distributions.Distribution,
|
||||
discount_factor: Optional[float] = 0.99,
|
||||
discount_factor: float = 0.99,
|
||||
max_grad_norm: Optional[float] = None,
|
||||
eps_clip: Optional[float] = .2,
|
||||
vf_coef: Optional[float] = .5,
|
||||
ent_coef: Optional[float] = .01,
|
||||
eps_clip: float = .2,
|
||||
vf_coef: float = .5,
|
||||
ent_coef: float = .01,
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
gae_lambda: Optional[float] = 0.95,
|
||||
dual_clip: Optional[float] = 5.,
|
||||
value_clip: Optional[bool] = True,
|
||||
reward_normalization: Optional[bool] = True,
|
||||
gae_lambda: float = 0.95,
|
||||
dual_clip: float = 5.,
|
||||
value_clip: bool = True,
|
||||
reward_normalization: bool = True,
|
||||
**kwargs) -> None:
|
||||
super().__init__(None, None, dist_fn, discount_factor, **kwargs)
|
||||
self._max_grad_norm = max_grad_norm
|
||||
|
@ -48,12 +48,12 @@ class SACPolicy(DDPGPolicy):
|
||||
critic1_optim: torch.optim.Optimizer,
|
||||
critic2: torch.nn.Module,
|
||||
critic2_optim: torch.optim.Optimizer,
|
||||
tau: Optional[float] = 0.005,
|
||||
gamma: Optional[float] = 0.99,
|
||||
alpha: Optional[float] = 0.2,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
alpha: float = 0.2,
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
reward_normalization: Optional[bool] = False,
|
||||
ignore_done: Optional[bool] = False,
|
||||
reward_normalization: bool = False,
|
||||
ignore_done: bool = False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(None, None, None, None, tau, gamma, 0,
|
||||
action_range, reward_normalization, ignore_done,
|
||||
@ -90,7 +90,7 @@ class SACPolicy(DDPGPolicy):
|
||||
|
||||
def forward(self, batch: Batch,
|
||||
state: Optional[Union[dict, Batch, np.ndarray]] = None,
|
||||
input: Optional[str] = 'obs', **kwargs) -> Batch:
|
||||
input: str = 'obs', **kwargs) -> Batch:
|
||||
obs = getattr(batch, input)
|
||||
logits, h = self.actor(obs, state=state, info=batch.info)
|
||||
assert isinstance(logits, tuple)
|
||||
|
@ -53,15 +53,15 @@ class TD3Policy(DDPGPolicy):
|
||||
critic1_optim: torch.optim.Optimizer,
|
||||
critic2: torch.nn.Module,
|
||||
critic2_optim: torch.optim.Optimizer,
|
||||
tau: Optional[float] = 0.005,
|
||||
gamma: Optional[float] = 0.99,
|
||||
exploration_noise: Optional[float] = 0.1,
|
||||
policy_noise: Optional[float] = 0.2,
|
||||
update_actor_freq: Optional[int] = 2,
|
||||
noise_clip: Optional[float] = 0.5,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
exploration_noise: float = 0.1,
|
||||
policy_noise: float = 0.2,
|
||||
update_actor_freq: int = 2,
|
||||
noise_clip: float = 0.5,
|
||||
action_range: Optional[Tuple[float, float]] = None,
|
||||
reward_normalization: Optional[bool] = False,
|
||||
ignore_done: Optional[bool] = False,
|
||||
reward_normalization: bool = False,
|
||||
ignore_done: bool = False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(actor, actor_optim, None, None, tau, gamma,
|
||||
exploration_noise, action_range, reward_normalization,
|
||||
|
@ -24,8 +24,8 @@ def offpolicy_trainer(
|
||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
log_fn: Optional[Callable[[dict], None]] = None,
|
||||
writer: Optional[SummaryWriter] = None,
|
||||
log_interval: Optional[int] = 1,
|
||||
verbose: Optional[bool] = True,
|
||||
log_interval: int = 1,
|
||||
verbose: bool = True,
|
||||
**kwargs
|
||||
) -> Dict[str, Union[float, str]]:
|
||||
"""A wrapper for off-policy trainer procedure.
|
||||
|
@ -25,8 +25,8 @@ def onpolicy_trainer(
|
||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
log_fn: Optional[Callable[[dict], None]] = None,
|
||||
writer: Optional[SummaryWriter] = None,
|
||||
log_interval: Optional[int] = 1,
|
||||
verbose: Optional[bool] = True,
|
||||
log_interval: int = 1,
|
||||
verbose: bool = True,
|
||||
**kwargs
|
||||
) -> Dict[str, Union[float, str]]:
|
||||
"""A wrapper for on-policy trainer procedure.
|
||||
|
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Union, Optional
|
||||
from typing import Union
|
||||
|
||||
|
||||
class MovAvg(object):
|
||||
@ -21,7 +21,7 @@ class MovAvg(object):
|
||||
6.50±1.12
|
||||
"""
|
||||
|
||||
def __init__(self, size: Optional[int] = 100) -> None:
|
||||
def __init__(self, size: int = 100) -> None:
|
||||
super().__init__()
|
||||
self.size = size
|
||||
self.cache = []
|
||||
|
Loading…
x
Reference in New Issue
Block a user