Add Timelimit trick to optimize policies (#296)
* consider timelimit.truncated in calculating returns by default * remove ignore_done
This commit is contained in:
parent
9b61bc620c
commit
3108b9db0d
@ -92,7 +92,7 @@ def test_sac(args=get_args()):
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
|
||||
reward_normalization=args.rew_norm, ignore_done=True,
|
||||
reward_normalization=args.rew_norm,
|
||||
exploration_noise=OUNoise(0.0, args.noise_std))
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
|
||||
@ -73,7 +73,7 @@ def test_ddpg(args=get_args()):
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
tau=args.tau, gamma=args.gamma,
|
||||
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
|
||||
reward_normalization=True, ignore_done=True)
|
||||
reward_normalization=True)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs,
|
||||
|
||||
@ -81,7 +81,7 @@ def test_td3(args=get_args()):
|
||||
policy_noise=args.policy_noise,
|
||||
update_actor_freq=args.update_actor_freq,
|
||||
noise_clip=args.noise_clip,
|
||||
reward_normalization=True, ignore_done=True)
|
||||
reward_normalization=True)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs,
|
||||
|
||||
@ -81,7 +81,7 @@ def test_sac(args=get_args()):
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
|
||||
reward_normalization=True, ignore_done=True)
|
||||
reward_normalization=True)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs,
|
||||
|
||||
@ -86,7 +86,7 @@ def test_td3(args=get_args()):
|
||||
policy_noise=args.policy_noise,
|
||||
update_actor_freq=args.update_actor_freq,
|
||||
noise_clip=args.noise_clip,
|
||||
reward_normalization=True, ignore_done=True)
|
||||
reward_normalization=True)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs,
|
||||
|
||||
@ -24,6 +24,8 @@ def test_episodic_returns(size=2560):
|
||||
batch = Batch(
|
||||
done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]),
|
||||
rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]),
|
||||
info=Batch({'TimeLimit.truncated':
|
||||
np.array([False, False, False, False, False, True, False, False])})
|
||||
)
|
||||
for b in batch:
|
||||
b.obs = b.act = 1
|
||||
@ -69,6 +71,24 @@ def test_episodic_returns(size=2560):
|
||||
474.2876, 390.1027, 299.476, 202.])
|
||||
assert np.allclose(ret.returns, returns)
|
||||
buf.reset()
|
||||
batch = Batch(
|
||||
done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]),
|
||||
rew=np.array([101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202]),
|
||||
info=Batch({'TimeLimit.truncated':
|
||||
np.array([False, False, False, True, False, False,
|
||||
False, True, False, False, False, False])})
|
||||
)
|
||||
for b in batch:
|
||||
b.obs = b.act = 1
|
||||
buf.add(b)
|
||||
v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3])
|
||||
ret = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95)
|
||||
returns = np.array([
|
||||
454.0109, 375.2386, 290.3669, 199.01,
|
||||
462.9138, 381.3571, 293.5248, 199.02,
|
||||
474.2876, 390.1027, 299.476, 202.])
|
||||
assert np.allclose(ret.returns, returns)
|
||||
|
||||
if __name__ == '__main__':
|
||||
buf = ReplayBuffer(size)
|
||||
batch = Batch(
|
||||
@ -106,15 +126,19 @@ def compute_nstep_return_base(nstep, gamma, buffer, indice):
|
||||
buf_len = len(buffer)
|
||||
for i in range(len(indice)):
|
||||
flag, r = False, 0.
|
||||
real_step_n = nstep
|
||||
for n in range(nstep):
|
||||
idx = (indice[i] + n) % buf_len
|
||||
r += buffer.rew[idx] * gamma ** n
|
||||
if buffer.done[idx]:
|
||||
flag = True
|
||||
if not (hasattr(buffer, 'info') and
|
||||
buffer.info['TimeLimit.truncated'][idx]):
|
||||
flag = True
|
||||
real_step_n = n + 1
|
||||
break
|
||||
if not flag:
|
||||
idx = (indice[i] + nstep - 1) % buf_len
|
||||
r += to_numpy(target_q_fn(buffer, idx)) * gamma ** nstep
|
||||
idx = (indice[i] + real_step_n - 1) % buf_len
|
||||
r += to_numpy(target_q_fn(buffer, idx)) * gamma ** real_step_n
|
||||
returns[i] = r
|
||||
return returns
|
||||
|
||||
@ -161,10 +185,56 @@ def test_nstep_returns(size=10000):
|
||||
).pop('returns'))
|
||||
assert np.allclose(returns_multidim, returns[:, np.newaxis])
|
||||
|
||||
|
||||
def test_nstep_returns_with_timelimit(size=10000):
|
||||
buf = ReplayBuffer(10)
|
||||
for i in range(12):
|
||||
buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % 4 == 3,
|
||||
info={"TimeLimit.truncated": i == 3}))
|
||||
batch, indice = buf.sample(0)
|
||||
assert np.allclose(indice, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1])
|
||||
# rew: [11, 12, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
# done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0]
|
||||
# test nstep = 1
|
||||
returns = to_numpy(BasePolicy.compute_nstep_return(
|
||||
batch, buf, indice, target_q_fn, gamma=.1, n_step=1
|
||||
).pop('returns').reshape(-1))
|
||||
assert np.allclose(returns, [2.6, 3.6, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12])
|
||||
r_ = compute_nstep_return_base(1, .1, buf, indice)
|
||||
assert np.allclose(returns, r_), (r_, returns)
|
||||
returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
|
||||
batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=1
|
||||
).pop('returns'))
|
||||
assert np.allclose(returns_multidim, returns[:, np.newaxis])
|
||||
# test nstep = 2
|
||||
returns = to_numpy(BasePolicy.compute_nstep_return(
|
||||
batch, buf, indice, target_q_fn, gamma=.1, n_step=2
|
||||
).pop('returns').reshape(-1))
|
||||
assert np.allclose(returns, [3.36, 3.6, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12])
|
||||
r_ = compute_nstep_return_base(2, .1, buf, indice)
|
||||
assert np.allclose(returns, r_)
|
||||
returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
|
||||
batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=2
|
||||
).pop('returns'))
|
||||
assert np.allclose(returns_multidim, returns[:, np.newaxis])
|
||||
# test nstep = 10
|
||||
returns = to_numpy(BasePolicy.compute_nstep_return(
|
||||
batch, buf, indice, target_q_fn, gamma=.1, n_step=10
|
||||
).pop('returns').reshape(-1))
|
||||
assert np.allclose(returns, [3.36, 3.6, 5.678, 6.78,
|
||||
7.8, 8, 10.122, 11.22, 12.2, 12])
|
||||
r_ = compute_nstep_return_base(10, .1, buf, indice)
|
||||
assert np.allclose(returns, r_)
|
||||
returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
|
||||
batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=10
|
||||
).pop('returns'))
|
||||
assert np.allclose(returns_multidim, returns[:, np.newaxis])
|
||||
|
||||
if __name__ == '__main__':
|
||||
buf = ReplayBuffer(size)
|
||||
for i in range(int(size * 1.5)):
|
||||
buf.add(Batch(obs=0, act=0, rew=i + 1, done=np.random.randint(3) == 0))
|
||||
buf.add(Batch(obs=0, act=0, rew=i + 1, done=np.random.randint(3) == 0,
|
||||
info={"TimeLimit.truncated": i % 33 == 0}))
|
||||
batch, indice = buf.sample(256)
|
||||
|
||||
def vanilla():
|
||||
@ -181,4 +251,5 @@ def test_nstep_returns(size=10000):
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_nstep_returns()
|
||||
test_nstep_returns_with_timelimit()
|
||||
test_episodic_returns()
|
||||
|
||||
@ -83,7 +83,6 @@ def test_ddpg(args=get_args()):
|
||||
tau=args.tau, gamma=args.gamma,
|
||||
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
|
||||
reward_normalization=args.rew_norm,
|
||||
ignore_done=args.ignore_done,
|
||||
estimation_step=args.n_step)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
|
||||
@ -91,7 +91,6 @@ def test_sac_with_il(args=get_args()):
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
|
||||
reward_normalization=args.rew_norm,
|
||||
ignore_done=args.ignore_done,
|
||||
estimation_step=args.n_step)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
|
||||
@ -95,7 +95,6 @@ def test_td3(args=get_args()):
|
||||
update_actor_freq=args.update_actor_freq,
|
||||
noise_clip=args.noise_clip,
|
||||
reward_normalization=args.rew_norm,
|
||||
ignore_done=args.ignore_done,
|
||||
estimation_step=args.n_step)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
|
||||
@ -116,8 +116,8 @@ def test_c51(args=get_args()):
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||
stop_fn=stop_fn, save_fn=save_fn, logger=logger)
|
||||
args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
|
||||
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
|
||||
|
||||
assert stop_fn(result['best_reward'])
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -88,8 +88,7 @@ def test_discrete_sac(args=get_args()):
|
||||
policy = DiscreteSACPolicy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
args.tau, args.gamma, args.alpha,
|
||||
reward_normalization=args.rew_norm,
|
||||
ignore_done=args.ignore_done)
|
||||
reward_normalization=args.rew_norm)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs,
|
||||
|
||||
@ -203,7 +203,12 @@ class BasePolicy(ABC, nn.Module):
|
||||
:return: A bool type numpy.ndarray in the same shape with indice. "True" means
|
||||
"obs_next" of that buffer[indice] is valid.
|
||||
"""
|
||||
return ~buffer.done[indice].astype(np.bool)
|
||||
mask = ~buffer.done[indice].astype(np.bool)
|
||||
# info['TimeLimit.truncated'] will be set to True if 'done' flag is generated
|
||||
# because of timelimit of environments. Checkout gym.wrappers.TimeLimit.
|
||||
if hasattr(buffer, 'info') and 'TimeLimit.truncated' in buffer.info:
|
||||
mask = mask | buffer.info['TimeLimit.truncated'][indice]
|
||||
return mask
|
||||
|
||||
@staticmethod
|
||||
def compute_episodic_return(
|
||||
@ -377,7 +382,7 @@ def _nstep_return(
|
||||
gammas = np.full(indices[0].shape, n_step)
|
||||
for n in range(n_step - 1, -1, -1):
|
||||
now = indices[n]
|
||||
gammas[end_flag[now] > 0] = n
|
||||
gammas[end_flag[now] > 0] = n + 1
|
||||
returns[end_flag[now] > 0] = 0.0
|
||||
returns = (rew[now].reshape(bsz, 1) - mean) / std + gamma * returns
|
||||
target_q = target_q * gamma_buffer[gammas].reshape(bsz, 1) + returns
|
||||
|
||||
@ -26,8 +26,6 @@ class DDPGPolicy(BasePolicy):
|
||||
add to the action, defaults to ``GaussianNoise(sigma=0.1)``.
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||
defaults to False.
|
||||
:param bool ignore_done: ignore the done flag while training the policy,
|
||||
defaults to False.
|
||||
:param int estimation_step: greater than 1, the number of steps to look
|
||||
ahead.
|
||||
|
||||
@ -48,7 +46,6 @@ class DDPGPolicy(BasePolicy):
|
||||
gamma: float = 0.99,
|
||||
exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
|
||||
reward_normalization: bool = False,
|
||||
ignore_done: bool = False,
|
||||
estimation_step: int = 1,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@ -73,7 +70,6 @@ class DDPGPolicy(BasePolicy):
|
||||
self._action_scale = (action_range[1] - action_range[0]) / 2.0
|
||||
# it is only a little difference to use GaussianNoise
|
||||
# self.noise = OUNoise()
|
||||
self._rm_done = ignore_done
|
||||
self._rew_norm = reward_normalization
|
||||
assert estimation_step > 0, "estimation_step should be greater than 0"
|
||||
self._n_step = estimation_step
|
||||
@ -110,8 +106,6 @@ class DDPGPolicy(BasePolicy):
|
||||
def process_fn(
|
||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||
) -> Batch:
|
||||
if self._rm_done:
|
||||
batch.done = batch.done * 0.0
|
||||
batch = self.compute_nstep_return(
|
||||
batch, buffer, indice, self._target_q,
|
||||
self._gamma, self._n_step, self._rew_norm)
|
||||
|
||||
@ -28,8 +28,6 @@ class DiscreteSACPolicy(SACPolicy):
|
||||
alpha is automatatically tuned.
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||
defaults to ``False``.
|
||||
:param bool ignore_done: ignore the done flag while training the policy,
|
||||
defaults to ``False``.
|
||||
|
||||
.. seealso::
|
||||
|
||||
@ -51,13 +49,12 @@ class DiscreteSACPolicy(SACPolicy):
|
||||
float, Tuple[float, torch.Tensor, torch.optim.Optimizer]
|
||||
] = 0.2,
|
||||
reward_normalization: bool = False,
|
||||
ignore_done: bool = False,
|
||||
estimation_step: int = 1,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(actor, actor_optim, critic1, critic1_optim, critic2,
|
||||
critic2_optim, (-np.inf, np.inf), tau, gamma, alpha,
|
||||
reward_normalization, ignore_done, estimation_step,
|
||||
reward_normalization, estimation_step,
|
||||
**kwargs)
|
||||
self._alpha: Union[float, torch.Tensor]
|
||||
|
||||
|
||||
@ -34,8 +34,6 @@ class SACPolicy(DDPGPolicy):
|
||||
alpha is automatatically tuned.
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||
defaults to False.
|
||||
:param bool ignore_done: ignore the done flag while training the policy,
|
||||
defaults to False.
|
||||
:param BaseNoise exploration_noise: add a noise to action for exploration,
|
||||
defaults to None. This is useful when solving hard-exploration problem.
|
||||
:param bool deterministic_eval: whether to use deterministic action (mean
|
||||
@ -63,14 +61,13 @@ class SACPolicy(DDPGPolicy):
|
||||
float, Tuple[float, torch.Tensor, torch.optim.Optimizer]
|
||||
] = 0.2,
|
||||
reward_normalization: bool = False,
|
||||
ignore_done: bool = False,
|
||||
estimation_step: int = 1,
|
||||
exploration_noise: Optional[BaseNoise] = None,
|
||||
deterministic_eval: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(None, None, None, None, action_range, tau, gamma,
|
||||
exploration_noise, reward_normalization, ignore_done,
|
||||
exploration_noise, reward_normalization,
|
||||
estimation_step, **kwargs)
|
||||
self.actor, self.actor_optim = actor, actor_optim
|
||||
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
||||
|
||||
@ -37,8 +37,6 @@ class TD3Policy(DDPGPolicy):
|
||||
network, default to 0.5.
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||
defaults to False.
|
||||
:param bool ignore_done: ignore the done flag while training the policy,
|
||||
defaults to False.
|
||||
|
||||
.. seealso::
|
||||
|
||||
@ -62,13 +60,12 @@ class TD3Policy(DDPGPolicy):
|
||||
update_actor_freq: int = 2,
|
||||
noise_clip: float = 0.5,
|
||||
reward_normalization: bool = False,
|
||||
ignore_done: bool = False,
|
||||
estimation_step: int = 1,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(actor, actor_optim, None, None, action_range,
|
||||
tau, gamma, exploration_noise, reward_normalization,
|
||||
ignore_done, estimation_step, **kwargs)
|
||||
super().__init__(actor, actor_optim, None, None, action_range, tau, gamma,
|
||||
exploration_noise, reward_normalization,
|
||||
estimation_step, **kwargs)
|
||||
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
||||
self.critic1_old.eval()
|
||||
self.critic1_optim = critic1_optim
|
||||
|
||||
@ -138,7 +138,7 @@ class BasicLogger(BaseLogger):
|
||||
def log_update_data(self, update_result: dict, step: int) -> None:
|
||||
if step - self.last_log_update_step >= self.update_interval:
|
||||
for k, v in update_result.items():
|
||||
self.write("train/" + k, step, v) # save in train/
|
||||
self.write(k, step, v)
|
||||
self.last_log_update_step = step
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user