Add TRPO policy (#337)
This commit is contained in:
parent
333b8fbd66
commit
5057b5c89e
@ -19,15 +19,16 @@
|
||||
**Tianshou** ([天授](https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88)) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed modularized framework and pythonic API for building the deep reinforcement learning agent with the least number of lines of code. The supported interface algorithms currently include:
|
||||
|
||||
|
||||
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
|
||||
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
|
||||
- [Double DQN](https://arxiv.org/pdf/1509.06461.pdf)
|
||||
- [Dueling DQN](https://arxiv.org/pdf/1511.06581.pdf)
|
||||
- [Categorical DQN (C51)](https://arxiv.org/pdf/1707.06887.pdf)
|
||||
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)
|
||||
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
|
||||
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
|
||||
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
|
||||
- [Trust Region Policy Optimization](https://arxiv.org/pdf/1502.05477.pdf)
|
||||
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
|
||||
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
|
||||
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
|
||||
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
|
||||
- [Discrete Soft Actor-Critic (SAC-Discrete)](https://arxiv.org/pdf/1910.07207.pdf)
|
||||
|
@ -48,6 +48,11 @@ On-policy
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.TRPOPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. autoclass:: tianshou.policy.PPOPolicy
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
@ -9,15 +9,16 @@ Welcome to Tianshou!
|
||||
|
||||
**Tianshou** (`天授 <https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88>`_) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed framework and pythonic API for building the deep reinforcement learning agent. The supported interface algorithms include:
|
||||
|
||||
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Deep Q-Network <https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf>`_
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Double DQN <https://arxiv.org/pdf/1509.06461.pdf>`_
|
||||
* :class:`~tianshou.policy.DQNPolicy` `Dueling DQN <https://arxiv.org/pdf/1511.06581.pdf>`_
|
||||
* :class:`~tianshou.policy.C51Policy` `Categorical DQN <https://arxiv.org/pdf/1707.06887.pdf>`_
|
||||
* :class:`~tianshou.policy.QRDQNPolicy` `Quantile Regression DQN <https://arxiv.org/pdf/1710.10044.pdf>`_
|
||||
* :class:`~tianshou.policy.PGPolicy` `Policy Gradient <https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf>`_
|
||||
* :class:`~tianshou.policy.A2CPolicy` `Advantage Actor-Critic <https://openai.com/blog/baselines-acktr-a2c/>`_
|
||||
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
|
||||
* :class:`~tianshou.policy.TRPOPolicy` `Trust Region Policy Optimization <https://arxiv.org/pdf/1502.05477.pdf>`_
|
||||
* :class:`~tianshou.policy.PPOPolicy` `Proximal Policy Optimization <https://arxiv.org/pdf/1707.06347.pdf>`_
|
||||
* :class:`~tianshou.policy.DDPGPolicy` `Deep Deterministic Policy Gradient <https://arxiv.org/pdf/1509.02971.pdf>`_
|
||||
* :class:`~tianshou.policy.TD3Policy` `Twin Delayed DDPG <https://arxiv.org/pdf/1802.09477.pdf>`_
|
||||
* :class:`~tianshou.policy.SACPolicy` `Soft Actor-Critic <https://arxiv.org/pdf/1812.05905.pdf>`_
|
||||
* :class:`~tianshou.policy.DiscreteSACPolicy` `Discrete Soft Actor-Critic <https://arxiv.org/pdf/1910.07207.pdf>`_
|
||||
@ -84,7 +85,7 @@ Tianshou is still under development, you can also check out the documents in sta
|
||||
tutorials/concepts
|
||||
tutorials/batch
|
||||
tutorials/tictactoe
|
||||
tutorials/trick
|
||||
tutorials/benchmark
|
||||
tutorials/cheatsheet
|
||||
|
||||
|
||||
|
14
docs/tutorials/benchmark.rst
Normal file
14
docs/tutorials/benchmark.rst
Normal file
@ -0,0 +1,14 @@
|
||||
Benchmark
|
||||
=========
|
||||
|
||||
Mujoco Benchmark
|
||||
----------------
|
||||
|
||||
Tianshou's Mujoco benchmark contains state-of-the-art results (even better than `SpinningUp <https://spinningup.openai.com/en/latest/spinningup/bench.html>`_!).
|
||||
|
||||
Please refer to https://github.com/thu-ml/tianshou/tree/master/examples/mujoco
|
||||
|
||||
Atari Benchmark
|
||||
---------------
|
||||
|
||||
Please refer to https://github.com/thu-ml/tianshou/tree/master/examples/atari
|
@ -101,7 +101,7 @@ This is related to `Issue 42 <https://github.com/thu-ml/tianshou/issues/42>`_.
|
||||
|
||||
If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer.
|
||||
|
||||
This function receives up to 7 keys ``obs``, ``act``, ``rew``, ``done``, ``obs_next``, ``info``, and ``policy``, as listed in :class:`~tianshou.data.Batch`. It returns the modified part within a :class:`~tianshou.data.Batch`. Only ``obs`` is defined at env.reset, while every key is specified for normal steps.
|
||||
It will receive with only "obs" when the collector resets the environment, and will receive five keys "obs_next", "rew", "done", "info", "policy" in a normal env step. It returns either a dict or a :class:`~tianshou.data.Batch` with the modified keys and values.
|
||||
|
||||
These variables are intended to gather all the information requires to keep track of a simulation step, namely the (observation, action, reward, done flag, next observation, info, intermediate result of the policy) at time t, for the whole duration of the simulation.
|
||||
|
||||
@ -122,7 +122,7 @@ For example, you can write your hook as:
|
||||
def preprocess_fn(**kwargs):
|
||||
"""change reward to zero mean"""
|
||||
# if only obs exist -> reset
|
||||
# if obs/act/rew/done/... exist -> normal step
|
||||
# if obs_next/act/rew/done/policy exist -> normal step
|
||||
if 'rew' not in kwargs:
|
||||
# means that it is called after env.reset(), it can only process the obs
|
||||
return Batch() # none of the variables are needed to be updated
|
||||
@ -163,10 +163,10 @@ First, add an argument "stack_num" to :class:`~tianshou.data.ReplayBuffer`, :cla
|
||||
|
||||
Then, change the network to recurrent-style, for example, :class:`~tianshou.utils.net.common.Recurrent`, :class:`~tianshou.utils.net.continuous.RecurrentActorProb` and :class:`~tianshou.utils.net.continuous.RecurrentCritic`.
|
||||
|
||||
The above code supports only stacked-observation. If you want to use stacked-action (for Q(stacked-s, stacked-a)), stacked-reward, or other stacked variables, you can add a ``gym.Wrapper`` to modify the state representation. For example, if we add a wrapper that map [s, a] pair to a new state:
|
||||
The above code supports only stacked-observation. If you want to use stacked-action (for Q(stacked-s, stacked-a)), stacked-reward, or other stacked variables, you can add a ``gym.Wrapper`` to modify the state representation. For example, if we add a wrapper that map ``[s, a]`` pair to a new state:
|
||||
|
||||
- Before: (s, a, s', r, d) stored in replay buffer, and get stacked s;
|
||||
- After applying wrapper: ([s, a], a, [s', a'], r, d) stored in replay buffer, and get both stacked s and a.
|
||||
- Before: ``(s, a, s', r, d)`` stored in replay buffer, and get stacked s;
|
||||
- After applying wrapper: ``([s, a], a, [s', a'], r, d)`` stored in replay buffer, and get both stacked s and a.
|
||||
|
||||
|
||||
.. _self_defined_env:
|
||||
|
@ -1,88 +0,0 @@
|
||||
Train a model-free RL agent within 30s
|
||||
======================================
|
||||
|
||||
This page summarizes some hyper-parameter tuning experience and code-level trick when training a model-free DRL agent.
|
||||
|
||||
You can also contribute to this page with your own tricks :)
|
||||
|
||||
|
||||
Avoid batch-size = 1
|
||||
--------------------
|
||||
|
||||
In the traditional RL training loop, we always use the policy to interact with only one environment for collecting data. That means most of the time the network use batch-size = 1. Quite inefficient!
|
||||
Here is an example of showing how inefficient it is:
|
||||
::
|
||||
|
||||
import torch, time
|
||||
from torch import nn
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = nn.Sequential(
|
||||
nn.Linear(3, 128), nn.ReLU(inplace=True),
|
||||
nn.Linear(128, 128), nn.ReLU(inplace=True),
|
||||
nn.Linear(128, 1))
|
||||
def forward(self, s):
|
||||
return self.model(s)
|
||||
|
||||
net = Net()
|
||||
cnt = 1000
|
||||
div = 128
|
||||
a = torch.randn([128, 3])
|
||||
|
||||
t = time.time()
|
||||
for i in range(cnt):
|
||||
b = net(a)
|
||||
t1 = (time.time() - t) / cnt
|
||||
print(t1)
|
||||
t = time.time()
|
||||
for i in range(cnt):
|
||||
for a_ in a.split(a.shape[0] // div):
|
||||
b = net(a_)
|
||||
t2 = (time.time() - t) / cnt
|
||||
print(t2)
|
||||
print(t2 / t1)
|
||||
|
||||
The first test uses batch-size 128, and the second test uses batch-size = 1 for 128 times. In our test, the first is 70-80 times faster than the second.
|
||||
|
||||
So how could we avoid the case of batch-size = 1? The answer is synchronize sampling: we create multiple independent environments and sample simultaneously. It is similar to A2C, but other algorithms can also use this method. In our experiments, sampling from more environments benefits not only the sample speed but also the converge speed of neural network (we guess it lowers the sample bias).
|
||||
|
||||
By the way, A2C is better than A3C in some cases: A3C needs to act independently and sync the gradient to master, but, in a single node, using A3C to act with batch-size = 1 is quite resource-consuming.
|
||||
|
||||
|
||||
Algorithm specific tricks
|
||||
-------------------------
|
||||
|
||||
Here is about the experience of hyper-parameter tuning on CartPole and Pendulum:
|
||||
|
||||
* :class:`~tianshou.policy.DQNPolicy`: use estimation_step = 3 or 4 and target network, also with a suitable size of replay buffer;
|
||||
* :class:`~tianshou.policy.PGPolicy`: TBD
|
||||
* :class:`~tianshou.policy.A2CPolicy`: TBD
|
||||
* :class:`~tianshou.policy.PPOPolicy`: TBD
|
||||
* :class:`~tianshou.policy.DDPGPolicy`, :class:`~tianshou.policy.TD3Policy`, and :class:`~tianshou.policy.SACPolicy`: We found two tricks. The first is to ignore the done flag. The second is to normalize reward to a standard normal distribution (it is against the theoretical analysis, but indeed works very well). The two tricks work amazingly on Mujoco tasks, typically with a faster converge speed (1M -> 200K).
|
||||
|
||||
* On-policy algorithms: increase the repeat-time (to 2 or 4 for trivial benchmark, 10 for mujoco) of the given batch in each training update will make the algorithm more stable.
|
||||
|
||||
|
||||
Code-level optimization
|
||||
-----------------------
|
||||
|
||||
Tianshou has many short-but-efficient lines of code. For example, when we want to compute :math:`V(s)` and :math:`V(s')` by the same network, the best way is to concatenate :math:`s` and :math:`s'` together instead of computing the value function using twice of network forward.
|
||||
|
||||
|
||||
Atari/Mujoco Task Specific
|
||||
--------------------------
|
||||
|
||||
Please refer to `Atari examples page <https://github.com/thu-ml/tianshou/tree/master/examples/atari>`_ and `Mujoco examples page <https://github.com/thu-ml/tianshou/tree/master/examples/mujoco>`_.
|
||||
|
||||
|
||||
Finally
|
||||
-------
|
||||
|
||||
With fast-speed sampling, we could use large batch-size and large learning rate for faster convergence.
|
||||
|
||||
RL algorithms are seed-sensitive. Try more seeds and pick the best. But for our demo, we just used seed = 0 and found it work surprisingly well on policy gradient, so we did not try other seed.
|
||||
|
||||
.. image:: /_static/images/testpg.gif
|
||||
:align: center
|
@ -52,7 +52,6 @@ def get_args():
|
||||
|
||||
|
||||
def test_ppo(args=get_args()):
|
||||
torch.set_num_threads(1) # we just need only one thread for NN
|
||||
env = gym.make(args.task)
|
||||
if args.task == 'Pendulum-v0':
|
||||
env.spec.reward_threshold = -250
|
||||
@ -110,8 +109,7 @@ def test_ppo(args=get_args()):
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs,
|
||||
VectorReplayBuffer(args.buffer_size, len(train_envs)),
|
||||
exploration_noise=True)
|
||||
VectorReplayBuffer(args.buffer_size, len(train_envs)))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'ppo')
|
||||
|
141
test/continuous/test_trpo.py
Normal file
141
test/continuous/test_trpo.py
Normal file
@ -0,0 +1,141 @@
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.distributions import Independent, Normal
|
||||
|
||||
from tianshou.policy import TRPOPolicy
|
||||
from tianshou.utils import BasicLogger
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||
parser.add_argument('--seed', type=int, default=1)
|
||||
parser.add_argument('--buffer-size', type=int, default=50000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.95)
|
||||
parser.add_argument('--epoch', type=int, default=5)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=50000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=2048)
|
||||
parser.add_argument('--repeat-per-collect', type=int,
|
||||
default=2) # theoretically it should be 1
|
||||
parser.add_argument('--batch-size', type=int, default=99999)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
|
||||
parser.add_argument('--training-num', type=int, default=16)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
# trpo special
|
||||
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
||||
parser.add_argument('--rew-norm', type=int, default=1)
|
||||
parser.add_argument('--norm-adv', type=int, default=1)
|
||||
parser.add_argument('--optim-critic-iters', type=int, default=5)
|
||||
parser.add_argument('--max-kl', type=float, default=0.01)
|
||||
parser.add_argument('--backtrack-coeff', type=float, default=0.8)
|
||||
parser.add_argument('--max-backtracks', type=int, default=10)
|
||||
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def test_trpo(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
if args.task == 'Pendulum-v0':
|
||||
env.spec.reward_threshold = -250
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0]
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
activation=nn.Tanh, device=args.device)
|
||||
actor = ActorProb(net, args.action_shape, max_action=args.max_action,
|
||||
unbounded=True, device=args.device).to(args.device)
|
||||
critic = Critic(Net(
|
||||
args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device,
|
||||
activation=nn.Tanh), device=args.device).to(args.device)
|
||||
# orthogonal initialization
|
||||
for m in list(actor.modules()) + list(critic.modules()):
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.orthogonal_(m.weight)
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
optim = torch.optim.Adam(set(
|
||||
actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||
|
||||
# replace DiagGuassian with Independent(Normal) which is equivalent
|
||||
# pass *logits to be consistent with policy.forward
|
||||
def dist(*logits):
|
||||
return Independent(Normal(*logits), 1)
|
||||
|
||||
policy = TRPOPolicy(
|
||||
actor, critic, optim, dist,
|
||||
discount_factor=args.gamma,
|
||||
reward_normalization=args.rew_norm,
|
||||
advantage_normalization=args.norm_adv,
|
||||
gae_lambda=args.gae_lambda,
|
||||
action_space=env.action_space,
|
||||
optim_critic_iters=args.optim_critic_iters,
|
||||
max_kl=args.max_kl,
|
||||
backtrack_coeff=args.backtrack_coeff,
|
||||
max_backtracks=args.max_backtracks)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs,
|
||||
VectorReplayBuffer(args.buffer_size, len(train_envs)))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'trpo')
|
||||
writer = SummaryWriter(log_path)
|
||||
logger = BasicLogger(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
# trainer
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
|
||||
step_per_collect=args.step_per_collect, stop_fn=stop_fn, save_fn=save_fn,
|
||||
logger=logger)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
rews, lens = result["rews"], result["lens"]
|
||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_trpo()
|
@ -86,8 +86,7 @@ def test_a2c_with_il(args=get_args()):
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs,
|
||||
VectorReplayBuffer(args.buffer_size, len(train_envs)),
|
||||
exploration_noise=True)
|
||||
VectorReplayBuffer(args.buffer_size, len(train_envs)))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'a2c')
|
||||
|
@ -73,8 +73,7 @@ def test_pg(args=get_args()):
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs,
|
||||
VectorReplayBuffer(args.buffer_size, len(train_envs)),
|
||||
exploration_noise=True)
|
||||
VectorReplayBuffer(args.buffer_size, len(train_envs)))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'pg')
|
||||
|
@ -93,8 +93,7 @@ def test_ppo(args=get_args()):
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs,
|
||||
VectorReplayBuffer(args.buffer_size, len(train_envs)),
|
||||
exploration_noise=True)
|
||||
VectorReplayBuffer(args.buffer_size, len(train_envs)))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'ppo')
|
||||
|
@ -91,8 +91,7 @@ def test_discrete_sac(args=get_args()):
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs,
|
||||
VectorReplayBuffer(args.buffer_size, len(train_envs)),
|
||||
exploration_noise=True)
|
||||
VectorReplayBuffer(args.buffer_size, len(train_envs)))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# train_collector.collect(n_step=args.buffer_size)
|
||||
# log
|
||||
|
@ -36,9 +36,10 @@ class Collector(object):
|
||||
|
||||
The "preprocess_fn" is a function called before the data has been added to the
|
||||
buffer with batch format. It will receive with only "obs" when the collector resets
|
||||
the environment, and will receive four keys "obs_next", "rew", "done", "info" in a
|
||||
normal env step. It returns either a dict or a :class:`~tianshou.data.Batch` with
|
||||
the modified keys and values. Examples are in "test/base/test_collector.py".
|
||||
the environment, and will receive five keys "obs_next", "rew", "done", "info", and
|
||||
"policy" in a normal env step. It returns either a dict or a
|
||||
:class:`~tianshou.data.Batch` with the modified keys and values. Examples are in
|
||||
"test/base/test_collector.py".
|
||||
|
||||
.. note::
|
||||
|
||||
@ -231,6 +232,7 @@ class Collector(object):
|
||||
rew=self.data.rew,
|
||||
done=self.data.done,
|
||||
info=self.data.info,
|
||||
policy=self.data.policy,
|
||||
))
|
||||
|
||||
if render:
|
||||
|
@ -7,6 +7,7 @@ from tianshou.policy.modelfree.pg import PGPolicy
|
||||
from tianshou.policy.modelfree.a2c import A2CPolicy
|
||||
from tianshou.policy.modelfree.ddpg import DDPGPolicy
|
||||
from tianshou.policy.modelfree.ppo import PPOPolicy
|
||||
from tianshou.policy.modelfree.trpo import TRPOPolicy
|
||||
from tianshou.policy.modelfree.td3 import TD3Policy
|
||||
from tianshou.policy.modelfree.sac import SACPolicy
|
||||
from tianshou.policy.modelfree.discrete_sac import DiscreteSACPolicy
|
||||
@ -26,6 +27,7 @@ __all__ = [
|
||||
"A2CPolicy",
|
||||
"DDPGPolicy",
|
||||
"PPOPolicy",
|
||||
"TRPOPolicy",
|
||||
"TD3Policy",
|
||||
"SACPolicy",
|
||||
"DiscreteSACPolicy",
|
||||
|
214
tianshou/policy/modelfree/trpo.py
Normal file
214
tianshou/policy/modelfree/trpo.py
Normal file
@ -0,0 +1,214 @@
|
||||
import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributions import kl_divergence
|
||||
from typing import Any, Dict, List, Type, Callable
|
||||
|
||||
|
||||
from tianshou.policy import A2CPolicy
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
|
||||
|
||||
def _conjugate_gradients(
|
||||
Avp: Callable[[torch.Tensor], torch.Tensor],
|
||||
b: torch.Tensor,
|
||||
nsteps: int = 10,
|
||||
residual_tol: float = 1e-10
|
||||
) -> torch.Tensor:
|
||||
x = torch.zeros_like(b)
|
||||
r, p = b.clone(), b.clone()
|
||||
# Note: should be 'r, p = b - A(x)', but for x=0, A(x)=0.
|
||||
# Change if doing warm start.
|
||||
rdotr = r.dot(r)
|
||||
for i in range(nsteps):
|
||||
z = Avp(p)
|
||||
alpha = rdotr / p.dot(z)
|
||||
x += alpha * p
|
||||
r -= alpha * z
|
||||
new_rdotr = r.dot(r)
|
||||
if new_rdotr < residual_tol:
|
||||
break
|
||||
p = r + new_rdotr / rdotr * p
|
||||
rdotr = new_rdotr
|
||||
return x
|
||||
|
||||
|
||||
def _get_flat_grad(y: torch.Tensor, model: nn.Module, **kwargs: Any) -> torch.Tensor:
|
||||
grads = torch.autograd.grad(y, model.parameters(), **kwargs) # type: ignore
|
||||
return torch.cat([grad.reshape(-1) for grad in grads])
|
||||
|
||||
|
||||
def _set_from_flat_params(model: nn.Module, flat_params: torch.Tensor) -> nn.Module:
|
||||
prev_ind = 0
|
||||
for param in model.parameters():
|
||||
flat_size = int(np.prod(list(param.size())))
|
||||
param.data.copy_(
|
||||
flat_params[prev_ind:prev_ind + flat_size].view(param.size()))
|
||||
prev_ind += flat_size
|
||||
return model
|
||||
|
||||
|
||||
class TRPOPolicy(A2CPolicy):
|
||||
"""Implementation of Trust Region Policy Optimization. arXiv:1502.05477.
|
||||
|
||||
:param torch.nn.Module actor: the actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param torch.nn.Module critic: the critic network. (s -> V(s))
|
||||
:param torch.optim.Optimizer optim: the optimizer for actor and critic network.
|
||||
:param dist_fn: distribution class for computing the action.
|
||||
:type dist_fn: Type[torch.distributions.Distribution]
|
||||
:param bool advantage_normalization: whether to do per mini-batch advantage
|
||||
normalization. Default to True.
|
||||
:param int optim_critic_iters: Number of times to optimize critic network per
|
||||
update. Default to 5.
|
||||
:param int max_kl: max kl-divergence used to constrain each actor network update.
|
||||
Default to 0.01.
|
||||
:param float backtrack_coeff: Coefficient to be multiplied by step size when
|
||||
constraints are not met. Default to 0.8.
|
||||
:param int max_backtracks: Max number of backtracking times in linesearch. Default
|
||||
to 10.
|
||||
:param float gae_lambda: in [0, 1], param for Generalized Advantage Estimation.
|
||||
Default to 0.95.
|
||||
:param bool reward_normalization: normalize estimated values to have std close to
|
||||
1. Default to False.
|
||||
:param int max_batchsize: the maximum size of the batch when computing GAE,
|
||||
depends on the size of available memory and the memory cost of the
|
||||
model; should be as large as possible within the memory constraint.
|
||||
Default to 256.
|
||||
:param bool action_scaling: whether to map actions from range [-1, 1] to range
|
||||
[action_spaces.low, action_spaces.high]. Default to True.
|
||||
:param str action_bound_method: method to bound action to range [-1, 1], can be
|
||||
either "clip" (for simply clipping the action), "tanh" (for applying tanh
|
||||
squashing) for now, or empty string for no bounding. Default to "clip".
|
||||
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: Type[torch.distributions.Distribution],
|
||||
advantage_normalization: bool = True,
|
||||
optim_critic_iters: int = 5,
|
||||
max_kl: float = 0.01,
|
||||
backtrack_coeff: float = 0.8,
|
||||
max_backtracks: int = 10,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(actor, critic, optim, dist_fn, **kwargs)
|
||||
del self._weight_vf, self._weight_ent, self._grad_norm
|
||||
self._norm_adv = advantage_normalization
|
||||
self._optim_critic_iters = optim_critic_iters
|
||||
self._max_backtracks = max_backtracks
|
||||
self._delta = max_kl
|
||||
self._backtrack_coeff = backtrack_coeff
|
||||
# adjusts Hessian-vector product calculation for numerical stability
|
||||
self.__damping = 0.1
|
||||
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> Batch:
|
||||
batch = super().process_fn(batch, buffer, indice)
|
||||
old_log_prob = []
|
||||
with torch.no_grad():
|
||||
for b in batch.split(self._batch, shuffle=False, merge_last=True):
|
||||
old_log_prob.append(self(b).dist.log_prob(b.act))
|
||||
batch.logp_old = torch.cat(old_log_prob, dim=0)
|
||||
if self._norm_adv:
|
||||
batch.adv = (batch.adv - batch.adv.mean()) / batch.adv.std()
|
||||
return batch
|
||||
|
||||
def learn( # type: ignore
|
||||
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
||||
) -> Dict[str, List[float]]:
|
||||
actor_losses, vf_losses, step_sizes, kls = [], [], [], []
|
||||
for step in range(repeat):
|
||||
for b in batch.split(batch_size, merge_last=True):
|
||||
# optimize actor
|
||||
# direction: calculate villia gradient
|
||||
dist = self(b).dist # TODO could come from batch
|
||||
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
||||
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
|
||||
actor_loss = -(ratio * b.adv).mean()
|
||||
flat_grads = _get_flat_grad(
|
||||
actor_loss, self.actor, retain_graph=True).detach()
|
||||
|
||||
# direction: calculate natural gradient
|
||||
with torch.no_grad():
|
||||
old_dist = self(b).dist
|
||||
|
||||
kl = kl_divergence(old_dist, dist).mean()
|
||||
# calculate first order gradient of kl with respect to theta
|
||||
flat_kl_grad = _get_flat_grad(kl, self.actor, create_graph=True)
|
||||
|
||||
def MVP(v: torch.Tensor) -> torch.Tensor: # matrix vector product
|
||||
# caculate second order gradient of kl with respect to theta
|
||||
kl_v = (flat_kl_grad * v).sum()
|
||||
flat_kl_grad_grad = _get_flat_grad(
|
||||
kl_v, self.actor, retain_graph=True).detach()
|
||||
return flat_kl_grad_grad + v * self.__damping
|
||||
|
||||
search_direction = -_conjugate_gradients(MVP, flat_grads, nsteps=10)
|
||||
|
||||
# stepsize: calculate max stepsize constrained by kl bound
|
||||
step_size = torch.sqrt(2 * self._delta / (
|
||||
search_direction * MVP(search_direction)).sum(0, keepdim=True))
|
||||
|
||||
# stepsize: linesearch stepsize
|
||||
with torch.no_grad():
|
||||
flat_params = torch.cat([param.data.view(-1)
|
||||
for param in self.actor.parameters()])
|
||||
for i in range(self._max_backtracks):
|
||||
new_flat_params = flat_params + step_size * search_direction
|
||||
_set_from_flat_params(self.actor, new_flat_params)
|
||||
# calculate kl and if in bound, loss actually down
|
||||
new_dist = self(b).dist
|
||||
new_dratio = (
|
||||
new_dist.log_prob(b.act) - b.logp_old).exp().float()
|
||||
new_dratio = new_dratio.reshape(
|
||||
new_dratio.size(0), -1).transpose(0, 1)
|
||||
new_actor_loss = -(new_dratio * b.adv).mean()
|
||||
kl = kl_divergence(old_dist, new_dist).mean()
|
||||
|
||||
if kl < self._delta and new_actor_loss < actor_loss:
|
||||
if i > 0:
|
||||
warnings.warn(f"Backtracking to step {i}. "
|
||||
"Hyperparamters aren't good enough.")
|
||||
break
|
||||
elif i < self._max_backtracks - 1:
|
||||
step_size = step_size * self._backtrack_coeff
|
||||
else:
|
||||
_set_from_flat_params(self.actor, new_flat_params)
|
||||
step_size = torch.tensor([0.0])
|
||||
warnings.warn("Line search failed! It seems hyperparamters"
|
||||
" are poor and need to be changed.")
|
||||
|
||||
# optimize citirc
|
||||
for _ in range(self._optim_critic_iters):
|
||||
value = self.critic(b.obs).flatten()
|
||||
vf_loss = F.mse_loss(b.returns, value)
|
||||
self.optim.zero_grad()
|
||||
vf_loss.backward()
|
||||
self.optim.step()
|
||||
|
||||
actor_losses.append(actor_loss.item())
|
||||
vf_losses.append(vf_loss.item())
|
||||
step_sizes.append(step_size.item())
|
||||
kls.append(kl.item())
|
||||
|
||||
# update learning rate if lr_scheduler is given
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
|
||||
return {
|
||||
"loss/actor": actor_losses,
|
||||
"loss/vf": vf_losses,
|
||||
"step_size": step_sizes,
|
||||
"kl": kls,
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user