add rllib result and fix pep8
This commit is contained in:
		
							parent
							
								
									77068af526
								
							
						
					
					
						commit
						c42990c725
					
				
							
								
								
									
										28
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										28
									
								
								README.md
									
									
									
									
									
								
							@ -37,7 +37,7 @@ pip3 install tianshou
 | 
			
		||||
 | 
			
		||||
The tutorials and API documentation are hosted on [https://tianshou.readthedocs.io](https://tianshou.readthedocs.io). It is under construction currently.
 | 
			
		||||
 | 
			
		||||
The example scripts are under [test/discrete](/test/discrete) (CartPole) and [test/continuous](/test/continuous) (Pendulum).
 | 
			
		||||
The example scripts are under [test/](/test/) folder and [examples/](/examples/) folder.
 | 
			
		||||
 | 
			
		||||
## Why Tianshou?
 | 
			
		||||
 | 
			
		||||
@ -49,22 +49,26 @@ Tianshou is a lightweight but high-speed reinforcement learning platform. For ex
 | 
			
		||||
 | 
			
		||||
We select some of famous (>1k stars) reinforcement learning platforms. Here is the benchmark result for other algorithms and platforms on toy scenarios:
 | 
			
		||||
 | 
			
		||||
| RL Platform      | [Tianshou](https://github.com/thu-ml/tianshou)               | [Baselines](https://github.com/openai/baselines)             | [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [PyTorch DRL](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch) | [rlpyt](https://github.com/astooke/rlpyt)                    |
 | 
			
		||||
| ---------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
 | 
			
		||||
| GitHub Stars     | [](https://github.com/thu-ml/tianshou/stargazers) | [](https://github.com/openai/baselines/stargazers) | [](https://github.com/ray-project/ray/stargazers) | [](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/stargazers) | [](https://github.com/astooke/rlpyt/stargazers) |
 | 
			
		||||
| Algo \ Task | PyTorch                                                      | TensorFlow                                                   | TF/PyTorch                                                   | PyTorch                                                      | PyTorch                                                      |
 | 
			
		||||
| PG - CartPole    | 9.03±4.18s                                                   | None                                                         |                                                              | None                                                         |                                                              |
 | 
			
		||||
| DQN - CartPole   | 20.94±11.38s                                                 | 1046.34±291.27s                                              |                                                              | 175.55±53.81s                                                |                                                              |
 | 
			
		||||
| A2C - CartPole   | 11.72±3.85s                                                  | *(~1612s)                                                    |                                                              | Runtime Error                                                |                                                              |
 | 
			
		||||
| PPO - CartPole   | 35.25±16.47s                                                 | *(~1179s)                                                    |                                                              | 29.16±15.46s                                                 |                                                              |
 | 
			
		||||
| DDPG - Pendulum  | 46.95±24.31s                                                 | *(>1h)                                                       |                                                              | 652.83±471.28s                                               | 172.18±62.48s                                                |
 | 
			
		||||
| TD3 - Pendulum   | 48.39±7.22s                                                  | None                                                         |                                                              | 619.33±324.97s                                               | 210.31±76.30s                                                |
 | 
			
		||||
| SAC - Pendulum   | 38.92±2.09s                                                  | None                                                         |                                                              | 808.21±405.70s                                               | 295.92±140.85s                                               |
 | 
			
		||||
| RL Platform     | [Tianshou](https://github.com/thu-ml/tianshou)               | [Baselines](https://github.com/openai/baselines)             | [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [PyTorch DRL](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch) | [rlpyt](https://github.com/astooke/rlpyt)                    |
 | 
			
		||||
| --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
 | 
			
		||||
| GitHub Stars    | [](https://github.com/thu-ml/tianshou/stargazers) | [](https://github.com/openai/baselines/stargazers) | [](https://github.com/ray-project/ray/stargazers) | [](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/stargazers) | [](https://github.com/astooke/rlpyt/stargazers) |
 | 
			
		||||
| Algo - Task     | PyTorch                                                      | TensorFlow                                                   | TF/PyTorch                                                   | PyTorch                                                      | PyTorch                                                      |
 | 
			
		||||
| PG - CartPole   | 9.03±4.18s                                                   | None                                                         | 15.77±6.28s                                                  | None                                                         |                                                              |
 | 
			
		||||
| DQN - CartPole  | 10.61±5.51s                                                  | 1046.34±291.27s                                              | 40.16±12.79s                                                 | 175.55±53.81s                                                |                                                              |
 | 
			
		||||
| A2C - CartPole  | 11.72±3.85s                                                  | *(~1612s)                                                    | 46.15±6.64s                                                  | Runtime Error                                                |                                                              |
 | 
			
		||||
| PPO - CartPole  | 35.25±16.47s                                                 | *(~1179s)                                                    | 62.21±13.31s (APPO)                                          | 29.16±15.46s                                                 |                                                              |
 | 
			
		||||
| DDPG - Pendulum | 46.95±24.31s                                                 | *(>1h)                                                       | 377.99±13.79s                                                | 652.83±471.28s                                               | 172.18±62.48s                                                |
 | 
			
		||||
| TD3 - Pendulum  | 48.39±7.22s                                                  | None                                                         | 620.83±248.43s                                               | 619.33±324.97s                                               | 210.31±76.30s                                                |
 | 
			
		||||
| SAC - Pendulum  | 38.92±2.09s                                                  | None                                                         | 92.68±4.48s                                                  | 808.21±405.70s                                               | 295.92±140.85s                                               |
 | 
			
		||||
 | 
			
		||||
*: Could not reach the target reward threshold in 1e6 steps in any of 10 runs. The total runtime is in the brackets.
 | 
			
		||||
 | 
			
		||||
All of the platforms use 10 different seeds for testing. We erase those trials which failed for training. The reward threshold is 195.0 in CartPole and -250.0 in Pendulum over consecutive 100 episodes' mean returns. 
 | 
			
		||||
 | 
			
		||||
Tianshou and RLlib's configures are very similar. They both use multiple workers for sampling. Indeed, both RLlib and rlpyt are excellent reinforcement learning platform :)
 | 
			
		||||
 | 
			
		||||
We will add results of Atari Pong / Mujoco these days.
 | 
			
		||||
 | 
			
		||||
### Reproducible
 | 
			
		||||
 | 
			
		||||
Tianshou has unit tests. Different from other platforms, **the unit tests include the full agent training procedure for all of the implemented algorithms**. It will be failed once it cannot train an agent to perform well enough on limited epochs on toy scenarios. The unit tests secure the reproducibility of our platform. 
 | 
			
		||||
 | 
			
		||||
@ -10,10 +10,7 @@ from tianshou.trainer import offpolicy_trainer
 | 
			
		||||
from tianshou.data import Collector, ReplayBuffer
 | 
			
		||||
from tianshou.env import VectorEnv, SubprocVectorEnv
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    from continuous_net import Actor, Critic
 | 
			
		||||
else:  # pytest
 | 
			
		||||
    from test.continuous.net import Actor, Critic
 | 
			
		||||
from continuous_net import Actor, Critic
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_args():
 | 
			
		||||
 | 
			
		||||
@ -10,10 +10,7 @@ from tianshou.trainer import offpolicy_trainer
 | 
			
		||||
from tianshou.data import Collector, ReplayBuffer
 | 
			
		||||
from tianshou.env import VectorEnv, SubprocVectorEnv
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    from continuous_net import ActorProb, Critic
 | 
			
		||||
else:  # pytest
 | 
			
		||||
    from test.continuous.net import ActorProb, Critic
 | 
			
		||||
from continuous_net import ActorProb, Critic
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_args():
 | 
			
		||||
 | 
			
		||||
@ -10,10 +10,7 @@ from tianshou.trainer import offpolicy_trainer
 | 
			
		||||
from tianshou.data import Collector, ReplayBuffer
 | 
			
		||||
from tianshou.env import VectorEnv, SubprocVectorEnv
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    from continuous_net import Actor, Critic
 | 
			
		||||
else:  # pytest
 | 
			
		||||
    from test.continuous.net import Actor, Critic
 | 
			
		||||
from continuous_net import Actor, Critic
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_args():
 | 
			
		||||
 | 
			
		||||
@ -10,10 +10,7 @@ from tianshou.trainer import offpolicy_trainer
 | 
			
		||||
from tianshou.data import Collector, ReplayBuffer
 | 
			
		||||
from tianshou.env import VectorEnv, SubprocVectorEnv
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    from continuous_net import Actor, Critic
 | 
			
		||||
else:  # pytest
 | 
			
		||||
    from test.continuous.net import Actor, Critic
 | 
			
		||||
from continuous_net import Actor, Critic
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_args():
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,3 @@
 | 
			
		||||
import gym
 | 
			
		||||
import torch
 | 
			
		||||
import pprint
 | 
			
		||||
import argparse
 | 
			
		||||
@ -11,10 +10,7 @@ from tianshou.trainer import onpolicy_trainer
 | 
			
		||||
from tianshou.data import Collector, ReplayBuffer
 | 
			
		||||
from tianshou.env.atari import create_atari_environment
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    from discrete_net import Net, Actor, Critic
 | 
			
		||||
else:  # pytest
 | 
			
		||||
    from test.discrete.net import Net, Actor, Critic
 | 
			
		||||
from discrete_net import Net, Actor, Critic
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_args():
 | 
			
		||||
@ -48,17 +44,20 @@ def get_args():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_a2c(args=get_args()):
 | 
			
		||||
    env = create_atari_environment(args.task, max_episode_steps=args.max_episode_steps)
 | 
			
		||||
    env = create_atari_environment(
 | 
			
		||||
        args.task, max_episode_steps=args.max_episode_steps)
 | 
			
		||||
    args.state_shape = env.observation_space.shape or env.observation_space.n
 | 
			
		||||
    args.action_shape = env.env.action_space.shape or env.env.action_space.n
 | 
			
		||||
    # train_envs = gym.make(args.task)
 | 
			
		||||
    train_envs = SubprocVectorEnv(
 | 
			
		||||
        [lambda: create_atari_environment(args.task, max_episode_steps=args.max_episode_steps) for _ in
 | 
			
		||||
         range(args.training_num)])
 | 
			
		||||
        [lambda: create_atari_environment(
 | 
			
		||||
            args.task, max_episode_steps=args.max_episode_steps)
 | 
			
		||||
            for _ in range(args.training_num)])
 | 
			
		||||
    # test_envs = gym.make(args.task)
 | 
			
		||||
    test_envs = SubprocVectorEnv(
 | 
			
		||||
        [lambda: create_atari_environment(args.task, max_episode_steps=args.max_episode_steps) for _ in
 | 
			
		||||
         range(args.test_num)])
 | 
			
		||||
        [lambda: create_atari_environment(
 | 
			
		||||
            args.task, max_episode_steps=args.max_episode_steps)
 | 
			
		||||
            for _ in range(args.test_num)])
 | 
			
		||||
    # seed
 | 
			
		||||
    np.random.seed(args.seed)
 | 
			
		||||
    torch.manual_seed(args.seed)
 | 
			
		||||
@ -91,7 +90,8 @@ def test_a2c(args=get_args()):
 | 
			
		||||
    result = onpolicy_trainer(
 | 
			
		||||
        policy, train_collector, test_collector, args.epoch,
 | 
			
		||||
        args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
 | 
			
		||||
        args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
 | 
			
		||||
        args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
 | 
			
		||||
        task=args.task)
 | 
			
		||||
    train_collector.close()
 | 
			
		||||
    test_collector.close()
 | 
			
		||||
    if __name__ == '__main__':
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,3 @@
 | 
			
		||||
import gym
 | 
			
		||||
import torch
 | 
			
		||||
import pprint
 | 
			
		||||
import argparse
 | 
			
		||||
@ -11,10 +10,7 @@ from tianshou.trainer import offpolicy_trainer
 | 
			
		||||
from tianshou.data import Collector, ReplayBuffer
 | 
			
		||||
from tianshou.env.atari import create_atari_environment
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    from discrete_net import DQN
 | 
			
		||||
else:  # pytest
 | 
			
		||||
    from test.discrete.net import DQN
 | 
			
		||||
from discrete_net import DQN
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_args():
 | 
			
		||||
@ -49,18 +45,22 @@ def test_dqn(args=get_args()):
 | 
			
		||||
    args.state_shape = env.observation_space.shape or env.observation_space.n
 | 
			
		||||
    args.action_shape = env.env.action_space.shape or env.env.action_space.n
 | 
			
		||||
    # train_envs = gym.make(args.task)
 | 
			
		||||
    train_envs = SubprocVectorEnv(
 | 
			
		||||
        [lambda: create_atari_environment(args.task) for _ in range(args.training_num)])
 | 
			
		||||
    train_envs = SubprocVectorEnv([
 | 
			
		||||
        lambda: create_atari_environment(args.task)
 | 
			
		||||
        for _ in range(args.training_num)])
 | 
			
		||||
    # test_envs = gym.make(args.task)
 | 
			
		||||
    test_envs = SubprocVectorEnv(
 | 
			
		||||
        [lambda: create_atari_environment(args.task) for _ in range(args.test_num)])
 | 
			
		||||
    test_envs = SubprocVectorEnv([
 | 
			
		||||
        lambda: create_atari_environment(
 | 
			
		||||
            args.task) for _ in range(args.test_num)])
 | 
			
		||||
    # seed
 | 
			
		||||
    np.random.seed(args.seed)
 | 
			
		||||
    torch.manual_seed(args.seed)
 | 
			
		||||
    train_envs.seed(args.seed)
 | 
			
		||||
    test_envs.seed(args.seed)
 | 
			
		||||
    # model
 | 
			
		||||
    net = DQN(args.state_shape[0], args.state_shape[1], args.action_shape, args.device)
 | 
			
		||||
    net = DQN(
 | 
			
		||||
        args.state_shape[0], args.state_shape[1],
 | 
			
		||||
        args.action_shape, args.device)
 | 
			
		||||
    net = net.to(args.device)
 | 
			
		||||
    optim = torch.optim.Adam(net.parameters(), lr=args.lr)
 | 
			
		||||
    policy = DQNPolicy(
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,3 @@
 | 
			
		||||
import gym
 | 
			
		||||
import torch
 | 
			
		||||
import pprint
 | 
			
		||||
import argparse
 | 
			
		||||
@ -11,10 +10,7 @@ from tianshou.trainer import onpolicy_trainer
 | 
			
		||||
from tianshou.data import Collector, ReplayBuffer
 | 
			
		||||
from tianshou.env.atari import create_atari_environment
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    from discrete_net import Net, Actor, Critic
 | 
			
		||||
else:  # pytest
 | 
			
		||||
    from test.discrete.net import Net, Actor, Critic
 | 
			
		||||
from discrete_net import Net, Actor, Critic
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_args():
 | 
			
		||||
@ -48,17 +44,18 @@ def get_args():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_ppo(args=get_args()):
 | 
			
		||||
    env = create_atari_environment(args.task, max_episode_steps=args.max_episode_steps)
 | 
			
		||||
    env = create_atari_environment(
 | 
			
		||||
        args.task, max_episode_steps=args.max_episode_steps)
 | 
			
		||||
    args.state_shape = env.observation_space.shape or env.observation_space.n
 | 
			
		||||
    args.action_shape = env.action_space().shape or env.action_space().n
 | 
			
		||||
    # train_envs = gym.make(args.task)
 | 
			
		||||
    train_envs = SubprocVectorEnv(
 | 
			
		||||
        [lambda: create_atari_environment(args.task, max_episode_steps=args.max_episode_steps) for _ in
 | 
			
		||||
         range(args.training_num)])
 | 
			
		||||
    train_envs = SubprocVectorEnv([lambda: create_atari_environment(
 | 
			
		||||
        args.task, max_episode_steps=args.max_episode_steps)
 | 
			
		||||
        for _ in range(args.training_num)])
 | 
			
		||||
    # test_envs = gym.make(args.task)
 | 
			
		||||
    test_envs = SubprocVectorEnv(
 | 
			
		||||
        [lambda: create_atari_environment(args.task, max_episode_steps=args.max_episode_steps) for _ in
 | 
			
		||||
         range(args.test_num)])
 | 
			
		||||
    test_envs = SubprocVectorEnv([lambda: create_atari_environment(
 | 
			
		||||
        args.task, max_episode_steps=args.max_episode_steps)
 | 
			
		||||
        for _ in range(args.test_num)])
 | 
			
		||||
    # seed
 | 
			
		||||
    np.random.seed(args.seed)
 | 
			
		||||
    torch.manual_seed(args.seed)
 | 
			
		||||
@ -95,7 +92,8 @@ def test_ppo(args=get_args()):
 | 
			
		||||
    result = onpolicy_trainer(
 | 
			
		||||
        policy, train_collector, test_collector, args.epoch,
 | 
			
		||||
        args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
 | 
			
		||||
        args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
 | 
			
		||||
        args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
 | 
			
		||||
        task=args.task)
 | 
			
		||||
    train_collector.close()
 | 
			
		||||
    test_collector.close()
 | 
			
		||||
    if __name__ == '__main__':
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							@ -55,7 +55,7 @@ setup(
 | 
			
		||||
        ],
 | 
			
		||||
        'atari': [
 | 
			
		||||
            'atari_py',
 | 
			
		||||
            'cv2'
 | 
			
		||||
            'cv2',
 | 
			
		||||
        ],
 | 
			
		||||
        'mujoco': [
 | 
			
		||||
            'mujoco_py',
 | 
			
		||||
 | 
			
		||||
@ -97,7 +97,8 @@ def _test_ppo(args=get_args()):
 | 
			
		||||
    result = onpolicy_trainer(
 | 
			
		||||
        policy, train_collector, test_collector, args.epoch,
 | 
			
		||||
        args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
 | 
			
		||||
        args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
 | 
			
		||||
        args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
 | 
			
		||||
        task=args.task)
 | 
			
		||||
    assert stop_fn(result['best_reward'])
 | 
			
		||||
    train_collector.close()
 | 
			
		||||
    test_collector.close()
 | 
			
		||||
 | 
			
		||||
@ -51,7 +51,6 @@ class Critic(nn.Module):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DQN(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, h, w, action_shape, device='cpu'):
 | 
			
		||||
        super(DQN, self).__init__()
 | 
			
		||||
        self.device = device
 | 
			
		||||
@ -73,7 +72,7 @@ class DQN(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def forward(self, x, state=None, info={}):
 | 
			
		||||
        if not isinstance(x, torch.Tensor):
 | 
			
		||||
            s = torch.tensor(x, device=self.device, dtype=torch.float)
 | 
			
		||||
            x = torch.tensor(x, device=self.device, dtype=torch.float)
 | 
			
		||||
        x = F.relu(self.bn1(self.conv1(x)))
 | 
			
		||||
        x = F.relu(self.bn2(self.conv2(x)))
 | 
			
		||||
        x = F.relu(self.bn3(self.conv3(x)))
 | 
			
		||||
 | 
			
		||||
@ -33,7 +33,6 @@ def get_args():
 | 
			
		||||
    parser.add_argument('--test-num', type=int, default=100)
 | 
			
		||||
    parser.add_argument('--logdir', type=str, default='log')
 | 
			
		||||
    parser.add_argument('--render', type=float, default=0.)
 | 
			
		||||
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        '--device', type=str,
 | 
			
		||||
        default='cuda' if torch.cuda.is_available() else 'cpu')
 | 
			
		||||
@ -84,7 +83,8 @@ def test_a2c(args=get_args()):
 | 
			
		||||
    result = onpolicy_trainer(
 | 
			
		||||
        policy, train_collector, test_collector, args.epoch,
 | 
			
		||||
        args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
 | 
			
		||||
        args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
 | 
			
		||||
        args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
 | 
			
		||||
        task=args.task)
 | 
			
		||||
    assert stop_fn(result['best_reward'])
 | 
			
		||||
    train_collector.close()
 | 
			
		||||
    test_collector.close()
 | 
			
		||||
 | 
			
		||||
@ -25,7 +25,7 @@ def get_args():
 | 
			
		||||
    parser.add_argument('--buffer-size', type=int, default=20000)
 | 
			
		||||
    parser.add_argument('--lr', type=float, default=1e-3)
 | 
			
		||||
    parser.add_argument('--gamma', type=float, default=0.9)
 | 
			
		||||
    parser.add_argument('--n-step', type=int, default=1)
 | 
			
		||||
    parser.add_argument('--n-step', type=int, default=3)
 | 
			
		||||
    parser.add_argument('--target-update-freq', type=int, default=320)
 | 
			
		||||
    parser.add_argument('--epoch', type=int, default=100)
 | 
			
		||||
    parser.add_argument('--step-per-epoch', type=int, default=1000)
 | 
			
		||||
@ -72,7 +72,6 @@ def test_dqn(args=get_args()):
 | 
			
		||||
    test_collector = Collector(policy, test_envs)
 | 
			
		||||
    # policy.set_eps(1)
 | 
			
		||||
    train_collector.collect(n_step=args.batch_size)
 | 
			
		||||
    print(len(train_collector.buffer))
 | 
			
		||||
    # log
 | 
			
		||||
    writer = SummaryWriter(args.logdir + '/' + 'ppo')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -131,7 +131,8 @@ def test_pg(args=get_args()):
 | 
			
		||||
    result = onpolicy_trainer(
 | 
			
		||||
        policy, train_collector, test_collector, args.epoch,
 | 
			
		||||
        args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
 | 
			
		||||
        args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
 | 
			
		||||
        args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
 | 
			
		||||
        task=args.task)
 | 
			
		||||
    assert stop_fn(result['best_reward'])
 | 
			
		||||
    train_collector.close()
 | 
			
		||||
    test_collector.close()
 | 
			
		||||
 | 
			
		||||
@ -88,7 +88,8 @@ def test_ppo(args=get_args()):
 | 
			
		||||
    result = onpolicy_trainer(
 | 
			
		||||
        policy, train_collector, test_collector, args.epoch,
 | 
			
		||||
        args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
 | 
			
		||||
        args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
 | 
			
		||||
        args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
 | 
			
		||||
        task=args.task)
 | 
			
		||||
    assert stop_fn(result['best_reward'])
 | 
			
		||||
    train_collector.close()
 | 
			
		||||
    test_collector.close()
 | 
			
		||||
 | 
			
		||||
@ -35,9 +35,10 @@ class Batch(object):
 | 
			
		||||
            elif isinstance(batch.__dict__[k], list):
 | 
			
		||||
                self.__dict__[k] += batch.__dict__[k]
 | 
			
		||||
            else:
 | 
			
		||||
                raise TypeError(
 | 
			
		||||
                    'No support for append with type {} in class Batch.'
 | 
			
		||||
                        .format(type(batch.__dict__[k])))
 | 
			
		||||
                s = 'No support for append with type'\
 | 
			
		||||
                    + str(type(batch.__dict__[k]))\
 | 
			
		||||
                    + 'in class Batch.'
 | 
			
		||||
                raise TypeError(s)
 | 
			
		||||
 | 
			
		||||
    def split(self, size=None, permute=True):
 | 
			
		||||
        length = min([
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,7 @@ class Collector(object):
 | 
			
		||||
        self.collect_episode = 0
 | 
			
		||||
        self.collect_time = 0
 | 
			
		||||
        if buffer is None:
 | 
			
		||||
            self.buffer = ReplayBuffer(20000)
 | 
			
		||||
            self.buffer = ReplayBuffer(100)
 | 
			
		||||
        else:
 | 
			
		||||
            self.buffer = buffer
 | 
			
		||||
        self.policy = policy
 | 
			
		||||
@ -100,7 +100,8 @@ class Collector(object):
 | 
			
		||||
        while True:
 | 
			
		||||
            if warning_count >= 100000:
 | 
			
		||||
                warnings.warn(
 | 
			
		||||
                    'There are already many steps in an episode. You should add a time limitation to your environment!',
 | 
			
		||||
                    'There are already many steps in an episode. '
 | 
			
		||||
                    'You should add a time limitation to your environment!',
 | 
			
		||||
                    Warning)
 | 
			
		||||
            if self._multi_env:
 | 
			
		||||
                batch_data = Batch(
 | 
			
		||||
 | 
			
		||||
@ -13,8 +13,8 @@ class OUNoise(object):
 | 
			
		||||
    def __call__(self, size, mu=.1):
 | 
			
		||||
        if self.x is None or self.x.shape != size:
 | 
			
		||||
            self.x = 0
 | 
			
		||||
        self.x = self.x + self.alpha * (mu - self.x) + \
 | 
			
		||||
                 self.beta * np.random.normal(size=size)
 | 
			
		||||
        r = self.beta * np.random.normal(size=size)
 | 
			
		||||
        self.x = self.x + self.alpha * (mu - self.x) + r
 | 
			
		||||
        return self.x
 | 
			
		||||
 | 
			
		||||
    def reset(self):
 | 
			
		||||
 | 
			
		||||
@ -34,9 +34,6 @@ class PGPolicy(BasePolicy):
 | 
			
		||||
 | 
			
		||||
    def learn(self, batch, batch_size=None, repeat=1):
 | 
			
		||||
        losses = []
 | 
			
		||||
 | 
			
		||||
        batch.returns = (batch.returns - batch.returns.mean()) \
 | 
			
		||||
                        / (batch.returns.std() + self._eps)
 | 
			
		||||
        r = batch.returns
 | 
			
		||||
        batch.returns = (r - r.mean()) / (r.std() + self._eps)
 | 
			
		||||
        for _ in range(repeat):
 | 
			
		||||
 | 
			
		||||
@ -58,9 +58,6 @@ class PPOPolicy(PGPolicy):
 | 
			
		||||
 | 
			
		||||
    def learn(self, batch, batch_size=None, repeat=1):
 | 
			
		||||
        losses, clip_losses, vf_losses, ent_losses = [], [], [], []
 | 
			
		||||
 | 
			
		||||
        batch.returns = (batch.returns - batch.returns.mean()) \
 | 
			
		||||
                        / (batch.returns.std() + self._eps)
 | 
			
		||||
        r = batch.returns
 | 
			
		||||
        batch.returns = (r - r.mean()) / (r.std() + self._eps)
 | 
			
		||||
        batch.act = torch.tensor(batch.act)
 | 
			
		||||
 | 
			
		||||
@ -47,7 +47,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
 | 
			
		||||
                        data[k] = f'{result[k]:.2f}'
 | 
			
		||||
                        if writer:
 | 
			
		||||
                            writer.add_scalar(
 | 
			
		||||
                                k + '_' + task, result[k], global_step=global_step)
 | 
			
		||||
                                k + '_' + task if task else k,
 | 
			
		||||
                                result[k], global_step=global_step)
 | 
			
		||||
                    for k in losses.keys():
 | 
			
		||||
                        if stat.get(k) is None:
 | 
			
		||||
                            stat[k] = MovAvg()
 | 
			
		||||
@ -55,7 +56,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
 | 
			
		||||
                        data[k] = f'{stat[k].get():.6f}'
 | 
			
		||||
                        if writer:
 | 
			
		||||
                            writer.add_scalar(
 | 
			
		||||
                                k + '_' + task, stat[k].get(), global_step=global_step)
 | 
			
		||||
                                k + '_' + task if task else k,
 | 
			
		||||
                                stat[k].get(), global_step=global_step)
 | 
			
		||||
                    t.update(1)
 | 
			
		||||
                    t.set_postfix(**data)
 | 
			
		||||
            if t.n <= t.total:
 | 
			
		||||
 | 
			
		||||
@ -52,7 +52,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
 | 
			
		||||
                    data[k] = f'{result[k]:.2f}'
 | 
			
		||||
                    if writer:
 | 
			
		||||
                        writer.add_scalar(
 | 
			
		||||
                            k + '_' + task, result[k], global_step=global_step)
 | 
			
		||||
                            k + '_' + task if task else k,
 | 
			
		||||
                            result[k], global_step=global_step)
 | 
			
		||||
                for k in losses.keys():
 | 
			
		||||
                    if stat.get(k) is None:
 | 
			
		||||
                        stat[k] = MovAvg()
 | 
			
		||||
@ -60,7 +61,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
 | 
			
		||||
                    data[k] = f'{stat[k].get():.6f}'
 | 
			
		||||
                    if writer and global_step:
 | 
			
		||||
                        writer.add_scalar(
 | 
			
		||||
                            k + '_' + task, stat[k].get(), global_step=global_step)
 | 
			
		||||
                            k + '_' + task if task else k,
 | 
			
		||||
                            stat[k].get(), global_step=global_step)
 | 
			
		||||
                t.update(step)
 | 
			
		||||
                t.set_postfix(**data)
 | 
			
		||||
            if t.n <= t.total:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user