Add discrete Conservative Q-Learning for offline RL (#359)

Co-authored-by: Yi Su <yi.su@antgroup.com>
Co-authored-by: Yi Su <yi.su@antfin.com>
This commit is contained in:
Yi Su 2021-05-11 18:24:48 -07:00 committed by GitHub
parent 84f58636eb
commit b5c3ddabfa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 395 additions and 6 deletions

View File

@ -35,6 +35,7 @@
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
- Vanilla Imitation Learning
- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf)
- [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf)
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)

View File

@ -99,6 +99,11 @@ Imitation
:undoc-members:
:show-inheritance:
.. autoclass:: tianshou.policy.DiscreteCQLPolicy
:members:
:undoc-members:
:show-inheritance:
Model-based
-----------

View File

@ -25,6 +25,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1910.01708.pdf>`_
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning <https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf>`_
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_

View File

@ -67,4 +67,35 @@ We test our BCQ implementation on two example tasks (different from author's ver
| Task | Online DQN | Behavioral | BCQ |
| ---------------------- | ---------- | ---------- | --------------------------------- |
| PongNoFrameskip-v4 | 21 | 7.7 | 21 (epoch 5) |
| BreakoutNoFrameskip-v4 | 303 | 61 | 167.4 (epoch 12, could be higher) |
| BreakoutNoFrameskip-v4 | 303 | 61 | 167.4 (epoch 12, could be higher) |
# CQL
To running CQL algorithm on Atari, you need to do the following things:
- Train an expert, by using the command listed in the above QRDQN section;
- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
- Train CQL: `python3 atari_cql.py --task {your_task} --load-buffer-name expert.hdf5`.
We test our CQL implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 6.8 | 19.5 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 248.3 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
We reduce the size of the offline data to 10% and 1% of the above and get:
Buffer size 100000:
| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | 5.8 | 21 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 5` |
| BreakoutNoFrameskip-v4 | 394.3 | 41.4 | 40.8 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 12 --min-q-weight 20` |
Buffer size 10000:
| Task | Online QRDQN | Behavioral | CQL | parameters |
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` |
| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` |

View File

@ -135,6 +135,8 @@ def test_discrete_bcq(args=get_args()):
result = test_collector.collect(n_episode=args.test_num,
render=args.render)
pprint.pprint(result)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
if args.watch:
watch()

147
examples/atari/atari_cql.py Normal file
View File

@ -0,0 +1,147 @@
import os
import torch
import pickle
import pprint
import datetime
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offline_trainer
from tianshou.policy import DiscreteCQLPolicy
from tianshou.data import Collector, VectorReplayBuffer
from atari_network import QRDQN
from atari_wrapper import wrap_deepmind
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.001)
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument('--num-quantiles', type=int, default=200)
parser.add_argument("--n-step", type=int, default=1)
parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument("--min-q-weight", type=float, default=10.)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--update-per-epoch", type=int, default=10000)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--watch", default=False, action="store_true",
help="watch the play of pre-trained policy only")
parser.add_argument("--log-interval", type=int, default=100)
parser.add_argument(
"--load-buffer-name", type=str,
default="./expert_DQN_PongNoFrameskip-v4.hdf5")
parser.add_argument(
"--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_known_args()[0]
return args
def make_atari_env(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack)
def make_atari_env_watch(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
episode_life=False, clip_rewards=False)
def test_discrete_cql(args=get_args()):
# envs
env = make_atari_env(args)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
net = QRDQN(*args.state_shape, args.action_shape,
args.num_quantiles, args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = DiscreteCQLPolicy(
net, optim, args.gamma, args.num_quantiles, args.n_step,
args.target_update_freq, min_q_weight=args.min_q_weight
).to(args.device)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(
args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run atari_qrdqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith('.pkl'):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith('.hdf5'):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
# collector
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_path = os.path.join(
args.logdir, args.task, 'cql',
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=args.log_interval)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
return False
# watch agent's performance
def watch():
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num,
render=args.render)
pprint.pprint(result)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
if args.watch:
watch()
exit(0)
result = offline_trainer(
policy, buffer, test_collector, args.epoch,
args.update_per_epoch, args.test_num, args.batch_size,
stop_fn=stop_fn, save_fn=save_fn, logger=logger)
pprint.pprint(result)
watch()
if __name__ == "__main__":
test_discrete_cql(get_args())

View File

@ -118,7 +118,7 @@ def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7):
pass_check = 0
break
total_pass += pass_check
if sys.platform != "darwin": # macOS cannot pass this check
if sys.platform == "linux": # Windows/macOS may not pass this check
assert total_pass >= 2

View File

@ -1,6 +1,7 @@
import os
import gym
import torch
import pickle
import pprint
import argparse
import numpy as np
@ -41,6 +42,9 @@ def get_args():
action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument(
'--save-buffer-name', type=str,
default="./expert_QRDQN_CartPole-v0.pkl")
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
@ -130,6 +134,14 @@ def test_qrdqn(args=get_args()):
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
# save buffer in pickle format, for imitation learning unittest
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs))
policy.set_eps(0.9) # 10% of expert data as demonstrated in the original paper
collector = Collector(policy, test_envs, buf, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
pickle.dump(buf, open(args.save_buffer_name, "wb"))
print(result["rews"].mean())
def test_pqrdqn(args=get_args()):
args.prioritized_replay = True

View File

@ -0,0 +1,110 @@
import os
import gym
import torch
import pickle
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offline_trainer
from tianshou.policy import DiscreteCQLPolicy
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.001)
parser.add_argument("--lr", type=float, default=7e-4)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument('--num-quantiles', type=int, default=200)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--target-update-freq", type=int, default=320)
parser.add_argument("--min-q-weight", type=float, default=10.)
parser.add_argument("--epoch", type=int, default=5)
parser.add_argument("--update-per-epoch", type=int, default=1000)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[64, 64])
parser.add_argument("--test-num", type=int, default=100)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
"--load-buffer-name", type=str,
default="./expert_QRDQN_CartPole-v0.pkl",
)
parser.add_argument(
"--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
args = parser.parse_known_args()[0]
return args
def test_discrete_cql(args=get_args()):
# envs
env = gym.make(args.task)
if args.task == 'CartPole-v0':
env.spec.reward_threshold = 190 # lower the goal
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
test_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, device=args.device,
softmax=False, num_atoms=args.num_quantiles)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = DiscreteCQLPolicy(
net, optim, args.gamma, args.num_quantiles, args.n_step,
args.target_update_freq, min_q_weight=args.min_q_weight
).to(args.device)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run test_dqn.py first to get expert's data buffer."
buffer = pickle.load(open(args.load_buffer_name, "rb"))
# collector
test_collector = Collector(policy, test_envs, exploration_noise=True)
log_path = os.path.join(args.logdir, args.task, 'discrete_cql')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
result = offline_trainer(
policy, buffer, test_collector,
args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
stop_fn=stop_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
if __name__ == "__main__":
test_discrete_cql(get_args())

View File

@ -299,7 +299,7 @@ class BaseVectorEnv(gym.Env):
clip_max = 10.0 # this magic number is from openai baselines
# see baselines/common/vec_env/vec_normalize.py#L10
obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.__eps)
obs = np.clip(obs, -clip_max, clip_max) # type: ignore
obs = np.clip(obs, -clip_max, clip_max)
return obs

