Add Timelimit trick to optimize policies (#296)
* consider timelimit.truncated in calculating returns by default * remove ignore_done
This commit is contained in:
parent
9b61bc620c
commit
3108b9db0d
@ -92,7 +92,7 @@ def test_sac(args=get_args()):
|
|||||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||||
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
|
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
|
||||||
reward_normalization=args.rew_norm, ignore_done=True,
|
reward_normalization=args.rew_norm,
|
||||||
exploration_noise=OUNoise(0.0, args.noise_std))
|
exploration_noise=OUNoise(0.0, args.noise_std))
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
|
|||||||
@ -73,7 +73,7 @@ def test_ddpg(args=get_args()):
|
|||||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||||
tau=args.tau, gamma=args.gamma,
|
tau=args.tau, gamma=args.gamma,
|
||||||
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
|
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
|
||||||
reward_normalization=True, ignore_done=True)
|
reward_normalization=True)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs,
|
policy, train_envs,
|
||||||
|
|||||||
@ -81,7 +81,7 @@ def test_td3(args=get_args()):
|
|||||||
policy_noise=args.policy_noise,
|
policy_noise=args.policy_noise,
|
||||||
update_actor_freq=args.update_actor_freq,
|
update_actor_freq=args.update_actor_freq,
|
||||||
noise_clip=args.noise_clip,
|
noise_clip=args.noise_clip,
|
||||||
reward_normalization=True, ignore_done=True)
|
reward_normalization=True)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs,
|
policy, train_envs,
|
||||||
|
|||||||
@ -81,7 +81,7 @@ def test_sac(args=get_args()):
|
|||||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||||
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
|
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
|
||||||
reward_normalization=True, ignore_done=True)
|
reward_normalization=True)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs,
|
policy, train_envs,
|
||||||
|
|||||||
@ -86,7 +86,7 @@ def test_td3(args=get_args()):
|
|||||||
policy_noise=args.policy_noise,
|
policy_noise=args.policy_noise,
|
||||||
update_actor_freq=args.update_actor_freq,
|
update_actor_freq=args.update_actor_freq,
|
||||||
noise_clip=args.noise_clip,
|
noise_clip=args.noise_clip,
|
||||||
reward_normalization=True, ignore_done=True)
|
reward_normalization=True)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs,
|
policy, train_envs,
|
||||||
|
|||||||
@ -24,6 +24,8 @@ def test_episodic_returns(size=2560):
|
|||||||
batch = Batch(
|
batch = Batch(
|
||||||
done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]),
|
done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]),
|
||||||
rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]),
|
rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]),
|
||||||
|
info=Batch({'TimeLimit.truncated':
|
||||||
|
np.array([False, False, False, False, False, True, False, False])})
|
||||||
)
|
)
|
||||||
for b in batch:
|
for b in batch:
|
||||||
b.obs = b.act = 1
|
b.obs = b.act = 1
|
||||||
@ -69,6 +71,24 @@ def test_episodic_returns(size=2560):
|
|||||||
474.2876, 390.1027, 299.476, 202.])
|
474.2876, 390.1027, 299.476, 202.])
|
||||||
assert np.allclose(ret.returns, returns)
|
assert np.allclose(ret.returns, returns)
|
||||||
buf.reset()
|
buf.reset()
|
||||||
|
batch = Batch(
|
||||||
|
done=np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]),
|
||||||
|
rew=np.array([101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202]),
|
||||||
|
info=Batch({'TimeLimit.truncated':
|
||||||
|
np.array([False, False, False, True, False, False,
|
||||||
|
False, True, False, False, False, False])})
|
||||||
|
)
|
||||||
|
for b in batch:
|
||||||
|
b.obs = b.act = 1
|
||||||
|
buf.add(b)
|
||||||
|
v = np.array([2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10, -3])
|
||||||
|
ret = fn(batch, buf, buf.sample_index(0), v, gamma=0.99, gae_lambda=0.95)
|
||||||
|
returns = np.array([
|
||||||
|
454.0109, 375.2386, 290.3669, 199.01,
|
||||||
|
462.9138, 381.3571, 293.5248, 199.02,
|
||||||
|
474.2876, 390.1027, 299.476, 202.])
|
||||||
|
assert np.allclose(ret.returns, returns)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
buf = ReplayBuffer(size)
|
buf = ReplayBuffer(size)
|
||||||
batch = Batch(
|
batch = Batch(
|
||||||
@ -106,15 +126,19 @@ def compute_nstep_return_base(nstep, gamma, buffer, indice):
|
|||||||
buf_len = len(buffer)
|
buf_len = len(buffer)
|
||||||
for i in range(len(indice)):
|
for i in range(len(indice)):
|
||||||
flag, r = False, 0.
|
flag, r = False, 0.
|
||||||
|
real_step_n = nstep
|
||||||
for n in range(nstep):
|
for n in range(nstep):
|
||||||
idx = (indice[i] + n) % buf_len
|
idx = (indice[i] + n) % buf_len
|
||||||
r += buffer.rew[idx] * gamma ** n
|
r += buffer.rew[idx] * gamma ** n
|
||||||
if buffer.done[idx]:
|
if buffer.done[idx]:
|
||||||
flag = True
|
if not (hasattr(buffer, 'info') and
|
||||||
|
buffer.info['TimeLimit.truncated'][idx]):
|
||||||
|
flag = True
|
||||||
|
real_step_n = n + 1
|
||||||
break
|
break
|
||||||
if not flag:
|
if not flag:
|
||||||
idx = (indice[i] + nstep - 1) % buf_len
|
idx = (indice[i] + real_step_n - 1) % buf_len
|
||||||
r += to_numpy(target_q_fn(buffer, idx)) * gamma ** nstep
|
r += to_numpy(target_q_fn(buffer, idx)) * gamma ** real_step_n
|
||||||
returns[i] = r
|
returns[i] = r
|
||||||
return returns
|
return returns
|
||||||
|
|
||||||
@ -161,10 +185,56 @@ def test_nstep_returns(size=10000):
|
|||||||
).pop('returns'))
|
).pop('returns'))
|
||||||
assert np.allclose(returns_multidim, returns[:, np.newaxis])
|
assert np.allclose(returns_multidim, returns[:, np.newaxis])
|
||||||
|
|
||||||
|
|
||||||
|
def test_nstep_returns_with_timelimit(size=10000):
|
||||||
|
buf = ReplayBuffer(10)
|
||||||
|
for i in range(12):
|
||||||
|
buf.add(Batch(obs=0, act=0, rew=i + 1, done=i % 4 == 3,
|
||||||
|
info={"TimeLimit.truncated": i == 3}))
|
||||||
|
batch, indice = buf.sample(0)
|
||||||
|
assert np.allclose(indice, [2, 3, 4, 5, 6, 7, 8, 9, 0, 1])
|
||||||
|
# rew: [11, 12, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||||
|
# done: [ 0, 1, 0, 1, 0, 0, 0, 1, 0, 0]
|
||||||
|
# test nstep = 1
|
||||||
|
returns = to_numpy(BasePolicy.compute_nstep_return(
|
||||||
|
batch, buf, indice, target_q_fn, gamma=.1, n_step=1
|
||||||
|
).pop('returns').reshape(-1))
|
||||||
|
assert np.allclose(returns, [2.6, 3.6, 4.4, 5.3, 6.2, 8, 8, 8.9, 9.8, 12])
|
||||||
|
r_ = compute_nstep_return_base(1, .1, buf, indice)
|
||||||
|
assert np.allclose(returns, r_), (r_, returns)
|
||||||
|
returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
|
||||||
|
batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=1
|
||||||
|
).pop('returns'))
|
||||||
|
assert np.allclose(returns_multidim, returns[:, np.newaxis])
|
||||||
|
# test nstep = 2
|
||||||
|
returns = to_numpy(BasePolicy.compute_nstep_return(
|
||||||
|
batch, buf, indice, target_q_fn, gamma=.1, n_step=2
|
||||||
|
).pop('returns').reshape(-1))
|
||||||
|
assert np.allclose(returns, [3.36, 3.6, 5.53, 6.62, 7.8, 8, 9.89, 10.98, 12.2, 12])
|
||||||
|
r_ = compute_nstep_return_base(2, .1, buf, indice)
|
||||||
|
assert np.allclose(returns, r_)
|
||||||
|
returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
|
||||||
|
batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=2
|
||||||
|
).pop('returns'))
|
||||||
|
assert np.allclose(returns_multidim, returns[:, np.newaxis])
|
||||||
|
# test nstep = 10
|
||||||
|
returns = to_numpy(BasePolicy.compute_nstep_return(
|
||||||
|
batch, buf, indice, target_q_fn, gamma=.1, n_step=10
|
||||||
|
).pop('returns').reshape(-1))
|
||||||
|
assert np.allclose(returns, [3.36, 3.6, 5.678, 6.78,
|
||||||
|
7.8, 8, 10.122, 11.22, 12.2, 12])
|
||||||
|
r_ = compute_nstep_return_base(10, .1, buf, indice)
|
||||||
|
assert np.allclose(returns, r_)
|
||||||
|
returns_multidim = to_numpy(BasePolicy.compute_nstep_return(
|
||||||
|
batch, buf, indice, target_q_fn_multidim, gamma=.1, n_step=10
|
||||||
|
).pop('returns'))
|
||||||
|
assert np.allclose(returns_multidim, returns[:, np.newaxis])
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
buf = ReplayBuffer(size)
|
buf = ReplayBuffer(size)
|
||||||
for i in range(int(size * 1.5)):
|
for i in range(int(size * 1.5)):
|
||||||
buf.add(Batch(obs=0, act=0, rew=i + 1, done=np.random.randint(3) == 0))
|
buf.add(Batch(obs=0, act=0, rew=i + 1, done=np.random.randint(3) == 0,
|
||||||
|
info={"TimeLimit.truncated": i % 33 == 0}))
|
||||||
batch, indice = buf.sample(256)
|
batch, indice = buf.sample(256)
|
||||||
|
|
||||||
def vanilla():
|
def vanilla():
|
||||||
@ -181,4 +251,5 @@ def test_nstep_returns(size=10000):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_nstep_returns()
|
test_nstep_returns()
|
||||||
|
test_nstep_returns_with_timelimit()
|
||||||
test_episodic_returns()
|
test_episodic_returns()
|
||||||
|
|||||||
@ -83,7 +83,6 @@ def test_ddpg(args=get_args()):
|
|||||||
tau=args.tau, gamma=args.gamma,
|
tau=args.tau, gamma=args.gamma,
|
||||||
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
|
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
|
||||||
reward_normalization=args.rew_norm,
|
reward_normalization=args.rew_norm,
|
||||||
ignore_done=args.ignore_done,
|
|
||||||
estimation_step=args.n_step)
|
estimation_step=args.n_step)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
|
|||||||
@ -91,7 +91,6 @@ def test_sac_with_il(args=get_args()):
|
|||||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||||
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
|
tau=args.tau, gamma=args.gamma, alpha=args.alpha,
|
||||||
reward_normalization=args.rew_norm,
|
reward_normalization=args.rew_norm,
|
||||||
ignore_done=args.ignore_done,
|
|
||||||
estimation_step=args.n_step)
|
estimation_step=args.n_step)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
|
|||||||
@ -95,7 +95,6 @@ def test_td3(args=get_args()):
|
|||||||
update_actor_freq=args.update_actor_freq,
|
update_actor_freq=args.update_actor_freq,
|
||||||
noise_clip=args.noise_clip,
|
noise_clip=args.noise_clip,
|
||||||
reward_normalization=args.rew_norm,
|
reward_normalization=args.rew_norm,
|
||||||
ignore_done=args.ignore_done,
|
|
||||||
estimation_step=args.n_step)
|
estimation_step=args.n_step)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
|
|||||||
@ -116,8 +116,8 @@ def test_c51(args=get_args()):
|
|||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
|
||||||
stop_fn=stop_fn, save_fn=save_fn, logger=logger)
|
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
|
||||||
|
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -88,8 +88,7 @@ def test_discrete_sac(args=get_args()):
|
|||||||
policy = DiscreteSACPolicy(
|
policy = DiscreteSACPolicy(
|
||||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||||
args.tau, args.gamma, args.alpha,
|
args.tau, args.gamma, args.alpha,
|
||||||
reward_normalization=args.rew_norm,
|
reward_normalization=args.rew_norm)
|
||||||
ignore_done=args.ignore_done)
|
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs,
|
policy, train_envs,
|
||||||
|
|||||||
@ -203,7 +203,12 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
:return: A bool type numpy.ndarray in the same shape with indice. "True" means
|
:return: A bool type numpy.ndarray in the same shape with indice. "True" means
|
||||||
"obs_next" of that buffer[indice] is valid.
|
"obs_next" of that buffer[indice] is valid.
|
||||||
"""
|
"""
|
||||||
return ~buffer.done[indice].astype(np.bool)
|
mask = ~buffer.done[indice].astype(np.bool)
|
||||||
|
# info['TimeLimit.truncated'] will be set to True if 'done' flag is generated
|
||||||
|
# because of timelimit of environments. Checkout gym.wrappers.TimeLimit.
|
||||||
|
if hasattr(buffer, 'info') and 'TimeLimit.truncated' in buffer.info:
|
||||||
|
mask = mask | buffer.info['TimeLimit.truncated'][indice]
|
||||||
|
return mask
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def compute_episodic_return(
|
def compute_episodic_return(
|
||||||
@ -377,7 +382,7 @@ def _nstep_return(
|
|||||||
gammas = np.full(indices[0].shape, n_step)
|
gammas = np.full(indices[0].shape, n_step)
|
||||||
for n in range(n_step - 1, -1, -1):
|
for n in range(n_step - 1, -1, -1):
|
||||||
now = indices[n]
|
now = indices[n]
|
||||||
gammas[end_flag[now] > 0] = n
|
gammas[end_flag[now] > 0] = n + 1
|
||||||
returns[end_flag[now] > 0] = 0.0
|
returns[end_flag[now] > 0] = 0.0
|
||||||
returns = (rew[now].reshape(bsz, 1) - mean) / std + gamma * returns
|
returns = (rew[now].reshape(bsz, 1) - mean) / std + gamma * returns
|
||||||
target_q = target_q * gamma_buffer[gammas].reshape(bsz, 1) + returns
|
target_q = target_q * gamma_buffer[gammas].reshape(bsz, 1) + returns
|
||||||
|
|||||||
@ -26,8 +26,6 @@ class DDPGPolicy(BasePolicy):
|
|||||||
add to the action, defaults to ``GaussianNoise(sigma=0.1)``.
|
add to the action, defaults to ``GaussianNoise(sigma=0.1)``.
|
||||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||||
defaults to False.
|
defaults to False.
|
||||||
:param bool ignore_done: ignore the done flag while training the policy,
|
|
||||||
defaults to False.
|
|
||||||
:param int estimation_step: greater than 1, the number of steps to look
|
:param int estimation_step: greater than 1, the number of steps to look
|
||||||
ahead.
|
ahead.
|
||||||
|
|
||||||
@ -48,7 +46,6 @@ class DDPGPolicy(BasePolicy):
|
|||||||
gamma: float = 0.99,
|
gamma: float = 0.99,
|
||||||
exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
|
exploration_noise: Optional[BaseNoise] = GaussianNoise(sigma=0.1),
|
||||||
reward_normalization: bool = False,
|
reward_normalization: bool = False,
|
||||||
ignore_done: bool = False,
|
|
||||||
estimation_step: int = 1,
|
estimation_step: int = 1,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -73,7 +70,6 @@ class DDPGPolicy(BasePolicy):
|
|||||||
self._action_scale = (action_range[1] - action_range[0]) / 2.0
|
self._action_scale = (action_range[1] - action_range[0]) / 2.0
|
||||||
# it is only a little difference to use GaussianNoise
|
# it is only a little difference to use GaussianNoise
|
||||||
# self.noise = OUNoise()
|
# self.noise = OUNoise()
|
||||||
self._rm_done = ignore_done
|
|
||||||
self._rew_norm = reward_normalization
|
self._rew_norm = reward_normalization
|
||||||
assert estimation_step > 0, "estimation_step should be greater than 0"
|
assert estimation_step > 0, "estimation_step should be greater than 0"
|
||||||
self._n_step = estimation_step
|
self._n_step = estimation_step
|
||||||
@ -110,8 +106,6 @@ class DDPGPolicy(BasePolicy):
|
|||||||
def process_fn(
|
def process_fn(
|
||||||
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray
|
||||||
) -> Batch:
|
) -> Batch:
|
||||||
if self._rm_done:
|
|
||||||
batch.done = batch.done * 0.0
|
|
||||||
batch = self.compute_nstep_return(
|
batch = self.compute_nstep_return(
|
||||||
batch, buffer, indice, self._target_q,
|
batch, buffer, indice, self._target_q,
|
||||||
self._gamma, self._n_step, self._rew_norm)
|
self._gamma, self._n_step, self._rew_norm)
|
||||||
|
|||||||
@ -28,8 +28,6 @@ class DiscreteSACPolicy(SACPolicy):
|
|||||||
alpha is automatatically tuned.
|
alpha is automatatically tuned.
|
||||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||||
defaults to ``False``.
|
defaults to ``False``.
|
||||||
:param bool ignore_done: ignore the done flag while training the policy,
|
|
||||||
defaults to ``False``.
|
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
@ -51,13 +49,12 @@ class DiscreteSACPolicy(SACPolicy):
|
|||||||
float, Tuple[float, torch.Tensor, torch.optim.Optimizer]
|
float, Tuple[float, torch.Tensor, torch.optim.Optimizer]
|
||||||
] = 0.2,
|
] = 0.2,
|
||||||
reward_normalization: bool = False,
|
reward_normalization: bool = False,
|
||||||
ignore_done: bool = False,
|
|
||||||
estimation_step: int = 1,
|
estimation_step: int = 1,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(actor, actor_optim, critic1, critic1_optim, critic2,
|
super().__init__(actor, actor_optim, critic1, critic1_optim, critic2,
|
||||||
critic2_optim, (-np.inf, np.inf), tau, gamma, alpha,
|
critic2_optim, (-np.inf, np.inf), tau, gamma, alpha,
|
||||||
reward_normalization, ignore_done, estimation_step,
|
reward_normalization, estimation_step,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
self._alpha: Union[float, torch.Tensor]
|
self._alpha: Union[float, torch.Tensor]
|
||||||
|
|
||||||
|
|||||||
@ -34,8 +34,6 @@ class SACPolicy(DDPGPolicy):
|
|||||||
alpha is automatatically tuned.
|
alpha is automatatically tuned.
|
||||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||||
defaults to False.
|
defaults to False.
|
||||||
:param bool ignore_done: ignore the done flag while training the policy,
|
|
||||||
defaults to False.
|
|
||||||
:param BaseNoise exploration_noise: add a noise to action for exploration,
|
:param BaseNoise exploration_noise: add a noise to action for exploration,
|
||||||
defaults to None. This is useful when solving hard-exploration problem.
|
defaults to None. This is useful when solving hard-exploration problem.
|
||||||
:param bool deterministic_eval: whether to use deterministic action (mean
|
:param bool deterministic_eval: whether to use deterministic action (mean
|
||||||
@ -63,14 +61,13 @@ class SACPolicy(DDPGPolicy):
|
|||||||
float, Tuple[float, torch.Tensor, torch.optim.Optimizer]
|
float, Tuple[float, torch.Tensor, torch.optim.Optimizer]
|
||||||
] = 0.2,
|
] = 0.2,
|
||||||
reward_normalization: bool = False,
|
reward_normalization: bool = False,
|
||||||
ignore_done: bool = False,
|
|
||||||
estimation_step: int = 1,
|
estimation_step: int = 1,
|
||||||
exploration_noise: Optional[BaseNoise] = None,
|
exploration_noise: Optional[BaseNoise] = None,
|
||||||
deterministic_eval: bool = True,
|
deterministic_eval: bool = True,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(None, None, None, None, action_range, tau, gamma,
|
super().__init__(None, None, None, None, action_range, tau, gamma,
|
||||||
exploration_noise, reward_normalization, ignore_done,
|
exploration_noise, reward_normalization,
|
||||||
estimation_step, **kwargs)
|
estimation_step, **kwargs)
|
||||||
self.actor, self.actor_optim = actor, actor_optim
|
self.actor, self.actor_optim = actor, actor_optim
|
||||||
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
||||||
|
|||||||
@ -37,8 +37,6 @@ class TD3Policy(DDPGPolicy):
|
|||||||
network, default to 0.5.
|
network, default to 0.5.
|
||||||
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
:param bool reward_normalization: normalize the reward to Normal(0, 1),
|
||||||
defaults to False.
|
defaults to False.
|
||||||
:param bool ignore_done: ignore the done flag while training the policy,
|
|
||||||
defaults to False.
|
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
@ -62,13 +60,12 @@ class TD3Policy(DDPGPolicy):
|
|||||||
update_actor_freq: int = 2,
|
update_actor_freq: int = 2,
|
||||||
noise_clip: float = 0.5,
|
noise_clip: float = 0.5,
|
||||||
reward_normalization: bool = False,
|
reward_normalization: bool = False,
|
||||||
ignore_done: bool = False,
|
|
||||||
estimation_step: int = 1,
|
estimation_step: int = 1,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(actor, actor_optim, None, None, action_range,
|
super().__init__(actor, actor_optim, None, None, action_range, tau, gamma,
|
||||||
tau, gamma, exploration_noise, reward_normalization,
|
exploration_noise, reward_normalization,
|
||||||
ignore_done, estimation_step, **kwargs)
|
estimation_step, **kwargs)
|
||||||
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
||||||
self.critic1_old.eval()
|
self.critic1_old.eval()
|
||||||
self.critic1_optim = critic1_optim
|
self.critic1_optim = critic1_optim
|
||||||
|
|||||||
@ -138,7 +138,7 @@ class BasicLogger(BaseLogger):
|
|||||||
def log_update_data(self, update_result: dict, step: int) -> None:
|
def log_update_data(self, update_result: dict, step: int) -> None:
|
||||||
if step - self.last_log_update_step >= self.update_interval:
|
if step - self.last_log_update_step >= self.update_interval:
|
||||||
for k, v in update_result.items():
|
for k, v in update_result.items():
|
||||||
self.write("train/" + k, step, v) # save in train/
|
self.write(k, step, v)
|
||||||
self.last_log_update_step = step
|
self.last_log_update_step = step
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user