diff --git a/README.md b/README.md index 4c2a0af..8f442e2 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ - [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf) - [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf) - [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf) +- [Implicit Quantile Network (IQN)](https://arxiv.org/pdf/1806.06923.pdf) - [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf) - [Natural Policy Gradient (NPG)](https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf) - [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/) diff --git a/docs/api/tianshou.policy.rst b/docs/api/tianshou.policy.rst index 32f124f..d857943 100644 --- a/docs/api/tianshou.policy.rst +++ b/docs/api/tianshou.policy.rst @@ -35,6 +35,11 @@ DQN Family :undoc-members: :show-inheritance: +.. autoclass:: tianshou.policy.IQNPolicy + :members: + :undoc-members: + :show-inheritance: + On-policy ~~~~~~~~~ diff --git a/docs/index.rst b/docs/index.rst index fe8f7d7..87189fd 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -14,6 +14,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.DQNPolicy` `Dueling DQN `_ * :class:`~tianshou.policy.C51Policy` `Categorical DQN `_ * :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN `_ +* :class:`~tianshou.policy.IQNPolicy` `Implicit Quantile Network `_ * :class:`~tianshou.policy.PGPolicy` `Policy Gradient `_ * :class:`~tianshou.policy.NPGPolicy` `Natural Policy Gradient `_ * :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic `_ diff --git a/examples/atari/README.md b/examples/atari/README.md index fb45858..47c50d3 100644 --- a/examples/atari/README.md +++ b/examples/atari/README.md @@ -54,6 +54,20 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | SeaquestNoFrameskip-v4 | 5676 | ![](results/qrdqn/Seaquest_rew.png) | `python3 atari_qrdqn.py --task "SeaquestNoFrameskip-v4"` | | SpaceInvadersNoFrameskip-v4 | 938 | ![](results/qrdqn/SpaceInvader_rew.png) | `python3 atari_qrdqn.py --task "SpaceInvadersNoFrameskip-v4"` | +# IQN (single run) + +One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. + +| task | best reward | reward curve | parameters | +| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20.9 | ![](results/iqn/Pong_rew.png) | `python3 atari_iqn.py --task "PongNoFrameskip-v4" --batch-size 64` | +| BreakoutNoFrameskip-v4 | 578.3 | ![](results/iqn/Breakout_rew.png) | `python3 atari_iqn.py --task "BreakoutNoFrameskip-v4" --n-step 1` | +| EnduroNoFrameskip-v4 | 1507 | ![](results/iqn/Enduro_rew.png) | `python3 atari_iqn.py --task "EnduroNoFrameskip-v4"` | +| QbertNoFrameskip-v4 | 15520 | ![](results/iqn/Qbert_rew.png) | `python3 atari_iqn.py --task "QbertNoFrameskip-v4"` | +| MsPacmanNoFrameskip-v4 | 2911 | ![](results/iqn/MsPacman_rew.png) | `python3 atari_iqn.py --task "MsPacmanNoFrameskip-v4"` | +| SeaquestNoFrameskip-v4 | 6236 | ![](results/iqn/Seaquest_rew.png) | `python3 atari_iqn.py --task "SeaquestNoFrameskip-v4"` | +| SpaceInvadersNoFrameskip-v4 | 1370 | ![](results/iqn/SpaceInvader_rew.png) | `python3 atari_iqn.py --task "SpaceInvadersNoFrameskip-v4"` | + # BCQ To running BCQ algorithm on Atari, you need to do the following things: diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 4642b87..558284f 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -107,7 +107,7 @@ def test_c51(args=get_args()): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): - if env.env.spec.reward_threshold: + if env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold elif 'Pong' in args.task: return mean_rewards >= 20 diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 00476c3..a7785c8 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -102,7 +102,7 @@ def test_dqn(args=get_args()): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): - if env.env.spec.reward_threshold: + if env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold elif 'Pong' in args.task: return mean_rewards >= 20 diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py new file mode 100644 index 0000000..ad34c8f --- /dev/null +++ b/examples/atari/atari_iqn.py @@ -0,0 +1,183 @@ +import os +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.policy import IQNPolicy +from tianshou.utils import BasicLogger +from tianshou.env import SubprocVectorEnv +from tianshou.trainer import offpolicy_trainer +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.utils.net.discrete import ImplicitQuantileNetwork + +from atari_network import DQN +from atari_wrapper import wrap_deepmind + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='PongNoFrameskip-v4') + parser.add_argument('--seed', type=int, default=1234) + parser.add_argument('--eps-test', type=float, default=0.005) + parser.add_argument('--eps-train', type=float, default=1.) + parser.add_argument('--eps-train-final', type=float, default=0.05) + parser.add_argument('--buffer-size', type=int, default=100000) + parser.add_argument('--lr', type=float, default=0.0001) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--sample-size', type=int, default=32) + parser.add_argument('--online-sample-size', type=int, default=8) + parser.add_argument('--target-sample-size', type=int, default=8) + parser.add_argument('--num-cosines', type=int, default=64) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512]) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--target-update-freq', type=int, default=500) + parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--step-per-epoch', type=int, default=100000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) + parser.add_argument('--batch-size', type=int, default=32) + parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--test-num', type=int, default=10) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + parser.add_argument('--frames-stack', type=int, default=4) + parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument('--watch', default=False, action='store_true', + help='watch the play of pre-trained policy only') + parser.add_argument('--save-buffer-name', type=str, default=None) + return parser.parse_args() + + +def make_atari_env(args): + return wrap_deepmind(args.task, frame_stack=args.frames_stack) + + +def make_atari_env_watch(args): + return wrap_deepmind(args.task, frame_stack=args.frames_stack, + episode_life=False, clip_rewards=False) + + +def test_iqn(args=get_args()): + env = make_atari_env(args) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + # make environments + train_envs = SubprocVectorEnv([lambda: make_atari_env(args) + for _ in range(args.training_num)]) + test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) + for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # define model + feature_net = DQN(*args.state_shape, args.action_shape, args.device, + features_only=True) + net = ImplicitQuantileNetwork( + feature_net, args.action_shape, args.hidden_sizes, + num_cosines=args.num_cosines, device=args.device + ).to(args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + # define policy + policy = IQNPolicy( + net, optim, args.gamma, args.sample_size, args.online_sample_size, + args.target_sample_size, args.n_step, + target_update_freq=args.target_update_freq + ).to(args.device) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + # replay buffer: `save_last_obs` and `stack_num` can be removed together + # when you have enough RAM + buffer = VectorReplayBuffer( + args.buffer_size, buffer_num=len(train_envs), + ignore_obs_next=True, save_only_last_obs=True, stack_num=args.frames_stack) + # collector + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + # log + log_path = os.path.join(args.logdir, args.task, 'iqn') + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = BasicLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + if env.spec.reward_threshold: + return mean_rewards >= env.spec.reward_threshold + elif 'Pong' in args.task: + return mean_rewards >= 20 + else: + return False + + def train_fn(epoch, env_step): + # nature DQN setting, linear decay in the first 1M steps + if env_step <= 1e6: + eps = args.eps_train - env_step / 1e6 * \ + (args.eps_train - args.eps_train_final) + else: + eps = args.eps_train_final + policy.set_eps(eps) + logger.write('train/eps', env_step, eps) + + def test_fn(epoch, env_step): + policy.set_eps(args.eps_test) + + # watch agent's performance + def watch(): + print("Setup test envs ...") + policy.eval() + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + if args.save_buffer_name: + print(f"Generate buffer with size {args.buffer_size}") + buffer = VectorReplayBuffer( + args.buffer_size, buffer_num=len(test_envs), + ignore_obs_next=True, save_only_last_obs=True, + stack_num=args.frames_stack) + collector = Collector(policy, test_envs, buffer, + exploration_noise=True) + result = collector.collect(n_step=args.buffer_size) + print(f"Save buffer into {args.save_buffer_name}") + # Unfortunately, pickle will cause oom with 1M buffer size + buffer.save_hdf5(args.save_buffer_name) + else: + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, + render=args.render) + rew = result["rews"].mean() + print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + + if args.watch: + watch() + exit(0) + + # test train_collector and start filling replay buffer + train_collector.collect(n_step=args.batch_size * args.training_num) + # trainer + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.step_per_collect, args.test_num, + args.batch_size, train_fn=train_fn, test_fn=test_fn, + stop_fn=stop_fn, save_fn=save_fn, logger=logger, + update_per_step=args.update_per_step, test_in_train=False) + + pprint.pprint(result) + watch() + + +if __name__ == '__main__': + test_iqn(get_args()) diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 061ad6a..6677f68 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -105,7 +105,7 @@ def test_qrdqn(args=get_args()): torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) def stop_fn(mean_rewards): - if env.env.spec.reward_threshold: + if env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold elif 'Pong' in args.task: return mean_rewards >= 20 diff --git a/examples/atari/results/iqn/Breakout_rew.png b/examples/atari/results/iqn/Breakout_rew.png new file mode 100644 index 0000000..d486bf6 Binary files /dev/null and b/examples/atari/results/iqn/Breakout_rew.png differ diff --git a/examples/atari/results/iqn/Enduro_rew.png b/examples/atari/results/iqn/Enduro_rew.png new file mode 100644 index 0000000..6d0a8a0 Binary files /dev/null and b/examples/atari/results/iqn/Enduro_rew.png differ diff --git a/examples/atari/results/iqn/MsPacman_rew.png b/examples/atari/results/iqn/MsPacman_rew.png new file mode 100644 index 0000000..da9bcba Binary files /dev/null and b/examples/atari/results/iqn/MsPacman_rew.png differ diff --git a/examples/atari/results/iqn/Pong_rew.png b/examples/atari/results/iqn/Pong_rew.png new file mode 100644 index 0000000..d5c91bd Binary files /dev/null and b/examples/atari/results/iqn/Pong_rew.png differ diff --git a/examples/atari/results/iqn/Qbert_rew.png b/examples/atari/results/iqn/Qbert_rew.png new file mode 100644 index 0000000..749e818 Binary files /dev/null and b/examples/atari/results/iqn/Qbert_rew.png differ diff --git a/examples/atari/results/iqn/Seaquest_rew.png b/examples/atari/results/iqn/Seaquest_rew.png new file mode 100644 index 0000000..0b14469 Binary files /dev/null and b/examples/atari/results/iqn/Seaquest_rew.png differ diff --git a/examples/atari/results/iqn/SpaceInvaders_rew.png b/examples/atari/results/iqn/SpaceInvaders_rew.png new file mode 100644 index 0000000..39b67dd Binary files /dev/null and b/examples/atari/results/iqn/SpaceInvaders_rew.png differ diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py new file mode 100644 index 0000000..f40407f --- /dev/null +++ b/test/discrete/test_iqn.py @@ -0,0 +1,149 @@ +import os +import gym +import torch +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.policy import IQNPolicy +from tianshou.utils import BasicLogger +from tianshou.env import DummyVectorEnv +from tianshou.utils.net.common import Net +from tianshou.trainer import offpolicy_trainer +from tianshou.utils.net.discrete import ImplicitQuantileNetwork +from tianshou.data import Collector, VectorReplayBuffer, PrioritizedVectorReplayBuffer + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--eps-test', type=float, default=0.05) + parser.add_argument('--eps-train', type=float, default=0.1) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--lr', type=float, default=3e-3) + parser.add_argument('--gamma', type=float, default=0.9) + parser.add_argument('--sample-size', type=int, default=32) + parser.add_argument('--online-sample-size', type=int, default=8) + parser.add_argument('--target-sample-size', type=int, default=8) + parser.add_argument('--num-cosines', type=int, default=64) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--target-update-freq', type=int, default=320) + parser.add_argument('--epoch', type=int, default=10) + parser.add_argument('--step-per-epoch', type=int, default=10000) + parser.add_argument('--step-per-collect', type=int, default=10) + parser.add_argument('--update-per-step', type=float, default=0.1) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--hidden-sizes', type=int, + nargs='*', default=[64, 64, 64]) + parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=0.) + parser.add_argument('--prioritized-replay', + action="store_true", default=False) + parser.add_argument('--alpha', type=float, default=0.6) + parser.add_argument('--beta', type=float, default=0.4) + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + args = parser.parse_known_args()[0] + return args + + +def test_iqn(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # train_envs = gym.make(args.task) + # you can also use tianshou.env.SubprocVectorEnv + train_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)]) + # test_envs = gym.make(args.task) + test_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + feature_net = Net(args.state_shape, args.hidden_sizes[-1], + hidden_sizes=args.hidden_sizes[:-1], device=args.device, + softmax=False) + net = ImplicitQuantileNetwork( + feature_net, args.action_shape, + num_cosines=args.num_cosines, device=args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + policy = IQNPolicy( + net, optim, args.gamma, args.sample_size, args.online_sample_size, + args.target_sample_size, args.n_step, + target_update_freq=args.target_update_freq + ).to(args.device) + # buffer + if args.prioritized_replay: + buf = PrioritizedVectorReplayBuffer( + args.buffer_size, buffer_num=len(train_envs), + alpha=args.alpha, beta=args.beta) + else: + buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) + # collector + train_collector = Collector(policy, train_envs, buf, exploration_noise=True) + test_collector = Collector(policy, test_envs, exploration_noise=True) + # policy.set_eps(1) + train_collector.collect(n_step=args.batch_size * args.training_num) + # log + log_path = os.path.join(args.logdir, args.task, 'iqn') + writer = SummaryWriter(log_path) + logger = BasicLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + def train_fn(epoch, env_step): + # eps annnealing, just a demo + if env_step <= 10000: + policy.set_eps(args.eps_train) + elif env_step <= 50000: + eps = args.eps_train - (env_step - 10000) / \ + 40000 * (0.9 * args.eps_train) + policy.set_eps(eps) + else: + policy.set_eps(0.1 * args.eps_train) + + def test_fn(epoch, env_step): + policy.set_eps(args.eps_test) + + # trainer + result = offpolicy_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.step_per_collect, args.test_num, + args.batch_size, train_fn=train_fn, test_fn=test_fn, + stop_fn=stop_fn, save_fn=save_fn, logger=logger, + update_per_step=args.update_per_step) + assert stop_fn(result['best_reward']) + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + policy.eval() + policy.set_eps(args.eps_test) + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + + +def test_piqn(args=get_args()): + args.prioritized_replay = True + args.gamma = .95 + test_iqn(args) + + +if __name__ == '__main__': + test_iqn(get_args()) diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 54f0f4a..a350891 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -3,6 +3,7 @@ from tianshou.policy.random import RandomPolicy from tianshou.policy.modelfree.dqn import DQNPolicy from tianshou.policy.modelfree.c51 import C51Policy from tianshou.policy.modelfree.qrdqn import QRDQNPolicy +from tianshou.policy.modelfree.iqn import IQNPolicy from tianshou.policy.modelfree.pg import PGPolicy from tianshou.policy.modelfree.a2c import A2CPolicy from tianshou.policy.modelfree.npg import NPGPolicy @@ -26,6 +27,7 @@ __all__ = [ "DQNPolicy", "C51Policy", "QRDQNPolicy", + "IQNPolicy", "PGPolicy", "A2CPolicy", "NPGPolicy", diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py new file mode 100644 index 0000000..4c54d35 --- /dev/null +++ b/tianshou/policy/modelfree/iqn.py @@ -0,0 +1,105 @@ +import torch +import numpy as np +import torch.nn.functional as F +from typing import Any, Dict, Optional, Union + +from tianshou.policy import QRDQNPolicy +from tianshou.data import Batch, to_numpy + + +class IQNPolicy(QRDQNPolicy): + """Implementation of Implicit Quantile Network. arXiv:1806.06923. + + :param torch.nn.Module model: a model following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. + :param float discount_factor: in [0, 1]. + :param int sample_size: the number of samples for policy evaluation. + Default to 32. + :param int online_sample_size: the number of samples for online model + in training. Default to 8. + :param int target_sample_size: the number of samples for target model + in training. Default to 8. + :param int estimation_step: the number of steps to look ahead. Default to 1. + :param int target_update_freq: the target network update frequency (0 if + you do not use the target network). + :param bool reward_normalization: normalize the reward to Normal(0, 1). + Default to False. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed + explanation. + """ + + def __init__( + self, + model: torch.nn.Module, + optim: torch.optim.Optimizer, + discount_factor: float = 0.99, + sample_size: int = 32, + online_sample_size: int = 8, + target_sample_size: int = 8, + estimation_step: int = 1, + target_update_freq: int = 0, + reward_normalization: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(model, optim, discount_factor, sample_size, estimation_step, + target_update_freq, reward_normalization, **kwargs) + assert sample_size > 1, "sample_size should be greater than 1" + assert online_sample_size > 1, "online_sample_size should be greater than 1" + assert target_sample_size > 1, "target_sample_size should be greater than 1" + self._sample_size = sample_size # for policy eval + self._online_sample_size = online_sample_size + self._target_sample_size = target_sample_size + + def forward( + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + model: str = "model", + input: str = "obs", + **kwargs: Any, + ) -> Batch: + if model == "model_old": + sample_size = self._target_sample_size + elif self.training: + sample_size = self._online_sample_size + else: + sample_size = self._sample_size + model = getattr(self, model) + obs = batch[input] + obs_ = obs.obs if hasattr(obs, "obs") else obs + (logits, taus), h = model( + obs_, sample_size=sample_size, state=state, info=batch.info + ) + q = self.compute_q_value(logits, getattr(obs, "mask", None)) + if not hasattr(self, "max_action_num"): + self.max_action_num = q.shape[1] + act = to_numpy(q.max(dim=1)[1]) + return Batch(logits=logits, act=act, state=h, taus=taus) + + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: + if self._target and self._iter % self._freq == 0: + self.sync_weight() + self.optim.zero_grad() + weight = batch.pop("weight", 1.0) + out = self(batch) + curr_dist, taus = out.logits, out.taus + act = batch.act + curr_dist = curr_dist[np.arange(len(act)), act, :].unsqueeze(2) + target_dist = batch.returns.unsqueeze(1) + # calculate each element's difference between curr_dist and target_dist + u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none") + huber_loss = (u * ( + taus.unsqueeze(2) - (target_dist - curr_dist).detach().le(0.).float() + ).abs()).sum(-1).mean(1) + loss = (huber_loss * weight).mean() + # ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/ + # blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130 + batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer + loss.backward() + self.optim.step() + self._iter += 1 + return {"loss": loss.item()} diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index ee1294f..14cc85e 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -111,3 +111,94 @@ class Critic(nn.Module): """Mapping: s -> V(s).""" logits, _ = self.preprocess(s, state=kwargs.get("state", None)) return self.last(logits) + + +class CosineEmbeddingNetwork(nn.Module): + """Cosine embedding network for IQN. Convert a scalar in [0, 1] to a list \ + of n-dim vectors. + + :param num_cosines: the number of cosines used for the embedding. + :param embedding_dim: the dimension of the embedding/output. + + .. note:: + + From https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/blob/master + /fqf_iqn_qrdqn/network.py . + """ + + def __init__(self, num_cosines: int, embedding_dim: int) -> None: + super().__init__() + self.net = nn.Sequential(nn.Linear(num_cosines, embedding_dim), nn.ReLU()) + self.num_cosines = num_cosines + self.embedding_dim = embedding_dim + + def forward(self, taus: torch.Tensor) -> torch.Tensor: + batch_size = taus.shape[0] + N = taus.shape[1] + # Calculate i * \pi (i=1,...,N). + i_pi = np.pi * torch.arange( + start=1, end=self.num_cosines + 1, dtype=taus.dtype, device=taus.device + ).view(1, 1, self.num_cosines) + # Calculate cos(i * \pi * \tau). + cosines = torch.cos(taus.view(batch_size, N, 1) * i_pi).view( + batch_size * N, self.num_cosines + ) + # Calculate embeddings of taus. + tau_embeddings = self.net(cosines).view(batch_size, N, self.embedding_dim) + return tau_embeddings + + +class ImplicitQuantileNetwork(Critic): + """Implicit Quantile Network. + + :param preprocess_net: a self-defined preprocess_net which output a + flattened hidden state. + :param int action_dim: the dimension of action space. + :param hidden_sizes: a sequence of int for constructing the MLP after + preprocess_net. Default to empty sequence (where the MLP now contains + only a single linear layer). + :param int num_cosines: the number of cosines to use for cosine embedding. + Default to 64. + :param int preprocess_net_output_dim: the output dimension of + preprocess_net. + + .. note:: + + Although this class inherits Critic, it is actually a quantile Q-Network + with output shape (batch_size, action_dim, sample_size). + + The second item of the first return value is tau vector. + """ + + def __init__( + self, + preprocess_net: nn.Module, + action_shape: Sequence[int], + hidden_sizes: Sequence[int] = (), + num_cosines: int = 64, + preprocess_net_output_dim: Optional[int] = None, + device: Union[str, int, torch.device] = "cpu" + ) -> None: + last_size = np.prod(action_shape) + super().__init__(preprocess_net, hidden_sizes, last_size, + preprocess_net_output_dim, device) + self.input_dim = getattr(preprocess_net, "output_dim", + preprocess_net_output_dim) + self.embed_model = CosineEmbeddingNetwork(num_cosines, + self.input_dim).to(device) + + def forward( # type: ignore + self, s: Union[np.ndarray, torch.Tensor], sample_size: int, **kwargs: Any + ) -> Tuple[Any, torch.Tensor]: + r"""Mapping: s -> Q(s, \*).""" + logits, h = self.preprocess(s, state=kwargs.get("state", None)) + # Sample fractions. + batch_size = logits.size(0) + taus = torch.rand(batch_size, sample_size, + dtype=logits.dtype, device=logits.device) + embedding = (logits.unsqueeze(1) * self.embed_model(taus)).view( + batch_size * sample_size, -1 + ) + out = self.last(embedding).view(batch_size, + sample_size, -1).transpose(1, 2) + return (out, taus), h