Add discrete Critic Regularized Regression (#367)

This commit is contained in:
Yi Su 2021-05-18 22:29:56 -07:00 committed by GitHub
parent b5c3ddabfa
commit 8f7bc65ac7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 415 additions and 1 deletions

View File

@ -36,6 +36,7 @@
- Vanilla Imitation Learning
- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf)
- [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf)
- [Discrete Critic Regularized Regression (CRR-Discrete)](https://arxiv.org/pdf/2006.15134.pdf)
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)

View File

@ -104,6 +104,11 @@ Imitation
:undoc-members:
:show-inheritance:
.. autoclass:: tianshou.policy.DiscreteCRRPolicy
:members:
:undoc-members:
:show-inheritance:
Model-based
-----------

View File

@ -26,6 +26,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1910.01708.pdf>`_
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression <https://arxiv.org/pdf/2006.15134.pdf>`_
* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning <https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf>`_
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_

View File

@ -99,3 +99,20 @@ Buffer size 10000:
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` |
| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` |
# CRR
To running CRR algorithm on Atari, you need to do the following things:
- Train an expert, by using the command listed in the above QRDQN section;
- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
- Train CQL: `python3 atari_crr.py --task {your_task} --load-buffer-name expert.hdf5`.
We test our CRR implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
| Task | Online QRDQN | Behavioral | CRR | CRR w/ CQL | parameters |
| ---------------------- | ---------- | ---------- | ---------------- | ----------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 16.1 (epoch 5) | `python3 atari_crr.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 26.4 (epoch 12) | 125.0 (epoch 12) | `python3 atari_crr.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps.

155
examples/atari/atari_crr.py Normal file
View File

@ -0,0 +1,155 @@
import os
import torch
import pickle
import pprint
import datetime
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offline_trainer
from tianshou.utils.net.discrete import Actor
from tianshou.policy import DiscreteCRRPolicy
from tianshou.data import Collector, VectorReplayBuffer
from atari_network import DQN
from atari_wrapper import wrap_deepmind
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--policy-improvement-mode", type=str, default="exp")
parser.add_argument("--ratio-upper-bound", type=float, default=20.)
parser.add_argument("--beta", type=float, default=1.)
parser.add_argument("--min-q-weight", type=float, default=10.)
parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--update-per-epoch", type=int, default=10000)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--watch", default=False, action="store_true",
help="watch the play of pre-trained policy only")
parser.add_argument("--log-interval", type=int, default=100)
parser.add_argument(
"--load-buffer-name", type=str,
default="./expert_DQN_PongNoFrameskip-v4.hdf5")
parser.add_argument(
"--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_known_args()[0]
return args
def make_atari_env(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack)
def make_atari_env_watch(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
episode_life=False, clip_rewards=False)
def test_discrete_crr(args=get_args()):
# envs
env = make_atari_env(args)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
feature_net = DQN(*args.state_shape, args.action_shape,
device=args.device, features_only=True).to(args.device)
actor = Actor(feature_net, args.action_shape, device=args.device,
hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device)
critic = DQN(*args.state_shape, args.action_shape,
device=args.device).to(args.device)
optim = torch.optim.Adam(list(actor.parameters()) + list(critic.parameters()),
lr=args.lr)
# define policy
policy = DiscreteCRRPolicy(
actor, critic, optim, args.gamma,
policy_improvement_mode=args.policy_improvement_mode,
ratio_upper_bound=args.ratio_upper_bound, beta=args.beta,
min_q_weight=args.min_q_weight,
target_update_freq=args.target_update_freq
).to(args.device)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(
args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run atari_qrdqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith('.pkl'):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith('.hdf5'):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
# collector
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_path = os.path.join(
args.logdir, args.task, 'crr',
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=args.log_interval)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
return False
# watch agent's performance
def watch():
print("Setup test envs ...")
policy.eval()
test_envs.seed(args.seed)
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num,
render=args.render)
pprint.pprint(result)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
if args.watch:
watch()
exit(0)
result = offline_trainer(
policy, buffer, test_collector, args.epoch,
args.update_per_epoch, args.test_num, args.batch_size,
stop_fn=stop_fn, save_fn=save_fn, logger=logger)
pprint.pprint(result)
watch()
if __name__ == "__main__":
test_discrete_crr(get_args())

View File

@ -0,0 +1,110 @@
import os
import gym
import torch
import pickle
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offline_trainer
from tianshou.policy import DiscreteCRRPolicy
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--lr", type=float, default=7e-4)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--target-update-freq", type=int, default=320)
parser.add_argument("--epoch", type=int, default=5)
parser.add_argument("--update-per-epoch", type=int, default=1000)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[64, 64])
parser.add_argument("--test-num", type=int, default=100)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
"--load-buffer-name", type=str,
default="./expert_DQN_CartPole-v0.pkl",
)
parser.add_argument(
"--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
args = parser.parse_known_args()[0]
return args
def test_discrete_crr(args=get_args()):
# envs
env = gym.make(args.task)
if args.task == 'CartPole-v0':
env.spec.reward_threshold = 190 # lower the goal
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
test_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
actor = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, device=args.device,
softmax=False)
critic = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, device=args.device,
softmax=False)
optim = torch.optim.Adam(list(actor.parameters()) + list(critic.parameters()),
lr=args.lr)
policy = DiscreteCRRPolicy(
actor, critic, optim, args.gamma,
target_update_freq=args.target_update_freq,
).to(args.device)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run test_dqn.py first to get expert's data buffer."
buffer = pickle.load(open(args.load_buffer_name, "rb"))
# collector
test_collector = Collector(policy, test_envs, exploration_noise=True)
log_path = os.path.join(args.logdir, args.task, 'discrete_cql')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
result = offline_trainer(
policy, buffer, test_collector,
args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
stop_fn=stop_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
if __name__ == "__main__":
test_discrete_crr(get_args())

View File

@ -71,7 +71,7 @@ def test_discrete_cql(args=get_args()):
).to(args.device)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run test_dqn.py first to get expert's data buffer."
"Please run test_qrdqn.py first to get expert's data buffer."
buffer = pickle.load(open(args.load_buffer_name, "rb"))
# collector

View File

@ -15,6 +15,7 @@ from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
from tianshou.policy.imitation.base import ImitationPolicy
from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy
from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy
from tianshou.policy.modelbased.psrl import PSRLPolicy
from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager
@ -37,6 +38,7 @@ __all__ = [
"ImitationPolicy",
"DiscreteBCQPolicy",
"DiscreteCQLPolicy",
"DiscreteCRRPolicy",
"PSRLPolicy",
"MultiAgentPolicyManager",
]

View File

@ -0,0 +1,123 @@
import torch
from copy import deepcopy
from typing import Any, Dict
import torch.nn.functional as F
from torch.distributions import Categorical
from tianshou.policy.modelfree.pg import PGPolicy
from tianshou.data import Batch, to_torch, to_torch_as
class DiscreteCRRPolicy(PGPolicy):
r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134.
:param torch.nn.Module actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.nn.Module critic: the action-value critic (i.e., Q function)
network. (s -> Q(s, \*))
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
:param float discount_factor: in [0, 1]. Default to 0.99.
:param str policy_improvement_mode: type of the weight function f. Possible
values: "binary"/"exp"/"all". Default to "exp".
:param float ratio_upper_bound: when policy_improvement_mode is "exp", the value
of the exp function is upper-bounded by this parameter. Default to 20.
:param float beta: when policy_improvement_mode is "exp", this is the denominator
of the exp function. Default to 1.
:param float min_q_weight: weight for CQL loss/regularizer. Default to 10.
:param int target_update_freq: the target network update frequency (0 if
you do not use the target network). Default to 0.
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
.. seealso::
Please refer to :class:`~tianshou.policy.PGPolicy` for more detailed
explanation.
"""
def __init__(
self,
actor: torch.nn.Module,
critic: torch.nn.Module,
optim: torch.optim.Optimizer,
discount_factor: float = 0.99,
policy_improvement_mode: str = "exp",
ratio_upper_bound: float = 20.0,
beta: float = 1.0,
min_q_weight: float = 10.0,
target_update_freq: int = 0,
reward_normalization: bool = False,
**kwargs: Any,
) -> None:
super().__init__(
actor,
optim,
lambda x: Categorical(logits=x), # type: ignore
discount_factor,
reward_normalization,
**kwargs,
)
self.critic = critic
self._target = target_update_freq > 0
self._freq = target_update_freq
self._iter = 0
if self._target:
self.actor_old = deepcopy(self.actor)
self.actor_old.eval()
self.critic_old = deepcopy(self.critic)
self.critic_old.eval()
else:
self.actor_old = self.actor
self.critic_old = self.critic
assert policy_improvement_mode in ["exp", "binary", "all"]
self._policy_improvement_mode = policy_improvement_mode
self._ratio_upper_bound = ratio_upper_bound
self._beta = beta
self._min_q_weight = min_q_weight
def sync_weight(self) -> None:
self.actor_old.load_state_dict(self.actor.state_dict()) # type: ignore
self.critic_old.load_state_dict(self.critic.state_dict()) # type: ignore
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # type: ignore
if self._target and self._iter % self._freq == 0:
self.sync_weight()
self.optim.zero_grad()
q_t, _ = self.critic(batch.obs)
act = to_torch(batch.act, dtype=torch.long, device=q_t.device)
qa_t = q_t.gather(1, act.unsqueeze(1))
# Critic loss
with torch.no_grad():
target_a_t, _ = self.actor_old(batch.obs_next)
target_m = Categorical(logits=target_a_t)
q_t_target, _ = self.critic_old(batch.obs_next)
rew = to_torch_as(batch.rew, q_t_target)
expected_target_q = (q_t_target * target_m.probs).sum(-1, keepdim=True)
expected_target_q[batch.done > 0] = 0.0
target = rew.unsqueeze(1) + self._gamma * expected_target_q
critic_loss = 0.5 * F.mse_loss(qa_t, target)
# Actor loss
a_t, _ = self.actor(batch.obs)
m = Categorical(logits=a_t)
expected_policy_q = (q_t * m.probs).sum(-1, keepdim=True)
advantage = qa_t - expected_policy_q
if self._policy_improvement_mode == "binary":
actor_loss_coef = (advantage > 0).float()
elif self._policy_improvement_mode == "exp":
actor_loss_coef = (
(advantage / self._beta).exp().clamp(0, self._ratio_upper_bound)
)
else:
actor_loss_coef = 1.0 # effectively behavior cloning
actor_loss = (-m.log_prob(act) * actor_loss_coef).mean()
# CQL loss/regularizer
min_q_loss = (q_t.logsumexp(1) - qa_t).mean()
loss = actor_loss + critic_loss + self._min_q_weight * min_q_loss
loss.backward()
self.optim.step()
self._iter += 1
return {
"loss": loss.item(),
"loss/actor": actor_loss.item(),
"loss/critic": critic_loss.item(),
"loss/cql": min_q_loss.item(),
}