diff --git a/README.md b/README.md index 3459478..a8fb2f0 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,7 @@ - [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf) - Vanilla Imitation Learning - [Batch-Constrained deep Q-Learning (BCQ)](https://arxiv.org/pdf/1812.02900.pdf) +- [Conservative Q-Learning (CQL)](https://arxiv.org/pdf/2006.04779.pdf) - [Discrete Batch-Constrained deep Q-Learning (BCQ-Discrete)](https://arxiv.org/pdf/1910.01708.pdf) - [Discrete Conservative Q-Learning (CQL-Discrete)](https://arxiv.org/pdf/2006.04779.pdf) - [Discrete Critic Regularized Regression (CRR-Discrete)](https://arxiv.org/pdf/2006.15134.pdf) diff --git a/docs/api/tianshou.policy.rst b/docs/api/tianshou.policy.rst index b44598a..c306366 100644 --- a/docs/api/tianshou.policy.rst +++ b/docs/api/tianshou.policy.rst @@ -114,6 +114,11 @@ Imitation :undoc-members: :show-inheritance: +.. autoclass:: tianshou.policy.CQLPolicy + :members: + :undoc-members: + :show-inheritance: + .. autoclass:: tianshou.policy.DiscreteBCQPolicy :members: :undoc-members: diff --git a/docs/index.rst b/docs/index.rst index 7341ec0..13adaed 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -28,6 +28,7 @@ Welcome to Tianshou! * :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic `_ * :class:`~tianshou.policy.ImitationPolicy` Imitation Learning * :class:`~tianshou.policy.BCQPolicy` `Batch-Constrained deep Q-Learning `_ +* :class:`~tianshou.policy.CQLPolicy` `Conservative Q-Learning `_ * :class:`~tianshou.policy.DiscreteBCQPolicy` `Discrete Batch-Constrained deep Q-Learning `_ * :class:`~tianshou.policy.DiscreteCQLPolicy` `Discrete Conservative Q-Learning `_ * :class:`~tianshou.policy.DiscreteCRRPolicy` `Critic Regularized Regression `_ diff --git a/examples/offline/README.md b/examples/offline/README.md index c0a07fa..1ac98ec 100644 --- a/examples/offline/README.md +++ b/examples/offline/README.md @@ -2,10 +2,12 @@ In offline reinforcement learning setting, the agent learns a policy from a fixed dataset which is collected once with any policy. And the agent does not interact with environment anymore. -## Continous control +## Continuous control Once the dataset is collected, it will not be changed during training. We use [d4rl](https://github.com/rail-berkeley/d4rl) datasets to train offline agent for continuous control. You can refer to [d4rl](https://github.com/rail-berkeley/d4rl) to see how to use d4rl datasets. +We provide implementation of BCQ and CQL algorithm for continuous control. + ### Train Tianshou provides an `offline_trainer` for offline reinforcement learning. You can parse d4rl datasets into a `ReplayBuffer` , and set it as the parameter `buffer` of `offline_trainer`. `offline_bcq.py` is an example of offline RL using the d4rl dataset. @@ -20,7 +22,7 @@ After 1M steps: ![halfcheetah-expert-v1_reward](results/bcq/halfcheetah-expert-v1_reward.png) -`halfcheetah-expert-v1` is a mujoco environment. The setting of hyperparameters are similar to the offpolicy algorithms in mujoco environment. +`halfcheetah-expert-v1` is a mujoco environment. The setting of hyperparameters are similar to the off-policy algorithms in mujoco environment. ## Results diff --git a/examples/offline/offline_cql.py b/examples/offline/offline_cql.py new file mode 100644 index 0000000..f494200 --- /dev/null +++ b/examples/offline/offline_cql.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 +import argparse +import datetime +import os +import pprint + +import d4rl +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Batch, Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.env import SubprocVectorEnv +from tianshou.policy import CQLPolicy +from tianshou.trainer import offline_trainer +from tianshou.utils import BasicLogger +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import ActorProb, Critic + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='halfcheetah-medium-v1') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--buffer-size', type=int, default=1000000) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[256, 256]) + parser.add_argument('--actor-lr', type=float, default=1e-4) + parser.add_argument('--critic-lr', type=float, default=3e-4) + parser.add_argument('--alpha', type=float, default=0.2) + parser.add_argument('--auto-alpha', default=True, action='store_true') + parser.add_argument('--alpha-lr', type=float, default=1e-4) + parser.add_argument('--cql-alpha-lr', type=float, default=3e-4) + parser.add_argument("--start-timesteps", type=int, default=10000) + parser.add_argument('--epoch', type=int, default=200) + parser.add_argument('--step-per-epoch', type=int, default=5000) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--batch-size', type=int, default=256) + + parser.add_argument("--tau", type=float, default=0.005) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--cql-weight", type=float, default=1.0) + parser.add_argument("--with-lagrange", type=bool, default=True) + parser.add_argument("--lagrange-threshold", type=float, default=10.0) + parser.add_argument("--gamma", type=float, default=0.99) + + parser.add_argument("--eval-freq", type=int, default=1) + parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--test-num', type=int, default=10) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=1 / 35) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only', + ) + return parser.parse_args() + + +def test_cql(): + args = get_args() + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] # float + print("device:", args.device) + print("Observations shape:", args.state_shape) + print("Actions shape:", args.action_shape) + print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) + + args.state_dim = args.state_shape[0] + args.action_dim = args.action_shape[0] + print("Max_action", args.max_action) + + # train_envs = gym.make(args.task) + train_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)] + ) + # test_envs = gym.make(args.task) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + + # model + # actor network + net_a = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + ) + actor = ActorProb( + net_a, + action_shape=args.action_shape, + max_action=args.max_action, + device=args.device, + unbounded=True, + conditioned_sigma=True + ).to(args.device) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + + # critic network + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + ) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + ) + critic1 = Critic(net_c1, device=args.device).to(args.device) + critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic2 = Critic(net_c2, device=args.device).to(args.device) + critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + + if args.auto_alpha: + target_entropy = -np.prod(env.action_space.shape) + log_alpha = torch.zeros(1, requires_grad=True, device=args.device) + alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) + args.alpha = (target_entropy, log_alpha, alpha_optim) + + policy = CQLPolicy( + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + cql_alpha_lr=args.cql_alpha_lr, + cql_weight=args.cql_weight, + tau=args.tau, + gamma=args.gamma, + alpha=args.alpha, + temperature=args.temperature, + with_lagrange=args.with_lagrange, + lagrange_threshold=args.lagrange_threshold, + min_action=np.min(env.action_space.low), + max_action=np.max(env.action_space.high), + device=args.device, + ) + + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + + # collector + if args.training_num > 1: + buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) + else: + buffer = ReplayBuffer(args.buffer_size) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs) + train_collector.collect(n_step=args.start_timesteps, random=True) + # log + t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") + log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_cql' + log_path = os.path.join(args.logdir, args.task, 'cql', log_file) + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = BasicLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def watch(): + if args.resume_path is None: + args.resume_path = os.path.join(log_path, 'policy.pth') + + policy.load_state_dict( + torch.load(args.resume_path, map_location=torch.device('cpu')) + ) + policy.eval() + collector = Collector(policy, env) + collector.collect(n_episode=1, render=1 / 35) + + if not args.watch: + dataset = d4rl.qlearning_dataset(env) + dataset_size = dataset['rewards'].size + + print("dataset_size", dataset_size) + replay_buffer = ReplayBuffer(dataset_size) + + for i in range(dataset_size): + replay_buffer.add( + Batch( + obs=dataset['observations'][i], + act=dataset['actions'][i], + rew=dataset['rewards'][i], + done=dataset['terminals'][i], + obs_next=dataset['next_observations'][i], + ) + ) + print("dataset loaded") + # trainer + result = offline_trainer( + policy, + replay_buffer, + test_collector, + args.epoch, + args.step_per_epoch, + args.test_num, + args.batch_size, + save_fn=save_fn, + logger=logger, + ) + pprint.pprint(result) + else: + watch() + + # Let's watch its performance! + policy.eval() + test_envs.seed(args.seed) + test_collector.reset() + result = test_collector.collect(n_episode=args.test_num, render=args.render) + print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') + + +if __name__ == '__main__': + test_cql() diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py new file mode 100644 index 0000000..ae1507a --- /dev/null +++ b/test/offline/test_cql.py @@ -0,0 +1,219 @@ +import argparse +import datetime +import os +import pickle +import pprint + +import gym +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter + +from tianshou.data import Collector +from tianshou.env import SubprocVectorEnv +from tianshou.policy import CQLPolicy +from tianshou.trainer import offline_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net +from tianshou.utils.net.continuous import ActorProb, Critic + +if __name__ == "__main__": + from gather_pendulum_data import gather_data +else: # pytest + from test.offline.gather_pendulum_data import gather_data + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='Pendulum-v0') + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[128, 128]) + parser.add_argument('--actor-lr', type=float, default=1e-3) + parser.add_argument('--critic-lr', type=float, default=1e-3) + parser.add_argument('--alpha', type=float, default=0.2) + parser.add_argument('--auto-alpha', default=True, action='store_true') + parser.add_argument('--alpha-lr', type=float, default=1e-3) + parser.add_argument('--cql-alpha-lr', type=float, default=1e-3) + parser.add_argument("--start-timesteps", type=int, default=10000) + parser.add_argument('--epoch', type=int, default=20) + parser.add_argument('--step-per-epoch', type=int, default=2000) + parser.add_argument('--n-step', type=int, default=3) + parser.add_argument('--batch-size', type=int, default=256) + + parser.add_argument("--tau", type=float, default=0.005) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--cql-weight", type=float, default=1.0) + parser.add_argument("--with-lagrange", type=bool, default=True) + parser.add_argument("--lagrange-threshold", type=float, default=10.0) + parser.add_argument("--gamma", type=float, default=0.99) + + parser.add_argument("--eval-freq", type=int, default=1) + parser.add_argument('--training-num', type=int, default=10) + parser.add_argument('--test-num', type=int, default=10) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument('--render', type=float, default=1 / 35) + parser.add_argument( + '--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu' + ) + parser.add_argument('--resume-path', type=str, default=None) + parser.add_argument( + '--watch', + default=False, + action='store_true', + help='watch the play of pre-trained policy only', + ) + parser.add_argument( + "--load-buffer-name", type=str, default="./expert_SAC_Pendulum-v0.pkl" + ) + args = parser.parse_known_args()[0] + return args + + +def test_cql(args=get_args()): + if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): + buffer = pickle.load(open(args.load_buffer_name, "rb")) + else: + buffer = gather_data() + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] # float + if args.task == 'Pendulum-v0': + env.spec.reward_threshold = -1200 # too low? + + args.state_dim = args.state_shape[0] + args.action_dim = args.action_shape[0] + # test_envs = gym.make(args.task) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + test_envs.seed(args.seed) + + # model + # actor network + net_a = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + device=args.device, + ) + actor = ActorProb( + net_a, + action_shape=args.action_shape, + max_action=args.max_action, + device=args.device, + unbounded=True, + conditioned_sigma=True + ).to(args.device) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) + + # critic network + net_c1 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + ) + net_c2 = Net( + args.state_shape, + args.action_shape, + hidden_sizes=args.hidden_sizes, + concat=True, + device=args.device, + ) + critic1 = Critic(net_c1, device=args.device).to(args.device) + critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) + critic2 = Critic(net_c2, device=args.device).to(args.device) + critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr) + + if args.auto_alpha: + target_entropy = -np.prod(env.action_space.shape) + log_alpha = torch.zeros(1, requires_grad=True, device=args.device) + alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) + args.alpha = (target_entropy, log_alpha, alpha_optim) + + policy = CQLPolicy( + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + cql_alpha_lr=args.cql_alpha_lr, + cql_weight=args.cql_weight, + tau=args.tau, + gamma=args.gamma, + alpha=args.alpha, + temperature=args.temperature, + with_lagrange=args.with_lagrange, + lagrange_threshold=args.lagrange_threshold, + min_action=np.min(env.action_space.low), + max_action=np.max(env.action_space.high), + device=args.device, + ) + + # load a previous policy + if args.resume_path: + policy.load_state_dict(torch.load(args.resume_path, map_location=args.device)) + print("Loaded agent from: ", args.resume_path) + + # collector + # buffer has been gathered + # train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, test_envs) + # log + t0 = datetime.datetime.now().strftime("%m%d_%H%M%S") + log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_cql' + log_path = os.path.join(args.logdir, args.task, 'cql', log_file) + writer = SummaryWriter(log_path) + writer.add_text("args", str(args)) + logger = TensorboardLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + def watch(): + policy.load_state_dict( + torch.load( + os.path.join(log_path, 'policy.pth'), map_location=torch.device('cpu') + ) + ) + policy.eval() + collector = Collector(policy, env) + collector.collect(n_episode=1, render=1 / 35) + + # trainer + result = offline_trainer( + policy, + buffer, + test_collector, + args.epoch, + args.step_per_epoch, + args.test_num, + args.batch_size, + save_fn=save_fn, + stop_fn=stop_fn, + logger=logger, + ) + assert stop_fn(result['best_reward']) + + # Let's watch its performance! + if __name__ == '__main__': + pprint.pprint(result) + env = gym.make(args.task) + policy.eval() + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=args.render) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + + +if __name__ == '__main__': + test_cql() diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index f8c8441..ced11af 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -20,6 +20,7 @@ from tianshou.policy.modelfree.sac import SACPolicy from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy from tianshou.policy.imitation.base import ImitationPolicy from tianshou.policy.imitation.bcq import BCQPolicy +from tianshou.policy.imitation.cql import CQLPolicy from tianshou.policy.imitation.discrete_bcq import DiscreteBCQPolicy from tianshou.policy.imitation.discrete_cql import DiscreteCQLPolicy from tianshou.policy.imitation.discrete_crr import DiscreteCRRPolicy @@ -47,6 +48,7 @@ __all__ = [ "DiscreteSACPolicy", "ImitationPolicy", "BCQPolicy", + "CQLPolicy", "DiscreteBCQPolicy", "DiscreteCQLPolicy", "DiscreteCRRPolicy", diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py new file mode 100644 index 0000000..c2df77c --- /dev/null +++ b/tianshou/policy/imitation/cql.py @@ -0,0 +1,293 @@ +from typing import Any, Dict, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch.nn.utils import clip_grad_norm_ + +from tianshou.data import Batch, ReplayBuffer, to_torch +from tianshou.policy import SACPolicy +from tianshou.utils.net.continuous import ActorProb + + +class CQLPolicy(SACPolicy): + """Implementation of CQL algorithm. arXiv:2006.04779. + + :param ActorProb actor: the actor network following the rules in + :class:`~tianshou.policy.BasePolicy`. (s -> a) + :param torch.optim.Optimizer actor_optim: the optimizer for actor network. + :param torch.nn.Module critic1: the first critic network. (s, a -> Q(s, a)) + :param torch.optim.Optimizer critic1_optim: the optimizer for the first + critic network. + :param torch.nn.Module critic2: the second critic network. (s, a -> Q(s, a)) + :param torch.optim.Optimizer critic2_optim: the optimizer for the second + critic network. + :param float cql_alpha_lr: the learning rate of cql_log_alpha. Default to 1e-4. + :param float cql_weight: the value of alpha. Default to 1.0. + :param float tau: param for soft update of the target network. + Default to 0.005. + :param float gamma: discount factor, in [0, 1]. Default to 0.99. + :param (float, torch.Tensor, torch.optim.Optimizer) or float alpha: entropy + regularization coefficient. Default to 0.2. + If a tuple (target_entropy, log_alpha, alpha_optim) is provided, then + alpha is automatically tuned. + :param float temperature: the value of temperature. Default to 1.0. + :param bool with_lagrange: whether to use Lagrange. Default to True. + :param float lagrange_threshold: the value of tau in CQL(Lagrange). + Default to 10.0. + :param float min_action: The minimum value of each dimension of action. + Default to -1.0. + :param float max_action: The maximum value of each dimension of action. + Default to 1.0. + :param int num_repeat_actions: The number of times the action is repeated + when calculating log-sum-exp. Default to 10. + :param float alpha_min: lower bound for clipping cql_alpha. Default to 0.0. + :param float alpha_max: upper bound for clipping cql_alpha. Default to 1e6. + :param float clip_grad: clip_grad for updating critic network. Default to 1.0. + :param Union[str, torch.device] device: which device to create this model on. + Default to "cpu". + + .. seealso:: + + Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed + explanation. + """ + + def __init__( + self, + actor: ActorProb, + actor_optim: torch.optim.Optimizer, + critic1: torch.nn.Module, + critic1_optim: torch.optim.Optimizer, + critic2: torch.nn.Module, + critic2_optim: torch.optim.Optimizer, + cql_alpha_lr: float = 1e-4, + cql_weight: float = 1.0, + tau: float = 0.005, + gamma: float = 0.99, + alpha: Union[float, Tuple[float, torch.Tensor, torch.optim.Optimizer]] = 0.2, + temperature: float = 1.0, + with_lagrange: bool = True, + lagrange_threshold: float = 10.0, + min_action: float = -1.0, + max_action: float = 1.0, + num_repeat_actions: int = 10, + alpha_min: float = 0.0, + alpha_max: float = 1e6, + clip_grad: float = 1.0, + device: Union[str, torch.device] = "cpu", + **kwargs: Any + ) -> None: + super().__init__( + actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, tau, + gamma, alpha, **kwargs + ) + # There are _target_entropy, _log_alpha, _alpha_optim in SACPolicy. + self.device = device + self.temperature = temperature + self.with_lagrange = with_lagrange + self.lagrange_threshold = lagrange_threshold + + self.cql_weight = cql_weight + + self.cql_log_alpha = torch.tensor([0.0], requires_grad=True) + self.cql_alpha_optim = torch.optim.Adam([self.cql_log_alpha], lr=cql_alpha_lr) + self.cql_log_alpha = self.cql_log_alpha.to(device) + + self.min_action = min_action + self.max_action = max_action + + self.num_repeat_actions = num_repeat_actions + + self.alpha_min = alpha_min + self.alpha_max = alpha_max + self.clip_grad = clip_grad + + def train(self, mode: bool = True) -> "CQLPolicy": + """Set the module in training mode, except for the target network.""" + self.training = mode + self.actor.train(mode) + self.critic1.train(mode) + self.critic2.train(mode) + return self + + def sync_weight(self) -> None: + """Soft-update the weight for the target network.""" + for net, net_old in [ + [self.critic1, self.critic1_old], [self.critic2, self.critic2_old] + ]: + for param, target_param in zip(net.parameters(), net_old.parameters()): + target_param.data.copy_( + self._tau * param.data + (1 - self._tau) * target_param.data + ) + + def actor_pred(self, obs: torch.Tensor) -> \ + Tuple[torch.Tensor, torch.Tensor]: + batch = Batch(obs=obs, info=None) + obs_result = self(batch) + return obs_result.act, obs_result.log_prob + + def calc_actor_loss(self, obs: torch.Tensor) -> \ + Tuple[torch.Tensor, torch.Tensor]: + act_pred, log_pi = self.actor_pred(obs) + q1 = self.critic1(obs, act_pred) + q2 = self.critic2(obs, act_pred) + min_Q = torch.min(q1, q2) + self._alpha: Union[float, torch.Tensor] + actor_loss = (self._alpha * log_pi - min_Q).mean() + # actor_loss.shape: (), log_pi.shape: (batch_size, 1) + return actor_loss, log_pi + + def calc_pi_values(self, obs_pi: torch.Tensor, obs_to_pred: torch.Tensor) -> \ + Tuple[torch.Tensor, torch.Tensor]: + act_pred, log_pi = self.actor_pred(obs_pi) + + q1 = self.critic1(obs_to_pred, act_pred) + q2 = self.critic2(obs_to_pred, act_pred) + + return q1 - log_pi.detach(), q2 - log_pi.detach() + + def calc_random_values(self, obs: torch.Tensor, act: torch.Tensor) -> \ + Tuple[torch.Tensor, torch.Tensor]: + random_value1 = self.critic1(obs, act) + random_log_prob1 = np.log(0.5**act.shape[-1]) + + random_value2 = self.critic2(obs, act) + random_log_prob2 = np.log(0.5**act.shape[-1]) + + return random_value1 - random_log_prob1, random_value2 - random_log_prob2 + + def process_fn( + self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray + ) -> Batch: + return batch + + def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: + batch: Batch = to_torch( # type: ignore + batch, dtype=torch.float, device=self.device, + ) + obs, act, rew, obs_next = batch.obs, batch.act, batch.rew, batch.obs_next + batch_size = obs.shape[0] + + # compute actor loss and update actor + actor_loss, log_pi = self.calc_actor_loss(obs) + self.actor_optim.zero_grad() + actor_loss.backward() + self.actor_optim.step() + + # compute alpha loss + if self._is_auto_alpha: + log_pi = log_pi + self._target_entropy + alpha_loss = -(self._log_alpha * log_pi.detach()).mean() + self._alpha_optim.zero_grad() + # update log_alpha + alpha_loss.backward() + self._alpha_optim.step() + # update alpha + self._alpha = self._log_alpha.detach().exp() + + # compute target_Q + with torch.no_grad(): + act_next, new_log_pi = self.actor_pred(obs_next) + + target_Q1 = self.critic1_old(obs_next, act_next) + target_Q2 = self.critic2_old(obs_next, act_next) + + target_Q = torch.min(target_Q1, target_Q2) - self._alpha * new_log_pi + + target_Q = \ + rew + self._gamma * (1 - batch.done) * target_Q.flatten() + # shape: (batch_size) + + # compute critic loss + current_Q1 = self.critic1(obs, act).flatten() + current_Q2 = self.critic2(obs, act).flatten() + # shape: (batch_size) + + critic1_loss = F.mse_loss(current_Q1, target_Q) + critic2_loss = F.mse_loss(current_Q2, target_Q) + + # CQL + random_actions = torch.FloatTensor( + batch_size * self.num_repeat_actions, act.shape[-1] + ).uniform_(-self.min_action, self.max_action).to(self.device) + tmp_obs = obs.unsqueeze(1) \ + .repeat(1, self.num_repeat_actions, 1) \ + .view(batch_size * self.num_repeat_actions, obs.shape[-1]) + tmp_obs_next = obs_next.unsqueeze(1) \ + .repeat(1, self.num_repeat_actions, 1) \ + .view(batch_size * self.num_repeat_actions, obs.shape[-1]) + # tmp_obs & tmp_obs_next: (batch_size * num_repeat, state_dim) + + current_pi_value1, current_pi_value2 = self.calc_pi_values(tmp_obs, tmp_obs) + next_pi_value1, next_pi_value2 = self.calc_pi_values(tmp_obs_next, tmp_obs) + + random_value1, random_value2 = self.calc_random_values(tmp_obs, random_actions) + + for value in [ + current_pi_value1, current_pi_value2, next_pi_value1, next_pi_value2, + random_value1, random_value2 + ]: + value.reshape(batch_size, self.num_repeat_actions, 1) + + # cat q values + cat_q1 = torch.cat([random_value1, current_pi_value1, next_pi_value1], 1) + cat_q2 = torch.cat([random_value2, current_pi_value2, next_pi_value2], 1) + # shape: (batch_size, 3 * num_repeat, 1) + + cql1_scaled_loss = \ + torch.logsumexp(cat_q1 / self.temperature, dim=1).mean() * \ + self.cql_weight * self.temperature - current_Q1.mean() * \ + self.cql_weight + cql2_scaled_loss = \ + torch.logsumexp(cat_q2 / self.temperature, dim=1).mean() * \ + self.cql_weight * self.temperature - current_Q2.mean() * \ + self.cql_weight + # shape: (1) + + if self.with_lagrange: + cql_alpha = torch.clamp( + self.cql_log_alpha.exp(), + self.alpha_min, + self.alpha_max, + ) + cql1_scaled_loss = \ + cql_alpha * (cql1_scaled_loss - self.lagrange_threshold) + cql2_scaled_loss = \ + cql_alpha * (cql2_scaled_loss - self.lagrange_threshold) + + self.cql_alpha_optim.zero_grad() + cql_alpha_loss = -(cql1_scaled_loss + cql2_scaled_loss) * 0.5 + cql_alpha_loss.backward(retain_graph=True) + self.cql_alpha_optim.step() + + critic1_loss = critic1_loss + cql1_scaled_loss + critic2_loss = critic2_loss + cql2_scaled_loss + + # update critic + self.critic1_optim.zero_grad() + critic1_loss.backward(retain_graph=True) + # clip grad, prevent the vanishing gradient problem + # It doesn't seem necessary + clip_grad_norm_(self.critic1.parameters(), self.clip_grad) + self.critic1_optim.step() + + self.critic2_optim.zero_grad() + critic2_loss.backward() + clip_grad_norm_(self.critic2.parameters(), self.clip_grad) + self.critic2_optim.step() + + self.sync_weight() + + result = { + "loss/actor": actor_loss.item(), + "loss/critic1": critic1_loss.item(), + "loss/critic2": critic2_loss.item(), + } + if self._is_auto_alpha: + result["loss/alpha"] = alpha_loss.item() + result["alpha"] = self._alpha.item() # type: ignore + if self.with_lagrange: + result["loss/cql_alpha"] = cql_alpha_loss.item() + result["cql_alpha"] = cql_alpha.item() + return result