fix ddpg
This commit is contained in:
		
							parent
							
								
									8bd8246b16
								
							
						
					
					
						commit
						c173f7bfbc
					
				| @ -22,18 +22,16 @@ def get_args(): | |||||||
|     parser.add_argument('--seed', type=int, default=1626) |     parser.add_argument('--seed', type=int, default=1626) | ||||||
|     parser.add_argument('--buffer-size', type=int, default=20000) |     parser.add_argument('--buffer-size', type=int, default=20000) | ||||||
|     parser.add_argument('--actor-lr', type=float, default=1e-4) |     parser.add_argument('--actor-lr', type=float, default=1e-4) | ||||||
|     parser.add_argument('--actor-wd', type=float, default=0) |  | ||||||
|     parser.add_argument('--critic-lr', type=float, default=1e-3) |     parser.add_argument('--critic-lr', type=float, default=1e-3) | ||||||
|     parser.add_argument('--critic-wd', type=float, default=1e-2) |  | ||||||
|     parser.add_argument('--gamma', type=float, default=0.99) |     parser.add_argument('--gamma', type=float, default=0.99) | ||||||
|     parser.add_argument('--tau', type=float, default=0.005) |     parser.add_argument('--tau', type=float, default=0.005) | ||||||
|     parser.add_argument('--exploration-noise', type=float, default=0.1) |     parser.add_argument('--exploration-noise', type=float, default=0.1) | ||||||
|     parser.add_argument('--epoch', type=int, default=100) |     parser.add_argument('--epoch', type=int, default=100) | ||||||
|     parser.add_argument('--step-per-epoch', type=int, default=2400) |     parser.add_argument('--step-per-epoch', type=int, default=2400) | ||||||
|     parser.add_argument('--collect-per-step', type=int, default=1) |     parser.add_argument('--collect-per-step', type=int, default=4) | ||||||
|     parser.add_argument('--batch-size', type=int, default=128) |     parser.add_argument('--batch-size', type=int, default=128) | ||||||
|     parser.add_argument('--layer-num', type=int, default=1) |     parser.add_argument('--layer-num', type=int, default=1) | ||||||
|     parser.add_argument('--training-num', type=int, default=1) |     parser.add_argument('--training-num', type=int, default=8) | ||||||
|     parser.add_argument('--test-num', type=int, default=100) |     parser.add_argument('--test-num', type=int, default=100) | ||||||
|     parser.add_argument('--logdir', type=str, default='log') |     parser.add_argument('--logdir', type=str, default='log') | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
| @ -45,6 +43,8 @@ def get_args(): | |||||||
| 
 | 
 | ||||||
| def test_ddpg(args=get_args()): | def test_ddpg(args=get_args()): | ||||||
|     env = gym.make(args.task) |     env = gym.make(args.task) | ||||||
|  |     if args.task == 'Pendulum-v0': | ||||||
|  |         env.spec.reward_threshold = -250 | ||||||
|     args.state_shape = env.observation_space.shape or env.observation_space.n |     args.state_shape = env.observation_space.shape or env.observation_space.n | ||||||
|     args.action_shape = env.action_space.shape or env.action_space.n |     args.action_shape = env.action_space.shape or env.action_space.n | ||||||
|     args.max_action = env.action_space.high[0] |     args.max_action = env.action_space.high[0] | ||||||
| @ -66,17 +66,16 @@ def test_ddpg(args=get_args()): | |||||||
|         args.layer_num, args.state_shape, args.action_shape, |         args.layer_num, args.state_shape, args.action_shape, | ||||||
|         args.max_action, args.device |         args.max_action, args.device | ||||||
|     ).to(args.device) |     ).to(args.device) | ||||||
|     actor_optim = torch.optim.Adam( |     actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) | ||||||
|         actor.parameters(), lr=args.actor_lr, weight_decay=args.actor_wd) |  | ||||||
|     critic = Critic( |     critic = Critic( | ||||||
|         args.layer_num, args.state_shape, args.action_shape, args.device |         args.layer_num, args.state_shape, args.action_shape, args.device | ||||||
|     ).to(args.device) |     ).to(args.device) | ||||||
|     critic_optim = torch.optim.Adam( |     critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) | ||||||
|         critic.parameters(), lr=args.critic_lr, weight_decay=args.critic_wd) |  | ||||||
|     policy = DDPGPolicy( |     policy = DDPGPolicy( | ||||||
|         actor, actor_optim, critic, critic_optim, |         actor, actor_optim, critic, critic_optim, | ||||||
|         args.tau, args.gamma, args.exploration_noise, |         args.tau, args.gamma, args.exploration_noise, | ||||||
|         [env.action_space.low[0], env.action_space.high[0]]) |         [env.action_space.low[0], env.action_space.high[0]], | ||||||
|  |         reward_normalization=True) | ||||||
|     # collector |     # collector | ||||||
|     train_collector = Collector( |     train_collector = Collector( | ||||||
|         policy, train_envs, ReplayBuffer(args.buffer_size), 1) |         policy, train_envs, ReplayBuffer(args.buffer_size), 1) | ||||||
| @ -85,10 +84,7 @@ def test_ddpg(args=get_args()): | |||||||
|     writer = SummaryWriter(args.logdir) |     writer = SummaryWriter(args.logdir) | ||||||
| 
 | 
 | ||||||
