Add TRPO policy (#337)

This commit is contained in:
ChenDRAG 2021-04-16 20:37:12 +08:00 committed by GitHub
parent 333b8fbd66
commit 5057b5c89e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 398 additions and 112 deletions

View File

@ -19,15 +19,16 @@
**Tianshou** ([天授](https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88)) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed modularized framework and pythonic API for building the deep reinforcement learning agent with the least number of lines of code. The supported interface algorithms currently include:
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
- [Double DQN](https://arxiv.org/pdf/1509.06461.pdf)
- [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf)
- [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf)
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
- [Trust Region Policy Optimization](https://arxiv.org/pdf/1502.05477.pdf)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)

View File

@ -48,6 +48,11 @@ On-policy
:undoc-members:
:show-inheritance:
.. autoclass:: tianshou.policy.TRPOPolicy
:members:
:undoc-members:
:show-inheritance:
.. autoclass:: tianshou.policy.PPOPolicy
:members:
:undoc-members:

View File

@ -9,15 +9,16 @@ Welcome to Tianshou!
**Tianshou** (`天授 <https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88>`_) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed framework and pythonic API for building the deep reinforcement learning agent. The supported interface algorithms include:
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_
* :class:`~tianshou.policy.DQNPolicy` `Dueling DQN <https://arxiv.org/pdf/1511.06581.pdf>`_
* :class:`~tianshou.policy.C51Policy` `Categorical DQN <https://arxiv.org/pdf/1707.06887.pdf>`_
* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN <https://arxiv.org/pdf/1710.10044.pdf>`_
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
* :class:`~tianshou.policy.TRPOPolicy` `Trust Region Policy Optimization <https://arxiv.org/pdf/1502.05477.pdf>`_
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
@ -84,7 +85,7 @@ Tianshou is still under development, you can also check out the documents in sta
tutorials/concepts
tutorials/batch
tutorials/tictactoe
tutorials/trick
tutorials/benchmark
tutorials/cheatsheet

View File

@ -0,0 +1,14 @@
Benchmark
=========
Mujoco Benchmark
----------------
Tianshou's Mujoco benchmark contains state-of-the-art results (even better than `SpinningUp <https://spinningup.openai.com/en/latest/spinningup/bench.html>`_!).
Please refer to https://github.com/thu-ml/tianshou/tree/master/examples/mujoco
Atari Benchmark
---------------
Please refer to https://github.com/thu-ml/tianshou/tree/master/examples/atari

View File

@ -101,7 +101,7 @@ This is related to `Issue 42 <https://github.com/thu-ml/tianshou/issues/42>`_.
If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer.
This function receives up to 7 keys ``obs``, ``act``, ``rew``, ``done``, ``obs_next``, ``info``, and ``policy``, as listed in :class:`~tianshou.data.Batch`. It returns the modified part within a :class:`~tianshou.data.Batch`. Only ``obs`` is defined at env.reset, while every key is specified for normal steps.
It will receive with only "obs" when the collector resets the environment, and will receive five keys "obs_next", "rew", "done", "info", "policy" in a normal env step. It returns either a dict or a :class:`~tianshou.data.Batch` with the modified keys and values.
These variables are intended to gather all the information requires to keep track of a simulation step, namely the (observation, action, reward, done flag, next observation, info, intermediate result of the policy) at time t, for the whole duration of the simulation.
@ -122,7 +122,7 @@ For example, you can write your hook as:
def preprocess_fn(**kwargs):
"""change reward to zero mean"""
# if only obs exist -> reset
# if obs/act/rew/done/... exist -> normal step
# if obs_next/act/rew/done/policy exist -> normal step
if 'rew' not in kwargs:
# means that it is called after env.reset(), it can only process the obs
return Batch() # none of the variables are needed to be updated
@ -163,10 +163,10 @@ First, add an argument "stack_num" to :class:`~tianshou.data.ReplayBuffer`, :cla
Then, change the network to recurrent-style, for example, :class:`~tianshou.utils.net.common.Recurrent`, :class:`~tianshou.utils.net.continuous.RecurrentActorProb` and :class:`~tianshou.utils.net.continuous.RecurrentCritic`.
The above code supports only stacked-observation. If you want to use stacked-action (for Q(stacked-s, stacked-a)), stacked-reward, or other stacked variables, you can add a ``gym.Wrapper`` to modify the state representation. For example, if we add a wrapper that map [s, a] pair to a new state:
The above code supports only stacked-observation. If you want to use stacked-action (for Q(stacked-s, stacked-a)), stacked-reward, or other stacked variables, you can add a ``gym.Wrapper`` to modify the state representation. For example, if we add a wrapper that map ``[s, a]`` pair to a new state:
- Before: (s, a, s', r, d) stored in replay buffer, and get stacked s;
- After applying wrapper: ([s, a], a, [s', a'], r, d) stored in replay buffer, and get both stacked s and a.
- Before: ``(s, a, s', r, d)`` stored in replay buffer, and get stacked s;
- After applying wrapper: ``([s, a], a, [s', a'], r, d)`` stored in replay buffer, and get both stacked s and a.
.. _self_defined_env:

View File

@ -1,88 +0,0 @@
Train a model-free RL agent within 30s
======================================
This page summarizes some hyper-parameter tuning experience and code-level trick when training a model-free DRL agent.
You can also contribute to this page with your own tricks :)
Avoid batch-size = 1
--------------------
In the traditional RL training loop, we always use the policy to interact with only one environment for collecting data. That means most of the time the network use batch-size = 1. Quite inefficient!
Here is an example of showing how inefficient it is:
::
import torch, time
from torch import nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(3, 128), nn.ReLU(inplace=True),
nn.Linear(128, 128), nn.ReLU(inplace=True),
nn.Linear(128, 1))
def forward(self, s):
return self.model(s)
net = Net()
cnt = 1000
div = 128
a = torch.randn([128, 3])
t = time.time()
for i in range(cnt):
b = net(a)
t1 = (time.time() - t) / cnt
print(t1)
t = time.time()
for i in range(cnt):
for a_ in a.split(a.shape[0] // div):
b = net(a_)
t2 = (time.time() - t) / cnt
print(t2)
print(t2 / t1)
The first test uses batch-size 128, and the second test uses batch-size = 1 for 128 times. In our test, the first is 70-80 times faster than the second.
So how could we avoid the case of batch-size = 1? The answer is synchronize sampling: we create multiple independent environments and sample simultaneously. It is similar to A2C, but other algorithms can also use this method. In our experiments, sampling from more environments benefits not only the sample speed but also the converge speed of neural network (we guess it lowers the sample bias).
By the way, A2C is better than A3C in some cases: A3C needs to act independently and sync the gradient to master, but, in a single node, using A3C to act with batch-size = 1 is quite resource-consuming.
Algorithm specific tricks
-------------------------
Here is about the experience of hyper-parameter tuning on CartPole and Pendulum:
* :class:`~tianshou.policy.DQNPolicy`: use estimation_step = 3 or 4 and target network, also with a suitable size of replay buffer;
* :class:`~tianshou.policy.PGPolicy`: TBD
* :class:`~tianshou.policy.A2CPolicy`: TBD
* :class:`~tianshou.policy.PPOPolicy`: TBD
* :class:`~tianshou.policy.DDPGPolicy`, :class:`~tianshou.policy.TD3Policy`, and :class:`~tianshou.policy.SACPolicy`: We found two tricks. The first is to ignore the done flag. The second is to normalize reward to a standard normal distribution (it is against the theoretical analysis, but indeed works very well). The two tricks work amazingly on Mujoco tasks, typically with a faster converge speed (1M -> 200K).
* On-policy algorithms: increase the repeat-time (to 2 or 4 for trivial benchmark, 10 for mujoco) of the given batch in each training update will make the algorithm more stable.
Code-level optimization
-----------------------
Tianshou has many short-but-efficient lines of code. For example, when we want to compute :math:`V(s)` and :math:`V(s')` by the same network, the best way is to concatenate :math:`s` and :math:`s'` together instead of computing the value function using twice of network forward.
Atari/Mujoco Task Specific
--------------------------
Please refer to `Atari examples page <https://github.com/thu-ml/tianshou/tree/master/examples/atari>`_ and `Mujoco examples page <https://github.com/thu-ml/tianshou/tree/master/examples/mujoco>`_.
Finally
-------
With fast-speed sampling, we could use large batch-size and large learning rate for faster convergence.
RL algorithms are seed-sensitive. Try more seeds and pick the best. But for our demo, we just used seed = 0 and found it work surprisingly well on policy gradient, so we did not try other seed.
.. image:: /_static/images/testpg.gif
:align: center

View File

@ -52,7 +52,6 @@ def get_args():
def test_ppo(args=get_args()):
torch.set_num_threads(1) # we just need only one thread for NN
env = gym.make(args.task)
if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -250
@ -110,8 +109,7 @@ def test_ppo(args=get_args()):
# collector
train_collector = Collector(
policy, train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=True)
VectorReplayBuffer(args.buffer_size, len(train_envs)))
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'ppo')

View File

@ -0,0 +1,141 @@
import os
import gym
import torch
import pprint
import argparse
import numpy as np
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from torch.distributions import Independent, Normal
from tianshou.policy import TRPOPolicy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.utils.net.continuous import ActorProb, Critic
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--buffer-size', type=int, default=50000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.95)
parser.add_argument('--epoch', type=int, default=5)
parser.add_argument('--step-per-epoch', type=int, default=50000)
parser.add_argument('--step-per-collect', type=int, default=2048)
parser.add_argument('--repeat-per-collect', type=int,
default=2) # theoretically it should be 1
parser.add_argument('--batch-size', type=int, default=99999)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
# trpo special
parser.add_argument('--gae-lambda', type=float, default=0.95)
parser.add_argument('--rew-norm', type=int, default=1)
parser.add_argument('--norm-adv', type=int, default=1)
parser.add_argument('--optim-critic-iters', type=int, default=5)
parser.add_argument('--max-kl', type=float, default=0.01)
parser.add_argument('--backtrack-coeff', type=float, default=0.8)
parser.add_argument('--max-backtracks', type=int, default=10)
args = parser.parse_known_args()[0]
return args
def test_trpo(args=get_args()):
env = gym.make(args.task)
if args.task == 'Pendulum-v0':
env.spec.reward_threshold = -250
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
# you can also use tianshou.env.SubprocVectorEnv
# train_envs = gym.make(args.task)
train_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
activation=nn.Tanh, device=args.device)
actor = ActorProb(net, args.action_shape, max_action=args.max_action,
unbounded=True, device=args.device).to(args.device)
critic = Critic(Net(
args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device,
activation=nn.Tanh), device=args.device).to(args.device)
# orthogonal initialization
for m in list(actor.modules()) + list(critic.modules()):
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
def dist(*logits):
return Independent(Normal(*logits), 1)
policy = TRPOPolicy(
actor, critic, optim, dist,
discount_factor=args.gamma,
reward_normalization=args.rew_norm,
advantage_normalization=args.norm_adv,
gae_lambda=args.gae_lambda,
action_space=env.action_space,
optim_critic_iters=args.optim_critic_iters,
max_kl=args.max_kl,
backtrack_coeff=args.backtrack_coeff,
max_backtracks=args.max_backtracks)
# collector
train_collector = Collector(
policy, train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)))
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'trpo')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
# trainer
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
step_per_collect=args.step_per_collect, stop_fn=stop_fn, save_fn=save_fn,
logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
if __name__ == '__main__':
test_trpo()

