From 6e563fe61a3ae16632de77c9bc827e855d981f25 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Tue, 17 Mar 2020 20:22:37 +0800 Subject: [PATCH] a2c --- test/test_a2c.py | 160 +++++++++++++++++++++++++++++ test/test_pg.py | 6 +- tianshou/env/wrapper.py | 46 +++++---- tianshou/policy/__init__.py | 2 + tianshou/policy/a2c.py | 42 ++++++++ tianshou/policy/policy_gradient.py | 20 ++-- 6 files changed, 239 insertions(+), 37 deletions(-) create mode 100644 test/test_a2c.py create mode 100644 tianshou/policy/a2c.py diff --git a/test/test_a2c.py b/test/test_a2c.py new file mode 100644 index 0000000..4033c4b --- /dev/null +++ b/test/test_a2c.py @@ -0,0 +1,160 @@ +import gym +import time +import tqdm +import torch +import argparse +import numpy as np +from torch import nn +from torch.utils.tensorboard import SummaryWriter + +from tianshou.policy import A2CPolicy +from tianshou.env import SubprocVectorEnv +from tianshou.utils import tqdm_config, MovAvg +from tianshou.data import Collector, ReplayBuffer + + +class Net(nn.Module): + def __init__(self, layer_num, state_shape, action_shape, device='cpu'): + super().__init__() + self.device = device + self.model = [ + nn.Linear(np.prod(state_shape), 128), + nn.ReLU(inplace=True)] + for i in range(layer_num): + self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] + self.actor = self.model + [nn.Linear(128, np.prod(action_shape))] + self.critic = self.model + [nn.Linear(128, 1)] + self.actor = nn.Sequential(*self.actor) + self.critic = nn.Sequential(*self.critic) + + def forward(self, s, **kwargs): + s = torch.tensor(s, device=self.device, dtype=torch.float) + batch = s.shape[0] + s = s.view(batch, -1) + logits = self.actor(s) + value = self.critic(s) + return logits, value, None + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='CartPole-v0') + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--gamma', type=float, default=0.9) + parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--step-per-epoch', type=int, default=320) + parser.add_argument('--collect-per-step', type=int, default=10) + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--layer-num', type=int, default=2) + parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + # a2c special + parser.add_argument('--vf-coef', type=float, default=0.5) + parser.add_argument('--entropy-coef', type=float, default=0.001) + args = parser.parse_known_args()[0] + return args + + +def test_a2c(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + # train_envs = gym.make(args.task) + train_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)], + reset_after_done=True) + # test_envs = gym.make(args.task) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)], + reset_after_done=False) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + net = Net(args.layer_num, args.state_shape, args.action_shape, args.device) + net = net.to(args.device) + optim = torch.optim.Adam(net.parameters(), lr=args.lr) + dist = torch.distributions.Categorical + policy = A2CPolicy( + net, optim, dist, args.gamma, + vf_coef=args.vf_coef, + entropy_coef=args.entropy_coef) + # collector + training_collector = Collector( + policy, train_envs, ReplayBuffer(args.buffer_size)) + test_collector = Collector( + policy, test_envs, ReplayBuffer(args.buffer_size), args.test_num) + # log + stat_loss = MovAvg() + global_step = 0 + writer = SummaryWriter(args.logdir) + best_epoch = -1 + best_reward = -1e10 + start_time = time.time() + for epoch in range(1, 1 + args.epoch): + desc = f'Epoch #{epoch}' + # train + policy.train() + with tqdm.tqdm( + total=args.step_per_epoch, desc=desc, **tqdm_config) as t: + while t.n < t.total: + result = training_collector.collect( + n_episode=args.collect_per_step) + losses = policy.learn( + training_collector.sample(0), args.batch_size) + training_collector.reset_buffer() + global_step += len(losses) + t.update(len(losses)) + stat_loss.add(losses) + writer.add_scalar( + 'reward', result['reward'], global_step=global_step) + writer.add_scalar( + 'length', result['length'], global_step=global_step) + writer.add_scalar( + 'loss', stat_loss.get(), global_step=global_step) + writer.add_scalar( + 'speed', result['speed'], global_step=global_step) + t.set_postfix(loss=f'{stat_loss.get():.6f}', + reward=f'{result["reward"]:.6f}', + length=f'{result["length"]:.2f}', + speed=f'{result["speed"]:.2f}') + # eval + test_collector.reset_env() + test_collector.reset_buffer() + policy.eval() + result = test_collector.collect(n_episode=args.test_num) + if best_reward < result['reward']: + best_reward = result['reward'] + best_epoch = epoch + print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, ' + f'best_reward: {best_reward:.6f} in #{best_epoch}') + if best_reward >= env.spec.reward_threshold: + break + assert best_reward >= env.spec.reward_threshold + training_collector.close() + test_collector.close() + if __name__ == '__main__': + train_cnt = training_collector.collect_step + test_cnt = test_collector.collect_step + duration = time.time() - start_time + print(f'Collect {train_cnt} training frame and {test_cnt} test frame ' + f'in {duration:.2f}s, ' + f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s') + # Let's watch its performance! + env = gym.make(args.task) + test_collector = Collector(policy, env) + result = test_collector.collect(n_episode=1, render=1 / 35) + print(f'Final reward: {result["reward"]}, length: {result["length"]}') + test_collector.close() + + +if __name__ == '__main__': + test_a2c() diff --git a/test/test_pg.py b/test/test_pg.py index 32978c8..93c8603 100644 --- a/test/test_pg.py +++ b/test/test_pg.py @@ -26,8 +26,7 @@ def compute_return_base(batch, aa=None, bb=None, gamma=0.1): def test_fn(size=2560): - policy = PGPolicy( - None, None, None, discount_factor=0.1, normalized_reward=False) + policy = PGPolicy(None, None, None, discount_factor=0.1) fn = policy.process_fn # fn = compute_return_base batch = Batch( @@ -36,7 +35,6 @@ def test_fn(size=2560): ) batch = fn(batch, None, None) ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7]) - ans -= ans.mean() assert abs(batch.returns - ans).sum() <= 1e-5 batch = Batch( done=np.array([0, 1, 0, 1, 0, 1, 0.]), @@ -44,7 +42,6 @@ def test_fn(size=2560): ) batch = fn(batch, None, None) ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5]) - ans -= ans.mean() assert abs(batch.returns - ans).sum() <= 1e-5 batch = Batch( done=np.array([0, 1, 0, 1, 0, 0, 1.]), @@ -52,7 +49,6 @@ def test_fn(size=2560): ) batch = fn(batch, None, None) ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]) - ans -= ans.mean() assert abs(batch.returns - ans).sum() <= 1e-5 if __name__ == '__main__': batch = Batch( diff --git a/tianshou/env/wrapper.py b/tianshou/env/wrapper.py index fa26dc2..f67be4f 100644 --- a/tianshou/env/wrapper.py +++ b/tianshou/env/wrapper.py @@ -143,27 +143,31 @@ def worker(parent, p, env_fn_wrapper, reset_after_done): parent.close() env = env_fn_wrapper.data() done = False - while True: - cmd, data = p.recv() - if cmd == 'step': - if reset_after_done or not done: - obs, rew, done, info = env.step(data) - if reset_after_done and done: - # s_ is useless when episode finishes - obs = env.reset() - p.send([obs, rew, done, info]) - elif cmd == 'reset': - done = False - p.send(env.reset()) - elif cmd == 'close': - p.close() - break - elif cmd == 'render': - p.send(env.render() if hasattr(env, 'render') else None) - elif cmd == 'seed': - p.send(env.seed(data) if hasattr(env, 'seed') else None) - else: - raise NotImplementedError + try: + while True: + cmd, data = p.recv() + if cmd == 'step': + if reset_after_done or not done: + obs, rew, done, info = env.step(data) + if reset_after_done and done: + # s_ is useless when episode finishes + obs = env.reset() + p.send([obs, rew, done, info]) + elif cmd == 'reset': + done = False + p.send(env.reset()) + elif cmd == 'close': + p.close() + break + elif cmd == 'render': + p.send(env.render() if hasattr(env, 'render') else None) + elif cmd == 'seed': + p.send(env.seed(data) if hasattr(env, 'seed') else None) + else: + p.close() + raise NotImplementedError + except KeyboardInterrupt: + p.close() class SubprocVectorEnv(BaseVectorEnv): diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 882b316..4301913 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -1,9 +1,11 @@ from tianshou.policy.base import BasePolicy from tianshou.policy.dqn import DQNPolicy from tianshou.policy.policy_gradient import PGPolicy +from tianshou.policy.a2c import A2CPolicy __all__ = [ 'BasePolicy', 'DQNPolicy', 'PGPolicy', + 'A2CPolicy', ] diff --git a/tianshou/policy/a2c.py b/tianshou/policy/a2c.py new file mode 100644 index 0000000..f546a62 --- /dev/null +++ b/tianshou/policy/a2c.py @@ -0,0 +1,42 @@ +import torch +import torch.nn.functional as F + +from tianshou.data import Batch +from tianshou.policy import PGPolicy + + +class A2CPolicy(PGPolicy): + """docstring for A2CPolicy""" + + def __init__(self, model, optim, dist_fn=torch.distributions.Categorical, + discount_factor=0.99, vf_coef=.5, entropy_coef=.01): + super().__init__(model, optim, dist_fn, discount_factor) + self._w_value = vf_coef + self._w_entropy = entropy_coef + + def __call__(self, batch, state=None): + logits, value, h = self.model(batch.obs, state=state, info=batch.info) + logits = F.softmax(logits, dim=1) + dist = self.dist_fn(logits) + act = dist.sample().detach().cpu().numpy() + return Batch(logits=logits, act=act, state=h, dist=dist, value=value) + + def learn(self, batch, batch_size=None): + losses = [] + for b in batch.split(batch_size): + self.optim.zero_grad() + result = self(b) + dist = result.dist + v = result.value + a = torch.tensor(b.act, device=dist.logits.device) + r = torch.tensor(b.returns, device=dist.logits.device) + actor_loss = -(dist.log_prob(a) * (r - v).detach()).mean() + critic_loss = (r - v).pow(2).mean() + entropy_loss = dist.entropy().mean() + loss = actor_loss \ + + self._w_value * critic_loss \ + - self._w_entropy * entropy_loss + loss.backward() + self.optim.step() + losses.append(loss.detach().cpu().numpy()) + return losses diff --git a/tianshou/policy/policy_gradient.py b/tianshou/policy/policy_gradient.py index b352b5d..8b698bd 100644 --- a/tianshou/policy/policy_gradient.py +++ b/tianshou/policy/policy_gradient.py @@ -11,7 +11,7 @@ class PGPolicy(BasePolicy, nn.Module): """docstring for PGPolicy""" def __init__(self, model, optim, dist_fn=torch.distributions.Categorical, - discount_factor=0.99, normalized_reward=True): + discount_factor=0.99): super().__init__() self.model = model self.optim = optim @@ -19,15 +19,10 @@ class PGPolicy(BasePolicy, nn.Module): self._eps = np.finfo(np.float32).eps.item() assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]' self._gamma = discount_factor - self._rew_norm = normalized_reward def process_fn(self, batch, buffer, indice): - batch_size = len(batch.rew) - returns = self._vanilla_returns(batch, batch_size) - # returns = self._vectorized_returns(batch, batch_size) - returns = returns - returns.mean() - if self._rew_norm: - returns = returns / (returns.std() + self._eps) + returns = self._vanilla_returns(batch) + # returns = self._vectorized_returns(batch) batch.update(returns=returns) return batch @@ -40,6 +35,8 @@ class PGPolicy(BasePolicy, nn.Module): def learn(self, batch, batch_size=None): losses = [] + batch.returns = (batch.returns - batch.returns.mean()) \ + / (batch.returns.std() + self._eps) for b in batch.split(batch_size): self.optim.zero_grad() dist = self(b).dist @@ -51,21 +48,22 @@ class PGPolicy(BasePolicy, nn.Module): losses.append(loss.detach().cpu().numpy()) return losses - def _vanilla_returns(self, batch, batch_size): + def _vanilla_returns(self, batch): returns = batch.rew[:] last = 0 - for i in range(batch_size - 1, -1, -1): + for i in range(len(returns) - 1, -1, -1): if not batch.done[i]: returns[i] += self._gamma * last last = returns[i] return returns - def _vectorized_returns(self, batch, batch_size): + def _vectorized_returns(self, batch): # according to my tests, it is slower than vanilla # import scipy.signal convolve = np.convolve # convolve = scipy.signal.convolve rew = batch.rew[::-1] + batch_size = len(rew) gammas = self._gamma ** np.arange(batch_size) c = convolve(rew, gammas)[:batch_size] T = np.where(batch.done[::-1])[0]