Add NPG policy (#344)
This commit is contained in:
parent
c059f98abf
commit
1dcf65fe21
@ -25,6 +25,7 @@
|
||||
- [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf)
|
||||
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)
|
||||
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
|
||||
- [Natural Policy Gradient (NPG)](https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf)
|
||||
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
|
||||
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/pdf/1502.05477.pdf)
|
||||
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
|
||||
|
@ -43,6 +43,11 @@ On-policy
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.NPGPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.A2CPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
@ -15,6 +15,7 @@ Welcome to Tianshou!
|
||||
* :class:`~tianshou.policy.C51Policy` `Categorical DQN <https://arxiv.org/pdf/1707.06887.pdf>`_
|
||||
* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN <https://arxiv.org/pdf/1710.10044.pdf>`_
|
||||
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
|
||||
* :class:`~tianshou.policy.NPGPolicy` `Natural Policy Gradient <https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf>`_
|
||||
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
|
||||
* :class:`~tianshou.policy.TRPOPolicy` `Trust Region Policy Optimization <https://arxiv.org/pdf/1502.05477.pdf>`_
|
||||
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
||||
|
136
test/continuous/test_npg.py
Normal file
136
test/continuous/test_npg.py
Normal file
@ -0,0 +1,136 @@
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.distributions import Independent, Normal
|
||||
|
||||
from tianshou.policy import NPGPolicy
|
||||
from tianshou.utils import BasicLogger
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||
parser.add_argument('--seed', type=int, default=1)
|
||||
parser.add_argument('--buffer-size', type=int, default=50000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.95)
|
||||
parser.add_argument('--epoch', type=int, default=5)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=50000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=2048)
|
||||
parser.add_argument('--repeat-per-collect', type=int,
|
||||
default=2) # theoretically it should be 1
|
||||
parser.add_argument('--batch-size', type=int, default=99999)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
|
||||
parser.add_argument('--training-num', type=int, default=16)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
# npg special
|
||||
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
||||
parser.add_argument('--rew-norm', type=int, default=1)
|
||||
parser.add_argument('--norm-adv', type=int, default=1)
|
||||
parser.add_argument('--optim-critic-iters', type=int, default=5)
|
||||
parser.add_argument('--actor-step-size', type=float, default=0.5)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def test_npg(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
if args.task == 'Pendulum-v0':
|
||||
env.spec.reward_threshold = -250
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0]
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
activation=nn.Tanh, device=args.device)
|
||||
actor = ActorProb(net, args.action_shape, max_action=args.max_action,
|
||||
unbounded=True, device=args.device).to(args.device)
|
||||
critic = Critic(Net(
|
||||
args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device,
|
||||
activation=nn.Tanh), device=args.device).to(args.device)
|
||||
# orthogonal initialization
|
||||
for m in list(actor.modules()) + list(critic.modules()):
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.orthogonal_(m.weight)
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
optim = torch.optim.Adam(set(
|
||||
actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||
|
||||
# replace DiagGuassian with Independent(Normal) which is equivalent
|
||||
# pass *logits to be consistent with policy.forward
|
||||
def dist(*logits):
|
||||
return Independent(Normal(*logits), 1)
|
||||
|
||||
policy = NPGPolicy(
|
||||
actor, critic, optim, dist,
|
||||
discount_factor=args.gamma,
|
||||
reward_normalization=args.rew_norm,
|
||||
advantage_normalization=args.norm_adv,
|
||||
gae_lambda=args.gae_lambda,
|
||||
action_space=env.action_space,
|
||||
optim_critic_iters=args.optim_critic_iters,
|
||||
actor_step_size=args.actor_step_size)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs,
|
||||
VectorReplayBuffer(args.buffer_size, len(train_envs)))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'npg')
|
||||
writer = SummaryWriter(log_path)
|
||||
logger = BasicLogger(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
# trainer
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
|
||||
step_per_collect=args.step_per_collect, stop_fn=stop_fn, save_fn=save_fn,
|
||||
logger=logger)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
rews, lens = result["rews"], result["lens"]
|
||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_npg()
|
@ -27,8 +27,7 @@ def get_args():
|
||||
parser.add_argument('--epoch', type=int, default=5)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=50000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=2048)
|
||||
parser.add_argument('--repeat-per-collect', type=int,
|
||||
default=2) # theoretically it should be 1
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=1)
|
||||
parser.add_argument('--batch-size', type=int, default=99999)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
|
||||
parser.add_argument('--training-num', type=int, default=16)
|
||||
@ -43,7 +42,7 @@ def get_args():
|
||||
parser.add_argument('--rew-norm', type=int, default=1)
|
||||
parser.add_argument('--norm-adv', type=int, default=1)
|
||||
parser.add_argument('--optim-critic-iters', type=int, default=5)
|
||||
parser.add_argument('--max-kl', type=float, default=0.01)
|
||||
parser.add_argument('--max-kl', type=float, default=0.005)
|
||||
parser.add_argument('--backtrack-coeff', type=float, default=0.8)
|
||||
parser.add_argument('--max-backtracks', type=int, default=10)
|
||||
|
||||
|
@ -5,6 +5,7 @@ from tianshou.policy.modelfree.c51 import C51Policy
|
||||
from tianshou.policy.modelfree.qrdqn import QRDQNPolicy
|
||||
from tianshou.policy.modelfree.pg import PGPolicy
|
||||
from tianshou.policy.modelfree.a2c import A2CPolicy
|
||||
from tianshou.policy.modelfree.npg import NPGPolicy
|
||||
from tianshou.policy.modelfree.ddpg import DDPGPolicy
|
||||
from tianshou.policy.modelfree.ppo import PPOPolicy
|
||||
from tianshou.policy.modelfree.trpo import TRPOPolicy
|
||||
@ -25,6 +26,7 @@ __all__ = [
|
||||
"QRDQNPolicy",
|
||||
"PGPolicy",
|
||||
"A2CPolicy",
|
||||
"NPGPolicy",
|
||||
"DDPGPolicy",
|
||||
"PPOPolicy",
|
||||
"TRPOPolicy",
|
||||
|
182
tianshou/policy/modelfree/npg.py
Normal file
182
tianshou/policy/modelfree/npg.py
Normal file
@ -0,0 +1,182 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Any, Dict, List, Type
|
||||
from torch.distributions import kl_divergence
|
||||
|
||||
|
||||
from tianshou.policy import A2CPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
|
||||
|
||||
class NPGPolicy(A2CPolicy):
|
||||
"""Implementation of Natural Policy Gradient.
|
||||
|
||||
https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf
|
||||
|
||||
:param torch.nn.Module actor: the actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param torch.nn.Module critic: the critic network. (s -> V(s))
|
||||
:param torch.optim.Optimizer optim: the optimizer for actor and critic network.
|
||||
:param dist_fn: distribution class for computing the action.
|
||||
:type dist_fn: Type[torch.distributions.Distribution]
|
||||
:param bool advantage_normalization: whether to do per mini-batch advantage
|
||||
normalization. Default to True.
|
||||
:param int optim_critic_iters: Number of times to optimize critic network per
|
||||
update. Default to 5.
|
||||
:param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation.
|
||||
Default to 0.95.
|
||||
:param bool reward_normalization: normalize estimated values to have std close to
|
||||
1. Default to False.
|
||||
:param int max_batchsize: the maximum size of the batch when computing GAE,
|
||||
depends on the size of available memory and the memory cost of the
|
||||
model; should be as large as possible within the memory constraint.
|
||||
Default to 256.
|
||||
:param bool action_scaling: whether to map actions from range [-1, 1] to range
|
||||
[action_spaces.low, action_spaces.high]. Default to True.
|
||||
:param str action_bound_method: method to bound action to range [-1, 1], can be
|
||||
either "clip" (for simply clipping the action), "tanh" (for applying tanh
|
||||
squashing) for now, or empty string for no bounding. Default to "clip".
|
||||
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: Type[torch.distributions.Distribution],
|
||||
advantage_normalization: bool = True,
|
||||
optim_critic_iters: int = 5,
|
||||
actor_step_size: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(actor, critic, optim, dist_fn, **kwargs)
|
||||
del self._weight_vf, self._weight_ent, self._grad_norm
|
||||
self._norm_adv = advantage_normalization
|
||||
self._optim_critic_iters = optim_critic_iters
|
||||
self._step_size = actor_step_size
|
||||
# adjusts Hessian-vector product calculation for numerical stability
|
||||
self._damping = 0.1
|
||||
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> Batch:
|
||||
batch = super().process_fn(batch, buffer, indice)
|
||||
old_log_prob = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(self._batch, shuffle=False, merge_last=True):
|
||||
old_log_prob.append(self(b).dist.log_prob(b.act))
|
||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||
if self._norm_adv:
|
||||
batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std()
|
||||
return batch
|
||||
|
||||
def learn( # type: ignore
|
||||
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
||||
) -> Dict[str, List[float]]:
|
||||
actor_losses, vf_losses, kls = [], [], []
|
||||
for step in range(repeat):
|
||||
for b in batch.split(batch_size, merge_last=True):
|
||||
# optimize actor
|
||||
# direction: calculate villia gradient
|
||||
dist = self(b).dist # TODO could come from batch
|
||||
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
||||
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
|
||||
actor_loss = -(ratio * b.adv).mean()
|
||||
flat_grads = self._get_flat_grad(
|
||||
actor_loss, self.actor, retain_graph=True).detach()
|
||||
|
||||
# direction: calculate natural gradient
|
||||
with torch.no_grad():
|
||||
old_dist = self(b).dist
|
||||
|
||||
kl = kl_divergence(old_dist, dist).mean()
|
||||
# calculate first order gradient of kl with respect to theta
|
||||
flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True)
|
||||
search_direction = -self._conjugate_gradients(
|
||||
flat_grads, flat_kl_grad, nsteps=10)
|
||||
|
||||
# step
|
||||
with torch.no_grad():
|
||||
flat_params = torch.cat([param.data.view(-1)
|
||||
for param in self.actor.parameters()])
|
||||
new_flat_params = flat_params + self._step_size * search_direction
|
||||
self._set_from_flat_params(self.actor, new_flat_params)
|
||||
new_dist = self(b).dist
|
||||
kl = kl_divergence(old_dist, new_dist).mean()
|
||||
|
||||
# optimize citirc
|
||||
for _ in range(self._optim_critic_iters):
|
||||
value = self.critic(b.obs).flatten()
|
||||
vf_loss = F.mse_loss(b.returns, value)
|
||||
self.optim.zero_grad()
|
||||
vf_loss.backward()
|
||||
self.optim.step()
|
||||
|
||||
actor_losses.append(actor_loss.item())
|
||||
vf_losses.append(vf_loss.item())
|
||||
kls.append(kl.item())
|
||||
|
||||
# update learning rate if lr_scheduler is given
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
|
||||
return {
|
||||
"loss/actor": actor_losses,
|
||||
"loss/vf": vf_losses,
|
||||
"kl": kls,
|
||||
}
|
||||
|
||||
def _MVP(self, v: torch.Tensor, flat_kl_grad: torch.Tensor) -> torch.Tensor:
|
||||
"""Matrix vector product."""
|
||||
# caculate second order gradient of kl with respect to theta
|
||||
kl_v = (flat_kl_grad * v).sum()
|
||||
flat_kl_grad_grad = self._get_flat_grad(
|
||||
kl_v, self.actor, retain_graph=True).detach()
|
||||
return flat_kl_grad_grad + v * self._damping
|
||||
|
||||
def _conjugate_gradients(
|
||||
self,
|
||||
b: torch.Tensor,
|
||||
flat_kl_grad: torch.Tensor,
|
||||
nsteps: int = 10,
|
||||
residual_tol: float = 1e-10
|
||||
) -> torch.Tensor:
|
||||
x = torch.zeros_like(b)
|
||||
r, p = b.clone(), b.clone()
|
||||
# Note: should be 'r, p = b - MVP(x)', but for x=0, MVP(x)=0.
|
||||
# Change if doing warm start.
|
||||
rdotr = r.dot(r)
|
||||
for i in range(nsteps):
|
||||
z = self._MVP(p, flat_kl_grad)
|
||||
alpha = rdotr / p.dot(z)
|
||||
x += alpha * p
|
||||
r -= alpha * z
|
||||
new_rdotr = r.dot(r)
|
||||
if new_rdotr < residual_tol:
|
||||
break
|
||||
p = r + new_rdotr / rdotr * p
|
||||
rdotr = new_rdotr
|
||||
return x
|
||||
|
||||
def _get_flat_grad(
|
||||
self, y: torch.Tensor, model: nn.Module, **kwargs: Any
|
||||
) -> torch.Tensor:
|
||||
grads = torch.autograd.grad(y, model.parameters(), **kwargs) # type: ignore
|
||||
return torch.cat([grad.reshape(-1) for grad in grads])
|
||||
|
||||
def _set_from_flat_params(
|
||||
self, model: nn.Module, flat_params: torch.Tensor
|
||||
) -> nn.Module:
|
||||
prev_ind = 0
|
||||
for param in model.parameters():
|
||||
flat_size = int(np.prod(list(param.size())))
|
||||
param.data.copy_(
|
||||
flat_params[prev_ind:prev_ind + flat_size].view(param.size()))
|
||||
prev_ind += flat_size
|
||||
return model
|
@ -1,56 +1,15 @@
|
||||
import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Any, Dict, List, Type
|
||||
from torch.distributions import kl_divergence
|
||||
from typing import Any, Dict, List, Type, Callable
|
||||
|
||||
|
||||
from tianshou.policy import A2CPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import NPGPolicy
|
||||
|
||||
|
||||
def _conjugate_gradients(
|
||||
Avp: Callable[[torch.Tensor], torch.Tensor],
|
||||
b: torch.Tensor,
|
||||
nsteps: int = 10,
|
||||
residual_tol: float = 1e-10
|
||||
) -> torch.Tensor:
|
||||
x = torch.zeros_like(b)
|
||||
r, p = b.clone(), b.clone()
|
||||
# Note: should be 'r, p = b - A(x)', but for x=0, A(x)=0.
|
||||
# Change if doing warm start.
|
||||
rdotr = r.dot(r)
|
||||
for i in range(nsteps):
|
||||
z = Avp(p)
|
||||
alpha = rdotr / p.dot(z)
|
||||
x += alpha * p
|
||||
r -= alpha * z
|
||||
new_rdotr = r.dot(r)
|
||||
if new_rdotr < residual_tol:
|
||||
break
|
||||
p = r + new_rdotr / rdotr * p
|
||||
rdotr = new_rdotr
|
||||
return x
|
||||
|
||||
|
||||
def _get_flat_grad(y: torch.Tensor, model: nn.Module, **kwargs: Any) -> torch.Tensor:
|
||||
grads = torch.autograd.grad(y, model.parameters(), **kwargs) # type: ignore
|
||||
return torch.cat([grad.reshape(-1) for grad in grads])
|
||||
|
||||
|
||||
def _set_from_flat_params(model: nn.Module, flat_params: torch.Tensor) -> nn.Module:
|
||||
prev_ind = 0
|
||||
for param in model.parameters():
|
||||
flat_size = int(np.prod(list(param.size())))
|
||||
param.data.copy_(
|
||||
flat_params[prev_ind:prev_ind + flat_size].view(param.size()))
|
||||
prev_ind += flat_size
|
||||
return model
|
||||
|
||||
|
||||
class TRPOPolicy(A2CPolicy):
|
||||
class TRPOPolicy(NPGPolicy):
|
||||
"""Implementation of Trust Region Policy Optimization. arXiv:1502.05477.
|
||||
|
||||
:param torch.nn.Module actor: the actor network following the rules in
|
||||
@ -94,35 +53,16 @@ class TRPOPolicy(A2CPolicy):
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: Type[torch.distributions.Distribution],
|
||||
advantage_normalization: bool = True,
|
||||
optim_critic_iters: int = 5,
|
||||
max_kl: float = 0.01,
|
||||
backtrack_coeff: float = 0.8,
|
||||
max_backtracks: int = 10,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(actor, critic, optim, dist_fn, **kwargs)
|
||||
del self._weight_vf, self._weight_ent, self._grad_norm
|
||||
self._norm_adv = advantage_normalization
|
||||
self._optim_critic_iters = optim_critic_iters
|
||||
del self._step_size
|
||||
self._max_backtracks = max_backtracks
|
||||
self._delta = max_kl
|
||||
self._backtrack_coeff = backtrack_coeff
|
||||
# adjusts Hessian-vector product calculation for numerical stability
|
||||
self.__damping = 0.1
|
||||
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> Batch:
|
||||
batch = super().process_fn(batch, buffer, indice)
|
||||
old_log_prob = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(self._batch, shuffle=False, merge_last=True):
|
||||
old_log_prob.append(self(b).dist.log_prob(b.act))
|
||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||
if self._norm_adv:
|
||||
batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std()
|
||||
return batch
|
||||
|
||||
def learn( # type: ignore
|
||||
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
||||
@ -136,7 +76,7 @@ class TRPOPolicy(A2CPolicy):
|
||||
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
||||
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
|
||||
actor_loss = -(ratio * b.adv).mean()
|
||||
flat_grads = _get_flat_grad(
|
||||
flat_grads = self._get_flat_grad(
|
||||
actor_loss, self.actor, retain_graph=True).detach()
|
||||
|
||||
# direction: calculate natural gradient
|
||||
@ -145,20 +85,14 @@ class TRPOPolicy(A2CPolicy):
|
||||
|
||||
kl = kl_divergence(old_dist, dist).mean()
|
||||
# calculate first order gradient of kl with respect to theta
|
||||
flat_kl_grad = _get_flat_grad(kl, self.actor, create_graph=True)
|
||||
|
||||
def MVP(v: torch.Tensor) -> torch.Tensor: # matrix vector product
|
||||
# caculate second order gradient of kl with respect to theta
|
||||
kl_v = (flat_kl_grad * v).sum()
|
||||
flat_kl_grad_grad = _get_flat_grad(
|
||||
kl_v, self.actor, retain_graph=True).detach()
|
||||
return flat_kl_grad_grad + v * self.__damping
|
||||
|
||||
search_direction = -_conjugate_gradients(MVP, flat_grads, nsteps=10)
|
||||
flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True)
|
||||
search_direction = -self._conjugate_gradients(
|
||||
flat_grads, flat_kl_grad, nsteps=10)
|
||||
|
||||
# stepsize: calculate max stepsize constrained by kl bound
|
||||
step_size = torch.sqrt(2 * self._delta / (
|
||||
search_direction * MVP(search_direction)).sum(0, keepdim=True))
|
||||
search_direction * self._MVP(search_direction, flat_kl_grad)
|
||||
).sum(0, keepdim=True))
|
||||
|
||||
# stepsize: linesearch stepsize
|
||||
with torch.no_grad():
|
||||
@ -166,7 +100,7 @@ class TRPOPolicy(A2CPolicy):
|
||||
for param in self.actor.parameters()])
|
||||
for i in range(self._max_backtracks):
|
||||
new_flat_params = flat_params + step_size * search_direction
|
||||
_set_from_flat_params(self.actor, new_flat_params)
|
||||
self._set_from_flat_params(self.actor, new_flat_params)
|
||||
# calculate kl and if in bound, loss actually down
|
||||
new_dist = self(b).dist
|
||||
new_dratio = (
|
||||
@ -183,7 +117,7 @@ class TRPOPolicy(A2CPolicy):
|
||||
elif i < self._max_backtracks - 1:
|
||||
step_size = step_size * self._backtrack_coeff
|
||||
else:
|
||||
_set_from_flat_params(self.actor, new_flat_params)
|
||||
self._set_from_flat_params(self.actor, new_flat_params)
|
||||
step_size = torch.tensor([0.0])
|
||||
warnings.warn("Line search failed! It seems hyperparamters"
|
||||
" are poor and need to be changed.")
|
||||
|
Loading…
x
Reference in New Issue
Block a user