diff --git a/.gitignore b/.gitignore index 082dcef..be8453a 100644 --- a/.gitignore +++ b/.gitignore @@ -145,3 +145,5 @@ MUJOCO_LOG.TXT *.zip *.pstats *.swp +*.pkl +*.hdf5 diff --git a/README.md b/README.md index 69f95ce..2848f9c 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ - [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf) - [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf) - Vanilla Imitation Learning +- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf) - [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf) - [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf) - [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf) diff --git a/docs/index.rst b/docs/index.rst index 3b1fe02..72704f4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -20,8 +20,9 @@ Welcome to Tianshou! * :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG `_ * :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic `_ * :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic `_ -* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning `_ * :class:`~tianshou.policy.ImitationPolicy` Imitation Learning +* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning `_ +* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning `_ * :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay `_ * :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator `_ diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index 7d64458..8ea2d27 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -201,7 +201,7 @@ Trainer Once you have a collector and a policy, you can start writing the training method for your RL agent. Trainer, to be honest, is a simple wrapper. It helps you save energy for writing the training loop. You can also construct your own trainer: :ref:`customized_trainer`. -Tianshou has two types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` and :func:`~tianshou.trainer.offpolicy_trainer`, corresponding to on-policy algorithms (such as Policy Gradient) and off-policy algorithms (such as DQN). Please check out :doc:`/api/tianshou.trainer` for the usage. +Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` for on-policy algorithms such as Policy Gradient, :func:`~tianshou.trainer.offpolicy_trainer` for off-policy algorithms such as DQN, and :func:`~tianshou.trainer.offline_trainer` for offline algorithms such as BCQ. Please check out :doc:`/api/tianshou.trainer` for the usage. .. _pseudocode: diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index b1971ea..faea6a8 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -120,7 +120,7 @@ In each step, the collector will let the policy perform (at least) a specified n Train Policy with a Trainer --------------------------- -Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tianshou.trainer.offpolicy_trainer`. The trainer will automatically stop training when the policy reach the stop condition ``stop_fn`` on test collector. Since DQN is an off-policy algorithm, we use the :class:`~tianshou.trainer.offpolicy_trainer` as follows: +Tianshou provides :func:`~tianshou.trainer.onpolicy_trainer`, :func:`~tianshou.trainer.offpolicy_trainer`, and :func:`~tianshou.trainer.offline_trainer`. The trainer will automatically stop training when the policy reach the stop condition ``stop_fn`` on test collector. Since DQN is an off-policy algorithm, we use the :func:`~tianshou.trainer.offpolicy_trainer` as follows: :: result = ts.trainer.offpolicy_trainer( @@ -133,7 +133,7 @@ Tianshou provides :class:`~tianshou.trainer.onpolicy_trainer` and :class:`~tians writer=None) print(f'Finished training! Use {result["duration"]}') -The meaning of each parameter is as follows (full description can be found at :meth:`~tianshou.trainer.offpolicy_trainer`): +The meaning of each parameter is as follows (full description can be found at :func:`~tianshou.trainer.offpolicy_trainer`): * ``max_epoch``: The maximum of epochs for training. The training process might be finished before reaching the ``max_epoch``; * ``step_per_epoch``: The number of step for updating policy network in one epoch; diff --git a/examples/atari/README.md b/examples/atari/README.md index 7fd0344..933415e 100644 --- a/examples/atari/README.md +++ b/examples/atari/README.md @@ -38,4 +38,15 @@ One epoch here is equal to 100,000 env step, 100 epochs stand for 10M. | SeaquestNoFrameskip-v4 | 6226 | ![](results/c51/Seaquest_rew.png) | `python3 atari_c51.py --task "SeaquestNoFrameskip-v4"` | | SpaceInvadersNoFrameskip-v4 | 988.5 | ![](results/c51/SpaceInvader_rew.png) | `python3 atari_c51.py --task "SpaceInvadersNoFrameskip-v4"` | -Note: The selection of `n_step` is based on Figure 6 in the [Rainbow](https://arxiv.org/abs/1710.02298) paper. \ No newline at end of file +Note: The selection of `n_step` is based on Figure 6 in the [Rainbow](https://arxiv.org/abs/1710.02298) paper. + +# BCQ + +TODO: after the `done` issue fixed, the result should be re-tuned and place here. + +To running BCQ algorithm on Atari, you need to do the following things: + +- Train an expert, by using the command listed in the above DQN section; +- Generate buffer with noise: `python3 atari_dqn.py --task {your_task} --watch --resume-path log/{your_task}/dqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error); +- Train BCQ: `python3 atari_bcq.py --task {your_task} --load-buffer-name expert.hdf5`. + diff --git a/examples/atari/atari_bcq.py b/examples/atari/atari_bcq.py new file mode 100644 index 0000000..e9edb83 --- /dev/null +++ b/examples/atari/atari_bcq.py @@ -0,0 +1,153 @@ +import os +import torch +import pickle +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.env import SubprocVectorEnv +from tianshou.trainer import offline_trainer +from tianshou.utils.net.discrete import Actor +from tianshou.policy import DiscreteBCQPolicy +from tianshou.data import Collector, ReplayBuffer + +from atari_network import DQN +from atari_wrapper import wrap_deepmind + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") + parser.add_argument("--seed", type=int, default=1626) + parser.add_argument("--eps-test", type=float, default=0.001) + parser.add_argument("--lr", type=float, default=6.25e-5) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--target-update-freq", type=int, default=8000) + parser.add_argument("--unlikely-action-threshold", type=float, default=0.3) + parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) + parser.add_argument("--epoch", type=int, default=100) + parser.add_argument("--step-per-epoch", type=int, default=10000) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument('--hidden-sizes', type=int, + nargs='*', default=[512]) + parser.add_argument("--test-num", type=int, default=100) + parser.add_argument('--frames_stack', type=int, default=4) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--watch", default=False, action="store_true", + help="watch the play of pre-trained policy only") + parser.add_argument("--log-interval", type=int, default=1000) + parser.add_argument( + "--load-buffer-name", type=str, + default="./expert_DQN_PongNoFrameskip-v4.hdf5", + ) + parser.add_argument( + "--device", type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + ) + args = parser.parse_known_args()[0] + return args + + +def make_atari_env(args): + return wrap_deepmind(args.task, frame_stack=args.frames_stack) + + +def make_atari_env_watch(args): + return wrap_deepmind(args.task, frame_stack=args.frames_stack, + episode_life=False, clip_rewards=False) + + +def test_discrete_bcq(args=get_args()): + # envs + env = make_atari_env(args) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + # make environments + test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) + for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + test_envs.seed(args.seed) + # model + feature_net = DQN(*args.state_shape, args.action_shape, + device=args.device, features_only=True).to(args.device) + policy_net = Actor(feature_net, args.action_shape, + hidden_sizes=args.hidden_sizes).to(args.device) + imitation_net = Actor(feature_net, args.action_shape, + hidden_sizes=args.hidden_sizes).to(args.device) + optim = torch.optim.Adam( + set(policy_net.parameters()).union(imitation_net.parameters()), + lr=args.lr, + ) + # define policy + policy = DiscreteBCQPolicy( + policy_net, imitation_net, optim, args.gamma, args.n_step, + args.target_update_freq, args.eps_test, + args.unlikely_action_threshold, args.imitation_logits_penalty, + ) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load( + args.resume_path, map_location=args.device + )) + print("Loaded agent from: ", args.resume_path) + # buffer + assert os.path.exists(args.load_buffer_name), \ + "Please run atari_dqn.py first to get expert's data buffer." + if args.load_buffer_name.endswith('.pkl'): + buffer = pickle.load(open(args.load_buffer_name, "rb")) + elif args.load_buffer_name.endswith('.hdf5'): + buffer = ReplayBuffer.load_hdf5(args.load_buffer_name) + else: + print(f"Unknown buffer format: {args.load_buffer_name}") + exit(0) + + # collector + test_collector = Collector(policy, test_envs) + + log_path = os.path.join(args.logdir, args.task, 'discrete_bcq') + writer = SummaryWriter(log_path) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return False + + # watch agent's performance + def watch(): + print("Setup test envs ...") + policy.eval() + policy.set_eps(args.eps_test) + test_envs.seed(args.seed) + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) + pprint.pprint(result) + + if args.watch: + watch() + exit(0) + + result = offline_trainer( + policy, buffer, test_collector, + args.epoch, args.step_per_epoch, args.test_num, args.batch_size, + stop_fn=stop_fn, save_fn=save_fn, writer=writer, + log_interval=args.log_interval, + ) + + pprint.pprint(result) + watch() + + +if __name__ == "__main__": + test_discrete_bcq(get_args()) diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index cf43189..4f4c4f2 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -41,6 +41,7 @@ def get_args(): parser.add_argument('--resume-path', type=str, default=None) parser.add_argument('--watch', default=False, action='store_true', help='watch the play of pre-trained policy only') + parser.add_argument('--save-buffer-name', type=str, default=None) return parser.parse_args() @@ -120,13 +121,25 @@ def test_dqn(args=get_args()): # watch agent's performance def watch(): - print("Testing agent ...") + print("Setup test envs ...") policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) - test_collector.reset() - result = test_collector.collect(n_episode=[1] * args.test_num, - render=args.render) + if args.save_buffer_name: + print(f"Generate buffer with size {args.buffer_size}") + buffer = ReplayBuffer( + args.buffer_size, ignore_obs_next=True, + save_only_last_obs=True, stack_num=args.frames_stack) + collector = Collector(policy, test_envs, buffer) + result = collector.collect(n_step=args.buffer_size) + print(f"Save buffer into {args.save_buffer_name}") + # Unfortunately, pickle will cause oom with 1M buffer size + buffer.save_hdf5(args.save_buffer_name) + else: + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect(n_episode=[1] * args.test_num, + render=args.render) pprint.pprint(result) if args.watch: diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index 0b7adaa..c31a6c8 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -44,8 +44,7 @@ class DQN(nn.Module): info: Dict[str, Any] = {}, ) -> Tuple[torch.Tensor, Any]: r"""Mapping: x -> Q(x, \*).""" - x = torch.as_tensor( - x, device=self.device, dtype=torch.float32) # type: ignore + x = torch.as_tensor(x, device=self.device, dtype=torch.float32) return self.net(x), state diff --git a/setup.py b/setup.py index 5af0b20..4b22472 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,7 @@ setup( "tensorboard", "torch>=1.4.0", "numba>=0.51.0", - "h5py>=3.1.0" + "h5py>=2.10.0", # to match tensorflow's minimal requirements ], extras_require={ "dev": [ diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index 541f249..1e9f089 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -1,6 +1,7 @@ import os import gym import torch +import pickle import pprint import argparse import numpy as np @@ -38,6 +39,9 @@ def get_args(): action="store_true", default=False) parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--beta', type=float, default=0.4) + parser.add_argument( + '--save-buffer-name', type=str, + default="./expert_DQN_CartPole-v0.pkl") parser.add_argument( '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') @@ -114,6 +118,7 @@ def test_dqn(args=get_args()): stop_fn=stop_fn, save_fn=save_fn, writer=writer) assert stop_fn(result['best_reward']) + if __name__ == '__main__': pprint.pprint(result) # Let's watch its performance! @@ -124,6 +129,12 @@ def test_dqn(args=get_args()): result = collector.collect(n_episode=1, render=args.render) print(f'Final reward: {result["rew"]}, length: {result["len"]}') + # save buffer in pickle format, for imitation learning unittest + buf = ReplayBuffer(args.buffer_size) + collector = Collector(policy, test_envs, buf) + collector.collect(n_step=args.buffer_size) + pickle.dump(buf, open(args.save_buffer_name, "wb")) + def test_pdqn(args=get_args()): args.prioritized_replay = True diff --git a/test/discrete/test_il_bcq.py b/test/discrete/test_il_bcq.py new file mode 100644 index 0000000..e9e857e --- /dev/null +++ b/test/discrete/test_il_bcq.py @@ -0,0 +1,111 @@ +import os +import gym +import torch +import pickle +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector +from tianshou.env import DummyVectorEnv +from tianshou.utils.net.common import Net +from tianshou.trainer import offline_trainer +from tianshou.policy import DiscreteBCQPolicy + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--seed", type=int, default=1626) + parser.add_argument("--eps-test", type=float, default=0.001) + parser.add_argument("--lr", type=float, default=3e-4) + parser.add_argument("--gamma", type=float, default=0.9) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--unlikely-action-threshold", type=float, default=0.3) + parser.add_argument("--imitation-logits-penalty", type=float, default=0.01) + parser.add_argument("--epoch", type=int, default=5) + parser.add_argument("--step-per-epoch", type=int, default=1000) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument('--hidden-sizes', type=int, + nargs='*', default=[128, 128, 128]) + parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.) + parser.add_argument( + "--load-buffer-name", type=str, + default="./expert_DQN_CartPole-v0.pkl", + ) + parser.add_argument( + "--device", type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + ) + args = parser.parse_known_args()[0] + return args + + +def test_discrete_bcq(args=get_args()): + # envs + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + test_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + test_envs.seed(args.seed) + # model + policy_net = Net( + args.state_shape, args.action_shape, + hidden_sizes=args.hidden_sizes, device=args.device).to(args.device) + imitation_net = Net( + args.state_shape, args.action_shape, + hidden_sizes=args.hidden_sizes, device=args.device).to(args.device) + optim = torch.optim.Adam( + set(policy_net.parameters()).union(imitation_net.parameters()), + lr=args.lr) + + policy = DiscreteBCQPolicy( + policy_net, imitation_net, optim, args.gamma, args.n_step, + args.target_update_freq, args.eps_test, + args.unlikely_action_threshold, args.imitation_logits_penalty, + ) + # buffer + assert os.path.exists(args.load_buffer_name), \ + "Please run test_dqn.py first to get expert's data buffer." + buffer = pickle.load(open(args.load_buffer_name, "rb")) + + # collector + test_collector = Collector(policy, test_envs) + + log_path = os.path.join(args.logdir, args.task, 'discrete_bcq') + writer = SummaryWriter(log_path) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + result = offline_trainer( + policy, buffer, test_collector, + args.epoch, args.step_per_epoch, args.test_num, args.batch_size, + stop_fn=stop_fn, save_fn=save_fn, writer=writer) + + assert stop_fn(result['best_reward']) + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + policy.eval() + policy.set_eps(args.eps_test) + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + + +if __name__ == "__main__": + test_discrete_bcq(get_args()) diff --git a/tianshou/__init__.py b/tianshou/__init__.py index cc37f27..0f7a59c 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,7 +1,7 @@ from tianshou import data, env, utils, policy, trainer, exploration -__version__ = "0.3.0" +__version__ = "0.3.1" __all__ = [ "env", diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 9626d97..968aaf6 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -10,6 +10,7 @@ from tianshou.policy.modelfree.ppo import PPOPolicy from tianshou.policy.modelfree.td3 import TD3Policy from tianshou.policy.modelfree.sac import SACPolicy from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy +from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy from tianshou.policy.modelbase.psrl import PSRLPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager @@ -27,6 +28,7 @@ __all__ = [ "TD3Policy", "SACPolicy", "DiscreteSACPolicy", + "DiscreteBCQPolicy", "PSRLPolicy", "MultiAgentPolicyManager", ] diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index d65cbc8..954bc81 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -57,7 +57,7 @@ class ImitationPolicy(BasePolicy): a_ = to_torch(batch.act, dtype=torch.float32, device=a.device) loss = F.mse_loss(a, a_) # type: ignore elif self.mode == "discrete": # classification - a = self(batch).logits + a = F.log_softmax(self(batch).logits, dim=-1) a_ = to_torch(batch.act, dtype=torch.long, device=a.device) loss = F.nll_loss(a, a_) # type: ignore loss.backward() diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py new file mode 100644 index 0000000..faae34a --- /dev/null +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -0,0 +1,139 @@ +import math +import torch +import numpy as np +import torch.nn.functional as F +from typing import Any, Dict, Union, Optional + +from tianshou.policy import DQNPolicy +from tianshou.data import Batch, ReplayBuffer, to_torch + + +class DiscreteBCQPolicy(DQNPolicy): + """Implementation of discrete BCQ algorithm. arXiv:1910.01708. + + :param torch.nn.Module model: a model following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> q_value) + :param torch.nn.Module imitator: a model following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> imtation_logits) + :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. + :param float discount_factor: in [0, 1]. + :param int estimation_step: greater than 1, the number of steps to look + ahead. + :param int target_update_freq: the target network update frequency. + :param float eval_eps: the epsilon-greedy noise added in evaluation. + :param float unlikely_action_threshold: the threshold (tau) for unlikely + actions, as shown in Equ. (17) in the paper, defaults to 0.3. + :param float imitation_logits_penalty: reguralization weight for imitation + logits, defaults to 1e-2. + :param bool reward_normalization: normalize the reward to Normal(0, 1), + defaults to False. + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + model: torch.nn.Module, + imitator: torch.nn.Module, + optim: torch.optim.Optimizer, + discount_factor: float = 0.99, + estimation_step: int = 1, + target_update_freq: int = 8000, + eval_eps: float = 1e-3, + unlikely_action_threshold: float = 0.3, + imitation_logits_penalty: float = 1e-2, + reward_normalization: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(model, optim, discount_factor, estimation_step, + target_update_freq, reward_normalization, **kwargs) + assert target_update_freq > 0, "BCQ needs target network setting." + self.imitator = imitator + assert ( + 0.0 <= unlikely_action_threshold < 1.0 + ), "unlikely_action_threshold should be in [0, 1)" + if unlikely_action_threshold > 0: + self._log_tau = math.log(unlikely_action_threshold) + else: + self._log_tau = -np.inf + assert 0.0 <= eval_eps < 1.0 + self._eps = eval_eps + self._weight_reg = imitation_logits_penalty + + def train(self, mode: bool = True) -> "DiscreteBCQPolicy": + self.training = mode + self.model.train(mode) + self.imitator.train(mode) + return self + + def _target_q( + self, buffer: ReplayBuffer, indice: np.ndarray + ) -> torch.Tensor: + batch = buffer[indice] # batch.obs_next: s_{t+n} + # target_Q = Q_old(s_, argmax(Q_new(s_, *))) + with torch.no_grad(): + act = self(batch, input="obs_next", eps=0.0).act + target_q, _ = self.model_old(batch.obs_next) + target_q = target_q[np.arange(len(act)), act] + return target_q + + def forward( # type: ignore + self, + batch: Batch, + state: Optional[Union[dict, Batch, np.ndarray]] = None, + input: str = "obs", + eps: Optional[float] = None, + **kwargs: Any, + ) -> Batch: + if eps is None: + eps = self._eps + obs = batch[input] + q_value, state = self.model(obs, state=state, info=batch.info) + imitation_logits, _ = self.imitator(obs, state=state, info=batch.info) + + # mask actions for argmax + ratio = imitation_logits - imitation_logits.max( + dim=-1, keepdim=True).values + mask = (ratio < self._log_tau).float() + action = (q_value - np.inf * mask).argmax(dim=-1) + + # add eps to act + if not np.isclose(eps, 0.0): + bsz, action_num = q_value.shape + mask = np.random.rand(bsz) < eps + action_rand = torch.randint( + action_num, size=[bsz], device=action.device) + action[mask] = action_rand[mask] + + return Batch(act=action, state=state, q_value=q_value, + imitation_logits=imitation_logits) + + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: + if self._iter % self._freq == 0: + self.sync_weight() + self._iter += 1 + + target_q = batch.returns.flatten() + result = self(batch, eps=0.0) + imitation_logits = result.imitation_logits + current_q = result.q_value[np.arange(len(target_q)), batch.act] + act = to_torch(batch.act, dtype=torch.long, device=target_q.device) + q_loss = F.smooth_l1_loss(current_q, target_q) + i_loss = F.nll_loss( + F.log_softmax(imitation_logits, dim=-1), act) # type: ignore + reg_loss = imitation_logits.pow(2).mean() + loss = q_loss + i_loss + self._weight_reg * reg_loss + + self.optim.zero_grad() + loss.backward() + self.optim.step() + + return { + "loss": loss.item(), + "q_loss": q_loss.item(), + "i_loss": i_loss.item(), + "reg_loss": reg_loss.item(), + } diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index df34e01..e215c6b 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -58,8 +58,8 @@ class A2CPolicy(PGPolicy): self.critic = critic assert 0.0 <= gae_lambda <= 1.0, "GAE lambda should be in [0, 1]." self._lambda = gae_lambda - self._w_vf = vf_coef - self._w_ent = ent_coef + self._weight_vf = vf_coef + self._weight_ent = ent_coef self._grad_norm = max_grad_norm self._batch = max_batchsize self._rew_norm = reward_normalization @@ -122,7 +122,8 @@ class A2CPolicy(PGPolicy): a_loss = -(log_prob * (r - v).detach()).mean() vf_loss = F.mse_loss(r, v) # type: ignore ent_loss = dist.entropy().mean() - loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss + loss = a_loss + self._weight_vf * vf_loss - \ + self._weight_ent * ent_loss loss.backward() if self._grad_norm is not None: nn.utils.clip_grad_norm_( diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index 96cc680..706872e 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -124,7 +124,7 @@ class C51Policy(DQNPolicy): return target_dist.sum(-1) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: - if self._target and self._cnt % self._freq == 0: + if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() with torch.no_grad(): @@ -139,5 +139,5 @@ class C51Policy(DQNPolicy): batch.weight = cross_entropy.detach() # prio-buffer loss.backward() self.optim.step() - self._cnt += 1 + self._iter += 1 return {"loss": loss.item()} diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 91cca61..a8f705b 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -54,7 +54,7 @@ class DQNPolicy(BasePolicy): self._n_step = estimation_step self._target = target_update_freq > 0 self._freq = target_update_freq - self._cnt = 0 + self._iter = 0 if self._target: self.model_old = deepcopy(self.model) self.model_old.eval() @@ -78,16 +78,15 @@ class DQNPolicy(BasePolicy): self, buffer: ReplayBuffer, indice: np.ndarray ) -> torch.Tensor: batch = buffer[indice] # batch.obs_next: s_{t+n} - if self._target: - # target_Q = Q_old(s_, argmax(Q_new(s_, *))) - a = self(batch, input="obs_next").act - with torch.no_grad(): + # target_Q = Q_old(s_, argmax(Q_new(s_, *))) + with torch.no_grad(): + if self._target: + a = self(batch, input="obs_next").act target_q = self( batch, model="model_old", input="obs_next" ).logits - target_q = target_q[np.arange(len(a)), a] - else: - with torch.no_grad(): + target_q = target_q[np.arange(len(a)), a] + else: target_q = self(batch, input="obs_next").logits.max(dim=1)[0] return target_q @@ -162,7 +161,7 @@ class DQNPolicy(BasePolicy): return Batch(logits=q, act=act, state=h) def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: - if self._target and self._cnt % self._freq == 0: + if self._target and self._iter % self._freq == 0: self.sync_weight() self.optim.zero_grad() weight = batch.pop("weight", 1.0) @@ -174,5 +173,5 @@ class DQNPolicy(BasePolicy): batch.weight = td # prio-buffer loss.backward() self.optim.step() - self._cnt += 1 + self._iter += 1 return {"loss": loss.item()} diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 5a04ec6..4cf9f90 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -68,8 +68,8 @@ class PPOPolicy(PGPolicy): super().__init__(None, optim, dist_fn, discount_factor, **kwargs) self._max_grad_norm = max_grad_norm self._eps_clip = eps_clip - self._w_vf = vf_coef - self._w_ent = ent_coef + self._weight_vf = vf_coef + self._weight_ent = ent_coef self._range = action_range self.actor = actor self.critic = critic @@ -174,7 +174,8 @@ class PPOPolicy(PGPolicy): vf_losses.append(vf_loss.item()) e_loss = dist.entropy().mean() ent_losses.append(e_loss.item()) - loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss + loss = clip_loss + self._weight_vf * vf_loss - \ + self._weight_ent * e_loss losses.append(loss.item()) self.optim.zero_grad() loss.backward() diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 014fbe6..983df71 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -143,7 +143,6 @@ class SACPolicy(DDPGPolicy): with torch.no_grad(): obs_next_result = self(batch, input='obs_next') a_ = obs_next_result.act - batch.act = to_torch_as(batch.act, a_) target_q = torch.min( self.critic1_old(batch.obs_next, a_), self.critic2_old(batch.obs_next, a_), diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py index 36a8ed4..22fc1ee 100644 --- a/tianshou/trainer/__init__.py +++ b/tianshou/trainer/__init__.py @@ -1,10 +1,12 @@ from tianshou.trainer.utils import test_episode, gather_info from tianshou.trainer.onpolicy import onpolicy_trainer from tianshou.trainer.offpolicy import offpolicy_trainer +from tianshou.trainer.offline import offline_trainer __all__ = [ "gather_info", "test_episode", "onpolicy_trainer", "offpolicy_trainer", + "offline_trainer", ] diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py new file mode 100644 index 0000000..e693641 --- /dev/null +++ b/tianshou/trainer/offline.py @@ -0,0 +1,97 @@ +import time +import tqdm +from collections import defaultdict +from torch.utils.tensorboard import SummaryWriter +from typing import Dict, List, Union, Callable, Optional + +from tianshou.policy import BasePolicy +from tianshou.utils import tqdm_config, MovAvg +from tianshou.data import Collector, ReplayBuffer +from tianshou.trainer import test_episode, gather_info + + +def offline_trainer( + policy: BasePolicy, + buffer: ReplayBuffer, + test_collector: Collector, + max_epoch: int, + step_per_epoch: int, + episode_per_test: Union[int, List[int]], + batch_size: int, + test_fn: Optional[Callable[[int, Optional[int]], None]] = None, + stop_fn: Optional[Callable[[float], bool]] = None, + save_fn: Optional[Callable[[BasePolicy], None]] = None, + writer: Optional[SummaryWriter] = None, + log_interval: int = 1, + verbose: bool = True, +) -> Dict[str, Union[float, str]]: + """A wrapper for offline trainer procedure. + + The "step" in trainer means a policy network update. + + :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` + class. + :param test_collector: the collector used for testing. + :type test_collector: :class:`~tianshou.data.Collector` + :param int max_epoch: the maximum number of epochs for training. The + training process might be finished before reaching the ``max_epoch``. + :param int step_per_epoch: the number of policy network updates, so-called + gradient steps, per epoch. + :param episode_per_test: the number of episodes for one policy evaluation. + :param int batch_size: the batch size of sample data, which is going to + feed in the policy network. + :param function test_fn: a hook called at the beginning of testing in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean + reward in evaluation phase gets better, with the signature ``f(policy: + BasePolicy) -> None``. + :param function stop_fn: a function with signature ``f(mean_rewards: float) + -> bool``, receives the average undiscounted returns of the testing + result, returns a boolean which indicates whether reaching the goal. + :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard + SummaryWriter; if None is given, it will not write logs to TensorBoard. + :param int log_interval: the log interval of the writer. + :param bool verbose: whether to print the information. + + :return: See :func:`~tianshou.trainer.gather_info`. + """ + gradient_step = 0 + best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 + stat: Dict[str, MovAvg] = defaultdict(MovAvg) + start_time = time.time() + test_collector.reset_stat() + + for epoch in range(1, 1 + max_epoch): + policy.train() + with tqdm.trange( + step_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config + ) as t: + for i in t: + gradient_step += 1 + losses = policy.update(batch_size, buffer) + data = {"gradient_step": str(gradient_step)} + for k in losses.keys(): + stat[k].add(losses[k]) + data[k] = f"{stat[k].get():.6f}" + if writer and gradient_step % log_interval == 0: + writer.add_scalar( + "train/" + k, stat[k].get(), + global_step=gradient_step) + t.set_postfix(**data) + # test + result = test_episode(policy, test_collector, test_fn, epoch, + episode_per_test, writer, gradient_step) + if best_epoch == -1 or best_reward < result["rew"]: + best_reward, best_reward_std = result["rew"], result["rew_std"] + best_epoch = epoch + if save_fn: + save_fn(policy) + if verbose: + print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± " + f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± " + f"{best_reward_std:.6f} in #{best_epoch}") + if stop_fn and stop_fn(best_reward): + break + return gather_info(start_time, None, test_collector, + best_reward, best_reward_std) diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index fb4b6f4..f34f5b2 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -1,5 +1,6 @@ import time import tqdm +from collections import defaultdict from torch.utils.tensorboard import SummaryWriter from typing import Dict, List, Union, Callable, Optional @@ -38,10 +39,10 @@ def offpolicy_trainer( :type train_collector: :class:`~tianshou.data.Collector` :param test_collector: the collector used for testing. :type test_collector: :class:`~tianshou.data.Collector` - :param int max_epoch: the maximum of epochs for training. The training - process might be finished before reaching the ``max_epoch``. - :param int step_per_epoch: the number of step for updating policy network - in one epoch. + :param int max_epoch: the maximum number of epochs for training. The + training process might be finished before reaching the ``max_epoch``. + :param int step_per_epoch: the number of policy network updates, so-called + gradient steps, per epoch. :param int collect_per_step: the number of frames the collector would collect before the network update. In other words, collect some frames and do some policy network update. @@ -52,19 +53,20 @@ def offpolicy_trainer( be updated after frames are collected, for example, set it to 256 means it updates policy 256 times once after ``collect_per_step`` frames are collected. - :param function train_fn: a function receives the current number of epoch - and step index, and performs some operations at the beginning of - training in this epoch. - :param function test_fn: a function receives the current number of epoch - and step index, and performs some operations at the beginning of - testing in this epoch. - :param function save_fn: a function for saving policy when the undiscounted - average mean reward in evaluation phase gets better. - :param function stop_fn: a function receives the average undiscounted - returns of the testing result, return a boolean which indicates whether - reaching the goal. + :param function train_fn: a hook called at the beginning of training in + each epoch. It can be used to perform custom additional operations, + with the signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function test_fn: a hook called at the beginning of testing in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean + reward in evaluation phase gets better, with the signature ``f(policy: + BasePolicy) -> None``. + :param function stop_fn: a function with signature ``f(mean_rewards: float) + -> bool``, receives the average undiscounted returns of the testing + result, returns a boolean which indicates whether reaching the goal. :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard - SummaryWriter. + SummaryWriter; if None is given, it will not write logs to TensorBoard. :param int log_interval: the log interval of the writer. :param bool verbose: whether to print the information. :param bool test_in_train: whether to test in the training phase. @@ -73,7 +75,7 @@ def offpolicy_trainer( """ env_step, gradient_step = 0, 0 best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 - stat: Dict[str, MovAvg] = {} + stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() train_collector.reset_stat() test_collector.reset_stat() @@ -122,8 +124,6 @@ def offpolicy_trainer( gradient_step += 1 losses = policy.update(batch_size, train_collector.buffer) for k in losses.keys(): - if stat.get(k) is None: - stat[k] = MovAvg() stat[k].add(losses[k]) data[k] = f"{stat[k].get():.6f}" if writer and gradient_step % log_interval == 0: diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index 5aff68b..f094ddd 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -1,5 +1,6 @@ import time import tqdm +from collections import defaultdict from torch.utils.tensorboard import SummaryWriter from typing import Dict, List, Union, Callable, Optional @@ -38,10 +39,10 @@ def onpolicy_trainer( :type train_collector: :class:`~tianshou.data.Collector` :param test_collector: the collector used for testing. :type test_collector: :class:`~tianshou.data.Collector` - :param int max_epoch: the maximum of epochs for training. The training - process might be finished before reaching the ``max_epoch``. - :param int step_per_epoch: the number of step for updating policy network - in one epoch. + :param int max_epoch: the maximum number of epochs for training. The + training process might be finished before reaching the ``max_epoch``. + :param int step_per_epoch: the number of policy network updates, so-called + gradient steps, per epoch. :param int collect_per_step: the number of episodes the collector would collect before the network update. In other words, collect some episodes and do one policy network update. @@ -52,19 +53,20 @@ def onpolicy_trainer( :type episode_per_test: int or list of ints :param int batch_size: the batch size of sample data, which is going to feed in the policy network. - :param function train_fn: a function receives the current number of epoch - and step index, and performs some operations at the beginning of - training in this poch. - :param function test_fn: a function receives the current number of epoch - and step index, and performs some operations at the beginning of - testing in this epoch. - :param function save_fn: a function for saving policy when the undiscounted - average mean reward in evaluation phase gets better. - :param function stop_fn: a function receives the average undiscounted - returns of the testing result, return a boolean which indicates whether - reaching the goal. + :param function train_fn: a hook called at the beginning of training in + each epoch. It can be used to perform custom additional operations, + with the signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function test_fn: a hook called at the beginning of testing in each + epoch. It can be used to perform custom additional operations, with the + signature ``f(num_epoch: int, step_idx: int) -> None``. + :param function save_fn: a hook called when the undiscounted average mean + reward in evaluation phase gets better, with the signature ``f(policy: + BasePolicy) -> None``. + :param function stop_fn: a function with signature ``f(mean_rewards: float) + -> bool``, receives the average undiscounted returns of the testing + result, returns a boolean which indicates whether reaching the goal. :param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard - SummaryWriter. + SummaryWriter; if None is given, it will not write logs to TensorBoard. :param int log_interval: the log interval of the writer. :param bool verbose: whether to print the information. :param bool test_in_train: whether to test in the training phase. @@ -73,7 +75,7 @@ def onpolicy_trainer( """ env_step, gradient_step = 0, 0 best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0 - stat: Dict[str, MovAvg] = {} + stat: Dict[str, MovAvg] = defaultdict(MovAvg) start_time = time.time() train_collector.reset_stat() test_collector.reset_stat() @@ -125,8 +127,6 @@ def onpolicy_trainer( len(v) for v in losses.values() if isinstance(v, list)]) gradient_step += step for k in losses.keys(): - if stat.get(k) is None: - stat[k] = MovAvg() stat[k].add(losses[k]) data[k] = f"{stat[k].get():.6f}" if writer and gradient_step % log_interval == 0: diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index da9dea8..dfffd71 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -36,7 +36,7 @@ def test_episode( def gather_info( start_time: float, - train_c: Collector, + train_c: Optional[Collector], test_c: Collector, best_reward: float, best_reward_std: float, @@ -59,15 +59,9 @@ def gather_info( * ``duration`` the total elapsed time. """ duration = time.time() - start_time - model_time = duration - train_c.collect_time - test_c.collect_time - train_speed = train_c.collect_step / (duration - test_c.collect_time) + model_time = duration - test_c.collect_time test_speed = test_c.collect_step / test_c.collect_time - return { - "train_step": train_c.collect_step, - "train_episode": train_c.collect_episode, - "train_time/collector": f"{train_c.collect_time:.2f}s", - "train_time/model": f"{model_time:.2f}s", - "train_speed": f"{train_speed:.2f} step/s", + result: Dict[str, Union[float, str]] = { "test_step": test_c.collect_step, "test_episode": test_c.collect_episode, "test_time": f"{test_c.collect_time:.2f}s", @@ -75,4 +69,16 @@ def gather_info( "best_reward": best_reward, "best_result": f"{best_reward:.2f} ± {best_reward_std:.2f}", "duration": f"{duration:.2f}s", + "train_time/model": f"{model_time:.2f}s", } + if train_c is not None: + model_time -= train_c.collect_time + train_speed = train_c.collect_step / (duration - test_c.collect_time) + result.update({ + "train_step": train_c.collect_step, + "train_episode": train_c.collect_episode, + "train_time/collector": f"{train_c.collect_time:.2f}s", + "train_time/model": f"{model_time:.2f}s", + "train_speed": f"{train_speed:.2f} step/s", + }) + return result