View File

@ -14,6 +14,7 @@ from tianshou.policy.modelfree.sac import SACPolicy
from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
from tianshou.policy.imitation.base import ImitationPolicy
from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy
from tianshou.policy.modelbased.psrl import PSRLPolicy
from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager
@ -35,6 +36,7 @@ __all__ = [
"DiscreteSACPolicy",
"ImitationPolicy",
"DiscreteBCQPolicy",
"DiscreteCQLPolicy",
"PSRLPolicy",
"MultiAgentPolicyManager",
]

View File

@ -162,7 +162,7 @@ class BasePolicy(ABC, nn.Module):
isinstance(act, np.ndarray):
# currently this action mapping only supports np.ndarray action
if self.action_bound_method == "clip":
act = np.clip(act, -1.0, 1.0) # type: ignore
act = np.clip(act, -1.0, 1.0)
elif self.action_bound_method == "tanh":
act = np.tanh(act)
if self.action_scaling:

View File

@ -0,0 +1,78 @@
import torch
import numpy as np
import torch.nn.functional as F
from typing import Any, Dict
from tianshou.policy import QRDQNPolicy
from tianshou.data import Batch, to_torch
class DiscreteCQLPolicy(QRDQNPolicy):
"""Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779.
:param torch.nn.Module model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
:param float discount_factor: in [0, 1].
:param int num_quantiles: the number of quantile midpoints in the inverse
cumulative distribution function of the value. Default to 200.
:param int estimation_step: the number of steps to look ahead. Default to 1.
:param int target_update_freq: the target network update frequency (0 if
you do not use the target network).
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param float min_q_weight: the weight for the cql loss.
.. seealso::
Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed
explanation.
"""
def __init__(
self,
model: torch.nn.Module,
optim: torch.optim.Optimizer,
discount_factor: float = 0.99,
num_quantiles: int = 200,
estimation_step: int = 1,
target_update_freq: int = 0,
reward_normalization: bool = False,
min_q_weight: float = 10.0,
**kwargs: Any,
) -> None:
super().__init__(model, optim, discount_factor, num_quantiles, estimation_step,
target_update_freq, reward_normalization, **kwargs)
self._min_q_weight = min_q_weight
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
if self._target and self._iter % self._freq == 0:
self.sync_weight()
self.optim.zero_grad()
weight = batch.pop("weight", 1.0)
all_dist = self(batch).logits
act = to_torch(batch.act, dtype=torch.long, device=all_dist.device)
curr_dist = all_dist[np.arange(len(act)), act, :].unsqueeze(2)
target_dist = batch.returns.unsqueeze(1)
# calculate each element's difference between curr_dist and target_dist
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
huber_loss = (u * (
self.tau_hat - (target_dist - curr_dist).detach().le(0.).float()
).abs()).sum(-1).mean(1)
qr_loss = (huber_loss * weight).mean()
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer
# add CQL loss
q = self.compute_q_value(all_dist, None)
dataset_expec = q.gather(1, act.unsqueeze(1)).mean()
negative_sampling = q.logsumexp(1).mean()
min_q_loss = negative_sampling - dataset_expec
loss = qr_loss + min_q_loss * self._min_q_weight
loss.backward()
self.optim.step()
self._iter += 1
return {
"loss": loss.item(),
"loss/qr": qr_loss.item(),
"loss/cql": min_q_loss.item(),
}

View File

@ -124,7 +124,7 @@ class TRPOPolicy(NPGPolicy):
" are poor and need to be changed.")
# optimize citirc
for _ in range(self._optim_critic_iters): # type: ignore
for _ in range(self._optim_critic_iters):
value = self.critic(b.obs).flatten()
vf_loss = F.mse_loss(b.returns, value)
self.optim.zero_grad()

View File

@ -91,5 +91,5 @@ class RunningMeanStd(object):
m_2 = m_a + m_b + delta ** 2 * self.count * batch_count / total_count
new_var = m_2 / total_count
self.mean, self.var = new_mean, new_var # type: ignore
self.mean, self.var = new_mean, new_var
self.count = total_count