update utils.network (#275)

This is the first commit of 6 commits mentioned in #274, which features

1. Refactor of `Class Net` to support any form of MLP.
2. Enable type check in utils.network.
3. Relative change in docs/test/examples.
4. Move atari-related network to examples/atari/atari_network.py

Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
This commit is contained in:
ChenDRAG 2021-01-20 16:54:13 +08:00 committed by GitHub
parent 866e35d550
commit a633a6a028
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 655 additions and 413 deletions

View File

@ -213,7 +213,7 @@ from tianshou.utils.net.common import Net
env = gym.make(task)
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
net = Net(layer_num=2, state_shape=state_shape, action_shape=action_shape)
net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128])
optim = torch.optim.Adam(net.parameters(), lr=lr)
```

View File

@ -203,7 +203,8 @@ The explanation of each Tianshou class/function will be deferred to their first
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=3)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128, 128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
@ -249,7 +250,8 @@ Here it is:
args.action_shape = env.action_space.shape or env.action_space.n
if agent_learn is None:
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device).to(args.device)
net = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, device=args.device).to(args.device)
if optim is None:
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
agent_learn = DQNPolicy(

View File

@ -7,10 +7,10 @@ from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import C51Policy
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.discrete import C51
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from atari_network import C51
from atari_wrapper import wrap_deepmind
@ -40,8 +40,8 @@ def get_args():
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--frames_stack', type=int, default=4)
parser.add_argument('--resume_path', type=str, default=None)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
return parser.parse_args()

View File

@ -7,10 +7,10 @@ from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DQNPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.discrete import DQN
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from atari_network import DQN
from atari_wrapper import wrap_deepmind
@ -37,8 +37,8 @@ def get_args():
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--frames_stack', type=int, default=4)
parser.add_argument('--resume_path', type=str, default=None)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
return parser.parse_args()

View File

@ -0,0 +1,82 @@
import torch
import numpy as np
from torch import nn
from typing import Any, Dict, Tuple, Union, Optional, Sequence
class DQN(nn.Module):
"""Reference: Human-level control through deep reinforcement learning.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(
self,
c: int,
h: int,
w: int,
action_shape: Sequence[int],
device: Union[str, int, torch.device] = "cpu",
features_only: bool = False,
) -> None:
super().__init__()
self.device = device
self.net = nn.Sequential(
nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True),
nn.Flatten())
with torch.no_grad():
self.output_dim = np.prod(
self.net(torch.zeros(1, c, h, w)).shape[1:])
if not features_only:
self.net = nn.Sequential(
self.net,
nn.Linear(self.output_dim, 512), nn.ReLU(inplace=True),
nn.Linear(512, np.prod(action_shape)))
self.output_dim = np.prod(action_shape)
def forward(
self,
x: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Q(x, \*)."""
x = torch.as_tensor(
x, device=self.device, dtype=torch.float32) # type: ignore
return self.net(x), state
class C51(DQN):
"""Reference: A distributional perspective on reinforcement learning.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(
self,
c: int,
h: int,
w: int,
action_shape: Sequence[int],
num_atoms: int = 51,
device: Union[str, int, torch.device] = "cpu",
) -> None:
super().__init__(c, h, w, [np.prod(action_shape) * num_atoms], device)
self.action_shape = action_shape
self.num_atoms = num_atoms
def forward(
self,
x: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Z(x, \*)."""
x, state = super().forward(x)
x = x.view(-1, self.num_atoms).softmax(dim=-1)
x = x.view(-1, np.prod(self.action_shape), self.num_atoms)
return x, state

View File

@ -7,10 +7,10 @@ from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import A2CPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.utils.net.discrete import Actor, Critic
from tianshou.utils.net.common import Net
from atari import create_atari_environment, preprocess_fn
@ -27,7 +27,8 @@ def get_args():
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--repeat-per-collect', type=int, default=1)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=2)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=8)
parser.add_argument('--logdir', type=str, default='log')
@ -40,7 +41,7 @@ def get_args():
parser.add_argument('--vf-coef', type=float, default=0.5)
parser.add_argument('--ent-coef', type=float, default=0.001)
parser.add_argument('--max-grad-norm', type=float, default=None)
parser.add_argument('--max_episode_steps', type=int, default=2000)
parser.add_argument('--max-episode-steps', type=int, default=2000)
return parser.parse_args()
@ -62,11 +63,12 @@ def test_a2c(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, device=args.device)
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = Actor(net, args.action_shape).to(args.device)
critic = Critic(net).to(args.device)
optim = torch.optim.Adam(list(
actor.parameters()) + list(critic.parameters()), lr=args.lr)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical
policy = A2CPolicy(
actor, critic, optim, dist, args.gamma, vf_coef=args.vf_coef,

View File

@ -7,10 +7,10 @@ from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import PPOPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.utils.net.discrete import Actor, Critic
from tianshou.utils.net.common import Net
from atari import create_atari_environment, preprocess_fn
@ -27,7 +27,8 @@ def get_args():
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=8)
parser.add_argument('--logdir', type=str, default='log')
@ -40,7 +41,7 @@ def get_args():
parser.add_argument('--ent-coef', type=float, default=0.0)
parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--max_episode_steps', type=int, default=2000)
parser.add_argument('--max-episode-steps', type=int, default=2000)
return parser.parse_args()
@ -62,11 +63,12 @@ def test_ppo(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, device=args.device)
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = Actor(net, args.action_shape).to(args.device)
critic = Critic(net).to(args.device)
optim = torch.optim.Adam(list(
actor.parameters()) + list(critic.parameters()), lr=args.lr)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical
policy = PPOPolicy(
actor, critic, optim, dist, args.gamma,

View File

@ -28,7 +28,12 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=100)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=0)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128])
parser.add_argument('--dueling-q-hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--dueling-v-hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
@ -56,8 +61,11 @@ def test_dqn(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape,
args.action_shape, args.device, dueling=(2, 2)).to(args.device)
Q_param = {"hidden_sizes": args.dueling_q_hidden_sizes}
V_param = {"hidden_sizes": args.dueling_v_hidden_sizes}
net = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, device=args.device,
dueling_param=(Q_param, V_param)).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = DQNPolicy(
net, optim, args.gamma, args.n_step,

View File

@ -30,7 +30,8 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=10000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
@ -87,20 +88,23 @@ def test_sac_bipedal(args=get_args()):
test_envs.seed(args.seed)
# model
net_a = Net(args.layer_num, args.state_shape, device=args.device)
net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = ActorProb(
net_a, args.action_shape, args.max_action, args.device, unbounded=True
).to(args.device)
net_a, args.action_shape, max_action=args.max_action,
device=args.device, unbounded=True).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net_c1 = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net_c1, args.device).to(args.device)
net_c1 = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True, device=args.device)
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
net_c2 = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic2 = Critic(net_c2, args.device).to(args.device)
net_c2 = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True, device=args.device)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
if args.auto_alpha:

View File

@ -29,7 +29,12 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=5000)
parser.add_argument('--collect-per-step', type=int, default=16)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--dueling-q-hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--dueling-v-hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
@ -57,9 +62,11 @@ def test_dqn(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape,
args.action_shape, args.device,
dueling=(2, 2)).to(args.device)
Q_param = {"hidden_sizes": args.dueling_q_hidden_sizes}
V_param = {"hidden_sizes": args.dueling_v_hidden_sizes}
net = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, device=args.device,
dueling_param=(Q_param, V_param)).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = DQNPolicy(
net, optim, args.gamma, args.n_step,

View File

@ -32,7 +32,8 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=5)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
@ -61,19 +62,22 @@ def test_sac(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, device=args.device)
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = ActorProb(
net, args.action_shape,
args.max_action, args.device, unbounded=True
max_action=args.max_action, device=args.device, unbounded=True
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net_c1 = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net_c1, args.device).to(args.device)
net_c1 = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, concat=True,
device=args.device)
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
net_c2 = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic2 = Critic(net_c2, args.device).to(args.device)
net_c2 = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, concat=True,
device=args.device)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
if args.auto_alpha:

View File

@ -33,8 +33,8 @@ def get_args():
parser.add_argument('--update-per-step', type=int, default=1)
parser.add_argument('--pre-collect-step', type=int, default=10000)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--hidden-layer-size', type=int, default=256)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
@ -70,26 +70,22 @@ def test_sac(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, device=args.device,
hidden_layer_size=args.hidden_layer_size)
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = ActorProb(
net, args.action_shape, args.max_action, args.device, unbounded=True,
hidden_layer_size=args.hidden_layer_size, conditioned_sigma=True,
net, args.action_shape, max_action=args.max_action,
device=args.device, unbounded=True, conditioned_sigma=True
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net_c1 = Net(args.layer_num, args.state_shape, args.action_shape,
concat=True, device=args.device,
hidden_layer_size=args.hidden_layer_size)
critic1 = Critic(
net_c1, args.device, hidden_layer_size=args.hidden_layer_size
).to(args.device)
net_c1 = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True, device=args.device)
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
net_c2 = Net(args.layer_num, args.state_shape, args.action_shape,
concat=True, device=args.device,
hidden_layer_size=args.hidden_layer_size)
critic2 = Critic(
net_c2, args.device, hidden_layer_size=args.hidden_layer_size
).to(args.device)
net_c2 = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True, device=args.device)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
if args.auto_alpha:

View File

@ -28,7 +28,8 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=4)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
@ -56,13 +57,14 @@ def test_ddpg(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, device=args.device)
actor = Actor(net, args.action_shape, args.max_action,
args.device).to(args.device)
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = Actor(net, args.action_shape, max_action=args.max_action,
device=args.device).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic = Critic(net, args.device).to(args.device)
net = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, concat=True, device=args.device)
critic = Critic(net, device=args.device).to(args.device)
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
policy = DDPGPolicy(
actor, actor_optim, critic, critic_optim,

View File

@ -31,7 +31,8 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
@ -59,17 +60,16 @@ def test_td3(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, device=args.device)
actor = Actor(
net, args.action_shape,
args.max_action, args.device
).to(args.device)
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = Actor(net, args.action_shape, max_action=args.max_action,
device=args.device).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net, args.device).to(args.device)
net = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, concat=True, device=args.device)
critic1 = Critic(net, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net, args.device).to(args.device)
critic2 = Critic(net, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,

View File

@ -30,7 +30,8 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=4)
parser.add_argument('--logdir', type=str, default='log')
@ -61,19 +62,18 @@ def test_sac(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, device=args.device)
actor = ActorProb(
net, args.action_shape,
args.max_action, args.device, unbounded=True
).to(args.device)
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = ActorProb(net, args.action_shape, max_action=args.max_action,
device=args.device, unbounded=True).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net, args.device).to(args.device)
net = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, concat=True, device=args.device)
critic1 = Critic(net, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
net = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic2 = Critic(net, args.device).to(args.device)
net = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, concat=True, device=args.device)
critic2 = Critic(net, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,

View File

@ -33,7 +33,8 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
@ -62,19 +63,18 @@ def test_td3(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, device=args.device)
actor = Actor(
net, args.action_shape,
args.max_action, args.device
).to(args.device)
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = Actor(net, args.action_shape, max_action=args.max_action,
device=args.device).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net, args.device).to(args.device)
net = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, concat=True, device=args.device)
critic1 = Critic(net, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
net = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic2 = Critic(net, args.device).to(args.device)
net = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, concat=True, device=args.device)
critic2 = Critic(net, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,

View File

@ -18,9 +18,6 @@ warn_unreachable = True
warn_unused_configs = True
warn_unused_ignores = True
[mypy-tianshou.utils.net.*]
ignore_errors = True
[pydocstyle]
ignore = D100,D102,D104,D105,D107,D203,D213,D401,D402

View File

@ -3,8 +3,7 @@ import numpy as np
from tianshou.utils import MovAvg
from tianshou.utils import SummaryWriter
from tianshou.utils.net.common import Net
from tianshou.utils.net.discrete import DQN, C51
from tianshou.utils.net.common import MLP, Net
from tianshou.exploration import GaussianNoise, OUNoise
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
@ -35,17 +34,39 @@ def test_moving_average():
def test_net():
# here test the networks that does not appear in the other script
bsz = 64
# MLP
data = torch.rand([bsz, 3])
mlp = MLP(3, 6, hidden_sizes=[128])
assert list(mlp(data).shape) == [bsz, 6]
# output == 0 and len(hidden_sizes) == 0 means identity model
mlp = MLP(6, 0)
assert data.shape == mlp(data).shape
# common net
state_shape = (10, 2)
action_shape = (5, )
data = torch.rand([bsz, *state_shape])
expect_output_shape = [bsz, *action_shape]
net = Net(3, state_shape, action_shape, norm_layer=torch.nn.LayerNorm)
net = Net(state_shape, action_shape, hidden_sizes=[128, 128],
norm_layer=torch.nn.LayerNorm, activation=None)
assert list(net(data)[0].shape) == expect_output_shape
net = Net(3, state_shape, action_shape, dueling=(2, 2))
assert str(net).count("LayerNorm") == 2
assert str(net).count("ReLU") == 0
Q_param = V_param = {"hidden_sizes": [128, 128]}
net = Net(state_shape, action_shape, hidden_sizes=[128, 128],
dueling_param=(Q_param, V_param))
assert list(net(data)[0].shape) == expect_output_shape
# concat
net = Net(state_shape, action_shape, hidden_sizes=[128],
concat=True)
data = torch.rand([bsz, np.prod(state_shape) + np.prod(action_shape)])
expect_output_shape = [bsz, 128]
assert list(net(data)[0].shape) == expect_output_shape
net = Net(state_shape, action_shape, hidden_sizes=[128],
concat=True, dueling_param=(Q_param, V_param))
assert list(net(data)[0].shape) == expect_output_shape
# recurrent actor/critic
data = data.flatten(1)
data = torch.rand([bsz, *state_shape]).flatten(1)
expect_output_shape = [bsz, *action_shape]
net = RecurrentActorProb(3, state_shape, action_shape)
mu, sigma = net(data)[0]
assert mu.shape == sigma.shape
@ -54,17 +75,6 @@ def test_net():
data = torch.rand([bsz, 8, np.prod(state_shape)])
act = torch.rand(expect_output_shape)
assert list(net(data, act).shape) == [bsz, 1]
# DQN
state_shape = (4, 84, 84)
action_shape = (6, )
data = np.random.rand(bsz, *state_shape)
expect_output_shape = [bsz, *action_shape]
net = DQN(*state_shape, action_shape)
assert list(net(data)[0].shape) == expect_output_shape
num_atoms = 51
net = C51(*state_shape, action_shape, num_atoms)
expect_output_shape = [bsz, *action_shape, num_atoms]
assert list(net(data)[0].shape) == expect_output_shape
def test_summary_writer():

View File

@ -30,7 +30,8 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=4)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
@ -66,15 +67,14 @@ def test_ddpg(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, device=args.device)
actor = Actor(
net, args.action_shape,
args.max_action, args.device
).to(args.device)
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = Actor(net, args.action_shape, max_action=args.max_action,
device=args.device).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic = Critic(net, args.device).to(args.device)
net = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, concat=True, device=args.device)
critic = Critic(net, device=args.device).to(args.device)
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
policy = DDPGPolicy(
actor, actor_optim, critic, critic_optim,

View File

@ -27,7 +27,8 @@ def get_args():
parser.add_argument('--collect-per-step', type=int, default=1)
parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
@ -69,21 +70,20 @@ def test_ppo(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, device=args.device)
actor = ActorProb(
net, args.action_shape,
args.max_action, args.device
).to(args.device)
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = ActorProb(net, args.action_shape, max_action=args.max_action,
device=args.device).to(args.device)
critic = Critic(Net(
args.layer_num, args.state_shape, device=args.device
args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device
), device=args.device).to(args.device)
# orthogonal initialization
for m in list(actor.modules()) + list(critic.modules()):
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(list(
actor.parameters()) + list(critic.parameters()), lr=args.lr)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward

View File

@ -29,7 +29,10 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--imitation-hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
@ -65,18 +68,20 @@ def test_sac_with_il(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, device=args.device)
actor = ActorProb(
net, args.action_shape, args.max_action, args.device, unbounded=True
).to(args.device)
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = ActorProb(net, args.action_shape, max_action=args.max_action,
device=args.device, unbounded=True).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net_c1 = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net_c1, args.device).to(args.device)
net_c1 = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True, device=args.device)
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
net_c2 = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic2 = Critic(net_c2, args.device).to(args.device)
net_c2 = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True, device=args.device)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
@ -120,8 +125,9 @@ def test_sac_with_il(args=get_args()):
if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -300 # lower the goal
net = Actor(
Net(1, args.state_shape, device=args.device),
args.action_shape, args.max_action, args.device
Net(args.state_shape, hidden_sizes=args.imitation_hidden_sizes,
device=args.device),
args.action_shape, max_action=args.max_action, device=args.device
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
il_policy = ImitationPolicy(net, optim, mode='continuous')

View File

@ -32,7 +32,8 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
@ -68,19 +69,20 @@ def test_td3(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, device=args.device)
actor = Actor(
net, args.action_shape,
args.max_action, args.device
).to(args.device)
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = Actor(net, args.action_shape, max_action=args.max_action,
device=args.device).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net_c1 = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic1 = Critic(net_c1, args.device).to(args.device)
net_c1 = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True, device=args.device)
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
net_c2 = Net(args.layer_num, args.state_shape,
args.action_shape, concat=True, device=args.device)
critic2 = Critic(net_c2, args.device).to(args.device)
net_c2 = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True, device=args.device)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,

View File

@ -27,7 +27,10 @@ def get_args():
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--repeat-per-collect', type=int, default=1)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=2)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128, 128])
parser.add_argument('--imitation-hidden-sizes', type=int,
nargs='*', default=[128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
@ -63,11 +66,12 @@ def test_a2c_with_il(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, device=args.device)
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = Actor(net, args.action_shape).to(args.device)
critic = Critic(net).to(args.device)
optim = torch.optim.Adam(list(
actor.parameters()) + list(critic.parameters()), lr=args.lr)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical
policy = A2CPolicy(
actor, critic, optim, dist, args.gamma, gae_lambda=args.gae_lambda,
@ -107,7 +111,8 @@ def test_a2c_with_il(args=get_args()):
# here we define an imitation collector with a trivial policy
if args.task == 'CartPole-v0':
env.spec.reward_threshold = 190 # lower the goal
net = Net(1, args.state_shape, device=args.device)
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
net = Actor(net, args.action_shape).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
il_policy = ImitationPolicy(net, optim, mode='discrete')

View File

@ -31,12 +31,14 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=3)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128, 128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--prioritized-replay', type=int, default=0)
parser.add_argument('--prioritized-replay',
action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
@ -63,7 +65,8 @@ def test_c51(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device,
net = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, device=args.device,
softmax=True, num_atoms=args.num_atoms)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = C51Policy(
@ -71,7 +74,7 @@ def test_c51(args=get_args()):
args.n_step, target_update_freq=args.target_update_freq
).to(args.device)
# buffer
if args.prioritized_replay > 0:
if args.prioritized_replay:
buf = PrioritizedReplayBuffer(
args.buffer_size, alpha=args.alpha, beta=args.beta)
else:
@ -125,7 +128,7 @@ def test_c51(args=get_args()):
def test_pc51(args=get_args()):
args.prioritized_replay = 1
args.prioritized_replay = True
args.gamma = .95
args.seed = 1
test_c51(args)

View File

@ -28,12 +28,14 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=3)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128, 128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--prioritized-replay', type=int, default=0)
parser.add_argument('--prioritized-replay',
action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
@ -59,16 +61,18 @@ def test_dqn(args=get_args()):
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# Q_param = V_param = {"hidden_sizes": [128]}
# model
net = Net(args.layer_num, args.state_shape,
args.action_shape, args.device, # dueling=(1, 1)
net = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, device=args.device,
# dueling=(Q_param, V_param),
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = DQNPolicy(
net, optim, args.gamma, args.n_step,
target_update_freq=args.target_update_freq)
# buffer
if args.prioritized_replay > 0:
if args.prioritized_replay:
buf = PrioritizedReplayBuffer(
args.buffer_size, alpha=args.alpha, beta=args.beta)
else:
@ -122,7 +126,7 @@ def test_dqn(args=get_args()):
def test_pdqn(args=get_args()):
args.prioritized_replay = 1
args.prioritized_replay = True
args.gamma = .95
args.seed = 1
test_dqn(args)

View File

@ -25,7 +25,8 @@ def get_args():
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=3)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128, 128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
@ -55,8 +56,8 @@ def test_pg(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(
args.layer_num, args.state_shape, args.action_shape,
net = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes,
device=args.device, softmax=True).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
dist = torch.distributions.Categorical

View File

@ -26,7 +26,8 @@ def get_args():
parser.add_argument('--collect-per-step', type=int, default=20)
parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--training-num', type=int, default=20)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
@ -65,7 +66,8 @@ def test_ppo(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, device=args.device)
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = Actor(net, args.action_shape).to(args.device)
critic = Critic(net).to(args.device)
# orthogonal initialization
@ -73,8 +75,8 @@ def test_ppo(args=get_args()):
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(list(
actor.parameters()) + list(critic.parameters()), lr=args.lr)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical
policy = PPOPolicy(
actor, critic, optim, dist, args.gamma,

View File

@ -30,7 +30,8 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=5)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128])
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
@ -59,13 +60,16 @@ def test_discrete_sac(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, device=args.device)
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = Actor(net, args.action_shape, softmax_output=False).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net_c1 = Net(args.layer_num, args.state_shape, device=args.device)
net_c1 = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
critic1 = Critic(net_c1, last_size=args.action_shape).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
net_c2 = Net(args.layer_num, args.state_shape, device=args.device)
net_c2 = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
critic2 = Critic(net_c2, last_size=args.action_shape).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)

View File

@ -31,7 +31,8 @@ def get_parser() -> argparse.ArgumentParser:
parser.add_argument('--step-per-epoch', type=int, default=500)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=3)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128, 128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
@ -75,8 +76,9 @@ def get_agents(
args.action_shape = env.action_space.shape or env.action_space.n
if agent_learn is None:
# model
net = Net(args.layer_num, args.state_shape, args.action_shape,
args.device).to(args.device)
net = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, device=args.device
).to(args.device)
if optim is None:
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
agent_learn = DQNPolicy(

View File

@ -1,93 +1,180 @@
import torch
import numpy as np
from torch import nn
from typing import Any, Dict, List, Tuple, Union, Callable, Optional, Sequence
from typing import Any, Dict, List, Type, Tuple, Union, Optional, Sequence
from tianshou.data import to_torch
ModuleType = Type[nn.Module]
def miniblock(
inp: int,
oup: int,
norm_layer: Optional[Callable[[int], nn.modules.Module]],
) -> List[nn.modules.Module]:
"""Construct a miniblock with given input/output-size and norm layer."""
ret: List[nn.modules.Module] = [nn.Linear(inp, oup)]
input_size: int,
output_size: int = 0,
norm_layer: Optional[ModuleType] = None,
activation: Optional[ModuleType] = None,
) -> List[nn.Module]:
"""Construct a miniblock with given input/output-size, norm layer and \
activation."""
layers: List[nn.Module] = [nn.Linear(input_size, output_size)]
if norm_layer is not None:
ret += [norm_layer(oup)]
ret += [nn.ReLU(inplace=True)]
return ret
layers += [norm_layer(output_size)] # type: ignore
if activation is not None:
layers += [activation()]
return layers
class Net(nn.Module):
class MLP(nn.Module):
"""Simple MLP backbone.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
Create a MLP of size input_dim * hidden_sizes[0] * hidden_sizes[1] * ...
* hidden_sizes[-1] * output_dim
:param bool concat: whether the input shape is concatenated by state_shape
and action_shape. If it is True, ``action_shape`` is not the output
shape, but affects the input shape.
:param bool dueling: whether to use dueling network to calculate Q values
(for Dueling DQN), defaults to False.
:param norm_layer: use which normalization before ReLU, e.g.,
``nn.LayerNorm`` and ``nn.BatchNorm1d``, defaults to None.
:param int num_atoms: in order to expand to the net of distributional RL,
defaults to 1.
:param int input_dim: dimension of the input vector.
:param int output_dim: dimension of the output vector. If set to 0, there
is no final linear layer.
:param hidden_sizes: shape of MLP passed in as a list, not incluing
input_dim and output_dim.
:param norm_layer: use which normalization before activation, e.g.,
``nn.LayerNorm`` and ``nn.BatchNorm1d``, defaults to no normalization.
You can also pass a list of normalization modules with the same length
of hidden_sizes, to use different normalization module in different
layers. Default to no normalization.
:param activation: which activation to use after each layer, can be both
the same actvition for all layers if passed in nn.Module, or different
activation for different Modules if passed in a list. Default to
nn.ReLU.
"""
def __init__(
self,
layer_num: int,
state_shape: tuple,
action_shape: Optional[Union[tuple, int]] = 0,
device: Union[str, int, torch.device] = "cpu",
softmax: bool = False,
concat: bool = False,
hidden_layer_size: int = 128,
dueling: Optional[Tuple[int, int]] = None,
norm_layer: Optional[Callable[[int], nn.modules.Module]] = None,
num_atoms: int = 1,
input_dim: int,
output_dim: int = 0,
hidden_sizes: Sequence[int] = (),
norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None,
activation: Optional[Union[ModuleType, Sequence[ModuleType]]]
= nn.ReLU,
device: Optional[Union[str, int, torch.device]] = None,
) -> None:
super().__init__()
self.device = device
if norm_layer:
if isinstance(norm_layer, list):
assert len(norm_layer) == len(hidden_sizes)
norm_layer_list = norm_layer
else:
norm_layer_list = [
norm_layer for _ in range(len(hidden_sizes))]
else:
norm_layer_list = [None] * len(hidden_sizes)
if activation:
if isinstance(activation, list):
assert len(activation) == len(hidden_sizes)
activation_list = activation
else:
activation_list = [
activation for _ in range(len(hidden_sizes))]
else:
activation_list = [None] * len(hidden_sizes)
hidden_sizes = [input_dim] + list(hidden_sizes)
model = []
for in_dim, out_dim, norm, activ in zip(
hidden_sizes[:-1], hidden_sizes[1:],
norm_layer_list, activation_list):
model += miniblock(in_dim, out_dim, norm, activ)
if output_dim > 0:
model += [nn.Linear(hidden_sizes[-1], output_dim)]
self.output_dim = output_dim or hidden_sizes[-1]
self.model = nn.Sequential(*model)
def forward(
self, x: Union[np.ndarray, torch.Tensor]
) -> torch.Tensor:
x = torch.as_tensor(
x, device=self.device, dtype=torch.float32) # type: ignore
return self.model(x.flatten(1))
class Net(nn.Module):
"""Wrapper of MLP to support more specific DRL usage.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
:param state_shape: int or a sequence of int of the shape of state.
:param action_shape: int or a sequence of int of the shape of action.
:param hidden_sizes: shape of MLP passed in as a list.
:param norm_layer: use which normalization before activation, e.g.,
``nn.LayerNorm`` and ``nn.BatchNorm1d``, defaults to no normalization.
You can also pass a list of normalization modules with the same length
of hidden_sizes, to use different normalization module in different
layers. Default to no normalization.
:param activation: which activation to use after each layer, can be both
the same actvition for all layers if passed in nn.Module, or different
activation for different Modules if passed in a list. Default to
nn.ReLU.
:param device: specify the device when the network actually runs. Default
to "cpu".
:param bool softmax: whether to apply a softmax layer over the last layer's
output.
:param bool concat: whether the input shape is concatenated by state_shape
and action_shape. If it is True, ``action_shape`` is not the output
shape, but affects the input shape only.
:param int num_atoms: in order to expand to the net of distributional RL,
defaults to 1 (not use).
:param bool dueling_param: whether to use dueling network to calculate Q
values (for Dueling DQN). If you want to use dueling option, you should
pass a tuple of two dict (first for Q and second for V) stating
self-defined arguments as stated in
class:`~tianshou.utils.net.common.MLP`. Defaults to None.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.MLP` for more
detailed explanation on the usage of activation, norm_layer, etc.
You can also refer to :class:`~tianshou.utils.net.continuous.Actor`,
:class:`~tianshou.utils.net.continuous.Critic`, etc, to see how it's
suggested be used.
"""
def __init__(
self,
state_shape: Union[int, Sequence[int]],
action_shape: Optional[Union[int, Sequence[int]]] = 0,
hidden_sizes: Sequence[int] = (),
norm_layer: Optional[ModuleType] = None,
activation: Optional[ModuleType] = nn.ReLU,
device: Union[str, int, torch.device] = "cpu",
softmax: bool = False,
concat: bool = False,
num_atoms: int = 1,
dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None,
) -> None:
super().__init__()
self.device = device
self.dueling = dueling
self.softmax = softmax
self.num_atoms = num_atoms
self.action_num = np.prod(action_shape)
input_size = np.prod(state_shape)
input_dim = np.prod(state_shape)
action_dim = np.prod(action_shape) * num_atoms
if concat:
input_size += np.prod(action_shape)
model = miniblock(input_size, hidden_layer_size, norm_layer)
for i in range(layer_num):
model += miniblock(
hidden_layer_size, hidden_layer_size, norm_layer)
if dueling is None:
if action_shape and not concat:
model += [nn.Linear(
hidden_layer_size, num_atoms * self.action_num)]
else: # dueling DQN
q_layer_num, v_layer_num = dueling
Q, V = [], []
for i in range(q_layer_num):
Q += miniblock(
hidden_layer_size, hidden_layer_size, norm_layer)
for i in range(v_layer_num):
V += miniblock(
hidden_layer_size, hidden_layer_size, norm_layer)
if action_shape and not concat:
Q += [nn.Linear(
hidden_layer_size, num_atoms * self.action_num)]
V += [nn.Linear(hidden_layer_size, num_atoms)]
self.Q = nn.Sequential(*Q)
self.V = nn.Sequential(*V)
self.model = nn.Sequential(*model)
input_dim += action_dim
self.use_dueling = dueling_param is not None
output_dim = action_dim if not self.use_dueling and not concat else 0
self.model = MLP(input_dim, output_dim, hidden_sizes,
norm_layer, activation, device)
self.output_dim = self.model.output_dim
if self.use_dueling: # dueling DQN
q_kwargs, v_kwargs = dueling_param # type: ignore
q_output_dim, v_output_dim = 0, 0
if not concat:
q_output_dim, v_output_dim = action_dim, num_atoms
q_kwargs: Dict[str, Any] = {
**q_kwargs, "input_dim": self.output_dim,
"output_dim": q_output_dim}
v_kwargs: Dict[str, Any] = {
**v_kwargs, "input_dim": self.output_dim,
"output_dim": v_output_dim}
self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs)
self.output_dim = self.Q.output_dim
def forward(
self,
@ -95,18 +182,17 @@ class Net(nn.Module):
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
"""Mapping: s -> flatten -> logits."""
s = to_torch(s, device=self.device, dtype=torch.float32)
s = s.reshape(s.size(0), -1)
"""Mapping: s -> flatten (inside MLP)-> logits."""
logits = self.model(s)
if self.dueling is not None: # Dueling DQN
bsz = logits.shape[0]
if self.use_dueling: # Dueling DQN
q, v = self.Q(logits), self.V(logits)
if self.num_atoms > 1:
v = v.view(-1, 1, self.num_atoms)
q = q.view(-1, self.action_num, self.num_atoms)
q = q.view(bsz, -1, self.num_atoms)
v = v.view(bsz, -1, self.num_atoms)
logits = q - q.mean(dim=1, keepdim=True) + v
elif self.num_atoms > 1:
logits = logits.view(-1, self.action_num, self.num_atoms)
logits = logits.view(bsz, -1, self.num_atoms)
if self.softmax:
logits = torch.softmax(logits, dim=-1)
return logits, state
@ -122,14 +208,12 @@ class Recurrent(nn.Module):
def __init__(
self,
layer_num: int,
state_shape: Sequence[int],
action_shape: Sequence[int],
state_shape: Union[int, Sequence[int]],
action_shape: Union[int, Sequence[int]],
device: Union[str, int, torch.device] = "cpu",
hidden_layer_size: int = 128,
) -> None:
super().__init__()
self.state_shape = state_shape
self.action_shape = action_shape
self.device = device
self.nn = nn.LSTM(
input_size=hidden_layer_size,
@ -152,7 +236,8 @@ class Recurrent(nn.Module):
training mode, s should be with shape ``[bsz, len, dim]``. See the code
and comment for more detail.
"""
s = to_torch(s, device=self.device, dtype=torch.float32)
s = torch.as_tensor(
s, device=self.device, dtype=torch.float32) # type: ignore
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which
# in evaluation phase.

