Add lr_scheduler option for Onpolicy algorithm (#318)
add lr_scheduler option in PGPolicy/A2CPolicy/PPOPolicy
This commit is contained in:
		
							parent
							
								
									4d92952a7b
								
							
						
					
					
						commit
						2c11b6e43b
					
				@ -38,6 +38,8 @@ class A2CPolicy(PGPolicy):
 | 
				
			|||||||
        squashing) for now, or empty string for no bounding. Default to "clip".
 | 
					        squashing) for now, or empty string for no bounding. Default to "clip".
 | 
				
			||||||
    :param Optional[gym.Space] action_space: env's action space, mandatory if you want
 | 
					    :param Optional[gym.Space] action_space: env's action space, mandatory if you want
 | 
				
			||||||
        to use option "action_scaling" or "action_bound_method". Default to None.
 | 
					        to use option "action_scaling" or "action_bound_method". Default to None.
 | 
				
			||||||
 | 
					    :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
 | 
				
			||||||
 | 
					        optimizer in each policy.update(). Default to None (no lr_scheduler).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    .. seealso::
 | 
					    .. seealso::
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -142,6 +144,10 @@ class A2CPolicy(PGPolicy):
 | 
				
			|||||||
                vf_losses.append(vf_loss.item())
 | 
					                vf_losses.append(vf_loss.item())
 | 
				
			||||||
                ent_losses.append(ent_loss.item())
 | 
					                ent_losses.append(ent_loss.item())
 | 
				
			||||||
                losses.append(loss.item())
 | 
					                losses.append(loss.item())
 | 
				
			||||||
 | 
					        # update learning rate if lr_scheduler is given
 | 
				
			||||||
 | 
					        if self.lr_scheduler is not None:
 | 
				
			||||||
 | 
					            self.lr_scheduler.step()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return {
 | 
					        return {
 | 
				
			||||||
            "loss": losses,
 | 
					            "loss": losses,
 | 
				
			||||||
            "loss/actor": actor_losses,
 | 
					            "loss/actor": actor_losses,
 | 
				
			||||||
 | 
				
			|||||||
@ -22,6 +22,8 @@ class PGPolicy(BasePolicy):
 | 
				
			|||||||
        squashing) for now, or empty string for no bounding. Default to "clip".
 | 
					        squashing) for now, or empty string for no bounding. Default to "clip".
 | 
				
			||||||
    :param Optional[gym.Space] action_space: env's action space, mandatory if you want
 | 
					    :param Optional[gym.Space] action_space: env's action space, mandatory if you want
 | 
				
			||||||
        to use option "action_scaling" or "action_bound_method". Default to None.
 | 
					        to use option "action_scaling" or "action_bound_method". Default to None.
 | 
				
			||||||
 | 
					    :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
 | 
				
			||||||
 | 
					        optimizer in each policy.update(). Default to None (no lr_scheduler).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    .. seealso::
 | 
					    .. seealso::
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -38,6 +40,7 @@ class PGPolicy(BasePolicy):
 | 
				
			|||||||
        reward_normalization: bool = False,
 | 
					        reward_normalization: bool = False,
 | 
				
			||||||
        action_scaling: bool = True,
 | 
					        action_scaling: bool = True,
 | 
				
			||||||
        action_bound_method: str = "clip",
 | 
					        action_bound_method: str = "clip",
 | 
				
			||||||
 | 
					        lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None,
 | 
				
			||||||
        **kwargs: Any,
 | 
					        **kwargs: Any,
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
        super().__init__(action_scaling=action_scaling,
 | 
					        super().__init__(action_scaling=action_scaling,
 | 
				
			||||||
@ -45,6 +48,7 @@ class PGPolicy(BasePolicy):
 | 
				
			|||||||
        if model is not None:
 | 
					        if model is not None:
 | 
				
			||||||
            self.model: torch.nn.Module = model
 | 
					            self.model: torch.nn.Module = model
 | 
				
			||||||
        self.optim = optim
 | 
					        self.optim = optim
 | 
				
			||||||
 | 
					        self.lr_scheduler = lr_scheduler
 | 
				
			||||||
        self.dist_fn = dist_fn
 | 
					        self.dist_fn = dist_fn
 | 
				
			||||||
        assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]"
 | 
					        assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]"
 | 
				
			||||||
        self._gamma = discount_factor
 | 
					        self._gamma = discount_factor
 | 
				
			||||||
@ -110,6 +114,10 @@ class PGPolicy(BasePolicy):
 | 
				
			|||||||
                loss.backward()
 | 
					                loss.backward()
 | 
				
			||||||
                self.optim.step()
 | 
					                self.optim.step()
 | 
				
			||||||
                losses.append(loss.item())
 | 
					                losses.append(loss.item())
 | 
				
			||||||
 | 
					        # update learning rate if lr_scheduler is given
 | 
				
			||||||
 | 
					        if self.lr_scheduler is not None:
 | 
				
			||||||
 | 
					            self.lr_scheduler.step()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return {"loss": losses}
 | 
					        return {"loss": losses}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # def _vanilla_returns(self, batch):
 | 
					    # def _vanilla_returns(self, batch):
 | 
				
			||||||
 | 
				
			|||||||
@ -43,6 +43,8 @@ class PPOPolicy(PGPolicy):
 | 
				
			|||||||
        squashing) for now, or empty string for no bounding. Default to "clip".
 | 
					        squashing) for now, or empty string for no bounding. Default to "clip".
 | 
				
			||||||
    :param Optional[gym.Space] action_space: env's action space, mandatory if you want
 | 
					    :param Optional[gym.Space] action_space: env's action space, mandatory if you want
 | 
				
			||||||
        to use option "action_scaling" or "action_bound_method". Default to None.
 | 
					        to use option "action_scaling" or "action_bound_method". Default to None.
 | 
				
			||||||
 | 
					    :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
 | 
				
			||||||
 | 
					        optimizer in each policy.update(). Default to None (no lr_scheduler).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    .. seealso::
 | 
					    .. seealso::
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -179,6 +181,10 @@ class PPOPolicy(PGPolicy):
 | 
				
			|||||||
                        list(self.actor.parameters()) + list(self.critic.parameters()),
 | 
					                        list(self.actor.parameters()) + list(self.critic.parameters()),
 | 
				
			||||||
                        self._max_grad_norm)
 | 
					                        self._max_grad_norm)
 | 
				
			||||||
                self.optim.step()
 | 
					                self.optim.step()
 | 
				
			||||||
 | 
					        # update learning rate if lr_scheduler is given
 | 
				
			||||||
 | 
					        if self.lr_scheduler is not None:
 | 
				
			||||||
 | 
					            self.lr_scheduler.step()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return {
 | 
					        return {
 | 
				
			||||||
            "loss": losses,
 | 
					            "loss": losses,
 | 
				
			||||||
            "loss/clip": clip_losses,
 | 
					            "loss/clip": clip_losses,
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user