Add Timelimit trick to optimize policies (#296)

* consider timelimit.truncated in calculating returns by default
* remove ignore_done
This commit is contained in:
ChenDRAG 2021-02-26 13:23:18 +08:00 committed by GitHub
parent 9b61bc620c
commit 3108b9db0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 96 additions and 39 deletions

View File

@ -92,7 +92,7 @@ def test_sac(args=get_args()):
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]], action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha, tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=args.rew_norm, ignore_done=True, reward_normalization=args.rew_norm,
exploration_noise=OUNoise(0.0, args.noise_std)) exploration_noise=OUNoise(0.0, args.noise_std))
# collector # collector
train_collector = Collector( train_collector = Collector(

View File

@ -73,7 +73,7 @@ def test_ddpg(args=get_args()):
action_range=[env.action_space.low[0], env.action_space.high[0]], action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise), exploration_noise=GaussianNoise(sigma=args.exploration_noise),
reward_normalization=True, ignore_done=True) reward_normalization=True)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, policy, train_envs,

View File

@ -81,7 +81,7 @@ def test_td3(args=get_args()):
policy_noise=args.policy_noise, policy_noise=args.policy_noise,
update_actor_freq=args.update_actor_freq, update_actor_freq=args.update_actor_freq,
noise_clip=args.noise_clip, noise_clip=args.noise_clip,
reward_normalization=True, ignore_done=True) reward_normalization=True)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, policy, train_envs,

View File

@ -81,7 +81,7 @@ def test_sac(args=get_args()):
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
action_range=[env.action_space.low[0], env.action_space.high[0]], action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha, tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=True, ignore_done=True) reward_normalization=True)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, policy, train_envs,

View File

@ -86,7 +86,7 @@ def test_td3(args=get_args()):
policy_noise=args.policy_noise, policy_noise=args.policy_noise,
update_actor_freq=args.update_actor_freq, update_actor_freq=args.update_actor_freq,
noise_clip=args.noise_clip, noise_clip=args.noise_clip,
reward_normalization=True, ignore_done=True) reward_normalization=True)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, policy, train_envs,

View File