View File

@ -3,7 +3,7 @@ import numpy as np
from torch import nn
from typing import Any, Dict, Tuple, Union, Optional, Sequence
from tianshou.data import to_torch, to_torch_as
from tianshou.utils.net.common import MLP
SIGMA_MIN = -20
@ -11,23 +11,45 @@ SIGMA_MAX = 2
class Actor(nn.Module):
"""Simple actor network with MLP.
"""Simple actor network. Will create an actor operated in continuous \
action space with structure of preprocess_net ---> action_shape.
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param action_shape: a sequence of int for the shape of action.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param float max_action: the scale for the final action logits. Default to
1.
:param int preprocess_net_output_dim: the output dimension of
preprocess_net.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
"""
def __init__(
self,
preprocess_net: nn.Module,
action_shape: Sequence[int],
hidden_sizes: Sequence[int] = (),
max_action: float = 1.0,
device: Union[str, int, torch.device] = "cpu",
hidden_layer_size: int = 128,
preprocess_net_output_dim: Optional[int] = None,
) -> None:
super().__init__()
self.device = device
self.preprocess = preprocess_net
self.last = nn.Linear(hidden_layer_size, np.prod(action_shape))
self.output_dim = np.prod(action_shape)
input_dim = getattr(preprocess_net, "output_dim",
preprocess_net_output_dim)
self.last = MLP(input_dim, self.output_dim, hidden_sizes)
self._max = max_action
def forward(
@ -43,22 +65,40 @@ class Actor(nn.Module):
class Critic(nn.Module):
"""Simple critic network with MLP.
"""Simple critic network. Will create an actor operated in continuous \
action space with structure of preprocess_net ---> 1(q value).
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param int preprocess_net_output_dim: the output dimension of
preprocess_net.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
"""
def __init__(
self,
preprocess_net: nn.Module,
hidden_sizes: Sequence[int] = (),
device: Union[str, int, torch.device] = "cpu",
hidden_layer_size: int = 128,
preprocess_net_output_dim: Optional[int] = None,
) -> None:
super().__init__()
self.device = device
self.preprocess = preprocess_net
self.last = nn.Linear(hidden_layer_size, 1)
self.output_dim = 1
input_dim = getattr(preprocess_net, "output_dim",
preprocess_net_output_dim)
self.last = MLP(input_dim, 1, hidden_sizes)
def forward(
self,
@ -67,11 +107,13 @@ class Critic(nn.Module):
info: Dict[str, Any] = {},
) -> torch.Tensor:
"""Mapping: (s, a) -> logits -> Q(s, a)."""
s = to_torch(s, device=self.device, dtype=torch.float32)
s = s.flatten(1)
s = torch.as_tensor(
s, device=self.device, dtype=torch.float32 # type: ignore
).flatten(1)
if a is not None:
a = to_torch_as(a, s)
a = a.flatten(1)
a = torch.as_tensor(
a, device=self.device, dtype=torch.float32 # type: ignore
).flatten(1)
s = torch.cat([s, a], dim=1)
logits, h = self.preprocess(s)
logits = self.last(logits)
@ -79,31 +121,55 @@ class Critic(nn.Module):
class ActorProb(nn.Module):
"""Simple actor network (output with a Gauss distribution) with MLP.
"""Simple actor network (output with a Gauss distribution).
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param action_shape: a sequence of int for the shape of action.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param float max_action: the scale for the final action logits. Default to
1.
:param bool unbounded: whether to apply tanh activation on final logits.
Default to False.
:param bool conditioned_sigma: True when sigma is calculated from the
input, False when sigma is an independent parameter. Default to False.
:param int preprocess_net_output_dim: the output dimension of
preprocess_net.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
"""
def __init__(
self,
preprocess_net: nn.Module,
action_shape: Sequence[int],
hidden_sizes: Sequence[int] = (),
max_action: float = 1.0,
device: Union[str, int, torch.device] = "cpu",
unbounded: bool = False,
hidden_layer_size: int = 128,
conditioned_sigma: bool = False,
preprocess_net_output_dim: Optional[int] = None,
) -> None:
super().__init__()
self.preprocess = preprocess_net
self.device = device
self.mu = nn.Linear(hidden_layer_size, np.prod(action_shape))
self.output_dim = np.prod(action_shape)
input_dim = getattr(preprocess_net, "output_dim",
preprocess_net_output_dim)
self.mu = MLP(input_dim, self.output_dim, hidden_sizes)
self._c_sigma = conditioned_sigma
if conditioned_sigma:
self.sigma = nn.Linear(hidden_layer_size, np.prod(action_shape))
self.sigma = MLP(input_dim, self.output_dim, hidden_sizes)
else:
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1))
self._max = max_action
self._unbounded = unbounded
@ -125,7 +191,7 @@ class ActorProb(nn.Module):
else:
shape = [1] * len(mu.shape)
shape[1] = -1
sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp()
sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp()
return (mu, sigma), state
@ -141,10 +207,10 @@ class RecurrentActorProb(nn.Module):
layer_num: int,
state_shape: Sequence[int],
action_shape: Sequence[int],
hidden_layer_size: int = 128,
max_action: float = 1.0,
device: Union[str, int, torch.device] = "cpu",
unbounded: bool = False,
hidden_layer_size: int = 128,
conditioned_sigma: bool = False,
) -> None:
super().__init__()
@ -155,12 +221,13 @@ class RecurrentActorProb(nn.Module):
num_layers=layer_num,
batch_first=True,
)
self.mu = nn.Linear(hidden_layer_size, np.prod(action_shape))
output_dim = np.prod(action_shape)
self.mu = nn.Linear(hidden_layer_size, output_dim)
self._c_sigma = conditioned_sigma
if conditioned_sigma:
self.sigma = nn.Linear(hidden_layer_size, np.prod(action_shape))
self.sigma = nn.Linear(hidden_layer_size, output_dim)
else:
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
self.sigma_param = nn.Parameter(torch.zeros(output_dim, 1))
self._max = max_action
self._unbounded = unbounded
@ -171,7 +238,8 @@ class RecurrentActorProb(nn.Module):
info: Dict[str, Any] = {},
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Dict[str, torch.Tensor]]:
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
s = to_torch(s, device=self.device, dtype=torch.float32)
s = torch.as_tensor(
s, device=self.device, dtype=torch.float32) # type: ignore
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which
# in evaluation phase.
@ -196,7 +264,7 @@ class RecurrentActorProb(nn.Module):
else:
shape = [1] * len(mu.shape)
shape[1] = -1
sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp()
sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp()
# please ensure the first dim is batch size: [bsz, len, ...]
return (mu, sigma), {"h": h.transpose(0, 1).detach(),
"c": c.transpose(0, 1).detach()}
@ -236,7 +304,8 @@ class RecurrentCritic(nn.Module):
info: Dict[str, Any] = {},
) -> torch.Tensor:
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
s = to_torch(s, device=self.device, dtype=torch.float32)
s = torch.as_tensor(
s, device=self.device, dtype=torch.float32) # type: ignore
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which
# in evaluation phase.
@ -245,7 +314,8 @@ class RecurrentCritic(nn.Module):
s, (h, c) = self.nn(s)
s = s[:, -1]
if a is not None:
a = to_torch_as(a, s)
a = torch.as_tensor(
a, device=self.device, dtype=torch.float32) # type: ignore
s = torch.cat([s, a], dim=1)
s = self.fc2(s)
return s

View File

@ -4,26 +4,49 @@ from torch import nn
import torch.nn.functional as F
from typing import Any, Dict, Tuple, Union, Optional, Sequence
from tianshou.data import to_torch
from tianshou.utils.net.common import MLP
class Actor(nn.Module):
"""Simple actor network with MLP.
"""Simple actor network.
Will create an actor operated in discrete action space with structure of
preprocess_net ---> action_shape.
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param action_shape: a sequence of int for the shape of action.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param bool softmax_output: whether to apply a softmax layer over the last
layer's output.
:param int preprocess_net_output_dim: the output dimension of
preprocess_net.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
"""
def __init__(
self,
preprocess_net: nn.Module,
action_shape: Sequence[int],
hidden_layer_size: int = 128,
hidden_sizes: Sequence[int] = (),
softmax_output: bool = True,
preprocess_net_output_dim: Optional[int] = None,
) -> None:
super().__init__()
self.preprocess = preprocess_net
self.last = nn.Linear(hidden_layer_size, np.prod(action_shape))
self.output_dim = np.prod(action_shape)
input_dim = getattr(preprocess_net, "output_dim",
preprocess_net_output_dim)
self.last = MLP(input_dim, self.output_dim, hidden_sizes)
self.softmax_output = softmax_output
def forward(
@ -41,125 +64,44 @@ class Actor(nn.Module):
class Critic(nn.Module):
"""Simple critic network with MLP.
"""Simple critic network. Will create an actor operated in discrete \
action space with structure of preprocess_net ---> 1(q value).
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param int last_size: the output dimension of Critic network. Default to 1.
:param int preprocess_net_output_dim: the output dimension of
preprocess_net.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
"""
def __init__(
self,
preprocess_net: nn.Module,
hidden_layer_size: int = 128,
last_size: int = 1
hidden_sizes: Sequence[int] = (),
last_size: int = 1,
preprocess_net_output_dim: Optional[int] = None,
) -> None:
super().__init__()
self.preprocess = preprocess_net
self.last = nn.Linear(hidden_layer_size, last_size)
self.output_dim = last_size
input_dim = getattr(preprocess_net, "output_dim",
preprocess_net_output_dim)
self.last = MLP(input_dim, last_size, hidden_sizes)
def forward(
self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any
) -> torch.Tensor:
"""Mapping: s -> V(s)."""
logits, h = self.preprocess(s, state=kwargs.get("state", None))
logits = self.last(logits)
return logits
class DQN(nn.Module):
"""Reference: Human-level control through deep reinforcement learning.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(
self,
c: int,
h: int,
w: int,
action_shape: Sequence[int],
device: Union[str, int, torch.device] = "cpu",
) -> None:
super().__init__()
self.device = device
def conv2d_size_out(
size: int, kernel_size: int = 5, stride: int = 2
) -> int:
return (size - (kernel_size - 1) - 1) // stride + 1
def conv2d_layers_size_out(
size: int,
kernel_size_1: int = 8,
stride_1: int = 4,
kernel_size_2: int = 4,
stride_2: int = 2,
kernel_size_3: int = 3,
stride_3: int = 1,
) -> int:
size = conv2d_size_out(size, kernel_size_1, stride_1)
size = conv2d_size_out(size, kernel_size_2, stride_2)
size = conv2d_size_out(size, kernel_size_3, stride_3)
return size
convw = conv2d_layers_size_out(w)
convh = conv2d_layers_size_out(h)
linear_input_size = convw * convh * 64
self.net = nn.Sequential(
nn.Conv2d(c, 32, kernel_size=8, stride=4),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU(inplace=True),
nn.Flatten(),
nn.Linear(linear_input_size, 512),
nn.ReLU(inplace=True),
nn.Linear(512, np.prod(action_shape)),
)
def forward(
self,
x: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Q(x, \*)."""
if not isinstance(x, torch.Tensor):
x = to_torch(x, device=self.device, dtype=torch.float32)
return self.net(x), state
class C51(DQN):
"""Reference: A distributional perspective on reinforcement learning.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(
self,
c: int,
h: int,
w: int,
action_shape: Sequence[int],
num_atoms: int = 51,
device: Union[str, int, torch.device] = "cpu",
) -> None:
super().__init__(c, h, w, [np.prod(action_shape) * num_atoms], device)
self.action_shape = action_shape
self.num_atoms = num_atoms
def forward(
self,
x: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Z(x, \*)."""
x, state = super().forward(x)
x = x.view(-1, self.num_atoms).softmax(dim=-1)
x = x.view(-1, np.prod(self.action_shape), self.num_atoms)
return x, state
logits, _ = self.preprocess(s, state=kwargs.get("state", None))
return self.last(logits)