a2c
This commit is contained in:
		
							parent
							
								
									fd621971e5
								
							
						
					
					
						commit
						6e563fe61a
					
				
							
								
								
									
										160
									
								
								test/test_a2c.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										160
									
								
								test/test_a2c.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,160 @@ | |||||||
|  | import gym | ||||||
|  | import time | ||||||
|  | import tqdm | ||||||
|  | import torch | ||||||
|  | import argparse | ||||||
|  | import numpy as np | ||||||
|  | from torch import nn | ||||||
|  | from torch.utils.tensorboard import SummaryWriter | ||||||
|  | 
 | ||||||
|  | from tianshou.policy import A2CPolicy | ||||||
|  | from tianshou.env import SubprocVectorEnv | ||||||
|  | from tianshou.utils import tqdm_config, MovAvg | ||||||
|  | from tianshou.data import Collector, ReplayBuffer | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class Net(nn.Module): | ||||||
|  |     def __init__(self, layer_num, state_shape, action_shape, device='cpu'): | ||||||
|  |         super().__init__() | ||||||
|  |         self.device = device | ||||||
|  |         self.model = [ | ||||||
|  |             nn.Linear(np.prod(state_shape), 128), | ||||||
|  |             nn.ReLU(inplace=True)] | ||||||
|  |         for i in range(layer_num): | ||||||
|  |             self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] | ||||||
|  |         self.actor = self.model + [nn.Linear(128, np.prod(action_shape))] | ||||||
|  |         self.critic = self.model + [nn.Linear(128, 1)] | ||||||
|  |         self.actor = nn.Sequential(*self.actor) | ||||||
|  |         self.critic = nn.Sequential(*self.critic) | ||||||
|  | 
 | ||||||
|  |     def forward(self, s, **kwargs): | ||||||
|  |         s = torch.tensor(s, device=self.device, dtype=torch.float) | ||||||
|  |         batch = s.shape[0] | ||||||
|  |         s = s.view(batch, -1) | ||||||
|  |         logits = self.actor(s) | ||||||
|  |         value = self.critic(s) | ||||||
|  |         return logits, value, None | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_args(): | ||||||
|  |     parser = argparse.ArgumentParser() | ||||||
|  |     parser.add_argument('--task', type=str, default='CartPole-v0') | ||||||
|  |     parser.add_argument('--seed', type=int, default=1626) | ||||||
|  |     parser.add_argument('--buffer-size', type=int, default=20000) | ||||||
|  |     parser.add_argument('--lr', type=float, default=1e-3) | ||||||
|  |     parser.add_argument('--gamma', type=float, default=0.9) | ||||||
|  |     parser.add_argument('--epoch', type=int, default=100) | ||||||
|  |     parser.add_argument('--step-per-epoch', type=int, default=320) | ||||||
|  |     parser.add_argument('--collect-per-step', type=int, default=10) | ||||||
|  |     parser.add_argument('--batch-size', type=int, default=64) | ||||||
|  |     parser.add_argument('--layer-num', type=int, default=2) | ||||||
|  |     parser.add_argument('--training-num', type=int, default=8) | ||||||
|  |     parser.add_argument('--test-num', type=int, default=100) | ||||||
|  |     parser.add_argument('--logdir', type=str, default='log') | ||||||
|  |     parser.add_argument( | ||||||
|  |         '--device', type=str, | ||||||
|  |         default='cuda' if torch.cuda.is_available() else 'cpu') | ||||||
|  |     # a2c special | ||||||
|  |     parser.add_argument('--vf-coef', type=float, default=0.5) | ||||||
|  |     parser.add_argument('--entropy-coef', type=float, default=0.001) | ||||||
|  |     args = parser.parse_known_args()[0] | ||||||
|  |     return args | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_a2c(args=get_args()): | ||||||
|  |     env = gym.make(args.task) | ||||||
|  |     args.state_shape = env.observation_space.shape or env.observation_space.n | ||||||
|  |     args.action_shape = env.action_space.shape or env.action_space.n | ||||||
|  |     # train_envs = gym.make(args.task) | ||||||
|  |     train_envs = SubprocVectorEnv( | ||||||
|  |         [lambda: gym.make(args.task) for _ in range(args.training_num)], | ||||||
|  |         reset_after_done=True) | ||||||
|  |     # test_envs = gym.make(args.task) | ||||||
|  |     test_envs = SubprocVectorEnv( | ||||||
|  |         [lambda: gym.make(args.task) for _ in range(args.test_num)], | ||||||
|  |         reset_after_done=False) | ||||||
|  |     # seed | ||||||
|  |     np.random.seed(args.seed) | ||||||
|  |     torch.manual_seed(args.seed) | ||||||
|  |     train_envs.seed(args.seed) | ||||||
|  |     test_envs.seed(args.seed) | ||||||
|  |     # model | ||||||
|  |     net = Net(args.layer_num, args.state_shape, args.action_shape, args.device) | ||||||
|  |     net = net.to(args.device) | ||||||
|  |     optim = torch.optim.Adam(net.parameters(), lr=args.lr) | ||||||
|  |     dist = torch.distributions.Categorical | ||||||
|  |     policy = A2CPolicy( | ||||||
|  |         net, optim, dist, args.gamma, | ||||||
|  |         vf_coef=args.vf_coef, | ||||||
|  |         entropy_coef=args.entropy_coef) | ||||||
|  |     # collector | ||||||
|  |     training_collector = Collector( | ||||||
|  |         policy, train_envs, ReplayBuffer(args.buffer_size)) | ||||||
|  |     test_collector = Collector( | ||||||
|  |         policy, test_envs, ReplayBuffer(args.buffer_size), args.test_num) | ||||||
|  |     # log | ||||||
|  |     stat_loss = MovAvg() | ||||||
|  |     global_step = 0 | ||||||
|  |     writer = SummaryWriter(args.logdir) | ||||||
|  |     best_epoch = -1 | ||||||
|  |     best_reward = -1e10 | ||||||
|  |     start_time = time.time() | ||||||
|  |     for epoch in range(1, 1 + args.epoch): | ||||||
|  |         desc = f'Epoch #{epoch}' | ||||||
|  |         # train | ||||||
|  |         policy.train() | ||||||
|  |         with tqdm.tqdm( | ||||||
|  |                 total=args.step_per_epoch, desc=desc, **tqdm_config) as t: | ||||||
|  |             while t.n < t.total: | ||||||
|  |                 result = training_collector.collect( | ||||||
|  |                     n_episode=args.collect_per_step) | ||||||
|  |                 losses = policy.learn( | ||||||
|  |                     training_collector.sample(0), args.batch_size) | ||||||
|  |                 training_collector.reset_buffer() | ||||||
|  |                 global_step += len(losses) | ||||||
|  |                 t.update(len(losses)) | ||||||
|  |                 stat_loss.add(losses) | ||||||
|  |                 writer.add_scalar( | ||||||
|  |                     'reward', result['reward'], global_step=global_step) | ||||||
|  |                 writer.add_scalar( | ||||||
|  |                     'length', result['length'], global_step=global_step) | ||||||
|  |                 writer.add_scalar( | ||||||
|  |                     'loss', stat_loss.get(), global_step=global_step) | ||||||
|  |                 writer.add_scalar( | ||||||
|  |                     'speed', result['speed'], global_step=global_step) | ||||||
|  |                 t.set_postfix(loss=f'{stat_loss.get():.6f}', | ||||||
|  |                               reward=f'{result["reward"]:.6f}', | ||||||
|  |                               length=f'{result["length"]:.2f}', | ||||||
|  |                               speed=f'{result["speed"]:.2f}') | ||||||
|  |         # eval | ||||||
|  |         test_collector.reset_env() | ||||||
|  |         test_collector.reset_buffer() | ||||||
|  |         policy.eval() | ||||||
|  |         result = test_collector.collect(n_episode=args.test_num) | ||||||
|  |         if best_reward < result['reward']: | ||||||
|  |             best_reward = result['reward'] | ||||||
|  |             best_epoch = epoch | ||||||
|  |         print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, ' | ||||||
|  |               f'best_reward: {best_reward:.6f} in #{best_epoch}') | ||||||
|  |         if best_reward >= env.spec.reward_threshold: | ||||||
|  |             break | ||||||
|  |     assert best_reward >= env.spec.reward_threshold | ||||||
|  |     training_collector.close() | ||||||
|  |     test_collector.close() | ||||||
|  |     if __name__ == '__main__': | ||||||
|  |         train_cnt = training_collector.collect_step | ||||||
|  |         test_cnt = test_collector.collect_step | ||||||
|  |         duration = time.time() - start_time | ||||||
|  |         print(f'Collect {train_cnt} training frame and {test_cnt} test frame ' | ||||||
|  |               f'in {duration:.2f}s, ' | ||||||
|  |               f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s') | ||||||
|  |         # Let's watch its performance! | ||||||
|  |         env = gym.make(args.task) | ||||||
|  |         test_collector = Collector(policy, env) | ||||||
|  |         result = test_collector.collect(n_episode=1, render=1 / 35) | ||||||
|  |         print(f'Final reward: {result["reward"]}, length: {result["length"]}') | ||||||
|  |         test_collector.close() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     test_a2c() | ||||||
| @ -26,8 +26,7 @@ def compute_return_base(batch, aa=None, bb=None, gamma=0.1): | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_fn(size=2560): | def test_fn(size=2560): | ||||||
|     policy = PGPolicy( |     policy = PGPolicy(None, None, None, discount_factor=0.1) | ||||||
|         None, None, None, discount_factor=0.1, normalized_reward=False) |  | ||||||
|     fn = policy.process_fn |     fn = policy.process_fn | ||||||
|     # fn = compute_return_base |     # fn = compute_return_base | ||||||
|     batch = Batch( |     batch = Batch( | ||||||
| @ -36,7 +35,6 @@ def test_fn(size=2560): | |||||||
|     ) |     ) | ||||||
|     batch = fn(batch, None, None) |     batch = fn(batch, None, None) | ||||||
|     ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7]) |     ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7]) | ||||||
|     ans -= ans.mean() |  | ||||||
|     assert abs(batch.returns - ans).sum() <= 1e-5 |     assert abs(batch.returns - ans).sum() <= 1e-5 | ||||||
|     batch = Batch( |     batch = Batch( | ||||||
|         done=np.array([0, 1, 0, 1, 0, 1, 0.]), |         done=np.array([0, 1, 0, 1, 0, 1, 0.]), | ||||||
| @ -44,7 +42,6 @@ def test_fn(size=2560): | |||||||
|     ) |     ) | ||||||
|     batch = fn(batch, None, None) |     batch = fn(batch, None, None) | ||||||
|     ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5]) |     ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5]) | ||||||
|     ans -= ans.mean() |  | ||||||
|     assert abs(batch.returns - ans).sum() <= 1e-5 |     assert abs(batch.returns - ans).sum() <= 1e-5 | ||||||
|     batch = Batch( |     batch = Batch( | ||||||
|         done=np.array([0, 1, 0, 1, 0, 0, 1.]), |         done=np.array([0, 1, 0, 1, 0, 0, 1.]), | ||||||
| @ -52,7 +49,6 @@ def test_fn(size=2560): | |||||||
|     ) |     ) | ||||||
|     batch = fn(batch, None, None) |     batch = fn(batch, None, None) | ||||||
|     ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]) |     ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]) | ||||||
|     ans -= ans.mean() |  | ||||||
|     assert abs(batch.returns - ans).sum() <= 1e-5 |     assert abs(batch.returns - ans).sum() <= 1e-5 | ||||||
|     if __name__ == '__main__': |     if __name__ == '__main__': | ||||||
|         batch = Batch( |         batch = Batch( | ||||||
|  | |||||||
							
								
								
									
										46
									
								
								tianshou/env/wrapper.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										46
									
								
								tianshou/env/wrapper.py
									
									
									
									
										vendored
									
									
								
							| @ -143,27 +143,31 @@ def worker(parent, p, env_fn_wrapper, reset_after_done): | |||||||
|     parent.close() |     parent.close() | ||||||
|     env = env_fn_wrapper.data() |     env = env_fn_wrapper.data() | ||||||
|     done = False |     done = False | ||||||
|     while True: |     try: | ||||||
|         cmd, data = p.recv() |         while True: | ||||||
|         if cmd == 'step': |             cmd, data = p.recv() | ||||||
|             if reset_after_done or not done: |             if cmd == 'step': | ||||||
|                 obs, rew, done, info = env.step(data) |                 if reset_after_done or not done: | ||||||
|             if reset_after_done and done: |                     obs, rew, done, info = env.step(data) | ||||||
|                 # s_ is useless when episode finishes |                 if reset_after_done and done: | ||||||
|                 obs = env.reset() |                     # s_ is useless when episode finishes | ||||||
|             p.send([obs, rew, done, info]) |                     obs = env.reset() | ||||||
|         elif cmd == 'reset': |                 p.send([obs, rew, done, info]) | ||||||
|             done = False |             elif cmd == 'reset': | ||||||
|             p.send(env.reset()) |                 done = False | ||||||
|         elif cmd == 'close': |                 p.send(env.reset()) | ||||||
|             p.close() |             elif cmd == 'close': | ||||||
|             break |                 p.close() | ||||||
|         elif cmd == 'render': |                 break | ||||||
|             p.send(env.render() if hasattr(env, 'render') else None) |             elif cmd == 'render': | ||||||
|         elif cmd == 'seed': |                 p.send(env.render() if hasattr(env, 'render') else None) | ||||||
|             p.send(env.seed(data) if hasattr(env, 'seed') else None) |             elif cmd == 'seed': | ||||||
|         else: |                 p.send(env.seed(data) if hasattr(env, 'seed') else None) | ||||||
|             raise NotImplementedError |             else: | ||||||
|  |                 p.close() | ||||||
|  |                 raise NotImplementedError | ||||||
|  |     except KeyboardInterrupt: | ||||||
|  |         p.close() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class SubprocVectorEnv(BaseVectorEnv): | class SubprocVectorEnv(BaseVectorEnv): | ||||||
|  | |||||||
| @ -1,9 +1,11 @@ | |||||||
| from tianshou.policy.base import BasePolicy | from tianshou.policy.base import BasePolicy | ||||||
| from tianshou.policy.dqn import DQNPolicy | from tianshou.policy.dqn import DQNPolicy | ||||||
| from tianshou.policy.policy_gradient import PGPolicy | from tianshou.policy.policy_gradient import PGPolicy | ||||||
|  | from tianshou.policy.a2c import A2CPolicy | ||||||
| 
 | 
 | ||||||
| __all__ = [ | __all__ = [ | ||||||
|     'BasePolicy', |     'BasePolicy', | ||||||
|     'DQNPolicy', |     'DQNPolicy', | ||||||
|     'PGPolicy', |     'PGPolicy', | ||||||
|  |     'A2CPolicy', | ||||||
| ] | ] | ||||||
|  | |||||||
							
								
								
									
										42
									
								
								tianshou/policy/a2c.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								tianshou/policy/a2c.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,42 @@ | |||||||
|  | import torch | ||||||
|  | import torch.nn.functional as F | ||||||
|  | 
 | ||||||
|  | from tianshou.data import Batch | ||||||
|  | from tianshou.policy import PGPolicy | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class A2CPolicy(PGPolicy): | ||||||
|  |     """docstring for A2CPolicy""" | ||||||
|  | 
 | ||||||
|  |     def __init__(self, model, optim, dist_fn=torch.distributions.Categorical, | ||||||
|  |                  discount_factor=0.99, vf_coef=.5, entropy_coef=.01): | ||||||
|  |         super().__init__(model, optim, dist_fn, discount_factor) | ||||||
|  |         self._w_value = vf_coef | ||||||
|  |         self._w_entropy = entropy_coef | ||||||
|  | 
 | ||||||
|  |     def __call__(self, batch, state=None): | ||||||
|  |         logits, value, h = self.model(batch.obs, state=state, info=batch.info) | ||||||
|  |         logits = F.softmax(logits, dim=1) | ||||||
|  |         dist = self.dist_fn(logits) | ||||||
|  |         act = dist.sample().detach().cpu().numpy() | ||||||
|  |         return Batch(logits=logits, act=act, state=h, dist=dist, value=value) | ||||||
|  | 
 | ||||||
|  |     def learn(self, batch, batch_size=None): | ||||||
|  |         losses = [] | ||||||
|  |         for b in batch.split(batch_size): | ||||||
|  |             self.optim.zero_grad() | ||||||
|  |             result = self(b) | ||||||
|  |             dist = result.dist | ||||||
|  |             v = result.value | ||||||
|  |             a = torch.tensor(b.act, device=dist.logits.device) | ||||||
|  |             r = torch.tensor(b.returns, device=dist.logits.device) | ||||||
|  |             actor_loss = -(dist.log_prob(a) * (r - v).detach()).mean() | ||||||
|  |             critic_loss = (r - v).pow(2).mean() | ||||||
|  |             entropy_loss = dist.entropy().mean() | ||||||
|  |             loss = actor_loss \ | ||||||
|  |                 + self._w_value * critic_loss \ | ||||||
|  |                 - self._w_entropy * entropy_loss | ||||||
|  |             loss.backward() | ||||||
|  |             self.optim.step() | ||||||
|  |             losses.append(loss.detach().cpu().numpy()) | ||||||
|  |         return losses | ||||||
| @ -11,7 +11,7 @@ class PGPolicy(BasePolicy, nn.Module): | |||||||
|     """docstring for PGPolicy""" |     """docstring for PGPolicy""" | ||||||
| 
 | 
 | ||||||
|     def __init__(self, model, optim, dist_fn=torch.distributions.Categorical, |     def __init__(self, model, optim, dist_fn=torch.distributions.Categorical, | ||||||
|                  discount_factor=0.99, normalized_reward=True): |                  discount_factor=0.99): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.model = model |         self.model = model | ||||||
|         self.optim = optim |         self.optim = optim | ||||||
| @ -19,15 +19,10 @@ class PGPolicy(BasePolicy, nn.Module): | |||||||
|         self._eps = np.finfo(np.float32).eps.item() |         self._eps = np.finfo(np.float32).eps.item() | ||||||
|         assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]' |         assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]' | ||||||
|         self._gamma = discount_factor |         self._gamma = discount_factor | ||||||
|         self._rew_norm = normalized_reward |  | ||||||
| 
 | 
 | ||||||
|     def process_fn(self, batch, buffer, indice): |     def process_fn(self, batch, buffer, indice): | ||||||
|         batch_size = len(batch.rew) |         returns = self._vanilla_returns(batch) | ||||||
|         returns = self._vanilla_returns(batch, batch_size) |         # returns = self._vectorized_returns(batch) | ||||||
|         # returns = self._vectorized_returns(batch, batch_size) |  | ||||||
|         returns = returns - returns.mean() |  | ||||||
|         if self._rew_norm: |  | ||||||
|             returns = returns / (returns.std() + self._eps) |  | ||||||
|         batch.update(returns=returns) |         batch.update(returns=returns) | ||||||
|         return batch |         return batch | ||||||
| 
 | 
 | ||||||
| @ -40,6 +35,8 @@ class PGPolicy(BasePolicy, nn.Module): | |||||||
| 
 | 
 | ||||||
|     def learn(self, batch, batch_size=None): |     def learn(self, batch, batch_size=None): | ||||||
|         losses = [] |         losses = [] | ||||||
|  |         batch.returns = (batch.returns - batch.returns.mean()) \ | ||||||
|  |             / (batch.returns.std() + self._eps) | ||||||
|         for b in batch.split(batch_size): |         for b in batch.split(batch_size): | ||||||
|             self.optim.zero_grad() |             self.optim.zero_grad() | ||||||
|             dist = self(b).dist |             dist = self(b).dist | ||||||
| @ -51,21 +48,22 @@ class PGPolicy(BasePolicy, nn.Module): | |||||||
|             losses.append(loss.detach().cpu().numpy()) |             losses.append(loss.detach().cpu().numpy()) | ||||||
|         return losses |         return losses | ||||||
| 
 | 
 | ||||||
|     def _vanilla_returns(self, batch, batch_size): |     def _vanilla_returns(self, batch): | ||||||
|         returns = batch.rew[:] |         returns = batch.rew[:] | ||||||
|         last = 0 |         last = 0 | ||||||
|         for i in range(batch_size - 1, -1, -1): |         for i in range(len(returns) - 1, -1, -1): | ||||||
|             if not batch.done[i]: |             if not batch.done[i]: | ||||||
|                 returns[i] += self._gamma * last |                 returns[i] += self._gamma * last | ||||||
|             last = returns[i] |             last = returns[i] | ||||||
|         return returns |         return returns | ||||||
| 
 | 
 | ||||||
|     def _vectorized_returns(self, batch, batch_size): |     def _vectorized_returns(self, batch): | ||||||
|         # according to my tests, it is slower than vanilla |         # according to my tests, it is slower than vanilla | ||||||
|         # import scipy.signal |         # import scipy.signal | ||||||
|         convolve = np.convolve |         convolve = np.convolve | ||||||
|         # convolve = scipy.signal.convolve |         # convolve = scipy.signal.convolve | ||||||
|         rew = batch.rew[::-1] |         rew = batch.rew[::-1] | ||||||
|  |         batch_size = len(rew) | ||||||
|         gammas = self._gamma ** np.arange(batch_size) |         gammas = self._gamma ** np.arange(batch_size) | ||||||
|         c = convolve(rew, gammas)[:batch_size] |         c = convolve(rew, gammas)[:batch_size] | ||||||
|         T = np.where(batch.done[::-1])[0] |         T = np.where(batch.done[::-1])[0] | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user