Add discrete Critic Regularized Regression (#367)
This commit is contained in:
parent
b5c3ddabfa
commit
8f7bc65ac7
@ -36,6 +36,7 @@
|
||||
- Vanilla Imitation Learning
|
||||
- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf)
|
||||
- [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf)
|
||||
- [Discrete Critic Regularized Regression (CRR-Discrete)](https://arxiv.org/pdf/2006.15134.pdf)
|
||||
- [Prioritized Experience Replay (PER)](https://arxiv.org/pdf/1511.05952.pdf)
|
||||
- [Generalized Advantage Estimator (GAE)](https://arxiv.org/pdf/1506.02438.pdf)
|
||||
- [Posterior Sampling Reinforcement Learning (PSRL)](https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf)
|
||||
|
@ -104,6 +104,11 @@ Imitation
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.DiscreteCRRPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Model-based
|
||||
-----------
|
||||
|
||||
|
@ -26,6 +26,7 @@ Welcome to Tianshou!
|
||||
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
|
||||
* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1910.01708.pdf>`_
|
||||
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
|
||||
* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression <https://arxiv.org/pdf/2006.15134.pdf>`_
|
||||
* :class:`~tianshou.policy.PSRLPolicy` `Posterior Sampling Reinforcement Learning <https://www.ece.uvic.ca/~bctill/papers/learning/Strens_2000.pdf>`_
|
||||
* :class:`~tianshou.data.PrioritizedReplayBuffer` `Prioritized Experience Replay <https://arxiv.org/pdf/1511.05952.pdf>`_
|
||||
* :meth:`~tianshou.policy.BasePolicy.compute_episodic_return` `Generalized Advantage Estimator <https://arxiv.org/pdf/1506.02438.pdf>`_
|
||||
|
@ -99,3 +99,20 @@ Buffer size 10000:
|
||||
| ---------------------- | ---------- | ---------- | --------------------------------- | ------------------------------------------------------------ |
|
||||
| PongNoFrameskip-v4 | 20.5 | nan | 1.8 (epoch 5) | `python3 atari_cql.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 5 --min-q-weight 1` |
|
||||
| BreakoutNoFrameskip-v4 | 394.3 | 31.7 | 22.5 (epoch 12) | `python3 atari_cql.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.size_1e4.hdf5 --epoch 12 --min-q-weight 10` |
|
||||
|
||||
# CRR
|
||||
|
||||
To running CRR algorithm on Atari, you need to do the following things:
|
||||
|
||||
- Train an expert, by using the command listed in the above QRDQN section;
|
||||
- Generate buffer with noise: `python3 atari_qrdqn.py --task {your_task} --watch --resume-path log/{your_task}/qrdqn/policy.pth --eps-test 0.2 --buffer-size 1000000 --save-buffer-name expert.hdf5` (note that 1M Atari buffer cannot be saved as `.pkl` format because it is too large and will cause error);
|
||||
- Train CQL: `python3 atari_crr.py --task {your_task} --load-buffer-name expert.hdf5`.
|
||||
|
||||
We test our CRR implementation on two example tasks (different from author's version, we use v4 instead of v0; one epoch means 10k gradient step):
|
||||
|
||||
| Task | Online QRDQN | Behavioral | CRR | CRR w/ CQL | parameters |
|
||||
| ---------------------- | ---------- | ---------- | ---------------- | ----------------- | ------------------------------------------------------------ |
|
||||
| PongNoFrameskip-v4 | 20.5 | 6.8 | -21 (epoch 5) | 16.1 (epoch 5) | `python3 atari_crr.py --task "PongNoFrameskip-v4" --load-buffer-name log/PongNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 5` |
|
||||
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 26.4 (epoch 12) | 125.0 (epoch 12) | `python3 atari_crr.py --task "BreakoutNoFrameskip-v4" --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
|
||||
|
||||
Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps.
|
||||
|
155
examples/atari/atari_crr.py
Normal file
155
examples/atari/atari_crr.py
Normal file
@ -0,0 +1,155 @@
|
||||
import os
|
||||
import torch
|
||||
import pickle
|
||||
import pprint
|
||||
import datetime
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.utils import BasicLogger
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.trainer import offline_trainer
|
||||
from tianshou.utils.net.discrete import Actor
|
||||
from tianshou.policy import DiscreteCRRPolicy
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
|
||||
from atari_network import DQN
|
||||
from atari_wrapper import wrap_deepmind
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
|
||||
parser.add_argument("--seed", type=int, default=1626)
|
||||
parser.add_argument("--lr", type=float, default=0.0001)
|
||||
parser.add_argument("--gamma", type=float, default=0.99)
|
||||
parser.add_argument("--policy-improvement-mode", type=str, default="exp")
|
||||
parser.add_argument("--ratio-upper-bound", type=float, default=20.)
|
||||
parser.add_argument("--beta", type=float, default=1.)
|
||||
parser.add_argument("--min-q-weight", type=float, default=10.)
|
||||
parser.add_argument("--target-update-freq", type=int, default=500)
|
||||
parser.add_argument("--epoch", type=int, default=100)
|
||||
parser.add_argument("--update-per-epoch", type=int, default=10000)
|
||||
parser.add_argument("--batch-size", type=int, default=32)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
|
||||
parser.add_argument("--test-num", type=int, default=10)
|
||||
parser.add_argument('--frames-stack', type=int, default=4)
|
||||
parser.add_argument("--logdir", type=str, default="log")
|
||||
parser.add_argument("--render", type=float, default=0.)
|
||||
parser.add_argument("--resume-path", type=str, default=None)
|
||||
parser.add_argument("--watch", default=False, action="store_true",
|
||||
help="watch the play of pre-trained policy only")
|
||||
parser.add_argument("--log-interval", type=int, default=100)
|
||||
parser.add_argument(
|
||||
"--load-buffer-name", type=str,
|
||||
default="./expert_DQN_PongNoFrameskip-v4.hdf5")
|
||||
parser.add_argument(
|
||||
"--device", type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu")
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def make_atari_env(args):
|
||||
return wrap_deepmind(args.task, frame_stack=args.frames_stack)
|
||||
|
||||
|
||||
def make_atari_env_watch(args):
|
||||
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
|
||||
episode_life=False, clip_rewards=False)
|
||||
|
||||
|
||||
def test_discrete_crr(args=get_args()):
|
||||
# envs
|
||||
env = make_atari_env(args)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# should be N_FRAMES x H x W
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
# make environments
|
||||
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
|
||||
for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
feature_net = DQN(*args.state_shape, args.action_shape,
|
||||
device=args.device, features_only=True).to(args.device)
|
||||
actor = Actor(feature_net, args.action_shape, device=args.device,
|
||||
hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device)
|
||||
critic = DQN(*args.state_shape, args.action_shape,
|
||||
device=args.device).to(args.device)
|
||||
optim = torch.optim.Adam(list(actor.parameters()) + list(critic.parameters()),
|
||||
lr=args.lr)
|
||||
# define policy
|
||||
policy = DiscreteCRRPolicy(
|
||||
actor, critic, optim, args.gamma,
|
||||
policy_improvement_mode=args.policy_improvement_mode,
|
||||
ratio_upper_bound=args.ratio_upper_bound, beta=args.beta,
|
||||
min_q_weight=args.min_q_weight,
|
||||
target_update_freq=args.target_update_freq
|
||||
).to(args.device)
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(
|
||||
args.resume_path, map_location=args.device))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
# buffer
|
||||
assert os.path.exists(args.load_buffer_name), \
|
||||
"Please run atari_qrdqn.py first to get expert's data buffer."
|
||||
if args.load_buffer_name.endswith('.pkl'):
|
||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||
elif args.load_buffer_name.endswith('.hdf5'):
|
||||
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
|
||||
else:
|
||||
print(f"Unknown buffer format: {args.load_buffer_name}")
|
||||
exit(0)
|
||||
|
||||
# collector
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
|
||||
# log
|
||||
log_path = os.path.join(
|
||||
args.logdir, args.task, 'crr',
|
||||
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = BasicLogger(writer, update_interval=args.log_interval)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return False
|
||||
|
||||
# watch agent's performance
|
||||
def watch():
|
||||
print("Setup test envs ...")
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num,
|
||||
render=args.render)
|
||||
pprint.pprint(result)
|
||||
rew = result["rews"].mean()
|
||||
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
|
||||
|
||||
if args.watch:
|
||||
watch()
|
||||
exit(0)
|
||||
|
||||
result = offline_trainer(
|
||||
policy, buffer, test_collector, args.epoch,
|
||||
args.update_per_epoch, args.test_num, args.batch_size,
|
||||
stop_fn=stop_fn, save_fn=save_fn, logger=logger)
|
||||
|
||||
pprint.pprint(result)
|
||||
watch()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_discrete_crr(get_args())
|
110
test/discrete/test_il_crr.py
Normal file
110
test/discrete/test_il_crr.py
Normal file
@ -0,0 +1,110 @@
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import pickle
|
||||
import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector
|
||||
from tianshou.utils import BasicLogger
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.trainer import offline_trainer
|
||||
from tianshou.policy import DiscreteCRRPolicy
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="CartPole-v0")
|
||||
parser.add_argument("--seed", type=int, default=1626)
|
||||
parser.add_argument("--lr", type=float, default=7e-4)
|
||||
parser.add_argument("--gamma", type=float, default=0.99)
|
||||
parser.add_argument("--n-step", type=int, default=3)
|
||||
parser.add_argument("--target-update-freq", type=int, default=320)
|
||||
parser.add_argument("--epoch", type=int, default=5)
|
||||
parser.add_argument("--update-per-epoch", type=int, default=1000)
|
||||
parser.add_argument("--batch-size", type=int, default=64)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[64, 64])
|
||||
parser.add_argument("--test-num", type=int, default=100)
|
||||
parser.add_argument("--logdir", type=str, default="log")
|
||||
parser.add_argument("--render", type=float, default=0.)
|
||||
parser.add_argument(
|
||||
"--load-buffer-name", type=str,
|
||||
default="./expert_DQN_CartPole-v0.pkl",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def test_discrete_crr(args=get_args()):
|
||||
# envs
|
||||
env = gym.make(args.task)
|
||||
if args.task == 'CartPole-v0':
|
||||
env.spec.reward_threshold = 190 # lower the goal
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
actor = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes, device=args.device,
|
||||
softmax=False)
|
||||
critic = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes, device=args.device,
|
||||
softmax=False)
|
||||
optim = torch.optim.Adam(list(actor.parameters()) + list(critic.parameters()),
|
||||
lr=args.lr)
|
||||
|
||||
policy = DiscreteCRRPolicy(
|
||||
actor, critic, optim, args.gamma,
|
||||
target_update_freq=args.target_update_freq,
|
||||
).to(args.device)
|
||||
# buffer
|
||||
assert os.path.exists(args.load_buffer_name), \
|
||||
"Please run test_dqn.py first to get expert's data buffer."
|
||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||
|
||||
# collector
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
|
||||
log_path = os.path.join(args.logdir, args.task, 'discrete_cql')
|
||||
writer = SummaryWriter(log_path)
|
||||
logger = BasicLogger(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
result = offline_trainer(
|
||||
policy, buffer, test_collector,
|
||||
args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
|
||||
stop_fn=stop_fn, save_fn=save_fn, logger=logger)
|
||||
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
rews, lens = result["rews"], result["lens"]
|
||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_discrete_crr(get_args())
|
@ -71,7 +71,7 @@ def test_discrete_cql(args=get_args()):
|
||||
).to(args.device)
|
||||
# buffer
|
||||
assert os.path.exists(args.load_buffer_name), \
|
||||
"Please run test_dqn.py first to get expert's data buffer."
|
||||
"Please run test_qrdqn.py first to get expert's data buffer."
|
||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||
|
||||
# collector
|
||||
|
@ -15,6 +15,7 @@ from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
|
||||
from tianshou.policy.imitation.base import ImitationPolicy
|
||||
from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
|
||||
from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy
|
||||
from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy
|
||||
from tianshou.policy.modelbased.psrl import PSRLPolicy
|
||||
from tianshou.policy.multiagent.mapolicy import MultiAgentPolicyManager
|
||||
|
||||
@ -37,6 +38,7 @@ __all__ = [
|
||||
"ImitationPolicy",
|
||||
"DiscreteBCQPolicy",
|
||||
"DiscreteCQLPolicy",
|
||||
"DiscreteCRRPolicy",
|
||||
"PSRLPolicy",
|
||||
"MultiAgentPolicyManager",
|
||||
]
|
||||
|
123
tianshou/policy/imitation/discrete_crr.py
Normal file
123
tianshou/policy/imitation/discrete_crr.py
Normal file
@ -0,0 +1,123 @@
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict
|
||||
import torch.nn.functional as F
|
||||
from torch.distributions import Categorical
|
||||
|
||||
from tianshou.policy.modelfree.pg import PGPolicy
|
||||
from tianshou.data import Batch, to_torch, to_torch_as
|
||||
|
||||
|
||||
class DiscreteCRRPolicy(PGPolicy):
|
||||
r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134.
|
||||
|
||||
:param torch.nn.Module actor: the actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param torch.nn.Module critic: the action-value critic (i.e., Q function)
|
||||
network. (s -> Q(s, \*))
|
||||
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
||||
:param float discount_factor: in [0, 1]. Default to 0.99.
|
||||
:param str policy_improvement_mode: type of the weight function f. Possible
|
||||
values: "binary"/"exp"/"all". Default to "exp".
|
||||
:param float ratio_upper_bound: when policy_improvement_mode is "exp", the value
|
||||
of the exp function is upper-bounded by this parameter. Default to 20.
|
||||
:param float beta: when policy_improvement_mode is "exp", this is the denominator
|
||||
of the exp function. Default to 1.
|
||||
:param float min_q_weight: weight for CQL loss/regularizer. Default to 10.
|
||||
:param int target_update_freq: the target network update frequency (0 if
|
||||
you do not use the target network). Default to 0.
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||
Default to False.
|
||||
|
||||
.. seealso::
|
||||
Please refer to :class:`~tianshou.policy.PGPolicy` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
discount_factor: float = 0.99,
|
||||
policy_improvement_mode: str = "exp",
|
||||
ratio_upper_bound: float = 20.0,
|
||||
beta: float = 1.0,
|
||||
min_q_weight: float = 10.0,
|
||||
target_update_freq: int = 0,
|
||||
reward_normalization: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
actor,
|
||||
optim,
|
||||
lambda x: Categorical(logits=x), # type: ignore
|
||||
discount_factor,
|
||||
reward_normalization,
|
||||
**kwargs,
|
||||
)
|
||||
self.critic = critic
|
||||
self._target = target_update_freq > 0
|
||||
self._freq = target_update_freq
|
||||
self._iter = 0
|
||||
if self._target:
|
||||
self.actor_old = deepcopy(self.actor)
|
||||
self.actor_old.eval()
|
||||
self.critic_old = deepcopy(self.critic)
|
||||
self.critic_old.eval()
|
||||
else:
|
||||
self.actor_old = self.actor
|
||||
self.critic_old = self.critic
|
||||
assert policy_improvement_mode in ["exp", "binary", "all"]
|
||||
self._policy_improvement_mode = policy_improvement_mode
|
||||
self._ratio_upper_bound = ratio_upper_bound
|
||||
self._beta = beta
|
||||
self._min_q_weight = min_q_weight
|
||||
|
||||
def sync_weight(self) -> None:
|
||||
self.actor_old.load_state_dict(self.actor.state_dict()) # type: ignore
|
||||
self.critic_old.load_state_dict(self.critic.state_dict()) # type: ignore
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # type: ignore
|
||||
if self._target and self._iter % self._freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
q_t, _ = self.critic(batch.obs)
|
||||
act = to_torch(batch.act, dtype=torch.long, device=q_t.device)
|
||||
qa_t = q_t.gather(1, act.unsqueeze(1))
|
||||
# Critic loss
|
||||
with torch.no_grad():
|
||||
target_a_t, _ = self.actor_old(batch.obs_next)
|
||||
target_m = Categorical(logits=target_a_t)
|
||||
q_t_target, _ = self.critic_old(batch.obs_next)
|
||||
rew = to_torch_as(batch.rew, q_t_target)
|
||||
expected_target_q = (q_t_target * target_m.probs).sum(-1, keepdim=True)
|
||||
expected_target_q[batch.done > 0] = 0.0
|
||||
target = rew.unsqueeze(1) + self._gamma * expected_target_q
|
||||
critic_loss = 0.5 * F.mse_loss(qa_t, target)
|
||||
# Actor loss
|
||||
a_t, _ = self.actor(batch.obs)
|
||||
m = Categorical(logits=a_t)
|
||||
expected_policy_q = (q_t * m.probs).sum(-1, keepdim=True)
|
||||
advantage = qa_t - expected_policy_q
|
||||
if self._policy_improvement_mode == "binary":
|
||||
actor_loss_coef = (advantage > 0).float()
|
||||
elif self._policy_improvement_mode == "exp":
|
||||
actor_loss_coef = (
|
||||
(advantage / self._beta).exp().clamp(0, self._ratio_upper_bound)
|
||||
)
|
||||
else:
|
||||
actor_loss_coef = 1.0 # effectively behavior cloning
|
||||
actor_loss = (-m.log_prob(act) * actor_loss_coef).mean()
|
||||
# CQL loss/regularizer
|
||||
min_q_loss = (q_t.logsumexp(1) - qa_t).mean()
|
||||
loss = actor_loss + critic_loss + self._min_q_weight * min_q_loss
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
self._iter += 1
|
||||
return {
|
||||
"loss": loss.item(),
|
||||
"loss/actor": actor_loss.item(),
|
||||
"loss/critic": critic_loss.item(),
|
||||
"loss/cql": min_q_loss.item(),
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user