@ -24,6 +24,8 @@ def test_episodic_returns(size=2560):
batch = Batch( batch = Batch(
done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]), done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]),
rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]), rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]),
info=Batch({'TimeLimit.truncated':
np.array([False, False, False, False, False, True, False, False])})
) )
for b in batch: for b in batch:
b.obs = b.act = 1 b.obs = b.act = 1
@ -69,6 +71,24 @@ def test_episodic_returns(size=2560):
474.2876, 390.1027, 299.476, 202.]) 474.2876, 390.1027, 299.476, 202.])
assert np.allclose(ret.returns, returns) assert np.allclose(ret.returns, returns)
buf.reset() buf.reset()
batch = Batch(
done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]),
rew=np.array([101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202]),
info=Batch({'TimeLimit.truncated':
np.array([False, False, False, True, False, False,
False, True, False, False, False, False])})
)
for b in batch:
b.obs = b.act = 1
buf.add(b)
v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3])
ret = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95)
returns = np.array([
454.0109, 375.2386, 290.3669, 199.01,
462.9138, 381.3571, 293.5248, 199.02,
474.2876, 390.1027, 299.476, 202.])
assert np.allclose(ret.returns, returns)
if __name__ == '__main__': if __name__ == '__main__':
buf = ReplayBuffer(size) buf = ReplayBuffer(size)
batch = Batch( batch = Batch(
@ -106,15 +126,19 @@ def compute_nstep_return_base(nstep, gamma, buffer, indice):
buf_len = len(buffer) buf_len = len(buffer)
for i in range(len(indice)): for i in range(len(indice)):
flag, r = False, 0. flag, r = False, 0.
real_step_n = nstep
for n in range(nstep): for n in range(nstep):
idx = (indice[i] + n) % buf_len idx = (indice[i] + n) % buf_len
r += buffer.rew[idx] * gamma ** n r += buffer.rew[idx] * gamma ** n
if buffer.done[idx]: if buffer.done[idx]:
flag = True if not (hasattr(buffer, 'info') and
buffer.info['TimeLimit.truncated'][idx]):
flag = True
real_step_n = n + 1
break break
if not flag: if not flag:
idx = (indice[i] + nstep - 1) % buf_len idx = (indice[i] + real_step_n - 1) % buf_len
r += to_numpy(target_q_fn(buffer, idx)) * gamma ** nstep r += to_numpy(target_q_fn(buffer, idx)) * gamma ** real_step_n
returns[i] = r returns[i] = r
return returns return returns
@ -161,10 +185,56 @@ def test_nstep_returns(size=10000):
).pop('returns')) ).pop('returns'))
assert np.allclose(returns_multidim, returns[:, np.newaxis]) assert np.allclose(returns_multidim, returns[:, np.newaxis])
def test_nstep_returns_with_timelimit(size=10000):
buf = ReplayBuffer(10)
for i in range(12):
buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % 4 == 3,
info={"TimeLimit.truncated": i == 3}))
batch, indice = buf.sample(0)
assert np.allclose(indice, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1])
# rew: [11, 12, 3, 4, 5, 6, 7, 8, 9, 10]
# done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0]
# test nstep = 1
returns = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn, gamma=.1, n_step=1
).pop('returns').reshape(-1))
assert np.allclose(returns, [2.6, 3.6, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12])
r_ = compute_nstep_return_base(1, .1, buf, indice)
assert np.allclose(returns, r_), (r_, returns)
returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=1
).pop('returns'))
assert np.allclose(returns_multidim, returns[:, np.newaxis])
# test nstep = 2
returns = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn, gamma=.1, n_step=2
).pop('returns').reshape(-1))
assert np.allclose(returns, [3.36, 3.6, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12])
r_ = compute_nstep_return_base(2, .1, buf, indice)
assert np.allclose(returns, r_)
returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=2
).pop('returns'))
assert np.allclose(returns_multidim, returns[:, np.newaxis])
# test nstep = 10
returns = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn, gamma=.1, n_step=10
).pop('returns').reshape(-1))
assert np.allclose(returns, [3.36, 3.6, 5.678, 6.78,
7.8, 8, 10.122, 11.22, 12.2, 12])
r_ = compute_nstep_return_base(10, .1, buf, indice)
assert np.allclose(returns, r_)
returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=10
).pop('returns'))
assert np.allclose(returns_multidim, returns[:, np.newaxis])
if __name__ == '__main__': if __name__ == '__main__':
buf = ReplayBuffer(size) buf = ReplayBuffer(size)
for i in range(int(size * 1.5)): for i in range(int(size * 1.5)):
buf.add(Batch(obs=0, act=0, rew=i + 1, done=np.random.randint(3) == 0)) buf.add(Batch(obs=0, act=0, rew=i + 1, done=np.random.randint(3) == 0,
info={"TimeLimit.truncated": i % 33 == 0}))
batch, indice = buf.sample(256) batch, indice = buf.sample(256)
def vanilla(): def vanilla():
@ -181,4 +251,5 @@ def test_nstep_returns(size=10000):
if __name__ == '__main__': if __name__ == '__main__':
test_nstep_returns() test_nstep_returns()
test_nstep_returns_with_timelimit()
test_episodic_returns() test_episodic_returns()

View File

