diff --git a/LICENSE b/LICENSE
index c94e299..6a7aa81 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,6 +1,6 @@
MIT License
-Copyright (c) 2020 TSAIL
+Copyright (c) 2020 Tianshou contributors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
diff --git a/README.md b/README.md
index 3cd6197..4c41822 100644
--- a/README.md
+++ b/README.md
@@ -1,11 +1,172 @@
-# Tianshou
-
+
Tianshou
+
+
+
+[](https://tianshou.readthedocs.io/en/latest/?badge=latest)
+[](https://github.com/thu-ml/tianshou/stargazers)
+[](https://github.com/thu-ml/tianshou/network)
+[](https://github.com/thu-ml/tianshou/issues)
+[](https://github.com/thu-ml/tianshou/blob/master/LICENSE)
+
+**Tianshou**(天授) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly api, or slow-speed, Tianshou provides a fast-speed framework and pythonic api for building the deep reinforcement learning agent. The supported interface algorithms include:
+
+
+- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
+- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
+- [Double DQN (DDQN)](https://arxiv.org/pdf/1509.06461.pdf)
+- [Advantage Actor-Critic (A2C)](http://incompleteideas.net/book/RLbook2018.pdf)
+- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
+- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
+- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
+- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf)
+
+Tianshou supports parallel environment training for all algorithms as well.
+
+Tianshou is still under development. More algorithms are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out the [guidelines](/CONTRIBUTING.md).
## Installation
-`pip3 install .`
+Tianshou is currently hosted on [pypi](https://pypi.org/project/tianshou/). You can simply install Tianshou with the following command:
-## Run Demo
+```bash
+pip3 install tianshou
+```
-`python3 test/*.py`
+## Documentation
+
+The tutorials and api documentations are hosted on https://tianshou.readthedocs.io/en/latest/.
+
+The example scripts are under [test/discrete](/test/discrete) (CartPole) and [test/continuous](/test/continuous) (Pendulum).
+
+## Why Tianshou?
+
+Tianshou is a lightweight but high-speed reinforcement learning platform. For example, here is a test on a laptop (i7-8750H + GTX1060). It only use 3 seconds for training a policy gradient agent on CartPole-v0 task.
+
+
+
+Here is the table for other algorithms and platforms:
+
+TODO: a TABLE
+
+Tianshou also has unit tests. Different from other platforms, **the unit tests include the agent training procedure for all of the implemented algorithms**. It will be failed when it cannot train an agent to perform well enough on limited epochs on toy scenarios. The unit tests secure the reproducibility of our platform.
+
+## Quick start
+
+This is an example of Policy Gradient. You can also run the full script under [test/discrete/test_pg.py](/test/discrete/test_pg.py).
+
+First, import the relevant packages:
+
+```python
+import gym, torch, numpy as np, torch.nn as nn
+
+from tianshou.policy import PGPolicy
+from tianshou.env import SubprocVectorEnv
+from tianshou.trainer import onpolicy_trainer
+from tianshou.data import Collector, ReplayBuffer
+```
+
+Define some hyper-parameters:
+
+```python
+task = 'CartPole-v0'
+seed = 1626
+lr = 3e-4
+gamma = 0.9
+epoch = 10
+step_per_epoch = 1000
+collect_per_step = 10
+repeat_per_collect = 2
+batch_size = 64
+train_num = 8
+test_num = 100
+device = 'cuda' if torch.cuda.is_available() else 'cpu'
+```
+
+Define the network:
+
+```python
+class Net(nn.Module):
+ def __init__(self, layer_num, state_shape, action_shape=0, device='cpu'):
+ super().__init__()
+ self.device = device
+ self.model = [
+ nn.Linear(np.prod(state_shape), 128),
+ nn.ReLU(inplace=True)]
+ for i in range(layer_num):
+ self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
+ if action_shape:
+ self.model += [nn.Linear(128, np.prod(action_shape))]
+ self.model = nn.Sequential(*self.model)
+
+ def forward(self, s, state=None, info={}):
+ if not isinstance(s, torch.Tensor):
+ s = torch.tensor(s, device=self.device, dtype=torch.float)
+ batch = s.shape[0]
+ s = s.view(batch, -1)
+ logits = self.model(s)
+ return logits, state
+```
+
+Make envs and fix seed:
+
+```python
+env = gym.make(task)
+state_shape = env.observation_space.shape or env.observation_space.n
+action_shape = env.action_space.shape or env.action_space.n
+train_envs = SubprocVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
+test_envs = SubprocVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
+np.random.seed(seed)
+torch.manual_seed(seed)
+train_envs.seed(seed)
+test_envs.seed(seed)
+```
+
+Setup policy and collector:
+
+```python
+net = Net(3, state_shape, action_shape, device).to(device)
+optim = torch.optim.Adam(net.parameters(), lr=lr)
+policy = PGPolicy(net, optim, torch.distributions.Categorical, gamma)
+train_collector = Collector(policy, train_envs, ReplayBuffer(20000))
+test_collector = Collector(policy, test_envs)
+```
+
+Let's train it:
+
+```python
+result = onpolicy_trainer(policy, train_collector, test_collector, epoch, step_per_epoch, collect_per_step, repeat_per_collect, test_num, batch_size, stop_fn=lambda x: x >= env.spec.reward_threshold)
+```
+
+Saving / loading trained policy (it's the same as PyTorch nn.module):
+
+```python
+torch.save(policy.state_dict(), 'pg.pth')
+policy.load_state_dict(torch.load('pg.pth', map_location=device))
+```
+
+Watch the performance with 35 FPS:
+
+```python3
+collecter = Collector(policy, env)
+collecter.collect(n_episode=1, render=1/35)
+```
+
+## Citing Tianshou
+
+If you find Tianshou useful, please cite it in your publications.
+
+```
+@misc{tianshou,
+ author = {Jiayi Weng},
+ title = {Tianshou},
+ year = {2020},
+ publisher = {GitHub},
+ journal = {GitHub repository},
+ howpublished = {\url{https://github.com/thu-ml/tianshou}},
+}
+```
+
+## Miscellaneous
+
+Tianshou was [previously](https://github.com/thu-ml/tianshou/tree/priv) a reinforcement learning platform based on TensorFlow. You can checkout the branch `priv` for more detail.
\ No newline at end of file
diff --git a/docs/_static/images/testpg.gif b/docs/_static/images/testpg.gif
new file mode 100644
index 0000000..e780552
Binary files /dev/null and b/docs/_static/images/testpg.gif differ
diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py
index 3867bec..3cfb9f2 100644
--- a/test/continuous/test_ppo.py
+++ b/test/continuous/test_ppo.py
@@ -83,8 +83,7 @@ def _test_ppo(args=get_args()):
action_range=[env.action_space.low[0], env.action_space.high[0]])
# collector
train_collector = Collector(
- policy, train_envs, ReplayBuffer(args.buffer_size),
- remove_done_flag=True)
+ policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.step_per_epoch)
# log
diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py
index d2af6b5..13777ab 100644
--- a/tianshou/data/batch.py
+++ b/tianshou/data/batch.py
@@ -37,7 +37,7 @@ class Batch(object):
else:
raise TypeError(
'No support for append with type {} in class Batch.'
- .format(type(batch.__dict__[k])))
+ .format(type(batch.__dict__[k])))
def split(self, size=None, permute=True):
length = min([
diff --git a/tianshou/exploration/random.py b/tianshou/exploration/random.py
index 12d92ea..011afbe 100644
--- a/tianshou/exploration/random.py
+++ b/tianshou/exploration/random.py
@@ -14,7 +14,7 @@ class OUNoise(object):
if self.x is None or self.x.shape != size:
self.x = 0
self.x = self.x + self.alpha * (mu - self.x) + \
- self.beta * np.random.normal(size=size)
+ self.beta * np.random.normal(size=size)
return self.x
def reset(self):
diff --git a/tianshou/policy/a2c.py b/tianshou/policy/a2c.py
index 4de99cf..93337d4 100644
--- a/tianshou/policy/a2c.py
+++ b/tianshou/policy/a2c.py
@@ -36,18 +36,16 @@ class A2CPolicy(PGPolicy):
v = self.critic(b.obs)
a = torch.tensor(b.act, device=dist.logits.device)
r = torch.tensor(b.returns, device=dist.logits.device)
- actor_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
+ a_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
vf_loss = F.mse_loss(r[:, None], v)
ent_loss = dist.entropy().mean()
- loss = actor_loss \
- + self._w_vf * vf_loss \
- - self._w_ent * ent_loss
+ loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss
loss.backward()
if self._grad_norm:
nn.utils.clip_grad_norm_(
self.model.parameters(), max_norm=self._grad_norm)
self.optim.step()
- actor_losses.append(actor_loss.detach().cpu().numpy())
+ actor_losses.append(a_loss.detach().cpu().numpy())
vf_losses.append(vf_loss.detach().cpu().numpy())
ent_losses.append(ent_loss.detach().cpu().numpy())
losses.append(loss.detach().cpu().numpy())
diff --git a/tianshou/policy/pg.py b/tianshou/policy/pg.py
index 2f52233..c5e3b70 100644
--- a/tianshou/policy/pg.py
+++ b/tianshou/policy/pg.py
@@ -34,8 +34,8 @@ class PGPolicy(BasePolicy):
def learn(self, batch, batch_size=None, repeat=1):
losses = []
- batch.returns = (batch.returns - batch.returns.mean()) \
- / (batch.returns.std() + self._eps)
+ r = batch.returns
+ batch.returns = (r - r.mean()) / (r.std() + self._eps)
for _ in range(repeat):
for b in batch.split(batch_size):
self.optim.zero_grad()
diff --git a/tianshou/policy/ppo.py b/tianshou/policy/ppo.py
index f972a26..01270ef 100644
--- a/tianshou/policy/ppo.py
+++ b/tianshou/policy/ppo.py
@@ -58,8 +58,8 @@ class PPOPolicy(PGPolicy):
def learn(self, batch, batch_size=None, repeat=1):
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
- batch.returns = (batch.returns - batch.returns.mean()) \
- / (batch.returns.std() + self._eps)
+ r = batch.returns
+ batch.returns = (r - r.mean()) / (r.std() + self._eps)
batch.act = torch.tensor(batch.act)
batch.returns = torch.tensor(batch.returns)[:, None]
for _ in range(repeat):
@@ -79,16 +79,15 @@ class PPOPolicy(PGPolicy):
clip_losses.append(clip_loss.detach().cpu().numpy())
vf_loss = F.smooth_l1_loss(self.critic(b.obs), target_v)
vf_losses.append(vf_loss.detach().cpu().numpy())
- ent_loss = dist.entropy().mean()
- ent_losses.append(ent_loss.detach().cpu().numpy())
- loss = clip_loss \
- + self._w_vf * vf_loss - self._w_ent * ent_loss
+ e_loss = dist.entropy().mean()
+ ent_losses.append(e_loss.detach().cpu().numpy())
+ loss = clip_loss + self._w_vf * vf_loss - self._w_ent * e_loss
losses.append(loss.detach().cpu().numpy())
self.optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(list(
self.actor.parameters()) + list(self.critic.parameters()),
- self._max_grad_norm)
+ self._max_grad_norm)
self.optim.step()
self.sync_weight()
return {