add rllib result and fix pep8

This commit is contained in:
Trinkle23897 2020-03-28 09:43:35 +08:00
parent 77068af526
commit c42990c725
22 changed files with 80 additions and 89 deletions

View File

@ -37,7 +37,7 @@ pip3 install tianshou
The tutorials and API documentation are hosted on [https://tianshou.readthedocs.io](https://tianshou.readthedocs.io). It is under construction currently.
The example scripts are under [test/discrete](/test/discrete) (CartPole) and [test/continuous](/test/continuous) (Pendulum).
The example scripts are under [test/](/test/) folder and [examples/](/examples/) folder.
## Why Tianshou?
@ -49,22 +49,26 @@ Tianshou is a lightweight but high-speed reinforcement learning platform. For ex
We select some of famous (>1k stars) reinforcement learning platforms. Here is the benchmark result for other algorithms and platforms on toy scenarios:
| RL Platform | [Tianshou](https://github.com/thu-ml/tianshou) | [Baselines](https://github.com/openai/baselines) | [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [PyTorch DRL](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch) | [rlpyt](https://github.com/astooke/rlpyt) |
| ---------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| GitHub Stars | [![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/openai/baselines)](https://github.com/openai/baselines/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/ray-project/ray)](https://github.com/ray-project/ray/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch)](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/astooke/rlpyt)](https://github.com/astooke/rlpyt/stargazers) |
| Algo \ Task | PyTorch | TensorFlow | TF/PyTorch | PyTorch | PyTorch |
| PG - CartPole | 9.03±4.18s | None | | None | |
| DQN - CartPole | 20.94±11.38s | 1046.34±291.27s | | 175.55±53.81s | |
| A2C - CartPole | 11.72±3.85s | *(~1612s) | | Runtime Error | |
| PPO - CartPole | 35.25±16.47s | *(~1179s) | | 29.16±15.46s | |
| DDPG - Pendulum | 46.95±24.31s | *(>1h) | | 652.83±471.28s | 172.18±62.48s |
| TD3 - Pendulum | 48.39±7.22s | None | | 619.33±324.97s | 210.31±76.30s |
| SAC - Pendulum | 38.92±2.09s | None | | 808.21±405.70s | 295.92±140.85s |
| RL Platform | [Tianshou](https://github.com/thu-ml/tianshou) | [Baselines](https://github.com/openai/baselines) | [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [PyTorch DRL](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch) | [rlpyt](https://github.com/astooke/rlpyt) |
| --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| GitHub Stars | [![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/openai/baselines)](https://github.com/openai/baselines/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/ray-project/ray)](https://github.com/ray-project/ray/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch)](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/stargazers) | [![GitHub stars](https://img.shields.io/github/stars/astooke/rlpyt)](https://github.com/astooke/rlpyt/stargazers) |
| Algo - Task | PyTorch | TensorFlow | TF/PyTorch | PyTorch | PyTorch |
| PG - CartPole | 9.03±4.18s | None | 15.77±6.28s | None | |
| DQN - CartPole | 10.61±5.51s | 1046.34±291.27s | 40.16±12.79s | 175.55±53.81s | |
| A2C - CartPole | 11.72±3.85s | *(~1612s) | 46.15±6.64s | Runtime Error | |
| PPO - CartPole | 35.25±16.47s | *(~1179s) | 62.21±13.31s (APPO) | 29.16±15.46s | |
| DDPG - Pendulum | 46.95±24.31s | *(>1h) | 377.99±13.79s | 652.83±471.28s | 172.18±62.48s |
| TD3 - Pendulum | 48.39±7.22s | None | 620.83±248.43s | 619.33±324.97s | 210.31±76.30s |
| SAC - Pendulum | 38.92±2.09s | None | 92.68±4.48s | 808.21±405.70s | 295.92±140.85s |
*: Could not reach the target reward threshold in 1e6 steps in any of 10 runs. The total runtime is in the brackets.
All of the platforms use 10 different seeds for testing. We erase those trials which failed for training. The reward threshold is 195.0 in CartPole and -250.0 in Pendulum over consecutive 100 episodes' mean returns.
Tianshou and RLlib's configures are very similar. They both use multiple workers for sampling. Indeed, both RLlib and rlpyt are excellent reinforcement learning platform :)
We will add results of Atari Pong / Mujoco these days.
### Reproducible
Tianshou has unit tests. Different from other platforms, **the unit tests include the full agent training procedure for all of the implemented algorithms**. It will be failed once it cannot train an agent to perform well enough on limited epochs on toy scenarios. The unit tests secure the reproducibility of our platform.

View File

@ -10,10 +10,7 @@ from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env import VectorEnv, SubprocVectorEnv
if __name__ == '__main__':
from continuous_net import Actor, Critic
else: # pytest
from test.continuous.net import Actor, Critic
from continuous_net import Actor, Critic
def get_args():

View File

@ -10,10 +10,7 @@ from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env import VectorEnv, SubprocVectorEnv
if __name__ == '__main__':
from continuous_net import ActorProb, Critic
else: # pytest
from test.continuous.net import ActorProb, Critic
from continuous_net import ActorProb, Critic
def get_args():

View File

@ -10,10 +10,7 @@ from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env import VectorEnv, SubprocVectorEnv
if __name__ == '__main__':
from continuous_net import Actor, Critic
else: # pytest
from test.continuous.net import Actor, Critic
from continuous_net import Actor, Critic
def get_args():

View File

@ -10,10 +10,7 @@ from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env import VectorEnv, SubprocVectorEnv
if __name__ == '__main__':
from continuous_net import Actor, Critic
else: # pytest
from test.continuous.net import Actor, Critic
from continuous_net import Actor, Critic
def get_args():

View File

@ -1,4 +1,3 @@
import gym
import torch
import pprint
import argparse
@ -11,10 +10,7 @@ from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env.atari import create_atari_environment
if __name__ == '__main__':
from discrete_net import Net, Actor, Critic
else: # pytest
from test.discrete.net import Net, Actor, Critic
from discrete_net import Net, Actor, Critic
def get_args():
@ -48,17 +44,20 @@ def get_args():
def test_a2c(args=get_args()):
env = create_atari_environment(args.task, max_episode_steps=args.max_episode_steps)
env = create_atari_environment(
args.task, max_episode_steps=args.max_episode_steps)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.env.action_space.shape or env.env.action_space.n
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: create_atari_environment(args.task, max_episode_steps=args.max_episode_steps) for _ in
range(args.training_num)])
[lambda: create_atari_environment(
args.task, max_episode_steps=args.max_episode_steps)
for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: create_atari_environment(args.task, max_episode_steps=args.max_episode_steps) for _ in
range(args.test_num)])
[lambda: create_atari_environment(
args.task, max_episode_steps=args.max_episode_steps)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
@ -91,7 +90,8 @@ def test_a2c(args=get_args()):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
task=args.task)
train_collector.close()
test_collector.close()
if __name__ == '__main__':

View File

@ -1,4 +1,3 @@
import gym
import torch
import pprint
import argparse
@ -11,10 +10,7 @@ from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env.atari import create_atari_environment
if __name__ == '__main__':
from discrete_net import DQN
else: # pytest
from test.discrete.net import DQN
from discrete_net import DQN
def get_args():
@ -49,18 +45,22 @@ def test_dqn(args=get_args()):
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.env.action_space.shape or env.env.action_space.n
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: create_atari_environment(args.task) for _ in range(args.training_num)])
train_envs = SubprocVectorEnv([
lambda: create_atari_environment(args.task)
for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: create_atari_environment(args.task) for _ in range(args.test_num)])
test_envs = SubprocVectorEnv([
lambda: create_atari_environment(
args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = DQN(args.state_shape[0], args.state_shape[1], args.action_shape, args.device)
net = DQN(
args.state_shape[0], args.state_shape[1],
args.action_shape, args.device)
net = net.to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = DQNPolicy(

View File

@ -1,4 +1,3 @@
import gym
import torch
import pprint
import argparse
@ -11,10 +10,7 @@ from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.env.atari import create_atari_environment
if __name__ == '__main__':
from discrete_net import Net, Actor, Critic
else: # pytest
from test.discrete.net import Net, Actor, Critic
from discrete_net import Net, Actor, Critic
def get_args():
@ -48,17 +44,18 @@ def get_args():
def test_ppo(args=get_args()):
env = create_atari_environment(args.task, max_episode_steps=args.max_episode_steps)
env = create_atari_environment(
args.task, max_episode_steps=args.max_episode_steps)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space().shape or env.action_space().n
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: create_atari_environment(args.task, max_episode_steps=args.max_episode_steps) for _ in
range(args.training_num)])
train_envs = SubprocVectorEnv([lambda: create_atari_environment(
args.task, max_episode_steps=args.max_episode_steps)
for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: create_atari_environment(args.task, max_episode_steps=args.max_episode_steps) for _ in
range(args.test_num)])
test_envs = SubprocVectorEnv([lambda: create_atari_environment(
args.task, max_episode_steps=args.max_episode_steps)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
@ -95,7 +92,8 @@ def test_ppo(args=get_args()):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
task=args.task)
train_collector.close()
test_collector.close()
if __name__ == '__main__':

View File

@ -55,7 +55,7 @@ setup(
],
'atari': [
'atari_py',
'cv2'
'cv2',
],
'mujoco': [
'mujoco_py',

View File

@ -97,7 +97,8 @@ def _test_ppo(args=get_args()):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
task=args.task)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -51,7 +51,6 @@ class Critic(nn.Module):
class DQN(nn.Module):
def __init__(self, h, w, action_shape, device='cpu'):
super(DQN, self).__init__()
self.device = device
@ -73,7 +72,7 @@ class DQN(nn.Module):
def forward(self, x, state=None, info={}):
if not isinstance(x, torch.Tensor):
s = torch.tensor(x, device=self.device, dtype=torch.float)
x = torch.tensor(x, device=self.device, dtype=torch.float)
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))

View File

@ -33,7 +33,6 @@ def get_args():
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
@ -84,7 +83,8 @@ def test_a2c(args=get_args()):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
task=args.task)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -25,7 +25,7 @@ def get_args():
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--n-step', type=int, default=1)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=1000)
@ -72,7 +72,6 @@ def test_dqn(args=get_args()):
test_collector = Collector(policy, test_envs)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size)
print(len(train_collector.buffer))
# log
writer = SummaryWriter(args.logdir + '/' + 'ppo')

