Michael Panchenko b900fdf6f2
Remove kwargs in policy init (#950)
Closes #947 

This removes all kwargs from all policy constructors. While doing that,
I also improved several names and added a whole lot of TODOs.

## Functional changes:

1. Added possibility to pass None as `critic2` and `critic2_optim`. In
fact, the default behavior then should cover the absolute majority of
cases
2. Added a function called `clone_optimizer` as a temporary measure to
support passing `critic2_optim=None`

## Breaking changes:

1. `action_space` is no longer optional. In fact, it already was
non-optional, as there was a ValueError in BasePolicy.init. So now
several examples were fixed to reflect that
2. `reward_normalization` removed from DDPG and children. It was never
allowed to pass it as `True` there, an error would have been raised in
`compute_n_step_reward`. Now I removed it from the interface
3. renamed `critic1` and similar to `critic`, in order to have uniform
interfaces. Note that the `critic` in DDPG was optional for the sole
reason that child classes used `critic1`. I removed this optionality
(DDPG can't do anything with `critic=None`)
4. Several renamings of fields (mostly private to public, so backwards
compatible)

## Additional changes: 
1. Removed type and default declaration from docstring. This kind of
duplication is really not necessary
2. Policy constructors are now only called using named arguments, not a
fragile mixture of positional and named as before
5. Minor beautifications in typing and code 
6. Generally shortened docstrings and made them uniform across all
policies (hopefully)

## Comment:

With these changes, several problems in tianshou's inheritance hierarchy
become more apparent. I tried highlighting them for future work.

---------

Co-authored-by: Dominik Jain <d.jain@appliedai.de>
2023-10-08 08:57:03 -07:00

389 lines
15 KiB
Python

from typing import Any, Literal, Self, cast
import gymnasium as gym
import numpy as np
import torch
import torch.nn.functional as F
from overrides import override
from torch.nn.utils import clip_grad_norm_
from tianshou.data import Batch, ReplayBuffer, to_torch
from tianshou.data.buffer.base import TBuffer
from tianshou.data.types import RolloutBatchProtocol
from tianshou.exploration import BaseNoise
from tianshou.policy import SACPolicy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.utils.net.continuous import ActorProb
class CQLPolicy(SACPolicy):
"""Implementation of CQL algorithm. arXiv:2006.04779.
:param actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> a)
:param actor_optim: The optimizer for actor network.
:param critic: The first critic network.
:param critic_optim: The optimizer for the first critic network.
:param action_space: Env's action space.
:param critic2: the second critic network. (s, a -> Q(s, a)).
If None, use the same network as critic (via deepcopy).
:param critic2_optim: the optimizer for the second critic network.
If None, clone critic_optim to use for critic2.parameters().
:param cql_alpha_lr: The learning rate of cql_log_alpha.
:param cql_weight:
:param tau: Parameter for soft update of the target network.
:param gamma: Discount factor, in [0, 1].
:param alpha: Entropy regularization coefficient or a tuple
(target_entropy, log_alpha, alpha_optim) for automatic tuning.
:param temperature:
:param with_lagrange: Whether to use Lagrange.
TODO: extend documentation - what does this mean?
:param lagrange_threshold: The value of tau in CQL(Lagrange).
:param min_action: The minimum value of each dimension of action.
:param max_action: The maximum value of each dimension of action.
:param num_repeat_actions: The number of times the action is repeated when calculating log-sum-exp.
:param alpha_min: Lower bound for clipping cql_alpha.
:param alpha_max: Upper bound for clipping cql_alpha.
:param clip_grad: Clip_grad for updating critic network.
:param calibrated: calibrate Q-values as in CalQL paper `arXiv:2303.05479`.
Useful for offline pre-training followed by online training,
and also was observed to achieve better results than vanilla cql.
:param device: Which device to create this model on.
:param estimation_step: Estimation steps.
:param exploration_noise: Type of exploration noise.
:param deterministic_eval: Flag for deterministic evaluation.
:param action_scaling: Flag for action scaling.
:param action_bound_method: Method for action bounding. Only used if the
action_space is continuous.
:param observation_space: Env's Observation space.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update().
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(
self,
*,
actor: ActorProb,
actor_optim: torch.optim.Optimizer,
critic: torch.nn.Module,
critic_optim: torch.optim.Optimizer,
action_space: gym.spaces.Box,
critic2: torch.nn.Module | None = None,
critic2_optim: torch.optim.Optimizer | None = None,
cql_alpha_lr: float = 1e-4,
cql_weight: float = 1.0,
tau: float = 0.005,
gamma: float = 0.99,
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2,
temperature: float = 1.0,
with_lagrange: bool = True,
lagrange_threshold: float = 10.0,
min_action: float = -1.0,
max_action: float = 1.0,
num_repeat_actions: int = 10,
alpha_min: float = 0.0,
alpha_max: float = 1e6,
clip_grad: float = 1.0,
calibrated: bool = True,
# TODO: why does this one have device? Almost no other policies have it
device: str | torch.device = "cpu",
estimation_step: int = 1,
exploration_noise: BaseNoise | Literal["default"] | None = None,
deterministic_eval: bool = True,
action_scaling: bool = True,
action_bound_method: Literal["clip"] | None = "clip",
observation_space: gym.Space | None = None,
lr_scheduler: TLearningRateScheduler | None = None,
) -> None:
super().__init__(
actor=actor,
actor_optim=actor_optim,
critic=critic,
critic_optim=critic_optim,
action_space=action_space,
critic2=critic2,
critic2_optim=critic2_optim,
tau=tau,
gamma=gamma,
deterministic_eval=deterministic_eval,
alpha=alpha,
exploration_noise=exploration_noise,
estimation_step=estimation_step,
action_scaling=action_scaling,
action_bound_method=action_bound_method,
observation_space=observation_space,
lr_scheduler=lr_scheduler,
)
# There are _target_entropy, _log_alpha, _alpha_optim in SACPolicy.
self.device = device
self.temperature = temperature
self.with_lagrange = with_lagrange
self.lagrange_threshold = lagrange_threshold
self.cql_weight = cql_weight
self.cql_log_alpha = torch.tensor([0.0], requires_grad=True)
self.cql_alpha_optim = torch.optim.Adam([self.cql_log_alpha], lr=cql_alpha_lr)
self.cql_log_alpha = self.cql_log_alpha.to(device)
self.min_action = min_action
self.max_action = max_action
self.num_repeat_actions = num_repeat_actions
self.alpha_min = alpha_min
self.alpha_max = alpha_max
self.clip_grad = clip_grad
self.calibrated = calibrated
def train(self, mode: bool = True) -> Self:
"""Set the module in training mode, except for the target network."""
self.training = mode
self.actor.train(mode)
self.critic.train(mode)
self.critic2.train(mode)
return self
def sync_weight(self) -> None:
"""Soft-update the weight for the target network."""
self.soft_update(self.critic_old, self.critic, self.tau)
self.soft_update(self.critic2_old, self.critic2, self.tau)
def actor_pred(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
batch = Batch(obs=obs, info=None)
obs_result = self(batch)
return obs_result.act, obs_result.log_prob
def calc_actor_loss(self, obs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
act_pred, log_pi = self.actor_pred(obs)
q1 = self.critic(obs, act_pred)
q2 = self.critic2(obs, act_pred)
min_Q = torch.min(q1, q2)
# self.alpha: float | torch.Tensor
actor_loss = (self.alpha * log_pi - min_Q).mean()
# actor_loss.shape: (), log_pi.shape: (batch_size, 1)
return actor_loss, log_pi
def calc_pi_values(
self,
obs_pi: torch.Tensor,
obs_to_pred: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
act_pred, log_pi = self.actor_pred(obs_pi)
q1 = self.critic(obs_to_pred, act_pred)
q2 = self.critic2(obs_to_pred, act_pred)
return q1 - log_pi.detach(), q2 - log_pi.detach()
def calc_random_values(
self,
obs: torch.Tensor,
act: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
random_value1 = self.critic(obs, act)
random_log_prob1 = np.log(0.5 ** act.shape[-1])
random_value2 = self.critic2(obs, act)
random_log_prob2 = np.log(0.5 ** act.shape[-1])
return random_value1 - random_log_prob1, random_value2 - random_log_prob2
@override
def process_buffer(self, buffer: TBuffer) -> TBuffer:
"""If `self.calibrated = True`, adds `calibration_returns` to buffer._meta.
:param buffer:
:return:
"""
if self.calibrated:
# otherwise _meta hack cannot work
assert isinstance(buffer, ReplayBuffer)
batch, indices = buffer.sample(0)
returns, _ = self.compute_episodic_return(
batch=batch,
buffer=buffer,
indices=indices,
gamma=self.gamma,
gae_lambda=1.0,
)
# TODO: don't access _meta directly
buffer._meta = cast(
RolloutBatchProtocol,
Batch(**buffer._meta.__dict__, calibration_returns=returns),
)
return buffer
def process_fn(
self,
batch: RolloutBatchProtocol,
buffer: ReplayBuffer,
indices: np.ndarray,
) -> RolloutBatchProtocol:
# TODO: mypy rightly complains here b/c the design violates
# Liskov Substitution Principle
# DDPGPolicy.process_fn() results in a batch with returns but
# CQLPolicy.process_fn() doesn't add the returns.
# Should probably be fixed!
return batch
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> dict[str, float]:
batch: Batch = to_torch(batch, dtype=torch.float, device=self.device)
obs, act, rew, obs_next = batch.obs, batch.act, batch.rew, batch.obs_next
batch_size = obs.shape[0]
# compute actor loss and update actor
actor_loss, log_pi = self.calc_actor_loss(obs)
self.actor_optim.zero_grad()
actor_loss.backward()
self.actor_optim.step()
# compute alpha loss
if self.is_auto_alpha:
log_pi = log_pi + self.target_entropy
alpha_loss = -(self.log_alpha * log_pi.detach()).mean()
self.alpha_optim.zero_grad()
# update log_alpha
alpha_loss.backward()
self.alpha_optim.step()
# update alpha
# TODO: it's probably a bad idea to track both alpha and log_alpha in different fields
self.alpha = self.log_alpha.detach().exp()
# compute target_Q
with torch.no_grad():
act_next, new_log_pi = self.actor_pred(obs_next)
target_Q1 = self.critic_old(obs_next, act_next)
target_Q2 = self.critic2_old(obs_next, act_next)
target_Q = torch.min(target_Q1, target_Q2) - self.alpha * new_log_pi
target_Q = rew + self.gamma * (1 - batch.done) * target_Q.flatten()
# shape: (batch_size)
# compute critic loss
current_Q1 = self.critic(obs, act).flatten()
current_Q2 = self.critic2(obs, act).flatten()
# shape: (batch_size)
critic1_loss = F.mse_loss(current_Q1, target_Q)
critic2_loss = F.mse_loss(current_Q2, target_Q)
# CQL
random_actions = (
torch.FloatTensor(batch_size * self.num_repeat_actions, act.shape[-1])
.uniform_(-self.min_action, self.max_action)
.to(self.device)
)
obs_len = len(obs.shape)
repeat_size = [1, self.num_repeat_actions] + [1] * (obs_len - 1)
view_size = [batch_size * self.num_repeat_actions, *list(obs.shape[1:])]
tmp_obs = obs.unsqueeze(1).repeat(*repeat_size).view(*view_size)
tmp_obs_next = obs_next.unsqueeze(1).repeat(*repeat_size).view(*view_size)
# tmp_obs & tmp_obs_next: (batch_size * num_repeat, state_dim)
current_pi_value1, current_pi_value2 = self.calc_pi_values(tmp_obs, tmp_obs)
next_pi_value1, next_pi_value2 = self.calc_pi_values(tmp_obs_next, tmp_obs)
random_value1, random_value2 = self.calc_random_values(tmp_obs, random_actions)
for value in [
current_pi_value1,
current_pi_value2,
next_pi_value1,
next_pi_value2,
random_value1,
random_value2,
]:
value.reshape(batch_size, self.num_repeat_actions, 1)
if self.calibrated:
returns = (
batch.calibration_returns.unsqueeze(1)
.repeat(
(1, self.num_repeat_actions),
)
.view(-1, 1)
)
random_value1 = torch.max(random_value1, returns)
random_value2 = torch.max(random_value2, returns)
current_pi_value1 = torch.max(current_pi_value1, returns)
current_pi_value2 = torch.max(current_pi_value2, returns)
next_pi_value1 = torch.max(next_pi_value1, returns)
next_pi_value2 = torch.max(next_pi_value2, returns)
# cat q values
cat_q1 = torch.cat([random_value1, current_pi_value1, next_pi_value1], 1)
cat_q2 = torch.cat([random_value2, current_pi_value2, next_pi_value2], 1)
# shape: (batch_size, 3 * num_repeat, 1)
cql1_scaled_loss = (
torch.logsumexp(cat_q1 / self.temperature, dim=1).mean()
* self.cql_weight
* self.temperature
- current_Q1.mean() * self.cql_weight
)
cql2_scaled_loss = (
torch.logsumexp(cat_q2 / self.temperature, dim=1).mean()
* self.cql_weight
* self.temperature
- current_Q2.mean() * self.cql_weight
)
# shape: (1)
if self.with_lagrange:
cql_alpha = torch.clamp(
self.cql_log_alpha.exp(),
self.alpha_min,
self.alpha_max,
)
cql1_scaled_loss = cql_alpha * (cql1_scaled_loss - self.lagrange_threshold)
cql2_scaled_loss = cql_alpha * (cql2_scaled_loss - self.lagrange_threshold)
self.cql_alpha_optim.zero_grad()
cql_alpha_loss = -(cql1_scaled_loss + cql2_scaled_loss) * 0.5
cql_alpha_loss.backward(retain_graph=True)
self.cql_alpha_optim.step()
critic1_loss = critic1_loss + cql1_scaled_loss
critic2_loss = critic2_loss + cql2_scaled_loss
# update critic
self.critic_optim.zero_grad()
critic1_loss.backward(retain_graph=True)
# clip grad, prevent the vanishing gradient problem
# It doesn't seem necessary
clip_grad_norm_(self.critic.parameters(), self.clip_grad)
self.critic_optim.step()
self.critic2_optim.zero_grad()
critic2_loss.backward()
clip_grad_norm_(self.critic2.parameters(), self.clip_grad)
self.critic2_optim.step()
self.sync_weight()
result = {
"loss/actor": actor_loss.item(),
"loss/critic1": critic1_loss.item(),
"loss/critic2": critic2_loss.item(),
}
if self.is_auto_alpha:
self.alpha = cast(torch.Tensor, self.alpha)
result["loss/alpha"] = alpha_loss.item()
result["alpha"] = self.alpha.item()
if self.with_lagrange:
result["loss/cql_alpha"] = cql_alpha_loss.item()
result["cql_alpha"] = cql_alpha.item()
return result