diff --git a/README.md b/README.md index 41935f1..4c2a0af 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ - Vanilla Imitation Learning - [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf) - [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf) +- [Discrete Critic Regularized Regression (CRR-Discrete)](https://arxiv.org/pdf/2006.15134.pdf) - [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf) - [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf) - [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf) diff --git a/docs/api/tianshou.policy.rst b/docs/api/tianshou.policy.rst index aa17040..32f124f 100644 --- a/docs/api/tianshou.policy.rst +++ b/docs/api/tianshou.policy.rst @@ -104,6 +104,11 @@ Imitation :undoc-members: :show-inheritance: +.. autoclass:: tianshou.policy.DiscreteCRRPolicy + :members: + :undoc-members: + :show-inheritance: + Model-based ----------- diff --git a/docs/index.rst b/docs/index.rst index 72f479d..c0a3c02 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -26,6 +26,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.ImitationPolicy` Imitation Learning * :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning `_ * :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning `_ +* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression `_ * :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning `_ * :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay `_ * :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator `_ diff --git a/examples/atari/README.md b/examples/atari/README.md index eb49a02..fb45858 100644 --- a/examples/atari/README.md +++ b/examples/atari/README.md @@ -99,3 +99,20 @@ Buffer size 10000: | ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ | | PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` | | BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` | + +# CRR + +To running CRR algorithm on Atari, you need to do the following things: + +- Train an expert, by using the command listed in the above QRDQN section; +- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error); +- Train CQL: `python3 atari_crr.py --task {your_task} --load-buffer-name expert.hdf5`. + +We test our CRR implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step): + +| Task | Online QRDQN | Behavioral | CRR | CRR w/ CQL | parameters | +| ---------------------- | ---------- | ---------- | ---------------- | ----------------- | ------------------------------------------------------------ | +| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 16.1 (epoch 5) | `python3 atari_crr.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` | +| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 26.4 (epoch 12) | 125.0 (epoch 12) | `python3 atari_crr.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` | + +Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps. diff --git a/examples/atari/atari_crr.py b/examples/atari/atari_crr.py new file mode 100644 index 0000000..6bd9167 --- /dev/null +++ b/examples/atari/atari_crr.py @@ -0,0 +1,155 @@ +import os +import torch +import pickle +import pprint +import datetime +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.utils import BasicLogger +from tianshou.env import SubprocVectorEnv +from tianshou.trainer import offline_trainer +from tianshou.utils.net.discrete import Actor +from tianshou.policy import DiscreteCRRPolicy +from tianshou.data import Collector, VectorReplayBuffer + +from atari_network import DQN +from atari_wrapper import wrap_deepmind + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="PongNoFrameskip-v4") + parser.add_argument("--seed", type=int, default=1626) + parser.add_argument("--lr", type=float, default=0.0001) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--policy-improvement-mode", type=str, default="exp") + parser.add_argument("--ratio-upper-bound", type=float, default=20.) + parser.add_argument("--beta", type=float, default=1.) + parser.add_argument("--min-q-weight", type=float, default=10.) + parser.add_argument("--target-update-freq", type=int, default=500) + parser.add_argument("--epoch", type=int, default=100) + parser.add_argument("--update-per-epoch", type=int, default=10000) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512]) + parser.add_argument("--test-num", type=int, default=10) + parser.add_argument('--frames-stack', type=int, default=4) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.) + parser.add_argument("--resume-path", type=str, default=None) + parser.add_argument("--watch", default=False, action="store_true", + help="watch the play of pre-trained policy only") + parser.add_argument("--log-interval", type=int, default=100) + parser.add_argument( + "--load-buffer-name", type=str, + default="./expert_DQN_PongNoFrameskip-v4.hdf5") + parser.add_argument( + "--device", type=str, + default="cuda" if torch.cuda.is_available() else "cpu") + args = parser.parse_known_args()[0] + return args + + +def make_atari_env(args): + return wrap_deepmind(args.task, frame_stack=args.frames_stack) + + +def make_atari_env_watch(args): + return wrap_deepmind(args.task, frame_stack=args.frames_stack, + episode_life=False, clip_rewards=False) + + +def test_discrete_crr(args=get_args()): + # envs + env = make_atari_env(args) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # should be N_FRAMES x H x W + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + # make environments + test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args) + for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + test_envs.seed(args.seed) + # model + feature_net = DQN(*args.state_shape, args.action_shape, + device=args.device, features_only=True).to(args.device) + actor = Actor(feature_net, args.action_shape, device=args.device, + hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device) + critic = DQN(*args.state_shape, args.action_shape, + device=args.device).to(args.device) + optim = torch.optim.Adam(list(actor.parameters()) + list(critic.parameters()), + lr=args.lr) + # define policy + policy = DiscreteCRRPolicy( + actor, critic, optim, args.gamma, + policy_improvement_mode=args.policy_improvement_mode, + ratio_upper_bound=args.ratio_upper_bound, beta=args.beta, + min_q_weight=args.min_q_weight, + target_update_freq=args.target_update_freq + ).to(args.device) + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load( + args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + # buffer + assert os.path.exists(args.load_buffer_name), \ + "Please run atari_qrdqn.py first to get expert's data buffer." + if args.load_buffer_name.endswith('.pkl'): + buffer = pickle.load(open(args.load_buffer_name, "rb")) + elif args.load_buffer_name.endswith('.hdf5'): + buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) + else: + print(f"Unknown buffer format: {args.load_buffer_name}") + exit(0) + + # collector + test_collector = Collector(policy, test_envs, exploration_noise=True) + + # log + log_path = os.path.join( + args.logdir, args.task, 'crr', + f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}') + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = BasicLogger(writer, update_interval=args.log_interval) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return False + + # watch agent's performance + def watch(): + print("Setup test envs ...") + policy.eval() + test_envs.seed(args.seed) + print("Testing agent ...") + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, + render=args.render) + pprint.pprint(result) + rew = result["rews"].mean() + print(f'Mean reward (over {result["n/ep"]} episodes): {rew}') + + if args.watch: + watch() + exit(0) + + result = offline_trainer( + policy, buffer, test_collector, args.epoch, + args.update_per_epoch, args.test_num, args.batch_size, + stop_fn=stop_fn, save_fn=save_fn, logger=logger) + + pprint.pprint(result) + watch() + + +if __name__ == "__main__": + test_discrete_crr(get_args()) diff --git a/test/discrete/test_il_crr.py b/test/discrete/test_il_crr.py new file mode 100644 index 0000000..736edef --- /dev/null +++ b/test/discrete/test_il_crr.py @@ -0,0 +1,110 @@ +import os +import gym +import torch +import pickle +import pprint +import argparse +import numpy as np +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector +from tianshou.utils import BasicLogger +from tianshou.env import DummyVectorEnv +from tianshou.utils.net.common import Net +from tianshou.trainer import offline_trainer +from tianshou.policy import DiscreteCRRPolicy + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--seed", type=int, default=1626) + parser.add_argument("--lr", type=float, default=7e-4) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--n-step", type=int, default=3) + parser.add_argument("--target-update-freq", type=int, default=320) + parser.add_argument("--epoch", type=int, default=5) + parser.add_argument("--update-per-epoch", type=int, default=1000) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument('--hidden-sizes', type=int, + nargs='*', default=[64, 64]) + parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument("--render", type=float, default=0.) + parser.add_argument( + "--load-buffer-name", type=str, + default="./expert_DQN_CartPole-v0.pkl", + ) + parser.add_argument( + "--device", type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + ) + args = parser.parse_known_args()[0] + return args + + +def test_discrete_crr(args=get_args()): + # envs + env = gym.make(args.task) + if args.task == 'CartPole-v0': + env.spec.reward_threshold = 190 # lower the goal + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + test_envs = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)]) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + test_envs.seed(args.seed) + # model + actor = Net(args.state_shape, args.action_shape, + hidden_sizes=args.hidden_sizes, device=args.device, + softmax=False) + critic = Net(args.state_shape, args.action_shape, + hidden_sizes=args.hidden_sizes, device=args.device, + softmax=False) + optim = torch.optim.Adam(list(actor.parameters()) + list(critic.parameters()), + lr=args.lr) + + policy = DiscreteCRRPolicy( + actor, critic, optim, args.gamma, + target_update_freq=args.target_update_freq, + ).to(args.device) + # buffer + assert os.path.exists(args.load_buffer_name), \ + "Please run test_dqn.py first to get expert's data buffer." + buffer = pickle.load(open(args.load_buffer_name, "rb")) + + # collector + test_collector = Collector(policy, test_envs, exploration_noise=True) + + log_path = os.path.join(args.logdir, args.task, 'discrete_cql') + writer = SummaryWriter(log_path) + logger = BasicLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + result = offline_trainer( + policy, buffer, test_collector, + args.epoch, args.update_per_epoch, args.test_num, args.batch_size, + stop_fn=stop_fn, save_fn=save_fn, logger=logger) + + assert stop_fn(result['best_reward']) + + if __name__ == '__main__': + pprint.pprint(result) + # Let's watch its performance! + env = gym.make(args.task) + policy.eval() + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + + +if __name__ == "__main__": + test_discrete_crr(get_args()) diff --git a/test/discrete/test_qrdqn_il_cql.py b/test/discrete/test_qrdqn_il_cql.py index eb7de42..7a782dd 100644 --- a/test/discrete/test_qrdqn_il_cql.py +++ b/test/discrete/test_qrdqn_il_cql.py @@ -71,7 +71,7 @@ def test_discrete_cql(args=get_args()): ).to(args.device) # buffer assert os.path.exists(args.load_buffer_name), \ - "Please run test_dqn.py first to get expert's data buffer." + "Please run test_qrdqn.py first to get expert's data buffer." buffer = pickle.load(open(args.load_buffer_name, "rb")) # collector diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 30b1af9..54f0f4a 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -15,6 +15,7 @@ from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy from tianshou.policy.imitation.base import ImitationPolicy from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy +from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy from tianshou.policy.modelbased.psrl import PSRLPolicy from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager @@ -37,6 +38,7 @@ __all__ = [ "ImitationPolicy", "DiscreteBCQPolicy", "DiscreteCQLPolicy", + "DiscreteCRRPolicy", "PSRLPolicy", "MultiAgentPolicyManager", ] diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py new file mode 100644 index 0000000..f5b203f --- /dev/null +++ b/tianshou/policy/imitation/discrete_crr.py @@ -0,0 +1,123 @@ +import torch +from copy import deepcopy +from typing import Any, Dict +import torch.nn.functional as F +from torch.distributions import Categorical + +from tianshou.policy.modelfree.pg import PGPolicy +from tianshou.data import Batch, to_torch, to_torch_as + + +class DiscreteCRRPolicy(PGPolicy): + r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134. + + :param torch.nn.Module actor: the actor network following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param torch.nn.Module critic: the action-value critic (i.e., Q function) + network. (s -> Q(s, \*)) + :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. + :param float discount_factor: in [0, 1]. Default to 0.99. + :param str policy_improvement_mode: type of the weight function f. Possible + values: "binary"/"exp"/"all". Default to "exp". + :param float ratio_upper_bound: when policy_improvement_mode is "exp", the value + of the exp function is upper-bounded by this parameter. Default to 20. + :param float beta: when policy_improvement_mode is "exp", this is the denominator + of the exp function. Default to 1. + :param float min_q_weight: weight for CQL loss/regularizer. Default to 10. + :param int target_update_freq: the target network update frequency (0 if + you do not use the target network). Default to 0. + :param bool reward_normalization: normalize the reward to Normal(0, 1). + Default to False. + + .. seealso:: + Please refer to :class:`~tianshou.policy.PGPolicy` for more detailed + explanation. + """ + + def __init__( + self, + actor: torch.nn.Module, + critic: torch.nn.Module, + optim: torch.optim.Optimizer, + discount_factor: float = 0.99, + policy_improvement_mode: str = "exp", + ratio_upper_bound: float = 20.0, + beta: float = 1.0, + min_q_weight: float = 10.0, + target_update_freq: int = 0, + reward_normalization: bool = False, + **kwargs: Any, + ) -> None: + super().__init__( + actor, + optim, + lambda x: Categorical(logits=x), # type: ignore + discount_factor, + reward_normalization, + **kwargs, + ) + self.critic = critic + self._target = target_update_freq > 0 + self._freq = target_update_freq + self._iter = 0 + if self._target: + self.actor_old = deepcopy(self.actor) + self.actor_old.eval() + self.critic_old = deepcopy(self.critic) + self.critic_old.eval() + else: + self.actor_old = self.actor + self.critic_old = self.critic + assert policy_improvement_mode in ["exp", "binary", "all"] + self._policy_improvement_mode = policy_improvement_mode + self._ratio_upper_bound = ratio_upper_bound + self._beta = beta + self._min_q_weight = min_q_weight + + def sync_weight(self) -> None: + self.actor_old.load_state_dict(self.actor.state_dict()) # type: ignore + self.critic_old.load_state_dict(self.critic.state_dict()) # type: ignore + + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # type: ignore + if self._target and self._iter % self._freq == 0: + self.sync_weight() + self.optim.zero_grad() + q_t, _ = self.critic(batch.obs) + act = to_torch(batch.act, dtype=torch.long, device=q_t.device) + qa_t = q_t.gather(1, act.unsqueeze(1)) + # Critic loss + with torch.no_grad(): + target_a_t, _ = self.actor_old(batch.obs_next) + target_m = Categorical(logits=target_a_t) + q_t_target, _ = self.critic_old(batch.obs_next) + rew = to_torch_as(batch.rew, q_t_target) + expected_target_q = (q_t_target * target_m.probs).sum(-1, keepdim=True) + expected_target_q[batch.done > 0] = 0.0 + target = rew.unsqueeze(1) + self._gamma * expected_target_q + critic_loss = 0.5 * F.mse_loss(qa_t, target) + # Actor loss + a_t, _ = self.actor(batch.obs) + m = Categorical(logits=a_t) + expected_policy_q = (q_t * m.probs).sum(-1, keepdim=True) + advantage = qa_t - expected_policy_q + if self._policy_improvement_mode == "binary": + actor_loss_coef = (advantage > 0).float() + elif self._policy_improvement_mode == "exp": + actor_loss_coef = ( + (advantage / self._beta).exp().clamp(0, self._ratio_upper_bound) + ) + else: + actor_loss_coef = 1.0 # effectively behavior cloning + actor_loss = (-m.log_prob(act) * actor_loss_coef).mean() + # CQL loss/regularizer + min_q_loss = (q_t.logsumexp(1) - qa_t).mean() + loss = actor_loss + critic_loss + self._min_q_weight * min_q_loss + loss.backward() + self.optim.step() + self._iter += 1 + return { + "loss": loss.item(), + "loss/actor": actor_loss.item(), + "loss/critic": critic_loss.item(), + "loss/cql": min_q_loss.item(), + }