View File

@ -131,7 +131,8 @@ def test_pg(args=get_args()):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
task=args.task)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -88,7 +88,8 @@ def test_ppo(args=get_args()):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
task=args.task)
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()

View File

@ -35,9 +35,10 @@ class Batch(object):
elif isinstance(batch.__dict__[k], list):
self.__dict__[k] += batch.__dict__[k]
else:
raise TypeError(
'No support for append with type {} in class Batch.'
.format(type(batch.__dict__[k])))
s = 'No support for append with type'\
+ str(type(batch.__dict__[k]))\
+ 'in class Batch.'
raise TypeError(s)
def split(self, size=None, permute=True):
length = min([

View File

@ -19,7 +19,7 @@ class Collector(object):
self.collect_episode = 0
self.collect_time = 0
if buffer is None:
self.buffer = ReplayBuffer(20000)
self.buffer = ReplayBuffer(100)
else:
self.buffer = buffer
self.policy = policy
@ -100,7 +100,8 @@ class Collector(object):
while True:
if warning_count >= 100000:
warnings.warn(
'There are already many steps in an episode. You should add a time limitation to your environment!',
'There are already many steps in an episode. '
'You should add a time limitation to your environment!',
Warning)
if self._multi_env:
batch_data = Batch(

View File

@ -13,8 +13,8 @@ class OUNoise(object):
def __call__(self, size, mu=.1):
if self.x is None or self.x.shape != size:
self.x = 0
self.x = self.x + self.alpha * (mu - self.x) + \
self.beta * np.random.normal(size=size)
r = self.beta * np.random.normal(size=size)
self.x = self.x + self.alpha * (mu - self.x) + r
return self.x
def reset(self):

View File

@ -34,9 +34,6 @@ class PGPolicy(BasePolicy):
def learn(self, batch, batch_size=None, repeat=1):
losses = []
batch.returns = (batch.returns - batch.returns.mean()) \
/ (batch.returns.std() + self._eps)
r = batch.returns
batch.returns = (r - r.mean()) / (r.std() + self._eps)
for _ in range(repeat):

View File

@ -58,9 +58,6 @@ class PPOPolicy(PGPolicy):
def learn(self, batch, batch_size=None, repeat=1):
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
batch.returns = (batch.returns - batch.returns.mean()) \
/ (batch.returns.std() + self._eps)
r = batch.returns
batch.returns = (r - r.mean()) / (r.std() + self._eps)
batch.act = torch.tensor(batch.act)

View File

@ -47,7 +47,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
data[k] = f'{result[k]:.2f}'
if writer:
writer.add_scalar(
k + '_' + task, result[k], global_step=global_step)
k + '_' + task if task else k,
result[k], global_step=global_step)
for k in losses.keys():
if stat.get(k) is None:
stat[k] = MovAvg()
@ -55,7 +56,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
data[k] = f'{stat[k].get():.6f}'
if writer:
writer.add_scalar(
k + '_' + task, stat[k].get(), global_step=global_step)
k + '_' + task if task else k,
stat[k].get(), global_step=global_step)
t.update(1)
t.set_postfix(**data)
if t.n <= t.total:

View File

@ -52,7 +52,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
data[k] = f'{result[k]:.2f}'
if writer:
writer.add_scalar(
k + '_' + task, result[k], global_step=global_step)
k + '_' + task if task else k,
result[k], global_step=global_step)
for k in losses.keys():
if stat.get(k) is None:
stat[k] = MovAvg()
@ -60,7 +61,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
data[k] = f'{stat[k].get():.6f}'
if writer and global_step:
writer.add_scalar(
k + '_' + task, stat[k].get(), global_step=global_step)
k + '_' + task if task else k,
stat[k].get(), global_step=global_step)
t.update(step)
t.set_postfix(**data)
if t.n <= t.total: