fix ddpg
This commit is contained in:
		
							parent
							
								
									8bd8246b16
								
							
						
					
					
						commit
						c173f7bfbc
					
				| @ -22,18 +22,16 @@ def get_args(): | ||||
|     parser.add_argument('--seed', type=int, default=1626) | ||||
|     parser.add_argument('--buffer-size', type=int, default=20000) | ||||
|     parser.add_argument('--actor-lr', type=float, default=1e-4) | ||||
|     parser.add_argument('--actor-wd', type=float, default=0) | ||||
|     parser.add_argument('--critic-lr', type=float, default=1e-3) | ||||
|     parser.add_argument('--critic-wd', type=float, default=1e-2) | ||||
|     parser.add_argument('--gamma', type=float, default=0.99) | ||||
|     parser.add_argument('--tau', type=float, default=0.005) | ||||
|     parser.add_argument('--exploration-noise', type=float, default=0.1) | ||||
|     parser.add_argument('--epoch', type=int, default=100) | ||||
|     parser.add_argument('--step-per-epoch', type=int, default=2400) | ||||
|     parser.add_argument('--collect-per-step', type=int, default=1) | ||||
|     parser.add_argument('--collect-per-step', type=int, default=4) | ||||
|     parser.add_argument('--batch-size', type=int, default=128) | ||||
|     parser.add_argument('--layer-num', type=int, default=1) | ||||
|     parser.add_argument('--training-num', type=int, default=1) | ||||
|     parser.add_argument('--training-num', type=int, default=8) | ||||
|     parser.add_argument('--test-num', type=int, default=100) | ||||
|     parser.add_argument('--logdir', type=str, default='log') | ||||
|     parser.add_argument( | ||||
| @ -45,6 +43,8 @@ def get_args(): | ||||
| 
 | ||||
| def test_ddpg(args=get_args()): | ||||
|     env = gym.make(args.task) | ||||
|     if args.task == 'Pendulum-v0': | ||||
|         env.spec.reward_threshold = -250 | ||||
|     args.state_shape = env.observation_space.shape or env.observation_space.n | ||||
|     args.action_shape = env.action_space.shape or env.action_space.n | ||||
|     args.max_action = env.action_space.high[0] | ||||
| @ -66,17 +66,16 @@ def test_ddpg(args=get_args()): | ||||
|         args.layer_num, args.state_shape, args.action_shape, | ||||
|         args.max_action, args.device | ||||
|     ).to(args.device) | ||||
|     actor_optim = torch.optim.Adam( | ||||
|         actor.parameters(), lr=args.actor_lr, weight_decay=args.actor_wd) | ||||
|     actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) | ||||
|     critic = Critic( | ||||
|         args.layer_num, args.state_shape, args.action_shape, args.device | ||||
|     ).to(args.device) | ||||
|     critic_optim = torch.optim.Adam( | ||||
|         critic.parameters(), lr=args.critic_lr, weight_decay=args.critic_wd) | ||||
|     critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) | ||||
|     policy = DDPGPolicy( | ||||
|         actor, actor_optim, critic, critic_optim, | ||||
|         args.tau, args.gamma, args.exploration_noise, | ||||
|         [env.action_space.low[0], env.action_space.high[0]]) | ||||
|         [env.action_space.low[0], env.action_space.high[0]], | ||||
|         reward_normalization=True) | ||||
|     # collector | ||||
|     train_collector = Collector( | ||||
|         policy, train_envs, ReplayBuffer(args.buffer_size), 1) | ||||
| @ -85,10 +84,7 @@ def test_ddpg(args=get_args()): | ||||
|     writer = SummaryWriter(args.logdir) | ||||
| 
 | ||||
|     def stop_fn(x): | ||||
|         if args.task == 'Pendulum-v0': | ||||
|             return x >= -250 | ||||
|         else: | ||||
|             return False | ||||
|         return x >= env.spec.reward_threshold | ||||
| 
 | ||||
|     # trainer | ||||
|     result = offpolicy_trainer( | ||||
|  | ||||
| @ -1,4 +1,5 @@ | ||||
| import torch | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| import torch.nn.functional as F | ||||
| 
 | ||||
| @ -12,7 +13,7 @@ class DDPGPolicy(BasePolicy): | ||||
| 
 | ||||
|     def __init__(self, actor, actor_optim, critic, critic_optim, | ||||
|                  tau=0.005, gamma=0.99, exploration_noise=0.1, | ||||
|                  action_range=None): | ||||
|                  action_range=None, reward_normalization=True): | ||||
|         super().__init__() | ||||
|         self.actor, self.actor_old = actor, deepcopy(actor) | ||||
|         self.actor_old.eval() | ||||
| @ -28,6 +29,8 @@ class DDPGPolicy(BasePolicy): | ||||
|         self._eps = exploration_noise | ||||
|         self._range = action_range | ||||
|         # self.noise = OUNoise() | ||||
|         self._rew_norm = reward_normalization | ||||
|         self.__eps = np.finfo(np.float32).eps.item() | ||||
| 
 | ||||
|     def set_eps(self, eps): | ||||
|         self._eps = eps | ||||
| @ -42,6 +45,9 @@ class DDPGPolicy(BasePolicy): | ||||
|         self.actor.eval() | ||||
|         self.critic.eval() | ||||
| 
 | ||||
|     def process_fn(self, batch, buffer, indice): | ||||
|         return batch | ||||
| 
 | ||||
|     def sync_weight(self): | ||||
|         for o, n in zip(self.actor_old.parameters(), self.actor.parameters()): | ||||
|             o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) | ||||
| @ -54,12 +60,12 @@ class DDPGPolicy(BasePolicy): | ||||
|         model = getattr(self, model) | ||||
|         obs = getattr(batch, input) | ||||
|         logits, h = model(obs, state=state, info=batch.info) | ||||
|         # noise = np.random.normal(0, self._eps, size=logits.shape) | ||||
|         if eps is None: | ||||
|             eps = self._eps | ||||
|         logits += torch.randn(size=logits.shape, device=logits.device) * eps | ||||
|         # noise = self.noise(logits.shape, self._eps) | ||||
|         # noise = np.random.normal(0, eps, size=logits.shape) | ||||
|         # noise = self.noise(logits.shape, eps) | ||||
|         # logits += torch.tensor(noise, device=logits.device) | ||||
|         logits += torch.randn(size=logits.shape, device=logits.device) * eps | ||||
|         if self._range: | ||||
|             logits = logits.clamp(self._range[0], self._range[1]) | ||||
|         return Batch(act=logits, state=h) | ||||
| @ -68,10 +74,11 @@ class DDPGPolicy(BasePolicy): | ||||
|         target_q = self.critic_old(batch.obs_next, self( | ||||
|             batch, model='actor_old', input='obs_next', eps=0).act) | ||||
|         dev = target_q.device | ||||
|         rew = torch.tensor(batch.rew, dtype=torch.float, device=dev) | ||||
|         done = torch.tensor(batch.done, dtype=torch.float, device=dev) | ||||
|         target_q = rew[:, None] + (( | ||||
|             1. - done[:, None]) * self._gamma * target_q).detach() | ||||
|         rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None] | ||||
|         if self._rew_norm: | ||||
|             rew = (rew - rew.mean()) / (rew.std() + self.__eps) | ||||
|         done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None] | ||||
|         target_q = rew + ((1. - done) * self._gamma * target_q).detach() | ||||
|         current_q = self.critic(batch.obs, batch.act) | ||||
|         critic_loss = F.mse_loss(current_q, target_q) | ||||
|         self.critic_optim.zero_grad() | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user