From c173f7bfbc7c13e3e145b610bcd2927e13f5c3ee Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Sat, 21 Mar 2020 15:31:31 +0800 Subject: [PATCH] fix ddpg --- test/continuous/test_ddpg.py | 22 +++++++++------------- tianshou/policy/ddpg.py | 23 +++++++++++++++-------- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 4fcea0c..7b3bfbb 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -22,18 +22,16 @@ def get_args(): parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--actor-lr', type=float, default=1e-4) - parser.add_argument('--actor-wd', type=float, default=0) parser.add_argument('--critic-lr', type=float, default=1e-3) - parser.add_argument('--critic-wd', type=float, default=1e-2) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--tau', type=float, default=0.005) parser.add_argument('--exploration-noise', type=float, default=0.1) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=2400) - parser.add_argument('--collect-per-step', type=int, default=1) + parser.add_argument('--collect-per-step', type=int, default=4) parser.add_argument('--batch-size', type=int, default=128) parser.add_argument('--layer-num', type=int, default=1) - parser.add_argument('--training-num', type=int, default=1) + parser.add_argument('--training-num', type=int, default=8) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument( @@ -45,6 +43,8 @@ def get_args(): def test_ddpg(args=get_args()): env = gym.make(args.task) + if args.task == 'Pendulum-v0': + env.spec.reward_threshold = -250 args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n args.max_action = env.action_space.high[0] @@ -66,17 +66,16 @@ def test_ddpg(args=get_args()): args.layer_num, args.state_shape, args.action_shape, args.max_action, args.device ).to(args.device) - actor_optim = torch.optim.Adam( - actor.parameters(), lr=args.actor_lr, weight_decay=args.actor_wd) + actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) critic = Critic( args.layer_num, args.state_shape, args.action_shape, args.device ).to(args.device) - critic_optim = torch.optim.Adam( - critic.parameters(), lr=args.critic_lr, weight_decay=args.critic_wd) + critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr) policy = DDPGPolicy( actor, actor_optim, critic, critic_optim, args.tau, args.gamma, args.exploration_noise, - [env.action_space.low[0], env.action_space.high[0]]) + [env.action_space.low[0], env.action_space.high[0]], + reward_normalization=True) # collector train_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size), 1) @@ -85,10 +84,7 @@ def test_ddpg(args=get_args()): writer = SummaryWriter(args.logdir) def stop_fn(x): - if args.task == 'Pendulum-v0': - return x >= -250 - else: - return False + return x >= env.spec.reward_threshold # trainer result = offpolicy_trainer( diff --git a/tianshou/policy/ddpg.py b/tianshou/policy/ddpg.py index cf8002f..72dc633 100644 --- a/tianshou/policy/ddpg.py +++ b/tianshou/policy/ddpg.py @@ -1,4 +1,5 @@ import torch +import numpy as np from copy import deepcopy import torch.nn.functional as F @@ -12,7 +13,7 @@ class DDPGPolicy(BasePolicy): def __init__(self, actor, actor_optim, critic, critic_optim, tau=0.005, gamma=0.99, exploration_noise=0.1, - action_range=None): + action_range=None, reward_normalization=True): super().__init__() self.actor, self.actor_old = actor, deepcopy(actor) self.actor_old.eval() @@ -28,6 +29,8 @@ class DDPGPolicy(BasePolicy): self._eps = exploration_noise self._range = action_range # self.noise = OUNoise() + self._rew_norm = reward_normalization + self.__eps = np.finfo(np.float32).eps.item() def set_eps(self, eps): self._eps = eps @@ -42,6 +45,9 @@ class DDPGPolicy(BasePolicy): self.actor.eval() self.critic.eval() + def process_fn(self, batch, buffer, indice): + return batch + def sync_weight(self): for o, n in zip(self.actor_old.parameters(), self.actor.parameters()): o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) @@ -54,12 +60,12 @@ class DDPGPolicy(BasePolicy): model = getattr(self, model) obs = getattr(batch, input) logits, h = model(obs, state=state, info=batch.info) - # noise = np.random.normal(0, self._eps, size=logits.shape) if eps is None: eps = self._eps - logits += torch.randn(size=logits.shape, device=logits.device) * eps - # noise = self.noise(logits.shape, self._eps) + # noise = np.random.normal(0, eps, size=logits.shape) + # noise = self.noise(logits.shape, eps) # logits += torch.tensor(noise, device=logits.device) + logits += torch.randn(size=logits.shape, device=logits.device) * eps if self._range: logits = logits.clamp(self._range[0], self._range[1]) return Batch(act=logits, state=h) @@ -68,10 +74,11 @@ class DDPGPolicy(BasePolicy): target_q = self.critic_old(batch.obs_next, self( batch, model='actor_old', input='obs_next', eps=0).act) dev = target_q.device - rew = torch.tensor(batch.rew, dtype=torch.float, device=dev) - done = torch.tensor(batch.done, dtype=torch.float, device=dev) - target_q = rew[:, None] + (( - 1. - done[:, None]) * self._gamma * target_q).detach() + rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None] + if self._rew_norm: + rew = (rew - rew.mean()) / (rew.std() + self.__eps) + done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None] + target_q = rew + ((1. - done) * self._gamma * target_q).detach() current_q = self.critic(batch.obs, batch.act) critic_loss = F.mse_loss(current_q, target_q) self.critic_optim.zero_grad()