diff --git a/README.md b/README.md index 8d057d2..414e288 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,8 @@ - [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf) - [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) -- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf) +- [Double DQN](https://arxiv.org/pdf/1509.06461.pdf) +- [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf) - [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/) - [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf) - [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf) diff --git a/docs/index.rst b/docs/index.rst index f7169c0..25f0410 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -11,6 +11,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.PGPolicy` `Policy Gradient `_ * :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network `_ * :class:`~tianshou.policy.DQNPolicy` `Double DQN `_ +* :class:`~tianshou.policy.DQNPolicy` `Dueling DQN `_ * :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic `_ * :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient `_ * :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization `_ diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index e4e2f0f..fe68ccf 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -193,7 +193,7 @@ The explanation of each Tianshou class/function will be deferred to their first parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--lr', type=float, default=1e-3) - parser.add_argument('--gamma', type=float, default=0.1, + parser.add_argument('--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win') parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) diff --git a/examples/acrobot_dualdqn.py b/examples/acrobot_dualdqn.py new file mode 100644 index 0000000..a5b0ac2 --- /dev/null +++ b/examples/acrobot_dualdqn.py @@ -0,0 +1,116 @@ +import os +import gym +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.env import VectorEnv +from tianshou.policy import DQNPolicy +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, ReplayBuffer +from tianshou.utils.net.common import Net + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='Acrobot-v1') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--eps-test', type=float, default=0.05) + parser.add_argument('--eps-train', type=float, default=0.5) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.95) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--target-update-freq', type=int, default=320) + parser.add_argument('--epoch', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=1000) + parser.add_argument('--collect-per-step', type=int, default=100) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--layer-num', type=int, default=0) + parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + args = parser.parse_known_args()[0] + return args + + +def test_dqn(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # train_envs = gym.make(args.task) + # you can also use tianshou.env.SubprocVectorEnv + train_envs = VectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)]) + # test_envs = gym.make(args.task) + test_envs = VectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net = Net(args.layer_num, args.state_shape, + args.action_shape, args.device, dueling=(2, 2)).to(args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + policy = DQNPolicy( + net, optim, args.gamma, args.n_step, + target_update_freq=args.target_update_freq) + # collector + train_collector = Collector( + policy, train_envs, ReplayBuffer(args.buffer_size)) + test_collector = Collector(policy, test_envs) + # policy.set_eps(1) + train_collector.collect(n_step=args.batch_size) + # log + log_path = os.path.join(args.logdir, args.task, 'dqn') + writer = SummaryWriter(log_path) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(x): + return x >= env.spec.reward_threshold + + def train_fn(x): + if x <= int(0.1 * args.epoch): + policy.set_eps(args.eps_train) + elif x <= int(0.5 * args.epoch): + eps = args.eps_train - (x - 0.1 * args.epoch) / \ + (0.4 * args.epoch) * (0.5 * args.eps_train) + policy.set_eps(eps) + else: + policy.set_eps(0.5 * args.eps_train) + + def test_fn(x): + policy.set_eps(args.eps_test) + + # trainer + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, args.test_num, + args.batch_size, train_fn=train_fn, test_fn=test_fn, + stop_fn=stop_fn, save_fn=save_fn, writer=writer) + + assert stop_fn(result['best_reward']) + train_collector.close() + test_collector.close() + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + collector.close() + + +if __name__ == '__main__': + test_dqn(get_args()) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 0a3a5b0..0455f70 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -58,7 +58,8 @@ def test_dqn(args=get_args()): test_envs.seed(args.seed) # model net = Net(args.layer_num, args.state_shape, - args.action_shape, args.device).to(args.device) + args.action_shape, args.device, + dueling=(2, 2)).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DQNPolicy( net, optim, args.gamma, args.n_step, @@ -80,7 +81,15 @@ def test_dqn(args=get_args()): return x >= env.spec.reward_threshold def train_fn(x): - policy.set_eps(args.eps_train) + # eps annnealing, just a demo + if x <= int(0.1 * args.epoch): + policy.set_eps(args.eps_train) + elif x <= int(0.5 * args.epoch): + eps = args.eps_train - (x - 0.1 * args.epoch) / \ + (0.4 * args.epoch) * (0.9 * args.eps_train) + policy.set_eps(eps) + else: + policy.set_eps(0.1 * args.eps_train) def test_fn(x): policy.set_eps(args.eps_test) diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index 5b718ce..00b8f56 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -23,7 +23,7 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--lr', type=float, default=1e-3) - parser.add_argument('--gamma', type=float, default=0.1, + parser.add_argument('--gamma', type=float, default=0.9, help='a smaller gamma favors earlier win') parser.add_argument('--n-step', type=int, default=3) parser.add_argument('--target-update-freq', type=int, default=320) @@ -38,7 +38,7 @@ def get_parser() -> argparse.ArgumentParser: parser.add_argument('--render', type=float, default=0.1) parser.add_argument('--board_size', type=int, default=6) parser.add_argument('--win_size', type=int, default=4) - parser.add_argument('--win-rate', type=float, default=0.8, + parser.add_argument('--win_rate', type=float, default=0.9, help='the expected winning rate') parser.add_argument('--watch', default=False, action='store_true', help='no training, ' diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 9bdbd11..9bf60a6 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -11,8 +11,12 @@ from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \ class DQNPolicy(BasePolicy): """Implementation of Deep Q Network. arXiv:1312.5602 + Implementation of Double Q-Learning. arXiv:1509.06461 + Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is + implemented in the network side, not here) + :param torch.nn.Module model: a model following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> logits) :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 2401ebd..a84a7e7 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -1,36 +1,77 @@ -import numpy as np import torch +import numpy as np from torch import nn +from typing import Tuple, Union, Optional from tianshou.data import to_torch +def miniblock(inp: int, oup: int, norm_layer: nn.modules.Module): + ret = [nn.Linear(inp, oup)] + if norm_layer is not None: + ret += [norm_layer(oup)] + ret += [nn.ReLU(inplace=True)] + return ret + + class Net(nn.Module): """Simple MLP backbone. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. - :param concat: whether the input shape is concatenated by state_shape + :param bool concat: whether the input shape is concatenated by state_shape and action_shape. If it is True, ``action_shape`` is not the output shape, but affects the input shape. + :param bool dueling: whether to use dueling network to calculate Q values + (for Dueling DQN), defaults to False. + :param nn.modules.Module norm_layer: use which normalization before ReLU, + e.g., ``nn.LayerNorm`` and ``nn.BatchNorm1d``, defaults to None. """ - def __init__(self, layer_num, state_shape, action_shape=0, device='cpu', - softmax=False, concat=False, hidden_layer_size=128): + def __init__(self, layer_num: int, state_shape: tuple, + action_shape: Optional[tuple] = 0, + device: Union[str, torch.device] = 'cpu', + softmax: bool = False, + concat: bool = False, + hidden_layer_size: int = 128, + dueling: Optional[Tuple[int, int]] = None, + norm_layer: Optional[nn.modules.Module] = None): super().__init__() self.device = device + self.dueling = dueling + self.softmax = softmax input_size = np.prod(state_shape) if concat: input_size += np.prod(action_shape) - self.model = [ - nn.Linear(input_size, hidden_layer_size), - nn.ReLU(inplace=True)] + + self.model = miniblock(input_size, hidden_layer_size, norm_layer) + for i in range(layer_num): - self.model += [nn.Linear(hidden_layer_size, hidden_layer_size), - nn.ReLU(inplace=True)] - if action_shape and not concat: - self.model += [nn.Linear(hidden_layer_size, np.prod(action_shape))] - if softmax: - self.model += [nn.Softmax(dim=-1)] + self.model += miniblock(hidden_layer_size, + hidden_layer_size, norm_layer) + + if self.dueling is None: + if action_shape and not concat: + self.model += [nn.Linear(hidden_layer_size, + np.prod(action_shape))] + else: # dueling DQN + assert isinstance(self.dueling, tuple) and len(self.dueling) == 2 + + q_layer_num, v_layer_num = self.dueling + self.Q, self.V = [], [] + + for i in range(q_layer_num): + self.Q += miniblock(hidden_layer_size, + hidden_layer_size, norm_layer) + for i in range(v_layer_num): + self.V += miniblock(hidden_layer_size, + hidden_layer_size, norm_layer) + + if action_shape and not concat: + self.Q += [nn.Linear(hidden_layer_size, np.prod(action_shape))] + self.V += [nn.Linear(hidden_layer_size, 1)] + + self.Q = nn.Sequential(*self.Q) + self.V = nn.Sequential(*self.V) self.model = nn.Sequential(*self.model) def forward(self, s, state=None, info={}): @@ -38,6 +79,11 @@ class Net(nn.Module): s = to_torch(s, device=self.device, dtype=torch.float32) s = s.reshape(s.size(0), -1) logits = self.model(s) + if self.dueling is not None: # Dueling DQN + q, v = self.Q(logits), self.V(logits) + logits = q - q.mean(dim=1, keepdim=True) + v + if self.softmax: + logits = torch.softmax(logits, dim=-1) return logits, state diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 7916a9d..ae6a1ef 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -41,25 +41,33 @@ class Critic(nn.Module): class DQN(nn.Module): """For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. + + Reference paper: "Human-level control through deep reinforcement learning". """ def __init__(self, h, w, action_shape, device='cpu'): super(DQN, self).__init__() self.device = device - self.conv1 = nn.Conv2d(4, 16, kernel_size=5, stride=2) - self.bn1 = nn.BatchNorm2d(16) - self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2) - self.bn2 = nn.BatchNorm2d(32) - self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2) - self.bn3 = nn.BatchNorm2d(32) + self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4) + self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) + self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) def conv2d_size_out(size, kernel_size=5, stride=2): return (size - (kernel_size - 1) - 1) // stride + 1 - convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w))) - convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h))) - linear_input_size = convw * convh * 32 + def conv2d_layers_size_out(size, + kernel_size_1=8, stride_1=4, + kernel_size_2=4, stride_2=2, + kernel_size_3=3, stride_3=1): + size = conv2d_size_out(size, kernel_size_1, stride_1) + size = conv2d_size_out(size, kernel_size_2, stride_2) + size = conv2d_size_out(size, kernel_size_3, stride_3) + return size + + convw = conv2d_layers_size_out(w) + convh = conv2d_layers_size_out(h) + linear_input_size = convw * convh * 64 self.fc = nn.Linear(linear_input_size, 512) self.head = nn.Linear(512, action_shape) @@ -68,8 +76,8 @@ class DQN(nn.Module): if not isinstance(x, torch.Tensor): x = torch.tensor(x, device=self.device, dtype=torch.float32) x = x.permute(0, 3, 1, 2) - x = F.relu(self.bn1(self.conv1(x))) - x = F.relu(self.bn2(self.conv2(x))) - x = F.relu(self.bn3(self.conv3(x))) + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + x = F.relu(self.conv3(x)) x = self.fc(x.reshape(x.size(0), -1)) return self.head(x), state