Remove dummy net code (#123)

* remove dummy net; delete two files

* split code to have backbone and head

* rename class

* change torch.float to torch.float32

* use flatten(1) instead of view(batch, -1)

* remove dummy net in docs

* bugfix for rnn

* fix cuda error

* minor fix of docs

* do not change the example code in dqn tutorial, since it is for demonstration

Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
youkaichao 2020-07-09 22:57:01 +08:00 committed by GitHub
parent aa3c453f42
commit e767de044b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 219 additions and 373 deletions

View File

@ -206,26 +206,12 @@ test_envs = ts.env.VectorEnv([lambda: gym.make(task) for _ in range(test_num)])
Define the network: Define the network:
```python ```python
class Net(nn.Module): from tianshou.utils.net.common import Net
def __init__(self, state_shape, action_shape):
super().__init__()
self.model = nn.Sequential(*[
nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True),
nn.Linear(128, 128), nn.ReLU(inplace=True),
nn.Linear(128, 128), nn.ReLU(inplace=True),
nn.Linear(128, np.prod(action_shape))
])
def forward(self, s, state=None, info={}):
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, dtype=torch.float)
batch = s.shape[0]
logits = self.model(s.view(batch, -1))
return logits, state
env = gym.make(task) env = gym.make(task)
state_shape = env.observation_space.shape or env.observation_space.n state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n action_shape = env.action_space.shape or env.action_space.n
net = Net(state_shape, action_shape) net = Net(layer_num=2, state_shape=state_shape, action_shape=action_shape)
optim = torch.optim.Adam(net.parameters(), lr=lr) optim = torch.optim.Adam(net.parameters(), lr=lr)
``` ```

View File

@ -5,3 +5,18 @@ tianshou.utils
:members: :members:
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
.. automodule:: tianshou.utils.net.common
:members:
:undoc-members:
:show-inheritance:
.. automodule:: tianshou.utils.net.discrete
:members:
:undoc-members:
:show-inheritance:
.. automodule:: tianshou.utils.net.continuous
:members:
:undoc-members:
:show-inheritance:

View File

@ -74,7 +74,7 @@ Tianshou supports any user-defined PyTorch networks and optimizers but with the
net = Net(state_shape, action_shape) net = Net(state_shape, action_shape)
optim = torch.optim.Adam(net.parameters(), lr=1e-3) optim = torch.optim.Adam(net.parameters(), lr=1e-3)
The rules of self-defined networks are: You can also have a try with those pre-defined networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are:
1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment. 1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment.
2. Output: some ``logits``, the next hidden state ``state``, and intermediate result during the policy forwarding procedure ``policy``. The logits could be a tuple instead of a ``torch.Tensor``. It depends on how the policy process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. The ``policy`` can be a Batch of torch.Tensor or other things, which will be stored in the replay buffer, and can be accessed in the policy update process (e.g. in ``policy.learn()``, the ``batch.policy`` is what you need). 2. Output: some ``logits``, the next hidden state ``state``, and intermediate result during the policy forwarding procedure ``policy``. The logits could be a tuple instead of a ``torch.Tensor``. It depends on how the policy process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. The ``policy`` can be a Batch of torch.Tensor or other things, which will be stored in the replay buffer, and can be accessed in the policy update process (e.g. in ``policy.learn()``, the ``batch.policy`` is what you need).

View File

