Add lr_scheduler option for Onpolicy algorithm (#318)

add lr_scheduler option in PGPolicy/A2CPolicy/PPOPolicy
This commit is contained in:
ChenDRAG 2021-03-22 16:57:24 +08:00 committed by GitHub
parent 4d92952a7b
commit 2c11b6e43b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 0 deletions

View File

@ -38,6 +38,8 @@ class A2CPolicy(PGPolicy):
squashing) for now, or empty string for no bounding. Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso::
@ -142,6 +144,10 @@ class A2CPolicy(PGPolicy):
vf_losses.append(vf_loss.item())
ent_losses.append(ent_loss.item())
losses.append(loss.item())
# update learning rate if lr_scheduler is given
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return {
"loss": losses,
"loss/actor": actor_losses,

View File

@ -22,6 +22,8 @@ class PGPolicy(BasePolicy):
squashing) for now, or empty string for no bounding. Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso::
@ -38,6 +40,7 @@ class PGPolicy(BasePolicy):
reward_normalization: bool = False,
action_scaling: bool = True,
action_bound_method: str = "clip",
lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None,
**kwargs: Any,
) -> None:
super().__init__(action_scaling=action_scaling,
@ -45,6 +48,7 @@ class PGPolicy(BasePolicy):
if model is not None:
self.model: torch.nn.Module = model
self.optim = optim
self.lr_scheduler = lr_scheduler
self.dist_fn = dist_fn
assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]"
self._gamma = discount_factor
@ -110,6 +114,10 @@ class PGPolicy(BasePolicy):
loss.backward()
self.optim.step()
losses.append(loss.item())
# update learning rate if lr_scheduler is given
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return {"loss": losses}
# def _vanilla_returns(self, batch):

View File

@ -43,6 +43,8 @@ class PPOPolicy(PGPolicy):
squashing) for now, or empty string for no bounding. Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso::
@ -179,6 +181,10 @@ class PPOPolicy(PGPolicy):
list(self.actor.parameters()) + list(self.critic.parameters()),
self._max_grad_norm)
self.optim.step()
# update learning rate if lr_scheduler is given
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return {
"loss": losses,
"loss/clip": clip_losses,