Remove dummy net code (#123)
* remove dummy net; delete two files * split code to have backbone and head * rename class * change torch.float to torch.float32 * use flatten(1) instead of view(batch, -1) * remove dummy net in docs * bugfix for rnn * fix cuda error * minor fix of docs * do not change the example code in dqn tutorial, since it is for demonstration Co-authored-by: Trinkle23897 <463003665@qq.com>
This commit is contained in:
parent
aa3c453f42
commit
e767de044b
18
README.md
18
README.md
@ -206,26 +206,12 @@ test_envs = ts.env.VectorEnv([lambda: gym.make(task) for _ in range(test_num)])
|
|||||||
Define the network:
|
Define the network:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class Net(nn.Module):
|
from tianshou.utils.net.common import Net
|
||||||
def __init__(self, state_shape, action_shape):
|
|
||||||
super().__init__()
|
|
||||||
self.model = nn.Sequential(*[
|
|
||||||
nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True),
|
|
||||||
nn.Linear(128, 128), nn.ReLU(inplace=True),
|
|
||||||
nn.Linear(128, 128), nn.ReLU(inplace=True),
|
|
||||||
nn.Linear(128, np.prod(action_shape))
|
|
||||||
])
|
|
||||||
def forward(self, s, state=None, info={}):
|
|
||||||
if not isinstance(s, torch.Tensor):
|
|
||||||
s = torch.tensor(s, dtype=torch.float)
|
|
||||||
batch = s.shape[0]
|
|
||||||
logits = self.model(s.view(batch, -1))
|
|
||||||
return logits, state
|
|
||||||
|
|
||||||
env = gym.make(task)
|
env = gym.make(task)
|
||||||
state_shape = env.observation_space.shape or env.observation_space.n
|
state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
action_shape = env.action_space.shape or env.action_space.n
|
action_shape = env.action_space.shape or env.action_space.n
|
||||||
net = Net(state_shape, action_shape)
|
net = Net(layer_num=2, state_shape=state_shape, action_shape=action_shape)
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=lr)
|
optim = torch.optim.Adam(net.parameters(), lr=lr)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -5,3 +5,18 @@ tianshou.utils
|
|||||||
:members:
|
:members:
|
||||||
:undoc-members:
|
:undoc-members:
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
|
|
||||||
|
.. automodule:: tianshou.utils.net.common
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
|
.. automodule:: tianshou.utils.net.discrete
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
|
.. automodule:: tianshou.utils.net.continuous
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
@ -74,7 +74,7 @@ Tianshou supports any user-defined PyTorch networks and optimizers but with the
|
|||||||
net = Net(state_shape, action_shape)
|
net = Net(state_shape, action_shape)
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=1e-3)
|
optim = torch.optim.Adam(net.parameters(), lr=1e-3)
|
||||||
|
|
||||||
The rules of self-defined networks are:
|
You can also have a try with those pre-defined networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are:
|
||||||
|
|
||||||
1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment.
|
1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment.
|
||||||
2. Output: some ``logits``, the next hidden state ``state``, and intermediate result during the policy forwarding procedure ``policy``. The logits could be a tuple instead of a ``torch.Tensor``. It depends on how the policy process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. The ``policy`` can be a Batch of torch.Tensor or other things, which will be stored in the replay buffer, and can be accessed in the policy update process (e.g. in ``policy.learn()``, the ``batch.policy`` is what you need).
|
2. Output: some ``logits``, the next hidden state ``state``, and intermediate result during the policy forwarding procedure ``policy``. The logits could be a tuple instead of a ``torch.Tensor``. It depends on how the policy process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. The ``policy`` can be a Batch of torch.Tensor or other things, which will be stored in the replay buffer, and can be accessed in the policy update process (e.g. in ``policy.learn()``, the ``batch.policy`` is what you need).
|
||||||
|
@ -10,8 +10,8 @@ from tianshou.trainer import offpolicy_trainer
|
|||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||||
from tianshou.exploration import GaussianNoise
|
from tianshou.exploration import GaussianNoise
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
from continuous_net import Actor, Critic
|
from tianshou.utils.net.continuous import Actor, Critic
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -57,14 +57,13 @@ def test_ddpg(args=get_args()):
|
|||||||
train_envs.seed(args.seed)
|
train_envs.seed(args.seed)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
# model
|
# model
|
||||||
actor = Actor(
|
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||||
args.layer_num, args.state_shape, args.action_shape,
|
actor = Actor(net, args.action_shape, args.max_action,
|
||||||
args.max_action, args.device
|
args.device).to(args.device)
|
||||||
).to(args.device)
|
|
||||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||||
critic = Critic(
|
net = Net(args.layer_num, args.state_shape,
|
||||||
args.layer_num, args.state_shape, args.action_shape, args.device
|
args.action_shape, concat=True, device=args.device)
|
||||||
).to(args.device)
|
critic = Critic(net, args.device).to(args.device)
|
||||||
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
|
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
|
||||||
policy = DDPGPolicy(
|
policy = DDPGPolicy(
|
||||||
actor, actor_optim, critic, critic_optim,
|
actor, actor_optim, critic, critic_optim,
|
||||||
|
@ -10,8 +10,8 @@ from tianshou.policy import SACPolicy
|
|||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
from continuous_net import ActorProb, Critic
|
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -58,18 +58,17 @@ def test_sac(args=get_args()):
|
|||||||
train_envs.seed(args.seed)
|
train_envs.seed(args.seed)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
# model
|
# model
|
||||||
|
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||||
actor = ActorProb(
|
actor = ActorProb(
|
||||||
args.layer_num, args.state_shape, args.action_shape,
|
net, args.action_shape,
|
||||||
args.max_action, args.device, unbounded=True
|
args.max_action, args.device, unbounded=True
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||||
critic1 = Critic(
|
net = Net(args.layer_num, args.state_shape,
|
||||||
args.layer_num, args.state_shape, args.action_shape, args.device
|
args.action_shape, concat=True, device=args.device)
|
||||||
).to(args.device)
|
critic1 = Critic(net, args.device).to(args.device)
|
||||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||||
critic2 = Critic(
|
critic2 = Critic(net, args.device).to(args.device)
|
||||||
args.layer_num, args.state_shape, args.action_shape, args.device
|
|
||||||
).to(args.device)
|
|
||||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||||
policy = SACPolicy(
|
policy = SACPolicy(
|
||||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||||
|
@ -10,8 +10,8 @@ from tianshou.trainer import offpolicy_trainer
|
|||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||||
from tianshou.exploration import GaussianNoise
|
from tianshou.exploration import GaussianNoise
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
from continuous_net import Actor, Critic
|
from tianshou.utils.net.continuous import Actor, Critic
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -60,18 +60,17 @@ def test_td3(args=get_args()):
|
|||||||
train_envs.seed(args.seed)
|
train_envs.seed(args.seed)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
# model
|
# model
|
||||||
|
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||||
actor = Actor(
|
actor = Actor(
|
||||||
args.layer_num, args.state_shape, args.action_shape,
|
net, args.action_shape,
|
||||||
args.max_action, args.device
|
args.max_action, args.device
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||||
critic1 = Critic(
|
net = Net(args.layer_num, args.state_shape,
|
||||||
args.layer_num, args.state_shape, args.action_shape, args.device
|
args.action_shape, concat=True, device=args.device)
|
||||||
).to(args.device)
|
critic1 = Critic(net, args.device).to(args.device)
|
||||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||||
critic2 = Critic(
|
critic2 = Critic(net, args.device).to(args.device)
|
||||||
args.layer_num, args.state_shape, args.action_shape, args.device
|
|
||||||
).to(args.device)
|
|
||||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||||
policy = TD3Policy(
|
policy = TD3Policy(
|
||||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||||
|
@ -1,81 +0,0 @@
|
|||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
|
|
||||||
class Actor(nn.Module):
|
|
||||||
def __init__(self, layer_num, state_shape, action_shape,
|
|
||||||
max_action, device='cpu'):
|
|
||||||
super().__init__()
|
|
||||||
self.device = device
|
|
||||||
self.model = [
|
|
||||||
nn.Linear(np.prod(state_shape), 128),
|
|
||||||
nn.ReLU(inplace=True)]
|
|
||||||
for i in range(layer_num):
|
|
||||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
|
||||||
self.model += [nn.Linear(128, np.prod(action_shape))]
|
|
||||||
self.model = nn.Sequential(*self.model)
|
|
||||||
self._max = max_action
|
|
||||||
|
|
||||||
def forward(self, s, **kwargs):
|
|
||||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
|
||||||
batch = s.shape[0]
|
|
||||||
s = s.view(batch, -1)
|
|
||||||
logits = self.model(s)
|
|
||||||
logits = self._max * torch.tanh(logits)
|
|
||||||
return logits, None
|
|
||||||
|
|
||||||
|
|
||||||
class ActorProb(nn.Module):
|
|
||||||
def __init__(self, layer_num, state_shape, action_shape,
|
|
||||||
max_action, device='cpu', unbounded=False):
|
|
||||||
super().__init__()
|
|
||||||
self.device = device
|
|
||||||
self.model = [
|
|
||||||
nn.Linear(np.prod(state_shape), 128),
|
|
||||||
nn.ReLU(inplace=True)]
|
|
||||||
for i in range(layer_num):
|
|
||||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
|
||||||
self.model = nn.Sequential(*self.model)
|
|
||||||
self.mu = nn.Linear(128, np.prod(action_shape))
|
|
||||||
self.sigma = nn.Linear(128, np.prod(action_shape))
|
|
||||||
self._max = max_action
|
|
||||||
self._unbounded = unbounded
|
|
||||||
|
|
||||||
def forward(self, s, **kwargs):
|
|
||||||
if not isinstance(s, torch.Tensor):
|
|
||||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
|
||||||
batch = s.shape[0]
|
|
||||||
s = s.view(batch, -1)
|
|
||||||
logits = self.model(s)
|
|
||||||
if not self._unbounded:
|
|
||||||
mu = self._max * torch.tanh(self.mu(logits))
|
|
||||||
sigma = torch.exp(self.sigma(logits))
|
|
||||||
return (mu, sigma), None
|
|
||||||
|
|
||||||
|
|
||||||
class Critic(nn.Module):
|
|
||||||
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'):
|
|
||||||
super().__init__()
|
|
||||||
self.device = device
|
|
||||||
self.model = [
|
|
||||||
nn.Linear(np.prod(state_shape) + np.prod(action_shape), 128),
|
|
||||||
nn.ReLU(inplace=True)]
|
|
||||||
for i in range(layer_num):
|
|
||||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
|
||||||
self.model += [nn.Linear(128, 1)]
|
|
||||||
self.model = nn.Sequential(*self.model)
|
|
||||||
|
|
||||||
def forward(self, s, a=None):
|
|
||||||
if not isinstance(s, torch.Tensor):
|
|
||||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
|
||||||
if a is not None and not isinstance(a, torch.Tensor):
|
|
||||||
a = torch.tensor(a, device=self.device, dtype=torch.float)
|
|
||||||
batch = s.shape[0]
|
|
||||||
s = s.view(batch, -1)
|
|
||||||
if a is None:
|
|
||||||
logits = self.model(s)
|
|
||||||
else:
|
|
||||||
a = a.view(batch, -1)
|
|
||||||
logits = self.model(torch.cat([s, a], dim=1))
|
|
||||||
return logits
|
|
@ -15,8 +15,8 @@ try:
|
|||||||
import pybullet_envs
|
import pybullet_envs
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
from continuous_net import ActorProb, Critic
|
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -66,18 +66,17 @@ def test_sac(args=get_args()):
|
|||||||
train_envs.seed(args.seed)
|
train_envs.seed(args.seed)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
# model
|
# model
|
||||||
|
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||||
actor = ActorProb(
|
actor = ActorProb(
|
||||||
args.layer_num, args.state_shape, args.action_shape,
|
net, args.action_shape,
|
||||||
args.max_action, args.device, unbounded=True
|
args.max_action, args.device, unbounded=True
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||||
critic1 = Critic(
|
net = Net(args.layer_num, args.state_shape,
|
||||||
args.layer_num, args.state_shape, args.action_shape, args.device
|
args.action_shape, concat=True, device=args.device)
|
||||||
).to(args.device)
|
critic1 = Critic(net, args.device).to(args.device)
|
||||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||||
critic2 = Critic(
|
critic2 = Critic(net, args.device).to(args.device)
|
||||||
args.layer_num, args.state_shape, args.action_shape, args.device
|
|
||||||
).to(args.device)
|
|
||||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||||
policy = SACPolicy(
|
policy = SACPolicy(
|
||||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||||
|
@ -10,7 +10,8 @@ from tianshou.trainer import offpolicy_trainer
|
|||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||||
from tianshou.exploration import GaussianNoise
|
from tianshou.exploration import GaussianNoise
|
||||||
from continuous_net import Actor, Critic
|
from tianshou.utils.net.common import Net
|
||||||
|
from tianshou.utils.net.continuous import Actor, Critic
|
||||||
from mujoco.register import reg
|
from mujoco.register import reg
|
||||||
|
|
||||||
|
|
||||||
@ -63,18 +64,17 @@ def test_td3(args=get_args()):
|
|||||||
train_envs.seed(args.seed)
|
train_envs.seed(args.seed)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
# model
|
# model
|
||||||
|
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||||
actor = Actor(
|
actor = Actor(
|
||||||
args.layer_num, args.state_shape, args.action_shape,
|
net, args.action_shape,
|
||||||
args.max_action, args.device
|
args.max_action, args.device
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||||
critic1 = Critic(
|
net = Net(args.layer_num, args.state_shape,
|
||||||
args.layer_num, args.state_shape, args.action_shape, args.device
|
args.action_shape, concat=True, device=args.device)
|
||||||
).to(args.device)
|
critic1 = Critic(net, args.device).to(args.device)
|
||||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||||
critic2 = Critic(
|
critic2 = Critic(net, args.device).to(args.device)
|
||||||
args.layer_num, args.state_shape, args.action_shape, args.device
|
|
||||||
).to(args.device)
|
|
||||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||||
policy = TD3Policy(
|
policy = TD3Policy(
|
||||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||||
|
@ -10,7 +10,8 @@ from tianshou.trainer import onpolicy_trainer
|
|||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.env.atari import create_atari_environment
|
from tianshou.env.atari import create_atari_environment
|
||||||
|
|
||||||
from discrete_net import Net, Actor, Critic
|
from tianshou.utils.net.discrete import Actor, Critic
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
|
@ -6,12 +6,11 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
|
|
||||||
from tianshou.policy import DQNPolicy
|
from tianshou.policy import DQNPolicy
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
|
from tianshou.utils.net.discrete import DQN
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.env.atari import create_atari_environment
|
from tianshou.env.atari import create_atari_environment
|
||||||
|
|
||||||
from discrete_net import DQN
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -9,8 +9,8 @@ from tianshou.env import SubprocVectorEnv
|
|||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.env.atari import create_atari_environment
|
from tianshou.env.atari import create_atari_environment
|
||||||
|
from tianshou.utils.net.discrete import Actor, Critic
|
||||||
from discrete_net import Net, Actor, Critic
|
from tianshou.utils.net.common import Net
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
|
@ -11,8 +11,8 @@ from tianshou.trainer import offpolicy_trainer
|
|||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.env import VectorEnv
|
from tianshou.env import VectorEnv
|
||||||
from tianshou.exploration import OUNoise
|
from tianshou.exploration import OUNoise
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
from continuous_net import ActorProb, Critic
|
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -62,18 +62,17 @@ def test_sac(args=get_args()):
|
|||||||
train_envs.seed(args.seed)
|
train_envs.seed(args.seed)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
# model
|
# model
|
||||||
|
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||||
actor = ActorProb(
|
actor = ActorProb(
|
||||||
args.layer_num, args.state_shape, args.action_shape,
|
net, args.action_shape,
|
||||||
args.max_action, args.device, unbounded=True
|
args.max_action, args.device, unbounded=True
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||||
critic1 = Critic(
|
net = Net(args.layer_num, args.state_shape,
|
||||||
args.layer_num, args.state_shape, args.action_shape, args.device
|
args.action_shape, concat=True, device=args.device)
|
||||||
).to(args.device)
|
critic1 = Critic(net, args.device).to(args.device)
|
||||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||||
critic2 = Critic(
|
critic2 = Critic(net, args.device).to(args.device)
|
||||||
args.layer_num, args.state_shape, args.action_shape, args.device
|
|
||||||
).to(args.device)
|
|
||||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||||
|
|
||||||
if args.auto_alpha:
|
if args.auto_alpha:
|
||||||
|
@ -11,11 +11,8 @@ from tianshou.policy import DDPGPolicy
|
|||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.exploration import GaussianNoise
|
from tianshou.exploration import GaussianNoise
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
if __name__ == '__main__':
|
from tianshou.utils.net.continuous import Actor, Critic
|
||||||
from net import Actor, Critic
|
|
||||||
else: # pytest
|
|
||||||
from test.continuous.net import Actor, Critic
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -69,14 +66,15 @@ def test_ddpg(args=get_args()):
|
|||||||
train_envs.seed(args.seed)
|
train_envs.seed(args.seed)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
# model
|
# model
|
||||||
|
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||||
actor = Actor(
|
actor = Actor(
|
||||||
args.layer_num, args.state_shape, args.action_shape,
|
net, args.action_shape,
|
||||||
args.max_action, args.device
|
args.max_action, args.device
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||||
critic = Critic(
|
net = Net(args.layer_num, args.state_shape,
|
||||||
args.layer_num, args.state_shape, args.action_shape, args.device
|
args.action_shape, concat=True, device=args.device)
|
||||||
).to(args.device)
|
critic = Critic(net, args.device).to(args.device)
|
||||||
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
|
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
|
||||||
policy = DDPGPolicy(
|
policy = DDPGPolicy(
|
||||||
actor, actor_optim, critic, critic_optim,
|
actor, actor_optim, critic, critic_optim,
|
||||||
|
@ -11,11 +11,8 @@ from tianshou.policy import PPOPolicy
|
|||||||
from tianshou.policy.dist import DiagGaussian
|
from tianshou.policy.dist import DiagGaussian
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
if __name__ == '__main__':
|
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||||
from net import ActorProb, Critic
|
|
||||||
else: # pytest
|
|
||||||
from test.continuous.net import ActorProb, Critic
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -72,13 +69,14 @@ def test_ppo(args=get_args()):
|
|||||||
train_envs.seed(args.seed)
|
train_envs.seed(args.seed)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
# model
|
# model
|
||||||
|
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||||
actor = ActorProb(
|
actor = ActorProb(
|
||||||
args.layer_num, args.state_shape, args.action_shape,
|
net, args.action_shape,
|
||||||
args.max_action, args.device
|
args.max_action, args.device
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
critic = Critic(
|
critic = Critic(Net(
|
||||||
args.layer_num, args.state_shape, device=args.device
|
args.layer_num, args.state_shape, device=args.device
|
||||||
).to(args.device)
|
), device=args.device).to(args.device)
|
||||||
# orthogonal initialization
|
# orthogonal initialization
|
||||||
for m in list(actor.modules()) + list(critic.modules()):
|
for m in list(actor.modules()) + list(critic.modules()):
|
||||||
if isinstance(m, torch.nn.Linear):
|
if isinstance(m, torch.nn.Linear):
|
||||||
|
@ -10,11 +10,8 @@ from tianshou.env import VectorEnv
|
|||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.policy import SACPolicy, ImitationPolicy
|
from tianshou.policy import SACPolicy, ImitationPolicy
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
if __name__ == '__main__':
|
from tianshou.utils.net.continuous import Actor, ActorProb, Critic
|
||||||
from net import Actor, ActorProb, Critic
|
|
||||||
else: # pytest
|
|
||||||
from test.continuous.net import Actor, ActorProb, Critic
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -68,18 +65,17 @@ def test_sac_with_il(args=get_args()):
|
|||||||
train_envs.seed(args.seed)
|
train_envs.seed(args.seed)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
# model
|
# model
|
||||||
|
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||||
actor = ActorProb(
|
actor = ActorProb(
|
||||||
args.layer_num, args.state_shape, args.action_shape,
|
net, args.action_shape,
|
||||||
args.max_action, args.device
|
args.max_action, args.device
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||||
critic1 = Critic(
|
net = Net(args.layer_num, args.state_shape,
|
||||||
args.layer_num, args.state_shape, args.action_shape, args.device
|
args.action_shape, concat=True, device=args.device)
|
||||||
).to(args.device)
|
critic1 = Critic(net, args.device).to(args.device)
|
||||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||||
critic2 = Critic(
|
critic2 = Critic(net, args.device).to(args.device)
|
||||||
args.layer_num, args.state_shape, args.action_shape, args.device
|
|
||||||
).to(args.device)
|
|
||||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||||
policy = SACPolicy(
|
policy = SACPolicy(
|
||||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||||
@ -122,8 +118,9 @@ def test_sac_with_il(args=get_args()):
|
|||||||
# here we define an imitation collector with a trivial policy
|
# here we define an imitation collector with a trivial policy
|
||||||
if args.task == 'Pendulum-v0':
|
if args.task == 'Pendulum-v0':
|
||||||
env.spec.reward_threshold = -300 # lower the goal
|
env.spec.reward_threshold = -300 # lower the goal
|
||||||
net = Actor(1, args.state_shape, args.action_shape,
|
net = Actor(Net(1, args.state_shape, device=args.device),
|
||||||
args.max_action, args.device).to(args.device)
|
args.action_shape, args.max_action, args.device
|
||||||
|
).to(args.device)
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
|
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
|
||||||
il_policy = ImitationPolicy(net, optim, mode='continuous')
|
il_policy = ImitationPolicy(net, optim, mode='continuous')
|
||||||
il_test_collector = Collector(il_policy, test_envs)
|
il_test_collector = Collector(il_policy, test_envs)
|
||||||
|
@ -11,11 +11,8 @@ from tianshou.policy import TD3Policy
|
|||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.exploration import GaussianNoise
|
from tianshou.exploration import GaussianNoise
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
if __name__ == '__main__':
|
from tianshou.utils.net.continuous import Actor, Critic
|
||||||
from net import Actor, Critic
|
|
||||||
else: # pytest
|
|
||||||
from test.continuous.net import Actor, Critic
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -71,18 +68,17 @@ def test_td3(args=get_args()):
|
|||||||
train_envs.seed(args.seed)
|
train_envs.seed(args.seed)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
# model
|
# model
|
||||||
|
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||||
actor = Actor(
|
actor = Actor(
|
||||||
args.layer_num, args.state_shape, args.action_shape,
|
net, args.action_shape,
|
||||||
args.max_action, args.device
|
args.max_action, args.device
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||||
critic1 = Critic(
|
net = Net(args.layer_num, args.state_shape,
|
||||||
args.layer_num, args.state_shape, args.action_shape, args.device
|
args.action_shape, concat=True, device=args.device)
|
||||||
).to(args.device)
|
critic1 = Critic(net, args.device).to(args.device)
|
||||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||||
critic2 = Critic(
|
critic2 = Critic(net, args.device).to(args.device)
|
||||||
args.layer_num, args.state_shape, args.action_shape, args.device
|
|
||||||
).to(args.device)
|
|
||||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||||
policy = TD3Policy(
|
policy = TD3Policy(
|
||||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||||
|
@ -10,11 +10,8 @@ from tianshou.env import VectorEnv
|
|||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.policy import A2CPolicy, ImitationPolicy
|
from tianshou.policy import A2CPolicy, ImitationPolicy
|
||||||
from tianshou.trainer import onpolicy_trainer, offpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer, offpolicy_trainer
|
||||||
|
from tianshou.utils.net.discrete import Actor, Critic
|
||||||
if __name__ == '__main__':
|
from tianshou.utils.net.common import Net
|
||||||
from net import Net, Actor, Critic
|
|
||||||
else: # pytest
|
|
||||||
from test.discrete.net import Net, Actor, Critic
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
|
@ -10,11 +10,7 @@ from tianshou.env import VectorEnv
|
|||||||
from tianshou.policy import DQNPolicy
|
from tianshou.policy import DQNPolicy
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
if __name__ == '__main__':
|
|
||||||
from net import Net
|
|
||||||
else: # pytest
|
|
||||||
from test.discrete.net import Net
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -61,8 +57,8 @@ def test_dqn(args=get_args()):
|
|||||||
train_envs.seed(args.seed)
|
train_envs.seed(args.seed)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
# model
|
# model
|
||||||
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
|
net = Net(args.layer_num, args.state_shape,
|
||||||
net = net.to(args.device)
|
args.action_shape, args.device).to(args.device)
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
policy = DQNPolicy(
|
policy = DQNPolicy(
|
||||||
net, optim, args.gamma, args.n_step,
|
net, optim, args.gamma, args.n_step,
|
||||||
|
@ -10,11 +10,7 @@ from tianshou.env import VectorEnv
|
|||||||
from tianshou.policy import DQNPolicy
|
from tianshou.policy import DQNPolicy
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
|
from tianshou.utils.net.common import Recurrent
|
||||||
if __name__ == '__main__':
|
|
||||||
from net import Recurrent
|
|
||||||
else: # pytest
|
|
||||||
from test.discrete.net import Recurrent
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -63,8 +59,7 @@ def test_drqn(args=get_args()):
|
|||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
# model
|
# model
|
||||||
net = Recurrent(args.layer_num, args.state_shape,
|
net = Recurrent(args.layer_num, args.state_shape,
|
||||||
args.action_shape, args.device)
|
args.action_shape, args.device).to(args.device)
|
||||||
net = net.to(args.device)
|
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
policy = DQNPolicy(
|
policy = DQNPolicy(
|
||||||
net, optim, args.gamma, args.n_step,
|
net, optim, args.gamma, args.n_step,
|
||||||
|
@ -6,16 +6,12 @@ import argparse
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.env import VectorEnv
|
from tianshou.env import VectorEnv
|
||||||
from tianshou.policy import DQNPolicy
|
from tianshou.policy import DQNPolicy
|
||||||
from tianshou.trainer import offpolicy_trainer
|
from tianshou.trainer import offpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
from net import Net
|
|
||||||
else: # pytest
|
|
||||||
from test.discrete.net import Net
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -64,8 +60,8 @@ def test_pdqn(args=get_args()):
|
|||||||
train_envs.seed(args.seed)
|
train_envs.seed(args.seed)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
# model
|
# model
|
||||||
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
|
net = Net(args.layer_num, args.state_shape,
|
||||||
net = net.to(args.device)
|
args.action_shape, args.device).to(args.device)
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
policy = DQNPolicy(
|
policy = DQNPolicy(
|
||||||
net, optim, args.gamma, args.n_step,
|
net, optim, args.gamma, args.n_step,
|
||||||
|
@ -7,16 +7,12 @@ import argparse
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.env import VectorEnv
|
from tianshou.env import VectorEnv
|
||||||
from tianshou.policy import PGPolicy
|
from tianshou.policy import PGPolicy
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
from tianshou.data import Batch, Collector, ReplayBuffer
|
from tianshou.data import Batch, Collector, ReplayBuffer
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
from net import Net
|
|
||||||
else: # pytest
|
|
||||||
from test.discrete.net import Net
|
|
||||||
|
|
||||||
|
|
||||||
def compute_return_base(batch, aa=None, bb=None, gamma=0.1):
|
def compute_return_base(batch, aa=None, bb=None, gamma=0.1):
|
||||||
returns = np.zeros_like(batch.rew)
|
returns = np.zeros_like(batch.rew)
|
||||||
@ -129,8 +125,7 @@ def test_pg(args=get_args()):
|
|||||||
# model
|
# model
|
||||||
net = Net(
|
net = Net(
|
||||||
args.layer_num, args.state_shape, args.action_shape,
|
args.layer_num, args.state_shape, args.action_shape,
|
||||||
device=args.device, softmax=True)
|
device=args.device, softmax=True).to(args.device)
|
||||||
net = net.to(args.device)
|
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
dist = torch.distributions.Categorical
|
dist = torch.distributions.Categorical
|
||||||
policy = PGPolicy(net, optim, dist, args.gamma,
|
policy = PGPolicy(net, optim, dist, args.gamma,
|
||||||
|
@ -10,11 +10,8 @@ from tianshou.env import VectorEnv
|
|||||||
from tianshou.policy import PPOPolicy
|
from tianshou.policy import PPOPolicy
|
||||||
from tianshou.trainer import onpolicy_trainer
|
from tianshou.trainer import onpolicy_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
|
from tianshou.utils.net.discrete import Actor, Critic
|
||||||
if __name__ == '__main__':
|
from tianshou.utils.net.common import Net
|
||||||
from net import Net, Actor, Critic
|
|
||||||
else: # pytest
|
|
||||||
from test.discrete.net import Net, Actor, Critic
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
|
@ -46,7 +46,7 @@ class ImitationPolicy(BasePolicy):
|
|||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
if self.mode == 'continuous':
|
if self.mode == 'continuous':
|
||||||
a = self(batch).act
|
a = self(batch).act
|
||||||
a_ = to_torch(batch.act, dtype=torch.float, device=a.device)
|
a_ = to_torch(batch.act, dtype=torch.float32, device=a.device)
|
||||||
loss = F.mse_loss(a, a_)
|
loss = F.mse_loss(a, a_)
|
||||||
elif self.mode == 'discrete': # classification
|
elif self.mode == 'discrete': # classification
|
||||||
a = self(batch).logits
|
a = self(batch).logits
|
||||||
|
0
tianshou/utils/net/__init__.py
Normal file
0
tianshou/utils/net/__init__.py
Normal file
@ -1,82 +1,67 @@
|
|||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from tianshou.data import to_torch
|
from tianshou.data import to_torch
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Module):
|
class Net(nn.Module):
|
||||||
|
"""Simple MLP backbone. For advanced usage (how to customize the network),
|
||||||
|
please refer to :ref:`build_the_network`.
|
||||||
|
|
||||||
|
:param concat: whether the input shape is concatenated by state_shape
|
||||||
|
and action_shape. If it is True, ``action_shape`` is not the output
|
||||||
|
shape, but affects the input shape.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu',
|
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu',
|
||||||
softmax=False):
|
softmax=False, concat=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
|
input_size = np.prod(state_shape)
|
||||||
|
if concat:
|
||||||
|
input_size += np.prod(action_shape)
|
||||||
self.model = [
|
self.model = [
|
||||||
nn.Linear(np.prod(state_shape), 128),
|
nn.Linear(input_size, 128),
|
||||||
nn.ReLU(inplace=True)]
|
nn.ReLU(inplace=True)]
|
||||||
for i in range(layer_num):
|
for i in range(layer_num):
|
||||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
||||||
if action_shape:
|
if action_shape and not concat:
|
||||||
self.model += [nn.Linear(128, np.prod(action_shape))]
|
self.model += [nn.Linear(128, np.prod(action_shape))]
|
||||||
if softmax:
|
if softmax:
|
||||||
self.model += [nn.Softmax(dim=-1)]
|
self.model += [nn.Softmax(dim=-1)]
|
||||||
self.model = nn.Sequential(*self.model)
|
self.model = nn.Sequential(*self.model)
|
||||||
|
|
||||||
def forward(self, s, state=None, info={}):
|
def forward(self, s, state=None, info={}):
|
||||||
s = to_torch(s, device=self.device, dtype=torch.float)
|
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||||
batch = s.shape[0]
|
s = s.flatten(1)
|
||||||
s = s.view(batch, -1)
|
|
||||||
logits = self.model(s)
|
logits = self.model(s)
|
||||||
return logits, state
|
return logits, state
|
||||||
|
|
||||||
|
|
||||||
class Actor(nn.Module):
|
|
||||||
def __init__(self, preprocess_net, action_shape):
|
|
||||||
super().__init__()
|
|
||||||
self.preprocess = preprocess_net
|
|
||||||
self.last = nn.Linear(128, np.prod(action_shape))
|
|
||||||
|
|
||||||
def forward(self, s, state=None, info={}):
|
|
||||||
logits, h = self.preprocess(s, state)
|
|
||||||
logits = F.softmax(self.last(logits), dim=-1)
|
|
||||||
return logits, h
|
|
||||||
|
|
||||||
|
|
||||||
class Critic(nn.Module):
|
|
||||||
def __init__(self, preprocess_net):
|
|
||||||
super().__init__()
|
|
||||||
self.preprocess = preprocess_net
|
|
||||||
self.last = nn.Linear(128, 1)
|
|
||||||
|
|
||||||
def forward(self, s, **kwargs):
|
|
||||||
logits, h = self.preprocess(s, state=kwargs.get('state', None))
|
|
||||||
logits = self.last(logits)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
class Recurrent(nn.Module):
|
class Recurrent(nn.Module):
|
||||||
|
"""Simple Recurrent network based on LSTM. For advanced usage (how to
|
||||||
|
customize the network), please refer to :ref:`build_the_network`.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
|
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.state_shape = state_shape
|
self.state_shape = state_shape
|
||||||
self.action_shape = action_shape
|
self.action_shape = action_shape
|
||||||
self.device = device
|
self.device = device
|
||||||
self.fc1 = nn.Linear(np.prod(state_shape), 128)
|
|
||||||
self.nn = nn.LSTM(input_size=128, hidden_size=128,
|
self.nn = nn.LSTM(input_size=128, hidden_size=128,
|
||||||
num_layers=layer_num, batch_first=True)
|
num_layers=layer_num, batch_first=True)
|
||||||
|
self.fc1 = nn.Linear(np.prod(state_shape), 128)
|
||||||
self.fc2 = nn.Linear(128, np.prod(action_shape))
|
self.fc2 = nn.Linear(128, np.prod(action_shape))
|
||||||
|
|
||||||
def forward(self, s, state=None, info={}):
|
def forward(self, s, state=None, info={}):
|
||||||
s = to_torch(s, device=self.device, dtype=torch.float)
|
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||||
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||||
# In short, the tensor's shape in training phase is longer than which
|
# In short, the tensor's shape in training phase is longer than which
|
||||||
# in evaluation phase.
|
# in evaluation phase.
|
||||||
if len(s.shape) == 2:
|
if len(s.shape) == 2:
|
||||||
bsz, dim = s.shape
|
s = s.unsqueeze(-2)
|
||||||
length = 1
|
s = self.fc1(s)
|
||||||
else:
|
|
||||||
bsz, length, dim = s.shape
|
|
||||||
s = self.fc1(s.view([bsz * length, dim]))
|
|
||||||
s = s.view(bsz, length, -1)
|
|
||||||
self.nn.flatten_parameters()
|
self.nn.flatten_parameters()
|
||||||
if state is None:
|
if state is None:
|
||||||
s, (h, c) = self.nn(s)
|
s, (h, c) = self.nn(s)
|
@ -6,85 +6,77 @@ from tianshou.data import to_torch
|
|||||||
|
|
||||||
|
|
||||||
class Actor(nn.Module):
|
class Actor(nn.Module):
|
||||||
def __init__(self, layer_num, state_shape, action_shape,
|
"""For advanced usage (how to customize the network), please refer to
|
||||||
|
:ref:`build_the_network`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, preprocess_net, action_shape,
|
||||||
max_action, device='cpu'):
|
max_action, device='cpu'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.preprocess = preprocess_net
|
||||||
self.model = [
|
self.last = nn.Linear(128, np.prod(action_shape))
|
||||||
nn.Linear(np.prod(state_shape), 128),
|
|
||||||
nn.ReLU(inplace=True)]
|
|
||||||
for i in range(layer_num):
|
|
||||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
|
||||||
self.model += [nn.Linear(128, np.prod(action_shape))]
|
|
||||||
self.model = nn.Sequential(*self.model)
|
|
||||||
self._max = max_action
|
self._max = max_action
|
||||||
|
|
||||||
def forward(self, s, **kwargs):
|
def forward(self, s, state=None, info={}):
|
||||||
s = to_torch(s, device=self.device, dtype=torch.float)
|
logits, h = self.preprocess(s, state)
|
||||||
batch = s.shape[0]
|
logits = self._max * torch.tanh(self.last(logits))
|
||||||
s = s.view(batch, -1)
|
return logits, h
|
||||||
logits = self.model(s)
|
|
||||||
logits = self._max * torch.tanh(logits)
|
|
||||||
return logits, None
|
|
||||||
|
|
||||||
|
|
||||||
class ActorProb(nn.Module):
|
|
||||||
def __init__(self, layer_num, state_shape, action_shape,
|
|
||||||
max_action, device='cpu'):
|
|
||||||
super().__init__()
|
|
||||||
self.device = device
|
|
||||||
self.model = [
|
|
||||||
nn.Linear(np.prod(state_shape), 128),
|
|
||||||
nn.ReLU(inplace=True)]
|
|
||||||
for i in range(layer_num):
|
|
||||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
|
||||||
self.model = nn.Sequential(*self.model)
|
|
||||||
self.mu = nn.Linear(128, np.prod(action_shape))
|
|
||||||
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
|
|
||||||
# self.sigma = nn.Linear(128, np.prod(action_shape))
|
|
||||||
self._max = max_action
|
|
||||||
|
|
||||||
def forward(self, s, **kwargs):
|
|
||||||
s = to_torch(s, device=self.device, dtype=torch.float)
|
|
||||||
batch = s.shape[0]
|
|
||||||
s = s.view(batch, -1)
|
|
||||||
logits = self.model(s)
|
|
||||||
mu = self.mu(logits)
|
|
||||||
shape = [1] * len(mu.shape)
|
|
||||||
shape[1] = -1
|
|
||||||
sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp()
|
|
||||||
# assert sigma.shape == mu.shape
|
|
||||||
# mu = self._max * torch.tanh(self.mu(logits))
|
|
||||||
# sigma = torch.exp(self.sigma(logits))
|
|
||||||
return (mu, sigma), None
|
|
||||||
|
|
||||||
|
|
||||||
class Critic(nn.Module):
|
class Critic(nn.Module):
|
||||||
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'):
|
"""For advanced usage (how to customize the network), please refer to
|
||||||
|
:ref:`build_the_network`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, preprocess_net, device='cpu'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.model = [
|
self.preprocess = preprocess_net
|
||||||
nn.Linear(np.prod(state_shape) + np.prod(action_shape), 128),
|
self.last = nn.Linear(128, 1)
|
||||||
nn.ReLU(inplace=True)]
|
|
||||||
for i in range(layer_num):
|
|
||||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
|
||||||
self.model += [nn.Linear(128, 1)]
|
|
||||||
self.model = nn.Sequential(*self.model)
|
|
||||||
|
|
||||||
def forward(self, s, a=None, **kwargs):
|
def forward(self, s, a=None, **kwargs):
|
||||||
s = to_torch(s, device=self.device, dtype=torch.float)
|
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||||
batch = s.shape[0]
|
s = s.flatten(1)
|
||||||
s = s.view(batch, -1)
|
|
||||||
if a is not None:
|
if a is not None:
|
||||||
if not isinstance(a, torch.Tensor):
|
a = to_torch(a, device=self.device, dtype=torch.float32)
|
||||||
a = torch.tensor(a, device=self.device, dtype=torch.float)
|
a = a.flatten(1)
|
||||||
a = a.view(batch, -1)
|
|
||||||
s = torch.cat([s, a], dim=1)
|
s = torch.cat([s, a], dim=1)
|
||||||
logits = self.model(s)
|
logits, h = self.preprocess(s)
|
||||||
|
logits = self.last(logits)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class ActorProb(nn.Module):
|
||||||
|
"""For advanced usage (how to customize the network), please refer to
|
||||||
|
:ref:`build_the_network`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, preprocess_net, action_shape,
|
||||||
|
max_action, device='cpu', unbounded=False):
|
||||||
|
super().__init__()
|
||||||
|
self.preprocess = preprocess_net
|
||||||
|
self.device = device
|
||||||
|
self.mu = nn.Linear(128, np.prod(action_shape))
|
||||||
|
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
|
||||||
|
self._max = max_action
|
||||||
|
self._unbounded = unbounded
|
||||||
|
|
||||||
|
def forward(self, s, state=None, **kwargs):
|
||||||
|
logits, h = self.preprocess(s, state)
|
||||||
|
mu = self.mu(logits)
|
||||||
|
if not self._unbounded:
|
||||||
|
mu = self._max * torch.tanh(mu)
|
||||||
|
shape = [1] * len(mu.shape)
|
||||||
|
shape[1] = -1
|
||||||
|
sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp()
|
||||||
|
return (mu, sigma), None
|
||||||
|
|
||||||
|
|
||||||
class RecurrentActorProb(nn.Module):
|
class RecurrentActorProb(nn.Module):
|
||||||
|
"""For advanced usage (how to customize the network), please refer to
|
||||||
|
:ref:`build_the_network`.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, layer_num, state_shape, action_shape,
|
def __init__(self, layer_num, state_shape, action_shape,
|
||||||
max_action, device='cpu'):
|
max_action, device='cpu'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -95,16 +87,12 @@ class RecurrentActorProb(nn.Module):
|
|||||||
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
|
self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1))
|
||||||
|
|
||||||
def forward(self, s, **kwargs):
|
def forward(self, s, **kwargs):
|
||||||
s = to_torch(s, device=self.device, dtype=torch.float)
|
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||||
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||||
# In short, the tensor's shape in training phase is longer than which
|
# In short, the tensor's shape in training phase is longer than which
|
||||||
# in evaluation phase.
|
# in evaluation phase.
|
||||||
if len(s.shape) == 2:
|
if len(s.shape) == 2:
|
||||||
bsz, dim = s.shape
|
s = s.unsqueeze(-2)
|
||||||
length = 1
|
|
||||||
else:
|
|
||||||
bsz, length, dim = s.shape
|
|
||||||
s = s.view(bsz, length, -1)
|
|
||||||
logits, _ = self.nn(s)
|
logits, _ = self.nn(s)
|
||||||
logits = logits[:, -1]
|
logits = logits[:, -1]
|
||||||
mu = self.mu(logits)
|
mu = self.mu(logits)
|
||||||
@ -115,6 +103,10 @@ class RecurrentActorProb(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class RecurrentCritic(nn.Module):
|
class RecurrentCritic(nn.Module):
|
||||||
|
"""For advanced usage (how to customize the network), please refer to
|
||||||
|
:ref:`build_the_network`.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'):
|
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.state_shape = state_shape
|
self.state_shape = state_shape
|
||||||
@ -125,7 +117,7 @@ class RecurrentCritic(nn.Module):
|
|||||||
self.fc2 = nn.Linear(128 + np.prod(action_shape), 1)
|
self.fc2 = nn.Linear(128 + np.prod(action_shape), 1)
|
||||||
|
|
||||||
def forward(self, s, a=None):
|
def forward(self, s, a=None):
|
||||||
s = to_torch(s, device=self.device, dtype=torch.float)
|
s = to_torch(s, device=self.device, dtype=torch.float32)
|
||||||
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
# s [bsz, len, dim] (training) or [bsz, dim] (evaluation)
|
||||||
# In short, the tensor's shape in training phase is longer than which
|
# In short, the tensor's shape in training phase is longer than which
|
||||||
# in evaluation phase.
|
# in evaluation phase.
|
||||||
@ -135,7 +127,7 @@ class RecurrentCritic(nn.Module):
|
|||||||
s = s[:, -1]
|
s = s[:, -1]
|
||||||
if a is not None:
|
if a is not None:
|
||||||
if not isinstance(a, torch.Tensor):
|
if not isinstance(a, torch.Tensor):
|
||||||
a = torch.tensor(a, device=self.device, dtype=torch.float)
|
a = torch.tensor(a, device=self.device, dtype=torch.float32)
|
||||||
s = torch.cat([s, a], dim=1)
|
s = torch.cat([s, a], dim=1)
|
||||||
s = self.fc2(s)
|
s = self.fc2(s)
|
||||||
return s
|
return s
|
@ -4,29 +4,11 @@ from torch import nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Module):
|
|
||||||
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'):
|
|
||||||
super().__init__()
|
|
||||||
self.device = device
|
|
||||||
self.model = [
|
|
||||||
nn.Linear(np.prod(state_shape), 128),
|
|
||||||
nn.ReLU(inplace=True)]
|
|
||||||
for i in range(layer_num):
|
|
||||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
|
||||||
if action_shape:
|
|
||||||
self.model += [nn.Linear(128, np.prod(action_shape))]
|
|
||||||
self.model = nn.Sequential(*self.model)
|
|
||||||
|
|
||||||
def forward(self, s, state=None, info={}):
|
|
||||||
if not isinstance(s, torch.Tensor):
|
|
||||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
|
||||||
batch = s.shape[0]
|
|
||||||
s = s.view(batch, -1)
|
|
||||||
logits = self.model(s)
|
|
||||||
return logits, state
|
|
||||||
|
|
||||||
|
|
||||||
class Actor(nn.Module):
|
class Actor(nn.Module):
|
||||||
|
"""For advanced usage (how to customize the network), please refer to
|
||||||
|
:ref:`build_the_network`.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, preprocess_net, action_shape):
|
def __init__(self, preprocess_net, action_shape):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.preprocess = preprocess_net
|
self.preprocess = preprocess_net
|
||||||
@ -39,18 +21,25 @@ class Actor(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Critic(nn.Module):
|
class Critic(nn.Module):
|
||||||
|
"""For advanced usage (how to customize the network), please refer to
|
||||||
|
:ref:`build_the_network`.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, preprocess_net):
|
def __init__(self, preprocess_net):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.preprocess = preprocess_net
|
self.preprocess = preprocess_net
|
||||||
self.last = nn.Linear(128, 1)
|
self.last = nn.Linear(128, 1)
|
||||||
|
|
||||||
def forward(self, s):
|
def forward(self, s, **kwargs):
|
||||||
logits, h = self.preprocess(s, None)
|
logits, h = self.preprocess(s, state=kwargs.get('state', None))
|
||||||
logits = self.last(logits)
|
logits = self.last(logits)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
class DQN(nn.Module):
|
class DQN(nn.Module):
|
||||||
|
"""For advanced usage (how to customize the network), please refer to
|
||||||
|
:ref:`build_the_network`.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, h, w, action_shape, device='cpu'):
|
def __init__(self, h, w, action_shape, device='cpu'):
|
||||||
super(DQN, self).__init__()
|
super(DQN, self).__init__()
|
||||||
@ -74,7 +63,7 @@ class DQN(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, state=None, info={}):
|
def forward(self, x, state=None, info={}):
|
||||||
if not isinstance(x, torch.Tensor):
|
if not isinstance(x, torch.Tensor):
|
||||||
x = torch.tensor(x, device=self.device, dtype=torch.float)
|
x = torch.tensor(x, device=self.device, dtype=torch.float32)
|
||||||
x = x.permute(0, 3, 1, 2)
|
x = x.permute(0, 3, 1, 2)
|
||||||
x = F.relu(self.bn1(self.conv1(x)))
|
x = F.relu(self.bn1(self.conv1(x)))
|
||||||
x = F.relu(self.bn2(self.conv2(x)))
|
x = F.relu(self.bn2(self.conv2(x)))
|
Loading…
x
Reference in New Issue
Block a user