From c87fe3c18c076f1a4951a5a3f2f4e5a8fc00abb5 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Thu, 19 Mar 2020 17:23:46 +0800 Subject: [PATCH] add trainer --- test/test_a2c.py | 138 +++++++++++++++-------------------- test/test_ddpg.py | 98 +++++++------------------ test/test_dqn.py | 94 ++++++++---------------- test/test_pg.py | 83 ++++++--------------- tianshou/__init__.py | 4 +- tianshou/data/collector.py | 39 +++++----- tianshou/policy/__init__.py | 4 +- tianshou/policy/a2c.py | 15 ++-- tianshou/policy/base.py | 1 + tianshou/policy/ddpg.py | 9 ++- tianshou/policy/dqn.py | 2 +- tianshou/policy/pg.py | 2 +- tianshou/policy/ppo.py | 50 +++++++++++++ tianshou/trainer/__init__.py | 7 ++ tianshou/trainer/episodic.py | 68 +++++++++++++++++ tianshou/trainer/step.py | 66 +++++++++++++++++ 16 files changed, 371 insertions(+), 309 deletions(-) create mode 100644 tianshou/policy/ppo.py create mode 100644 tianshou/trainer/__init__.py create mode 100644 tianshou/trainer/episodic.py create mode 100644 tianshou/trainer/step.py diff --git a/test/test_a2c.py b/test/test_a2c.py index 7cb2358..2b04b0c 100644 --- a/test/test_a2c.py +++ b/test/test_a2c.py @@ -1,6 +1,4 @@ import gym -import time -import tqdm import torch import argparse import numpy as np @@ -9,12 +7,12 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import A2CPolicy from tianshou.env import SubprocVectorEnv -from tianshou.utils import tqdm_config, MovAvg +from tianshou.trainer import episodic_trainer from tianshou.data import Collector, ReplayBuffer class Net(nn.Module): - def __init__(self, layer_num, state_shape, action_shape, device='cpu'): + def __init__(self, layer_num, state_shape, device='cpu'): super().__init__() self.device = device self.model = [ @@ -22,18 +20,40 @@ class Net(nn.Module): nn.ReLU(inplace=True)] for i in range(layer_num): self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] - self.actor = self.model + [nn.Linear(128, np.prod(action_shape))] - self.critic = self.model + [nn.Linear(128, 1)] - self.actor = nn.Sequential(*self.actor) - self.critic = nn.Sequential(*self.critic) + self.model = nn.Sequential(*self.model) - def forward(self, s, **kwargs): + def forward(self, s): s = torch.tensor(s, device=self.device, dtype=torch.float) batch = s.shape[0] s = s.view(batch, -1) - logits = self.actor(s) - value = self.critic(s) - return logits, value, None + logits = self.model(s) + return logits + + +class Actor(nn.Module): + def __init__(self, preprocess_net, action_shape): + super().__init__() + self.model = nn.Sequential(*[ + preprocess_net, + nn.Linear(128, np.prod(action_shape)), + ]) + + def forward(self, s, **kwargs): + logits = self.model(s) + return logits, None + + +class Critic(nn.Module): + def __init__(self, preprocess_net): + super().__init__() + self.model = nn.Sequential(*[ + preprocess_net, + nn.Linear(128, 1), + ]) + + def forward(self, s): + logits = self.model(s) + return logits def get_args(): @@ -80,83 +100,45 @@ def test_a2c(args=get_args()): train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.layer_num, args.state_shape, args.action_shape, args.device) - net = net.to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.lr) + net = Net(args.layer_num, args.state_shape, args.device) + actor = Actor(net, args.action_shape).to(args.device) + critic = Critic(net).to(args.device) + optim = torch.optim.Adam(list( + actor.parameters()) + list(critic.parameters()), lr=args.lr) dist = torch.distributions.Categorical policy = A2CPolicy( - net, optim, dist, args.gamma, - vf_coef=args.vf_coef, - entropy_coef=args.entropy_coef, - max_grad_norm=args.max_grad_norm) + actor, critic, optim, dist, args.gamma, vf_coef=args.vf_coef, + entropy_coef=args.entropy_coef, max_grad_norm=args.max_grad_norm) # collector - training_collector = Collector( + train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) test_collector = Collector(policy, test_envs, stat_size=args.test_num) # log - stat_loss = MovAvg() - global_step = 0 writer = SummaryWriter(args.logdir) - best_epoch = -1 - best_reward = -1e10 - start_time = time.time() - for epoch in range(1, 1 + args.epoch): - desc = f'Epoch #{epoch}' - # train - policy.train() - with tqdm.tqdm( - total=args.step_per_epoch, desc=desc, **tqdm_config) as t: - while t.n < t.total: - result = training_collector.collect( - n_episode=args.collect_per_step) - losses = policy.learn( - training_collector.sample(0), args.batch_size) - training_collector.reset_buffer() - global_step += len(losses) - t.update(len(losses)) - stat_loss.add(losses) - writer.add_scalar( - 'reward', result['reward'], global_step=global_step) - writer.add_scalar( - 'length', result['length'], global_step=global_step) - writer.add_scalar( - 'loss', stat_loss.get(), global_step=global_step) - writer.add_scalar( - 'speed', result['speed'], global_step=global_step) - t.set_postfix(loss=f'{stat_loss.get():.6f}', - reward=f'{result["reward"]:.6f}', - length=f'{result["length"]:.2f}', - speed=f'{result["speed"]:.2f}') - if t.n <= t.total: - t.update() - # eval - test_collector.reset_env() - test_collector.reset_buffer() - policy.eval() - result = test_collector.collect(n_episode=args.test_num) - if best_reward < result['reward']: - best_reward = result['reward'] - best_epoch = epoch - print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, ' - f'best_reward: {best_reward:.6f} in #{best_epoch}') - if best_reward >= env.spec.reward_threshold: - break - assert best_reward >= env.spec.reward_threshold - training_collector.close() + + def stop_fn(x): + return x >= env.spec.reward_threshold + + # trainer + train_step, train_episode, test_step, test_episode, best_rew, duration = \ + episodic_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, args.test_num, + args.batch_size, stop_fn=stop_fn, writer=writer) + assert stop_fn(best_rew) + train_collector.close() test_collector.close() if __name__ == '__main__': - train_cnt = training_collector.collect_step - test_cnt = test_collector.collect_step - duration = time.time() - start_time - print(f'Collect {train_cnt} training frame and {test_cnt} test frame ' - f'in {duration:.2f}s, ' - f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s') + print(f'Collect {train_step} frame / {train_episode} episode during ' + f'training and {test_step} frame / {test_episode} episode during' + f' test in {duration:.2f}s, best_reward: {best_rew}, speed: ' + f'{(train_step + test_step) / duration:.2f}it/s') # Let's watch its performance! env = gym.make(args.task) - test_collector = Collector(policy, env) - result = test_collector.collect(n_episode=1, render=1 / 35) - print(f'Final reward: {result["reward"]}, length: {result["length"]}') - test_collector.close() + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=1 / 35) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + collector.close() if __name__ == '__main__': diff --git a/test/test_ddpg.py b/test/test_ddpg.py index 1f808e9..e3237d5 100644 --- a/test/test_ddpg.py +++ b/test/test_ddpg.py @@ -1,6 +1,4 @@ import gym -import time -import tqdm import torch import argparse import numpy as np @@ -8,7 +6,7 @@ from torch import nn from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DDPGPolicy -from tianshou.utils import tqdm_config, MovAvg +from tianshou.trainer import step_trainer from tianshou.data import Collector, ReplayBuffer from tianshou.env import VectorEnv, SubprocVectorEnv @@ -121,85 +119,39 @@ def test_ddpg(args=get_args()): [env.action_space.low[0], env.action_space.high[0]], args.tau, args.gamma, args.exploration_noise) # collector - training_collector = Collector( + train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size), 1) test_collector = Collector(policy, test_envs, stat_size=args.test_num) # log - stat_a_loss = MovAvg() - stat_c_loss = MovAvg() - global_step = 0 writer = SummaryWriter(args.logdir) - best_epoch = -1 - best_reward = -1e10 - start_time = time.time() - # training_collector.collect(n_step=1000) - for epoch in range(1, 1 + args.epoch): - desc = f'Epoch #{epoch}' - # train - policy.train() - with tqdm.tqdm( - total=args.step_per_epoch, desc=desc, **tqdm_config) as t: - while t.n < t.total: - result = training_collector.collect( - n_step=args.collect_per_step) - for i in range(min( - result['n_step'] // args.collect_per_step, - t.total - t.n)): - t.update(1) - global_step += 1 - actor_loss, critic_loss = policy.learn( - training_collector.sample(args.batch_size)) - policy.sync_weight() - stat_a_loss.add(actor_loss) - stat_c_loss.add(critic_loss) - writer.add_scalar( - 'reward', result['reward'], global_step=global_step) - writer.add_scalar( - 'length', result['length'], global_step=global_step) - writer.add_scalar( - 'actor_loss', stat_a_loss.get(), - global_step=global_step) - writer.add_scalar( - 'critic_loss', stat_a_loss.get(), - global_step=global_step) - writer.add_scalar( - 'speed', result['speed'], global_step=global_step) - t.set_postfix(actor_loss=f'{stat_a_loss.get():.6f}', - critic_loss=f'{stat_c_loss.get():.6f}', - reward=f'{result["reward"]:.6f}', - length=f'{result["length"]:.2f}', - speed=f'{result["speed"]:.2f}') - if t.n <= t.total: - t.update() - # eval - test_collector.reset_env() - test_collector.reset_buffer() - policy.eval() - result = test_collector.collect(n_episode=args.test_num) - if best_reward < result['reward']: - best_reward = result['reward'] - best_epoch = epoch - print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, ' - f'best_reward: {best_reward:.6f} in #{best_epoch}') - if args.task == 'Pendulum-v0' and best_reward >= -250: - break + + def stop_fn(x): + if args.task == 'Pendulum-v0': + return x >= -250 + else: + return False + + # trainer + train_step, train_episode, test_step, test_episode, best_rew, duration = \ + step_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, args.test_num, + args.batch_size, stop_fn=stop_fn, writer=writer) if args.task == 'Pendulum-v0': - assert best_reward >= -250 - training_collector.close() + assert stop_fn(best_rew) + train_collector.close() test_collector.close() if __name__ == '__main__': - train_cnt = training_collector.collect_step - test_cnt = test_collector.collect_step - duration = time.time() - start_time - print(f'Collect {train_cnt} training frame and {test_cnt} test frame ' - f'in {duration:.2f}s, ' - f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s') + print(f'Collect {train_step} frame / {train_episode} episode during ' + f'training and {test_step} frame / {test_episode} episode during' + f' test in {duration:.2f}s, best_reward: {best_rew}, speed: ' + f'{(train_step + test_step) / duration:.2f}it/s') # Let's watch its performance! env = gym.make(args.task) - test_collector = Collector(policy, env) - result = test_collector.collect(n_episode=1, render=1 / 35) - print(f'Final reward: {result["reward"]}, length: {result["length"]}') - test_collector.close() + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=1 / 35) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + collector.close() if __name__ == '__main__': diff --git a/test/test_dqn.py b/test/test_dqn.py index e2450d6..efcce77 100644 --- a/test/test_dqn.py +++ b/test/test_dqn.py @@ -1,6 +1,4 @@ import gym -import time -import tqdm import torch import argparse import numpy as np @@ -9,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import DQNPolicy from tianshou.env import SubprocVectorEnv -from tianshou.utils import tqdm_config, MovAvg +from tianshou.trainer import step_trainer from tianshou.data import Collector, ReplayBuffer @@ -80,79 +78,45 @@ def test_dqn(args=get_args()): optim = torch.optim.Adam(net.parameters(), lr=args.lr) policy = DQNPolicy(net, optim, args.gamma, args.n_step) # collector - training_collector = Collector( + train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) test_collector = Collector(policy, test_envs, stat_size=args.test_num) - training_collector.collect(n_step=args.batch_size) + train_collector.collect(n_step=args.batch_size) # log - stat_loss = MovAvg() - global_step = 0 writer = SummaryWriter(args.logdir) - best_epoch = -1 - best_reward = -1e10 - start_time = time.time() - for epoch in range(1, 1 + args.epoch): - desc = f"Epoch #{epoch}" - # train - policy.train() + + def stop_fn(x): + return x >= env.spec.reward_threshold + + def train_fn(x): policy.sync_weight() policy.set_eps(args.eps_train) - with tqdm.tqdm( - total=args.step_per_epoch, desc=desc, **tqdm_config) as t: - while t.n < t.total: - result = training_collector.collect( - n_step=args.collect_per_step) - for i in range(min( - result['n_step'] // args.collect_per_step, - t.total - t.n)): - t.update(1) - global_step += 1 - loss = policy.learn( - training_collector.sample(args.batch_size)) - stat_loss.add(loss) - writer.add_scalar( - 'reward', result['reward'], global_step=global_step) - writer.add_scalar( - 'length', result['length'], global_step=global_step) - writer.add_scalar( - 'loss', stat_loss.get(), global_step=global_step) - writer.add_scalar( - 'speed', result['speed'], global_step=global_step) - t.set_postfix(loss=f'{stat_loss.get():.6f}', - reward=f'{result["reward"]:.6f}', - length=f'{result["length"]:.2f}', - speed=f'{result["speed"]:.2f}') - if t.n <= t.total: - t.update() - # eval - test_collector.reset_env() - test_collector.reset_buffer() - policy.eval() + + def test_fn(x): policy.set_eps(args.eps_test) - result = test_collector.collect(n_episode=args.test_num) - if best_reward < result['reward']: - best_reward = result['reward'] - best_epoch = epoch - print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, ' - f'best_reward: {best_reward:.6f} in #{best_epoch}') - if best_reward >= env.spec.reward_threshold: - break - assert best_reward >= env.spec.reward_threshold - training_collector.close() + + # trainer + train_step, train_episode, test_step, test_episode, best_rew, duration = \ + step_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, args.test_num, + args.batch_size, train_fn=train_fn, test_fn=test_fn, + stop_fn=stop_fn, writer=writer) + + assert stop_fn(best_rew) + train_collector.close() test_collector.close() if __name__ == '__main__': - train_cnt = training_collector.collect_step - test_cnt = test_collector.collect_step - duration = time.time() - start_time - print(f'Collect {train_cnt} training frame and {test_cnt} test frame ' - f'in {duration:.2f}s, ' - f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s') + print(f'Collect {train_step} frame / {train_episode} episode during ' + f'training and {test_step} frame / {test_episode} episode during' + f' test in {duration:.2f}s, best_reward: {best_rew}, speed: ' + f'{(train_step + test_step) / duration:.2f}it/s') # Let's watch its performance! env = gym.make(args.task) - test_collector = Collector(policy, env) - result = test_collector.collect(n_episode=1, render=1 / 35) - print(f'Final reward: {result["reward"]}, length: {result["length"]}') - test_collector.close() + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=1 / 35) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + collector.close() if __name__ == '__main__': diff --git a/test/test_pg.py b/test/test_pg.py index e2f45a6..5a476c2 100644 --- a/test/test_pg.py +++ b/test/test_pg.py @@ -1,6 +1,5 @@ import gym import time -import tqdm import torch import argparse import numpy as np @@ -9,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.policy import PGPolicy from tianshou.env import SubprocVectorEnv -from tianshou.utils import tqdm_config, MovAvg +from tianshou.trainer import episodic_trainer from tianshou.data import Batch, Collector, ReplayBuffer @@ -131,73 +130,35 @@ def test_pg(args=get_args()): dist = torch.distributions.Categorical policy = PGPolicy(net, optim, dist, args.gamma) # collector - training_collector = Collector( + train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) test_collector = Collector(policy, test_envs, stat_size=args.test_num) # log - stat_loss = MovAvg() - global_step = 0 writer = SummaryWriter(args.logdir) - best_epoch = -1 - best_reward = -1e10 - start_time = time.time() - for epoch in range(1, 1 + args.epoch): - desc = f"Epoch #{epoch}" - # train - policy.train() - with tqdm.tqdm( - total=args.step_per_epoch, desc=desc, **tqdm_config) as t: - while t.n < t.total: - result = training_collector.collect( - n_episode=args.collect_per_step) - losses = policy.learn( - training_collector.sample(0), args.batch_size) - training_collector.reset_buffer() - global_step += len(losses) - t.update(len(losses)) - stat_loss.add(losses) - writer.add_scalar( - 'reward', result['reward'], global_step=global_step) - writer.add_scalar( - 'length', result['length'], global_step=global_step) - writer.add_scalar( - 'loss', stat_loss.get(), global_step=global_step) - writer.add_scalar( - 'speed', result['speed'], global_step=global_step) - t.set_postfix(loss=f'{stat_loss.get():.6f}', - reward=f'{result["reward"]:.6f}', - length=f'{result["length"]:.2f}', - speed=f'{result["speed"]:.2f}') - if t.n <= t.total: - t.update() - # eval - test_collector.reset_env() - test_collector.reset_buffer() - policy.eval() - result = test_collector.collect(n_episode=args.test_num) - if best_reward < result['reward']: - best_reward = result['reward'] - best_epoch = epoch - print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, ' - f'best_reward: {best_reward:.6f} in #{best_epoch}') - if best_reward >= env.spec.reward_threshold: - break - assert best_reward >= env.spec.reward_threshold - training_collector.close() + + def stop_fn(x): + return x >= env.spec.reward_threshold + + # trainer + train_step, train_episode, test_step, test_episode, best_rew, duration = \ + episodic_trainer( + policy, train_collector, test_collector, args.epoch, + args.step_per_epoch, args.collect_per_step, args.test_num, + args.batch_size, stop_fn=stop_fn, writer=writer) + assert stop_fn(best_rew) + train_collector.close() test_collector.close() if __name__ == '__main__': - train_cnt = training_collector.collect_step - test_cnt = test_collector.collect_step - duration = time.time() - start_time - print(f'Collect {train_cnt} training frame and {test_cnt} test frame ' - f'in {duration:.2f}s, ' - f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s') + print(f'Collect {train_step} frame / {train_episode} episode during ' + f'training and {test_step} frame / {test_episode} episode during' + f' test in {duration:.2f}s, best_reward: {best_rew}, speed: ' + f'{(train_step + test_step) / duration:.2f}it/s') # Let's watch its performance! env = gym.make(args.task) - test_collector = Collector(policy, env) - result = test_collector.collect(n_episode=1, render=1 / 35) - print(f'Final reward: {result["reward"]}, length: {result["length"]}') - test_collector.close() + collector = Collector(policy, env) + result = collector.collect(n_episode=1, render=1 / 35) + print(f'Final reward: {result["rew"]}, length: {result["len"]}') + collector.close() if __name__ == '__main__': diff --git a/tianshou/__init__.py b/tianshou/__init__.py index 5d71a93..3528010 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,4 +1,5 @@ -from tianshou import data, env, utils, policy, exploration +from tianshou import data, env, utils, policy, trainer,\ + exploration __version__ = '0.2.0' __all__ = [ @@ -6,5 +7,6 @@ __all__ = [ 'data', 'utils', 'policy', + 'trainer', 'exploration', ] diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 938769a..ae2980d 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -16,6 +16,7 @@ class Collector(object): self.env = env self.env_num = 1 self.collect_step = 0 + self.collect_episode = 0 self.buffer = buffer self.policy = policy self.process_fn = policy.process_fn @@ -39,9 +40,8 @@ class Collector(object): self.reset_buffer() # state over batch is either a list, an np.ndarray, or a torch.Tensor self.state = None - self.stat_reward = MovAvg(stat_size) - self.stat_length = MovAvg(stat_size) - self.stat_speed = MovAvg(stat_size) + self.step_speed = MovAvg(stat_size) + self.episode_speed = MovAvg(stat_size) def reset_buffer(self): if self._multi_buf: @@ -81,11 +81,12 @@ class Collector(object): def collect(self, n_step=0, n_episode=0, render=0): start_time = time.time() - start_step = self.collect_step assert sum([(n_step > 0), (n_episode > 0)]) == 1,\ "One and only one collection number specification permitted!" cur_step = 0 cur_episode = np.zeros(self.env_num) if self._multi_env else 0 + reward_sum = 0 + length_sum = 0 while True: if self._multi_env: batch_data = Batch( @@ -126,20 +127,17 @@ class Collector(object): elif self._multi_buf: self.buffer[i].add(**data) cur_step += 1 - self.collect_step += 1 else: self.buffer.add(**data) cur_step += 1 - self.collect_step += 1 if self._done[i]: cur_episode[i] += 1 - self.stat_reward.add(self.reward[i]) - self.stat_length.add(self.length[i]) + reward_sum += self.reward[i] + length_sum += self.length[i] self.reward[i], self.length[i] = 0, 0 if self._cached_buf: self.buffer.update(self._cached_buf[i]) cur_step += len(self._cached_buf[i]) - self.collect_step += len(self._cached_buf[i]) self._cached_buf[i].reset() if isinstance(self.state, list): self.state[i] = None @@ -158,11 +156,10 @@ class Collector(object): self._obs, self._act[0], self._rew, self._done, obs_next, self._info) cur_step += 1 - self.collect_step += 1 if self._done: cur_episode += 1 - self.stat_reward.add(self.reward) - self.stat_length.add(self.length) + reward_sum += self.reward + length_sum += self.length self.reward, self.length = 0, 0 self.state = None self._obs = self.env.reset() @@ -172,16 +169,20 @@ class Collector(object): break self._obs = obs_next self._obs = obs_next - self.stat_speed.add((self.collect_step - start_step) / ( - time.time() - start_time)) if self._multi_env: cur_episode = sum(cur_episode) + duration = time.time() - start_time + self.step_speed.add(cur_step / duration) + self.episode_speed.add(cur_episode / duration) + self.collect_step += cur_step + self.collect_episode += cur_episode return { - 'reward': self.stat_reward.get(), - 'length': self.stat_length.get(), - 'speed': self.stat_speed.get(), - 'n_episode': cur_episode, - 'n_step': cur_step, + 'n/ep': cur_episode, + 'n/st': cur_step, + 'speed/st': self.step_speed.get(), + 'speed/ep': self.episode_speed.get(), + 'rew': reward_sum / cur_episode, + 'len': length_sum / cur_episode, } def sample(self, batch_size): diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index aba6a73..a979443 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -3,11 +3,13 @@ from tianshou.policy.dqn import DQNPolicy from tianshou.policy.pg import PGPolicy from tianshou.policy.a2c import A2CPolicy from tianshou.policy.ddpg import DDPGPolicy +from tianshou.policy.ppo import PPOPolicy __all__ = [ 'BasePolicy', 'DQNPolicy', 'PGPolicy', 'A2CPolicy', - 'DDPGPolicy' + 'DDPGPolicy', + 'PPOPolicy', ] diff --git a/tianshou/policy/a2c.py b/tianshou/policy/a2c.py index 1dd1cc8..4bed68a 100644 --- a/tianshou/policy/a2c.py +++ b/tianshou/policy/a2c.py @@ -9,20 +9,23 @@ from tianshou.policy import PGPolicy class A2CPolicy(PGPolicy): """docstring for A2CPolicy""" - def __init__(self, model, optim, dist_fn=torch.distributions.Categorical, + def __init__(self, actor, critic, optim, + dist_fn=torch.distributions.Categorical, discount_factor=0.99, vf_coef=.5, entropy_coef=.01, max_grad_norm=None): - super().__init__(model, optim, dist_fn, discount_factor) + super().__init__(None, optim, dist_fn, discount_factor) + self.actor = actor + self.critic = critic self._w_value = vf_coef self._w_entropy = entropy_coef self._grad_norm = max_grad_norm def __call__(self, batch, state=None): - logits, value, h = self.model(batch.obs, state=state, info=batch.info) + logits, h = self.actor(batch.obs, state=state, info=batch.info) logits = F.softmax(logits, dim=1) dist = self.dist_fn(logits) act = dist.sample() - return Batch(logits=logits, act=act, state=h, dist=dist, value=value) + return Batch(logits=logits, act=act, state=h, dist=dist) def learn(self, batch, batch_size=None): losses = [] @@ -30,7 +33,7 @@ class A2CPolicy(PGPolicy): self.optim.zero_grad() result = self(b) dist = result.dist - v = result.value + v = self.critic(b.obs) a = torch.tensor(b.act, device=dist.logits.device) r = torch.tensor(b.returns, device=dist.logits.device) actor_loss = -(dist.log_prob(a) * (r - v).detach()).mean() @@ -45,4 +48,4 @@ class A2CPolicy(PGPolicy): self.model.parameters(), max_norm=self._grad_norm) self.optim.step() losses.append(loss.detach().cpu().numpy()) - return losses + return {'loss': losses} diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index f4f4f83..487d6c8 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -18,6 +18,7 @@ class BasePolicy(ABC, nn.Module): @abstractmethod def learn(self, batch, batch_size=None): + # return a dict which includes loss and its name pass def sync_weight(self): diff --git a/tianshou/policy/ddpg.py b/tianshou/policy/ddpg.py index 63b32db..4d545b5 100644 --- a/tianshou/policy/ddpg.py +++ b/tianshou/policy/ddpg.py @@ -28,7 +28,7 @@ class DDPGPolicy(BasePolicy): self._tau = tau assert 0 < gamma <= 1, 'gamma should in (0, 1]' self._gamma = gamma - assert 0 <= exploration_noise, 'noise should greater than zero' + assert 0 <= exploration_noise, 'noise should not be negative' self._eps = exploration_noise self._range = action_range # self.noise = OUNoise() @@ -87,5 +87,8 @@ class DDPGPolicy(BasePolicy): self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() - return actor_loss.detach().cpu().numpy(),\ - critic_loss.detach().cpu().numpy() + self.sync_weight() + return { + 'loss/actor': actor_loss.detach().cpu().numpy(), + 'loss/critic': critic_loss.detach().cpu().numpy(), + } diff --git a/tianshou/policy/dqn.py b/tianshou/policy/dqn.py index a4dec21..c0f4754 100644 --- a/tianshou/policy/dqn.py +++ b/tianshou/policy/dqn.py @@ -93,4 +93,4 @@ class DQNPolicy(BasePolicy): loss = F.mse_loss(q, r) loss.backward() self.optim.step() - return loss.detach().cpu().numpy() + return {'loss': loss.detach().cpu().numpy()} diff --git a/tianshou/policy/pg.py b/tianshou/policy/pg.py index 430442b..e270470 100644 --- a/tianshou/policy/pg.py +++ b/tianshou/policy/pg.py @@ -45,7 +45,7 @@ class PGPolicy(BasePolicy): loss.backward() self.optim.step() losses.append(loss.detach().cpu().numpy()) - return losses + return {'loss': losses} def _vanilla_returns(self, batch): returns = batch.rew[:] diff --git a/tianshou/policy/ppo.py b/tianshou/policy/ppo.py new file mode 100644 index 0000000..6754bbb --- /dev/null +++ b/tianshou/policy/ppo.py @@ -0,0 +1,50 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from tianshou.data import Batch +from tianshou.policy import PGPolicy + + +class PPOPolicy(PGPolicy): + """docstring for PPOPolicy""" + + def __init__(self, actor, actor_optim, + critic, critic_optim, + dist_fn=torch.distributions.Categorical, + discount_factor=0.99, vf_coef=.5, entropy_coef=.01, + eps_clip=None): + super().__init__(None, None, dist_fn, discount_factor) + self._w_value = vf_coef + self._w_entropy = entropy_coef + self._eps_clip = eps_clip + + def __call__(self, batch, state=None): + logits, value, h = self.model(batch.obs, state=state, info=batch.info) + logits = F.softmax(logits, dim=1) + dist = self.dist_fn(logits) + act = dist.sample() + return Batch(logits=logits, act=act, state=h, dist=dist, value=value) + + def learn(self, batch, batch_size=None): + losses = [] + for b in batch.split(batch_size): + self.optim.zero_grad() + result = self(b) + dist = result.dist + v = result.value + a = torch.tensor(b.act, device=dist.logits.device) + r = torch.tensor(b.returns, device=dist.logits.device) + actor_loss = -(dist.log_prob(a) * (r - v).detach()).mean() + critic_loss = F.mse_loss(r[:, None], v) + entropy_loss = dist.entropy().mean() + loss = actor_loss \ + + self._w_value * critic_loss \ + - self._w_entropy * entropy_loss + loss.backward() + if self._grad_norm: + nn.utils.clip_grad_norm_( + self.model.parameters(), max_norm=self._grad_norm) + self.optim.step() + losses.append(loss.detach().cpu().numpy()) + return losses diff --git a/tianshou/trainer/__init__.py b/tianshou/trainer/__init__.py new file mode 100644 index 0000000..24ab402 --- /dev/null +++ b/tianshou/trainer/__init__.py @@ -0,0 +1,7 @@ +from tianshou.trainer.episodic import episodic_trainer +from tianshou.trainer.step import step_trainer + +__all__ = [ + 'episodic_trainer', + 'step_trainer', +] diff --git a/tianshou/trainer/episodic.py b/tianshou/trainer/episodic.py new file mode 100644 index 0000000..46b09f1 --- /dev/null +++ b/tianshou/trainer/episodic.py @@ -0,0 +1,68 @@ +import time +import tqdm + +from tianshou.utils import tqdm_config, MovAvg + + +def episodic_trainer(policy, train_collector, test_collector, max_epoch, + step_per_epoch, collect_per_step, episode_per_test, + batch_size, train_fn=None, test_fn=None, stop_fn=None, + writer=None, verbose=True): + global_step = 0 + best_epoch, best_reward = -1, -1 + stat = {} + start_time = time.time() + for epoch in range(1, 1 + max_epoch): + # train + policy.train() + if train_fn: + train_fn(epoch) + with tqdm.tqdm( + total=step_per_epoch, desc=f'Epoch #{epoch}', + **tqdm_config) as t: + while t.n < t.total: + result = train_collector.collect(n_episode=collect_per_step) + losses = policy.learn(train_collector.sample(0), batch_size) + train_collector.reset_buffer() + step = 1 + data = {} + for k in losses.keys(): + if isinstance(losses[k], list): + step = max(step, len(losses[k])) + global_step += step + for k in result.keys(): + data[k] = f'{result[k]:.2f}' + if writer: + writer.add_scalar( + k, result[k], global_step=global_step) + for k in losses.keys(): + if stat.get(k) is None: + stat[k] = MovAvg() + stat[k].add(losses[k]) + data[k] = f'{stat[k].get():.6f}' + if writer: + writer.add_scalar( + k, stat[k].get(), global_step=global_step) + t.update(step) + t.set_postfix(**data) + if t.n <= t.total: + t.update() + # eval + test_collector.reset_env() + test_collector.reset_buffer() + policy.eval() + if test_fn: + test_fn(epoch) + result = test_collector.collect(n_episode=episode_per_test) + if best_epoch == -1 or best_reward < result['rew']: + best_reward = result['rew'] + best_epoch = epoch + if verbose: + print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, ' + f'best_reward: {best_reward:.6f} in #{best_epoch}') + if stop_fn(best_reward): + break + duration = time.time() - start_time + return train_collector.collect_step, train_collector.collect_episode,\ + test_collector.collect_step, test_collector.collect_episode,\ + best_reward, duration diff --git a/tianshou/trainer/step.py b/tianshou/trainer/step.py new file mode 100644 index 0000000..95dfdc8 --- /dev/null +++ b/tianshou/trainer/step.py @@ -0,0 +1,66 @@ +import time +import tqdm + +from tianshou.utils import tqdm_config, MovAvg + + +def step_trainer(policy, train_collector, test_collector, max_epoch, + step_per_epoch, collect_per_step, episode_per_test, + batch_size, train_fn=None, test_fn=None, stop_fn=None, + writer=None, verbose=True): + global_step = 0 + best_epoch, best_reward = -1, -1 + stat = {} + start_time = time.time() + for epoch in range(1, 1 + max_epoch): + # train + policy.train() + if train_fn: + train_fn(epoch) + with tqdm.tqdm( + total=step_per_epoch, desc=f'Epoch #{epoch}', + **tqdm_config) as t: + while t.n < t.total: + result = train_collector.collect(n_step=collect_per_step) + for i in range(min( + result['n/st'] // collect_per_step, + t.total - t.n)): + global_step += 1 + losses = policy.learn(train_collector.sample(batch_size)) + data = {} + for k in result.keys(): + data[k] = f'{result[k]:.2f}' + if writer: + writer.add_scalar( + k, result[k], global_step=global_step) + for k in losses.keys(): + if stat.get(k) is None: + stat[k] = MovAvg() + stat[k].add(losses[k]) + data[k] = f'{stat[k].get():.6f}' + if writer: + writer.add_scalar( + k, stat[k].get(), global_step=global_step) + t.update(1) + t.set_postfix(**data) + if t.n <= t.total: + t.update() + # eval + test_collector.reset_env() + test_collector.reset_buffer() + policy.eval() + if test_fn: + test_fn(epoch) + result = test_collector.collect(n_episode=episode_per_test) + if best_epoch == -1 or best_reward < result['rew']: + best_reward = result['rew'] + best_epoch = epoch + if verbose: + print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, ' + f'best_reward: {best_reward:.6f} in #{best_epoch}') + if stop_fn(best_reward): + break + duration = time.time() - start_time + return train_collector.collect_step, train_collector.collect_episode,\ + test_collector.collect_step, test_collector.collect_episode,\ + best_reward, duration