@ -10,8 +10,8 @@ from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.env import VectorEnv, SubprocVectorEnv from tianshou.env import VectorEnv, SubprocVectorEnv
from tianshou.exploration import GaussianNoise from tianshou.exploration import GaussianNoise
from tianshou.utils.net.common import Net
from continuous_net import Actor, Critic from tianshou.utils.net.continuous import Actor, Critic
def get_args(): def get_args():
@ -57,14 +57,13 @@ def test_ddpg(args=get_args()):
train_envs.seed(args.seed) train_envs.seed(args.seed)
test_envs.seed(args.seed) test_envs.seed(args.seed)
# model # model
actor = Actor( net = Net(args.layer_num, args.state_shape, device=args.device)
args.layer_num, args.state_shape, args.action_shape, actor = Actor(net, args.action_shape, args.max_action,
args.max_action, args.device args.device).to(args.device)
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic = Critic( net = Net(args.layer_num, args.state_shape,
args.layer_num, args.state_shape, args.action_shape, args.device args.action_shape, concat=True, device=args.device)
).to(args.device) critic = Critic(net, args.device).to(args.device)
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
policy = DDPGPolicy( policy = DDPGPolicy(
actor, actor_optim, critic, critic_optim, actor, actor_optim, critic, critic_optim,

View File

@ -10,8 +10,8 @@ from tianshou.policy import SACPolicy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.env import VectorEnv, SubprocVectorEnv from tianshou.env import VectorEnv, SubprocVectorEnv
from tianshou.utils.net.common import Net
from continuous_net import ActorProb, Critic from tianshou.utils.net.continuous import ActorProb, Critic
def get_args(): def get_args():
@ -58,18 +58,17 @@ def test_sac(args=get_args()):
train_envs.seed(args.seed) train_envs.seed(args.seed)
test_envs.seed(args.seed) test_envs.seed(args.seed)
# model # model
net = Net(args.layer_num, args.state_shape, device=args.device)
actor = ActorProb( actor = ActorProb(
args.layer_num, args.state_shape, args.action_shape, net, args.action_shape,
args.max_action, args.device, unbounded=True args.max_action, args.device, unbounded=True
).to(args.device) ).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic( net = Net(args.layer_num, args.state_shape,
args.layer_num, args.state_shape, args.action_shape, args.device args.action_shape, concat=True, device=args.device)
).to(args.device) critic1 = Critic(net, args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic( critic2 = Critic(net, args.device).to(args.device)
args.layer_num, args.state_shape, args.action_shape, args.device
).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = SACPolicy( policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,

View File

@ -10,8 +10,8 @@ from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.env import VectorEnv, SubprocVectorEnv from tianshou.env import VectorEnv, SubprocVectorEnv
from tianshou.exploration import GaussianNoise from tianshou.exploration import GaussianNoise
from tianshou.utils.net.common import Net
from continuous_net import Actor, Critic from tianshou.utils.net.continuous import Actor, Critic
def get_args(): def get_args():
@ -60,18 +60,17 @@ def test_td3(args=get_args()):
train_envs.seed(args.seed) train_envs.seed(args.seed)
test_envs.seed(args.seed) test_envs.seed(args.seed)
# model # model
net = Net(args.layer_num, args.state_shape, device=args.device)
actor = Actor( actor = Actor(
args.layer_num, args.state_shape, args.action_shape, net, args.action_shape,
args.max_action, args.device args.max_action, args.device
).to(args.device) ).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic( net = Net(args.layer_num, args.state_shape,
args.layer_num, args.state_shape, args.action_shape, args.device args.action_shape, concat=True, device=args.device)
).to(args.device) critic1 = Critic(net, args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic( critic2 = Critic(net, args.device).to(args.device)
args.layer_num, args.state_shape, args.action_shape, args.device
).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy( policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,

View File

@ -1,81 +0,0 @@
import torch
import numpy as np
from torch import nn
class Actor(nn.Module):
def __init__(self, layer_num, state_shape, action_shape,
max_action, device='cpu'):
super().__init__()
self.device = device
self.model = [
nn.Linear(np.prod(state_shape), 128),
nn.ReLU(inplace=True)]
for i in range(layer_num):
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
self.model += [nn.Linear(128, np.prod(action_shape))]
self.model = nn.Sequential(*self.model)
self._max = max_action
def forward(self, s, **kwargs):
s = torch.tensor(s, device=self.device, dtype=torch.float)
batch = s.shape[0]
s = s.view(batch, -1)
logits = self.model(s)
logits = self._max * torch.tanh(logits)
return logits, None
class ActorProb(nn.Module):
def __init__(self, layer_num, state_shape, action_shape,
max_action, device='cpu', unbounded=False):
super().__init__()
self.device = device
self.model = [
nn.Linear(np.prod(state_shape), 128),
nn.ReLU(inplace=True)]
for i in range(layer_num):
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
self.model = nn.Sequential(*self.model)
self.mu = nn.Linear(128, np.prod(action_shape))
self.sigma = nn.Linear(128, np.prod(action_shape))
self._max = max_action
self._unbounded = unbounded
def forward(self, s, **kwargs):
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, device=self.device, dtype=torch.float)
batch = s.shape[0]
s = s.view(batch, -1)
logits = self.model(s)
if not self._unbounded:
mu = self._max * torch.tanh(self.mu(logits))
sigma = torch.exp(self.sigma(logits))
return (mu, sigma), None
class Critic(nn.Module):
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'):
super().__init__()
self.device = device
self.model = [
nn.Linear(np.prod(state_shape) + np.prod(action_shape), 128),
nn.ReLU(inplace=True)]
for i in range(layer_num):
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
self.model += [nn.Linear(128, 1)]
self.model = nn.Sequential(*self.model)
def forward(self, s, a=None):
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, device=self.device, dtype=torch.float)
if a is not None and not isinstance(a, torch.Tensor):
a = torch.tensor(a, device=self.device, dtype=torch.float)
batch = s.shape[0]
s = s.view(batch, -1)
if a is None:
logits = self.model(s)
else:
a = a.view(batch, -1)
logits = self.model(torch.cat([s, a], dim=1))
return logits

View File

@ -15,8 +15,8 @@ try:
import pybullet_envs import pybullet_envs
except ImportError: except ImportError:
pass pass
from tianshou.utils.net.common import Net
from continuous_net import ActorProb, Critic from tianshou.utils.net.continuous import ActorProb, Critic
def get_args(): def get_args():
@ -66,18 +66,17 @@ def test_sac(args=get_args()):
train_envs.seed(args.seed) train_envs.seed(args.seed)
test_envs.seed(args.seed) test_envs.seed(args.seed)
# model # model
net = Net(args.layer_num, args.state_shape, device=args.device)
actor = ActorProb( actor = ActorProb(
args.layer_num, args.state_shape, args.action_shape, net, args.action_shape,
args.max_action, args.device, unbounded=True args.max_action, args.device, unbounded=True
).to(args.device) ).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic( net = Net(args.layer_num, args.state_shape,
args.layer_num, args.state_shape, args.action_shape, args.device args.action_shape, concat=True, device=args.device)
).to(args.device) critic1 = Critic(net, args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic( critic2 = Critic(net, args.device).to(args.device)
args.layer_num, args.state_shape, args.action_shape, args.device
).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = SACPolicy( policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,

View File

@ -10,7 +10,8 @@ from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.env import VectorEnv, SubprocVectorEnv from tianshou.env import VectorEnv, SubprocVectorEnv
from tianshou.exploration import GaussianNoise from tianshou.exploration import GaussianNoise
from continuous_net import Actor, Critic from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, Critic
from mujoco.register import reg from mujoco.register import reg
@ -63,18 +64,17 @@ def test_td3(args=get_args()):
train_envs.seed(args.seed) train_envs.seed(args.seed)
test_envs.seed(args.seed) test_envs.seed(args.seed)
# model # model
net = Net(args.layer_num, args.state_shape, device=args.device)
actor = Actor( actor = Actor(
args.layer_num, args.state_shape, args.action_shape, net, args.action_shape,
args.max_action, args.device args.max_action, args.device
).to(args.device) ).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic( net = Net(args.layer_num, args.state_shape,
args.layer_num, args.state_shape, args.action_shape, args.device args.action_shape, concat=True, device=args.device)
).to(args.device) critic1 = Critic(net, args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic( critic2 = Critic(net, args.device).to(args.device)
args.layer_num, args.state_shape, args.action_shape, args.device
).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy( policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,

View File

@ -10,7 +10,8 @@ from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.env.atari import create_atari_environment from tianshou.env.atari import create_atari_environment
from discrete_net import Net, Actor, Critic from tianshou.utils.net.discrete import Actor, Critic
from tianshou.utils.net.common import Net
def get_args(): def get_args():

View File

@ -6,12 +6,11 @@ from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DQNPolicy from tianshou.policy import DQNPolicy
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.discrete import DQN
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.env.atari import create_atari_environment from tianshou.env.atari import create_atari_environment
from discrete_net import DQN
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -9,8 +9,8 @@ from tianshou.env import SubprocVectorEnv
from tianshou.trainer import onpolicy_trainer from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.env.atari import create_atari_environment from tianshou.env.atari import create_atari_environment
from tianshou.utils.net.discrete import Actor, Critic
from discrete_net import Net, Actor, Critic from tianshou.utils.net.common import Net
def get_args(): def get_args():

View File

@ -11,8 +11,8 @@ from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.env import VectorEnv from tianshou.env import VectorEnv
from tianshou.exploration import OUNoise from tianshou.exploration import OUNoise
from tianshou.utils.net.common import Net
from continuous_net import ActorProb, Critic from tianshou.utils.net.continuous import ActorProb, Critic
def get_args(): def get_args():
@ -62,18 +62,17 @@ def test_sac(args=get_args()):
train_envs.seed(args.seed) train_envs.seed(args.seed)
test_envs.seed(args.seed) test_envs.seed(args.seed)
# model # model
net = Net(args.layer_num, args.state_shape, device=args.device)
actor = ActorProb( actor = ActorProb(
args.layer_num, args.state_shape, args.action_shape, net, args.action_shape,
args.max_action, args.device, unbounded=True args.max_action, args.device, unbounded=True
).to(args.device) ).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic( net = Net(args.layer_num, args.state_shape,
args.layer_num, args.state_shape, args.action_shape, args.device args.action_shape, concat=True, device=args.device)
).to(args.device) critic1 = Critic(net, args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic( critic2 = Critic(net, args.device).to(args.device)
args.layer_num, args.state_shape, args.action_shape, args.device
).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
if args.auto_alpha: if args.auto_alpha:

View File

@ -11,11 +11,8 @@ from tianshou.policy import DDPGPolicy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.exploration import GaussianNoise from tianshou.exploration import GaussianNoise
from tianshou.utils.net.common import Net
if __name__ == '__main__': from tianshou.utils.net.continuous import Actor, Critic
from net import Actor, Critic
else: # pytest
from test.continuous.net import Actor, Critic
def get_args(): def get_args():
@ -69,14 +66,15 @@ def test_ddpg(args=get_args()):
train_envs.seed(args.seed) train_envs.seed(args.seed)
test_envs.seed(args.seed) test_envs.seed(args.seed)
# model # model
net = Net(args.layer_num, args.state_shape, device=args.device)
actor = Actor( actor = Actor(
args.layer_num, args.state_shape, args.action_shape, net, args.action_shape,
args.max_action, args.device args.max_action, args.device
).to(args.device) ).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic = Critic( net = Net(args.layer_num, args.state_shape,
args.layer_num, args.state_shape, args.action_shape, args.device args.action_shape, concat=True, device=args.device)
).to(args.device) critic = Critic(net, args.device).to(args.device)
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
policy = DDPGPolicy( policy = DDPGPolicy(
actor, actor_optim, critic, critic_optim, actor, actor_optim, critic, critic_optim,

View File

@ -11,11 +11,8 @@ from tianshou.policy import PPOPolicy
from tianshou.policy.dist import DiagGaussian from tianshou.policy.dist import DiagGaussian
from tianshou.trainer import onpolicy_trainer from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.utils.net.common import Net
if __name__ == '__main__': from tianshou.utils.net.continuous import ActorProb, Critic
from net import ActorProb, Critic
else: # pytest
from test.continuous.net import ActorProb, Critic
def get_args(): def get_args():
@ -72,13 +69,14 @@ def test_ppo(args=get_args()):
train_envs.seed(args.seed) train_envs.seed(args.seed)
test_envs.seed(args.seed) test_envs.seed(args.seed)
# model # model
net = Net(args.layer_num, args.state_shape, device=args.device)
actor = ActorProb( actor = ActorProb(
args.layer_num, args.state_shape, args.action_shape, net, args.action_shape,
args.max_action, args.device args.max_action, args.device
).to(args.device) ).to(args.device)
critic = Critic( critic = Critic(Net(
args.layer_num, args.state_shape, device=args.device args.layer_num, args.state_shape, device=args.device
).to(args.device) ), device=args.device).to(args.device)
# orthogonal initialization # orthogonal initialization
for m in list(actor.modules()) + list(critic.modules()): for m in list(actor.modules()) + list(critic.modules()):
if isinstance(m, torch.nn.Linear): if isinstance(m, torch.nn.Linear):

View File

@ -10,11 +10,8 @@ from tianshou.env import VectorEnv
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.policy import SACPolicy, ImitationPolicy from tianshou.policy import SACPolicy, ImitationPolicy
from tianshou.utils.net.common import Net
if __name__ == '__main__': from tianshou.utils.net.continuous import Actor, ActorProb, Critic
from net import Actor, ActorProb, Critic
else: # pytest
from test.continuous.net import Actor, ActorProb, Critic
def get_args(): def get_args():
@ -68,18 +65,17 @@ def test_sac_with_il(args=get_args()):
train_envs.seed(args.seed) train_envs.seed(args.seed)
test_envs.seed(args.seed) test_envs.seed(args.seed)
# model # model
net = Net(args.layer_num, args.state_shape, device=args.device)
actor = ActorProb( actor = ActorProb(
args.layer_num, args.state_shape, args.action_shape, net, args.action_shape,
args.max_action, args.device args.max_action, args.device
).to(args.device) ).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic( net = Net(args.layer_num, args.state_shape,
args.layer_num, args.state_shape, args.action_shape, args.device args.action_shape, concat=True, device=args.device)
).to(args.device) critic1 = Critic(net, args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic( critic2 = Critic(net, args.device).to(args.device)
args.layer_num, args.state_shape, args.action_shape, args.device
).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = SACPolicy( policy = SACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
@ -122,8 +118,9 @@ def test_sac_with_il(args=get_args()):
# here we define an imitation collector with a trivial policy # here we define an imitation collector with a trivial policy
if args.task == 'Pendulum-v0': if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -300 # lower the goal env.spec.reward_threshold = -300 # lower the goal
net = Actor(1, args.state_shape, args.action_shape, net = Actor(Net(1, args.state_shape, device=args.device),
args.max_action, args.device).to(args.device) args.action_shape, args.max_action, args.device
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
il_policy = ImitationPolicy(net, optim, mode='continuous') il_policy = ImitationPolicy(net, optim, mode='continuous')
il_test_collector = Collector(il_policy, test_envs) il_test_collector = Collector(il_policy, test_envs)

View File

@ -11,11 +11,8 @@ from tianshou.policy import TD3Policy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.exploration import GaussianNoise from tianshou.exploration import GaussianNoise
from tianshou.utils.net.common import Net
if __name__ == '__main__': from tianshou.utils.net.continuous import Actor, Critic
from net import Actor, Critic
else: # pytest
from test.continuous.net import Actor, Critic
def get_args(): def get_args():
@ -71,18 +68,17 @@ def test_td3(args=get_args()):
train_envs.seed(args.seed) train_envs.seed(args.seed)
test_envs.seed(args.seed) test_envs.seed(args.seed)
# model # model
net = Net(args.layer_num, args.state_shape, device=args.device)
actor = Actor( actor = Actor(
args.layer_num, args.state_shape, args.action_shape, net, args.action_shape,
args.max_action, args.device args.max_action, args.device
).to(args.device) ).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
critic1 = Critic( net = Net(args.layer_num, args.state_shape,
args.layer_num, args.state_shape, args.action_shape, args.device args.action_shape, concat=True, device=args.device)
).to(args.device) critic1 = Critic(net, args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic( critic2 = Critic(net, args.device).to(args.device)
args.layer_num, args.state_shape, args.action_shape, args.device
).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
policy = TD3Policy( policy = TD3Policy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,

View File

@ -10,11 +10,8 @@ from tianshou.env import VectorEnv
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.policy import A2CPolicy, ImitationPolicy from tianshou.policy import A2CPolicy, ImitationPolicy
from tianshou.trainer import onpolicy_trainer, offpolicy_trainer from tianshou.trainer import onpolicy_trainer, offpolicy_trainer
from tianshou.utils.net.discrete import Actor, Critic
if __name__ == '__main__': from tianshou.utils.net.common import Net
from net import Net, Actor, Critic
else: # pytest
from test.discrete.net import Net, Actor, Critic
def get_args(): def get_args():

View File

@ -10,11 +10,7 @@ from tianshou.env import VectorEnv
from tianshou.policy import DQNPolicy from tianshou.policy import DQNPolicy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.utils.net.common import Net
if __name__ == '__main__':
from net import Net
else: # pytest
from test.discrete.net import Net
def get_args(): def get_args():
@ -61,8 +57,8 @@ def test_dqn(args=get_args()):
train_envs.seed(args.seed) train_envs.seed(args.seed)
test_envs.seed(args.seed) test_envs.seed(args.seed)
# model # model
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device) net = Net(args.layer_num, args.state_shape,
net = net.to(args.device) args.action_shape, args.device).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr) optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = DQNPolicy( policy = DQNPolicy(
net, optim, args.gamma, args.n_step, net, optim, args.gamma, args.n_step,

View File

@ -10,11 +10,7 @@ from tianshou.env import VectorEnv
from tianshou.policy import DQNPolicy from tianshou.policy import DQNPolicy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.utils.net.common import Recurrent
if __name__ == '__main__':
from net import Recurrent
else: # pytest
from test.discrete.net import Recurrent
def get_args(): def get_args():
@ -63,8 +59,7 @@ def test_drqn(args=get_args()):
test_envs.seed(args.seed) test_envs.seed(args.seed)
# model # model
net = Recurrent(args.layer_num, args.state_shape, net = Recurrent(args.layer_num, args.state_shape,
args.action_shape, args.device) args.action_shape, args.device).to(args.device)
net = net.to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr) optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = DQNPolicy( policy = DQNPolicy(
net, optim, args.gamma, args.n_step, net, optim, args.gamma, args.n_step,

View File

@ -6,16 +6,12 @@ import argparse
import numpy as np import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.utils.net.common import Net
from tianshou.env import VectorEnv from tianshou.env import VectorEnv
from tianshou.policy import DQNPolicy from tianshou.policy import DQNPolicy
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer
if __name__ == '__main__':
from net import Net
else: # pytest
from test.discrete.net import Net
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -64,8 +60,8 @@ def test_pdqn(args=get_args()):
train_envs.seed(args.seed) train_envs.seed(args.seed)
test_envs.seed(args.seed) test_envs.seed(args.seed)
# model # model
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device) net = Net(args.layer_num, args.state_shape,
net = net.to(args.device) args.action_shape, args.device).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr) optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = DQNPolicy( policy = DQNPolicy(
net, optim, args.gamma, args.n_step, net, optim, args.gamma, args.n_step,

View File

@ -7,16 +7,12 @@ import argparse
import numpy as np import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.utils.net.common import Net
from tianshou.env import VectorEnv from tianshou.env import VectorEnv
from tianshou.policy import PGPolicy from tianshou.policy import PGPolicy
from tianshou.trainer import onpolicy_trainer from tianshou.trainer import onpolicy_trainer
from tianshou.data import Batch, Collector, ReplayBuffer from tianshou.data import Batch, Collector, ReplayBuffer
if __name__ == '__main__':
from net import Net
else: # pytest
from test.discrete.net import Net
def compute_return_base(batch, aa=None, bb=None, gamma=0.1): def compute_return_base(batch, aa=None, bb=None, gamma=0.1):
returns = np.zeros_like(batch.rew) returns = np.zeros_like(batch.rew)
@ -129,8 +125,7 @@ def test_pg(args=get_args()):
# model # model
net = Net( net = Net(
args.layer_num, args.state_shape, args.action_shape, args.layer_num, args.state_shape, args.action_shape,
device=args.device, softmax=True) device=args.device, softmax=True).to(args.device)
net = net.to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr) optim = torch.optim.Adam(net.parameters(), lr=args.lr)
dist = torch.distributions.Categorical dist = torch.distributions.Categorical
policy = PGPolicy(net, optim, dist, args.gamma, policy = PGPolicy(net, optim, dist, args.gamma,

View File

@ -10,11 +10,8 @@ from tianshou.env import VectorEnv
from tianshou.policy import PPOPolicy from tianshou.policy import PPOPolicy
from tianshou.trainer import onpolicy_trainer from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.utils.net.discrete import Actor, Critic
if __name__ == '__main__': from tianshou.utils.net.common import Net
from net import Net, Actor, Critic
else: # pytest
from test.discrete.net import Net, Actor, Critic
def get_args(): def get_args():

View File

@ -46,7 +46,7 @@ class ImitationPolicy(BasePolicy):
self.optim.zero_grad() self.optim.zero_grad()
if self.mode == 'continuous': if self.mode == 'continuous':
a = self(batch).act a = self(batch).act
a_ = to_torch(batch.act, dtype=torch.float, device=a.device) a_ = to_torch(batch.act, dtype=torch.float32, device=a.device)
loss = F.mse_loss(a, a_) loss = F.mse_loss(a, a_)
elif self.mode == 'discrete': # classification elif self.mode == 'discrete': # classification
a = self(batch).logits a = self(batch).logits

View File

View File

@ -1,82 +1,67 @@
import torch
import numpy as np import numpy as np
import torch
from torch import nn from torch import nn
import torch.nn.functional as F
from tianshou.data import to_torch from tianshou.data import to_torch
class Net(nn.Module): class Net(nn.Module):
"""Simple MLP backbone. For advanced usage (how to customize the network),
please refer to :ref:`build_the_network`.
:param concat: whether the input shape is concatenated by state_shape
and action_shape. If it is True, ``action_shape`` is not the output
shape, but affects the input shape.
"""
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu', def __init__(self, layer_num, state_shape, action_shape=0, device='cpu',
softmax=False): softmax=False, concat=False):
super().__init__() super().__init__()
self.device = device self.device = device
input_size = np.prod(state_shape)
if concat:
input_size += np.prod(action_shape)
self.model = [ self.model = [
nn.Linear(np.prod(state_shape), 128), nn.Linear(input_size, 128),
nn.ReLU(inplace=True)] nn.ReLU(inplace=True)]
for i in range(layer_num): for i in range(layer_num):
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
if action_shape: if action_shape and not concat:
self.model += [nn.Linear(128, np.prod(action_shape))] self.model += [nn.Linear(128, np.prod(action_shape))]
if softmax: if softmax:
self.model += [nn.Softmax(dim=-1)] self.model += [nn.Softmax(dim=-1)]
self.model = nn.Sequential(*self.model) self.model = nn.Sequential(*self.model)
def forward(self, s, state=None, info={}): def forward(self, s, state=None, info={}):
s = to_torch(s, device=self.device, dtype=torch.float) s = to_torch(s, device=self.device, dtype=torch.float32)
batch = s.shape[0] s = s.flatten(1)
s = s.view(batch, -1)
logits = self.model(s) logits = self.model(s)
return logits, state return logits, state
class Actor(nn.Module):
def __init__(self, preprocess_net, action_shape):
super().__init__()
self.preprocess = preprocess_net
self.last = nn.Linear(128, np.prod(action_shape))
def forward(self, s, state=None, info={}):
logits, h = self.preprocess(s, state)
logits = F.softmax(self.last(logits), dim=-1)
return logits, h
class Critic(nn.Module):
def __init__(self, preprocess_net):
super().__init__()
self.preprocess = preprocess_net
self.last = nn.Linear(128, 1)
def forward(self, s, **kwargs):
logits, h = self.preprocess(s, state=kwargs.get('state', None))
logits = self.last(logits)
return logits
class Recurrent(nn.Module): class Recurrent(nn.Module):
"""Simple Recurrent network based on LSTM. For advanced usage (how to
customize the network), please refer to :ref:`build_the_network`.
"""
def __init__(self, layer_num, state_shape, action_shape, device='cpu'): def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
super().__init__() super().__init__()
self.state_shape = state_shape self.state_shape = state_shape
self.action_shape = action_shape self.action_shape = action_shape
self.device = device self.device = device
self.fc1 = nn.Linear(np.prod(state_shape), 128)
self.nn = nn.LSTM(input_size=128, hidden_size=128, self.nn = nn.LSTM(input_size=128, hidden_size=128,
num_layers=layer_num, batch_first=True) num_layers=layer_num, batch_first=True)
self.fc1 = nn.Linear(np.prod(state_shape), 128)
self.fc2 = nn.Linear(128, np.prod(action_shape)) self.fc2 = nn.Linear(128, np.prod(action_shape))
def forward(self, s, state=None, info={}): def forward(self, s, state=None, info={}):
s = to_torch(s, device=self.device, dtype=torch.float) s = to_torch(s, device=self.device, dtype=torch.float32)
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which # In short, the tensor's shape in training phase is longer than which
# in evaluation phase. # in evaluation phase.
if len(s.shape) == 2: if len(s.shape) == 2:
bsz, dim = s.shape s = s.unsqueeze(-2)
length = 1 s = self.fc1(s)
else:
bsz, length, dim = s.shape
s = self.fc1(s.view([bsz * length, dim]))
s = s.view(bsz, length, -1)
self.nn.flatten_parameters() self.nn.flatten_parameters()
if state is None: if state is None:
s, (h, c) = self.nn(s) s, (h, c) = self.nn(s)

View File

@ -6,85 +6,77 @@ from tianshou.data import to_torch
class Actor(nn.Module): class Actor(nn.Module):
def __init__(self, layer_num, state_shape, action_shape, """For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(self, preprocess_net, action_shape,
max_action, device='cpu'): max_action, device='cpu'):
super().__init__() super().__init__()
self.device = device self.preprocess = preprocess_net
self.model = [ self.last = nn.Linear(128, np.prod(action_shape))
nn.Linear(np.prod(state_shape), 128),
nn.ReLU(inplace=True)]
for i in range(layer_num):
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
self.model += [nn.Linear(128, np.prod(action_shape))]
self.model = nn.Sequential(*self.model)
self._max = max_action self._max = max_action
def forward(self, s, **kwargs): def forward(self, s, state=None, info={}):
s = to_torch(s, device=self.device, dtype=torch.float) logits, h = self.preprocess(s, state)
batch = s.shape[0] logits = self._max * torch.tanh(self.last(logits))
s = s.view(batch, -1) return logits, h
logits = self.model(s)
logits = self._max * torch.tanh(logits)
return logits, None
class ActorProb(nn.Module):
def __init__(self, layer_num, state_shape, action_shape,
max_action, device='cpu'):
super().__init__()
self.device = device
self.model = [
nn.Linear(np.prod(state_shape), 128),
nn.ReLU(inplace=True)]
for i in range(layer_num):
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
self.model = nn.Sequential(*self.model)
self.mu = nn.Linear(128, np.prod(action_shape))
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
# self.sigma = nn.Linear(128, np.prod(action_shape))
self._max = max_action
def forward(self, s, **kwargs):
s = to_torch(s, device=self.device, dtype=torch.float)
batch = s.shape[0]
s = s.view(batch, -1)
logits = self.model(s)
mu = self.mu(logits)
shape = [1] * len(mu.shape)
shape[1] = -1
sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp()
# assert sigma.shape == mu.shape
# mu = self._max * torch.tanh(self.mu(logits))
# sigma = torch.exp(self.sigma(logits))
return (mu, sigma), None
class Critic(nn.Module): class Critic(nn.Module):
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'): """For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(self, preprocess_net, device='cpu'):
super().__init__() super().__init__()
self.device = device self.device = device
self.model = [ self.preprocess = preprocess_net
nn.Linear(np.prod(state_shape) + np.prod(action_shape), 128), self.last = nn.Linear(128, 1)
nn.ReLU(inplace=True)]
for i in range(layer_num):
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
self.model += [nn.Linear(128, 1)]
self.model = nn.Sequential(*self.model)
def forward(self, s, a=None, **kwargs): def forward(self, s, a=None, **kwargs):
s = to_torch(s, device=self.device, dtype=torch.float) s = to_torch(s, device=self.device, dtype=torch.float32)
batch = s.shape[0] s = s.flatten(1)
s = s.view(batch, -1)
if a is not None: if a is not None:
if not isinstance(a, torch.Tensor): a = to_torch(a, device=self.device, dtype=torch.float32)
a = torch.tensor(a, device=self.device, dtype=torch.float) a = a.flatten(1)
a = a.view(batch, -1)
s = torch.cat([s, a], dim=1) s = torch.cat([s, a], dim=1)
logits = self.model(s) logits, h = self.preprocess(s)
logits = self.last(logits)
return logits return logits
class ActorProb(nn.Module):
"""For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(self, preprocess_net, action_shape,
max_action, device='cpu', unbounded=False):
super().__init__()
self.preprocess = preprocess_net
self.device = device
self.mu = nn.Linear(128, np.prod(action_shape))
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
self._max = max_action
self._unbounded = unbounded
def forward(self, s, state=None, **kwargs):
logits, h = self.preprocess(s, state)
mu = self.mu(logits)
if not self._unbounded:
mu = self._max * torch.tanh(mu)
shape = [1] * len(mu.shape)
shape[1] = -1
sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp()
return (mu, sigma), None
class RecurrentActorProb(nn.Module): class RecurrentActorProb(nn.Module):
"""For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(self, layer_num, state_shape, action_shape, def __init__(self, layer_num, state_shape, action_shape,
max_action, device='cpu'): max_action, device='cpu'):
super().__init__() super().__init__()
@ -95,16 +87,12 @@ class RecurrentActorProb(nn.Module):
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1)) self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
def forward(self, s, **kwargs): def forward(self, s, **kwargs):
s = to_torch(s, device=self.device, dtype=torch.float) s = to_torch(s, device=self.device, dtype=torch.float32)
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which # In short, the tensor's shape in training phase is longer than which
# in evaluation phase. # in evaluation phase.
if len(s.shape) == 2: if len(s.shape) == 2:
bsz, dim = s.shape s = s.unsqueeze(-2)
length = 1
else:
bsz, length, dim = s.shape
s = s.view(bsz, length, -1)
logits, _ = self.nn(s) logits, _ = self.nn(s)
logits = logits[:, -1] logits = logits[:, -1]
mu = self.mu(logits) mu = self.mu(logits)
@ -115,6 +103,10 @@ class RecurrentActorProb(nn.Module):
class RecurrentCritic(nn.Module): class RecurrentCritic(nn.Module):
"""For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'): def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'):
super().__init__() super().__init__()
self.state_shape = state_shape self.state_shape = state_shape
@ -125,7 +117,7 @@ class RecurrentCritic(nn.Module):
self.fc2 = nn.Linear(128 + np.prod(action_shape), 1) self.fc2 = nn.Linear(128 + np.prod(action_shape), 1)
def forward(self, s, a=None): def forward(self, s, a=None):
s = to_torch(s, device=self.device, dtype=torch.float) s = to_torch(s, device=self.device, dtype=torch.float32)
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
# In short, the tensor's shape in training phase is longer than which # In short, the tensor's shape in training phase is longer than which
# in evaluation phase. # in evaluation phase.
@ -135,7 +127,7 @@ class RecurrentCritic(nn.Module):
s = s[:, -1] s = s[:, -1]
if a is not None: if a is not None:
if not isinstance(a, torch.Tensor): if not isinstance(a, torch.Tensor):
a = torch.tensor(a, device=self.device, dtype=torch.float) a = torch.tensor(a, device=self.device, dtype=torch.float32)
s = torch.cat([s, a], dim=1) s = torch.cat([s, a], dim=1)
s = self.fc2(s) s = self.fc2(s)
return s return s

View File

@ -4,29 +4,11 @@ from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
class Net(nn.Module):
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'):
super().__init__()
self.device = device
self.model = [
nn.Linear(np.prod(state_shape), 128),
nn.ReLU(inplace=True)]
for i in range(layer_num):
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
if action_shape:
self.model += [nn.Linear(128, np.prod(action_shape))]
self.model = nn.Sequential(*self.model)
def forward(self, s, state=None, info={}):
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, device=self.device, dtype=torch.float)
batch = s.shape[0]
s = s.view(batch, -1)
logits = self.model(s)
return logits, state
class Actor(nn.Module): class Actor(nn.Module):
"""For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(self, preprocess_net, action_shape): def __init__(self, preprocess_net, action_shape):
super().__init__() super().__init__()
self.preprocess = preprocess_net self.preprocess = preprocess_net
@ -39,18 +21,25 @@ class Actor(nn.Module):
class Critic(nn.Module): class Critic(nn.Module):
"""For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(self, preprocess_net): def __init__(self, preprocess_net):
super().__init__() super().__init__()
self.preprocess = preprocess_net self.preprocess = preprocess_net
self.last = nn.Linear(128, 1) self.last = nn.Linear(128, 1)
def forward(self, s): def forward(self, s, **kwargs):
logits, h = self.preprocess(s, None) logits, h = self.preprocess(s, state=kwargs.get('state', None))
logits = self.last(logits) logits = self.last(logits)
return logits return logits
class DQN(nn.Module): class DQN(nn.Module):
"""For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
"""
def __init__(self, h, w, action_shape, device='cpu'): def __init__(self, h, w, action_shape, device='cpu'):
super(DQN, self).__init__() super(DQN, self).__init__()
@ -74,7 +63,7 @@ class DQN(nn.Module):
def forward(self, x, state=None, info={}): def forward(self, x, state=None, info={}):
if not isinstance(x, torch.Tensor): if not isinstance(x, torch.Tensor):
x = torch.tensor(x, device=self.device, dtype=torch.float) x = torch.tensor(x, device=self.device, dtype=torch.float32)
x = x.permute(0, 3, 1, 2) x = x.permute(0, 3, 1, 2)
x = F.relu(self.bn1(self.conv1(x))) x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x))) x = F.relu(self.bn2(self.conv2(x)))