Add learning rate scheduler to BasePolicy (#598)

This commit is contained in:
Alex Nikulkov 2022-04-17 08:52:30 -07:00 committed by GitHub
parent 6fc6857812
commit 92456cdb68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 136 additions and 26 deletions

View File

@ -1,7 +1,7 @@
SHELL=/bin/bash
PROJECT_NAME=tianshou
PROJECT_PATH=${PROJECT_NAME}/
LINT_PATHS=${PROJECT_PATH} test/ docs/conf.py examples/ setup.py
PYTHON_FILES = $(shell find setup.py ${PROJECT_NAME} test docs/conf.py examples -type f -name "*.py")
check_install = python3 -c "import $(1)" || pip3 install $(1) --upgrade
check_install_extra = python3 -c "import $(1)" || pip3 install $(2) --upgrade
@ -19,18 +19,18 @@ mypy:
lint:
$(call check_install, flake8)
$(call check_install_extra, bugbear, flake8_bugbear)
flake8 ${LINT_PATHS} --count --show-source --statistics
flake8 ${PYTHON_FILES} --count --show-source --statistics
format:
$(call check_install, isort)
isort ${LINT_PATHS}
isort ${PYTHON_FILES}
$(call check_install, yapf)
yapf -ir ${LINT_PATHS}
yapf -ir ${PYTHON_FILES}
check-codestyle:
$(call check_install, isort)
$(call check_install, yapf)
isort --check ${LINT_PATHS} && yapf -r -d ${LINT_PATHS}
isort --check ${PYTHON_FILES} && yapf -r -d ${PYTHON_FILES}
check-docstyle:
$(call check_install, pydocstyle)

View File

@ -18,6 +18,7 @@ numpy
ndarray
stackoverflow
tensorboard
state_dict
len
tac
fqf

View File

@ -10,7 +10,7 @@ pip install envpool
After that, `atari_wrapper` will automatically switch to envpool's Atari env. EnvPool's implementation is much faster (about 2\~3x faster for pure execution speed, 1.5x for overall RL training pipeline) than python vectorized env implementation, and it's behavior is consistent to that approach (OpenAI wrapper), which will describe below.
For more information, please refer to EnvPool's [GitHub](https://github.com/sail-sg/envpool/), [Docs](https://envpool.readthedocs.io/en/latest/api/atari.html), and [3rd-party report](https://ppo-details.cleanrl.dev/2021/11/05/ppo-implementation-details/#solving-pong-in-5-minutes-with-ppo--envpool).
For more information, please refer to EnvPool's [GitHub](https://github.com/sail-sg/envpool/), [Docs](https://envpool.readthedocs.io/en/latest/api/atari.html), and [3rd-party report](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#solving-pong-in-5-minutes-with-ppo--envpool).
## ALE-py

View File

@ -2,7 +2,7 @@ import numpy as np
import torch
from tianshou.exploration import GaussianNoise, OUNoise
from tianshou.utils import MovAvg, RunningMeanStd
from tianshou.utils import MovAvg, MultipleLRSchedulers, RunningMeanStd
from tianshou.utils.net.common import MLP, Net
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
@ -99,8 +99,48 @@ def test_net():
assert list(net(data, act).shape) == [bsz, 1]
def test_lr_schedulers():
initial_lr_1 = 10.0
step_size_1 = 1
gamma_1 = 0.5
net_1 = torch.nn.Linear(2, 3)
optim_1 = torch.optim.Adam(net_1.parameters(), lr=initial_lr_1)
sched_1 = torch.optim.lr_scheduler.StepLR(
optim_1, step_size=step_size_1, gamma=gamma_1
)
initial_lr_2 = 5.0
step_size_2 = 2
gamma_2 = 0.3
net_2 = torch.nn.Linear(3, 2)
optim_2 = torch.optim.Adam(net_2.parameters(), lr=initial_lr_2)
sched_2 = torch.optim.lr_scheduler.StepLR(
optim_2, step_size=step_size_2, gamma=gamma_2
)
schedulers = MultipleLRSchedulers(sched_1, sched_2)
for _ in range(10):
loss_1 = (torch.ones((1, 3)) - net_1(torch.ones((1, 2)))).sum()
optim_1.zero_grad()
loss_1.backward()
optim_1.step()
loss_2 = (torch.ones((1, 2)) - net_2(torch.ones((1, 3)))).sum()
optim_2.zero_grad()
loss_2.backward()
optim_2.step()
schedulers.step()
assert (
optim_1.state_dict()["param_groups"][0]["lr"] ==
(initial_lr_1 * gamma_1**(10 // step_size_1))
)
assert (
optim_2.state_dict()["param_groups"][0]["lr"] ==
(initial_lr_2 * gamma_2**(10 // step_size_2))
)
if __name__ == '__main__':
test_noise()
test_moving_average()
test_rms()
test_net()
test_lr_schedulers()

View File

@ -9,6 +9,7 @@ from numba import njit
from torch import nn
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
from tianshou.utils import MultipleLRSchedulers
class BasePolicy(ABC, nn.Module):
@ -64,6 +65,8 @@ class BasePolicy(ABC, nn.Module):
action_space: Optional[gym.Space] = None,
action_scaling: bool = False,
action_bound_method: str = "",
lr_scheduler: Optional[Union[torch.optim.lr_scheduler.LambdaLR,
MultipleLRSchedulers]] = None,
) -> None:
super().__init__()
self.observation_space = observation_space
@ -79,6 +82,7 @@ class BasePolicy(ABC, nn.Module):
# can be one of ("clip", "tanh", ""), empty string means no bounding
assert action_bound_method in ("", "clip", "tanh")
self.action_bound_method = action_bound_method
self.lr_scheduler = lr_scheduler
self._compile()
def set_agent_id(self, agent_id: int) -> None:
@ -272,6 +276,8 @@ class BasePolicy(ABC, nn.Module):
batch = self.process_fn(batch, buffer, indices)
result = self.learn(batch, **kwargs)
self.post_process_fn(batch, buffer, indices)
if self.lr_scheduler is not None:
self.lr_scheduler.step()
self.updating = False
return result

View File

@ -15,6 +15,8 @@ class ImitationPolicy(BasePolicy):
:class:`~tianshou.policy.BasePolicy`. (s -> a)
:param torch.optim.Optimizer optim: for optimizing the model.
:param gym.Space action_space: env's action space.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso::

View File

@ -36,6 +36,8 @@ class BCQPolicy(BasePolicy):
:param int num_sampled_action: the number of sampled actions in calculating
target Q. The algorithm samples several actions using VAE, and perturbs
each action to get the target Q. Default to 10.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso::

View File

@ -46,6 +46,8 @@ class CQLPolicy(SACPolicy):
:param float clip_grad: clip_grad for updating critic network. Default to 1.0.
:param Union[str, torch.device] device: which device to create this model on.
Default to "cpu".
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso::

View File

@ -27,6 +27,8 @@ class DiscreteBCQPolicy(DQNPolicy):
logits. Default to 1e-2.
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso::

View File

@ -23,6 +23,8 @@ class DiscreteCQLPolicy(QRDQNPolicy):
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param float min_q_weight: the weight for the cql loss.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso::
Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed

View File

@ -29,6 +29,8 @@ class DiscreteCRRPolicy(PGPolicy):
you do not use the target network). Default to 0.
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso::
Please refer to :class:`~tianshou.policy.PGPolicy` for more detailed

View File

@ -59,6 +59,8 @@ class GAILPolicy(PPOPolicy):
optimizer in each policy.update(). Default to None (no lr_scheduler).
:param bool deterministic_eval: whether to use deterministic action instead of
stochastic action sampled by the policy. Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso::

View File

@ -17,6 +17,8 @@ class ICMPolicy(BasePolicy):
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
:param float lr_scale: the scaling factor for ICM learning.
:param float forward_loss_weight: the weight for forward model loss.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso::

View File

@ -18,6 +18,8 @@ class PSRLModel(object):
of rewards, with shape (n_state, n_action).
:param float discount_factor: in [0, 1].
:param float epsilon: for precision control in value iteration.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
"""
def __init__(

View File

@ -146,9 +146,6 @@ class A2CPolicy(PGPolicy):
vf_losses.append(vf_loss.item())
ent_losses.append(ent_loss.item())
losses.append(loss.item())
# update learning rate if lr_scheduler is given
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return {
"loss": losses,

View File

@ -25,6 +25,8 @@ class C51Policy(DQNPolicy):
you do not use the target network). Default to 0.
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso::

View File

@ -32,6 +32,8 @@ class DDPGPolicy(BasePolicy):
Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso::

View File

@ -28,6 +28,8 @@ class DiscreteSACPolicy(SACPolicy):
alpha is automatically tuned.
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso::

View File

@ -26,6 +26,8 @@ class DQNPolicy(BasePolicy):
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param bool is_double: use double dqn. Default to True.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso::

View File

@ -27,6 +27,8 @@ class FQFPolicy(QRDQNPolicy):
you do not use the target network).
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso::

View File

@ -26,6 +26,8 @@ class IQNPolicy(QRDQNPolicy):
you do not use the target network).
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso::

View File

@ -127,10 +127,6 @@ class NPGPolicy(A2CPolicy):
vf_losses.append(vf_loss.item())
kls.append(kl.item())
# update learning rate if lr_scheduler is given
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return {
"loss/actor": actor_losses,
"loss/vf": vf_losses,

View File

@ -44,7 +44,6 @@ class PGPolicy(BasePolicy):
reward_normalization: bool = False,
action_scaling: bool = True,
action_bound_method: str = "clip",
lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None,
deterministic_eval: bool = False,
**kwargs: Any,
) -> None:
@ -55,7 +54,6 @@ class PGPolicy(BasePolicy):
)
self.actor = model
self.optim = optim
self.lr_scheduler = lr_scheduler
self.dist_fn = dist_fn
assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]"
self._gamma = discount_factor
@ -137,8 +135,5 @@ class PGPolicy(BasePolicy):
loss.backward()
self.optim.step()
losses.append(loss.item())
# update learning rate if lr_scheduler is given
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return {"loss": losses}

View File

@ -152,9 +152,6 @@ class PPOPolicy(A2CPolicy):
vf_losses.append(vf_loss.item())
ent_losses.append(ent_loss.item())
losses.append(loss.item())
# update learning rate if lr_scheduler is given
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return {
"loss": losses,

View File

@ -23,6 +23,8 @@ class QRDQNPolicy(DQNPolicy):
you do not use the target network).
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso::

View File

@ -23,6 +23,8 @@ class RainbowPolicy(C51Policy):
you do not use the target network). Default to 0.
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso::

View File

@ -42,6 +42,8 @@ class SACPolicy(DDPGPolicy):
Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso::

View File

@ -40,6 +40,8 @@ class TD3Policy(DDPGPolicy):
Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
.. seealso::

View File

@ -146,10 +146,6 @@ class TRPOPolicy(NPGPolicy):
step_sizes.append(step_size.item())
kls.append(kl.item())
# update learning rate if lr_scheduler is given
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return {
"loss/actor": actor_losses,
"loss/vf": vf_losses,

View File

@ -4,6 +4,7 @@ from tianshou.utils.config import tqdm_config
from tianshou.utils.logger.base import BaseLogger, LazyLogger
from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger
from tianshou.utils.logger.wandb import WandbLogger
from tianshou.utils.lr_scheduler import MultipleLRSchedulers
from tianshou.utils.statistics import MovAvg, RunningMeanStd
from tianshou.utils.warning import deprecation
@ -17,4 +18,5 @@ __all__ = [
"LazyLogger",
"WandbLogger",
"deprecation",
"MultipleLRSchedulers",
]

View File

@ -0,0 +1,42 @@
from typing import Dict, List
import torch
class MultipleLRSchedulers:
"""A wrapper for multiple learning rate schedulers.
Every time :meth:`~tianshou.utils.MultipleLRSchedulers.step` is called,
it calls the step() method of each of the schedulers that it contains.
Example usage:
::
scheduler1 = ConstantLR(opt1, factor=0.1, total_iters=2)
scheduler2 = ExponentialLR(opt2, gamma=0.9)
scheduler = MultipleLRSchedulers(scheduler1, scheduler2)
policy = PPOPolicy(..., lr_scheduler=scheduler)
"""
def __init__(self, *args: torch.optim.lr_scheduler.LambdaLR):
self.schedulers = args
def step(self) -> None:
"""Take a step in each of the learning rate schedulers."""
for scheduler in self.schedulers:
scheduler.step()
def state_dict(self) -> List[Dict]:
"""Get state_dict for each of the learning rate schedulers.
:return: A list of state_dict of learning rate schedulers.
"""
return [s.state_dict() for s in self.schedulers]
def load_state_dict(self, state_dict: List[Dict]) -> None:
"""Load states from state_dict.
:param List[Dict] state_dict: A list of learning rate scheduler
state_dict, in the same order as the schedulers.
"""
for (s, sd) in zip(self.schedulers, state_dict):
s.__dict__.update(sd)