View File

@ -86,8 +86,7 @@ def test_a2c_with_il(args=get_args()):
# collector
train_collector = Collector(
policy, train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=True)
VectorReplayBuffer(args.buffer_size, len(train_envs)))
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'a2c')

View File

@ -73,8 +73,7 @@ def test_pg(args=get_args()):
# collector
train_collector = Collector(
policy, train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=True)
VectorReplayBuffer(args.buffer_size, len(train_envs)))
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'pg')

View File

@ -93,8 +93,7 @@ def test_ppo(args=get_args()):
# collector
train_collector = Collector(
policy, train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=True)
VectorReplayBuffer(args.buffer_size, len(train_envs)))
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'ppo')

View File

@ -91,8 +91,7 @@ def test_discrete_sac(args=get_args()):
# collector
train_collector = Collector(
policy, train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=True)
VectorReplayBuffer(args.buffer_size, len(train_envs)))
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size)
# log

View File

@ -36,9 +36,10 @@ class Collector(object):
The "preprocess_fn" is a function called before the data has been added to the
buffer with batch format. It will receive with only "obs" when the collector resets
the environment, and will receive four keys "obs_next", "rew", "done", "info" in a
normal env step. It returns either a dict or a :class:`~tianshou.data.Batch` with
the modified keys and values. Examples are in "test/base/test_collector.py".
the environment, and will receive five keys "obs_next", "rew", "done", "info", and
"policy" in a normal env step. It returns either a dict or a
:class:`~tianshou.data.Batch` with the modified keys and values. Examples are in
"test/base/test_collector.py".
.. note::
@ -231,6 +232,7 @@ class Collector(object):
rew=self.data.rew,
done=self.data.done,
info=self.data.info,
policy=self.data.policy,
))
if render:

