a2c
This commit is contained in:
		
							parent
							
								
									fd621971e5
								
							
						
					
					
						commit
						6e563fe61a
					
				
							
								
								
									
										160
									
								
								test/test_a2c.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										160
									
								
								test/test_a2c.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,160 @@ | ||||
| import gym | ||||
| import time | ||||
| import tqdm | ||||
| import torch | ||||
| import argparse | ||||
| import numpy as np | ||||
| from torch import nn | ||||
| from torch.utils.tensorboard import SummaryWriter | ||||
| 
 | ||||
| from tianshou.policy import A2CPolicy | ||||
| from tianshou.env import SubprocVectorEnv | ||||
| from tianshou.utils import tqdm_config, MovAvg | ||||
| from tianshou.data import Collector, ReplayBuffer | ||||
| 
 | ||||
| 
 | ||||
| class Net(nn.Module): | ||||
|     def __init__(self, layer_num, state_shape, action_shape, device='cpu'): | ||||
|         super().__init__() | ||||
|         self.device = device | ||||
|         self.model = [ | ||||
|             nn.Linear(np.prod(state_shape), 128), | ||||
|             nn.ReLU(inplace=True)] | ||||
|         for i in range(layer_num): | ||||
|             self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] | ||||
|         self.actor = self.model + [nn.Linear(128, np.prod(action_shape))] | ||||
|         self.critic = self.model + [nn.Linear(128, 1)] | ||||
|         self.actor = nn.Sequential(*self.actor) | ||||
|         self.critic = nn.Sequential(*self.critic) | ||||
| 
 | ||||
|     def forward(self, s, **kwargs): | ||||
|         s = torch.tensor(s, device=self.device, dtype=torch.float) | ||||
|         batch = s.shape[0] | ||||
|         s = s.view(batch, -1) | ||||
|         logits = self.actor(s) | ||||
|         value = self.critic(s) | ||||
|         return logits, value, None | ||||
| 
 | ||||
| 
 | ||||
| def get_args(): | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument('--task', type=str, default='CartPole-v0') | ||||
|     parser.add_argument('--seed', type=int, default=1626) | ||||
|     parser.add_argument('--buffer-size', type=int, default=20000) | ||||
|     parser.add_argument('--lr', type=float, default=1e-3) | ||||
|     parser.add_argument('--gamma', type=float, default=0.9) | ||||
|     parser.add_argument('--epoch', type=int, default=100) | ||||
|     parser.add_argument('--step-per-epoch', type=int, default=320) | ||||
|     parser.add_argument('--collect-per-step', type=int, default=10) | ||||
|     parser.add_argument('--batch-size', type=int, default=64) | ||||
|     parser.add_argument('--layer-num', type=int, default=2) | ||||
|     parser.add_argument('--training-num', type=int, default=8) | ||||
|     parser.add_argument('--test-num', type=int, default=100) | ||||
|     parser.add_argument('--logdir', type=str, default='log') | ||||
|     parser.add_argument( | ||||
|         '--device', type=str, | ||||
|         default='cuda' if torch.cuda.is_available() else 'cpu') | ||||
|     # a2c special | ||||
|     parser.add_argument('--vf-coef', type=float, default=0.5) | ||||
|     parser.add_argument('--entropy-coef', type=float, default=0.001) | ||||
|     args = parser.parse_known_args()[0] | ||||
|     return args | ||||
| 
 | ||||
| 
 | ||||
