From 3108b9db0d1eecb7cef37774ea85b6080edec7a2 Mon Sep 17 00:00:00 2001 From: ChenDRAG <40993476+ChenDRAG@users.noreply.github.com> Date: Fri, 26 Feb 2021 13:23:18 +0800 Subject: [PATCH] Add Timelimit trick to optimize policies (#296) * consider timelimit.truncated in calculating returns by default * remove ignore_done --- examples/box2d/mcc_sac.py | 2 +- examples/mujoco/runnable/ant_v2_ddpg.py | 2 +- examples/mujoco/runnable/ant_v2_td3.py | 2 +- .../runnable/halfcheetahBullet_v0_sac.py | 2 +- examples/mujoco/runnable/point_maze_td3.py | 2 +- test/base/test_returns.py | 79 ++++++++++++++++++- test/continuous/test_ddpg.py | 1 - test/continuous/test_sac_with_il.py | 1 - test/continuous/test_td3.py | 1 - test/discrete/test_c51.py | 4 +- test/discrete/test_sac.py | 3 +- tianshou/policy/base.py | 9 ++- tianshou/policy/modelfree/ddpg.py | 6 -- tianshou/policy/modelfree/discrete_sac.py | 5 +- tianshou/policy/modelfree/sac.py | 5 +- tianshou/policy/modelfree/td3.py | 9 +-- tianshou/utils/log_tools.py | 2 +- 17 files changed, 96 insertions(+), 39 deletions(-) diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 7fe3dae..f22c784 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -92,7 +92,7 @@ def test_sac(args=get_args()): actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, action_range=[env.action_space.low[0], env.action_space.high[0]], tau=args.tau, gamma=args.gamma, alpha=args.alpha, - reward_normalization=args.rew_norm, ignore_done=True, + reward_normalization=args.rew_norm, exploration_noise=OUNoise(0.0, args.noise_std)) # collector train_collector = Collector( diff --git a/examples/mujoco/runnable/ant_v2_ddpg.py b/examples/mujoco/runnable/ant_v2_ddpg.py index bc75e1f..53e9ac4 100644 --- a/examples/mujoco/runnable/ant_v2_ddpg.py +++ b/examples/mujoco/runnable/ant_v2_ddpg.py @@ -73,7 +73,7 @@ def test_ddpg(args=get_args()): action_range=[env.action_space.low[0], env.action_space.high[0]], tau=args.tau, gamma=args.gamma, exploration_noise=GaussianNoise(sigma=args.exploration_noise), - reward_normalization=True, ignore_done=True) + reward_normalization=True) # collector train_collector = Collector( policy, train_envs, diff --git a/examples/mujoco/runnable/ant_v2_td3.py b/examples/mujoco/runnable/ant_v2_td3.py index 004b604..cbbd952 100644 --- a/examples/mujoco/runnable/ant_v2_td3.py +++ b/examples/mujoco/runnable/ant_v2_td3.py @@ -81,7 +81,7 @@ def test_td3(args=get_args()): policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, - reward_normalization=True, ignore_done=True) + reward_normalization=True) # collector train_collector = Collector( policy, train_envs, diff --git a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py index 2ed0462..db0ce6e 100644 --- a/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py +++ b/examples/mujoco/runnable/halfcheetahBullet_v0_sac.py @@ -81,7 +81,7 @@ def test_sac(args=get_args()): actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, action_range=[env.action_space.low[0], env.action_space.high[0]], tau=args.tau, gamma=args.gamma, alpha=args.alpha, - reward_normalization=True, ignore_done=True) + reward_normalization=True) # collector train_collector = Collector( policy, train_envs, diff --git a/examples/mujoco/runnable/point_maze_td3.py b/examples/mujoco/runnable/point_maze_td3.py index 8e5f37b..ed2ce0e 100644 --- a/examples/mujoco/runnable/point_maze_td3.py +++ b/examples/mujoco/runnable/point_maze_td3.py @@ -86,7 +86,7 @@ def test_td3(args=get_args()): policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, - reward_normalization=True, ignore_done=True) + reward_normalization=True) # collector train_collector = Collector( policy, train_envs, diff --git a/test/base/test_returns.py b/test/base/test_returns.py index 5aba7f5..e8d70de 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -24,6 +24,8 @@ def test_episodic_returns(size=2560): batch = Batch( done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]), rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]), + info=Batch({'TimeLimit.truncated': + np.array([False, False, False, False, False, True, False, False])}) ) for b in batch: b.obs = b.act = 1 @@ -69,6 +71,24 @@ def test_episodic_returns(size=2560): 474.2876, 390.1027, 299.476, 202.]) assert np.allclose(ret.returns, returns) buf.reset() + batch = Batch( + done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]), + rew=np.array([101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202]), + info=Batch({'TimeLimit.truncated': + np.array([False, False, False, True, False, False, + False, True, False, False, False, False])}) + ) + for b in batch: + b.obs = b.act = 1 + buf.add(b) + v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3]) + ret = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95) + returns = np.array([ + 454.0109, 375.2386, 290.3669, 199.01, + 462.9138, 381.3571, 293.5248, 199.02, + 474.2876, 390.1027, 299.476, 202.]) + assert np.allclose(ret.returns, returns) + if __name__ == '__main__': buf = ReplayBuffer(size) batch = Batch( @@ -106,15 +126,19 @@ def compute_nstep_return_base(nstep, gamma, buffer, indice): buf_len = len(buffer) for i in range(len(indice)): flag, r = False, 0. + real_step_n = nstep for n in range(nstep): idx = (indice[i] + n) % buf_len r += buffer.rew[idx] * gamma ** n if buffer.done[idx]: - flag = True + if not (hasattr(buffer, 'info') and + buffer.info['TimeLimit.truncated'][idx]): + flag = True + real_step_n = n + 1 break if not flag: - idx = (indice[i] + nstep - 1) % buf_len - r += to_numpy(target_q_fn(buffer, idx)) * gamma ** nstep + idx = (indice[i] + real_step_n - 1) % buf_len + r += to_numpy(target_q_fn(buffer, idx)) * gamma ** real_step_n returns[i] = r return returns @@ -161,10 +185,56 @@ def test_nstep_returns(size=10000): ).pop('returns')) assert np.allclose(returns_multidim, returns[:, np.newaxis]) + +def test_nstep_returns_with_timelimit(size=10000): + buf = ReplayBuffer(10) + for i in range(12): + buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % 4 == 3, + info={"TimeLimit.truncated": i == 3})) + batch, indice = buf.sample(0) + assert np.allclose(indice, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1]) + # rew: [11, 12, 3, 4, 5, 6, 7, 8, 9, 10] + # done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0] + # test nstep = 1 + returns = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn, gamma=.1, n_step=1 + ).pop('returns').reshape(-1)) + assert np.allclose(returns, [2.6, 3.6, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12]) + r_ = compute_nstep_return_base(1, .1, buf, indice) + assert np.allclose(returns, r_), (r_, returns) + returns_multidim = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=1 + ).pop('returns')) + assert np.allclose(returns_multidim, returns[:, np.newaxis]) + # test nstep = 2 + returns = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn, gamma=.1, n_step=2 + ).pop('returns').reshape(-1)) + assert np.allclose(returns, [3.36, 3.6, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12]) + r_ = compute_nstep_return_base(2, .1, buf, indice) + assert np.allclose(returns, r_) + returns_multidim = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=2 + ).pop('returns')) + assert np.allclose(returns_multidim, returns[:, np.newaxis]) + # test nstep = 10 + returns = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn, gamma=.1, n_step=10 + ).pop('returns').reshape(-1)) + assert np.allclose(returns, [3.36, 3.6, 5.678, 6.78, + 7.8, 8, 10.122, 11.22, 12.2, 12]) + r_ = compute_nstep_return_base(10, .1, buf, indice) + assert np.allclose(returns, r_) + returns_multidim = to_numpy(BasePolicy.compute_nstep_return( + batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=10 + ).pop('returns')) + assert np.allclose(returns_multidim, returns[:, np.newaxis]) + if __name__ == '__main__': buf = ReplayBuffer(size) for i in range(int(size * 1.5)): - buf.add(Batch(obs=0, act=0, rew=i + 1, done=np.random.randint(3) == 0)) + buf.add(Batch(obs=0, act=0, rew=i + 1, done=np.random.randint(3) == 0, + info={"TimeLimit.truncated": i % 33 == 0})) batch, indice = buf.sample(256) def vanilla(): @@ -181,4 +251,5 @@ def test_nstep_returns(size=10000): if __name__ == '__main__': test_nstep_returns() + test_nstep_returns_with_timelimit() test_episodic_returns() diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index 311aa65..232eef1 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -83,7 +83,6 @@ def test_ddpg(args=get_args()): tau=args.tau, gamma=args.gamma, exploration_noise=GaussianNoise(sigma=args.exploration_noise), reward_normalization=args.rew_norm, - ignore_done=args.ignore_done, estimation_step=args.n_step) # collector train_collector = Collector( diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index 1b4a977..8d18428 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -91,7 +91,6 @@ def test_sac_with_il(args=get_args()): action_range=[env.action_space.low[0], env.action_space.high[0]], tau=args.tau, gamma=args.gamma, alpha=args.alpha, reward_normalization=args.rew_norm, - ignore_done=args.ignore_done, estimation_step=args.n_step) # collector train_collector = Collector( diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index d678181..c24741c 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -95,7 +95,6 @@ def test_td3(args=get_args()): update_actor_freq=args.update_actor_freq, noise_clip=args.noise_clip, reward_normalization=args.rew_norm, - ignore_done=args.ignore_done, estimation_step=args.n_step) # collector train_collector = Collector( diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 53768da..1d0c4cc 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -116,8 +116,8 @@ def test_c51(args=get_args()): result = offpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.step_per_collect, args.test_num, - args.batch_size, train_fn=train_fn, test_fn=test_fn, - stop_fn=stop_fn, save_fn=save_fn, logger=logger) + args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, + test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward']) if __name__ == '__main__': diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index d7f408f..b5871f6 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -88,8 +88,7 @@ def test_discrete_sac(args=get_args()): policy = DiscreteSACPolicy( actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, args.tau, args.gamma, args.alpha, - reward_normalization=args.rew_norm, - ignore_done=args.ignore_done) + reward_normalization=args.rew_norm) # collector train_collector = Collector( policy, train_envs, diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 99a2f8a..cf2678f 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -203,7 +203,12 @@ class BasePolicy(ABC, nn.Module): :return: A bool type numpy.ndarray in the same shape with indice. "True" means "obs_next" of that buffer[indice] is valid. """ - return ~buffer.done[indice].astype(np.bool) + mask = ~buffer.done[indice].astype(np.bool) + # info['TimeLimit.truncated'] will be set to True if 'done' flag is generated + # because of timelimit of environments. Checkout gym.wrappers.TimeLimit. + if hasattr(buffer, 'info') and 'TimeLimit.truncated' in buffer.info: + mask = mask | buffer.info['TimeLimit.truncated'][indice] + return mask @staticmethod def compute_episodic_return( @@ -377,7 +382,7 @@ def _nstep_return( gammas = np.full(indices[0].shape, n_step) for n in range(n_step - 1, -1, -1): now = indices[n] - gammas[end_flag[now] > 0] = n + gammas[end_flag[now] > 0] = n + 1 returns[end_flag[now] > 0] = 0.0 returns = (rew[now].reshape(bsz, 1) - mean) / std + gamma * returns target_q = target_q * gamma_buffer[gammas].reshape(bsz, 1) + returns diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index efa9fb7..d91359a 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -26,8 +26,6 @@ class DDPGPolicy(BasePolicy): add to the action, defaults to ``GaussianNoise(sigma=0.1)``. :param bool reward_normalization: normalize the reward to Normal(0, 1), defaults to False. - :param bool ignore_done: ignore the done flag while training the policy, - defaults to False. :param int estimation_step: greater than 1, the number of steps to look ahead. @@ -48,7 +46,6 @@ class DDPGPolicy(BasePolicy): gamma: float = 0.99, exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1), reward_normalization: bool = False, - ignore_done: bool = False, estimation_step: int = 1, **kwargs: Any, ) -> None: @@ -73,7 +70,6 @@ class DDPGPolicy(BasePolicy): self._action_scale = (action_range[1] - action_range[0]) / 2.0 # it is only a little difference to use GaussianNoise # self.noise = OUNoise() - self._rm_done = ignore_done self._rew_norm = reward_normalization assert estimation_step > 0, "estimation_step should be greater than 0" self._n_step = estimation_step @@ -110,8 +106,6 @@ class DDPGPolicy(BasePolicy): def process_fn( self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray ) -> Batch: - if self._rm_done: - batch.done = batch.done * 0.0 batch = self.compute_nstep_return( batch, buffer, indice, self._target_q, self._gamma, self._n_step, self._rew_norm) diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 9c46fc4..fd67d47 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -28,8 +28,6 @@ class DiscreteSACPolicy(SACPolicy): alpha is automatatically tuned. :param bool reward_normalization: normalize the reward to Normal(0, 1), defaults to ``False``. - :param bool ignore_done: ignore the done flag while training the policy, - defaults to ``False``. .. seealso:: @@ -51,13 +49,12 @@ class DiscreteSACPolicy(SACPolicy): float, Tuple[float, torch.Tensor, torch.optim.Optimizer] ] = 0.2, reward_normalization: bool = False, - ignore_done: bool = False, estimation_step: int = 1, **kwargs: Any, ) -> None: super().__init__(actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, (-np.inf, np.inf), tau, gamma, alpha, - reward_normalization, ignore_done, estimation_step, + reward_normalization, estimation_step, **kwargs) self._alpha: Union[float, torch.Tensor] diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 091d124..cb53fad 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -34,8 +34,6 @@ class SACPolicy(DDPGPolicy): alpha is automatatically tuned. :param bool reward_normalization: normalize the reward to Normal(0, 1), defaults to False. - :param bool ignore_done: ignore the done flag while training the policy, - defaults to False. :param BaseNoise exploration_noise: add a noise to action for exploration, defaults to None. This is useful when solving hard-exploration problem. :param bool deterministic_eval: whether to use deterministic action (mean @@ -63,14 +61,13 @@ class SACPolicy(DDPGPolicy): float, Tuple[float, torch.Tensor, torch.optim.Optimizer] ] = 0.2, reward_normalization: bool = False, - ignore_done: bool = False, estimation_step: int = 1, exploration_noise: Optional[BaseNoise] = None, deterministic_eval: bool = True, **kwargs: Any, ) -> None: super().__init__(None, None, None, None, action_range, tau, gamma, - exploration_noise, reward_normalization, ignore_done, + exploration_noise, reward_normalization, estimation_step, **kwargs) self.actor, self.actor_optim = actor, actor_optim self.critic1, self.critic1_old = critic1, deepcopy(critic1) diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index f79c2a0..23e16d8 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -37,8 +37,6 @@ class TD3Policy(DDPGPolicy): network, default to 0.5. :param bool reward_normalization: normalize the reward to Normal(0, 1), defaults to False. - :param bool ignore_done: ignore the done flag while training the policy, - defaults to False. .. seealso:: @@ -62,13 +60,12 @@ class TD3Policy(DDPGPolicy): update_actor_freq: int = 2, noise_clip: float = 0.5, reward_normalization: bool = False, - ignore_done: bool = False, estimation_step: int = 1, **kwargs: Any, ) -> None: - super().__init__(actor, actor_optim, None, None, action_range, - tau, gamma, exploration_noise, reward_normalization, - ignore_done, estimation_step, **kwargs) + super().__init__(actor, actor_optim, None, None, action_range, tau, gamma, + exploration_noise, reward_normalization, + estimation_step, **kwargs) self.critic1, self.critic1_old = critic1, deepcopy(critic1) self.critic1_old.eval() self.critic1_optim = critic1_optim diff --git a/tianshou/utils/log_tools.py b/tianshou/utils/log_tools.py index 7605a27..c50c8eb 100644 --- a/tianshou/utils/log_tools.py +++ b/tianshou/utils/log_tools.py @@ -138,7 +138,7 @@ class BasicLogger(BaseLogger): def log_update_data(self, update_result: dict, step: int) -> None: if step - self.last_log_update_step >= self.update_interval: for k, v in update_result.items(): - self.write("train/" + k, step, v) # save in train/ + self.write(k, step, v) self.last_log_update_step = step