From e767de044b969bc34ccfafcb40be7988a6e7c95b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 9 Jul 2020 22:57:01 +0800 Subject: [PATCH] Remove dummy net code (#123) * remove dummy net; delete two files * split code to have backbone and head * rename class * change torch.float to torch.float32 * use flatten(1) instead of view(batch, -1) * remove dummy net in docs * bugfix for rnn * fix cuda error * minor fix of docs * do not change the example code in dqn tutorial, since it is for demonstration Co-authored-by: Trinkle23897 <463003665@qq.com> --- README.md | 18 +-- docs/api/tianshou.utils.rst | 15 ++ docs/tutorials/dqn.rst | 2 +- examples/ant_v2_ddpg.py | 17 ++- examples/ant_v2_sac.py | 17 ++- examples/ant_v2_td3.py | 17 ++- examples/continuous_net.py | 81 ----------- examples/halfcheetahBullet_v0_sac.py | 17 ++- examples/point_maze_td3.py | 16 +-- examples/pong_a2c.py | 3 +- examples/pong_dqn.py | 3 +- examples/pong_ppo.py | 4 +- examples/sac_mcc.py | 17 ++- test/continuous/test_ddpg.py | 16 +-- test/continuous/test_ppo.py | 14 +- test/continuous/test_sac_with_il.py | 25 ++-- test/continuous/test_td3.py | 20 ++- test/discrete/test_a2c_with_il.py | 7 +- test/discrete/test_dqn.py | 10 +- test/discrete/test_drqn.py | 9 +- test/discrete/test_pdqn.py | 10 +- test/discrete/test_pg.py | 9 +- test/discrete/test_ppo.py | 7 +- tianshou/policy/imitation/base.py | 2 +- tianshou/utils/net/__init__.py | 0 .../net.py => tianshou/utils/net/common.py | 65 ++++----- .../utils/net/continuous.py | 132 ++++++++---------- .../utils/net/discrete.py | 39 ++---- 28 files changed, 219 insertions(+), 373 deletions(-) delete mode 100644 examples/continuous_net.py create mode 100644 tianshou/utils/net/__init__.py rename test/discrete/net.py => tianshou/utils/net/common.py (62%) rename test/continuous/net.py => tianshou/utils/net/continuous.py (53%) rename examples/discrete_net.py => tianshou/utils/net/discrete.py (69%) diff --git a/README.md b/README.md index 150489a..fbc8bbb 100644 --- a/README.md +++ b/README.md @@ -206,26 +206,12 @@ test_envs = ts.env.VectorEnv([lambda: gym.make(task) for _ in range(test_num)]) Define the network: ```python -class Net(nn.Module): - def __init__(self, state_shape, action_shape): - super().__init__() - self.model = nn.Sequential(*[ - nn.Linear(np.prod(state_shape), 128), nn.ReLU(inplace=True), - nn.Linear(128, 128), nn.ReLU(inplace=True), - nn.Linear(128, 128), nn.ReLU(inplace=True), - nn.Linear(128, np.prod(action_shape)) - ]) - def forward(self, s, state=None, info={}): - if not isinstance(s, torch.Tensor): - s = torch.tensor(s, dtype=torch.float) - batch = s.shape[0] - logits = self.model(s.view(batch, -1)) - return logits, state +from tianshou.utils.net.common import Net env = gym.make(task) state_shape = env.observation_space.shape or env.observation_space.n action_shape = env.action_space.shape or env.action_space.n -net = Net(state_shape, action_shape) +net = Net(layer_num=2, state_shape=state_shape, action_shape=action_shape) optim = torch.optim.Adam(net.parameters(), lr=lr) ``` diff --git a/docs/api/tianshou.utils.rst b/docs/api/tianshou.utils.rst index 82c1952..3a293b1 100644 --- a/docs/api/tianshou.utils.rst +++ b/docs/api/tianshou.utils.rst @@ -5,3 +5,18 @@ tianshou.utils :members: :undoc-members: :show-inheritance: + +.. automodule:: tianshou.utils.net.common + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: tianshou.utils.net.discrete + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: tianshou.utils.net.continuous + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index 0d43031..d981a1c 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -74,7 +74,7 @@ Tianshou supports any user-defined PyTorch networks and optimizers but with the net = Net(state_shape, action_shape) optim = torch.optim.Adam(net.parameters(), lr=1e-3) -The rules of self-defined networks are: +You can also have a try with those pre-defined networks in :mod:`~tianshou.utils.net.common`, :mod:`~tianshou.utils.net.discrete`, and :mod:`~tianshou.utils.net.continuous`. The rules of self-defined networks are: 1. Input: observation ``obs`` (may be a ``numpy.ndarray``, ``torch.Tensor``, dict, or self-defined class), hidden state ``state`` (for RNN usage), and other information ``info`` provided by the environment. 2. Output: some ``logits``, the next hidden state ``state``, and intermediate result during the policy forwarding procedure ``policy``. The logits could be a tuple instead of a ``torch.Tensor``. It depends on how the policy process the network output. For example, in PPO :cite:`PPO`, the return of the network might be ``(mu, sigma), state`` for Gaussian policy. The ``policy`` can be a Batch of torch.Tensor or other things, which will be stored in the replay buffer, and can be accessed in the policy update process (e.g. in ``policy.learn()``, the ``batch.policy`` is what you need). diff --git a/examples/ant_v2_ddpg.py b/examples/ant_v2_ddpg.py index 16b0299..6b9ba0a 100644 --- a/examples/ant_v2_ddpg.py +++ b/examples/ant_v2_ddpg.py @@ -10,8 +10,8 @@ from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.env import VectorEnv, SubprocVectorEnv from tianshou.exploration import GaussianNoise - -from continuous_net import Actor, Critic +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import Actor, Critic def get_args(): @@ -57,14 +57,13 @@ def test_ddpg(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - actor = Actor( - args.layer_num, args.state_shape, args.action_shape, - args.max_action, args.device - ).to(args.device) + net = Net(args.layer_num, args.state_shape, device=args.device) + actor = Actor(net, args.action_shape, args.max_action, + args.device).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - critic = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device - ).to(args.device) + net = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) + critic = Critic(net, args.device).to(args.device) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) policy = DDPGPolicy( actor, actor_optim, critic, critic_optim, diff --git a/examples/ant_v2_sac.py b/examples/ant_v2_sac.py index 1d28615..cdfc813 100644 --- a/examples/ant_v2_sac.py +++ b/examples/ant_v2_sac.py @@ -10,8 +10,8 @@ from tianshou.policy import SACPolicy from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.env import VectorEnv, SubprocVectorEnv - -from continuous_net import ActorProb, Critic +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import ActorProb, Critic def get_args(): @@ -58,18 +58,17 @@ def test_sac(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model + net = Net(args.layer_num, args.state_shape, device=args.device) actor = ActorProb( - args.layer_num, args.state_shape, args.action_shape, + net, args.action_shape, args.max_action, args.device, unbounded=True ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - critic1 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device - ).to(args.device) + net = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) + critic1 = Critic(net, args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device - ).to(args.device) + critic2 = Critic(net, args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = SACPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, diff --git a/examples/ant_v2_td3.py b/examples/ant_v2_td3.py index 45770b2..495891b 100644 --- a/examples/ant_v2_td3.py +++ b/examples/ant_v2_td3.py @@ -10,8 +10,8 @@ from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.env import VectorEnv, SubprocVectorEnv from tianshou.exploration import GaussianNoise - -from continuous_net import Actor, Critic +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import Actor, Critic def get_args(): @@ -60,18 +60,17 @@ def test_td3(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model + net = Net(args.layer_num, args.state_shape, device=args.device) actor = Actor( - args.layer_num, args.state_shape, args.action_shape, + net, args.action_shape, args.max_action, args.device ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - critic1 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device - ).to(args.device) + net = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) + critic1 = Critic(net, args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device - ).to(args.device) + critic2 = Critic(net, args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = TD3Policy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, diff --git a/examples/continuous_net.py b/examples/continuous_net.py deleted file mode 100644 index c76ab17..0000000 --- a/examples/continuous_net.py +++ /dev/null @@ -1,81 +0,0 @@ -import torch -import numpy as np -from torch import nn - - -class Actor(nn.Module): - def __init__(self, layer_num, state_shape, action_shape, - max_action, device='cpu'): - super().__init__() - self.device = device - self.model = [ - nn.Linear(np.prod(state_shape), 128), - nn.ReLU(inplace=True)] - for i in range(layer_num): - self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] - self.model += [nn.Linear(128, np.prod(action_shape))] - self.model = nn.Sequential(*self.model) - self._max = max_action - - def forward(self, s, **kwargs): - s = torch.tensor(s, device=self.device, dtype=torch.float) - batch = s.shape[0] - s = s.view(batch, -1) - logits = self.model(s) - logits = self._max * torch.tanh(logits) - return logits, None - - -class ActorProb(nn.Module): - def __init__(self, layer_num, state_shape, action_shape, - max_action, device='cpu', unbounded=False): - super().__init__() - self.device = device - self.model = [ - nn.Linear(np.prod(state_shape), 128), - nn.ReLU(inplace=True)] - for i in range(layer_num): - self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] - self.model = nn.Sequential(*self.model) - self.mu = nn.Linear(128, np.prod(action_shape)) - self.sigma = nn.Linear(128, np.prod(action_shape)) - self._max = max_action - self._unbounded = unbounded - - def forward(self, s, **kwargs): - if not isinstance(s, torch.Tensor): - s = torch.tensor(s, device=self.device, dtype=torch.float) - batch = s.shape[0] - s = s.view(batch, -1) - logits = self.model(s) - if not self._unbounded: - mu = self._max * torch.tanh(self.mu(logits)) - sigma = torch.exp(self.sigma(logits)) - return (mu, sigma), None - - -class Critic(nn.Module): - def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'): - super().__init__() - self.device = device - self.model = [ - nn.Linear(np.prod(state_shape) + np.prod(action_shape), 128), - nn.ReLU(inplace=True)] - for i in range(layer_num): - self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] - self.model += [nn.Linear(128, 1)] - self.model = nn.Sequential(*self.model) - - def forward(self, s, a=None): - if not isinstance(s, torch.Tensor): - s = torch.tensor(s, device=self.device, dtype=torch.float) - if a is not None and not isinstance(a, torch.Tensor): - a = torch.tensor(a, device=self.device, dtype=torch.float) - batch = s.shape[0] - s = s.view(batch, -1) - if a is None: - logits = self.model(s) - else: - a = a.view(batch, -1) - logits = self.model(torch.cat([s, a], dim=1)) - return logits diff --git a/examples/halfcheetahBullet_v0_sac.py b/examples/halfcheetahBullet_v0_sac.py index 57ca8ba..3da77cc 100644 --- a/examples/halfcheetahBullet_v0_sac.py +++ b/examples/halfcheetahBullet_v0_sac.py @@ -15,8 +15,8 @@ try: import pybullet_envs except ImportError: pass - -from continuous_net import ActorProb, Critic +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import ActorProb, Critic def get_args(): @@ -66,18 +66,17 @@ def test_sac(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model + net = Net(args.layer_num, args.state_shape, device=args.device) actor = ActorProb( - args.layer_num, args.state_shape, args.action_shape, + net, args.action_shape, args.max_action, args.device, unbounded=True ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - critic1 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device - ).to(args.device) + net = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) + critic1 = Critic(net, args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device - ).to(args.device) + critic2 = Critic(net, args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = SACPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, diff --git a/examples/point_maze_td3.py b/examples/point_maze_td3.py index b3f2c95..ce04599 100644 --- a/examples/point_maze_td3.py +++ b/examples/point_maze_td3.py @@ -10,7 +10,8 @@ from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.env import VectorEnv, SubprocVectorEnv from tianshou.exploration import GaussianNoise -from continuous_net import Actor, Critic +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import Actor, Critic from mujoco.register import reg @@ -63,18 +64,17 @@ def test_td3(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model + net = Net(args.layer_num, args.state_shape, device=args.device) actor = Actor( - args.layer_num, args.state_shape, args.action_shape, + net, args.action_shape, args.max_action, args.device ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - critic1 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device - ).to(args.device) + net = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) + critic1 = Critic(net, args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device - ).to(args.device) + critic2 = Critic(net, args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = TD3Policy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, diff --git a/examples/pong_a2c.py b/examples/pong_a2c.py index 0490a43..544153e 100644 --- a/examples/pong_a2c.py +++ b/examples/pong_a2c.py @@ -10,7 +10,8 @@ from tianshou.trainer import onpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.env.atari import create_atari_environment -from discrete_net import Net, Actor, Critic +from tianshou.utils.net.discrete import Actor, Critic +from tianshou.utils.net.common import Net def get_args(): diff --git a/examples/pong_dqn.py b/examples/pong_dqn.py index 98c49a8..85a7e6e 100644 --- a/examples/pong_dqn.py +++ b/examples/pong_dqn.py @@ -6,12 +6,11 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DQNPolicy from tianshou.env import SubprocVectorEnv +from tianshou.utils.net.discrete import DQN from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.env.atari import create_atari_environment -from discrete_net import DQN - def get_args(): parser = argparse.ArgumentParser() diff --git a/examples/pong_ppo.py b/examples/pong_ppo.py index f9976b5..e475344 100644 --- a/examples/pong_ppo.py +++ b/examples/pong_ppo.py @@ -9,8 +9,8 @@ from tianshou.env import SubprocVectorEnv from tianshou.trainer import onpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.env.atari import create_atari_environment - -from discrete_net import Net, Actor, Critic +from tianshou.utils.net.discrete import Actor, Critic +from tianshou.utils.net.common import Net def get_args(): diff --git a/examples/sac_mcc.py b/examples/sac_mcc.py index 6455975..fcd2ce4 100644 --- a/examples/sac_mcc.py +++ b/examples/sac_mcc.py @@ -11,8 +11,8 @@ from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.env import VectorEnv from tianshou.exploration import OUNoise - -from continuous_net import ActorProb, Critic +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import ActorProb, Critic def get_args(): @@ -62,18 +62,17 @@ def test_sac(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model + net = Net(args.layer_num, args.state_shape, device=args.device) actor = ActorProb( - args.layer_num, args.state_shape, args.action_shape, + net, args.action_shape, args.max_action, args.device, unbounded=True ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - critic1 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device - ).to(args.device) + net = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) + critic1 = Critic(net, args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device - ).to(args.device) + critic2 = Critic(net, args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) if args.auto_alpha: diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 9428f11..6f078bc 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -11,11 +11,8 @@ from tianshou.policy import DDPGPolicy from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.exploration import GaussianNoise - -if __name__ == '__main__': - from net import Actor, Critic -else: # pytest - from test.continuous.net import Actor, Critic +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import Actor, Critic def get_args(): @@ -69,14 +66,15 @@ def test_ddpg(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model + net = Net(args.layer_num, args.state_shape, device=args.device) actor = Actor( - args.layer_num, args.state_shape, args.action_shape, + net, args.action_shape, args.max_action, args.device ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - critic = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device - ).to(args.device) + net = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) + critic = Critic(net, args.device).to(args.device) critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) policy = DDPGPolicy( actor, actor_optim, critic, critic_optim, diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index dd0e765..daa7b06 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -11,11 +11,8 @@ from tianshou.policy import PPOPolicy from tianshou.policy.dist import DiagGaussian from tianshou.trainer import onpolicy_trainer from tianshou.data import Collector, ReplayBuffer - -if __name__ == '__main__': - from net import ActorProb, Critic -else: # pytest - from test.continuous.net import ActorProb, Critic +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import ActorProb, Critic def get_args(): @@ -72,13 +69,14 @@ def test_ppo(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model + net = Net(args.layer_num, args.state_shape, device=args.device) actor = ActorProb( - args.layer_num, args.state_shape, args.action_shape, + net, args.action_shape, args.max_action, args.device ).to(args.device) - critic = Critic( + critic = Critic(Net( args.layer_num, args.state_shape, device=args.device - ).to(args.device) + ), device=args.device).to(args.device) # orthogonal initialization for m in list(actor.modules()) + list(critic.modules()): if isinstance(m, torch.nn.Linear): diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 8b1a3d1..96e0340 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -10,11 +10,8 @@ from tianshou.env import VectorEnv from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.policy import SACPolicy, ImitationPolicy - -if __name__ == '__main__': - from net import Actor, ActorProb, Critic -else: # pytest - from test.continuous.net import Actor, ActorProb, Critic +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import Actor, ActorProb, Critic def get_args(): @@ -68,18 +65,17 @@ def test_sac_with_il(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model + net = Net(args.layer_num, args.state_shape, device=args.device) actor = ActorProb( - args.layer_num, args.state_shape, args.action_shape, + net, args.action_shape, args.max_action, args.device ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - critic1 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device - ).to(args.device) + net = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) + critic1 = Critic(net, args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device - ).to(args.device) + critic2 = Critic(net, args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = SACPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, @@ -122,8 +118,9 @@ def test_sac_with_il(args=get_args()): # here we define an imitation collector with a trivial policy if args.task == 'Pendulum-v0': env.spec.reward_threshold = -300 # lower the goal - net = Actor(1, args.state_shape, args.action_shape, - args.max_action, args.device).to(args.device) + net = Actor(Net(1, args.state_shape, device=args.device), + args.action_shape, args.max_action, args.device + ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) il_policy = ImitationPolicy(net, optim, mode='continuous') il_test_collector = Collector(il_policy, test_envs) diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 6d3133b..096290b 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -11,11 +11,8 @@ from tianshou.policy import TD3Policy from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.exploration import GaussianNoise - -if __name__ == '__main__': - from net import Actor, Critic -else: # pytest - from test.continuous.net import Actor, Critic +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import Actor, Critic def get_args(): @@ -71,18 +68,17 @@ def test_td3(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model + net = Net(args.layer_num, args.state_shape, device=args.device) actor = Actor( - args.layer_num, args.state_shape, args.action_shape, + net, args.action_shape, args.max_action, args.device ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - critic1 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device - ).to(args.device) + net = Net(args.layer_num, args.state_shape, + args.action_shape, concat=True, device=args.device) + critic1 = Critic(net, args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) - critic2 = Critic( - args.layer_num, args.state_shape, args.action_shape, args.device - ).to(args.device) + critic2 = Critic(net, args.device).to(args.device) critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) policy = TD3Policy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 267732f..365fb12 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -10,11 +10,8 @@ from tianshou.env import VectorEnv from tianshou.data import Collector, ReplayBuffer from tianshou.policy import A2CPolicy, ImitationPolicy from tianshou.trainer import onpolicy_trainer, offpolicy_trainer - -if __name__ == '__main__': - from net import Net, Actor, Critic -else: # pytest - from test.discrete.net import Net, Actor, Critic +from tianshou.utils.net.discrete import Actor, Critic +from tianshou.utils.net.common import Net def get_args(): diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index b378194..96ddb70 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -10,11 +10,7 @@ from tianshou.env import VectorEnv from tianshou.policy import DQNPolicy from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer - -if __name__ == '__main__': - from net import Net -else: # pytest - from test.discrete.net import Net +from tianshou.utils.net.common import Net def get_args(): @@ -61,8 +57,8 @@ def test_dqn(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.layer_num, args.state_shape, args.action_shape, args.device) - net = net.to(args.device) + net = Net(args.layer_num, args.state_shape, + args.action_shape, args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DQNPolicy( net, optim, args.gamma, args.n_step, diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 8c15e3b..f0f34ed 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -10,11 +10,7 @@ from tianshou.env import VectorEnv from tianshou.policy import DQNPolicy from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer - -if __name__ == '__main__': - from net import Recurrent -else: # pytest - from test.discrete.net import Recurrent +from tianshou.utils.net.common import Recurrent def get_args(): @@ -63,8 +59,7 @@ def test_drqn(args=get_args()): test_envs.seed(args.seed) # model net = Recurrent(args.layer_num, args.state_shape, - args.action_shape, args.device) - net = net.to(args.device) + args.action_shape, args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DQNPolicy( net, optim, args.gamma, args.n_step, diff --git a/test/discrete/test_pdqn.py b/test/discrete/test_pdqn.py index 70bfdbb..22fa347 100644 --- a/test/discrete/test_pdqn.py +++ b/test/discrete/test_pdqn.py @@ -6,16 +6,12 @@ import argparse import numpy as np from torch.utils.tensorboard import SummaryWriter +from tianshou.utils.net.common import Net from tianshou.env import VectorEnv from tianshou.policy import DQNPolicy from tianshou.trainer import offpolicy_trainer from tianshou.data import Collector, ReplayBuffer, PrioritizedReplayBuffer -if __name__ == '__main__': - from net import Net -else: # pytest - from test.discrete.net import Net - def get_args(): parser = argparse.ArgumentParser() @@ -64,8 +60,8 @@ def test_pdqn(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.layer_num, args.state_shape, args.action_shape, args.device) - net = net.to(args.device) + net = Net(args.layer_num, args.state_shape, + args.action_shape, args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DQNPolicy( net, optim, args.gamma, args.n_step, diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index d2817a7..c820762 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -7,16 +7,12 @@ import argparse import numpy as np from torch.utils.tensorboard import SummaryWriter +from tianshou.utils.net.common import Net from tianshou.env import VectorEnv from tianshou.policy import PGPolicy from tianshou.trainer import onpolicy_trainer from tianshou.data import Batch, Collector, ReplayBuffer -if __name__ == '__main__': - from net import Net -else: # pytest - from test.discrete.net import Net - def compute_return_base(batch, aa=None, bb=None, gamma=0.1): returns = np.zeros_like(batch.rew) @@ -129,8 +125,7 @@ def test_pg(args=get_args()): # model net = Net( args.layer_num, args.state_shape, args.action_shape, - device=args.device, softmax=True) - net = net.to(args.device) + device=args.device, softmax=True).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) dist = torch.distributions.Categorical policy = PGPolicy(net, optim, dist, args.gamma, diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 44850a8..ca0e879 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -10,11 +10,8 @@ from tianshou.env import VectorEnv from tianshou.policy import PPOPolicy from tianshou.trainer import onpolicy_trainer from tianshou.data import Collector, ReplayBuffer - -if __name__ == '__main__': - from net import Net, Actor, Critic -else: # pytest - from test.discrete.net import Net, Actor, Critic +from tianshou.utils.net.discrete import Actor, Critic +from tianshou.utils.net.common import Net def get_args(): diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index aeb0eb3..57bdba9 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -46,7 +46,7 @@ class ImitationPolicy(BasePolicy): self.optim.zero_grad() if self.mode == 'continuous': a = self(batch).act - a_ = to_torch(batch.act, dtype=torch.float, device=a.device) + a_ = to_torch(batch.act, dtype=torch.float32, device=a.device) loss = F.mse_loss(a, a_) elif self.mode == 'discrete': # classification a = self(batch).logits diff --git a/tianshou/utils/net/__init__.py b/tianshou/utils/net/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/discrete/net.py b/tianshou/utils/net/common.py similarity index 62% rename from test/discrete/net.py rename to tianshou/utils/net/common.py index 1dcf783..8dd376b 100644 --- a/test/discrete/net.py +++ b/tianshou/utils/net/common.py @@ -1,82 +1,67 @@ -import torch import numpy as np +import torch from torch import nn -import torch.nn.functional as F from tianshou.data import to_torch class Net(nn.Module): + """Simple MLP backbone. For advanced usage (how to customize the network), + please refer to :ref:`build_the_network`. + + :param concat: whether the input shape is concatenated by state_shape + and action_shape. If it is True, ``action_shape`` is not the output + shape, but affects the input shape. + """ + def __init__(self, layer_num, state_shape, action_shape=0, device='cpu', - softmax=False): + softmax=False, concat=False): super().__init__() self.device = device + input_size = np.prod(state_shape) + if concat: + input_size += np.prod(action_shape) self.model = [ - nn.Linear(np.prod(state_shape), 128), + nn.Linear(input_size, 128), nn.ReLU(inplace=True)] for i in range(layer_num): self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] - if action_shape: + if action_shape and not concat: self.model += [nn.Linear(128, np.prod(action_shape))] if softmax: self.model += [nn.Softmax(dim=-1)] self.model = nn.Sequential(*self.model) def forward(self, s, state=None, info={}): - s = to_torch(s, device=self.device, dtype=torch.float) - batch = s.shape[0] - s = s.view(batch, -1) + s = to_torch(s, device=self.device, dtype=torch.float32) + s = s.flatten(1) logits = self.model(s) return logits, state -class Actor(nn.Module): - def __init__(self, preprocess_net, action_shape): - super().__init__() - self.preprocess = preprocess_net - self.last = nn.Linear(128, np.prod(action_shape)) - - def forward(self, s, state=None, info={}): - logits, h = self.preprocess(s, state) - logits = F.softmax(self.last(logits), dim=-1) - return logits, h - - -class Critic(nn.Module): - def __init__(self, preprocess_net): - super().__init__() - self.preprocess = preprocess_net - self.last = nn.Linear(128, 1) - - def forward(self, s, **kwargs): - logits, h = self.preprocess(s, state=kwargs.get('state', None)) - logits = self.last(logits) - return logits - - class Recurrent(nn.Module): + """Simple Recurrent network based on LSTM. For advanced usage (how to + customize the network), please refer to :ref:`build_the_network`. + """ + def __init__(self, layer_num, state_shape, action_shape, device='cpu'): super().__init__() self.state_shape = state_shape self.action_shape = action_shape self.device = device - self.fc1 = nn.Linear(np.prod(state_shape), 128) self.nn = nn.LSTM(input_size=128, hidden_size=128, num_layers=layer_num, batch_first=True) + self.fc1 = nn.Linear(np.prod(state_shape), 128) self.fc2 = nn.Linear(128, np.prod(action_shape)) def forward(self, s, state=None, info={}): - s = to_torch(s, device=self.device, dtype=torch.float) + s = to_torch(s, device=self.device, dtype=torch.float32) # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. if len(s.shape) == 2: - bsz, dim = s.shape - length = 1 - else: - bsz, length, dim = s.shape - s = self.fc1(s.view([bsz * length, dim])) - s = s.view(bsz, length, -1) + s = s.unsqueeze(-2) + s = self.fc1(s) self.nn.flatten_parameters() if state is None: s, (h, c) = self.nn(s) diff --git a/test/continuous/net.py b/tianshou/utils/net/continuous.py similarity index 53% rename from test/continuous/net.py rename to tianshou/utils/net/continuous.py index 043044a..aac6fc4 100644 --- a/test/continuous/net.py +++ b/tianshou/utils/net/continuous.py @@ -6,85 +6,77 @@ from tianshou.data import to_torch class Actor(nn.Module): - def __init__(self, layer_num, state_shape, action_shape, + """For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + + def __init__(self, preprocess_net, action_shape, max_action, device='cpu'): super().__init__() - self.device = device - self.model = [ - nn.Linear(np.prod(state_shape), 128), - nn.ReLU(inplace=True)] - for i in range(layer_num): - self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] - self.model += [nn.Linear(128, np.prod(action_shape))] - self.model = nn.Sequential(*self.model) + self.preprocess = preprocess_net + self.last = nn.Linear(128, np.prod(action_shape)) self._max = max_action - def forward(self, s, **kwargs): - s = to_torch(s, device=self.device, dtype=torch.float) - batch = s.shape[0] - s = s.view(batch, -1) - logits = self.model(s) - logits = self._max * torch.tanh(logits) - return logits, None - - -class ActorProb(nn.Module): - def __init__(self, layer_num, state_shape, action_shape, - max_action, device='cpu'): - super().__init__() - self.device = device - self.model = [ - nn.Linear(np.prod(state_shape), 128), - nn.ReLU(inplace=True)] - for i in range(layer_num): - self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] - self.model = nn.Sequential(*self.model) - self.mu = nn.Linear(128, np.prod(action_shape)) - self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1)) - # self.sigma = nn.Linear(128, np.prod(action_shape)) - self._max = max_action - - def forward(self, s, **kwargs): - s = to_torch(s, device=self.device, dtype=torch.float) - batch = s.shape[0] - s = s.view(batch, -1) - logits = self.model(s) - mu = self.mu(logits) - shape = [1] * len(mu.shape) - shape[1] = -1 - sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp() - # assert sigma.shape == mu.shape - # mu = self._max * torch.tanh(self.mu(logits)) - # sigma = torch.exp(self.sigma(logits)) - return (mu, sigma), None + def forward(self, s, state=None, info={}): + logits, h = self.preprocess(s, state) + logits = self._max * torch.tanh(self.last(logits)) + return logits, h class Critic(nn.Module): - def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'): + """For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + + def __init__(self, preprocess_net, device='cpu'): super().__init__() self.device = device - self.model = [ - nn.Linear(np.prod(state_shape) + np.prod(action_shape), 128), - nn.ReLU(inplace=True)] - for i in range(layer_num): - self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] - self.model += [nn.Linear(128, 1)] - self.model = nn.Sequential(*self.model) + self.preprocess = preprocess_net + self.last = nn.Linear(128, 1) def forward(self, s, a=None, **kwargs): - s = to_torch(s, device=self.device, dtype=torch.float) - batch = s.shape[0] - s = s.view(batch, -1) + s = to_torch(s, device=self.device, dtype=torch.float32) + s = s.flatten(1) if a is not None: - if not isinstance(a, torch.Tensor): - a = torch.tensor(a, device=self.device, dtype=torch.float) - a = a.view(batch, -1) + a = to_torch(a, device=self.device, dtype=torch.float32) + a = a.flatten(1) s = torch.cat([s, a], dim=1) - logits = self.model(s) + logits, h = self.preprocess(s) + logits = self.last(logits) return logits +class ActorProb(nn.Module): + """For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + + def __init__(self, preprocess_net, action_shape, + max_action, device='cpu', unbounded=False): + super().__init__() + self.preprocess = preprocess_net + self.device = device + self.mu = nn.Linear(128, np.prod(action_shape)) + self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1)) + self._max = max_action + self._unbounded = unbounded + + def forward(self, s, state=None, **kwargs): + logits, h = self.preprocess(s, state) + mu = self.mu(logits) + if not self._unbounded: + mu = self._max * torch.tanh(mu) + shape = [1] * len(mu.shape) + shape[1] = -1 + sigma = (self.sigma.view(shape) + torch.zeros_like(mu)).exp() + return (mu, sigma), None + + class RecurrentActorProb(nn.Module): + """For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + def __init__(self, layer_num, state_shape, action_shape, max_action, device='cpu'): super().__init__() @@ -95,16 +87,12 @@ class RecurrentActorProb(nn.Module): self.sigma = nn.Parameter(torch.zeros(np.prod(action_shape), 1)) def forward(self, s, **kwargs): - s = to_torch(s, device=self.device, dtype=torch.float) + s = to_torch(s, device=self.device, dtype=torch.float32) # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. if len(s.shape) == 2: - bsz, dim = s.shape - length = 1 - else: - bsz, length, dim = s.shape - s = s.view(bsz, length, -1) + s = s.unsqueeze(-2) logits, _ = self.nn(s) logits = logits[:, -1] mu = self.mu(logits) @@ -115,6 +103,10 @@ class RecurrentActorProb(nn.Module): class RecurrentCritic(nn.Module): + """For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'): super().__init__() self.state_shape = state_shape @@ -125,7 +117,7 @@ class RecurrentCritic(nn.Module): self.fc2 = nn.Linear(128 + np.prod(action_shape), 1) def forward(self, s, a=None): - s = to_torch(s, device=self.device, dtype=torch.float) + s = to_torch(s, device=self.device, dtype=torch.float32) # s [bsz, len, dim] (training) or [bsz, dim] (evaluation) # In short, the tensor's shape in training phase is longer than which # in evaluation phase. @@ -135,7 +127,7 @@ class RecurrentCritic(nn.Module): s = s[:, -1] if a is not None: if not isinstance(a, torch.Tensor): - a = torch.tensor(a, device=self.device, dtype=torch.float) + a = torch.tensor(a, device=self.device, dtype=torch.float32) s = torch.cat([s, a], dim=1) s = self.fc2(s) return s diff --git a/examples/discrete_net.py b/tianshou/utils/net/discrete.py similarity index 69% rename from examples/discrete_net.py rename to tianshou/utils/net/discrete.py index eda7108..4cad50d 100644 --- a/examples/discrete_net.py +++ b/tianshou/utils/net/discrete.py @@ -4,29 +4,11 @@ from torch import nn import torch.nn.functional as F -class Net(nn.Module): - def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'): - super().__init__() - self.device = device - self.model = [ - nn.Linear(np.prod(state_shape), 128), - nn.ReLU(inplace=True)] - for i in range(layer_num): - self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] - if action_shape: - self.model += [nn.Linear(128, np.prod(action_shape))] - self.model = nn.Sequential(*self.model) - - def forward(self, s, state=None, info={}): - if not isinstance(s, torch.Tensor): - s = torch.tensor(s, device=self.device, dtype=torch.float) - batch = s.shape[0] - s = s.view(batch, -1) - logits = self.model(s) - return logits, state - - class Actor(nn.Module): + """For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + def __init__(self, preprocess_net, action_shape): super().__init__() self.preprocess = preprocess_net @@ -39,18 +21,25 @@ class Actor(nn.Module): class Critic(nn.Module): + """For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + def __init__(self, preprocess_net): super().__init__() self.preprocess = preprocess_net self.last = nn.Linear(128, 1) - def forward(self, s): - logits, h = self.preprocess(s, None) + def forward(self, s, **kwargs): + logits, h = self.preprocess(s, state=kwargs.get('state', None)) logits = self.last(logits) return logits class DQN(nn.Module): + """For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ def __init__(self, h, w, action_shape, device='cpu'): super(DQN, self).__init__() @@ -74,7 +63,7 @@ class DQN(nn.Module): def forward(self, x, state=None, info={}): if not isinstance(x, torch.Tensor): - x = torch.tensor(x, device=self.device, dtype=torch.float) + x = torch.tensor(x, device=self.device, dtype=torch.float32) x = x.permute(0, 3, 1, 2) x = F.relu(self.bn1(self.conv1(x))) x = F.relu(self.bn2(self.conv2(x)))