Add learning rate scheduler to BasePolicy (#598)

This commit is contained in:
Alex Nikulkov 2022-04-17 08:52:30 -07:00 committed by GitHub
parent 6fc6857812
commit 92456cdb68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 136 additions and 26 deletions

View File

@ -1,7 +1,7 @@
SHELL=/bin/bash SHELL=/bin/bash
PROJECT_NAME=tianshou PROJECT_NAME=tianshou
PROJECT_PATH=${PROJECT_NAME}/ PROJECT_PATH=${PROJECT_NAME}/
LINT_PATHS=${PROJECT_PATH} test/ docs/conf.py examples/ setup.py PYTHON_FILES = $(shell find setup.py ${PROJECT_NAME} test docs/conf.py examples -type f -name "*.py")
check_install = python3 -c "import $(1)" || pip3 install $(1) --upgrade check_install = python3 -c "import $(1)" || pip3 install $(1) --upgrade
check_install_extra = python3 -c "import $(1)" || pip3 install $(2) --upgrade check_install_extra = python3 -c "import $(1)" || pip3 install $(2) --upgrade
@ -19,18 +19,18 @@ mypy:
lint: lint:
$(call check_install, flake8) $(call check_install, flake8)
$(call check_install_extra, bugbear, flake8_bugbear) $(call check_install_extra, bugbear, flake8_bugbear)
flake8 ${LINT_PATHS} --count --show-source --statistics flake8 ${PYTHON_FILES} --count --show-source --statistics
format: format:
$(call check_install, isort) $(call check_install, isort)
isort ${LINT_PATHS} isort ${PYTHON_FILES}
$(call check_install, yapf) $(call check_install, yapf)
yapf -ir ${LINT_PATHS} yapf -ir ${PYTHON_FILES}
check-codestyle: check-codestyle:
$(call check_install, isort) $(call check_install, isort)
$(call check_install, yapf) $(call check_install, yapf)
isort --check ${LINT_PATHS} && yapf -r -d ${LINT_PATHS} isort --check ${PYTHON_FILES} && yapf -r -d ${PYTHON_FILES}
check-docstyle: check-docstyle:
$(call check_install, pydocstyle) $(call check_install, pydocstyle)

View File

@ -18,6 +18,7 @@ numpy
ndarray ndarray
stackoverflow stackoverflow
tensorboard tensorboard
state_dict
len len
tac tac
fqf fqf

View File

@ -10,7 +10,7 @@ pip install envpool
After that, `atari_wrapper` will automatically switch to envpool's Atari env. EnvPool's implementation is much faster (about 2\~3x faster for pure execution speed, 1.5x for overall RL training pipeline) than python vectorized env implementation, and it's behavior is consistent to that approach (OpenAI wrapper), which will describe below. After that, `atari_wrapper` will automatically switch to envpool's Atari env. EnvPool's implementation is much faster (about 2\~3x faster for pure execution speed, 1.5x for overall RL training pipeline) than python vectorized env implementation, and it's behavior is consistent to that approach (OpenAI wrapper), which will describe below.
For more information, please refer to EnvPool's [GitHub](https://github.com/sail-sg/envpool/), [Docs](https://envpool.readthedocs.io/en/latest/api/atari.html), and [3rd-party report](https://ppo-details.cleanrl.dev/2021/11/05/ppo-implementation-details/#solving-pong-in-5-minutes-with-ppo--envpool). For more information, please refer to EnvPool's [GitHub](https://github.com/sail-sg/envpool/), [Docs](https://envpool.readthedocs.io/en/latest/api/atari.html), and [3rd-party report](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#solving-pong-in-5-minutes-with-ppo--envpool).
## ALE-py ## ALE-py

View File

@ -2,7 +2,7 @@ import numpy as np
import torch import torch
from tianshou.exploration import GaussianNoise, OUNoise from tianshou.exploration import GaussianNoise, OUNoise
from tianshou.utils import MovAvg, RunningMeanStd from tianshou.utils import MovAvg, MultipleLRSchedulers, RunningMeanStd
from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.common import MLP, Net
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
@ -99,8 +99,48 @@ def test_net():
assert list(net(data, act).shape) == [bsz, 1] assert list(net(data, act).shape) == [bsz, 1]
def test_lr_schedulers():
initial_lr_1 = 10.0
step_size_1 = 1
gamma_1 = 0.5
net_1 = torch.nn.Linear(2, 3)
optim_1 = torch.optim.Adam(net_1.parameters(), lr=initial_lr_1)
sched_1 = torch.optim.lr_scheduler.StepLR(
optim_1, step_size=step_size_1, gamma=gamma_1
)
initial_lr_2 = 5.0
step_size_2 = 2
gamma_2 = 0.3
net_2 = torch.nn.Linear(3, 2)
optim_2 = torch.optim.Adam(net_2.parameters(), lr=initial_lr_2)
sched_2 = torch.optim.lr_scheduler.StepLR(
optim_2, step_size=step_size_2, gamma=gamma_2
)
schedulers = MultipleLRSchedulers(sched_1, sched_2)
for _ in range(10):
loss_1 = (torch.ones((1, 3)) - net_1(torch.ones((1, 2)))).sum()
optim_1.zero_grad()
loss_1.backward()
optim_1.step()
loss_2 = (torch.ones((1, 2)) - net_2(torch.ones((1, 3)))).sum()
optim_2.zero_grad()
loss_2.backward()
optim_2.step()
schedulers.step()
assert (
optim_1.state_dict()["param_groups"][0]["lr"] ==
(initial_lr_1 * gamma_1**(10 // step_size_1))
)
assert (
optim_2.state_dict()["param_groups"][0]["lr"] ==
(initial_lr_2 * gamma_2**(10 // step_size_2))
)
if __name__ == '__main__': if __name__ == '__main__':
test_noise() test_noise()
test_moving_average() test_moving_average()
test_rms() test_rms()
test_net() test_net()
test_lr_schedulers()

View File

@ -9,6 +9,7 @@ from numba import njit
from torch import nn from torch import nn
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
from tianshou.utils import MultipleLRSchedulers
class BasePolicy(ABC, nn.Module): class BasePolicy(ABC, nn.Module):
@ -64,6 +65,8 @@ class BasePolicy(ABC, nn.Module):
action_space: Optional[gym.Space] = None, action_space: Optional[gym.Space] = None,
action_scaling: bool = False, action_scaling: bool = False,
action_bound_method: str = "", action_bound_method: str = "",
lr_scheduler: Optional[Union[torch.optim.lr_scheduler.LambdaLR,
MultipleLRSchedulers]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.observation_space = observation_space self.observation_space = observation_space
@ -79,6 +82,7 @@ class BasePolicy(ABC, nn.Module):
# can be one of ("clip", "tanh", ""), empty string means no bounding # can be one of ("clip", "tanh", ""), empty string means no bounding
assert action_bound_method in ("", "clip", "tanh") assert action_bound_method in ("", "clip", "tanh")
self.action_bound_method = action_bound_method self.action_bound_method = action_bound_method
self.lr_scheduler = lr_scheduler
self._compile() self._compile()
def set_agent_id(self, agent_id: int) -> None: def set_agent_id(self, agent_id: int) -> None:
@ -272,6 +276,8 @@ class BasePolicy(ABC, nn.Module):
batch = self.process_fn(batch, buffer, indices) batch = self.process_fn(batch, buffer, indices)
result = self.learn(batch, **kwargs) result = self.learn(batch, **kwargs)
self.post_process_fn(batch, buffer, indices) self.post_process_fn(batch, buffer, indices)
if self.lr_scheduler is not None:
self.lr_scheduler.step()
self.updating = False self.updating = False
return result return result

View File

@ -15,6 +15,8 @@ class ImitationPolicy(BasePolicy):
:class:`~tianshou.policy.BasePolicy`. (s -> a) :class:`~tianshou.policy.BasePolicy`. (s -> a)
:param torch.optim.Optimizer optim: for optimizing the model. :param torch.optim.Optimizer optim: for optimizing the model.
:param gym.Space action_space: env's action space. :param gym.Space action_space: env's action space.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso:: .. seealso::

View File

@ -36,6 +36,8 @@ class BCQPolicy(BasePolicy):
:param int num_sampled_action: the number of sampled actions in calculating :param int num_sampled_action: the number of sampled actions in calculating
target Q. The algorithm samples several actions using VAE, and perturbs target Q. The algorithm samples several actions using VAE, and perturbs
each action to get the target Q. Default to 10. each action to get the target Q. Default to 10.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso:: .. seealso::

View File

@ -46,6 +46,8 @@ class CQLPolicy(SACPolicy):
:param float clip_grad: clip_grad for updating critic network. Default to 1.0. :param float clip_grad: clip_grad for updating critic network. Default to 1.0.
:param Union[str, torch.device] device: which device to create this model on. :param Union[str, torch.device] device: which device to create this model on.
Default to "cpu". Default to "cpu".
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso:: .. seealso::

View File

@ -27,6 +27,8 @@ class DiscreteBCQPolicy(DQNPolicy):
logits. Default to 1e-2. logits. Default to 1e-2.
:param bool reward_normalization: normalize the reward to Normal(0, 1). :param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False. Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso:: .. seealso::

View File

@ -23,6 +23,8 @@ class DiscreteCQLPolicy(QRDQNPolicy):
:param bool reward_normalization: normalize the reward to Normal(0, 1). :param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False. Default to False.
:param float min_q_weight: the weight for the cql loss. :param float min_q_weight: the weight for the cql loss.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso:: .. seealso::
Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed

View File

@ -29,6 +29,8 @@ class DiscreteCRRPolicy(PGPolicy):
you do not use the target network). Default to 0. you do not use the target network). Default to 0.
:param bool reward_normalization: normalize the reward to Normal(0, 1). :param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False. Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso:: .. seealso::
Please refer to :class:`~tianshou.policy.PGPolicy` for more detailed Please refer to :class:`~tianshou.policy.PGPolicy` for more detailed

View File

@ -59,6 +59,8 @@ class GAILPolicy(PPOPolicy):
optimizer in each policy.update(). Default to None (no lr_scheduler). optimizer in each policy.update(). Default to None (no lr_scheduler).
:param bool deterministic_eval: whether to use deterministic action instead of :param bool deterministic_eval: whether to use deterministic action instead of
stochastic action sampled by the policy. Default to False. stochastic action sampled by the policy. Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso:: .. seealso::

View File

@ -17,6 +17,8 @@ class ICMPolicy(BasePolicy):
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model. :param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
:param float lr_scale: the scaling factor for ICM learning. :param float lr_scale: the scaling factor for ICM learning.
:param float forward_loss_weight: the weight for forward model loss. :param float forward_loss_weight: the weight for forward model loss.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso:: .. seealso::

View File

@ -18,6 +18,8 @@ class PSRLModel(object):
of rewards, with shape (n_state, n_action). of rewards, with shape (n_state, n_action).
:param float discount_factor: in [0, 1]. :param float discount_factor: in [0, 1].
:param float epsilon: for precision control in value iteration. :param float epsilon: for precision control in value iteration.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
""" """
def __init__( def __init__(

View File

@ -146,9 +146,6 @@ class A2CPolicy(PGPolicy):
vf_losses.append(vf_loss.item()) vf_losses.append(vf_loss.item())
ent_losses.append(ent_loss.item()) ent_losses.append(ent_loss.item())
losses.append(loss.item()) losses.append(loss.item())
# update learning rate if lr_scheduler is given
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return { return {
"loss": losses, "loss": losses,

View File

@ -25,6 +25,8 @@ class C51Policy(DQNPolicy):
you do not use the target network). Default to 0. you do not use the target network). Default to 0.
:param bool reward_normalization: normalize the reward to Normal(0, 1). :param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False. Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso:: .. seealso::

View File

@ -32,6 +32,8 @@ class DDPGPolicy(BasePolicy):
Default to "clip". Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want :param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None. to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso:: .. seealso::

View File

@ -28,6 +28,8 @@ class DiscreteSACPolicy(SACPolicy):
alpha is automatically tuned. alpha is automatically tuned.
:param bool reward_normalization: normalize the reward to Normal(0, 1). :param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False. Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso:: .. seealso::

View File

@ -26,6 +26,8 @@ class DQNPolicy(BasePolicy):
:param bool reward_normalization: normalize the reward to Normal(0, 1). :param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False. Default to False.
:param bool is_double: use double dqn. Default to True. :param bool is_double: use double dqn. Default to True.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso:: .. seealso::

View File

@ -27,6 +27,8 @@ class FQFPolicy(QRDQNPolicy):
you do not use the target network). you do not use the target network).
:param bool reward_normalization: normalize the reward to Normal(0, 1). :param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False. Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso:: .. seealso::

View File

@ -26,6 +26,8 @@ class IQNPolicy(QRDQNPolicy):
you do not use the target network). you do not use the target network).
:param bool reward_normalization: normalize the reward to Normal(0, 1). :param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False. Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso:: .. seealso::

View File

@ -127,10 +127,6 @@ class NPGPolicy(A2CPolicy):
vf_losses.append(vf_loss.item()) vf_losses.append(vf_loss.item())
kls.append(kl.item()) kls.append(kl.item())
# update learning rate if lr_scheduler is given
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return { return {
"loss/actor": actor_losses, "loss/actor": actor_losses,
"loss/vf": vf_losses, "loss/vf": vf_losses,

View File

@ -44,7 +44,6 @@ class PGPolicy(BasePolicy):
reward_normalization: bool = False, reward_normalization: bool = False,
action_scaling: bool = True, action_scaling: bool = True,
action_bound_method: str = "clip", action_bound_method: str = "clip",
lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None,
deterministic_eval: bool = False, deterministic_eval: bool = False,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
@ -55,7 +54,6 @@ class PGPolicy(BasePolicy):
) )
self.actor = model self.actor = model
self.optim = optim self.optim = optim
self.lr_scheduler = lr_scheduler
self.dist_fn = dist_fn self.dist_fn = dist_fn
assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]" assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]"
self._gamma = discount_factor self._gamma = discount_factor
@ -137,8 +135,5 @@ class PGPolicy(BasePolicy):
loss.backward() loss.backward()
self.optim.step() self.optim.step()
losses.append(loss.item()) losses.append(loss.item())
# update learning rate if lr_scheduler is given
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return {"loss": losses} return {"loss": losses}

View File

@ -152,9 +152,6 @@ class PPOPolicy(A2CPolicy):
vf_losses.append(vf_loss.item()) vf_losses.append(vf_loss.item())
ent_losses.append(ent_loss.item()) ent_losses.append(ent_loss.item())
losses.append(loss.item()) losses.append(loss.item())
# update learning rate if lr_scheduler is given
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return { return {
"loss": losses, "loss": losses,

View File

@ -23,6 +23,8 @@ class QRDQNPolicy(DQNPolicy):
you do not use the target network). you do not use the target network).
:param bool reward_normalization: normalize the reward to Normal(0, 1). :param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False. Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso:: .. seealso::

View File

@ -23,6 +23,8 @@ class RainbowPolicy(C51Policy):
you do not use the target network). Default to 0. you do not use the target network). Default to 0.
:param bool reward_normalization: normalize the reward to Normal(0, 1). :param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False. Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso:: .. seealso::

