diff --git a/README.md b/README.md index ab9e43f..f189f36 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ - [Double DQN](https://arxiv.org/pdf/1509.06461.pdf) - [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf) - [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf) +- [Rainbow DQN (Rainbow)](https://arxiv.org/pdf/1710.02298.pdf) - [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf) - [Implicit Quantile Network (IQN)](https://arxiv.org/pdf/1806.06923.pdf) - [Fully-parameterized Quantile Function (FQF)](https://arxiv.org/pdf/1911.02140.pdf) diff --git a/docs/api/tianshou.policy.rst b/docs/api/tianshou.policy.rst index 39f478c..b05f5be 100644 --- a/docs/api/tianshou.policy.rst +++ b/docs/api/tianshou.policy.rst @@ -30,6 +30,11 @@ DQN Family :undoc-members: :show-inheritance: +.. autoclass:: tianshou.policy.RainbowPolicy + :members: + :undoc-members: + :show-inheritance: + .. autoclass:: tianshou.policy.QRDQNPolicy :members: :undoc-members: diff --git a/docs/index.rst b/docs/index.rst index 4afe03a..5c33245 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -13,6 +13,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.DQNPolicy` `Double DQN `_ * :class:`~tianshou.policy.DQNPolicy` `Dueling DQN `_ * :class:`~tianshou.policy.C51Policy` `Categorical DQN `_ +* :class:`~tianshou.policy.RainbowPolicy` `Rainbow DQN `_ * :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN `_ * :class:`~tianshou.policy.IQNPolicy` `Implicit Quantile Network `_ * :class:`~tianshou.policy.FQFPolicy` `Fully-parameterized Quantile Function `_ diff --git a/examples/atari/README.md b/examples/atari/README.md index a46fcac..ffccecb 100644 --- a/examples/atari/README.md +++ b/examples/atari/README.md @@ -82,6 +82,20 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | SeaquestNoFrameskip-v4 | 10775 | ![](results/fqf/Seaquest_rew.png) | `python3 atari_fqf.py --task "SeaquestNoFrameskip-v4"` | | SpaceInvadersNoFrameskip-v4 | 2482 | ![](results/fqf/SpaceInvaders_rew.png) | `python3 atari_fqf.py --task "SpaceInvadersNoFrameskip-v4"` | +# Rainbow (single run) + +One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. + +| task | best reward | reward curve | parameters | +| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 21 | ![](results/rainbow/Pong_rew.png) | `python3 atari_rainbow.py --task "PongNoFrameskip-v4" --batch-size 64` | +| BreakoutNoFrameskip-v4 | 684.6 | ![](results/rainbow/Breakout_rew.png) | `python3 atari_rainbow.py --task "BreakoutNoFrameskip-v4" --n-step 1` | +| EnduroNoFrameskip-v4 | 1625.9 | ![](results/rainbow/Enduro_rew.png) | `python3 atari_rainbow.py --task "EnduroNoFrameskip-v4"` | +| QbertNoFrameskip-v4 | 16192.5 | ![](results/rainbow/Qbert_rew.png) | `python3 atari_rainbow.py --task "QbertNoFrameskip-v4"` | +| MsPacmanNoFrameskip-v4 | 3101 | ![](results/rainbow/MsPacman_rew.png) | `python3 atari_rainbow.py --task "MsPacmanNoFrameskip-v4"` | +| SeaquestNoFrameskip-v4 | 2126 | ![](results/rainbow/Seaquest_rew.png) | `python3 atari_rainbow.py --task "SeaquestNoFrameskip-v4"` | +| SpaceInvadersNoFrameskip-v4 | 1794.5 | ![](results/rainbow/SpaceInvaders_rew.png) | `python3 atari_rainbow.py --task "SpaceInvadersNoFrameskip-v4"` | + # BCQ To running BCQ algorithm on Atari, you need to do the following things: diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index f531d96..2eccf11 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -2,6 +2,7 @@ import torch import numpy as np from torch import nn from typing import Any, Dict, Tuple, Union, Optional, Sequence +from tianshou.utils.net.discrete import NoisyLinear class DQN(nn.Module): @@ -81,6 +82,65 @@ class C51(DQN): return x, state +class Rainbow(DQN): + """Reference: Rainbow: Combining Improvements in Deep Reinforcement Learning. + + For advanced usage (how to customize the network), please refer to + :ref:`build_the_network`. + """ + + def __init__( + self, + c: int, + h: int, + w: int, + action_shape: Sequence[int], + num_atoms: int = 51, + noisy_std: float = 0.5, + device: Union[str, int, torch.device] = "cpu", + is_dueling: bool = True, + is_noisy: bool = True, + ) -> None: + super().__init__(c, h, w, action_shape, device, features_only=True) + self.action_num = np.prod(action_shape) + self.num_atoms = num_atoms + + def linear(x, y): + if is_noisy: + return NoisyLinear(x, y, noisy_std) + else: + return nn.Linear(x, y) + + self.Q = nn.Sequential( + linear(self.output_dim, 512), nn.ReLU(inplace=True), + linear(512, self.action_num * self.num_atoms)) + self._is_dueling = is_dueling + if self._is_dueling: + self.V = nn.Sequential( + linear(self.output_dim, 512), nn.ReLU(inplace=True), + linear(512, self.num_atoms)) + self.output_dim = self.action_num * self.num_atoms + + def forward( + self, + x: Union[np.ndarray, torch.Tensor], + state: Optional[Any] = None, + info: Dict[str, Any] = {}, + ) -> Tuple[torch.Tensor, Any]: + r"""Mapping: x -> Z(x, \*).""" + x, state = super().forward(x) + q = self.Q(x) + q = q.view(-1, self.action_num, self.num_atoms) + if self._is_dueling: + v = self.V(x) + v = v.view(-1, 1, self.num_atoms) + logits = q - q.mean(dim=1, keepdim=True) + v + else: + logits = q + y = logits.softmax(dim=2) + return y, state + + class QRDQN(DQN): """Reference: Distributional Reinforcement Learning with Quantile \ Regression. diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py new file mode 100644 index 0000000..941df5c --- /dev/null +++ b/examples/atari/atari_rainbow.py @@ -0,0 +1,204 @@ +import os +import torch +import pprint +import datetime +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.policy import RainbowPolicy +from tianshou.utils import BasicLogger +from tianshou.env import SubprocVectorEnv +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer + +from atari_network import Rainbow +from atari_wrapper import wrap_deepmind + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--eps-test', type=float, default=0.005) + parser.add_argument('--eps-train', type=float, default=1.) + parser.add_argument('--eps-train-final', type=float, default=0.05) + parser.add_argument('--buffer-size', type=int, default=100000) + parser.add_argument('--lr', type=float, default=0.0000625) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--num-atoms', type=int, default=51) + parser.add_argument('--v-min', type=float, default=-10.) + parser.add_argument('--v-max', type=float, default=10.) + parser.add_argument('--noisy-std', type=float, default=0.1) + parser.add_argument('--no-dueling', action='store_true', default=False) + parser.add_argument('--no-noisy', action='store_true', default=False) + parser.add_argument('--no-priority', action='store_true', default=False) + parser.add_argument('--alpha', type=float, default=0.5) + parser.add_argument('--beta', type=float, default=0.4) + parser.add_argument('--beta-final', type=float, default=1.) + parser.add_argument('--beta-anneal-step', type=int, default=5000000) + parser.add_argument('--no-weight-norm', action='store_true', default=False) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--target-update-freq', type=int, default=500) + parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--step-per-epoch', type=int, default=100000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) + parser.add_argument('--batch-size', type=int, default=32) + parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--test-num', type=int, default=10) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + parser.add_argument('--frames-stack', type=int, default=4) + parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument('--watch', default=False, action='store_true', + help='watch the play of pre-trained policy only') + parser.add_argument('--save-buffer-name', type=str, default=None) + return parser.parse_args() + + +def make_atari_env(args): + return wrap_deepmind(args.task, frame_stack=args.frames_stack) + + +def make_atari_env_watch(args): + return wrap_deepmind(args.task, frame_stack=args.frames_stack, + episode_life=False, clip_rewards=False) + + +def test_rainbow(args=get_args()): + env = make_atari_env(args) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + # make environments + train_envs = SubprocVectorEnv([lambda: make_atari_env(args) + for _ in range(args.training_num)]) + test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) + for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # define model + net = Rainbow(*args.state_shape, args.action_shape, + args.num_atoms, args.noisy_std, args.device, + is_dueling=not args.no_dueling, + is_noisy=not args.no_noisy) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + # define policy + policy = RainbowPolicy( + net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max, + args.n_step, target_update_freq=args.target_update_freq + ).to(args.device) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together + # when you have enough RAM + if args.no_priority: + buffer = VectorReplayBuffer( + args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True, + save_only_last_obs=True, stack_num=args.frames_stack) + else: + buffer = PrioritizedVectorReplayBuffer( + args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True, + save_only_last_obs=True, stack_num=args.frames_stack, alpha=args.alpha, + beta=args.beta, weight_norm=not args.no_weight_norm) + # collector + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + # log + log_path = os.path.join( + args.logdir, args.task, 'rainbow', + f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}') + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = BasicLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + if env.spec.reward_threshold: + return mean_rewards >= env.spec.reward_threshold + elif 'Pong' in args.task: + return mean_rewards >= 20 + else: + return False + + def train_fn(epoch, env_step): + # nature DQN setting, linear decay in the first 1M steps + if env_step <= 1e6: + eps = args.eps_train - env_step / 1e6 * \ + (args.eps_train - args.eps_train_final) + else: + eps = args.eps_train_final + policy.set_eps(eps) + logger.write('train/eps', env_step, eps) + if not args.no_priority: + if env_step <= args.beta_anneal_step: + beta = args.beta - env_step / args.beta_anneal_step * \ + (args.beta - args.beta_final) + else: + beta = args.beta_final + buffer.set_beta(beta) + logger.write('train/beta', env_step, beta) + + def test_fn(epoch, env_step): + policy.set_eps(args.eps_test) + + # watch agent's performance + def watch(): + print("Setup test envs ...") + policy.eval() + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + if args.save_buffer_name: + print(f"Generate buffer with size {args.buffer_size}") + buffer = PrioritizedVectorReplayBuffer( + args.buffer_size, buffer_num=len(test_envs), + ignore_obs_next=True, save_only_last_obs=True, + stack_num=args.frames_stack, alpha=args.alpha, + beta=args.beta) + collector = Collector(policy, test_envs, buffer, + exploration_noise=True) + result = collector.collect(n_step=args.buffer_size) + print(f"Save buffer into {args.save_buffer_name}") + # Unfortunately, pickle will cause oom with 1M buffer size + buffer.save_hdf5(args.save_buffer_name) + else: + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, + render=args.render) + rew = result["rews"].mean() + print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + + if args.watch: + watch() + exit(0) + + # test train_collector and start filling replay buffer + train_collector.collect(n_step=args.batch_size * args.training_num) + # trainer + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.step_per_collect, args.test_num, + args.batch_size, train_fn=train_fn, test_fn=test_fn, + stop_fn=stop_fn, save_fn=save_fn, logger=logger, + update_per_step=args.update_per_step, test_in_train=False) + + pprint.pprint(result) + watch() + + +if __name__ == '__main__': + test_rainbow(get_args()) diff --git a/examples/atari/results/rainbow/Breakout_rew.png b/examples/atari/results/rainbow/Breakout_rew.png new file mode 100644 index 0000000..b2071cc Binary files /dev/null and b/examples/atari/results/rainbow/Breakout_rew.png differ diff --git a/examples/atari/results/rainbow/Enduro_rew.png b/examples/atari/results/rainbow/Enduro_rew.png new file mode 100644 index 0000000..f6b913f Binary files /dev/null and b/examples/atari/results/rainbow/Enduro_rew.png differ diff --git a/examples/atari/results/rainbow/MsPacman_rew.png b/examples/atari/results/rainbow/MsPacman_rew.png new file mode 100644 index 0000000..2b51f2d Binary files /dev/null and b/examples/atari/results/rainbow/MsPacman_rew.png differ diff --git a/examples/atari/results/rainbow/Pong_rew.png b/examples/atari/results/rainbow/Pong_rew.png new file mode 100644 index 0000000..3566cfb Binary files /dev/null and b/examples/atari/results/rainbow/Pong_rew.png differ diff --git a/examples/atari/results/rainbow/Qbert_rew.png b/examples/atari/results/rainbow/Qbert_rew.png new file mode 100644 index 0000000..1644ab8 Binary files /dev/null and b/examples/atari/results/rainbow/Qbert_rew.png differ diff --git a/examples/atari/results/rainbow/Seaquest_rew.png b/examples/atari/results/rainbow/Seaquest_rew.png new file mode 100644 index 0000000..9c5898a Binary files /dev/null and b/examples/atari/results/rainbow/Seaquest_rew.png differ diff --git a/examples/atari/results/rainbow/SpaceInvaders_rew.png b/examples/atari/results/rainbow/SpaceInvaders_rew.png new file mode 100644 index 0000000..2182ee8 Binary files /dev/null and b/examples/atari/results/rainbow/SpaceInvaders_rew.png differ diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 04b99aa..39e5bad 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -193,7 +193,7 @@ def test_priortized_replaybuffer(size=32, bufsize=15): mask = np.isin(np.arange(buf2.maxsize), indices) assert np.all(weight[mask] == weight[mask][0]) assert np.all(weight[~mask] == weight[~mask][0]) - assert weight[~mask][0] < weight[mask][0] and weight[mask][0] < 1 + assert weight[~mask][0] < weight[mask][0] and weight[mask][0] <= 1 def test_update(): diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 27c6d65..956099b 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -54,6 +54,8 @@ def get_args(): def test_qrdqn(args=get_args()): env = gym.make(args.task) + if args.task == 'CartPole-v0': + env.spec.reward_threshold = 190 # lower the goal args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n # train_envs = gym.make(args.task) diff --git a/test/discrete/test_qrdqn_il_cql.py b/test/discrete/test_qrdqn_il_cql.py index 7a782dd..423506b 100644 --- a/test/discrete/test_qrdqn_il_cql.py +++ b/test/discrete/test_qrdqn_il_cql.py @@ -50,7 +50,7 @@ def test_discrete_cql(args=get_args()): # envs env = gym.make(args.task) if args.task == 'CartPole-v0': - env.spec.reward_threshold = 190 # lower the goal + env.spec.reward_threshold = 185 # lower the goal args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n test_envs = DummyVectorEnv( diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py new file mode 100644 index 0000000..cb48fe8 --- /dev/null +++ b/test/discrete/test_rainbow.py @@ -0,0 +1,198 @@ +import os +import gym +import torch +import pickle +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.policy import RainbowPolicy +from tianshou.utils import BasicLogger +from tianshou.env import DummyVectorEnv +from tianshou.utils.net.common import Net +from tianshou.utils.net.discrete import NoisyLinear +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--eps-test', type=float, default=0.05) + parser.add_argument('--eps-train', type=float, default=0.1) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.9) + parser.add_argument('--num-atoms', type=int, default=51) + parser.add_argument('--v-min', type=float, default=-10.) + parser.add_argument('--v-max', type=float, default=10.) + parser.add_argument('--noisy-std', type=float, default=0.1) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--target-update-freq', type=int, default=320) + parser.add_argument('--epoch', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=8000) + parser.add_argument('--step-per-collect', type=int, default=8) + parser.add_argument('--update-per-step', type=float, default=0.125) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--hidden-sizes', type=int, + nargs='*', default=[128, 128, 128, 128]) + parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument('--prioritized-replay', + action="store_true", default=False) + parser.add_argument('--alpha', type=float, default=0.6) + parser.add_argument('--beta', type=float, default=0.4) + parser.add_argument('--beta-final', type=float, default=1.) + parser.add_argument('--resume', action="store_true") + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + parser.add_argument("--save-interval", type=int, default=4) + args = parser.parse_known_args()[0] + return args + + +def test_rainbow(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # train_envs = gym.make(args.task) + # you can also use tianshou.env.SubprocVectorEnv + train_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)]) + # test_envs = gym.make(args.task) + test_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + + def noisy_linear(x, y): + return NoisyLinear(x, y, args.noisy_std) + + net = Net(args.state_shape, args.action_shape, + hidden_sizes=args.hidden_sizes, device=args.device, + softmax=True, num_atoms=args.num_atoms, + dueling_param=({"linear_layer": noisy_linear}, + {"linear_layer": noisy_linear})) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + policy = RainbowPolicy( + net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max, + args.n_step, target_update_freq=args.target_update_freq + ).to(args.device) + # buffer + if args.prioritized_replay: + buf = PrioritizedVectorReplayBuffer( + args.buffer_size, buffer_num=len(train_envs), + alpha=args.alpha, beta=args.beta, weight_norm=True) + else: + buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) + # collector + train_collector = Collector(policy, train_envs, buf, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + # policy.set_eps(1) + train_collector.collect(n_step=args.batch_size * args.training_num) + # log + log_path = os.path.join(args.logdir, args.task, 'rainbow') + writer = SummaryWriter(log_path) + logger = BasicLogger(writer, save_interval=args.save_interval) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + def train_fn(epoch, env_step): + # eps annealing, just a demo + if env_step <= 10000: + policy.set_eps(args.eps_train) + elif env_step <= 50000: + eps = args.eps_train - (env_step - 10000) / \ + 40000 * (0.9 * args.eps_train) + policy.set_eps(eps) + else: + policy.set_eps(0.1 * args.eps_train) + # beta annealing, just a demo + if args.prioritized_replay: + if env_step <= 10000: + beta = args.beta + elif env_step <= 50000: + beta = args.beta - (env_step - 10000) / \ + 40000 * (args.beta - args.beta_final) + else: + beta = args.beta_final + buf.set_beta(beta) + + def test_fn(epoch, env_step): + policy.set_eps(args.eps_test) + + def save_checkpoint_fn(epoch, env_step, gradient_step): + # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html + torch.save({ + 'model': policy.state_dict(), + 'optim': optim.state_dict(), + }, os.path.join(log_path, 'checkpoint.pth')) + pickle.dump(train_collector.buffer, + open(os.path.join(log_path, 'train_buffer.pkl'), "wb")) + + if args.resume: + # load from existing checkpoint + print(f"Loading agent under {log_path}") + ckpt_path = os.path.join(log_path, 'checkpoint.pth') + if os.path.exists(ckpt_path): + checkpoint = torch.load(ckpt_path, map_location=args.device) + policy.load_state_dict(checkpoint['model']) + policy.optim.load_state_dict(checkpoint['optim']) + print("Successfully restore policy and optim.") + else: + print("Fail to restore policy and optim.") + buffer_path = os.path.join(log_path, 'train_buffer.pkl') + if os.path.exists(buffer_path): + train_collector.buffer = pickle.load(open(buffer_path, "rb")) + print("Successfully restore buffer.") + else: + print("Fail to restore buffer.") + + # trainer + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.step_per_collect, args.test_num, + args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, + test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger, + resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn) + assert stop_fn(result['best_reward']) + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + policy.eval() + policy.set_eps(args.eps_test) + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + + +def test_rainbow_resume(args=get_args()): + args.resume = True + test_rainbow(args) + + +def test_prainbow(args=get_args()): + args.prioritized_replay = True + args.gamma = .95 + args.seed = 1 + test_rainbow(args) + + +if __name__ == '__main__': + test_rainbow(get_args()) diff --git a/tianshou/data/buffer/prio.py b/tianshou/data/buffer/prio.py index d89678a..a4357d5 100644 --- a/tianshou/data/buffer/prio.py +++ b/tianshou/data/buffer/prio.py @@ -10,13 +10,22 @@ class PrioritizedReplayBuffer(ReplayBuffer): :param float alpha: the prioritization exponent. :param float beta: the importance sample soft coefficient. + :param bool weight_norm: whether to normalize returned weights with the maximum + weight value within the batch. Default to True. .. seealso:: Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ - def __init__(self, size: int, alpha: float, beta: float, **kwargs: Any) -> None: + def __init__( + self, + size: int, + alpha: float, + beta: float, + weight_norm: bool = True, + **kwargs: Any + ) -> None: # will raise KeyError in PrioritizedVectorReplayBuffer # super().__init__(size, **kwargs) ReplayBuffer.__init__(self, size, **kwargs) @@ -27,6 +36,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): self.weight = SegmentTree(size) self.__eps = np.finfo(np.float32).eps.item() self.options.update(alpha=alpha, beta=beta) + self._weight_norm = weight_norm def init_weight(self, index: Union[int, np.ndarray]) -> None: self.weight[index] = self._max_prio ** self._alpha @@ -83,5 +93,10 @@ class PrioritizedReplayBuffer(ReplayBuffer): else: indices = index batch = super().__getitem__(indices) - batch.weight = self.get_weight(indices) + weight = self.get_weight(indices) + # ref: https://github.com/Kaixhin/Rainbow/blob/master/memory.py L154 + batch.weight = weight / np.max(weight) if self._weight_norm else weight return batch + + def set_beta(self, beta: float) -> None: + self._beta = beta diff --git a/tianshou/data/buffer/vecbuf.py b/tianshou/data/buffer/vecbuf.py index 1cfeae9..374765b 100644 --- a/tianshou/data/buffer/vecbuf.py +++ b/tianshou/data/buffer/vecbuf.py @@ -55,3 +55,7 @@ class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager): PrioritizedReplayBuffer(size, **kwargs) for _ in range(buffer_num) ] super().__init__(buffer_list) + + def set_beta(self, beta: float) -> None: + for buffer in self.buffers: + buffer.set_beta(beta) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 9dd879c..8a9c647 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -2,6 +2,7 @@ from tianshou.policy.base import BasePolicy from tianshou.policy.random import RandomPolicy from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.modelfree.c51 import C51Policy +from tianshou.policy.modelfree.rainbow import RainbowPolicy from tianshou.policy.modelfree.qrdqn import QRDQNPolicy from tianshou.policy.modelfree.iqn import IQNPolicy from tianshou.policy.modelfree.fqf import FQFPolicy @@ -27,6 +28,7 @@ __all__ = [ "RandomPolicy", "DQNPolicy", "C51Policy", + "RainbowPolicy", "QRDQNPolicy", "IQNPolicy", "FQFPolicy", diff --git a/tianshou/policy/modelfree/rainbow.py b/tianshou/policy/modelfree/rainbow.py new file mode 100644 index 0000000..7aa4c68 --- /dev/null +++ b/tianshou/policy/modelfree/rainbow.py @@ -0,0 +1,37 @@ +from typing import Any, Dict + +from tianshou.policy import C51Policy +from tianshou.data import Batch +from tianshou.utils.net.discrete import sample_noise + + +class RainbowPolicy(C51Policy): + """Implementation of Rainbow DQN. arXiv:1710.02298. + + :param torch.nn.Module model: a model following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. + :param float discount_factor: in [0, 1]. + :param int num_atoms: the number of atoms in the support set of the + value distribution. Default to 51. + :param float v_min: the value of the smallest atom in the support set. + Default to -10.0. + :param float v_max: the value of the largest atom in the support set. + Default to 10.0. + :param int estimation_step: the number of steps to look ahead. Default to 1. + :param int target_update_freq: the target network update frequency (0 if + you do not use the target network). Default to 0. + :param bool reward_normalization: normalize the reward to Normal(0, 1). + Default to False. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.C51Policy` for more detailed + explanation. + """ + + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: + sample_noise(self.model) + if self._target and sample_noise(self.model_old): + self.model_old.train() # so that NoisyLinear takes effect + return super().learn(batch, **kwargs) diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 664b488..cb11abc 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -11,10 +11,11 @@ def miniblock( output_size: int = 0, norm_layer: Optional[ModuleType] = None, activation: Optional[ModuleType] = None, + linear_layer: Type[nn.Linear] = nn.Linear, ) -> List[nn.Module]: """Construct a miniblock with given input/output-size, norm layer and \ activation.""" - layers: List[nn.Module] = [nn.Linear(input_size, output_size)] + layers: List[nn.Module] = [linear_layer(input_size, output_size)] if norm_layer is not None: layers += [norm_layer(output_size)] # type: ignore if activation is not None: @@ -42,6 +43,8 @@ class MLP(nn.Module): the same actvition for all layers if passed in nn.Module, or different activation for different Modules if passed in a list. Default to nn.ReLU. + :param device: which device to create this model on. Default to None. + :param linear_layer: use this module as linear layer. Default to nn.Linear. """ def __init__( @@ -52,6 +55,7 @@ class MLP(nn.Module): norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None, activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU, device: Optional[Union[str, int, torch.device]] = None, + linear_layer: Type[nn.Linear] = nn.Linear, ) -> None: super().__init__() self.device = device @@ -78,9 +82,9 @@ class MLP(nn.Module): for in_dim, out_dim, norm, activ in zip( hidden_sizes[:-1], hidden_sizes[1:], norm_layer_list, activation_list): - model += miniblock(in_dim, out_dim, norm, activ) + model += miniblock(in_dim, out_dim, norm, activ, linear_layer) if output_dim > 0: - model += [nn.Linear(hidden_sizes[-1], output_dim)] + model += [linear_layer(hidden_sizes[-1], output_dim)] self.output_dim = output_dim or hidden_sizes[-1] self.model = nn.Sequential(*model) @@ -168,10 +172,10 @@ class Net(nn.Module): q_output_dim, v_output_dim = action_dim, num_atoms q_kwargs: Dict[str, Any] = { **q_kwargs, "input_dim": self.output_dim, - "output_dim": q_output_dim} + "output_dim": q_output_dim, "device": self.device} v_kwargs: Dict[str, Any] = { **v_kwargs, "input_dim": self.output_dim, - "output_dim": v_output_dim} + "output_dim": v_output_dim, "device": self.device} self.Q, self.V = MLP(**q_kwargs), MLP(**v_kwargs) self.output_dim = self.Q.output_dim diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index b104b4e..200ae9d 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -307,3 +307,83 @@ class FullQuantileFunction(ImplicitQuantileNetwork): with torch.no_grad(): quantiles_tau = self._compute_quantiles(logits, taus[:, 1:-1]) return (quantiles, fractions, quantiles_tau), h + + +class NoisyLinear(nn.Module): + """Implementation of Noisy Networks. arXiv:1706.10295. + + :param int in_features: the number of input features. + :param int out_features: the number of output features. + :param float noisy_std: initial standard deviation of noisy linear layers. + + .. note:: + + Adapted from https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master + /fqf_iqn_qrdqn/network.py . + """ + + def __init__( + self, in_features: int, out_features: int, noisy_std: float = 0.5 + ) -> None: + super().__init__() + + # Learnable parameters. + self.mu_W = nn.Parameter( + torch.FloatTensor(out_features, in_features)) + self.sigma_W = nn.Parameter( + torch.FloatTensor(out_features, in_features)) + self.mu_bias = nn.Parameter(torch.FloatTensor(out_features)) + self.sigma_bias = nn.Parameter(torch.FloatTensor(out_features)) + + # Factorized noise parameters. + self.register_buffer('eps_p', torch.FloatTensor(in_features)) + self.register_buffer('eps_q', torch.FloatTensor(out_features)) + + self.in_features = in_features + self.out_features = out_features + self.sigma = noisy_std + + self.reset() + self.sample() + + def reset(self) -> None: + bound = 1 / np.sqrt(self.in_features) + self.mu_W.data.uniform_(-bound, bound) + self.mu_bias.data.uniform_(-bound, bound) + self.sigma_W.data.fill_(self.sigma / np.sqrt(self.in_features)) + self.sigma_bias.data.fill_(self.sigma / np.sqrt(self.in_features)) + + def f(self, x: torch.Tensor) -> torch.Tensor: + x = torch.randn(x.size(0), device=x.device) + return x.sign().mul_(x.abs().sqrt_()) + + def sample(self) -> None: + self.eps_p.copy_(self.f(self.eps_p)) # type: ignore + self.eps_q.copy_(self.f(self.eps_q)) # type: ignore + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.training: + weight = self.mu_W + self.sigma_W * ( + self.eps_q.ger(self.eps_p) # type: ignore + ) + bias = self.mu_bias + self.sigma_bias * self.eps_q.clone() # type: ignore + else: + weight = self.mu_W + bias = self.mu_bias + + return F.linear(x, weight, bias) + + +def sample_noise(model: nn.Module) -> bool: + """Sample the random noises of NoisyLinear modules in the model. + + :param model: a PyTorch module which may have NoisyLinear submodules. + :returns: True if model has at least one NoisyLinear submodule; + otherwise, False. + """ + done = False + for m in model.modules(): + if isinstance(m, NoisyLinear): + m.sample() + done = True + return done