| def test_a2c(args=get_args()): | ||||
|     env = gym.make(args.task) | ||||
|     args.state_shape = env.observation_space.shape or env.observation_space.n | ||||
|     args.action_shape = env.action_space.shape or env.action_space.n | ||||
|     # train_envs = gym.make(args.task) | ||||
|     train_envs = SubprocVectorEnv( | ||||
|         [lambda: gym.make(args.task) for _ in range(args.training_num)], | ||||
|         reset_after_done=True) | ||||
|     # test_envs = gym.make(args.task) | ||||
|     test_envs = SubprocVectorEnv( | ||||
|         [lambda: gym.make(args.task) for _ in range(args.test_num)], | ||||
|         reset_after_done=False) | ||||
|     # seed | ||||
|     np.random.seed(args.seed) | ||||
|     torch.manual_seed(args.seed) | ||||
|     train_envs.seed(args.seed) | ||||
|     test_envs.seed(args.seed) | ||||
|     # model | ||||
|     net = Net(args.layer_num, args.state_shape, args.action_shape, args.device) | ||||
|     net = net.to(args.device) | ||||
|     optim = torch.optim.Adam(net.parameters(), lr=args.lr) | ||||
|     dist = torch.distributions.Categorical | ||||
|     policy = A2CPolicy( | ||||
|         net, optim, dist, args.gamma, | ||||
|         vf_coef=args.vf_coef, | ||||
|         entropy_coef=args.entropy_coef) | ||||
|     # collector | ||||
|     training_collector = Collector( | ||||
|         policy, train_envs, ReplayBuffer(args.buffer_size)) | ||||
|     test_collector = Collector( | ||||
|         policy, test_envs, ReplayBuffer(args.buffer_size), args.test_num) | ||||
|     # log | ||||
|     stat_loss = MovAvg() | ||||
|     global_step = 0 | ||||
|     writer = SummaryWriter(args.logdir) | ||||
|     best_epoch = -1 | ||||
|     best_reward = -1e10 | ||||
|     start_time = time.time() | ||||
|     for epoch in range(1, 1 + args.epoch): | ||||
|         desc = f'Epoch #{epoch}' | ||||
|         # train | ||||
|         policy.train() | ||||
|         with tqdm.tqdm( | ||||
|                 total=args.step_per_epoch, desc=desc, **tqdm_config) as t: | ||||
|             while t.n < t.total: | ||||
|                 result = training_collector.collect( | ||||
|                     n_episode=args.collect_per_step) | ||||
|                 losses = policy.learn( | ||||
|                     training_collector.sample(0), args.batch_size) | ||||
|                 training_collector.reset_buffer() | ||||
|                 global_step += len(losses) | ||||
|                 t.update(len(losses)) | ||||
|                 stat_loss.add(losses) | ||||
|                 writer.add_scalar( | ||||
|                     'reward', result['reward'], global_step=global_step) | ||||
|                 writer.add_scalar( | ||||
|                     'length', result['length'], global_step=global_step) | ||||
|                 writer.add_scalar( | ||||
|                     'loss', stat_loss.get(), global_step=global_step) | ||||
|                 writer.add_scalar( | ||||
|                     'speed', result['speed'], global_step=global_step) | ||||
|                 t.set_postfix(loss=f'{stat_loss.get():.6f}', | ||||
|                               reward=f'{result["reward"]:.6f}', | ||||
|                               length=f'{result["length"]:.2f}', | ||||
|                               speed=f'{result["speed"]:.2f}') | ||||
|         # eval | ||||
|         test_collector.reset_env() | ||||
|         test_collector.reset_buffer() | ||||
|         policy.eval() | ||||
|         result = test_collector.collect(n_episode=args.test_num) | ||||
|         if best_reward < result['reward']: | ||||
|             best_reward = result['reward'] | ||||
|             best_epoch = epoch | ||||
|         print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, ' | ||||
|               f'best_reward: {best_reward:.6f} in #{best_epoch}') | ||||
|         if best_reward >= env.spec.reward_threshold: | ||||
|             break | ||||
|     assert best_reward >= env.spec.reward_threshold | ||||
|     training_collector.close() | ||||
|     test_collector.close() | ||||
|     if __name__ == '__main__': | ||||
|         train_cnt = training_collector.collect_step | ||||
|         test_cnt = test_collector.collect_step | ||||
|         duration = time.time() - start_time | ||||
|         print(f'Collect {train_cnt} training frame and {test_cnt} test frame ' | ||||
|               f'in {duration:.2f}s, ' | ||||
|               f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s') | ||||
|         # Let's watch its performance! | ||||
|         env = gym.make(args.task) | ||||
|         test_collector = Collector(policy, env) | ||||
|         result = test_collector.collect(n_episode=1, render=1 / 35) | ||||
|         print(f'Final reward: {result["reward"]}, length: {result["length"]}') | ||||
|         test_collector.close() | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == '__main__': | ||||
|     test_a2c() | ||||
| @ -26,8 +26,7 @@ def compute_return_base(batch, aa=None, bb=None, gamma=0.1): | ||||
| 
 | ||||
| 
 | ||||
| def test_fn(size=2560): | ||||
|     policy = PGPolicy( | ||||
|         None, None, None, discount_factor=0.1, normalized_reward=False) | ||||
|     policy = PGPolicy(None, None, None, discount_factor=0.1) | ||||
|     fn = policy.process_fn | ||||
|     # fn = compute_return_base | ||||
|     batch = Batch( | ||||
| @ -36,7 +35,6 @@ def test_fn(size=2560): | ||||
|     ) | ||||
|     batch = fn(batch, None, None) | ||||
|     ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7]) | ||||
|     ans -= ans.mean() | ||||
|     assert abs(batch.returns - ans).sum() <= 1e-5 | ||||
|     batch = Batch( | ||||
|         done=np.array([0, 1, 0, 1, 0, 1, 0.]), | ||||
| @ -44,7 +42,6 @@ def test_fn(size=2560): | ||||
|     ) | ||||
|     batch = fn(batch, None, None) | ||||
|     ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5]) | ||||
|     ans -= ans.mean() | ||||
|     assert abs(batch.returns - ans).sum() <= 1e-5 | ||||
|     batch = Batch( | ||||
|         done=np.array([0, 1, 0, 1, 0, 0, 1.]), | ||||
| @ -52,7 +49,6 @@ def test_fn(size=2560): | ||||
|     ) | ||||
|     batch = fn(batch, None, None) | ||||
|     ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]) | ||||
|     ans -= ans.mean() | ||||
|     assert abs(batch.returns - ans).sum() <= 1e-5 | ||||
|     if __name__ == '__main__': | ||||
|         batch = Batch( | ||||
|  | ||||
							
								
								
									
										46
									
								
								tianshou/env/wrapper.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										46
									
								
								tianshou/env/wrapper.py
									
									
									
									
										vendored
									
									
								
							| @ -143,27 +143,31 @@ def worker(parent, p, env_fn_wrapper, reset_after_done): | ||||
|     parent.close() | ||||
|     env = env_fn_wrapper.data() | ||||
|     done = False | ||||
|     while True: | ||||
|         cmd, data = p.recv() | ||||
|         if cmd == 'step': | ||||
|             if reset_after_done or not done: | ||||
|                 obs, rew, done, info = env.step(data) | ||||
|             if reset_after_done and done: | ||||
|                 # s_ is useless when episode finishes | ||||
|                 obs = env.reset() | ||||
|             p.send([obs, rew, done, info]) | ||||
|         elif cmd == 'reset': | ||||
|             done = False | ||||
|             p.send(env.reset()) | ||||
|         elif cmd == 'close': | ||||
|             p.close() | ||||
|             break | ||||
|         elif cmd == 'render': | ||||
|             p.send(env.render() if hasattr(env, 'render') else None) | ||||
|         elif cmd == 'seed': | ||||
|             p.send(env.seed(data) if hasattr(env, 'seed') else None) | ||||
|         else: | ||||
|             raise NotImplementedError | ||||
|     try: | ||||
|         while True: | ||||
|             cmd, data = p.recv() | ||||
|             if cmd == 'step': | ||||
|                 if reset_after_done or not done: | ||||
|                     obs, rew, done, info = env.step(data) | ||||
|                 if reset_after_done and done: | ||||
|                     # s_ is useless when episode finishes | ||||
|                     obs = env.reset() | ||||
|                 p.send([obs, rew, done, info]) | ||||
|             elif cmd == 'reset': | ||||
|                 done = False | ||||
|                 p.send(env.reset()) | ||||
|             elif cmd == 'close': | ||||
|                 p.close() | ||||
|                 break | ||||
|             elif cmd == 'render': | ||||
|                 p.send(env.render() if hasattr(env, 'render') else None) | ||||
|             elif cmd == 'seed': | ||||
|                 p.send(env.seed(data) if hasattr(env, 'seed') else None) | ||||
|             else: | ||||
|                 p.close() | ||||
|                 raise NotImplementedError | ||||
|     except KeyboardInterrupt: | ||||
|         p.close() | ||||
| 
 | ||||
| 
 | ||||
| class SubprocVectorEnv(BaseVectorEnv): | ||||
|  | ||||
| @ -1,9 +1,11 @@ | ||||
| from tianshou.policy.base import BasePolicy | ||||
| from tianshou.policy.dqn import DQNPolicy | ||||
| from tianshou.policy.policy_gradient import PGPolicy | ||||
| from tianshou.policy.a2c import A2CPolicy | ||||
| 
 | ||||
| __all__ = [ | ||||
|     'BasePolicy', | ||||
|     'DQNPolicy', | ||||
|     'PGPolicy', | ||||
|     'A2CPolicy', | ||||
| ] | ||||
|  | ||||
							
								
								
									
										42
									
								
								tianshou/policy/a2c.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								tianshou/policy/a2c.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,42 @@ | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| 
 | ||||
| from tianshou.data import Batch | ||||
| from tianshou.policy import PGPolicy | ||||
| 
 | ||||
| 
 | ||||
| class A2CPolicy(PGPolicy): | ||||
|     """docstring for A2CPolicy""" | ||||
| 
 | ||||
|     def __init__(self, model, optim, dist_fn=torch.distributions.Categorical, | ||||
|                  discount_factor=0.99, vf_coef=.5, entropy_coef=.01): | ||||
|         super().__init__(model, optim, dist_fn, discount_factor) | ||||
|         self._w_value = vf_coef | ||||
|         self._w_entropy = entropy_coef | ||||
| 
 | ||||
|     def __call__(self, batch, state=None): | ||||
|         logits, value, h = self.model(batch.obs, state=state, info=batch.info) | ||||
|         logits = F.softmax(logits, dim=1) | ||||
|         dist = self.dist_fn(logits) | ||||
|         act = dist.sample().detach().cpu().numpy() | ||||
|         return Batch(logits=logits, act=act, state=h, dist=dist, value=value) | ||||
| 
 | ||||
|     def learn(self, batch, batch_size=None): | ||||
|         losses = [] | ||||
|         for b in batch.split(batch_size): | ||||
|             self.optim.zero_grad() | ||||
|             result = self(b) | ||||
|             dist = result.dist | ||||
|             v = result.value | ||||
|             a = torch.tensor(b.act, device=dist.logits.device) | ||||
|             r = torch.tensor(b.returns, device=dist.logits.device) | ||||
|             actor_loss = -(dist.log_prob(a) * (r - v).detach()).mean() | ||||
|             critic_loss = (r - v).pow(2).mean() | ||||
|             entropy_loss = dist.entropy().mean() | ||||
|             loss = actor_loss \ | ||||
|                 + self._w_value * critic_loss \ | ||||
|                 - self._w_entropy * entropy_loss | ||||
|             loss.backward() | ||||
|             self.optim.step() | ||||
|             losses.append(loss.detach().cpu().numpy()) | ||||
|         return losses | ||||
| @ -11,7 +11,7 @@ class PGPolicy(BasePolicy, nn.Module): | ||||
|     """docstring for PGPolicy""" | ||||
| 
 | ||||
|     def __init__(self, model, optim, dist_fn=torch.distributions.Categorical, | ||||
|                  discount_factor=0.99, normalized_reward=True): | ||||
|                  discount_factor=0.99): | ||||
|         super().__init__() | ||||
|         self.model = model | ||||
|         self.optim = optim | ||||
| @ -19,15 +19,10 @@ class PGPolicy(BasePolicy, nn.Module): | ||||
|         self._eps = np.finfo(np.float32).eps.item() | ||||
|         assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]' | ||||
|         self._gamma = discount_factor | ||||
|         self._rew_norm = normalized_reward | ||||
| 
 | ||||
