Implement CQLPolicy and offline_cql example (#506)

This commit is contained in:
Bernard Tan 2022-01-16 05:30:21 +08:00 committed by GitHub
parent a59d96d041
commit bc53ead273
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 761 additions and 2 deletions

View File

@ -37,6 +37,7 @@
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
- Vanilla Imitation Learning
- [Batch-Constrained deep Q-Learning (BCQ)](https://arxiv.org/pdf/1812.02900.pdf)
- [Conservative Q-Learning (CQL)](https://arxiv.org/pdf/2006.04779.pdf)
- [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf)
- [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf)
- [Discrete Critic Regularized Regression (CRR-Discrete)](https://arxiv.org/pdf/2006.15134.pdf)

View File

@ -114,6 +114,11 @@ Imitation
:undoc-members:
:show-inheritance:
.. autoclass:: tianshou.policy.CQLPolicy
:members:
:undoc-members:
:show-inheritance:
.. autoclass:: tianshou.policy.DiscreteBCQPolicy
:members:
:undoc-members:

View File

@ -28,6 +28,7 @@ Welcome to Tianshou!
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
* :class:`~tianshou.policy.ImitationPolicy` Imitation Learning
* :class:`~tianshou.policy.BCQPolicy` `Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1812.02900.pdf>`_
* :class:`~tianshou.policy.CQLPolicy` `Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
* :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning <https://arxiv.org/pdf/1910.01708.pdf>`_
* :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning <https://arxiv.org/pdf/2006.04779.pdf>`_
* :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression <https://arxiv.org/pdf/2006.15134.pdf>`_

View File

@ -2,10 +2,12 @@
In offline reinforcement learning setting, the agent learns a policy from a fixed dataset which is collected once with any policy. And the agent does not interact with environment anymore.
## Continous control
## Continuous control
Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent for continuous control. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets.
We provide implementation of BCQ and CQL algorithm for continuous control.
### Train
Tianshou provides an `offline_trainer` for offline reinforcement learning. You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `buffer` of `offline_trainer`. `offline_bcq.py` is an example of offline RL using the d4rl dataset.
@ -20,7 +22,7 @@ After 1M steps:
![halfcheetah-expert-v1_reward](results/bcq/halfcheetah-expert-v1_reward.png)
`halfcheetah-expert-v1` is a mujoco environment. The setting of hyperparameters are similar to the offpolicy algorithms in mujoco environment.
`halfcheetah-expert-v1` is a mujoco environment. The setting of hyperparameters are similar to the off-policy algorithms in mujoco environment.
## Results

View File

@ -0,0 +1,236 @@
#!/usr/bin/env python3
import argparse
import datetime
import os
import pprint
import d4rl
import gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Batch, Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.policy import CQLPolicy
from tianshou.trainer import offline_trainer
from tianshou.utils import BasicLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='halfcheetah-medium-v1')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=1000000)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[256, 256])
parser.add_argument('--actor-lr', type=float, default=1e-4)
parser.add_argument('--critic-lr', type=float, default=3e-4)
parser.add_argument('--alpha', type=float, default=0.2)
parser.add_argument('--auto-alpha', default=True, action='store_true')
parser.add_argument('--alpha-lr', type=float, default=1e-4)
parser.add_argument('--cql-alpha-lr', type=float, default=3e-4)
parser.add_argument("--start-timesteps", type=int, default=10000)
parser.add_argument('--epoch', type=int, default=200)
parser.add_argument('--step-per-epoch', type=int, default=5000)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument("--tau", type=float, default=0.005)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--cql-weight", type=float, default=1.0)
parser.add_argument("--with-lagrange", type=bool, default=True)
parser.add_argument("--lagrange-threshold", type=float, default=10.0)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--eval-freq", type=int, default=1)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=1 / 35)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument(
'--watch',
default=False,
action='store_true',
help='watch the play of pre-trained policy only',
)
return parser.parse_args()
def test_cql():
args = get_args()
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0] # float
print("device:", args.device)
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
args.state_dim = args.state_shape[0]
args.action_dim = args.action_shape[0]
print("Max_action", args.max_action)
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)]
)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)]
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
# actor network
net_a = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
device=args.device,
)
actor = ActorProb(
net_a,
action_shape=args.action_shape,
max_action=args.max_action,
device=args.device,
unbounded=True,
conditioned_sigma=True
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
# critic network
net_c1 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
net_c2 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
if args.auto_alpha:
target_entropy = -np.prod(env.action_space.shape)
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
args.alpha = (target_entropy, log_alpha, alpha_optim)
policy = CQLPolicy(
actor,
actor_optim,
critic1,
critic1_optim,
critic2,
critic2_optim,
cql_alpha_lr=args.cql_alpha_lr,
cql_weight=args.cql_weight,
tau=args.tau,
gamma=args.gamma,
alpha=args.alpha,
temperature=args.temperature,
with_lagrange=args.with_lagrange,
lagrange_threshold=args.lagrange_threshold,
min_action=np.min(env.action_space.low),
max_action=np.max(env.action_space.high),
device=args.device,
)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# collector
if args.training_num > 1:
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.start_timesteps, random=True)
# log
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_cql'
log_path = os.path.join(args.logdir, args.task, 'cql', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def watch():
if args.resume_path is None:
args.resume_path = os.path.join(log_path, 'policy.pth')
policy.load_state_dict(
torch.load(args.resume_path, map_location=torch.device('cpu'))
)
policy.eval()
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)
if not args.watch:
dataset = d4rl.qlearning_dataset(env)
dataset_size = dataset['rewards'].size
print("dataset_size", dataset_size)
replay_buffer = ReplayBuffer(dataset_size)
for i in range(dataset_size):
replay_buffer.add(
Batch(
obs=dataset['observations'][i],
act=dataset['actions'][i],
rew=dataset['rewards'][i],
done=dataset['terminals'][i],
obs_next=dataset['next_observations'][i],
)
)
print("dataset loaded")
# trainer
result = offline_trainer(
policy,
replay_buffer,
test_collector,
args.epoch,
args.step_per_epoch,
args.test_num,
args.batch_size,
save_fn=save_fn,
logger=logger,
)
pprint.pprint(result)
else:
watch()
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
if __name__ == '__main__':
test_cql()

