157 lines
7.2 KiB
Python
Raw Normal View History

2021-04-16 20:37:12 +08:00
import warnings
2021-04-21 09:52:15 +08:00
from typing import Any, Dict, List, Type
2021-04-16 20:37:12 +08:00
import torch
import torch.nn.functional as F
from torch.distributions import kl_divergence
2021-04-16 20:37:12 +08:00
2021-04-21 09:52:15 +08:00
from tianshou.data import Batch
from tianshou.policy import NPGPolicy
2021-04-16 20:37:12 +08:00
2021-04-21 09:52:15 +08:00
class TRPOPolicy(NPGPolicy):
2021-04-16 20:37:12 +08:00
"""Implementation of Trust Region Policy Optimization. arXiv:1502.05477.
:param torch.nn.Module actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.nn.Module critic: the critic network. (s -> V(s))
:param torch.optim.Optimizer optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action.
:type dist_fn: Type[torch.distributions.Distribution]
:param bool advantage_normalization: whether to do per mini-batch advantage
normalization. Default to True.
:param int optim_critic_iters: Number of times to optimize critic network per
update. Default to 5.
:param int max_kl: max kl-divergence used to constrain each actor network update.
Default to 0.01.
:param float backtrack_coeff: Coefficient to be multiplied by step size when
constraints are not met. Default to 0.8.
:param int max_backtracks: Max number of backtracking times in linesearch. Default
to 10.
:param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation.
Default to 0.95.
:param bool reward_normalization: normalize estimated values to have std close to
1. Default to False.
:param int max_batchsize: the maximum size of the batch when computing GAE,
depends on the size of available memory and the memory cost of the
model; should be as large as possible within the memory constraint.
Default to 256.
:param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action), "tanh" (for applying tanh
squashing) for now, or empty string for no bounding. Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
:param bool deterministic_eval: whether to use deterministic action instead of
stochastic action sampled by the policy. Default to False.
2021-04-16 20:37:12 +08:00
"""
def __init__(
self,
actor: torch.nn.Module,
critic: torch.nn.Module,
optim: torch.optim.Optimizer,
dist_fn: Type[torch.distributions.Distribution],
max_kl: float = 0.01,
backtrack_coeff: float = 0.8,
max_backtracks: int = 10,
**kwargs: Any,
) -> None:
super().__init__(actor, critic, optim, dist_fn, **kwargs)
self._max_backtracks = max_backtracks
self._delta = max_kl
self._backtrack_coeff = backtrack_coeff
self._optim_critic_iters: int
2021-04-16 20:37:12 +08:00
def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
) -> Dict[str, List[float]]:
actor_losses, vf_losses, step_sizes, kls = [], [], [], []
for _ in range(repeat):
2021-04-16 20:37:12 +08:00
for b in batch.split(batch_size, merge_last=True):
# optimize actor
# direction: calculate villia gradient
dist = self(b).dist # TODO could come from batch
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
actor_loss = -(ratio * b.adv).mean()
2021-04-21 09:52:15 +08:00
flat_grads = self._get_flat_grad(
actor_loss, self.actor, retain_graph=True
).detach()
2021-04-16 20:37:12 +08:00
# direction: calculate natural gradient
with torch.no_grad():
old_dist = self(b).dist
kl = kl_divergence(old_dist, dist).mean()
# calculate first order gradient of kl with respect to theta
2021-04-21 09:52:15 +08:00
flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True)
search_direction = -self._conjugate_gradients(
flat_grads, flat_kl_grad, nsteps=10
)
2021-04-16 20:37:12 +08:00
# stepsize: calculate max stepsize constrained by kl bound
step_size = torch.sqrt(
2 * self._delta /
(search_direction *
self._MVP(search_direction, flat_kl_grad)).sum(0, keepdim=True)
)
2021-04-16 20:37:12 +08:00
# stepsize: linesearch stepsize
with torch.no_grad():
flat_params = torch.cat(
[param.data.view(-1) for param in self.actor.parameters()]
)
2021-04-16 20:37:12 +08:00
for i in range(self._max_backtracks):
new_flat_params = flat_params + step_size * search_direction
2021-04-21 09:52:15 +08:00
self._set_from_flat_params(self.actor, new_flat_params)
2021-04-16 20:37:12 +08:00
# calculate kl and if in bound, loss actually down
new_dist = self(b).dist
new_dratio = (new_dist.log_prob(b.act) -
b.logp_old).exp().float()
new_dratio = new_dratio.reshape(new_dratio.size(0),
-1).transpose(0, 1)
2021-04-16 20:37:12 +08:00
new_actor_loss = -(new_dratio * b.adv).mean()
kl = kl_divergence(old_dist, new_dist).mean()
if kl < self._delta and new_actor_loss < actor_loss:
if i > 0:
2021-04-19 17:05:06 +08:00
warnings.warn(f"Backtracking to step {i}.")
2021-04-16 20:37:12 +08:00
break
elif i < self._max_backtracks - 1:
step_size = step_size * self._backtrack_coeff
else:
2021-04-21 09:52:15 +08:00
self._set_from_flat_params(self.actor, new_flat_params)
2021-04-16 20:37:12 +08:00
step_size = torch.tensor([0.0])
warnings.warn(
"Line search failed! It seems hyperparamters"
" are poor and need to be changed."
)
2021-04-16 20:37:12 +08:00
# optimize citirc
for _ in range(self._optim_critic_iters):
2021-04-16 20:37:12 +08:00
value = self.critic(b.obs).flatten()
vf_loss = F.mse_loss(b.returns, value)
self.optim.zero_grad()
vf_loss.backward()
self.optim.step()
actor_losses.append(actor_loss.item())
vf_losses.append(vf_loss.item())
step_sizes.append(step_size.item())
kls.append(kl.item())
# update learning rate if lr_scheduler is given
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return {
"loss/actor": actor_losses,
"loss/vf": vf_losses,
"step_size": step_sizes,
"kl": kls,
}