add rllib result and fix pep8
This commit is contained in:
parent
77068af526
commit
c42990c725
28
README.md
28
README.md
@ -37,7 +37,7 @@ pip3 install tianshou
|
|||||||
|
|
||||||
The tutorials and API documentation are hosted on [https://tianshou.readthedocs.io](https://tianshou.readthedocs.io). It is under construction currently.
|
The tutorials and API documentation are hosted on [https://tianshou.readthedocs.io](https://tianshou.readthedocs.io). It is under construction currently.
|
||||||
|
|
||||||
The example scripts are under [test/discrete](/test/discrete) (CartPole) and [test/continuous](/test/continuous) (Pendulum).
|
The example scripts are under [test/](/test/) folder and [examples/](/examples/) folder.
|
||||||
|
|
||||||
## Why Tianshou?
|
## Why Tianshou?
|
||||||
|
|
||||||
@ -49,22 +49,26 @@ Tianshou is a lightweight but high-speed reinforcement learning platform. For ex
|
|||||||
|
|
||||||
We select some of famous (>1k stars) reinforcement learning platforms. Here is the benchmark result for other algorithms and platforms on toy scenarios:
|
We select some of famous (>1k stars) reinforcement learning platforms. Here is the benchmark result for other algorithms and platforms on toy scenarios:
|
||||||
|
|
||||||
| RL Platform | [Tianshou](https://github.com/thu-ml/tianshou) | [Baselines](https://github.com/openai/baselines) | [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [PyTorch DRL](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch) | [rlpyt](https://github.com/astooke/rlpyt) |
|
| RL Platform | [Tianshou](https://github.com/thu-ml/tianshou) | [Baselines](https://github.com/openai/baselines) | [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [PyTorch DRL](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch) | [rlpyt](https://github.com/astooke/rlpyt) |
|
||||||
| ---------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
| --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||||
| GitHub Stars | [](https://github.com/thu-ml/tianshou/stargazers) | [](https://github.com/openai/baselines/stargazers) | [](https://github.com/ray-project/ray/stargazers) | [](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/stargazers) | [](https://github.com/astooke/rlpyt/stargazers) |
|
| GitHub Stars | [](https://github.com/thu-ml/tianshou/stargazers) | [](https://github.com/openai/baselines/stargazers) | [](https://github.com/ray-project/ray/stargazers) | [](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/stargazers) | [](https://github.com/astooke/rlpyt/stargazers) |
|
||||||
| Algo \ Task | PyTorch | TensorFlow | TF/PyTorch | PyTorch | PyTorch |
|
| Algo - Task | PyTorch | TensorFlow | TF/PyTorch | PyTorch | PyTorch |
|
||||||
| PG - CartPole | 9.03±4.18s | None | | None | |
|
| PG - CartPole | 9.03±4.18s | None | 15.77±6.28s | None | |
|
||||||
| DQN - CartPole | 20.94±11.38s | 1046.34±291.27s | | 175.55±53.81s | |
|
| DQN - CartPole | 10.61±5.51s | 1046.34±291.27s | 40.16±12.79s | 175.55±53.81s | |
|
||||||
| A2C - CartPole | 11.72±3.85s | *(~1612s) | | Runtime Error | |
|
| A2C - CartPole | 11.72±3.85s | *(~1612s) | 46.15±6.64s | Runtime Error | |
|
||||||
| PPO - CartPole | 35.25±16.47s | *(~1179s) | | 29.16±15.46s | |
|
| PPO - CartPole | 35.25±16.47s | *(~1179s) | 62.21±13.31s (APPO) | 29.16±15.46s | |
|
||||||
| DDPG - Pendulum | 46.95±24.31s | *(>1h) | | 652.83±471.28s | 172.18±62.48s |
|
| DDPG - Pendulum | 46.95±24.31s | *(>1h) | 377.99±13.79s | 652.83±471.28s | 172.18±62.48s |
|
||||||
| TD3 - Pendulum | 48.39±7.22s | None | | 619.33±324.97s | 210.31±76.30s |
|
| TD3 - Pendulum | 48.39±7.22s | None | 620.83±248.43s | 619.33±324.97s | 210.31±76.30s |
|
||||||
| SAC - Pendulum | 38.92±2.09s | None | | 808.21±405.70s | 295.92±140.85s |
|
| SAC - Pendulum | 38.92±2.09s | None | 92.68±4.48s | 808.21±405.70s | 295.92±140.85s |
|
||||||
|
|
||||||
*: Could not reach the target reward threshold in 1e6 steps in any of 10 runs. The total runtime is in the brackets.
|
*: Could not reach the target reward threshold in 1e6 steps in any of 10 runs. The total runtime is in the brackets.
|
||||||
|
|
||||||
All of the platforms use 10 different seeds for testing. We erase those trials which failed for training. The reward threshold is 195.0 in CartPole and -250.0 in Pendulum over consecutive 100 episodes' mean returns.
|
All of the platforms use 10 different seeds for testing. We erase those trials which failed for training. The reward threshold is 195.0 in CartPole and -250.0 in Pendulum over consecutive 100 episodes' mean returns.
|
||||||
|
|
||||||
|
Tianshou and RLlib's configures are very similar. They both use multiple workers for sampling. Indeed, both RLlib and rlpyt are excellent reinforcement learning platform :)
|
||||||
|
|
||||||
|
We will add results of Atari Pong / Mujoco these days.
|
||||||
|
|
||||||
### Reproducible
|
### Reproducible
|
||||||
|
|
||||||
Tianshou has unit tests. Different from other platforms, **the unit tests include the full agent training procedure for all of the implemented algorithms**. It will be failed once it cannot train an agent to perform well enough on limited epochs on toy scenarios. The unit tests secure the reproducibility of our platform.
|
Tianshou has unit tests. Different from other platforms, **the unit tests include the full agent training procedure for all of the implemented algorithms**. It will be failed once it cannot train an agent to perform well enough on limited epochs on toy scenarios. The unit tests secure the reproducibility of our platform.
|
||||||
|
@ -10,10 +10,7 @@ from tianshou.trainer import offpolicy_trainer
|
|||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||||
|
|
||||||
if __name__ == '__main__':
|
from continuous_net import Actor, Critic
|
||||||
from continuous_net import Actor, Critic
|
|
||||||
else: # pytest
|
|
||||||
from test.continuous.net import Actor, Critic
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
|
@ -10,10 +10,7 @@ from tianshou.trainer import offpolicy_trainer
|
|||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||||
|
|
||||||
if __name__ == '__main__':
|
from continuous_net import ActorProb, Critic
|
||||||
from continuous_net import ActorProb, Critic
|
|
||||||
else: # pytest
|
|
||||||
from test.continuous.net import ActorProb, Critic
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
|
@ -10,10 +10,7 @@ from tianshou.trainer import offpolicy_trainer
|
|||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||||
|
|
||||||
if __name__ == '__main__':
|
from continuous_net import Actor, Critic
|
||||||
from continuous_net import Actor, Critic
|
|
||||||
else: # pytest
|
|
||||||
from test.continuous.net import Actor, Critic
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
|
@ -10,10 +10,7 @@ from tianshou.trainer import offpolicy_trainer
|
|||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||||
|
|
||||||
if __name__ == '__main__':
|
from continuous_net import Actor, Critic
|
||||||
from continuous_net import Actor, Critic
|
|
||||||
else: # pytest
|
|
||||||
from test.continuous.net import Actor, Critic
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import gym
|
|
||||||
import torch
|
import torch
|
||||||
import pprint
|
import pprint
|
||||||
import argparse
|
import argparse
|
||||||
@ -11,10 +10,7 @@ from tianshou.trainer import onpolicy_trainer
|
|||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.env.atari import create_atari_environment
|
from tianshou.env.atari import create_atari_environment
|
||||||
|
|
||||||
if __name__ == '__main__':
|
from discrete_net import Net, Actor, Critic
|
||||||
from discrete_net import Net, Actor, Critic
|
|
||||||
else: # pytest
|
|
||||||
from test.discrete.net import Net, Actor, Critic
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -48,17 +44,20 @@ def get_args():
|
|||||||
|
|
||||||
|
|
||||||
def test_a2c(args=get_args()):
|
def test_a2c(args=get_args()):
|
||||||
env = create_atari_environment(args.task, max_episode_steps=args.max_episode_steps)
|
env = create_atari_environment(
|
||||||
|
args.task, max_episode_steps=args.max_episode_steps)
|
||||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
args.action_shape = env.env.action_space.shape or env.env.action_space.n
|
args.action_shape = env.env.action_space.shape or env.env.action_space.n
|
||||||
# train_envs = gym.make(args.task)
|
# train_envs = gym.make(args.task)
|
||||||
train_envs = SubprocVectorEnv(
|
train_envs = SubprocVectorEnv(
|
||||||
[lambda: create_atari_environment(args.task, max_episode_steps=args.max_episode_steps) for _ in
|
[lambda: create_atari_environment(
|
||||||
range(args.training_num)])
|
args.task, max_episode_steps=args.max_episode_steps)
|
||||||
|
for _ in range(args.training_num)])
|
||||||
# test_envs = gym.make(args.task)
|
# test_envs = gym.make(args.task)
|
||||||
test_envs = SubprocVectorEnv(
|
test_envs = SubprocVectorEnv(
|
||||||
[lambda: create_atari_environment(args.task, max_episode_steps=args.max_episode_steps) for _ in
|
[lambda: create_atari_environment(
|
||||||
range(args.test_num)])
|
args.task, max_episode_steps=args.max_episode_steps)
|
||||||
|
for _ in range(args.test_num)])
|
||||||
# seed
|
# seed
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
@ -91,7 +90,8 @@ def test_a2c(args=get_args()):
|
|||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
|
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
|
||||||
|
task=args.task)
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import gym
|
|
||||||
import torch
|
import torch
|
||||||
import pprint
|
import pprint
|
||||||
import argparse
|
import argparse
|
||||||
@ -11,10 +10,7 @@ from tianshou.trainer import offpolicy_trainer
|
|||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.env.atari import create_atari_environment
|
from tianshou.env.atari import create_atari_environment
|
||||||
|
|
||||||
if __name__ == '__main__':
|
from discrete_net import DQN
|
||||||
from discrete_net import DQN
|
|
||||||
else: # pytest
|
|
||||||
from test.discrete.net import DQN
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -49,18 +45,22 @@ def test_dqn(args=get_args()):
|
|||||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
args.action_shape = env.env.action_space.shape or env.env.action_space.n
|
args.action_shape = env.env.action_space.shape or env.env.action_space.n
|
||||||
# train_envs = gym.make(args.task)
|
# train_envs = gym.make(args.task)
|
||||||
train_envs = SubprocVectorEnv(
|
train_envs = SubprocVectorEnv([
|
||||||
[lambda: create_atari_environment(args.task) for _ in range(args.training_num)])
|
lambda: create_atari_environment(args.task)
|
||||||
|
for _ in range(args.training_num)])
|
||||||
# test_envs = gym.make(args.task)
|
# test_envs = gym.make(args.task)
|
||||||
test_envs = SubprocVectorEnv(
|
test_envs = SubprocVectorEnv([
|
||||||
[lambda: create_atari_environment(args.task) for _ in range(args.test_num)])
|
lambda: create_atari_environment(
|
||||||
|
args.task) for _ in range(args.test_num)])
|
||||||
# seed
|
# seed
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
train_envs.seed(args.seed)
|
train_envs.seed(args.seed)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
# model
|
# model
|
||||||
net = DQN(args.state_shape[0], args.state_shape[1], args.action_shape, args.device)
|
net = DQN(
|
||||||
|
args.state_shape[0], args.state_shape[1],
|
||||||
|
args.action_shape, args.device)
|
||||||
net = net.to(args.device)
|
net = net.to(args.device)
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
policy = DQNPolicy(
|
policy = DQNPolicy(
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import gym
|
|
||||||
import torch
|
import torch
|
||||||
import pprint
|
import pprint
|
||||||
import argparse
|
import argparse
|
||||||
@ -11,10 +10,7 @@ from tianshou.trainer import onpolicy_trainer
|
|||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.env.atari import create_atari_environment
|
from tianshou.env.atari import create_atari_environment
|
||||||
|
|
||||||
if __name__ == '__main__':
|
from discrete_net import Net, Actor, Critic
|
||||||
from discrete_net import Net, Actor, Critic
|
|
||||||
else: # pytest
|
|
||||||
from test.discrete.net import Net, Actor, Critic
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -48,17 +44,18 @@ def get_args():
|
|||||||
|
|
||||||
|
|
||||||
def test_ppo(args=get_args()):
|
def test_ppo(args=get_args()):
|
||||||
env = create_atari_environment(args.task, max_episode_steps=args.max_episode_steps)
|
env = create_atari_environment(
|
||||||
|
args.task, max_episode_steps=args.max_episode_steps)
|
||||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
args.action_shape = env.action_space().shape or env.action_space().n
|
args.action_shape = env.action_space().shape or env.action_space().n
|
||||||
# train_envs = gym.make(args.task)
|
# train_envs = gym.make(args.task)
|
||||||
train_envs = SubprocVectorEnv(
|
train_envs = SubprocVectorEnv([lambda: create_atari_environment(
|
||||||
[lambda: create_atari_environment(args.task, max_episode_steps=args.max_episode_steps) for _ in
|
args.task, max_episode_steps=args.max_episode_steps)
|
||||||
range(args.training_num)])
|
for _ in range(args.training_num)])
|
||||||
# test_envs = gym.make(args.task)
|
# test_envs = gym.make(args.task)
|
||||||
test_envs = SubprocVectorEnv(
|
test_envs = SubprocVectorEnv([lambda: create_atari_environment(
|
||||||
[lambda: create_atari_environment(args.task, max_episode_steps=args.max_episode_steps) for _ in
|
args.task, max_episode_steps=args.max_episode_steps)
|
||||||
range(args.test_num)])
|
for _ in range(args.test_num)])
|
||||||
# seed
|
# seed
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
@ -95,7 +92,8 @@ def test_ppo(args=get_args()):
|
|||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
|
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
|
||||||
|
task=args.task)
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
2
setup.py
2
setup.py
@ -55,7 +55,7 @@ setup(
|
|||||||
],
|
],
|
||||||
'atari': [
|
'atari': [
|
||||||
'atari_py',
|
'atari_py',
|
||||||
'cv2'
|
'cv2',
|
||||||
],
|
],
|
||||||
'mujoco': [
|
'mujoco': [
|
||||||
'mujoco_py',
|
'mujoco_py',
|
||||||
|
@ -97,7 +97,8 @@ def _test_ppo(args=get_args()):
|
|||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
|
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
|
||||||
|
task=args.task)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
|
@ -51,7 +51,6 @@ class Critic(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DQN(nn.Module):
|
class DQN(nn.Module):
|
||||||
|
|
||||||
def __init__(self, h, w, action_shape, device='cpu'):
|
def __init__(self, h, w, action_shape, device='cpu'):
|
||||||
super(DQN, self).__init__()
|
super(DQN, self).__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -73,7 +72,7 @@ class DQN(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, state=None, info={}):
|
def forward(self, x, state=None, info={}):
|
||||||
if not isinstance(x, torch.Tensor):
|
if not isinstance(x, torch.Tensor):
|
||||||
s = torch.tensor(x, device=self.device, dtype=torch.float)
|
x = torch.tensor(x, device=self.device, dtype=torch.float)
|
||||||
x = F.relu(self.bn1(self.conv1(x)))
|
x = F.relu(self.bn1(self.conv1(x)))
|
||||||
x = F.relu(self.bn2(self.conv2(x)))
|
x = F.relu(self.bn2(self.conv2(x)))
|
||||||
x = F.relu(self.bn3(self.conv3(x)))
|
x = F.relu(self.bn3(self.conv3(x)))
|
||||||
|
@ -33,7 +33,6 @@ def get_args():
|
|||||||
parser.add_argument('--test-num', type=int, default=100)
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
parser.add_argument('--logdir', type=str, default='log')
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
parser.add_argument('--render', type=float, default=0.)
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--device', type=str,
|
'--device', type=str,
|
||||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
@ -84,7 +83,8 @@ def test_a2c(args=get_args()):
|
|||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
|
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
|
||||||
|
task=args.task)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
|
@ -25,7 +25,7 @@ def get_args():
|
|||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--lr', type=float, default=1e-3)
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
parser.add_argument('--gamma', type=float, default=0.9)
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
parser.add_argument('--n-step', type=int, default=1)
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
@ -72,7 +72,6 @@ def test_dqn(args=get_args()):
|
|||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
# policy.set_eps(1)
|
# policy.set_eps(1)
|
||||||
train_collector.collect(n_step=args.batch_size)
|
train_collector.collect(n_step=args.batch_size)
|
||||||
print(len(train_collector.buffer))
|
|
||||||
# log
|
# log
|
||||||
writer = SummaryWriter(args.logdir + '/' + 'ppo')
|
writer = SummaryWriter(args.logdir + '/' + 'ppo')
|
||||||
|
|
||||||
|
@ -131,7 +131,8 @@ def test_pg(args=get_args()):
|
|||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
|
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
|
||||||
|
task=args.task)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
|
@ -88,7 +88,8 @@ def test_ppo(args=get_args()):
|
|||||||
result = onpolicy_trainer(
|
result = onpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
|
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
|
||||||
|
task=args.task)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
|
@ -35,9 +35,10 @@ class Batch(object):
|
|||||||
elif isinstance(batch.__dict__[k], list):
|
elif isinstance(batch.__dict__[k], list):
|
||||||
self.__dict__[k] += batch.__dict__[k]
|
self.__dict__[k] += batch.__dict__[k]
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
s = 'No support for append with type'\
|
||||||
'No support for append with type {} in class Batch.'
|
+ str(type(batch.__dict__[k]))\
|
||||||
.format(type(batch.__dict__[k])))
|
+ 'in class Batch.'
|
||||||
|
raise TypeError(s)
|
||||||
|
|
||||||
def split(self, size=None, permute=True):
|
def split(self, size=None, permute=True):
|
||||||
length = min([
|
length = min([
|
||||||
|
@ -19,7 +19,7 @@ class Collector(object):
|
|||||||
self.collect_episode = 0
|
self.collect_episode = 0
|
||||||
self.collect_time = 0
|
self.collect_time = 0
|
||||||
if buffer is None:
|
if buffer is None:
|
||||||
self.buffer = ReplayBuffer(20000)
|
self.buffer = ReplayBuffer(100)
|
||||||
else:
|
else:
|
||||||
self.buffer = buffer
|
self.buffer = buffer
|
||||||
self.policy = policy
|
self.policy = policy
|
||||||
@ -100,7 +100,8 @@ class Collector(object):
|
|||||||
while True:
|
while True:
|
||||||
if warning_count >= 100000:
|
if warning_count >= 100000:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
'There are already many steps in an episode. You should add a time limitation to your environment!',
|
'There are already many steps in an episode. '
|
||||||
|
'You should add a time limitation to your environment!',
|
||||||
Warning)
|
Warning)
|
||||||
if self._multi_env:
|
if self._multi_env:
|
||||||
batch_data = Batch(
|
batch_data = Batch(
|
||||||
|
@ -13,8 +13,8 @@ class OUNoise(object):
|
|||||||
def __call__(self, size, mu=.1):
|
def __call__(self, size, mu=.1):
|
||||||
if self.x is None or self.x.shape != size:
|
if self.x is None or self.x.shape != size:
|
||||||
self.x = 0
|
self.x = 0
|
||||||
self.x = self.x + self.alpha * (mu - self.x) + \
|
r = self.beta * np.random.normal(size=size)
|
||||||
self.beta * np.random.normal(size=size)
|
self.x = self.x + self.alpha * (mu - self.x) + r
|
||||||
return self.x
|
return self.x
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
@ -34,9 +34,6 @@ class PGPolicy(BasePolicy):
|
|||||||
|
|
||||||
def learn(self, batch, batch_size=None, repeat=1):
|
def learn(self, batch, batch_size=None, repeat=1):
|
||||||
losses = []
|
losses = []
|
||||||
|
|
||||||
batch.returns = (batch.returns - batch.returns.mean()) \
|
|
||||||
/ (batch.returns.std() + self._eps)
|
|
||||||
r = batch.returns
|
r = batch.returns
|
||||||
batch.returns = (r - r.mean()) / (r.std() + self._eps)
|
batch.returns = (r - r.mean()) / (r.std() + self._eps)
|
||||||
for _ in range(repeat):
|
for _ in range(repeat):
|
||||||
|
@ -58,9 +58,6 @@ class PPOPolicy(PGPolicy):
|
|||||||
|
|
||||||
def learn(self, batch, batch_size=None, repeat=1):
|
def learn(self, batch, batch_size=None, repeat=1):
|
||||||
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
||||||
|
|
||||||
batch.returns = (batch.returns - batch.returns.mean()) \
|
|
||||||
/ (batch.returns.std() + self._eps)
|
|
||||||
r = batch.returns
|
r = batch.returns
|
||||||
batch.returns = (r - r.mean()) / (r.std() + self._eps)
|
batch.returns = (r - r.mean()) / (r.std() + self._eps)
|
||||||
batch.act = torch.tensor(batch.act)
|
batch.act = torch.tensor(batch.act)
|
||||||
|
@ -47,7 +47,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
data[k] = f'{result[k]:.2f}'
|
data[k] = f'{result[k]:.2f}'
|
||||||
if writer:
|
if writer:
|
||||||
writer.add_scalar(
|
writer.add_scalar(
|
||||||
k + '_' + task, result[k], global_step=global_step)
|
k + '_' + task if task else k,
|
||||||
|
result[k], global_step=global_step)
|
||||||
for k in losses.keys():
|
for k in losses.keys():
|
||||||
if stat.get(k) is None:
|
if stat.get(k) is None:
|
||||||
stat[k] = MovAvg()
|
stat[k] = MovAvg()
|
||||||
@ -55,7 +56,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
data[k] = f'{stat[k].get():.6f}'
|
data[k] = f'{stat[k].get():.6f}'
|
||||||
if writer:
|
if writer:
|
||||||
writer.add_scalar(
|
writer.add_scalar(
|
||||||
k + '_' + task, stat[k].get(), global_step=global_step)
|
k + '_' + task if task else k,
|
||||||
|
stat[k].get(), global_step=global_step)
|
||||||
t.update(1)
|
t.update(1)
|
||||||
t.set_postfix(**data)
|
t.set_postfix(**data)
|
||||||
if t.n <= t.total:
|
if t.n <= t.total:
|
||||||
|
@ -52,7 +52,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
data[k] = f'{result[k]:.2f}'
|
data[k] = f'{result[k]:.2f}'
|
||||||
if writer:
|
if writer:
|
||||||
writer.add_scalar(
|
writer.add_scalar(
|
||||||
k + '_' + task, result[k], global_step=global_step)
|
k + '_' + task if task else k,
|
||||||
|
result[k], global_step=global_step)
|
||||||
for k in losses.keys():
|
for k in losses.keys():
|
||||||
if stat.get(k) is None:
|
if stat.get(k) is None:
|
||||||
stat[k] = MovAvg()
|
stat[k] = MovAvg()
|
||||||
@ -60,7 +61,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
data[k] = f'{stat[k].get():.6f}'
|
data[k] = f'{stat[k].get():.6f}'
|
||||||
if writer and global_step:
|
if writer and global_step:
|
||||||
writer.add_scalar(
|
writer.add_scalar(
|
||||||
k + '_' + task, stat[k].get(), global_step=global_step)
|
k + '_' + task if task else k,
|
||||||
|
stat[k].get(), global_step=global_step)
|
||||||
t.update(step)
|
t.update(step)
|
||||||
t.set_postfix(**data)
|
t.set_postfix(**data)
|
||||||
if t.n <= t.total:
|
if t.n <= t.total:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user