219
test/offline/test_cql.py Normal file
View File

@ -0,0 +1,219 @@
import argparse
import datetime
import os
import pickle
import pprint
import gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector
from tianshou.env import SubprocVectorEnv
from tianshou.policy import CQLPolicy
from tianshou.trainer import offline_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic
if __name__ == "__main__":
from gather_pendulum_data import gather_data
else: # pytest
from test.offline.gather_pendulum_data import gather_data
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128])
parser.add_argument('--actor-lr', type=float, default=1e-3)
parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--alpha', type=float, default=0.2)
parser.add_argument('--auto-alpha', default=True, action='store_true')
parser.add_argument('--alpha-lr', type=float, default=1e-3)
parser.add_argument('--cql-alpha-lr', type=float, default=1e-3)
parser.add_argument("--start-timesteps", type=int, default=10000)
parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=2000)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument("--tau", type=float, default=0.005)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--cql-weight", type=float, default=1.0)
parser.add_argument("--with-lagrange", type=bool, default=True)
parser.add_argument("--lagrange-threshold", type=float, default=10.0)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--eval-freq", type=int, default=1)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=1 / 35)
parser.add_argument(
'--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu'
)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument(
'--watch',
default=False,
action='store_true',
help='watch the play of pre-trained policy only',
)
parser.add_argument(
"--load-buffer-name", type=str, default="./expert_SAC_Pendulum-v0.pkl"
)
args = parser.parse_known_args()[0]
return args
def test_cql(args=get_args()):
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
else:
buffer = gather_data()
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0] # float
if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -1200 # too low?
args.state_dim = args.state_shape[0]
args.action_dim = args.action_shape[0]
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)]
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
# actor network
net_a = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
device=args.device,
)
actor = ActorProb(
net_a,
action_shape=args.action_shape,
max_action=args.max_action,
device=args.device,
unbounded=True,
conditioned_sigma=True
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
# critic network
net_c1 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
net_c2 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
if args.auto_alpha:
target_entropy = -np.prod(env.action_space.shape)
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
args.alpha = (target_entropy, log_alpha, alpha_optim)
policy = CQLPolicy(
actor,
actor_optim,
critic1,
critic1_optim,
critic2,
critic2_optim,
cql_alpha_lr=args.cql_alpha_lr,
cql_weight=args.cql_weight,
tau=args.tau,
gamma=args.gamma,
alpha=args.alpha,
temperature=args.temperature,
with_lagrange=args.with_lagrange,
lagrange_threshold=args.lagrange_threshold,
min_action=np.min(env.action_space.low),
max_action=np.max(env.action_space.high),
device=args.device,
)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# collector
# buffer has been gathered
# train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
# log
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_cql'
log_path = os.path.join(args.logdir, args.task, 'cql', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
def watch():
policy.load_state_dict(
torch.load(
os.path.join(log_path, 'policy.pth'), map_location=torch.device('cpu')
)
)
policy.eval()
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)
# trainer
result = offline_trainer(
policy,
buffer,
test_collector,
args.epoch,
args.step_per_epoch,
args.test_num,
args.batch_size,
save_fn=save_fn,
stop_fn=stop_fn,
logger=logger,
)
assert stop_fn(result['best_reward'])
# Let's watch its performance!
if __name__ == '__main__':
pprint.pprint(result)
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
if __name__ == '__main__':
test_cql()

View File

@ -20,6 +20,7 @@ from tianshou.policy.modelfree.sac import SACPolicy
from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
from tianshou.policy.imitation.base import ImitationPolicy
from tianshou.policy.imitation.bcq import BCQPolicy
from tianshou.policy.imitation.cql import CQLPolicy
from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy
from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy
from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy
@ -47,6 +48,7 @@ __all__ = [
"DiscreteSACPolicy",
"ImitationPolicy",
"BCQPolicy",
"CQLPolicy",
"DiscreteBCQPolicy",
"DiscreteCQLPolicy",
"DiscreteCRRPolicy",

View File

@ -0,0 +1,293 @@
from typing import Any, Dict, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from tianshou.data import Batch, ReplayBuffer, to_torch
from tianshou.policy import SACPolicy
from tianshou.utils.net.continuous import ActorProb
class CQLPolicy(SACPolicy):
"""Implementation of CQL algorithm. arXiv:2006.04779.
:param ActorProb actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> a)
:param torch.optim.Optimizer actor_optim: the optimizer for actor network.
:param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a))
:param torch.optim.Optimizer critic1_optim: the optimizer for the first
critic network.
:param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a))
:param torch.optim.Optimizer critic2_optim: the optimizer for the second
critic network.
:param float cql_alpha_lr: the learning rate of cql_log_alpha. Default to 1e-4.
:param float cql_weight: the value of alpha. Default to 1.0.
:param float tau: param for soft update of the target network.
Default to 0.005.
:param float gamma: discount factor, in [0, 1]. Default to 0.99.
:param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy
regularization coefficient. Default to 0.2.
If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then
alpha is automatically tuned.
:param float temperature: the value of temperature. Default to 1.0.
:param bool with_lagrange: whether to use Lagrange. Default to True.
:param float lagrange_threshold: the value of tau in CQL(Lagrange).
Default to 10.0.
:param float min_action: The minimum value of each dimension of action.
Default to -1.0.
:param float max_action: The maximum value of each dimension of action.
Default to 1.0.
:param int num_repeat_actions: The number of times the action is repeated
when calculating log-sum-exp. Default to 10.
:param float alpha_min: lower bound for clipping cql_alpha. Default to 0.0.
:param float alpha_max: upper bound for clipping cql_alpha. Default to 1e6.
:param float clip_grad: clip_grad for updating critic network. Default to 1.0.
:param Union[str, torch.device] device: which device to create this model on.
Default to "cpu".
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(
self,
actor: ActorProb,
actor_optim: torch.optim.Optimizer,
critic1: torch.nn.Module,
critic1_optim: torch.optim.Optimizer,
critic2: torch.nn.Module,
critic2_optim: torch.optim.Optimizer,
cql_alpha_lr: float = 1e-4,
cql_weight: float = 1.0,
tau: float = 0.005,
gamma: float = 0.99,
alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2,
temperature: float = 1.0,
with_lagrange: bool = True,
lagrange_threshold: float = 10.0,
min_action: float = -1.0,
max_action: float = 1.0,
num_repeat_actions: int = 10,
alpha_min: float = 0.0,
alpha_max: float = 1e6,
clip_grad: float = 1.0,
device: Union[str, torch.device] = "cpu",
**kwargs: Any
) -> None:
super().__init__(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, tau,
gamma, alpha, **kwargs
)
# There are _target_entropy, _log_alpha, _alpha_optim in SACPolicy.
self.device = device
self.temperature = temperature
self.with_lagrange = with_lagrange
self.lagrange_threshold = lagrange_threshold
self.cql_weight = cql_weight
self.cql_log_alpha = torch.tensor([0.0], requires_grad=True)
self.cql_alpha_optim = torch.optim.Adam([self.cql_log_alpha], lr=cql_alpha_lr)
self.cql_log_alpha = self.cql_log_alpha.to(device)
self.min_action = min_action
self.max_action = max_action
self.num_repeat_actions = num_repeat_actions
self.alpha_min = alpha_min
self.alpha_max = alpha_max
self.clip_grad = clip_grad
def train(self, mode: bool = True) -> "CQLPolicy":
"""Set the module in training mode, except for the target network."""
self.training = mode
self.actor.train(mode)
self.critic1.train(mode)
self.critic2.train(mode)
return self
def sync_weight(self) -> None:
"""Soft-update the weight for the target network."""
for net, net_old in [
[self.critic1, self.critic1_old], [self.critic2, self.critic2_old]
]:
for param, target_param in zip(net.parameters(), net_old.parameters()):
target_param.data.copy_(
self._tau * param.data + (1 - self._tau) * target_param.data
)
def actor_pred(self, obs: torch.Tensor) -> \
Tuple[torch.Tensor, torch.Tensor]:
batch = Batch(obs=obs, info=None)
obs_result = self(batch)
return obs_result.act, obs_result.log_prob
def calc_actor_loss(self, obs: torch.Tensor) -> \
Tuple[torch.Tensor, torch.Tensor]:
act_pred, log_pi = self.actor_pred(obs)
q1 = self.critic1(obs, act_pred)
q2 = self.critic2(obs, act_pred)
min_Q = torch.min(q1, q2)
self._alpha: Union[float, torch.Tensor]
actor_loss = (self._alpha * log_pi - min_Q).mean()
# actor_loss.shape: (), log_pi.shape: (batch_size, 1)
return actor_loss, log_pi
def calc_pi_values(self, obs_pi: torch.Tensor, obs_to_pred: torch.Tensor) -> \
Tuple[torch.Tensor, torch.Tensor]:
act_pred, log_pi = self.actor_pred(obs_pi)
q1 = self.critic1(obs_to_pred, act_pred)
q2 = self.critic2(obs_to_pred, act_pred)
return q1 - log_pi.detach(), q2 - log_pi.detach()
def calc_random_values(self, obs: torch.Tensor, act: torch.Tensor) -> \
Tuple[torch.Tensor, torch.Tensor]:
random_value1 = self.critic1(obs, act)
random_log_prob1 = np.log(0.5**act.shape[-1])
random_value2 = self.critic2(obs, act)
random_log_prob2 = np.log(0.5**act.shape[-1])
return random_value1 - random_log_prob1, random_value2 - random_log_prob2
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray
) -> Batch:
return batch
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
batch: Batch = to_torch( # type: ignore
batch, dtype=torch.float, device=self.device,
)
obs, act, rew, obs_next = batch.obs, batch.act, batch.rew, batch.obs_next
batch_size = obs.shape[0]
# compute actor loss and update actor
actor_loss, log_pi = self.calc_actor_loss(obs)
self.actor_optim.zero_grad()
actor_loss.backward()
self.actor_optim.step()
# compute alpha loss
if self._is_auto_alpha:
log_pi = log_pi + self._target_entropy
alpha_loss = -(self._log_alpha * log_pi.detach()).mean()
self._alpha_optim.zero_grad()
# update log_alpha
alpha_loss.backward()
self._alpha_optim.step()
# update alpha
self._alpha = self._log_alpha.detach().exp()
# compute target_Q
with torch.no_grad():
act_next, new_log_pi = self.actor_pred(obs_next)
target_Q1 = self.critic1_old(obs_next, act_next)
target_Q2 = self.critic2_old(obs_next, act_next)
target_Q = torch.min(target_Q1, target_Q2) - self._alpha * new_log_pi
target_Q = \
rew + self._gamma * (1 - batch.done) * target_Q.flatten()
# shape: (batch_size)
# compute critic loss
current_Q1 = self.critic1(obs, act).flatten()
current_Q2 = self.critic2(obs, act).flatten()
# shape: (batch_size)
critic1_loss = F.mse_loss(current_Q1, target_Q)
critic2_loss = F.mse_loss(current_Q2, target_Q)
# CQL
random_actions = torch.FloatTensor(
batch_size * self.num_repeat_actions, act.shape[-1]
).uniform_(-self.min_action, self.max_action).to(self.device)
tmp_obs = obs.unsqueeze(1) \
.repeat(1, self.num_repeat_actions, 1) \
.view(batch_size * self.num_repeat_actions, obs.shape[-1])
tmp_obs_next = obs_next.unsqueeze(1) \
.repeat(1, self.num_repeat_actions, 1) \
.view(batch_size * self.num_repeat_actions, obs.shape[-1])
# tmp_obs & tmp_obs_next: (batch_size * num_repeat, state_dim)
current_pi_value1, current_pi_value2 = self.calc_pi_values(tmp_obs, tmp_obs)
next_pi_value1, next_pi_value2 = self.calc_pi_values(tmp_obs_next, tmp_obs)
random_value1, random_value2 = self.calc_random_values(tmp_obs, random_actions)
for value in [
current_pi_value1, current_pi_value2, next_pi_value1, next_pi_value2,
random_value1, random_value2
]:
value.reshape(batch_size, self.num_repeat_actions, 1)
# cat q values
cat_q1 = torch.cat([random_value1, current_pi_value1, next_pi_value1], 1)
cat_q2 = torch.cat([random_value2, current_pi_value2, next_pi_value2], 1)
# shape: (batch_size, 3 * num_repeat, 1)
cql1_scaled_loss = \
torch.logsumexp(cat_q1 / self.temperature, dim=1).mean() * \
self.cql_weight * self.temperature - current_Q1.mean() * \
self.cql_weight
cql2_scaled_loss = \
torch.logsumexp(cat_q2 / self.temperature, dim=1).mean() * \
self.cql_weight * self.temperature - current_Q2.mean() * \
self.cql_weight
# shape: (1)
if self.with_lagrange:
cql_alpha = torch.clamp(
self.cql_log_alpha.exp(),
self.alpha_min,
self.alpha_max,
)
cql1_scaled_loss = \
cql_alpha * (cql1_scaled_loss - self.lagrange_threshold)
cql2_scaled_loss = \
cql_alpha * (cql2_scaled_loss - self.lagrange_threshold)
self.cql_alpha_optim.zero_grad()
cql_alpha_loss = -(cql1_scaled_loss + cql2_scaled_loss) * 0.5
cql_alpha_loss.backward(retain_graph=True)
self.cql_alpha_optim.step()
critic1_loss = critic1_loss + cql1_scaled_loss
critic2_loss = critic2_loss + cql2_scaled_loss
# update critic
self.critic1_optim.zero_grad()
critic1_loss.backward(retain_graph=True)
# clip grad, prevent the vanishing gradient problem
# It doesn't seem necessary
clip_grad_norm_(self.critic1.parameters(), self.clip_grad)
self.critic1_optim.step()
self.critic2_optim.zero_grad()
critic2_loss.backward()
clip_grad_norm_(self.critic2.parameters(), self.clip_grad)
self.critic2_optim.step()
self.sync_weight()
result = {
"loss/actor": actor_loss.item(),
"loss/critic1": critic1_loss.item(),
"loss/critic2": critic2_loss.item(),
}
if self._is_auto_alpha:
result["loss/alpha"] = alpha_loss.item()
result["alpha"] = self._alpha.item() # type: ignore
if self.with_lagrange:
result["loss/cql_alpha"] = cql_alpha_loss.item()
result["cql_alpha"] = cql_alpha.item()
return result