|     def process_fn(self, batch, buffer, indice): | ||||
|         batch_size = len(batch.rew) | ||||
|         returns = self._vanilla_returns(batch, batch_size) | ||||
|         # returns = self._vectorized_returns(batch, batch_size) | ||||
|         returns = returns - returns.mean() | ||||
|         if self._rew_norm: | ||||
|             returns = returns / (returns.std() + self._eps) | ||||
|         returns = self._vanilla_returns(batch) | ||||
|         # returns = self._vectorized_returns(batch) | ||||
|         batch.update(returns=returns) | ||||
|         return batch | ||||
| 
 | ||||
| @ -40,6 +35,8 @@ class PGPolicy(BasePolicy, nn.Module): | ||||
| 
 | ||||
|     def learn(self, batch, batch_size=None): | ||||
|         losses = [] | ||||
|         batch.returns = (batch.returns - batch.returns.mean()) \ | ||||
|             / (batch.returns.std() + self._eps) | ||||
|         for b in batch.split(batch_size): | ||||
|             self.optim.zero_grad() | ||||
|             dist = self(b).dist | ||||
| @ -51,21 +48,22 @@ class PGPolicy(BasePolicy, nn.Module): | ||||
|             losses.append(loss.detach().cpu().numpy()) | ||||
|         return losses | ||||
| 
 | ||||
|     def _vanilla_returns(self, batch, batch_size): | ||||
|     def _vanilla_returns(self, batch): | ||||
|         returns = batch.rew[:] | ||||
|         last = 0 | ||||
|         for i in range(batch_size - 1, -1, -1): | ||||
|         for i in range(len(returns) - 1, -1, -1): | ||||
|             if not batch.done[i]: | ||||
|                 returns[i] += self._gamma * last | ||||
|             last = returns[i] | ||||
|         return returns | ||||
| 
 | ||||
|     def _vectorized_returns(self, batch, batch_size): | ||||
|     def _vectorized_returns(self, batch): | ||||
|         # according to my tests, it is slower than vanilla | ||||
|         # import scipy.signal | ||||
|         convolve = np.convolve | ||||
|         # convolve = scipy.signal.convolve | ||||
|         rew = batch.rew[::-1] | ||||
|         batch_size = len(rew) | ||||
|         gammas = self._gamma ** np.arange(batch_size) | ||||
|         c = convolve(rew, gammas)[:batch_size] | ||||
|         T = np.where(batch.done[::-1])[0] | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user