@ -83,7 +83,6 @@ def test_ddpg(args=get_args()):
tau=args.tau, gamma=args.gamma, tau=args.tau, gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise), exploration_noise=GaussianNoise(sigma=args.exploration_noise),
reward_normalization=args.rew_norm, reward_normalization=args.rew_norm,
ignore_done=args.ignore_done,
estimation_step=args.n_step) estimation_step=args.n_step)
# collector # collector
train_collector = Collector( train_collector = Collector(

View File

@ -91,7 +91,6 @@ def test_sac_with_il(args=get_args()):
action_range=[env.action_space.low[0], env.action_space.high[0]], action_range=[env.action_space.low[0], env.action_space.high[0]],
tau=args.tau, gamma=args.gamma, alpha=args.alpha, tau=args.tau, gamma=args.gamma, alpha=args.alpha,
reward_normalization=args.rew_norm, reward_normalization=args.rew_norm,
ignore_done=args.ignore_done,
estimation_step=args.n_step) estimation_step=args.n_step)
# collector # collector
train_collector = Collector( train_collector = Collector(

View File

@ -95,7 +95,6 @@ def test_td3(args=get_args()):
update_actor_freq=args.update_actor_freq, update_actor_freq=args.update_actor_freq,
noise_clip=args.noise_clip, noise_clip=args.noise_clip,
reward_normalization=args.rew_norm, reward_normalization=args.rew_norm,
ignore_done=args.ignore_done,
estimation_step=args.n_step) estimation_step=args.n_step)
# collector # collector
train_collector = Collector( train_collector = Collector(

View File

@ -116,8 +116,8 @@ def test_c51(args=get_args()):
result = offpolicy_trainer( result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn, args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
stop_fn=stop_fn, save_fn=save_fn, logger=logger) test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -88,8 +88,7 @@ def test_discrete_sac(args=get_args()):
policy = DiscreteSACPolicy( policy = DiscreteSACPolicy(
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim, actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, args.alpha, args.tau, args.gamma, args.alpha,
reward_normalization=args.rew_norm, reward_normalization=args.rew_norm)
ignore_done=args.ignore_done)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, policy, train_envs,

View File

@ -203,7 +203,12 @@ class BasePolicy(ABC, nn.Module):
:return: A bool type numpy.ndarray in the same shape with indice. "True" means :return: A bool type numpy.ndarray in the same shape with indice. "True" means
"obs_next" of that buffer[indice] is valid. "obs_next" of that buffer[indice] is valid.
""" """
return ~buffer.done[indice].astype(np.bool) mask = ~buffer.done[indice].astype(np.bool)
# info['TimeLimit.truncated'] will be set to True if 'done' flag is generated
# because of timelimit of environments. Checkout gym.wrappers.TimeLimit.
if hasattr(buffer, 'info') and 'TimeLimit.truncated' in buffer.info:
mask = mask | buffer.info['TimeLimit.truncated'][indice]
return mask
@staticmethod @staticmethod
def compute_episodic_return( def compute_episodic_return(
@ -377,7 +382,7 @@ def _nstep_return(
gammas = np.full(indices[0].shape, n_step) gammas = np.full(indices[0].shape, n_step)
for n in range(n_step - 1, -1, -1): for n in range(n_step - 1, -1, -1):
now = indices[n] now = indices[n]
gammas[end_flag[now] > 0] = n gammas[end_flag[now] > 0] = n + 1
returns[end_flag[now] > 0] = 0.0 returns[end_flag[now] > 0] = 0.0
returns = (rew[now].reshape(bsz, 1) - mean) / std + gamma * returns returns = (rew[now].reshape(bsz, 1) - mean) / std + gamma * returns
target_q = target_q * gamma_buffer[gammas].reshape(bsz, 1) + returns target_q = target_q * gamma_buffer[gammas].reshape(bsz, 1) + returns

View File

@ -26,8 +26,6 @@ class DDPGPolicy(BasePolicy):
add to the action, defaults to ``GaussianNoise(sigma=0.1)``. add to the action, defaults to ``GaussianNoise(sigma=0.1)``.
:param bool reward_normalization: normalize the reward to Normal(0, 1), :param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to False. defaults to False.
:param bool ignore_done: ignore the done flag while training the policy,
defaults to False.
:param int estimation_step: greater than 1, the number of steps to look :param int estimation_step: greater than 1, the number of steps to look
ahead. ahead.
@ -48,7 +46,6 @@ class DDPGPolicy(BasePolicy):
gamma: float = 0.99, gamma: float = 0.99,
exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1), exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
reward_normalization: bool = False, reward_normalization: bool = False,
ignore_done: bool = False,
estimation_step: int = 1, estimation_step: int = 1,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
@ -73,7 +70,6 @@ class DDPGPolicy(BasePolicy):
self._action_scale = (action_range[1] - action_range[0]) / 2.0 self._action_scale = (action_range[1] - action_range[0]) / 2.0
# it is only a little difference to use GaussianNoise # it is only a little difference to use GaussianNoise
# self.noise = OUNoise() # self.noise = OUNoise()
self._rm_done = ignore_done
self._rew_norm = reward_normalization self._rew_norm = reward_normalization
assert estimation_step > 0, "estimation_step should be greater than 0" assert estimation_step > 0, "estimation_step should be greater than 0"
self._n_step = estimation_step self._n_step = estimation_step
@ -110,8 +106,6 @@ class DDPGPolicy(BasePolicy):
def process_fn( def process_fn(
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
) -> Batch: ) -> Batch:
if self._rm_done:
batch.done = batch.done * 0.0
batch = self.compute_nstep_return( batch = self.compute_nstep_return(
batch, buffer, indice, self._target_q, batch, buffer, indice, self._target_q,
self._gamma, self._n_step, self._rew_norm) self._gamma, self._n_step, self._rew_norm)

View File

@ -28,8 +28,6 @@ class DiscreteSACPolicy(SACPolicy):
alpha is automatatically tuned. alpha is automatatically tuned.
:param bool reward_normalization: normalize the reward to Normal(0, 1), :param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to ``False``. defaults to ``False``.
:param bool ignore_done: ignore the done flag while training the policy,
defaults to ``False``.
.. seealso:: .. seealso::
@ -51,13 +49,12 @@ class DiscreteSACPolicy(SACPolicy):
float, Tuple[float, torch.Tensor, torch.optim.Optimizer] float, Tuple[float, torch.Tensor, torch.optim.Optimizer]
] = 0.2, ] = 0.2,
reward_normalization: bool = False, reward_normalization: bool = False,
ignore_done: bool = False,
estimation_step: int = 1, estimation_step: int = 1,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__(actor, actor_optim, critic1, critic1_optim, critic2, super().__init__(actor, actor_optim, critic1, critic1_optim, critic2,
critic2_optim, (-np.inf, np.inf), tau, gamma, alpha, critic2_optim, (-np.inf, np.inf), tau, gamma, alpha,
reward_normalization, ignore_done, estimation_step, reward_normalization, estimation_step,
**kwargs) **kwargs)
self._alpha: Union[float, torch.Tensor] self._alpha: Union[float, torch.Tensor]

View File

@ -34,8 +34,6 @@ class SACPolicy(DDPGPolicy):
alpha is automatatically tuned. alpha is automatatically tuned.
:param bool reward_normalization: normalize the reward to Normal(0, 1), :param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to False. defaults to False.
:param bool ignore_done: ignore the done flag while training the policy,
defaults to False.
:param BaseNoise exploration_noise: add a noise to action for exploration, :param BaseNoise exploration_noise: add a noise to action for exploration,
defaults to None. This is useful when solving hard-exploration problem. defaults to None. This is useful when solving hard-exploration problem.
:param bool deterministic_eval: whether to use deterministic action (mean :param bool deterministic_eval: whether to use deterministic action (mean
@ -63,14 +61,13 @@ class SACPolicy(DDPGPolicy):
float, Tuple[float, torch.Tensor, torch.optim.Optimizer] float, Tuple[float, torch.Tensor, torch.optim.Optimizer]
] = 0.2, ] = 0.2,
reward_normalization: bool = False, reward_normalization: bool = False,
ignore_done: bool = False,
estimation_step: int = 1, estimation_step: int = 1,
exploration_noise: Optional[BaseNoise] = None, exploration_noise: Optional[BaseNoise] = None,
deterministic_eval: bool = True, deterministic_eval: bool = True,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__(None, None, None, None, action_range, tau, gamma, super().__init__(None, None, None, None, action_range, tau, gamma,
exploration_noise, reward_normalization, ignore_done, exploration_noise, reward_normalization,
estimation_step, **kwargs) estimation_step, **kwargs)
self.actor, self.actor_optim = actor, actor_optim self.actor, self.actor_optim = actor, actor_optim
self.critic1, self.critic1_old = critic1, deepcopy(critic1) self.critic1, self.critic1_old = critic1, deepcopy(critic1)

View File

@ -37,8 +37,6 @@ class TD3Policy(DDPGPolicy):
network, default to 0.5. network, default to 0.5.
:param bool reward_normalization: normalize the reward to Normal(0, 1), :param bool reward_normalization: normalize the reward to Normal(0, 1),
defaults to False. defaults to False.
:param bool ignore_done: ignore the done flag while training the policy,
defaults to False.
.. seealso:: .. seealso::
@ -62,13 +60,12 @@ class TD3Policy(DDPGPolicy):
update_actor_freq: int = 2, update_actor_freq: int = 2,
noise_clip: float = 0.5, noise_clip: float = 0.5,
reward_normalization: bool = False, reward_normalization: bool = False,
ignore_done: bool = False,
estimation_step: int = 1, estimation_step: int = 1,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
super().__init__(actor, actor_optim, None, None, action_range, super().__init__(actor, actor_optim, None, None, action_range, tau, gamma,
tau, gamma, exploration_noise, reward_normalization, exploration_noise, reward_normalization,
ignore_done, estimation_step, **kwargs) estimation_step, **kwargs)
self.critic1, self.critic1_old = critic1, deepcopy(critic1) self.critic1, self.critic1_old = critic1, deepcopy(critic1)
self.critic1_old.eval() self.critic1_old.eval()
self.critic1_optim = critic1_optim self.critic1_optim = critic1_optim

View File

@ -138,7 +138,7 @@ class BasicLogger(BaseLogger):
def log_update_data(self, update_result: dict, step: int) -> None: def log_update_data(self, update_result: dict, step: int) -> None:
if step - self.last_log_update_step >= self.update_interval: if step - self.last_log_update_step >= self.update_interval:
for k, v in update_result.items(): for k, v in update_result.items():
self.write("train/" + k, step, v) # save in train/ self.write(k, step, v)
self.last_log_update_step = step self.last_log_update_step = step