View File

@ -7,6 +7,7 @@ from tianshou.policy.modelfree.pg import PGPolicy
from tianshou.policy.modelfree.a2c import A2CPolicy
from tianshou.policy.modelfree.ddpg import DDPGPolicy
from tianshou.policy.modelfree.ppo import PPOPolicy
from tianshou.policy.modelfree.trpo import TRPOPolicy
from tianshou.policy.modelfree.td3 import TD3Policy
from tianshou.policy.modelfree.sac import SACPolicy
from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
@ -26,6 +27,7 @@ __all__ = [
"A2CPolicy",
"DDPGPolicy",
"PPOPolicy",
"TRPOPolicy",
"TD3Policy",
"SACPolicy",
"DiscreteSACPolicy",

View File

@ -0,0 +1,214 @@
import torch
import warnings
import numpy as np
from torch import nn
import torch.nn.functional as F
from torch.distributions import kl_divergence
from typing import Any, Dict, List, Type, Callable
from tianshou.policy import A2CPolicy
from tianshou.data import Batch, ReplayBuffer
def _conjugate_gradients(
Avp: Callable[[torch.Tensor], torch.Tensor],
b: torch.Tensor,
nsteps: int = 10,
residual_tol: float = 1e-10
) -> torch.Tensor:
x = torch.zeros_like(b)
r, p = b.clone(), b.clone()
# Note: should be 'r, p = b - A(x)', but for x=0, A(x)=0.
# Change if doing warm start.
rdotr = r.dot(r)
for i in range(nsteps):
z = Avp(p)
alpha = rdotr / p.dot(z)
x += alpha * p
r -= alpha * z
new_rdotr = r.dot(r)
if new_rdotr < residual_tol:
break
p = r + new_rdotr / rdotr * p
rdotr = new_rdotr
return x
def _get_flat_grad(y: torch.Tensor, model: nn.Module, **kwargs: Any) -> torch.Tensor:
grads = torch.autograd.grad(y, model.parameters(), **kwargs) # type: ignore
return torch.cat([grad.reshape(-1) for grad in grads])
def _set_from_flat_params(model: nn.Module, flat_params: torch.Tensor) -> nn.Module:
prev_ind = 0
for param in model.parameters():
flat_size = int(np.prod(list(param.size())))
param.data.copy_(
flat_params[prev_ind:prev_ind + flat_size].view(param.size()))
prev_ind += flat_size
return model
class TRPOPolicy(A2CPolicy):
"""Implementation of Trust Region Policy Optimization. arXiv:1502.05477.
:param torch.nn.Module actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param torch.nn.Module critic: the critic network. (s -> V(s))
:param torch.optim.Optimizer optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action.
:type dist_fn: Type[torch.distributions.Distribution]
:param bool advantage_normalization: whether to do per mini-batch advantage
normalization. Default to True.
:param int optim_critic_iters: Number of times to optimize critic network per
update. Default to 5.
:param int max_kl: max kl-divergence used to constrain each actor network update.
Default to 0.01.
:param float backtrack_coeff: Coefficient to be multiplied by step size when
constraints are not met. Default to 0.8.
:param int max_backtracks: Max number of backtracking times in linesearch. Default
to 10.
:param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation.
Default to 0.95.
:param bool reward_normalization: normalize estimated values to have std close to
1. Default to False.
:param int max_batchsize: the maximum size of the batch when computing GAE,
depends on the size of available memory and the memory cost of the
model; should be as large as possible within the memory constraint.
Default to 256.
:param bool action_scaling: whether to map actions from range [-1, 1] to range
[action_spaces.low, action_spaces.high]. Default to True.
:param str action_bound_method: method to bound action to range [-1, 1], can be
either "clip" (for simply clipping the action), "tanh" (for applying tanh
squashing) for now, or empty string for no bounding. Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
"""
def __init__(
self,
actor: torch.nn.Module,
critic: torch.nn.Module,
optim: torch.optim.Optimizer,
dist_fn: Type[torch.distributions.Distribution],
advantage_normalization: bool = True,
optim_critic_iters: int = 5,
max_kl: float = 0.01,
backtrack_coeff: float = 0.8,
max_backtracks: int = 10,
**kwargs: Any,
) -> None:
super().__init__(actor, critic, optim, dist_fn, **kwargs)
del self._weight_vf, self._weight_ent, self._grad_norm
self._norm_adv = advantage_normalization
self._optim_critic_iters = optim_critic_iters
self._max_backtracks = max_backtracks
self._delta = max_kl
self._backtrack_coeff = backtrack_coeff
# adjusts Hessian-vector product calculation for numerical stability
self.__damping = 0.1
def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch:
batch = super().process_fn(batch, buffer, indice)
old_log_prob = []
with torch.no_grad():
for b in batch.split(self._batch, shuffle=False, merge_last=True):
old_log_prob.append(self(b).dist.log_prob(b.act))
batch.logp_old = torch.cat(old_log_prob, dim=0)
if self._norm_adv:
batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std()
return batch
def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
) -> Dict[str, List[float]]:
actor_losses, vf_losses, step_sizes, kls = [], [], [], []
for step in range(repeat):
for b in batch.split(batch_size, merge_last=True):
# optimize actor
# direction: calculate villia gradient
dist = self(b).dist # TODO could come from batch
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
actor_loss = -(ratio * b.adv).mean()
flat_grads = _get_flat_grad(
actor_loss, self.actor, retain_graph=True).detach()
# direction: calculate natural gradient
with torch.no_grad():
old_dist = self(b).dist
kl = kl_divergence(old_dist, dist).mean()
# calculate first order gradient of kl with respect to theta
flat_kl_grad = _get_flat_grad(kl, self.actor, create_graph=True)
def MVP(v: torch.Tensor) -> torch.Tensor: # matrix vector product
# caculate second order gradient of kl with respect to theta
kl_v = (flat_kl_grad * v).sum()
flat_kl_grad_grad = _get_flat_grad(
kl_v, self.actor, retain_graph=True).detach()
return flat_kl_grad_grad + v * self.__damping
search_direction = -_conjugate_gradients(MVP, flat_grads, nsteps=10)
# stepsize: calculate max stepsize constrained by kl bound
step_size = torch.sqrt(2 * self._delta / (
search_direction * MVP(search_direction)).sum(0, keepdim=True))
# stepsize: linesearch stepsize
with torch.no_grad():
flat_params = torch.cat([param.data.view(-1)
for param in self.actor.parameters()])
for i in range(self._max_backtracks):
new_flat_params = flat_params + step_size * search_direction
_set_from_flat_params(self.actor, new_flat_params)
# calculate kl and if in bound, loss actually down
new_dist = self(b).dist
new_dratio = (
new_dist.log_prob(b.act) - b.logp_old).exp().float()
new_dratio = new_dratio.reshape(
new_dratio.size(0), -1).transpose(0, 1)
new_actor_loss = -(new_dratio * b.adv).mean()
kl = kl_divergence(old_dist, new_dist).mean()
if kl < self._delta and new_actor_loss < actor_loss:
if i > 0:
warnings.warn(f"Backtracking to step {i}. "
"Hyperparamters aren't good enough.")
break
elif i < self._max_backtracks - 1:
step_size = step_size * self._backtrack_coeff
else:
_set_from_flat_params(self.actor, new_flat_params)
step_size = torch.tensor([0.0])
warnings.warn("Line search failed! It seems hyperparamters"
" are poor and need to be changed.")
# optimize citirc
for _ in range(self._optim_critic_iters):
value = self.critic(b.obs).flatten()
vf_loss = F.mse_loss(b.returns, value)
self.optim.zero_grad()
vf_loss.backward()
self.optim.step()
actor_losses.append(actor_loss.item())
vf_losses.append(vf_loss.item())
step_sizes.append(step_size.item())
kls.append(kl.item())
# update learning rate if lr_scheduler is given
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return {
"loss/actor": actor_losses,
"loss/vf": vf_losses,
"step_size": step_sizes,
"kl": kls,
}