update readme

This commit is contained in:
Trinkle23897 2020-03-26 11:42:34 +08:00
parent 3c0a09fefd
commit c505cd8205
9 changed files with 181 additions and 24 deletions

View File

@ -1,6 +1,6 @@
MIT License
Copyright (c) 2020 TSAIL
Copyright (c) 2020 Tianshou contributors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

171
README.md
View File

@ -1,11 +1,172 @@
# Tianshou
![Python package](https://github.com/Trinkle23897/tianshou/workflows/Python%20package/badge.svg)
<h1 align="center">Tianshou</h1>
![PyPI](https://img.shields.io/pypi/v/tianshou)
![Unittest](https://github.com/thu-ml/tianshou/workflows/Unittest/badge.svg)
[![Documentation Status](https://readthedocs.org/projects/tianshou/badge/?version=latest)](https://tianshou.readthedocs.io/en/latest/?badge=latest)
[![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers)
[![GitHub forks](https://img.shields.io/github/forks/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/network)
[![GitHub issues](https://img.shields.io/github/issues/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/issues)
[![GitHub license](https://img.shields.io/github/license/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/blob/master/LICENSE)
**Tianshou**(天授) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly api, or slow-speed, Tianshou provides a fast-speed framework and pythonic api for building the deep reinforcement learning agent. The supported interface algorithms include:
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf)
- [Advantage Actor-Critic (A2C)](http://incompleteideas.net/book/RLbook2018.pdf)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
Tianshou supports parallel environment training for all algorithms as well.
Tianshou is still under development. More algorithms are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out the [guidelines](/CONTRIBUTING.md).
## Installation
`pip3 install .`
Tianshou is currently hosted on [pypi](https://pypi.org/project/tianshou/). You can simply install Tianshou with the following command:
## Run Demo
```bash
pip3 install tianshou
```
`python3 test/*.py`
## Documentation
The tutorials and api documentations are hosted on https://tianshou.readthedocs.io/en/latest/.
The example scripts are under [test/discrete](/test/discrete) (CartPole) and [test/continuous](/test/continuous) (Pendulum).
## Why Tianshou?
Tianshou is a lightweight but high-speed reinforcement learning platform. For example, here is a test on a laptop (i7-8750H + GTX1060). It only use 3 seconds for training a policy gradient agent on CartPole-v0 task.
![testpg](docs/_static/images/testpg.gif)
Here is the table for other algorithms and platforms:
TODO: a TABLE
Tianshou also has unit tests. Different from other platforms, **the unit tests include the agent training procedure for all of the implemented algorithms**. It will be failed when it cannot train an agent to perform well enough on limited epochs on toy scenarios. The unit tests secure the reproducibility of our platform.
## Quick start
This is an example of Policy Gradient. You can also run the full script under [test/discrete/test_pg.py](/test/discrete/test_pg.py).
First, import the relevant packages:
```python
import gym, torch, numpy as np, torch.nn as nn
from tianshou.policy import PGPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
```
Define some hyper-parameters:
```python
task = 'CartPole-v0'
seed = 1626
lr = 3e-4
gamma = 0.9
epoch = 10
step_per_epoch = 1000
collect_per_step = 10
repeat_per_collect = 2
batch_size = 64
train_num = 8
test_num = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'
```
Define the network:
```python
class Net(nn.Module):
def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'):
super().__init__()
self.device = device
self.model = [
nn.Linear(np.prod(state_shape), 128),
nn.ReLU(inplace=True)]
for i in range(layer_num):
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
if action_shape:
self.model += [nn.Linear(128, np.prod(action_shape))]
self.model = nn.Sequential(*self.model)
def forward(self, s, state=None, info={}):
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, device=self.device, dtype=torch.float)
batch = s.shape[0]
s = s.view(batch, -1)
logits = self.model(s)
return logits, state
```
Make envs and fix seed:
```python
env = gym.make(task)
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
train_envs = SubprocVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
test_envs = SubprocVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
np.random.seed(seed)
torch.manual_seed(seed)
train_envs.seed(seed)
test_envs.seed(seed)
```
Setup policy and collector:
```python
net = Net(3, state_shape, action_shape, device).to(device)
optim = torch.optim.Adam(net.parameters(), lr=lr)
policy = PGPolicy(net, optim, torch.distributions.Categorical, gamma)
train_collector = Collector(policy, train_envs, ReplayBuffer(20000))
test_collector = Collector(policy, test_envs)
```
Let's train it:
```python
result = onpolicy_trainer(policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, repeat_per_collect, test_num, batch_size, stop_fn=lambda x: x >= env.spec.reward_threshold)
```
Saving / loading trained policy (it's the same as PyTorch nn.module):
```python
torch.save(policy.state_dict(), 'pg.pth')
policy.load_state_dict(torch.load('pg.pth', map_location=device))
```
Watch the performance with 35 FPS:
```python3
collecter = Collector(policy, env)
collecter.collect(n_episode=1, render=1/35)
```
## Citing Tianshou
If you find Tianshou useful, please cite it in your publications.
```
@misc{tianshou,
author = {Jiayi Weng},
title = {Tianshou},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/thu-ml/tianshou}},
}
```
## Miscellaneous
Tianshou was [previously](https://github.com/thu-ml/tianshou/tree/priv) a reinforcement learning platform based on TensorFlow. You can checkout the branch `priv` for more detail.

BIN
docs/_static/images/testpg.gif vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 526 KiB

View File

@ -83,8 +83,7 @@ def _test_ppo(args=get_args()):
action_range=[env.action_space.low[0], env.action_space.high[0]])
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size),
remove_done_flag=True)
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.step_per_epoch)
# log

View File

@ -37,7 +37,7 @@ class Batch(object):
else:
raise TypeError(
'No support for append with type {} in class Batch.'
.format(type(batch.__dict__[k])))
.format(type(batch.__dict__[k])))
def split(self, size=None, permute=True):
length = min([

View File

@ -14,7 +14,7 @@ class OUNoise(object):
if self.x is None or self.x.shape != size:
self.x = 0
self.x = self.x + self.alpha * (mu - self.x) + \
self.beta * np.random.normal(size=size)
self.beta * np.random.normal(size=size)
return self.x
def reset(self):

View File

@ -36,18 +36,16 @@ class A2CPolicy(PGPolicy):
v = self.critic(b.obs)
a = torch.tensor(b.act, device=dist.logits.device)
r = torch.tensor(b.returns, device=dist.logits.device)
actor_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
a_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
vf_loss = F.mse_loss(r[:, None], v)
ent_loss = dist.entropy().mean()
loss = actor_loss \
+ self._w_vf * vf_loss \
- self._w_ent * ent_loss
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
loss.backward()
if self._grad_norm:
nn.utils.clip_grad_norm_(
self.model.parameters(), max_norm=self._grad_norm)
self.optim.step()
actor_losses.append(actor_loss.detach().cpu().numpy())
actor_losses.append(a_loss.detach().cpu().numpy())
vf_losses.append(vf_loss.detach().cpu().numpy())
ent_losses.append(ent_loss.detach().cpu().numpy())
losses.append(loss.detach().cpu().numpy())

View File

@ -34,8 +34,8 @@ class PGPolicy(BasePolicy):
def learn(self, batch, batch_size=None, repeat=1):
losses = []
batch.returns = (batch.returns - batch.returns.mean()) \
/ (batch.returns.std() + self._eps)
r = batch.returns
batch.returns = (r - r.mean()) / (r.std() + self._eps)
for _ in range(repeat):
for b in batch.split(batch_size):
self.optim.zero_grad()

View File

@ -58,8 +58,8 @@ class PPOPolicy(PGPolicy):
def learn(self, batch, batch_size=None, repeat=1):
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
batch.returns = (batch.returns - batch.returns.mean()) \
/ (batch.returns.std() + self._eps)
r = batch.returns
batch.returns = (r - r.mean()) / (r.std() + self._eps)
batch.act = torch.tensor(batch.act)
batch.returns = torch.tensor(batch.returns)[:, None]
for _ in range(repeat):
@ -79,16 +79,15 @@ class PPOPolicy(PGPolicy):
clip_losses.append(clip_loss.detach().cpu().numpy())
vf_loss = F.smooth_l1_loss(self.critic(b.obs), target_v)
vf_losses.append(vf_loss.detach().cpu().numpy())
ent_loss = dist.entropy().mean()
ent_losses.append(ent_loss.detach().cpu().numpy())
loss = clip_loss \
+ self._w_vf * vf_loss - self._w_ent * ent_loss
e_loss = dist.entropy().mean()
ent_losses.append(e_loss.detach().cpu().numpy())
loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss
losses.append(loss.detach().cpu().numpy())
self.optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(list(
self.actor.parameters()) + list(self.critic.parameters()),
self._max_grad_norm)
self._max_grad_norm)
self.optim.step()
self.sync_weight()
return {