View File

@ -42,6 +42,8 @@ class SACPolicy(DDPGPolicy):
Default to "clip". Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want :param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None. to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso:: .. seealso::

View File

@ -40,6 +40,8 @@ class TD3Policy(DDPGPolicy):
Default to "clip". Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want :param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None. to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso:: .. seealso::

View File

@ -146,10 +146,6 @@ class TRPOPolicy(NPGPolicy):
step_sizes.append(step_size.item()) step_sizes.append(step_size.item())
kls.append(kl.item()) kls.append(kl.item())
# update learning rate if lr_scheduler is given
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return { return {
"loss/actor": actor_losses, "loss/actor": actor_losses,
"loss/vf": vf_losses, "loss/vf": vf_losses,

View File

@ -4,6 +4,7 @@ from tianshou.utils.config import tqdm_config
from tianshou.utils.logger.base import BaseLogger, LazyLogger from tianshou.utils.logger.base import BaseLogger, LazyLogger
from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger
from tianshou.utils.logger.wandb import WandbLogger from tianshou.utils.logger.wandb import WandbLogger
from tianshou.utils.lr_scheduler import MultipleLRSchedulers
from tianshou.utils.statistics import MovAvg, RunningMeanStd from tianshou.utils.statistics import MovAvg, RunningMeanStd
from tianshou.utils.warning import deprecation from tianshou.utils.warning import deprecation
@ -17,4 +18,5 @@ __all__ = [
"LazyLogger", "LazyLogger",
"WandbLogger", "WandbLogger",
"deprecation", "deprecation",
"MultipleLRSchedulers",
] ]

View File

@ -0,0 +1,42 @@
from typing import Dict, List
import torch
class MultipleLRSchedulers:
"""A wrapper for multiple learning rate schedulers.
Every time :meth:`~tianshou.utils.MultipleLRSchedulers.step` is called,
it calls the step() method of each of the schedulers that it contains.
Example usage:
::
scheduler1 = ConstantLR(opt1, factor=0.1, total_iters=2)
scheduler2 = ExponentialLR(opt2, gamma=0.9)
scheduler = MultipleLRSchedulers(scheduler1, scheduler2)
policy = PPOPolicy(..., lr_scheduler=scheduler)
"""
def __init__(self, *args: torch.optim.lr_scheduler.LambdaLR):
self.schedulers = args
def step(self) -> None:
"""Take a step in each of the learning rate schedulers."""
for scheduler in self.schedulers:
scheduler.step()
def state_dict(self) -> List[Dict]:
"""Get state_dict for each of the learning rate schedulers.
:return: A list of state_dict of learning rate schedulers.
"""
return [s.state_dict() for s in self.schedulers]
def load_state_dict(self, state_dict: List[Dict]) -> None:
"""Load states from state_dict.
:param List[Dict] state_dict: A list of learning rate scheduler
state_dict, in the same order as the schedulers.
"""
for (s, sd) in zip(self.schedulers, state_dict):
s.__dict__.update(sd)