|     def stop_fn(x): |     def stop_fn(x): | ||||||
|         if args.task == 'Pendulum-v0': |         return x >= env.spec.reward_threshold | ||||||
|             return x >= -250 |  | ||||||
|         else: |  | ||||||
|             return False |  | ||||||
| 
 | 
 | ||||||
|     # trainer |     # trainer | ||||||
|     result = offpolicy_trainer( |     result = offpolicy_trainer( | ||||||
|  | |||||||
| @ -1,4 +1,5 @@ | |||||||
| import torch | import torch | ||||||
|  | import numpy as np | ||||||
| from copy import deepcopy | from copy import deepcopy | ||||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||||
| 
 | 
 | ||||||
| @ -12,7 +13,7 @@ class DDPGPolicy(BasePolicy): | |||||||
| 
 | 
 | ||||||
|     def __init__(self, actor, actor_optim, critic, critic_optim, |     def __init__(self, actor, actor_optim, critic, critic_optim, | ||||||
|                  tau=0.005, gamma=0.99, exploration_noise=0.1, |                  tau=0.005, gamma=0.99, exploration_noise=0.1, | ||||||
|                  action_range=None): |                  action_range=None, reward_normalization=True): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.actor, self.actor_old = actor, deepcopy(actor) |         self.actor, self.actor_old = actor, deepcopy(actor) | ||||||
|         self.actor_old.eval() |         self.actor_old.eval() | ||||||
| @ -28,6 +29,8 @@ class DDPGPolicy(BasePolicy): | |||||||
|         self._eps = exploration_noise |         self._eps = exploration_noise | ||||||
|         self._range = action_range |         self._range = action_range | ||||||
|         # self.noise = OUNoise() |         # self.noise = OUNoise() | ||||||
|  |         self._rew_norm = reward_normalization | ||||||
|  |         self.__eps = np.finfo(np.float32).eps.item() | ||||||
| 
 | 
 | ||||||
|     def set_eps(self, eps): |     def set_eps(self, eps): | ||||||
|         self._eps = eps |         self._eps = eps | ||||||
| @ -42,6 +45,9 @@ class DDPGPolicy(BasePolicy): | |||||||
|         self.actor.eval() |         self.actor.eval() | ||||||
|         self.critic.eval() |         self.critic.eval() | ||||||
| 
 | 
 | ||||||
|  |     def process_fn(self, batch, buffer, indice): | ||||||
|  |         return batch | ||||||
|  | 
 | ||||||
|     def sync_weight(self): |     def sync_weight(self): | ||||||
|         for o, n in zip(self.actor_old.parameters(), self.actor.parameters()): |         for o, n in zip(self.actor_old.parameters(), self.actor.parameters()): | ||||||
|             o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) |             o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) | ||||||
| @ -54,12 +60,12 @@ class DDPGPolicy(BasePolicy): | |||||||
|         model = getattr(self, model) |         model = getattr(self, model) | ||||||
|         obs = getattr(batch, input) |         obs = getattr(batch, input) | ||||||
|         logits, h = model(obs, state=state, info=batch.info) |         logits, h = model(obs, state=state, info=batch.info) | ||||||
|         # noise = np.random.normal(0, self._eps, size=logits.shape) |  | ||||||
|         if eps is None: |         if eps is None: | ||||||
|             eps = self._eps |             eps = self._eps | ||||||
|         logits += torch.randn(size=logits.shape, device=logits.device) * eps |         # noise = np.random.normal(0, eps, size=logits.shape) | ||||||
|         # noise = self.noise(logits.shape, self._eps) |         # noise = self.noise(logits.shape, eps) | ||||||
|         # logits += torch.tensor(noise, device=logits.device) |         # logits += torch.tensor(noise, device=logits.device) | ||||||
|  |         logits += torch.randn(size=logits.shape, device=logits.device) * eps | ||||||
|         if self._range: |         if self._range: | ||||||
|             logits = logits.clamp(self._range[0], self._range[1]) |             logits = logits.clamp(self._range[0], self._range[1]) | ||||||
|         return Batch(act=logits, state=h) |         return Batch(act=logits, state=h) | ||||||
| @ -68,10 +74,11 @@ class DDPGPolicy(BasePolicy): | |||||||
|         target_q = self.critic_old(batch.obs_next, self( |         target_q = self.critic_old(batch.obs_next, self( | ||||||
|             batch, model='actor_old', input='obs_next', eps=0).act) |             batch, model='actor_old', input='obs_next', eps=0).act) | ||||||
|         dev = target_q.device |         dev = target_q.device | ||||||
|         rew = torch.tensor(batch.rew, dtype=torch.float, device=dev) |         rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None] | ||||||
|         done = torch.tensor(batch.done, dtype=torch.float, device=dev) |         if self._rew_norm: | ||||||
|         target_q = rew[:, None] + (( |             rew = (rew - rew.mean()) / (rew.std() + self.__eps) | ||||||
|             1. - done[:, None]) * self._gamma * target_q).detach() |         done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None] | ||||||
|  |         target_q = rew + ((1. - done) * self._gamma * target_q).detach() | ||||||
|         current_q = self.critic(batch.obs, batch.act) |         current_q = self.critic(batch.obs, batch.act) | ||||||
|         critic_loss = F.mse_loss(current_q, target_q) |         critic_loss = F.mse_loss(current_q, target_q) | ||||||
|         self.critic_optim.zero_grad() |         self.critic_optim.zero_grad() | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user