update utils.network (#275)
This is the first commit of 6 commits mentioned in #274, which features 1. Refactor of `Class Net` to support any form of MLP. 2. Enable type check in utils.network. 3. Relative change in docs/test/examples. 4. Move atari-related network to examples/atari/atari_network.py Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
This commit is contained in:
parent
866e35d550
commit
a633a6a028
@ -213,7 +213,7 @@ from tianshou.utils.net.common import Net
|
||||
env = gym.make(task)
|
||||
state_shape = env.observation_space.shape or env.observation_space.n
|
||||
action_shape = env.action_space.shape or env.action_space.n
|
||||
net = Net(layer_num=2, state_shape=state_shape, action_shape=action_shape)
|
||||
net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128])
|
||||
optim = torch.optim.Adam(net.parameters(), lr=lr)
|
||||
```
|
||||
|
||||
|
@ -203,7 +203,8 @@ The explanation of each Tianshou class/function will be deferred to their first
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=3)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128, 128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
@ -249,7 +250,8 @@ Here it is:
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
|
||||
if agent_learn is None:
|
||||
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device).to(args.device)
|
||||
net = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes, device=args.device).to(args.device)
|
||||
if optim is None:
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
agent_learn = DQNPolicy(
|
||||
|
@ -7,10 +7,10 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import C51Policy
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.utils.net.discrete import C51
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
|
||||
from atari_network import C51
|
||||
from atari_wrapper import wrap_deepmind
|
||||
|
||||
|
||||
@ -40,8 +40,8 @@ def get_args():
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
parser.add_argument('--frames_stack', type=int, default=4)
|
||||
parser.add_argument('--resume_path', type=str, default=None)
|
||||
parser.add_argument('--frames-stack', type=int, default=4)
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
parser.add_argument('--watch', default=False, action='store_true',
|
||||
help='watch the play of pre-trained policy only')
|
||||
return parser.parse_args()
|
||||
|
@ -7,10 +7,10 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.utils.net.discrete import DQN
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
|
||||
from atari_network import DQN
|
||||
from atari_wrapper import wrap_deepmind
|
||||
|
||||
|
||||
@ -37,8 +37,8 @@ def get_args():
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
parser.add_argument('--frames_stack', type=int, default=4)
|
||||
parser.add_argument('--resume_path', type=str, default=None)
|
||||
parser.add_argument('--frames-stack', type=int, default=4)
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
parser.add_argument('--watch', default=False, action='store_true',
|
||||
help='watch the play of pre-trained policy only')
|
||||
return parser.parse_args()
|
||||
|
82
examples/atari/atari_network.py
Normal file
82
examples/atari/atari_network.py
Normal file
@ -0,0 +1,82 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from typing import Any, Dict, Tuple, Union, Optional, Sequence
|
||||
|
||||
|
||||
class DQN(nn.Module):
|
||||
"""Reference: Human-level control through deep reinforcement learning.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c: int,
|
||||
h: int,
|
||||
w: int,
|
||||
action_shape: Sequence[int],
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
features_only: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(c, 32, kernel_size=8, stride=4), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(32, 64, kernel_size=4, stride=2), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1), nn.ReLU(inplace=True),
|
||||
nn.Flatten())
|
||||
with torch.no_grad():
|
||||
self.output_dim = np.prod(
|
||||
self.net(torch.zeros(1, c, h, w)).shape[1:])
|
||||
if not features_only:
|
||||
self.net = nn.Sequential(
|
||||
self.net,
|
||||
nn.Linear(self.output_dim, 512), nn.ReLU(inplace=True),
|
||||
nn.Linear(512, np.prod(action_shape)))
|
||||
self.output_dim = np.prod(action_shape)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Any] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
r"""Mapping: x -> Q(x, \*)."""
|
||||
x = torch.as_tensor(
|
||||
x, device=self.device, dtype=torch.float32) # type: ignore
|
||||
return self.net(x), state
|
||||
|
||||
|
||||
class C51(DQN):
|
||||
"""Reference: A distributional perspective on reinforcement learning.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c: int,
|
||||
h: int,
|
||||
w: int,
|
||||
action_shape: Sequence[int],
|
||||
num_atoms: int = 51,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
) -> None:
|
||||
super().__init__(c, h, w, [np.prod(action_shape) * num_atoms], device)
|
||||
self.action_shape = action_shape
|
||||
self.num_atoms = num_atoms
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Any] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
r"""Mapping: x -> Z(x, \*)."""
|
||||
x, state = super().forward(x)
|
||||
x = x.view(-1, self.num_atoms).softmax(dim=-1)
|
||||
x = x.view(-1, np.prod(self.action_shape), self.num_atoms)
|
||||
return x, state
|
@ -7,10 +7,10 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import A2CPolicy
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.utils.net.discrete import Actor, Critic
|
||||
from tianshou.utils.net.common import Net
|
||||
|
||||
from atari import create_atari_environment, preprocess_fn
|
||||
|
||||
@ -27,7 +27,8 @@ def get_args():
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=1)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=2)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=8)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
@ -40,7 +41,7 @@ def get_args():
|
||||
parser.add_argument('--vf-coef', type=float, default=0.5)
|
||||
parser.add_argument('--ent-coef', type=float, default=0.001)
|
||||
parser.add_argument('--max-grad-norm', type=float, default=None)
|
||||
parser.add_argument('--max_episode_steps', type=int, default=2000)
|
||||
parser.add_argument('--max-episode-steps', type=int, default=2000)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -62,11 +63,12 @@ def test_a2c(args=get_args()):
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
actor = Actor(net, args.action_shape).to(args.device)
|
||||
critic = Critic(net).to(args.device)
|
||||
optim = torch.optim.Adam(list(
|
||||
actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
||||
optim = torch.optim.Adam(set(
|
||||
actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||
dist = torch.distributions.Categorical
|
||||
policy = A2CPolicy(
|
||||
actor, critic, optim, dist, args.gamma, vf_coef=args.vf_coef,
|
||||
|
@ -7,10 +7,10 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import PPOPolicy
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.utils.net.discrete import Actor, Critic
|
||||
from tianshou.utils.net.common import Net
|
||||
|
||||
from atari import create_atari_environment, preprocess_fn
|
||||
|
||||
@ -27,7 +27,8 @@ def get_args():
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=1)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=8)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
@ -40,7 +41,7 @@ def get_args():
|
||||
parser.add_argument('--ent-coef', type=float, default=0.0)
|
||||
parser.add_argument('--eps-clip', type=float, default=0.2)
|
||||
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
||||
parser.add_argument('--max_episode_steps', type=int, default=2000)
|
||||
parser.add_argument('--max-episode-steps', type=int, default=2000)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -62,11 +63,12 @@ def test_ppo(args=get_args()):
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
actor = Actor(net, args.action_shape).to(args.device)
|
||||
critic = Critic(net).to(args.device)
|
||||
optim = torch.optim.Adam(list(
|
||||
actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
||||
optim = torch.optim.Adam(set(
|
||||
actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||
dist = torch.distributions.Categorical
|
||||
policy = PPOPolicy(
|
||||
actor, critic, optim, dist, args.gamma,
|
||||
|
@ -28,7 +28,12 @@ def get_args():
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=100)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=0)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128])
|
||||
parser.add_argument('--dueling-q-hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128])
|
||||
parser.add_argument('--dueling-v-hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
@ -56,8 +61,11 @@ def test_dqn(args=get_args()):
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape,
|
||||
args.action_shape, args.device, dueling=(2, 2)).to(args.device)
|
||||
Q_param = {"hidden_sizes": args.dueling_q_hidden_sizes}
|
||||
V_param = {"hidden_sizes": args.dueling_v_hidden_sizes}
|
||||
net = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes, device=args.device,
|
||||
dueling_param=(Q_param, V_param)).to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
policy = DQNPolicy(
|
||||
net, optim, args.gamma, args.n_step,
|
||||
|
@ -30,7 +30,8 @@ def get_args():
|
||||
parser.add_argument('--step-per-epoch', type=int, default=10000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=128)
|
||||
parser.add_argument('--layer-num', type=int, default=1)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
@ -87,20 +88,23 @@ def test_sac_bipedal(args=get_args()):
|
||||
test_envs.seed(args.seed)
|
||||
|
||||
# model
|
||||
net_a = Net(args.layer_num, args.state_shape, device=args.device)
|
||||
net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
actor = ActorProb(
|
||||
net_a, args.action_shape, args.max_action, args.device, unbounded=True
|
||||
).to(args.device)
|
||||
net_a, args.action_shape, max_action=args.max_action,
|
||||
device=args.device, unbounded=True).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
|
||||
net_c1 = Net(args.layer_num, args.state_shape,
|
||||
args.action_shape, concat=True, device=args.device)
|
||||
critic1 = Critic(net_c1, args.device).to(args.device)
|
||||
net_c1 = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
concat=True, device=args.device)
|
||||
critic1 = Critic(net_c1, device=args.device).to(args.device)
|
||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||
|
||||
net_c2 = Net(args.layer_num, args.state_shape,
|
||||
args.action_shape, concat=True, device=args.device)
|
||||
critic2 = Critic(net_c2, args.device).to(args.device)
|
||||
net_c2 = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
concat=True, device=args.device)
|
||||
critic2 = Critic(net_c2, device=args.device).to(args.device)
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
|
||||
if args.auto_alpha:
|
||||
|
@ -29,7 +29,12 @@ def get_args():
|
||||
parser.add_argument('--step-per-epoch', type=int, default=5000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=16)
|
||||
parser.add_argument('--batch-size', type=int, default=128)
|
||||
parser.add_argument('--layer-num', type=int, default=1)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128])
|
||||
parser.add_argument('--dueling-q-hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128])
|
||||
parser.add_argument('--dueling-v-hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=10)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
@ -57,9 +62,11 @@ def test_dqn(args=get_args()):
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape,
|
||||
args.action_shape, args.device,
|
||||
dueling=(2, 2)).to(args.device)
|
||||
Q_param = {"hidden_sizes": args.dueling_q_hidden_sizes}
|
||||
V_param = {"hidden_sizes": args.dueling_v_hidden_sizes}
|
||||
net = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes, device=args.device,
|
||||
dueling_param=(Q_param, V_param)).to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
policy = DQNPolicy(
|
||||
net, optim, args.gamma, args.n_step,
|
||||
|
@ -32,7 +32,8 @@ def get_args():
|
||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
||||
parser.add_argument('--collect-per-step', type=int, default=5)
|
||||
parser.add_argument('--batch-size', type=int, default=128)
|
||||
parser.add_argument('--layer-num', type=int, default=1)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=16)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
@ -61,19 +62,22 @@ def test_sac(args=get_args()):
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
actor = ActorProb(
|
||||
net, args.action_shape,
|
||||
args.max_action, args.device, unbounded=True
|
||||
max_action=args.max_action, device=args.device, unbounded=True
|
||||
).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
net_c1 = Net(args.layer_num, args.state_shape,
|
||||
args.action_shape, concat=True, device=args.device)
|
||||
critic1 = Critic(net_c1, args.device).to(args.device)
|
||||
net_c1 = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes, concat=True,
|
||||
device=args.device)
|
||||
critic1 = Critic(net_c1, device=args.device).to(args.device)
|
||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||
net_c2 = Net(args.layer_num, args.state_shape,
|
||||
args.action_shape, concat=True, device=args.device)
|
||||
critic2 = Critic(net_c2, args.device).to(args.device)
|
||||
net_c2 = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes, concat=True,
|
||||
device=args.device)
|
||||
critic2 = Critic(net_c2, device=args.device).to(args.device)
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
|
||||
if args.auto_alpha:
|
||||
|
@ -33,8 +33,8 @@ def get_args():
|
||||
parser.add_argument('--update-per-step', type=int, default=1)
|
||||
parser.add_argument('--pre-collect-step', type=int, default=10000)
|
||||
parser.add_argument('--batch-size', type=int, default=256)
|
||||
parser.add_argument('--hidden-layer-size', type=int, default=256)
|
||||
parser.add_argument('--layer-num', type=int, default=1)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=16)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
@ -70,26 +70,22 @@ def test_sac(args=get_args()):
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape, device=args.device,
|
||||
hidden_layer_size=args.hidden_layer_size)
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
actor = ActorProb(
|
||||
net, args.action_shape, args.max_action, args.device, unbounded=True,
|
||||
hidden_layer_size=args.hidden_layer_size, conditioned_sigma=True,
|
||||
net, args.action_shape, max_action=args.max_action,
|
||||
device=args.device, unbounded=True, conditioned_sigma=True
|
||||
).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
net_c1 = Net(args.layer_num, args.state_shape, args.action_shape,
|
||||
concat=True, device=args.device,
|
||||
hidden_layer_size=args.hidden_layer_size)
|
||||
critic1 = Critic(
|
||||
net_c1, args.device, hidden_layer_size=args.hidden_layer_size
|
||||
).to(args.device)
|
||||
net_c1 = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
concat=True, device=args.device)
|
||||
critic1 = Critic(net_c1, device=args.device).to(args.device)
|
||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||
net_c2 = Net(args.layer_num, args.state_shape, args.action_shape,
|
||||
concat=True, device=args.device,
|
||||
hidden_layer_size=args.hidden_layer_size)
|
||||
critic2 = Critic(
|
||||
net_c2, args.device, hidden_layer_size=args.hidden_layer_size
|
||||
).to(args.device)
|
||||
net_c2 = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
concat=True, device=args.device)
|
||||
critic2 = Critic(net_c2, device=args.device).to(args.device)
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
|
||||
if args.auto_alpha:
|
||||
|
@ -28,7 +28,8 @@ def get_args():
|
||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
||||
parser.add_argument('--collect-per-step', type=int, default=4)
|
||||
parser.add_argument('--batch-size', type=int, default=128)
|
||||
parser.add_argument('--layer-num', type=int, default=1)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
@ -56,13 +57,14 @@ def test_ddpg(args=get_args()):
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||
actor = Actor(net, args.action_shape, args.max_action,
|
||||
args.device).to(args.device)
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
actor = Actor(net, args.action_shape, max_action=args.max_action,
|
||||
device=args.device).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
net = Net(args.layer_num, args.state_shape,
|
||||
args.action_shape, concat=True, device=args.device)
|
||||
critic = Critic(net, args.device).to(args.device)
|
||||
net = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes, concat=True, device=args.device)
|
||||
critic = Critic(net, device=args.device).to(args.device)
|
||||
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
|
||||
policy = DDPGPolicy(
|
||||
actor, actor_optim, critic, critic_optim,
|
||||
|
@ -31,7 +31,8 @@ def get_args():
|
||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=128)
|
||||
parser.add_argument('--layer-num', type=int, default=1)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
@ -59,17 +60,16 @@ def test_td3(args=get_args()):
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||
actor = Actor(
|
||||
net, args.action_shape,
|
||||
args.max_action, args.device
|
||||
).to(args.device)
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
actor = Actor(net, args.action_shape, max_action=args.max_action,
|
||||
device=args.device).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
net = Net(args.layer_num, args.state_shape,
|
||||
args.action_shape, concat=True, device=args.device)
|
||||
critic1 = Critic(net, args.device).to(args.device)
|
||||
net = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes, concat=True, device=args.device)
|
||||
critic1 = Critic(net, device=args.device).to(args.device)
|
||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||
critic2 = Critic(net, args.device).to(args.device)
|
||||
critic2 = Critic(net, device=args.device).to(args.device)
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
policy = TD3Policy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
|
@ -30,7 +30,8 @@ def get_args():
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=128)
|
||||
parser.add_argument('--layer-num', type=int, default=1)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=4)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
@ -61,19 +62,18 @@ def test_sac(args=get_args()):
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||
actor = ActorProb(
|
||||
net, args.action_shape,
|
||||
args.max_action, args.device, unbounded=True
|
||||
).to(args.device)
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
actor = ActorProb(net, args.action_shape, max_action=args.max_action,
|
||||
device=args.device, unbounded=True).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
net = Net(args.layer_num, args.state_shape,
|
||||
args.action_shape, concat=True, device=args.device)
|
||||
critic1 = Critic(net, args.device).to(args.device)
|
||||
net = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes, concat=True, device=args.device)
|
||||
critic1 = Critic(net, device=args.device).to(args.device)
|
||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||
net = Net(args.layer_num, args.state_shape,
|
||||
args.action_shape, concat=True, device=args.device)
|
||||
critic2 = Critic(net, args.device).to(args.device)
|
||||
net = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes, concat=True, device=args.device)
|
||||
critic2 = Critic(net, device=args.device).to(args.device)
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
policy = SACPolicy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
|
@ -33,7 +33,8 @@ def get_args():
|
||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=128)
|
||||
parser.add_argument('--layer-num', type=int, default=1)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
@ -62,19 +63,18 @@ def test_td3(args=get_args()):
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||
actor = Actor(
|
||||
net, args.action_shape,
|
||||
args.max_action, args.device
|
||||
).to(args.device)
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
actor = Actor(net, args.action_shape, max_action=args.max_action,
|
||||
device=args.device).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
net = Net(args.layer_num, args.state_shape,
|
||||
args.action_shape, concat=True, device=args.device)
|
||||
critic1 = Critic(net, args.device).to(args.device)
|
||||
net = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes, concat=True, device=args.device)
|
||||
critic1 = Critic(net, device=args.device).to(args.device)
|
||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||
net = Net(args.layer_num, args.state_shape,
|
||||
args.action_shape, concat=True, device=args.device)
|
||||
critic2 = Critic(net, args.device).to(args.device)
|
||||
net = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes, concat=True, device=args.device)
|
||||
critic2 = Critic(net, device=args.device).to(args.device)
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
policy = TD3Policy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
|
@ -18,9 +18,6 @@ warn_unreachable = True
|
||||
warn_unused_configs = True
|
||||
warn_unused_ignores = True
|
||||
|
||||
[mypy-tianshou.utils.net.*]
|
||||
ignore_errors = True
|
||||
|
||||
[pydocstyle]
|
||||
ignore = D100,D102,D104,D105,D107,D203,D213,D401,D402
|
||||
|
||||
|
@ -3,8 +3,7 @@ import numpy as np
|
||||
|
||||
from tianshou.utils import MovAvg
|
||||
from tianshou.utils import SummaryWriter
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.discrete import DQN, C51
|
||||
from tianshou.utils.net.common import MLP, Net
|
||||
from tianshou.exploration import GaussianNoise, OUNoise
|
||||
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
|
||||
|
||||
@ -35,17 +34,39 @@ def test_moving_average():
|
||||
def test_net():
|
||||
# here test the networks that does not appear in the other script
|
||||
bsz = 64
|
||||
# MLP
|
||||
data = torch.rand([bsz, 3])
|
||||
mlp = MLP(3, 6, hidden_sizes=[128])
|
||||
assert list(mlp(data).shape) == [bsz, 6]
|
||||
# output == 0 and len(hidden_sizes) == 0 means identity model
|
||||
mlp = MLP(6, 0)
|
||||
assert data.shape == mlp(data).shape
|
||||
# common net
|
||||
state_shape = (10, 2)
|
||||
action_shape = (5, )
|
||||
data = torch.rand([bsz, *state_shape])
|
||||
expect_output_shape = [bsz, *action_shape]
|
||||
net = Net(3, state_shape, action_shape, norm_layer=torch.nn.LayerNorm)
|
||||
net = Net(state_shape, action_shape, hidden_sizes=[128, 128],
|
||||
norm_layer=torch.nn.LayerNorm, activation=None)
|
||||
assert list(net(data)[0].shape) == expect_output_shape
|
||||
net = Net(3, state_shape, action_shape, dueling=(2, 2))
|
||||
assert str(net).count("LayerNorm") == 2
|
||||
assert str(net).count("ReLU") == 0
|
||||
Q_param = V_param = {"hidden_sizes": [128, 128]}
|
||||
net = Net(state_shape, action_shape, hidden_sizes=[128, 128],
|
||||
dueling_param=(Q_param, V_param))
|
||||
assert list(net(data)[0].shape) == expect_output_shape
|
||||
# concat
|
||||
net = Net(state_shape, action_shape, hidden_sizes=[128],
|
||||
concat=True)
|
||||
data = torch.rand([bsz, np.prod(state_shape) + np.prod(action_shape)])
|
||||
expect_output_shape = [bsz, 128]
|
||||
assert list(net(data)[0].shape) == expect_output_shape
|
||||
net = Net(state_shape, action_shape, hidden_sizes=[128],
|
||||
concat=True, dueling_param=(Q_param, V_param))
|
||||
assert list(net(data)[0].shape) == expect_output_shape
|
||||
# recurrent actor/critic
|
||||
data = data.flatten(1)
|
||||
data = torch.rand([bsz, *state_shape]).flatten(1)
|
||||
expect_output_shape = [bsz, *action_shape]
|
||||
net = RecurrentActorProb(3, state_shape, action_shape)
|
||||
mu, sigma = net(data)[0]
|
||||
assert mu.shape == sigma.shape
|
||||
@ -54,17 +75,6 @@ def test_net():
|
||||
data = torch.rand([bsz, 8, np.prod(state_shape)])
|
||||
act = torch.rand(expect_output_shape)
|
||||
assert list(net(data, act).shape) == [bsz, 1]
|
||||
# DQN
|
||||
state_shape = (4, 84, 84)
|
||||
action_shape = (6, )
|
||||
data = np.random.rand(bsz, *state_shape)
|
||||
expect_output_shape = [bsz, *action_shape]
|
||||
net = DQN(*state_shape, action_shape)
|
||||
assert list(net(data)[0].shape) == expect_output_shape
|
||||
num_atoms = 51
|
||||
net = C51(*state_shape, action_shape, num_atoms)
|
||||
expect_output_shape = [bsz, *action_shape, num_atoms]
|
||||
assert list(net(data)[0].shape) == expect_output_shape
|
||||
|
||||
|
||||
def test_summary_writer():
|
||||
|
@ -30,7 +30,8 @@ def get_args():
|
||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
||||
parser.add_argument('--collect-per-step', type=int, default=4)
|
||||
parser.add_argument('--batch-size', type=int, default=128)
|
||||
parser.add_argument('--layer-num', type=int, default=1)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
@ -66,15 +67,14 @@ def test_ddpg(args=get_args()):
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||
actor = Actor(
|
||||
net, args.action_shape,
|
||||
args.max_action, args.device
|
||||
).to(args.device)
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
actor = Actor(net, args.action_shape, max_action=args.max_action,
|
||||
device=args.device).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
net = Net(args.layer_num, args.state_shape,
|
||||
args.action_shape, concat=True, device=args.device)
|
||||
critic = Critic(net, args.device).to(args.device)
|
||||
net = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes, concat=True, device=args.device)
|
||||
critic = Critic(net, device=args.device).to(args.device)
|
||||
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
|
||||
policy = DDPGPolicy(
|
||||
actor, actor_optim, critic, critic_optim,
|
||||
|
@ -27,7 +27,8 @@ def get_args():
|
||||
parser.add_argument('--collect-per-step', type=int, default=1)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||
parser.add_argument('--batch-size', type=int, default=128)
|
||||
parser.add_argument('--layer-num', type=int, default=1)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=16)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
@ -69,21 +70,20 @@ def test_ppo(args=get_args()):
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||
actor = ActorProb(
|
||||
net, args.action_shape,
|
||||
args.max_action, args.device
|
||||
).to(args.device)
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
actor = ActorProb(net, args.action_shape, max_action=args.max_action,
|
||||
device=args.device).to(args.device)
|
||||
critic = Critic(Net(
|
||||
args.layer_num, args.state_shape, device=args.device
|
||||
args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device
|
||||
), device=args.device).to(args.device)
|
||||
# orthogonal initialization
|
||||
for m in list(actor.modules()) + list(critic.modules()):
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.orthogonal_(m.weight)
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
optim = torch.optim.Adam(list(
|
||||
actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
||||
optim = torch.optim.Adam(set(
|
||||
actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||
|
||||
# replace DiagGuassian with Independent(Normal) which is equivalent
|
||||
# pass *logits to be consistent with policy.forward
|
||||
|
@ -29,7 +29,10 @@ def get_args():
|
||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=128)
|
||||
parser.add_argument('--layer-num', type=int, default=1)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128])
|
||||
parser.add_argument('--imitation-hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
@ -65,18 +68,20 @@ def test_sac_with_il(args=get_args()):
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||
actor = ActorProb(
|
||||
net, args.action_shape, args.max_action, args.device, unbounded=True
|
||||
).to(args.device)
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
actor = ActorProb(net, args.action_shape, max_action=args.max_action,
|
||||
device=args.device, unbounded=True).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
net_c1 = Net(args.layer_num, args.state_shape,
|
||||
args.action_shape, concat=True, device=args.device)
|
||||
critic1 = Critic(net_c1, args.device).to(args.device)
|
||||
net_c1 = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
concat=True, device=args.device)
|
||||
critic1 = Critic(net_c1, device=args.device).to(args.device)
|
||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||
net_c2 = Net(args.layer_num, args.state_shape,
|
||||
args.action_shape, concat=True, device=args.device)
|
||||
critic2 = Critic(net_c2, args.device).to(args.device)
|
||||
net_c2 = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
concat=True, device=args.device)
|
||||
critic2 = Critic(net_c2, device=args.device).to(args.device)
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
policy = SACPolicy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
@ -120,8 +125,9 @@ def test_sac_with_il(args=get_args()):
|
||||
if args.task == 'Pendulum-v0':
|
||||
env.spec.reward_threshold = -300 # lower the goal
|
||||
net = Actor(
|
||||
Net(1, args.state_shape, device=args.device),
|
||||
args.action_shape, args.max_action, args.device
|
||||
Net(args.state_shape, hidden_sizes=args.imitation_hidden_sizes,
|
||||
device=args.device),
|
||||
args.action_shape, max_action=args.max_action, device=args.device
|
||||
).to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
|
||||
il_policy = ImitationPolicy(net, optim, mode='continuous')
|
||||
|
@ -32,7 +32,8 @@ def get_args():
|
||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=128)
|
||||
parser.add_argument('--layer-num', type=int, default=1)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
@ -68,19 +69,20 @@ def test_td3(args=get_args()):
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||
actor = Actor(
|
||||
net, args.action_shape,
|
||||
args.max_action, args.device
|
||||
).to(args.device)
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
actor = Actor(net, args.action_shape, max_action=args.max_action,
|
||||
device=args.device).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
net_c1 = Net(args.layer_num, args.state_shape,
|
||||
args.action_shape, concat=True, device=args.device)
|
||||
critic1 = Critic(net_c1, args.device).to(args.device)
|
||||
net_c1 = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
concat=True, device=args.device)
|
||||
critic1 = Critic(net_c1, device=args.device).to(args.device)
|
||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||
net_c2 = Net(args.layer_num, args.state_shape,
|
||||
args.action_shape, concat=True, device=args.device)
|
||||
critic2 = Critic(net_c2, args.device).to(args.device)
|
||||
net_c2 = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
concat=True, device=args.device)
|
||||
critic2 = Critic(net_c2, device=args.device).to(args.device)
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
policy = TD3Policy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
|
@ -27,7 +27,10 @@ def get_args():
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=1)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=2)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128, 128])
|
||||
parser.add_argument('--imitation-hidden-sizes', type=int,
|
||||
nargs='*', default=[128])
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
@ -63,11 +66,12 @@ def test_a2c_with_il(args=get_args()):
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
actor = Actor(net, args.action_shape).to(args.device)
|
||||
critic = Critic(net).to(args.device)
|
||||
optim = torch.optim.Adam(list(
|
||||
actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
||||
optim = torch.optim.Adam(set(
|
||||
actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||
dist = torch.distributions.Categorical
|
||||
policy = A2CPolicy(
|
||||
actor, critic, optim, dist, args.gamma, gae_lambda=args.gae_lambda,
|
||||
@ -107,7 +111,8 @@ def test_a2c_with_il(args=get_args()):
|
||||
# here we define an imitation collector with a trivial policy
|
||||
if args.task == 'CartPole-v0':
|
||||
env.spec.reward_threshold = 190 # lower the goal
|
||||
net = Net(1, args.state_shape, device=args.device)
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
net = Actor(net, args.action_shape).to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
|
||||
il_policy = ImitationPolicy(net, optim, mode='discrete')
|
||||
|
@ -31,12 +31,14 @@ def get_args():
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=3)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128, 128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument('--prioritized-replay', type=int, default=0)
|
||||
parser.add_argument('--prioritized-replay',
|
||||
action="store_true", default=False)
|
||||
parser.add_argument('--alpha', type=float, default=0.6)
|
||||
parser.add_argument('--beta', type=float, default=0.4)
|
||||
parser.add_argument(
|
||||
@ -63,7 +65,8 @@ def test_c51(args=get_args()):
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device,
|
||||
net = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes, device=args.device,
|
||||
softmax=True, num_atoms=args.num_atoms)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
policy = C51Policy(
|
||||
@ -71,7 +74,7 @@ def test_c51(args=get_args()):
|
||||
args.n_step, target_update_freq=args.target_update_freq
|
||||
).to(args.device)
|
||||
# buffer
|
||||
if args.prioritized_replay > 0:
|
||||
if args.prioritized_replay:
|
||||
buf = PrioritizedReplayBuffer(
|
||||
args.buffer_size, alpha=args.alpha, beta=args.beta)
|
||||
else:
|
||||
@ -125,7 +128,7 @@ def test_c51(args=get_args()):
|
||||
|
||||
|
||||
def test_pc51(args=get_args()):
|
||||
args.prioritized_replay = 1
|
||||
args.prioritized_replay = True
|
||||
args.gamma = .95
|
||||
args.seed = 1
|
||||
test_c51(args)
|
||||
|
@ -28,12 +28,14 @@ def get_args():
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=3)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128, 128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument('--prioritized-replay', type=int, default=0)
|
||||
parser.add_argument('--prioritized-replay',
|
||||
action="store_true", default=False)
|
||||
parser.add_argument('--alpha', type=float, default=0.6)
|
||||
parser.add_argument('--beta', type=float, default=0.4)
|
||||
parser.add_argument(
|
||||
@ -59,16 +61,18 @@ def test_dqn(args=get_args()):
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# Q_param = V_param = {"hidden_sizes": [128]}
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape,
|
||||
args.action_shape, args.device, # dueling=(1, 1)
|
||||
net = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes, device=args.device,
|
||||
# dueling=(Q_param, V_param),
|
||||
).to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
policy = DQNPolicy(
|
||||
net, optim, args.gamma, args.n_step,
|
||||
target_update_freq=args.target_update_freq)
|
||||
# buffer
|
||||
if args.prioritized_replay > 0:
|
||||
if args.prioritized_replay:
|
||||
buf = PrioritizedReplayBuffer(
|
||||
args.buffer_size, alpha=args.alpha, beta=args.beta)
|
||||
else:
|
||||
@ -122,7 +126,7 @@ def test_dqn(args=get_args()):
|
||||
|
||||
|
||||
def test_pdqn(args=get_args()):
|
||||
args.prioritized_replay = 1
|
||||
args.prioritized_replay = True
|
||||
args.gamma = .95
|
||||
args.seed = 1
|
||||
test_dqn(args)
|
||||
|
@ -25,7 +25,8 @@ def get_args():
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=3)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128, 128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
@ -55,9 +56,9 @@ def test_pg(args=get_args()):
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(
|
||||
args.layer_num, args.state_shape, args.action_shape,
|
||||
device=args.device, softmax=True).to(args.device)
|
||||
net = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
device=args.device, softmax=True).to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
dist = torch.distributions.Categorical
|
||||
policy = PGPolicy(net, optim, dist, args.gamma,
|
||||
|
@ -26,7 +26,8 @@ def get_args():
|
||||
parser.add_argument('--collect-per-step', type=int, default=20)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=1)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=20)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
@ -65,7 +66,8 @@ def test_ppo(args=get_args()):
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
actor = Actor(net, args.action_shape).to(args.device)
|
||||
critic = Critic(net).to(args.device)
|
||||
# orthogonal initialization
|
||||
@ -73,8 +75,8 @@ def test_ppo(args=get_args()):
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.orthogonal_(m.weight)
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
optim = torch.optim.Adam(list(
|
||||
actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
||||
optim = torch.optim.Adam(set(
|
||||
actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||
dist = torch.distributions.Categorical
|
||||
policy = PPOPolicy(
|
||||
actor, critic, optim, dist, args.gamma,
|
||||
|
@ -30,7 +30,8 @@ def get_args():
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=5)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=1)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=16)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
@ -59,13 +60,16 @@ def test_discrete_sac(args=get_args()):
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
actor = Actor(net, args.action_shape, softmax_output=False).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
net_c1 = Net(args.layer_num, args.state_shape, device=args.device)
|
||||
net_c1 = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
critic1 = Critic(net_c1, last_size=args.action_shape).to(args.device)
|
||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||
net_c2 = Net(args.layer_num, args.state_shape, device=args.device)
|
||||
net_c2 = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
critic2 = Critic(net_c2, last_size=args.action_shape).to(args.device)
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
|
||||
|
@ -31,7 +31,8 @@ def get_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument('--step-per-epoch', type=int, default=500)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=3)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128, 128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
@ -75,8 +76,9 @@ def get_agents(
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
if agent_learn is None:
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape, args.action_shape,
|
||||
args.device).to(args.device)
|
||||
net = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes, device=args.device
|
||||
).to(args.device)
|
||||
if optim is None:
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
agent_learn = DQNPolicy(
|
||||
|
@ -1,93 +1,180 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from typing import Any, Dict, List, Tuple, Union, Callable, Optional, Sequence
|
||||
from typing import Any, Dict, List, Type, Tuple, Union, Optional, Sequence
|
||||
|
||||
from tianshou.data import to_torch
|
||||
ModuleType = Type[nn.Module]
|
||||
|
||||
|
||||
def miniblock(
|
||||
inp: int,
|
||||
oup: int,
|
||||
norm_layer: Optional[Callable[[int], nn.modules.Module]],
|
||||
) -> List[nn.modules.Module]:
|
||||
"""Construct a miniblock with given input/output-size and norm layer."""
|
||||
ret: List[nn.modules.Module] = [nn.Linear(inp, oup)]
|
||||
input_size: int,
|
||||
output_size: int = 0,
|
||||
norm_layer: Optional[ModuleType] = None,
|
||||
activation: Optional[ModuleType] = None,
|
||||
) -> List[nn.Module]:
|
||||
"""Construct a miniblock with given input/output-size, norm layer and \
|
||||
activation."""
|
||||
layers: List[nn.Module] = [nn.Linear(input_size, output_size)]
|
||||
if norm_layer is not None:
|
||||
ret += [norm_layer(oup)]
|
||||
ret += [nn.ReLU(inplace=True)]
|
||||
return ret
|
||||
layers += [norm_layer(output_size)] # type: ignore
|
||||
if activation is not None:
|
||||
layers += [activation()]
|
||||
return layers
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
class MLP(nn.Module):
|
||||
"""Simple MLP backbone.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
Create a MLP of size input_dim * hidden_sizes[0] * hidden_sizes[1] * ...
|
||||
* hidden_sizes[-1] * output_dim
|
||||
|
||||
:param bool concat: whether the input shape is concatenated by state_shape
|
||||
and action_shape. If it is True, ``action_shape`` is not the output
|
||||
shape, but affects the input shape.
|
||||
:param bool dueling: whether to use dueling network to calculate Q values
|
||||
(for Dueling DQN), defaults to False.
|
||||
:param norm_layer: use which normalization before ReLU, e.g.,
|
||||
``nn.LayerNorm`` and ``nn.BatchNorm1d``, defaults to None.
|
||||
:param int num_atoms: in order to expand to the net of distributional RL,
|
||||
defaults to 1.
|
||||
:param int input_dim: dimension of the input vector.
|
||||
:param int output_dim: dimension of the output vector. If set to 0, there
|
||||
is no final linear layer.
|
||||
:param hidden_sizes: shape of MLP passed in as a list, not incluing
|
||||
input_dim and output_dim.
|
||||
:param norm_layer: use which normalization before activation, e.g.,
|
||||
``nn.LayerNorm`` and ``nn.BatchNorm1d``, defaults to no normalization.
|
||||
You can also pass a list of normalization modules with the same length
|
||||
of hidden_sizes, to use different normalization module in different
|
||||
layers. Default to no normalization.
|
||||
:param activation: which activation to use after each layer, can be both
|
||||
the same actvition for all layers if passed in nn.Module, or different
|
||||
activation for different Modules if passed in a list. Default to
|
||||
nn.ReLU.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_num: int,
|
||||
state_shape: tuple,
|
||||
action_shape: Optional[Union[tuple, int]] = 0,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
softmax: bool = False,
|
||||
concat: bool = False,
|
||||
hidden_layer_size: int = 128,
|
||||
dueling: Optional[Tuple[int, int]] = None,
|
||||
norm_layer: Optional[Callable[[int], nn.modules.Module]] = None,
|
||||
num_atoms: int = 1,
|
||||
input_dim: int,
|
||||
output_dim: int = 0,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None,
|
||||
activation: Optional[Union[ModuleType, Sequence[ModuleType]]]
|
||||
= nn.ReLU,
|
||||
device: Optional[Union[str, int, torch.device]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
if norm_layer:
|
||||
if isinstance(norm_layer, list):
|
||||
assert len(norm_layer) == len(hidden_sizes)
|
||||
norm_layer_list = norm_layer
|
||||
else:
|
||||
norm_layer_list = [
|
||||
norm_layer for _ in range(len(hidden_sizes))]
|
||||
else:
|
||||
norm_layer_list = [None] * len(hidden_sizes)
|
||||
if activation:
|
||||
if isinstance(activation, list):
|
||||
assert len(activation) == len(hidden_sizes)
|
||||
activation_list = activation
|
||||
else:
|
||||
activation_list = [
|
||||
activation for _ in range(len(hidden_sizes))]
|
||||
else:
|
||||
activation_list = [None] * len(hidden_sizes)
|
||||
hidden_sizes = [input_dim] + list(hidden_sizes)
|
||||
model = []
|
||||
for in_dim, out_dim, norm, activ in zip(
|
||||
hidden_sizes[:-1], hidden_sizes[1:],
|
||||
norm_layer_list, activation_list):
|
||||
model += miniblock(in_dim, out_dim, norm, activ)
|
||||
if output_dim > 0:
|
||||
model += [nn.Linear(hidden_sizes[-1], output_dim)]
|
||||
self.output_dim = output_dim or hidden_sizes[-1]
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
def forward(
|
||||
self, x: Union[np.ndarray, torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
x = torch.as_tensor(
|
||||
x, device=self.device, dtype=torch.float32) # type: ignore
|
||||
return self.model(x.flatten(1))
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
"""Wrapper of MLP to support more specific DRL usage.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
|
||||
:param state_shape: int or a sequence of int of the shape of state.
|
||||
:param action_shape: int or a sequence of int of the shape of action.
|
||||
:param hidden_sizes: shape of MLP passed in as a list.
|
||||
:param norm_layer: use which normalization before activation, e.g.,
|
||||
``nn.LayerNorm`` and ``nn.BatchNorm1d``, defaults to no normalization.
|
||||
You can also pass a list of normalization modules with the same length
|
||||
of hidden_sizes, to use different normalization module in different
|
||||
layers. Default to no normalization.
|
||||
:param activation: which activation to use after each layer, can be both
|
||||
the same actvition for all layers if passed in nn.Module, or different
|
||||
activation for different Modules if passed in a list. Default to
|
||||
nn.ReLU.
|
||||
:param device: specify the device when the network actually runs. Default
|
||||
to "cpu".
|
||||
:param bool softmax: whether to apply a softmax layer over the last layer's
|
||||
output.
|
||||
:param bool concat: whether the input shape is concatenated by state_shape
|
||||
and action_shape. If it is True, ``action_shape`` is not the output
|
||||
shape, but affects the input shape only.
|
||||
:param int num_atoms: in order to expand to the net of distributional RL,
|
||||
defaults to 1 (not use).
|
||||
:param bool dueling_param: whether to use dueling network to calculate Q
|
||||
values (for Dueling DQN). If you want to use dueling option, you should
|
||||
pass a tuple of two dict (first for Q and second for V) stating
|
||||
self-defined arguments as stated in
|
||||
class:`~tianshou.utils.net.common.MLP`. Defaults to None.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.utils.net.common.MLP` for more
|
||||
detailed explanation on the usage of activation, norm_layer, etc.
|
||||
|
||||
You can also refer to :class:`~tianshou.utils.net.continuous.Actor`,
|
||||
:class:`~tianshou.utils.net.continuous.Critic`, etc, to see how it's
|
||||
suggested be used.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_shape: Union[int, Sequence[int]],
|
||||
action_shape: Optional[Union[int, Sequence[int]]] = 0,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
norm_layer: Optional[ModuleType] = None,
|
||||
activation: Optional[ModuleType] = nn.ReLU,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
softmax: bool = False,
|
||||
concat: bool = False,
|
||||
num_atoms: int = 1,
|
||||
dueling_param: Optional[Tuple[Dict[str, Any], Dict[str, Any]]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.dueling = dueling
|
||||
self.softmax = softmax
|
||||
self.num_atoms = num_atoms
|
||||
self.action_num = np.prod(action_shape)
|
||||
input_size = np.prod(state_shape)
|
||||
input_dim = np.prod(state_shape)
|
||||
action_dim = np.prod(action_shape) * num_atoms
|
||||
if concat:
|
||||
input_size += np.prod(action_shape)
|
||||
|
||||
model = miniblock(input_size, hidden_layer_size, norm_layer)
|
||||
|
||||
for i in range(layer_num):
|
||||
model += miniblock(
|
||||
hidden_layer_size, hidden_layer_size, norm_layer)
|
||||
|
||||
if dueling is None:
|
||||
if action_shape and not concat:
|
||||
model += [nn.Linear(
|
||||
hidden_layer_size, num_atoms * self.action_num)]
|
||||
else: # dueling DQN
|
||||
q_layer_num, v_layer_num = dueling
|
||||
Q, V = [], []
|
||||
|
||||
for i in range(q_layer_num):
|
||||
Q += miniblock(
|
||||
hidden_layer_size, hidden_layer_size, norm_layer)
|
||||
for i in range(v_layer_num):
|
||||
V += miniblock(
|
||||
hidden_layer_size, hidden_layer_size, norm_layer)
|
||||
|
||||
if action_shape and not concat:
|
||||
Q += [nn.Linear(
|
||||
hidden_layer_size, num_atoms * self.action_num)]
|
||||
V += [nn.Linear(hidden_layer_size, num_atoms)]
|
||||
|
||||
self.Q = nn.Sequential(*Q)
|
||||
self.V = nn.Sequential(*V)
|
||||
self.model = nn.Sequential(*model)
|
||||
input_dim += action_dim
|
||||
self.use_dueling = dueling_param is not None
|
||||
output_dim = action_dim if not self.use_dueling and not concat else 0
|
||||
self.model = MLP(input_dim, output_dim, hidden_sizes,
|
||||
norm_layer, activation, device)
|
||||
self.output_dim = self.model.output_dim
|
||||
if self.use_dueling: # dueling DQN
|
||||
q_kwargs, v_kwargs = dueling_param # type: ignore
|
||||
q_output_dim, v_output_dim = 0, 0
|
||||
if not concat:
|
||||
q_output_dim, v_output_dim = action_dim, num_atoms
|
||||
q_kwargs: Dict[str, Any] = {
|
||||
**q_kwargs, "input_dim": self.output_dim,
|
||||
"output_dim": q_output_dim}
|
||||
v_kwargs: Dict[str, Any] = {
|
||||
**v_kwargs, "input_dim": self.output_dim,
|
||||
"output_dim": v_output_dim}
|
||||
self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs)
|
||||
self.output_dim = self.Q.output_dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -95,18 +182,17 @@ class Net(nn.Module):
|
||||
state: Optional[Any] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
"""Mapping: s -> flatten -> logits."""
|
||||
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||
s = s.reshape(s.size(0), -1)
|
||||
"""Mapping: s -> flatten (inside MLP)-> logits."""
|
||||
logits = self.model(s)
|
||||
if self.dueling is not None: # Dueling DQN
|
||||
bsz = logits.shape[0]
|
||||
if self.use_dueling: # Dueling DQN
|
||||
q, v = self.Q(logits), self.V(logits)
|
||||
if self.num_atoms > 1:
|
||||
v = v.view(-1, 1, self.num_atoms)
|
||||
q = q.view(-1, self.action_num, self.num_atoms)
|
||||
q = q.view(bsz, -1, self.num_atoms)
|
||||
v = v.view(bsz, -1, self.num_atoms)
|
||||
logits = q - q.mean(dim=1, keepdim=True) + v
|
||||
elif self.num_atoms > 1:
|
||||
logits = logits.view(-1, self.action_num, self.num_atoms)
|
||||
logits = logits.view(bsz, -1, self.num_atoms)
|
||||
if self.softmax:
|
||||
logits = torch.softmax(logits, dim=-1)
|
||||
return logits, state
|
||||
@ -122,14 +208,12 @@ class Recurrent(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
layer_num: int,
|
||||
state_shape: Sequence[int],
|
||||
action_shape: Sequence[int],
|
||||
state_shape: Union[int, Sequence[int]],
|
||||
action_shape: Union[int, Sequence[int]],
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
hidden_layer_size: int = 128,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.state_shape = state_shape
|
||||
self.action_shape = action_shape
|
||||
self.device = device
|
||||
self.nn = nn.LSTM(
|
||||
input_size=hidden_layer_size,
|
||||
@ -152,7 +236,8 @@ class Recurrent(nn.Module):
|
||||
training mode, s should be with shape ``[bsz, len, dim]``. See the code
|
||||
and comment for more detail.
|
||||
"""
|
||||
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||
s = torch.as_tensor(
|
||||
s, device=self.device, dtype=torch.float32) # type: ignore
|
||||
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||
# In short, the tensor's shape in training phase is longer than which
|
||||
# in evaluation phase.
|
||||
|
@ -3,7 +3,7 @@ import numpy as np
|
||||
from torch import nn
|
||||
from typing import Any, Dict, Tuple, Union, Optional, Sequence
|
||||
|
||||
from tianshou.data import to_torch, to_torch_as
|
||||
from tianshou.utils.net.common import MLP
|
||||
|
||||
|
||||
SIGMA_MIN = -20
|
||||
@ -11,23 +11,45 @@ SIGMA_MAX = 2
|
||||
|
||||
|
||||
class Actor(nn.Module):
|
||||
"""Simple actor network with MLP.
|
||||
"""Simple actor network. Will create an actor operated in continuous \
|
||||
action space with structure of preprocess_net ---> action_shape.
|
||||
|
||||
:param preprocess_net: a self-defined preprocess_net which output a
|
||||
flattened hidden state.
|
||||
:param action_shape: a sequence of int for the shape of action.
|
||||
:param hidden_sizes: a sequence of int for constructing the MLP after
|
||||
preprocess_net. Default to empty sequence (where the MLP now contains
|
||||
only a single linear layer).
|
||||
:param float max_action: the scale for the final action logits. Default to
|
||||
1.
|
||||
:param int preprocess_net_output_dim: the output dimension of
|
||||
preprocess_net.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
||||
of how preprocess_net is suggested to be defined.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
action_shape: Sequence[int],
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
max_action: float = 1.0,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
hidden_layer_size: int = 128,
|
||||
preprocess_net_output_dim: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.preprocess = preprocess_net
|
||||
self.last = nn.Linear(hidden_layer_size, np.prod(action_shape))
|
||||
self.output_dim = np.prod(action_shape)
|
||||
input_dim = getattr(preprocess_net, "output_dim",
|
||||
preprocess_net_output_dim)
|
||||
self.last = MLP(input_dim, self.output_dim, hidden_sizes)
|
||||
self._max = max_action
|
||||
|
||||
def forward(
|
||||
@ -43,22 +65,40 @@ class Actor(nn.Module):
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
"""Simple critic network with MLP.
|
||||
"""Simple critic network. Will create an actor operated in continuous \
|
||||
action space with structure of preprocess_net ---> 1(q value).
|
||||
|
||||
:param preprocess_net: a self-defined preprocess_net which output a
|
||||
flattened hidden state.
|
||||
:param hidden_sizes: a sequence of int for constructing the MLP after
|
||||
preprocess_net. Default to empty sequence (where the MLP now contains
|
||||
only a single linear layer).
|
||||
:param int preprocess_net_output_dim: the output dimension of
|
||||
preprocess_net.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
||||
of how preprocess_net is suggested to be defined.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
hidden_layer_size: int = 128,
|
||||
preprocess_net_output_dim: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.preprocess = preprocess_net
|
||||
self.last = nn.Linear(hidden_layer_size, 1)
|
||||
self.output_dim = 1
|
||||
input_dim = getattr(preprocess_net, "output_dim",
|
||||
preprocess_net_output_dim)
|
||||
self.last = MLP(input_dim, 1, hidden_sizes)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -67,11 +107,13 @@ class Critic(nn.Module):
|
||||
info: Dict[str, Any] = {},
|
||||
) -> torch.Tensor:
|
||||
"""Mapping: (s, a) -> logits -> Q(s, a)."""
|
||||
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||
s = s.flatten(1)
|
||||
s = torch.as_tensor(
|
||||
s, device=self.device, dtype=torch.float32 # type: ignore
|
||||
).flatten(1)
|
||||
if a is not None:
|
||||
a = to_torch_as(a, s)
|
||||
a = a.flatten(1)
|
||||
a = torch.as_tensor(
|
||||
a, device=self.device, dtype=torch.float32 # type: ignore
|
||||
).flatten(1)
|
||||
s = torch.cat([s, a], dim=1)
|
||||
logits, h = self.preprocess(s)
|
||||
logits = self.last(logits)
|
||||
@ -79,31 +121,55 @@ class Critic(nn.Module):
|
||||
|
||||
|
||||
class ActorProb(nn.Module):
|
||||
"""Simple actor network (output with a Gauss distribution) with MLP.
|
||||
"""Simple actor network (output with a Gauss distribution).
|
||||
|
||||
:param preprocess_net: a self-defined preprocess_net which output a
|
||||
flattened hidden state.
|
||||
:param action_shape: a sequence of int for the shape of action.
|
||||
:param hidden_sizes: a sequence of int for constructing the MLP after
|
||||
preprocess_net. Default to empty sequence (where the MLP now contains
|
||||
only a single linear layer).
|
||||
:param float max_action: the scale for the final action logits. Default to
|
||||
1.
|
||||
:param bool unbounded: whether to apply tanh activation on final logits.
|
||||
Default to False.
|
||||
:param bool conditioned_sigma: True when sigma is calculated from the
|
||||
input, False when sigma is an independent parameter. Default to False.
|
||||
:param int preprocess_net_output_dim: the output dimension of
|
||||
preprocess_net.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
||||
of how preprocess_net is suggested to be defined.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
action_shape: Sequence[int],
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
max_action: float = 1.0,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
unbounded: bool = False,
|
||||
hidden_layer_size: int = 128,
|
||||
conditioned_sigma: bool = False,
|
||||
preprocess_net_output_dim: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.preprocess = preprocess_net
|
||||
self.device = device
|
||||
self.mu = nn.Linear(hidden_layer_size, np.prod(action_shape))
|
||||
self.output_dim = np.prod(action_shape)
|
||||
input_dim = getattr(preprocess_net, "output_dim",
|
||||
preprocess_net_output_dim)
|
||||
self.mu = MLP(input_dim, self.output_dim, hidden_sizes)
|
||||
self._c_sigma = conditioned_sigma
|
||||
if conditioned_sigma:
|
||||
self.sigma = nn.Linear(hidden_layer_size, np.prod(action_shape))
|
||||
self.sigma = MLP(input_dim, self.output_dim, hidden_sizes)
|
||||
else:
|
||||
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
|
||||
self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1))
|
||||
self._max = max_action
|
||||
self._unbounded = unbounded
|
||||
|
||||
@ -125,7 +191,7 @@ class ActorProb(nn.Module):
|
||||
else:
|
||||
shape = [1] * len(mu.shape)
|
||||
shape[1] = -1
|
||||
sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp()
|
||||
sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp()
|
||||
return (mu, sigma), state
|
||||
|
||||
|
||||
@ -141,10 +207,10 @@ class RecurrentActorProb(nn.Module):
|
||||
layer_num: int,
|
||||
state_shape: Sequence[int],
|
||||
action_shape: Sequence[int],
|
||||
hidden_layer_size: int = 128,
|
||||
max_action: float = 1.0,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
unbounded: bool = False,
|
||||
hidden_layer_size: int = 128,
|
||||
conditioned_sigma: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -155,12 +221,13 @@ class RecurrentActorProb(nn.Module):
|
||||
num_layers=layer_num,
|
||||
batch_first=True,
|
||||
)
|
||||
self.mu = nn.Linear(hidden_layer_size, np.prod(action_shape))
|
||||
output_dim = np.prod(action_shape)
|
||||
self.mu = nn.Linear(hidden_layer_size, output_dim)
|
||||
self._c_sigma = conditioned_sigma
|
||||
if conditioned_sigma:
|
||||
self.sigma = nn.Linear(hidden_layer_size, np.prod(action_shape))
|
||||
self.sigma = nn.Linear(hidden_layer_size, output_dim)
|
||||
else:
|
||||
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
|
||||
self.sigma_param = nn.Parameter(torch.zeros(output_dim, 1))
|
||||
self._max = max_action
|
||||
self._unbounded = unbounded
|
||||
|
||||
@ -171,7 +238,8 @@ class RecurrentActorProb(nn.Module):
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Dict[str, torch.Tensor]]:
|
||||
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
|
||||
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||
s = torch.as_tensor(
|
||||
s, device=self.device, dtype=torch.float32) # type: ignore
|
||||
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||
# In short, the tensor's shape in training phase is longer than which
|
||||
# in evaluation phase.
|
||||
@ -196,7 +264,7 @@ class RecurrentActorProb(nn.Module):
|
||||
else:
|
||||
shape = [1] * len(mu.shape)
|
||||
shape[1] = -1
|
||||
sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp()
|
||||
sigma = (self.sigma_param.view(shape) + torch.zeros_like(mu)).exp()
|
||||
# please ensure the first dim is batch size: [bsz, len, ...]
|
||||
return (mu, sigma), {"h": h.transpose(0, 1).detach(),
|
||||
"c": c.transpose(0, 1).detach()}
|
||||
@ -236,7 +304,8 @@ class RecurrentCritic(nn.Module):
|
||||
info: Dict[str, Any] = {},
|
||||
) -> torch.Tensor:
|
||||
"""Almost the same as :class:`~tianshou.utils.net.common.Recurrent`."""
|
||||
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||
s = torch.as_tensor(
|
||||
s, device=self.device, dtype=torch.float32) # type: ignore
|
||||
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||
# In short, the tensor's shape in training phase is longer than which
|
||||
# in evaluation phase.
|
||||
@ -245,7 +314,8 @@ class RecurrentCritic(nn.Module):
|
||||
s, (h, c) = self.nn(s)
|
||||
s = s[:, -1]
|
||||
if a is not None:
|
||||
a = to_torch_as(a, s)
|
||||
a = torch.as_tensor(
|
||||
a, device=self.device, dtype=torch.float32) # type: ignore
|
||||
s = torch.cat([s, a], dim=1)
|
||||
s = self.fc2(s)
|
||||
return s
|
||||
|
@ -4,26 +4,49 @@ from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Any, Dict, Tuple, Union, Optional, Sequence
|
||||
|
||||
from tianshou.data import to_torch
|
||||
from tianshou.utils.net.common import MLP
|
||||
|
||||
|
||||
class Actor(nn.Module):
|
||||
"""Simple actor network with MLP.
|
||||
"""Simple actor network.
|
||||
|
||||
Will create an actor operated in discrete action space with structure of
|
||||
preprocess_net ---> action_shape.
|
||||
|
||||
:param preprocess_net: a self-defined preprocess_net which output a
|
||||
flattened hidden state.
|
||||
:param action_shape: a sequence of int for the shape of action.
|
||||
:param hidden_sizes: a sequence of int for constructing the MLP after
|
||||
preprocess_net. Default to empty sequence (where the MLP now contains
|
||||
only a single linear layer).
|
||||
:param bool softmax_output: whether to apply a softmax layer over the last
|
||||
layer's output.
|
||||
:param int preprocess_net_output_dim: the output dimension of
|
||||
preprocess_net.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
||||
of how preprocess_net is suggested to be defined.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
action_shape: Sequence[int],
|
||||
hidden_layer_size: int = 128,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
softmax_output: bool = True,
|
||||
preprocess_net_output_dim: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.preprocess = preprocess_net
|
||||
self.last = nn.Linear(hidden_layer_size, np.prod(action_shape))
|
||||
self.output_dim = np.prod(action_shape)
|
||||
input_dim = getattr(preprocess_net, "output_dim",
|
||||
preprocess_net_output_dim)
|
||||
self.last = MLP(input_dim, self.output_dim, hidden_sizes)
|
||||
self.softmax_output = softmax_output
|
||||
|
||||
def forward(
|
||||
@ -41,125 +64,44 @@ class Actor(nn.Module):
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
"""Simple critic network with MLP.
|
||||
"""Simple critic network. Will create an actor operated in discrete \
|
||||
action space with structure of preprocess_net ---> 1(q value).
|
||||
|
||||
:param preprocess_net: a self-defined preprocess_net which output a
|
||||
flattened hidden state.
|
||||
:param hidden_sizes: a sequence of int for constructing the MLP after
|
||||
preprocess_net. Default to empty sequence (where the MLP now contains
|
||||
only a single linear layer).
|
||||
:param int last_size: the output dimension of Critic network. Default to 1.
|
||||
:param int preprocess_net_output_dim: the output dimension of
|
||||
preprocess_net.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
||||
of how preprocess_net is suggested to be defined.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
hidden_layer_size: int = 128,
|
||||
last_size: int = 1
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
last_size: int = 1,
|
||||
preprocess_net_output_dim: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.preprocess = preprocess_net
|
||||
self.last = nn.Linear(hidden_layer_size, last_size)
|
||||
self.output_dim = last_size
|
||||
input_dim = getattr(preprocess_net, "output_dim",
|
||||
preprocess_net_output_dim)
|
||||
self.last = MLP(input_dim, last_size, hidden_sizes)
|
||||
|
||||
def forward(
|
||||
self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any
|
||||
) -> torch.Tensor:
|
||||
"""Mapping: s -> V(s)."""
|
||||
logits, h = self.preprocess(s, state=kwargs.get("state", None))
|
||||
logits = self.last(logits)
|
||||
return logits
|
||||
|
||||
|
||||
class DQN(nn.Module):
|
||||
"""Reference: Human-level control through deep reinforcement learning.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c: int,
|
||||
h: int,
|
||||
w: int,
|
||||
action_shape: Sequence[int],
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
|
||||
def conv2d_size_out(
|
||||
size: int, kernel_size: int = 5, stride: int = 2
|
||||
) -> int:
|
||||
return (size - (kernel_size - 1) - 1) // stride + 1
|
||||
|
||||
def conv2d_layers_size_out(
|
||||
size: int,
|
||||
kernel_size_1: int = 8,
|
||||
stride_1: int = 4,
|
||||
kernel_size_2: int = 4,
|
||||
stride_2: int = 2,
|
||||
kernel_size_3: int = 3,
|
||||
stride_3: int = 1,
|
||||
) -> int:
|
||||
size = conv2d_size_out(size, kernel_size_1, stride_1)
|
||||
size = conv2d_size_out(size, kernel_size_2, stride_2)
|
||||
size = conv2d_size_out(size, kernel_size_3, stride_3)
|
||||
return size
|
||||
|
||||
convw = conv2d_layers_size_out(w)
|
||||
convh = conv2d_layers_size_out(h)
|
||||
linear_input_size = convw * convh * 64
|
||||
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(c, 32, kernel_size=8, stride=4),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(32, 64, kernel_size=4, stride=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Flatten(),
|
||||
nn.Linear(linear_input_size, 512),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(512, np.prod(action_shape)),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Any] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
r"""Mapping: x -> Q(x, \*)."""
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = to_torch(x, device=self.device, dtype=torch.float32)
|
||||
return self.net(x), state
|
||||
|
||||
|
||||
class C51(DQN):
|
||||
"""Reference: A distributional perspective on reinforcement learning.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c: int,
|
||||
h: int,
|
||||
w: int,
|
||||
action_shape: Sequence[int],
|
||||
num_atoms: int = 51,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
) -> None:
|
||||
super().__init__(c, h, w, [np.prod(action_shape) * num_atoms], device)
|
||||
self.action_shape = action_shape
|
||||
self.num_atoms = num_atoms
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Union[np.ndarray, torch.Tensor],
|
||||
state: Optional[Any] = None,
|
||||
info: Dict[str, Any] = {},
|
||||
) -> Tuple[torch.Tensor, Any]:
|
||||
r"""Mapping: x -> Z(x, \*)."""
|
||||
x, state = super().forward(x)
|
||||
x = x.view(-1, self.num_atoms).softmax(dim=-1)
|
||||
x = x.view(-1, np.prod(self.action_shape), self.num_atoms)
|
||||
return x, state
|
||||
logits, _ = self.preprocess(s, state=kwargs.get("state", None))
|
||||
return self.last(logits)
|
||||
|
Loading…
x
Reference in New Issue
Block a user