Add discrete Conservative Q-Learning for offline RL (#359)
Co-authored-by: Yi Su <yi.su@antgroup.com> Co-authored-by: Yi Su <yi.su@antfin.com>
This commit is contained in:
parent
84f58636eb
commit
b5c3ddabfa
@ -35,6 +35,7 @@
|
||||
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
|
||||
- Vanilla Imitation Learning
|
||||
- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf)
|
||||
- [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf)
|
||||
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
|
||||
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
|
||||
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)
|
||||
|
@ -99,6 +99,11 @@ Imitation
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.DiscreteCQLPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Model-based
|
||||
-----------
|
||||
|
||||
|
@ -25,6 +25,7 @@ Welcome to Tianshou!
|
||||
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
|
||||
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
|
||||
* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1910.01708.pdf>`_
|
||||
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
|
||||
* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning <https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf>`_
|
||||
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
|
||||
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_
|
||||
|
@ -68,3 +68,34 @@ We test our BCQ implementation on two example tasks (different from author's ver
|
||||
| ---------------------- | ---------- | ---------- | --------------------------------- |
|
||||
| PongNoFrameskip-v4 | 21 | 7.7 | 21 (epoch 5) |
|
||||
| BreakoutNoFrameskip-v4 | 303 | 61 | 167.4 (epoch 12, could be higher) |
|
||||
|
||||
# CQL
|
||||
|
||||
To running CQL algorithm on Atari, you need to do the following things:
|
||||
|
||||
- Train an expert, by using the command listed in the above QRDQN section;
|
||||
- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
|
||||
- Train CQL: `python3 atari_cql.py --task {your_task} --load-buffer-name expert.hdf5`.
|
||||
|
||||
We test our CQL implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
|
||||
|
||||
| Task | Online QRDQN | Behavioral | CQL | parameters |
|
||||
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
|
||||
| PongNoFrameskip-v4 | 20.5 | 6.8 | 19.5 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
|
||||
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 248.3 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
|
||||
|
||||
We reduce the size of the offline data to 10% and 1% of the above and get:
|
||||
|
||||
Buffer size 100000:
|
||||
|
||||
| Task | Online QRDQN | Behavioral | CQL | parameters |
|
||||
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
|
||||
| PongNoFrameskip-v4 | 20.5 | 5.8 | 21 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 5` |
|
||||
| BreakoutNoFrameskip-v4 | 394.3 | 41.4 | 40.8 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e5.hdf5 --epoch 12 --min-q-weight 20` |
|
||||
|
||||
Buffer size 10000:
|
||||
|
||||
| Task | Online QRDQN | Behavioral | CQL | parameters |
|
||||
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
|
||||
| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` |
|
||||
| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` |
|
||||
|
@ -135,6 +135,8 @@ def test_discrete_bcq(args=get_args()):
|
||||
result = test_collector.collect(n_episode=args.test_num,
|
||||
render=args.render)
|
||||
pprint.pprint(result)
|
||||
rew = result["rews"].mean()
|
||||
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
|
||||
|
||||
if args.watch:
|
||||
watch()
|
||||
|
147
examples/atari/atari_cql.py
Normal file
147
examples/atari/atari_cql.py
Normal file
@ -0,0 +1,147 @@
|
||||
import os
|
||||
import torch
|
||||
import pickle
|
||||
import pprint
|
||||
import datetime
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.utils import BasicLogger
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.trainer import offline_trainer
|
||||
from tianshou.policy import DiscreteCQLPolicy
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
|
||||
from atari_network import QRDQN
|
||||
from atari_wrapper import wrap_deepmind
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
|
||||
parser.add_argument("--seed", type=int, default=1626)
|
||||
parser.add_argument("--eps-test", type=float, default=0.001)
|
||||
parser.add_argument("--lr", type=float, default=0.0001)
|
||||
parser.add_argument("--gamma", type=float, default=0.99)
|
||||
parser.add_argument('--num-quantiles', type=int, default=200)
|
||||
parser.add_argument("--n-step", type=int, default=1)
|
||||
parser.add_argument("--target-update-freq", type=int, default=500)
|
||||
parser.add_argument("--min-q-weight", type=float, default=10.)
|
||||
parser.add_argument("--epoch", type=int, default=100)
|
||||
parser.add_argument("--update-per-epoch", type=int, default=10000)
|
||||
parser.add_argument("--batch-size", type=int, default=32)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
|
||||
parser.add_argument("--test-num", type=int, default=10)
|
||||
parser.add_argument('--frames-stack', type=int, default=4)
|
||||
parser.add_argument("--logdir", type=str, default="log")
|
||||
parser.add_argument("--render", type=float, default=0.)
|
||||
parser.add_argument("--resume-path", type=str, default=None)
|
||||
parser.add_argument("--watch", default=False, action="store_true",
|
||||
help="watch the play of pre-trained policy only")
|
||||
parser.add_argument("--log-interval", type=int, default=100)
|
||||
parser.add_argument(
|
||||
"--load-buffer-name", type=str,
|
||||
default="./expert_DQN_PongNoFrameskip-v4.hdf5")
|
||||
parser.add_argument(
|
||||
"--device", type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def make_atari_env(args):
|
||||
return wrap_deepmind(args.task, frame_stack=args.frames_stack)
|
||||
|
||||
|
||||
def make_atari_env_watch(args):
|
||||
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
|
||||
episode_life=False, clip_rewards=False)
|
||||
|
||||
|
||||
def test_discrete_cql(args=get_args()):
|
||||
# envs
|
||||
env = make_atari_env(args)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# should be N_FRAMES x H x W
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
# make environments
|
||||
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
|
||||
for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = QRDQN(*args.state_shape, args.action_shape,
|
||||
args.num_quantiles, args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
# define policy
|
||||
policy = DiscreteCQLPolicy(
|
||||
net, optim, args.gamma, args.num_quantiles, args.n_step,
|
||||
args.target_update_freq, min_q_weight=args.min_q_weight
|
||||
).to(args.device)
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(
|
||||
args.resume_path, map_location=args.device))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
# buffer
|
||||
assert os.path.exists(args.load_buffer_name), \
|
||||
"Please run atari_qrdqn.py first to get expert's data buffer."
|
||||
if args.load_buffer_name.endswith('.pkl'):
|
||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||
elif args.load_buffer_name.endswith('.hdf5'):
|
||||
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
|
||||
else:
|
||||
print(f"Unknown buffer format: {args.load_buffer_name}")
|
||||
exit(0)
|
||||
|
||||
# collector
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
|
||||
# log
|
||||
log_path = os.path.join(
|
||||
args.logdir, args.task, 'cql',
|
||||
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = BasicLogger(writer, update_interval=args.log_interval)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return False
|
||||
|
||||
# watch agent's performance
|
||||
def watch():
|
||||
print("Setup test envs ...")
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num,
|
||||
render=args.render)
|
||||
pprint.pprint(result)
|
||||
rew = result["rews"].mean()
|
||||
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
|
||||
|
||||
if args.watch:
|
||||
watch()
|
||||
exit(0)
|
||||
|
||||
result = offline_trainer(
|
||||
policy, buffer, test_collector, args.epoch,
|
||||
args.update_per_epoch, args.test_num, args.batch_size,
|
||||
stop_fn=stop_fn, save_fn=save_fn, logger=logger)
|
||||
|
||||
pprint.pprint(result)
|
||||
watch()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_discrete_cql(get_args())
|
@ -118,7 +118,7 @@ def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7):
|
||||
pass_check = 0
|
||||
break
|
||||
total_pass += pass_check
|
||||
if sys.platform != "darwin": # macOS cannot pass this check
|
||||
if sys.platform == "linux": # Windows/macOS may not pass this check
|
||||
assert total_pass >= 2
|
||||
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import pickle
|
||||
import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
@ -41,6 +42,9 @@ def get_args():
|
||||
action="store_true", default=False)
|
||||
parser.add_argument('--alpha', type=float, default=0.6)
|
||||
parser.add_argument('--beta', type=float, default=0.4)
|
||||
parser.add_argument(
|
||||
'--save-buffer-name', type=str,
|
||||
default="./expert_QRDQN_CartPole-v0.pkl")
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
@ -130,6 +134,14 @@ def test_qrdqn(args=get_args()):
|
||||
rews, lens = result["rews"], result["lens"]
|
||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||
|
||||
# save buffer in pickle format, for imitation learning unittest
|
||||
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs))
|
||||
policy.set_eps(0.9) # 10% of expert data as demonstrated in the original paper
|
||||
collector = Collector(policy, test_envs, buf, exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
pickle.dump(buf, open(args.save_buffer_name, "wb"))
|
||||
print(result["rews"].mean())
|
||||
|
||||
|
||||
def test_pqrdqn(args=get_args()):
|
||||
args.prioritized_replay = True
|
||||
|
110
test/discrete/test_qrdqn_il_cql.py
Normal file
110
test/discrete/test_qrdqn_il_cql.py
Normal file
@ -0,0 +1,110 @@
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import pickle
|
||||
import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector
|
||||
from tianshou.utils import BasicLogger
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.trainer import offline_trainer
|
||||
from tianshou.policy import DiscreteCQLPolicy
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="CartPole-v0")
|
||||
parser.add_argument("--seed", type=int, default=1626)
|
||||
parser.add_argument("--eps-test", type=float, default=0.001)
|
||||
parser.add_argument("--lr", type=float, default=7e-4)
|
||||
parser.add_argument("--gamma", type=float, default=0.99)
|
||||
parser.add_argument('--num-quantiles', type=int, default=200)
|
||||
parser.add_argument("--n-step", type=int, default=3)
|
||||
parser.add_argument("--target-update-freq", type=int, default=320)
|
||||
parser.add_argument("--min-q-weight", type=float, default=10.)
|
||||
parser.add_argument("--epoch", type=int, default=5)
|
||||
parser.add_argument("--update-per-epoch", type=int, default=1000)
|
||||
parser.add_argument("--batch-size", type=int, default=64)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[64, 64])
|
||||
parser.add_argument("--test-num", type=int, default=100)
|
||||
parser.add_argument("--logdir", type=str, default="log")
|
||||
parser.add_argument("--render", type=float, default=0.)
|
||||
parser.add_argument(
|
||||
"--load-buffer-name", type=str,
|
||||
default="./expert_QRDQN_CartPole-v0.pkl",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def test_discrete_cql(args=get_args()):
|
||||
# envs
|
||||
env = gym.make(args.task)
|
||||
if args.task == 'CartPole-v0':
|
||||
env.spec.reward_threshold = 190 # lower the goal
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes, device=args.device,
|
||||
softmax=False, num_atoms=args.num_quantiles)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
|
||||
policy = DiscreteCQLPolicy(
|
||||
net, optim, args.gamma, args.num_quantiles, args.n_step,
|
||||
args.target_update_freq, min_q_weight=args.min_q_weight
|
||||
).to(args.device)
|
||||
# buffer
|
||||
assert os.path.exists(args.load_buffer_name), \
|
||||
"Please run test_dqn.py first to get expert's data buffer."
|
||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||
|
||||
# collector
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
|
||||
log_path = os.path.join(args.logdir, args.task, 'discrete_cql')
|
||||
writer = SummaryWriter(log_path)
|
||||
logger = BasicLogger(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
result = offline_trainer(
|
||||
policy, buffer, test_collector,
|
||||
args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
|
||||
stop_fn=stop_fn, save_fn=save_fn, logger=logger)
|
||||
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
rews, lens = result["rews"], result["lens"]
|
||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_discrete_cql(get_args())
|
2
tianshou/env/venvs.py
vendored
2
tianshou/env/venvs.py
vendored
@ -299,7 +299,7 @@ class BaseVectorEnv(gym.Env):
|
||||
clip_max = 10.0 # this magic number is from openai baselines
|
||||
# see baselines/common/vec_env/vec_normalize.py#L10
|
||||
obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.__eps)
|
||||
obs = np.clip(obs, -clip_max, clip_max) # type: ignore
|
||||
obs = np.clip(obs, -clip_max, clip_max)
|
||||
return obs
|
||||
|
||||
|
||||
|
@ -14,6 +14,7 @@ from tianshou.policy.modelfree.sac import SACPolicy
|
||||
from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
|
||||
from tianshou.policy.imitation.base import ImitationPolicy
|
||||
from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
|
||||
from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy
|
||||
from tianshou.policy.modelbased.psrl import PSRLPolicy
|
||||
from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager
|
||||
|
||||
@ -35,6 +36,7 @@ __all__ = [
|
||||
"DiscreteSACPolicy",
|
||||
"ImitationPolicy",
|
||||
"DiscreteBCQPolicy",
|
||||
"DiscreteCQLPolicy",
|
||||
"PSRLPolicy",
|
||||
"MultiAgentPolicyManager",
|
||||
]
|
||||
|
@ -162,7 +162,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
isinstance(act, np.ndarray):
|
||||
# currently this action mapping only supports np.ndarray action
|
||||
if self.action_bound_method == "clip":
|
||||
act = np.clip(act, -1.0, 1.0) # type: ignore
|
||||
act = np.clip(act, -1.0, 1.0)
|
||||
elif self.action_bound_method == "tanh":
|
||||
act = np.tanh(act)
|
||||
if self.action_scaling:
|
||||
|
78
tianshou/policy/imitation/discrete_cql.py
Normal file
78
tianshou/policy/imitation/discrete_cql.py
Normal file
@ -0,0 +1,78 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from typing import Any, Dict
|
||||
|
||||
from tianshou.policy import QRDQNPolicy
|
||||
from tianshou.data import Batch, to_torch
|
||||
|
||||
|
||||
class DiscreteCQLPolicy(QRDQNPolicy):
|
||||
"""Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779.
|
||||
|
||||
:param torch.nn.Module model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
||||
:param float discount_factor: in [0, 1].
|
||||
:param int num_quantiles: the number of quantile midpoints in the inverse
|
||||
cumulative distribution function of the value. Default to 200.
|
||||
:param int estimation_step: the number of steps to look ahead. Default to 1.
|
||||
:param int target_update_freq: the target network update frequency (0 if
|
||||
you do not use the target network).
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||
Default to False.
|
||||
:param float min_q_weight: the weight for the cql loss.
|
||||
|
||||
.. seealso::
|
||||
Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
discount_factor: float = 0.99,
|
||||
num_quantiles: int = 200,
|
||||
estimation_step: int = 1,
|
||||
target_update_freq: int = 0,
|
||||
reward_normalization: bool = False,
|
||||
min_q_weight: float = 10.0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(model, optim, discount_factor, num_quantiles, estimation_step,
|
||||
target_update_freq, reward_normalization, **kwargs)
|
||||
self._min_q_weight = min_q_weight
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
if self._target and self._iter % self._freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
weight = batch.pop("weight", 1.0)
|
||||
all_dist = self(batch).logits
|
||||
act = to_torch(batch.act, dtype=torch.long, device=all_dist.device)
|
||||
curr_dist = all_dist[np.arange(len(act)), act, :].unsqueeze(2)
|
||||
target_dist = batch.returns.unsqueeze(1)
|
||||
# calculate each element's difference between curr_dist and target_dist
|
||||
u = F.smooth_l1_loss(target_dist, curr_dist, reduction="none")
|
||||
huber_loss = (u * (
|
||||
self.tau_hat - (target_dist - curr_dist).detach().le(0.).float()
|
||||
).abs()).sum(-1).mean(1)
|
||||
qr_loss = (huber_loss * weight).mean()
|
||||
# ref: https://github.com/ku2482/fqf-iqn-qrdqn.pytorch/
|
||||
# blob/master/fqf_iqn_qrdqn/agent/qrdqn_agent.py L130
|
||||
batch.weight = u.detach().abs().sum(-1).mean(1) # prio-buffer
|
||||
# add CQL loss
|
||||
q = self.compute_q_value(all_dist, None)
|
||||
dataset_expec = q.gather(1, act.unsqueeze(1)).mean()
|
||||
negative_sampling = q.logsumexp(1).mean()
|
||||
min_q_loss = negative_sampling - dataset_expec
|
||||
loss = qr_loss + min_q_loss * self._min_q_weight
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
self._iter += 1
|
||||
return {
|
||||
"loss": loss.item(),
|
||||
"loss/qr": qr_loss.item(),
|
||||
"loss/cql": min_q_loss.item(),
|
||||
}
|
@ -124,7 +124,7 @@ class TRPOPolicy(NPGPolicy):
|
||||
" are poor and need to be changed.")
|
||||
|
||||
# optimize citirc
|
||||
for _ in range(self._optim_critic_iters): # type: ignore
|
||||
for _ in range(self._optim_critic_iters):
|
||||
value = self.critic(b.obs).flatten()
|
||||
vf_loss = F.mse_loss(b.returns, value)
|
||||
self.optim.zero_grad()
|
||||
|
@ -91,5 +91,5 @@ class RunningMeanStd(object):
|
||||
m_2 = m_a + m_b + delta ** 2 * self.count * batch_count / total_count
|
||||
new_var = m_2 / total_count
|
||||
|
||||
self.mean, self.var = new_mean, new_var # type: ignore
|
||||
self.mean, self.var = new_mean, new_var
|
||||
self.count = total_count
|
||||
|
Loading…
x
Reference in New Issue
Block a user