Add NPG policy (#344)
This commit is contained in:
parent
c059f98abf
commit
1dcf65fe21
@ -25,6 +25,7 @@
|
|||||||
- [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf)
|
- [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf)
|
||||||
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)
|
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)
|
||||||
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
|
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
|
||||||
|
- [Natural Policy Gradient (NPG)](https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf)
|
||||||
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
|
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
|
||||||
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/pdf/1502.05477.pdf)
|
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/pdf/1502.05477.pdf)
|
||||||
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
|
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
|
||||||
|
|||||||
@ -43,6 +43,11 @@ On-policy
|
|||||||
:undoc-members:
|
:undoc-members:
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
|
|
||||||
|
.. autoclass:: tianshou.policy.NPGPolicy
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
.. autoclass:: tianshou.policy.A2CPolicy
|
.. autoclass:: tianshou.policy.A2CPolicy
|
||||||
:members:
|
:members:
|
||||||
:undoc-members:
|
:undoc-members:
|
||||||
|
|||||||
@ -15,6 +15,7 @@ Welcome to Tianshou!
|
|||||||
* :class:`~tianshou.policy.C51Policy` `Categorical DQN <https://arxiv.org/pdf/1707.06887.pdf>`_
|
* :class:`~tianshou.policy.C51Policy` `Categorical DQN <https://arxiv.org/pdf/1707.06887.pdf>`_
|
||||||
* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN <https://arxiv.org/pdf/1710.10044.pdf>`_
|
* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN <https://arxiv.org/pdf/1710.10044.pdf>`_
|
||||||
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
|
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
|
||||||
|
* :class:`~tianshou.policy.NPGPolicy` `Natural Policy Gradient <https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf>`_
|
||||||
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
|
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
|
||||||
* :class:`~tianshou.policy.TRPOPolicy` `Trust Region Policy Optimization <https://arxiv.org/pdf/1502.05477.pdf>`_
|
* :class:`~tianshou.policy.TRPOPolicy` `Trust Region Policy Optimization <https://arxiv.org/pdf/1502.05477.pdf>`_
|
||||||
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
||||||
|
|||||||
136
test/continuous/test_npg.py
Normal file
136
test/continuous/test_npg.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
import os
|
||||||
|
import gym
|
||||||
|
import torch
|
||||||
|
import pprint
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from torch.distributions import Independent, Normal
|
||||||
|
|
||||||
|
from tianshou.policy import NPGPolicy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
|
from tianshou.env import DummyVectorEnv
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
|
from tianshou.trainer import onpolicy_trainer
|
||||||
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
|
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||||
|
parser.add_argument('--seed', type=int, default=1)
|
||||||
|
parser.add_argument('--buffer-size', type=int, default=50000)
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument('--gamma', type=float, default=0.95)
|
||||||
|
parser.add_argument('--epoch', type=int, default=5)
|
||||||
|
parser.add_argument('--step-per-epoch', type=int, default=50000)
|
||||||
|
parser.add_argument('--step-per-collect', type=int, default=2048)
|
||||||
|
parser.add_argument('--repeat-per-collect', type=int,
|
||||||
|
default=2) # theoretically it should be 1
|
||||||
|
parser.add_argument('--batch-size', type=int, default=99999)
|
||||||
|
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
|
||||||
|
parser.add_argument('--training-num', type=int, default=16)
|
||||||
|
parser.add_argument('--test-num', type=int, default=10)
|
||||||
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', type=str,
|
||||||
|
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
# npg special
|
||||||
|
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
||||||
|
parser.add_argument('--rew-norm', type=int, default=1)
|
||||||
|
parser.add_argument('--norm-adv', type=int, default=1)
|
||||||
|
parser.add_argument('--optim-critic-iters', type=int, default=5)
|
||||||
|
parser.add_argument('--actor-step-size', type=float, default=0.5)
|
||||||
|
args = parser.parse_known_args()[0]
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def test_npg(args=get_args()):
|
||||||
|
env = gym.make(args.task)
|
||||||
|
if args.task == 'Pendulum-v0':
|
||||||
|
env.spec.reward_threshold = -250
|
||||||
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
args.max_action = env.action_space.high[0]
|
||||||
|
# you can also use tianshou.env.SubprocVectorEnv
|
||||||
|
# train_envs = gym.make(args.task)
|
||||||
|
train_envs = DummyVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||||
|
# test_envs = gym.make(args.task)
|
||||||
|
test_envs = DummyVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
train_envs.seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
# model
|
||||||
|
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||||
|
activation=nn.Tanh, device=args.device)
|
||||||
|
actor = ActorProb(net, args.action_shape, max_action=args.max_action,
|
||||||
|
unbounded=True, device=args.device).to(args.device)
|
||||||
|
critic = Critic(Net(
|
||||||
|
args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device,
|
||||||
|
activation=nn.Tanh), device=args.device).to(args.device)
|
||||||
|
# orthogonal initialization
|
||||||
|
for m in list(actor.modules()) + list(critic.modules()):
|
||||||
|
if isinstance(m, torch.nn.Linear):
|
||||||
|
torch.nn.init.orthogonal_(m.weight)
|
||||||
|
torch.nn.init.zeros_(m.bias)
|
||||||
|
optim = torch.optim.Adam(set(
|
||||||
|
actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||||
|
|
||||||
|
# replace DiagGuassian with Independent(Normal) which is equivalent
|
||||||
|
# pass *logits to be consistent with policy.forward
|
||||||
|
def dist(*logits):
|
||||||
|
return Independent(Normal(*logits), 1)
|
||||||
|
|
||||||
|
policy = NPGPolicy(
|
||||||
|
actor, critic, optim, dist,
|
||||||
|
discount_factor=args.gamma,
|
||||||
|
reward_normalization=args.rew_norm,
|
||||||
|
advantage_normalization=args.norm_adv,
|
||||||
|
gae_lambda=args.gae_lambda,
|
||||||
|
action_space=env.action_space,
|
||||||
|
optim_critic_iters=args.optim_critic_iters,
|
||||||
|
actor_step_size=args.actor_step_size)
|
||||||
|
# collector
|
||||||
|
train_collector = Collector(
|
||||||
|
policy, train_envs,
|
||||||
|
VectorReplayBuffer(args.buffer_size, len(train_envs)))
|
||||||
|
test_collector = Collector(policy, test_envs)
|
||||||
|
# log
|
||||||
|
log_path = os.path.join(args.logdir, args.task, 'npg')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
|
def stop_fn(mean_rewards):
|
||||||
|
return mean_rewards >= env.spec.reward_threshold
|
||||||
|
|
||||||
|
# trainer
|
||||||
|
result = onpolicy_trainer(
|
||||||
|
policy, train_collector, test_collector, args.epoch,
|
||||||
|
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
|
||||||
|
step_per_collect=args.step_per_collect, stop_fn=stop_fn, save_fn=save_fn,
|
||||||
|
logger=logger)
|
||||||
|
assert stop_fn(result['best_reward'])
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pprint.pprint(result)
|
||||||
|
# Let's watch its performance!
|
||||||
|
env = gym.make(args.task)
|
||||||
|
policy.eval()
|
||||||
|
collector = Collector(policy, env)
|
||||||
|
result = collector.collect(n_episode=1, render=args.render)
|
||||||
|
rews, lens = result["rews"], result["lens"]
|
||||||
|
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_npg()
|
||||||
@ -27,8 +27,7 @@ def get_args():
|
|||||||
parser.add_argument('--epoch', type=int, default=5)
|
parser.add_argument('--epoch', type=int, default=5)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=50000)
|
parser.add_argument('--step-per-epoch', type=int, default=50000)
|
||||||
parser.add_argument('--step-per-collect', type=int, default=2048)
|
parser.add_argument('--step-per-collect', type=int, default=2048)
|
||||||
parser.add_argument('--repeat-per-collect', type=int,
|
parser.add_argument('--repeat-per-collect', type=int, default=1)
|
||||||
default=2) # theoretically it should be 1
|
|
||||||
parser.add_argument('--batch-size', type=int, default=99999)
|
parser.add_argument('--batch-size', type=int, default=99999)
|
||||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
|
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
|
||||||
parser.add_argument('--training-num', type=int, default=16)
|
parser.add_argument('--training-num', type=int, default=16)
|
||||||
@ -43,7 +42,7 @@ def get_args():
|
|||||||
parser.add_argument('--rew-norm', type=int, default=1)
|
parser.add_argument('--rew-norm', type=int, default=1)
|
||||||
parser.add_argument('--norm-adv', type=int, default=1)
|
parser.add_argument('--norm-adv', type=int, default=1)
|
||||||
parser.add_argument('--optim-critic-iters', type=int, default=5)
|
parser.add_argument('--optim-critic-iters', type=int, default=5)
|
||||||
parser.add_argument('--max-kl', type=float, default=0.01)
|
parser.add_argument('--max-kl', type=float, default=0.005)
|
||||||
parser.add_argument('--backtrack-coeff', type=float, default=0.8)
|
parser.add_argument('--backtrack-coeff', type=float, default=0.8)
|
||||||
parser.add_argument('--max-backtracks', type=int, default=10)
|
parser.add_argument('--max-backtracks', type=int, default=10)
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from tianshou.policy.modelfree.c51 import C51Policy
|
|||||||
from tianshou.policy.modelfree.qrdqn import QRDQNPolicy
|
from tianshou.policy.modelfree.qrdqn import QRDQNPolicy
|
||||||
from tianshou.policy.modelfree.pg import PGPolicy
|
from tianshou.policy.modelfree.pg import PGPolicy
|
||||||
from tianshou.policy.modelfree.a2c import A2CPolicy
|
from tianshou.policy.modelfree.a2c import A2CPolicy
|
||||||
|
from tianshou.policy.modelfree.npg import NPGPolicy
|
||||||
from tianshou.policy.modelfree.ddpg import DDPGPolicy
|
from tianshou.policy.modelfree.ddpg import DDPGPolicy
|
||||||
from tianshou.policy.modelfree.ppo import PPOPolicy
|
from tianshou.policy.modelfree.ppo import PPOPolicy
|
||||||
from tianshou.policy.modelfree.trpo import TRPOPolicy
|
from tianshou.policy.modelfree.trpo import TRPOPolicy
|
||||||
@ -25,6 +26,7 @@ __all__ = [
|
|||||||
"QRDQNPolicy",
|
"QRDQNPolicy",
|
||||||
"PGPolicy",
|
"PGPolicy",
|
||||||
"A2CPolicy",
|
"A2CPolicy",
|
||||||
|
"NPGPolicy",
|
||||||
"DDPGPolicy",
|
"DDPGPolicy",
|
||||||
"PPOPolicy",
|
"PPOPolicy",
|
||||||
"TRPOPolicy",
|
"TRPOPolicy",
|
||||||
|
|||||||
182
tianshou/policy/modelfree/npg.py
Normal file
182
tianshou/policy/modelfree/npg.py
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import Any, Dict, List, Type
|
||||||
|
from torch.distributions import kl_divergence
|
||||||
|
|
||||||
|
|
||||||
|
from tianshou.policy import A2CPolicy
|
||||||
|
from tianshou.data import Batch, ReplayBuffer
|
||||||
|
|
||||||
|
|
||||||
|
class NPGPolicy(A2CPolicy):
|
||||||
|
"""Implementation of Natural Policy Gradient.
|
||||||
|
|
||||||
|
https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf
|
||||||
|
|
||||||
|
:param torch.nn.Module actor: the actor network following the rules in
|
||||||
|
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||||
|
:param torch.nn.Module critic: the critic network. (s -> V(s))
|
||||||
|
:param torch.optim.Optimizer optim: the optimizer for actor and critic network.
|
||||||
|
:param dist_fn: distribution class for computing the action.
|
||||||
|
:type dist_fn: Type[torch.distributions.Distribution]
|
||||||
|
:param bool advantage_normalization: whether to do per mini-batch advantage
|
||||||
|
normalization. Default to True.
|
||||||
|
:param int optim_critic_iters: Number of times to optimize critic network per
|
||||||
|
update. Default to 5.
|
||||||
|
:param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation.
|
||||||
|
Default to 0.95.
|
||||||
|
:param bool reward_normalization: normalize estimated values to have std close to
|
||||||
|
1. Default to False.
|
||||||
|
:param int max_batchsize: the maximum size of the batch when computing GAE,
|
||||||
|
depends on the size of available memory and the memory cost of the
|
||||||
|
model; should be as large as possible within the memory constraint.
|
||||||
|
Default to 256.
|
||||||
|
:param bool action_scaling: whether to map actions from range [-1, 1] to range
|
||||||
|
[action_spaces.low, action_spaces.high]. Default to True.
|
||||||
|
:param str action_bound_method: method to bound action to range [-1, 1], can be
|
||||||
|
either "clip" (for simply clipping the action), "tanh" (for applying tanh
|
||||||
|
squashing) for now, or empty string for no bounding. Default to "clip".
|
||||||
|
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
||||||
|
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||||
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
actor: torch.nn.Module,
|
||||||
|
critic: torch.nn.Module,
|
||||||
|
optim: torch.optim.Optimizer,
|
||||||
|
dist_fn: Type[torch.distributions.Distribution],
|
||||||
|
advantage_normalization: bool = True,
|
||||||
|
optim_critic_iters: int = 5,
|
||||||
|
actor_step_size: float = 0.5,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(actor, critic, optim, dist_fn, **kwargs)
|
||||||
|
del self._weight_vf, self._weight_ent, self._grad_norm
|
||||||
|
self._norm_adv = advantage_normalization
|
||||||
|
self._optim_critic_iters = optim_critic_iters
|
||||||
|
self._step_size = actor_step_size
|
||||||
|
# adjusts Hessian-vector product calculation for numerical stability
|
||||||
|
self._damping = 0.1
|
||||||
|
|
||||||
|
def process_fn(
|
||||||
|
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||||
|
) -> Batch:
|
||||||
|
batch = super().process_fn(batch, buffer, indice)
|
||||||
|
old_log_prob = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for b in batch.split(self._batch, shuffle=False, merge_last=True):
|
||||||
|
old_log_prob.append(self(b).dist.log_prob(b.act))
|
||||||
|
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||||
|
if self._norm_adv:
|
||||||
|
batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std()
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def learn( # type: ignore
|
||||||
|
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
||||||
|
) -> Dict[str, List[float]]:
|
||||||
|
actor_losses, vf_losses, kls = [], [], []
|
||||||
|
for step in range(repeat):
|
||||||
|
for b in batch.split(batch_size, merge_last=True):
|
||||||
|
# optimize actor
|
||||||
|
# direction: calculate villia gradient
|
||||||
|
dist = self(b).dist # TODO could come from batch
|
||||||
|
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
||||||
|
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
|
||||||
|
actor_loss = -(ratio * b.adv).mean()
|
||||||
|
flat_grads = self._get_flat_grad(
|
||||||
|
actor_loss, self.actor, retain_graph=True).detach()
|
||||||
|
|
||||||
|
# direction: calculate natural gradient
|
||||||
|
with torch.no_grad():
|
||||||
|
old_dist = self(b).dist
|
||||||
|
|
||||||
|
kl = kl_divergence(old_dist, dist).mean()
|
||||||
|
# calculate first order gradient of kl with respect to theta
|
||||||
|
flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True)
|
||||||
|
search_direction = -self._conjugate_gradients(
|
||||||
|
flat_grads, flat_kl_grad, nsteps=10)
|
||||||
|
|
||||||
|
# step
|
||||||
|
with torch.no_grad():
|
||||||
|
flat_params = torch.cat([param.data.view(-1)
|
||||||
|
for param in self.actor.parameters()])
|
||||||
|
new_flat_params = flat_params + self._step_size * search_direction
|
||||||
|
self._set_from_flat_params(self.actor, new_flat_params)
|
||||||
|
new_dist = self(b).dist
|
||||||
|
kl = kl_divergence(old_dist, new_dist).mean()
|
||||||
|
|
||||||
|
# optimize citirc
|
||||||
|
for _ in range(self._optim_critic_iters):
|
||||||
|
value = self.critic(b.obs).flatten()
|
||||||
|
vf_loss = F.mse_loss(b.returns, value)
|
||||||
|
self.optim.zero_grad()
|
||||||
|
vf_loss.backward()
|
||||||
|
self.optim.step()
|
||||||
|
|
||||||
|
actor_losses.append(actor_loss.item())
|
||||||
|
vf_losses.append(vf_loss.item())
|
||||||
|
kls.append(kl.item())
|
||||||
|
|
||||||
|
# update learning rate if lr_scheduler is given
|
||||||
|
if self.lr_scheduler is not None:
|
||||||
|
self.lr_scheduler.step()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"loss/actor": actor_losses,
|
||||||
|
"loss/vf": vf_losses,
|
||||||
|
"kl": kls,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _MVP(self, v: torch.Tensor, flat_kl_grad: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Matrix vector product."""
|
||||||
|
# caculate second order gradient of kl with respect to theta
|
||||||
|
kl_v = (flat_kl_grad * v).sum()
|
||||||
|
flat_kl_grad_grad = self._get_flat_grad(
|
||||||
|
kl_v, self.actor, retain_graph=True).detach()
|
||||||
|
return flat_kl_grad_grad + v * self._damping
|
||||||
|
|
||||||
|
def _conjugate_gradients(
|
||||||
|
self,
|
||||||
|
b: torch.Tensor,
|
||||||
|
flat_kl_grad: torch.Tensor,
|
||||||
|
nsteps: int = 10,
|
||||||
|
residual_tol: float = 1e-10
|
||||||
|
) -> torch.Tensor:
|
||||||
|
x = torch.zeros_like(b)
|
||||||
|
r, p = b.clone(), b.clone()
|
||||||
|
# Note: should be 'r, p = b - MVP(x)', but for x=0, MVP(x)=0.
|
||||||
|
# Change if doing warm start.
|
||||||
|
rdotr = r.dot(r)
|
||||||
|
for i in range(nsteps):
|
||||||
|
z = self._MVP(p, flat_kl_grad)
|
||||||
|
alpha = rdotr / p.dot(z)
|
||||||
|
x += alpha * p
|
||||||
|
r -= alpha * z
|
||||||
|
new_rdotr = r.dot(r)
|
||||||
|
if new_rdotr < residual_tol:
|
||||||
|
break
|
||||||
|
p = r + new_rdotr / rdotr * p
|
||||||
|
rdotr = new_rdotr
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _get_flat_grad(
|
||||||
|
self, y: torch.Tensor, model: nn.Module, **kwargs: Any
|
||||||
|
) -> torch.Tensor:
|
||||||
|
grads = torch.autograd.grad(y, model.parameters(), **kwargs) # type: ignore
|
||||||
|
return torch.cat([grad.reshape(-1) for grad in grads])
|
||||||
|
|
||||||
|
def _set_from_flat_params(
|
||||||
|
self, model: nn.Module, flat_params: torch.Tensor
|
||||||
|
) -> nn.Module:
|
||||||
|
prev_ind = 0
|
||||||
|
for param in model.parameters():
|
||||||
|
flat_size = int(np.prod(list(param.size())))
|
||||||
|
param.data.copy_(
|
||||||
|
flat_params[prev_ind:prev_ind + flat_size].view(param.size()))
|
||||||
|
prev_ind += flat_size
|
||||||
|
return model
|
||||||
@ -1,56 +1,15 @@
|
|||||||
import torch
|
import torch
|
||||||
import warnings
|
import warnings
|
||||||
import numpy as np
|
|
||||||
from torch import nn
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from typing import Any, Dict, List, Type
|
||||||
from torch.distributions import kl_divergence
|
from torch.distributions import kl_divergence
|
||||||
from typing import Any, Dict, List, Type, Callable
|
|
||||||
|
|
||||||
|
|
||||||
from tianshou.policy import A2CPolicy
|
from tianshou.data import Batch
|
||||||
from tianshou.data import Batch, ReplayBuffer
|
from tianshou.policy import NPGPolicy
|
||||||
|
|
||||||
|
|
||||||
def _conjugate_gradients(
|
class TRPOPolicy(NPGPolicy):
|
||||||
Avp: Callable[[torch.Tensor], torch.Tensor],
|
|
||||||
b: torch.Tensor,
|
|
||||||
nsteps: int = 10,
|
|
||||||
residual_tol: float = 1e-10
|
|
||||||
) -> torch.Tensor:
|
|
||||||
x = torch.zeros_like(b)
|
|
||||||
r, p = b.clone(), b.clone()
|
|
||||||
# Note: should be 'r, p = b - A(x)', but for x=0, A(x)=0.
|
|
||||||
# Change if doing warm start.
|
|
||||||
rdotr = r.dot(r)
|
|
||||||
for i in range(nsteps):
|
|
||||||
z = Avp(p)
|
|
||||||
alpha = rdotr / p.dot(z)
|
|
||||||
x += alpha * p
|
|
||||||
r -= alpha * z
|
|
||||||
new_rdotr = r.dot(r)
|
|
||||||
if new_rdotr < residual_tol:
|
|
||||||
break
|
|
||||||
p = r + new_rdotr / rdotr * p
|
|
||||||
rdotr = new_rdotr
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def _get_flat_grad(y: torch.Tensor, model: nn.Module, **kwargs: Any) -> torch.Tensor:
|
|
||||||
grads = torch.autograd.grad(y, model.parameters(), **kwargs) # type: ignore
|
|
||||||
return torch.cat([grad.reshape(-1) for grad in grads])
|
|
||||||
|
|
||||||
|
|
||||||
def _set_from_flat_params(model: nn.Module, flat_params: torch.Tensor) -> nn.Module:
|
|
||||||
prev_ind = 0
|
|
||||||
for param in model.parameters():
|
|
||||||
flat_size = int(np.prod(list(param.size())))
|
|
||||||
param.data.copy_(
|
|
||||||
flat_params[prev_ind:prev_ind + flat_size].view(param.size()))
|
|
||||||
prev_ind += flat_size
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class TRPOPolicy(A2CPolicy):
|
|
||||||
"""Implementation of Trust Region Policy Optimization. arXiv:1502.05477.
|
"""Implementation of Trust Region Policy Optimization. arXiv:1502.05477.
|
||||||
|
|
||||||
:param torch.nn.Module actor: the actor network following the rules in
|
:param torch.nn.Module actor: the actor network following the rules in
|
||||||
@ -94,35 +53,16 @@ class TRPOPolicy(A2CPolicy):
|
|||||||
critic: torch.nn.Module,
|
critic: torch.nn.Module,
|
||||||
optim: torch.optim.Optimizer,
|
optim: torch.optim.Optimizer,
|
||||||
dist_fn: Type[torch.distributions.Distribution],
|
dist_fn: Type[torch.distributions.Distribution],
|
||||||
advantage_normalization: bool = True,
|
|
||||||
optim_critic_iters: int = 5,
|
|
||||||
max_kl: float = 0.01,
|
max_kl: float = 0.01,
|
||||||
backtrack_coeff: float = 0.8,
|
backtrack_coeff: float = 0.8,
|
||||||
max_backtracks: int = 10,
|
max_backtracks: int = 10,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(actor, critic, optim, dist_fn, **kwargs)
|
super().__init__(actor, critic, optim, dist_fn, **kwargs)
|
||||||
del self._weight_vf, self._weight_ent, self._grad_norm
|
del self._step_size
|
||||||
self._norm_adv = advantage_normalization
|
|
||||||
self._optim_critic_iters = optim_critic_iters
|
|
||||||
self._max_backtracks = max_backtracks
|
self._max_backtracks = max_backtracks
|
||||||
self._delta = max_kl
|
self._delta = max_kl
|
||||||
self._backtrack_coeff = backtrack_coeff
|
self._backtrack_coeff = backtrack_coeff
|
||||||
# adjusts Hessian-vector product calculation for numerical stability
|
|
||||||
self.__damping = 0.1
|
|
||||||
|
|
||||||
def process_fn(
|
|
||||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
|
||||||
) -> Batch:
|
|
||||||
batch = super().process_fn(batch, buffer, indice)
|
|
||||||
old_log_prob = []
|
|
||||||
with torch.no_grad():
|
|
||||||
for b in batch.split(self._batch, shuffle=False, merge_last=True):
|
|
||||||
old_log_prob.append(self(b).dist.log_prob(b.act))
|
|
||||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
|
||||||
if self._norm_adv:
|
|
||||||
batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std()
|
|
||||||
return batch
|
|
||||||
|
|
||||||
def learn( # type: ignore
|
def learn( # type: ignore
|
||||||
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
||||||
@ -136,7 +76,7 @@ class TRPOPolicy(A2CPolicy):
|
|||||||
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
||||||
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
|
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
|
||||||
actor_loss = -(ratio * b.adv).mean()
|
actor_loss = -(ratio * b.adv).mean()
|
||||||
flat_grads = _get_flat_grad(
|
flat_grads = self._get_flat_grad(
|
||||||
actor_loss, self.actor, retain_graph=True).detach()
|
actor_loss, self.actor, retain_graph=True).detach()
|
||||||
|
|
||||||
# direction: calculate natural gradient
|
# direction: calculate natural gradient
|
||||||
@ -145,20 +85,14 @@ class TRPOPolicy(A2CPolicy):
|
|||||||
|
|
||||||
kl = kl_divergence(old_dist, dist).mean()
|
kl = kl_divergence(old_dist, dist).mean()
|
||||||
# calculate first order gradient of kl with respect to theta
|
# calculate first order gradient of kl with respect to theta
|
||||||
flat_kl_grad = _get_flat_grad(kl, self.actor, create_graph=True)
|
flat_kl_grad = self._get_flat_grad(kl, self.actor, create_graph=True)
|
||||||
|
search_direction = -self._conjugate_gradients(
|
||||||
def MVP(v: torch.Tensor) -> torch.Tensor: # matrix vector product
|
flat_grads, flat_kl_grad, nsteps=10)
|
||||||
# caculate second order gradient of kl with respect to theta
|
|
||||||
kl_v = (flat_kl_grad * v).sum()
|
|
||||||
flat_kl_grad_grad = _get_flat_grad(
|
|
||||||
kl_v, self.actor, retain_graph=True).detach()
|
|
||||||
return flat_kl_grad_grad + v * self.__damping
|
|
||||||
|
|
||||||
search_direction = -_conjugate_gradients(MVP, flat_grads, nsteps=10)
|
|
||||||
|
|
||||||
# stepsize: calculate max stepsize constrained by kl bound
|
# stepsize: calculate max stepsize constrained by kl bound
|
||||||
step_size = torch.sqrt(2 * self._delta / (
|
step_size = torch.sqrt(2 * self._delta / (
|
||||||
search_direction * MVP(search_direction)).sum(0, keepdim=True))
|
search_direction * self._MVP(search_direction, flat_kl_grad)
|
||||||
|
).sum(0, keepdim=True))
|
||||||
|
|
||||||
# stepsize: linesearch stepsize
|
# stepsize: linesearch stepsize
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -166,7 +100,7 @@ class TRPOPolicy(A2CPolicy):
|
|||||||
for param in self.actor.parameters()])
|
for param in self.actor.parameters()])
|
||||||
for i in range(self._max_backtracks):
|
for i in range(self._max_backtracks):
|
||||||
new_flat_params = flat_params + step_size * search_direction
|
new_flat_params = flat_params + step_size * search_direction
|
||||||
_set_from_flat_params(self.actor, new_flat_params)
|
self._set_from_flat_params(self.actor, new_flat_params)
|
||||||
# calculate kl and if in bound, loss actually down
|
# calculate kl and if in bound, loss actually down
|
||||||
new_dist = self(b).dist
|
new_dist = self(b).dist
|
||||||
new_dratio = (
|
new_dratio = (
|
||||||
@ -183,7 +117,7 @@ class TRPOPolicy(A2CPolicy):
|
|||||||
elif i < self._max_backtracks - 1:
|
elif i < self._max_backtracks - 1:
|
||||||
step_size = step_size * self._backtrack_coeff
|
step_size = step_size * self._backtrack_coeff
|
||||||
else:
|
else:
|
||||||
_set_from_flat_params(self.actor, new_flat_params)
|
self._set_from_flat_params(self.actor, new_flat_params)
|
||||||
step_size = torch.tensor([0.0])
|
step_size = torch.tensor([0.0])
|
||||||
warnings.warn("Line search failed! It seems hyperparamters"
|
warnings.warn("Line search failed! It seems hyperparamters"
|
||||||
" are poor and need to be changed.")
|
" are poor and need to be changed.")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user