Add learning rate scheduler to BasePolicy (#598)
This commit is contained in:
parent
6fc6857812
commit
92456cdb68
10
Makefile
10
Makefile
@ -1,7 +1,7 @@
|
||||
SHELL=/bin/bash
|
||||
PROJECT_NAME=tianshou
|
||||
PROJECT_PATH=${PROJECT_NAME}/
|
||||
LINT_PATHS=${PROJECT_PATH} test/ docs/conf.py examples/ setup.py
|
||||
PYTHON_FILES = $(shell find setup.py ${PROJECT_NAME} test docs/conf.py examples -type f -name "*.py")
|
||||
|
||||
check_install = python3 -c "import $(1)" || pip3 install $(1) --upgrade
|
||||
check_install_extra = python3 -c "import $(1)" || pip3 install $(2) --upgrade
|
||||
@ -19,18 +19,18 @@ mypy:
|
||||
lint:
|
||||
$(call check_install, flake8)
|
||||
$(call check_install_extra, bugbear, flake8_bugbear)
|
||||
flake8 ${LINT_PATHS} --count --show-source --statistics
|
||||
flake8 ${PYTHON_FILES} --count --show-source --statistics
|
||||
|
||||
format:
|
||||
$(call check_install, isort)
|
||||
isort ${LINT_PATHS}
|
||||
isort ${PYTHON_FILES}
|
||||
$(call check_install, yapf)
|
||||
yapf -ir ${LINT_PATHS}
|
||||
yapf -ir ${PYTHON_FILES}
|
||||
|
||||
check-codestyle:
|
||||
$(call check_install, isort)
|
||||
$(call check_install, yapf)
|
||||
isort --check ${LINT_PATHS} && yapf -r -d ${LINT_PATHS}
|
||||
isort --check ${PYTHON_FILES} && yapf -r -d ${PYTHON_FILES}
|
||||
|
||||
check-docstyle:
|
||||
$(call check_install, pydocstyle)
|
||||
|
@ -18,6 +18,7 @@ numpy
|
||||
ndarray
|
||||
stackoverflow
|
||||
tensorboard
|
||||
state_dict
|
||||
len
|
||||
tac
|
||||
fqf
|
||||
|
@ -10,7 +10,7 @@ pip install envpool
|
||||
|
||||
After that, `atari_wrapper` will automatically switch to envpool's Atari env. EnvPool's implementation is much faster (about 2\~3x faster for pure execution speed, 1.5x for overall RL training pipeline) than python vectorized env implementation, and it's behavior is consistent to that approach (OpenAI wrapper), which will describe below.
|
||||
|
||||
For more information, please refer to EnvPool's [GitHub](https://github.com/sail-sg/envpool/), [Docs](https://envpool.readthedocs.io/en/latest/api/atari.html), and [3rd-party report](https://ppo-details.cleanrl.dev/2021/11/05/ppo-implementation-details/#solving-pong-in-5-minutes-with-ppo--envpool).
|
||||
For more information, please refer to EnvPool's [GitHub](https://github.com/sail-sg/envpool/), [Docs](https://envpool.readthedocs.io/en/latest/api/atari.html), and [3rd-party report](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#solving-pong-in-5-minutes-with-ppo--envpool).
|
||||
|
||||
## ALE-py
|
||||
|
||||
|
@ -2,7 +2,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from tianshou.exploration import GaussianNoise, OUNoise
|
||||
from tianshou.utils import MovAvg, RunningMeanStd
|
||||
from tianshou.utils import MovAvg, MultipleLRSchedulers, RunningMeanStd
|
||||
from tianshou.utils.net.common import MLP, Net
|
||||
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
|
||||
|
||||
@ -99,8 +99,48 @@ def test_net():
|
||||
assert list(net(data, act).shape) == [bsz, 1]
|
||||
|
||||
|
||||
def test_lr_schedulers():
|
||||
initial_lr_1 = 10.0
|
||||
step_size_1 = 1
|
||||
gamma_1 = 0.5
|
||||
net_1 = torch.nn.Linear(2, 3)
|
||||
optim_1 = torch.optim.Adam(net_1.parameters(), lr=initial_lr_1)
|
||||
sched_1 = torch.optim.lr_scheduler.StepLR(
|
||||
optim_1, step_size=step_size_1, gamma=gamma_1
|
||||
)
|
||||
|
||||
initial_lr_2 = 5.0
|
||||
step_size_2 = 2
|
||||
gamma_2 = 0.3
|
||||
net_2 = torch.nn.Linear(3, 2)
|
||||
optim_2 = torch.optim.Adam(net_2.parameters(), lr=initial_lr_2)
|
||||
sched_2 = torch.optim.lr_scheduler.StepLR(
|
||||
optim_2, step_size=step_size_2, gamma=gamma_2
|
||||
)
|
||||
schedulers = MultipleLRSchedulers(sched_1, sched_2)
|
||||
for _ in range(10):
|
||||
loss_1 = (torch.ones((1, 3)) - net_1(torch.ones((1, 2)))).sum()
|
||||
optim_1.zero_grad()
|
||||
loss_1.backward()
|
||||
optim_1.step()
|
||||
loss_2 = (torch.ones((1, 2)) - net_2(torch.ones((1, 3)))).sum()
|
||||
optim_2.zero_grad()
|
||||
loss_2.backward()
|
||||
optim_2.step()
|
||||
schedulers.step()
|
||||
assert (
|
||||
optim_1.state_dict()["param_groups"][0]["lr"] ==
|
||||
(initial_lr_1 * gamma_1**(10 // step_size_1))
|
||||
)
|
||||
assert (
|
||||
optim_2.state_dict()["param_groups"][0]["lr"] ==
|
||||
(initial_lr_2 * gamma_2**(10 // step_size_2))
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_noise()
|
||||
test_moving_average()
|
||||
test_rms()
|
||||
test_net()
|
||||
test_lr_schedulers()
|
||||
|
@ -9,6 +9,7 @@ from numba import njit
|
||||
from torch import nn
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
|
||||
from tianshou.utils import MultipleLRSchedulers
|
||||
|
||||
|
||||
class BasePolicy(ABC, nn.Module):
|
||||
@ -64,6 +65,8 @@ class BasePolicy(ABC, nn.Module):
|
||||
action_space: Optional[gym.Space] = None,
|
||||
action_scaling: bool = False,
|
||||
action_bound_method: str = "",
|
||||
lr_scheduler: Optional[Union[torch.optim.lr_scheduler.LambdaLR,
|
||||
MultipleLRSchedulers]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.observation_space = observation_space
|
||||
@ -79,6 +82,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
# can be one of ("clip", "tanh", ""), empty string means no bounding
|
||||
assert action_bound_method in ("", "clip", "tanh")
|
||||
self.action_bound_method = action_bound_method
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self._compile()
|
||||
|
||||
def set_agent_id(self, agent_id: int) -> None:
|
||||
@ -272,6 +276,8 @@ class BasePolicy(ABC, nn.Module):
|
||||
batch = self.process_fn(batch, buffer, indices)
|
||||
result = self.learn(batch, **kwargs)
|
||||
self.post_process_fn(batch, buffer, indices)
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
self.updating = False
|
||||
return result
|
||||
|
||||
|
@ -15,6 +15,8 @@ class ImitationPolicy(BasePolicy):
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> a)
|
||||
:param torch.optim.Optimizer optim: for optimizing the model.
|
||||
:param gym.Space action_space: env's action space.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
|
||||
.. seealso::
|
||||
|
||||
|
@ -36,6 +36,8 @@ class BCQPolicy(BasePolicy):
|
||||
:param int num_sampled_action: the number of sampled actions in calculating
|
||||
target Q. The algorithm samples several actions using VAE, and perturbs
|
||||
each action to get the target Q. Default to 10.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
|
||||
.. seealso::
|
||||
|
||||
|
@ -46,6 +46,8 @@ class CQLPolicy(SACPolicy):
|
||||
:param float clip_grad: clip_grad for updating critic network. Default to 1.0.
|
||||
:param Union[str, torch.device] device: which device to create this model on.
|
||||
Default to "cpu".
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
|
||||
.. seealso::
|
||||
|
||||
|
@ -27,6 +27,8 @@ class DiscreteBCQPolicy(DQNPolicy):
|
||||
logits. Default to 1e-2.
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||
Default to False.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
|
||||
.. seealso::
|
||||
|
||||
|
@ -23,6 +23,8 @@ class DiscreteCQLPolicy(QRDQNPolicy):
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||
Default to False.
|
||||
:param float min_q_weight: the weight for the cql loss.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
|
||||
.. seealso::
|
||||
Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed
|
||||
|
@ -29,6 +29,8 @@ class DiscreteCRRPolicy(PGPolicy):
|
||||
you do not use the target network). Default to 0.
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||
Default to False.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
|
||||
.. seealso::
|
||||
Please refer to :class:`~tianshou.policy.PGPolicy` for more detailed
|
||||
|
@ -59,6 +59,8 @@ class GAILPolicy(PPOPolicy):
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
:param bool deterministic_eval: whether to use deterministic action instead of
|
||||
stochastic action sampled by the policy. Default to False.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
|
||||
.. seealso::
|
||||
|
||||
|
@ -17,6 +17,8 @@ class ICMPolicy(BasePolicy):
|
||||
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
||||
:param float lr_scale: the scaling factor for ICM learning.
|
||||
:param float forward_loss_weight: the weight for forward model loss.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
|
||||
.. seealso::
|
||||
|
||||
|
@ -18,6 +18,8 @@ class PSRLModel(object):
|
||||
of rewards, with shape (n_state, n_action).
|
||||
:param float discount_factor: in [0, 1].
|
||||
:param float epsilon: for precision control in value iteration.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -146,9 +146,6 @@ class A2CPolicy(PGPolicy):
|
||||
vf_losses.append(vf_loss.item())
|
||||
ent_losses.append(ent_loss.item())
|
||||
losses.append(loss.item())
|
||||
# update learning rate if lr_scheduler is given
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
|
||||
return {
|
||||
"loss": losses,
|
||||
|
@ -25,6 +25,8 @@ class C51Policy(DQNPolicy):
|
||||
you do not use the target network). Default to 0.
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||
Default to False.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
|
||||
.. seealso::
|
||||
|
||||
|
@ -32,6 +32,8 @@ class DDPGPolicy(BasePolicy):
|
||||
Default to "clip".
|
||||
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
|
||||
.. seealso::
|
||||
|
||||
|
@ -28,6 +28,8 @@ class DiscreteSACPolicy(SACPolicy):
|
||||
alpha is automatically tuned.
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||
Default to False.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
|
||||
.. seealso::
|
||||
|
||||
|
@ -26,6 +26,8 @@ class DQNPolicy(BasePolicy):
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||
Default to False.
|
||||
:param bool is_double: use double dqn. Default to True.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
|
||||
.. seealso::
|
||||
|
||||
|
@ -27,6 +27,8 @@ class FQFPolicy(QRDQNPolicy):
|
||||
you do not use the target network).
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||
Default to False.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
|
||||
.. seealso::
|
||||
|
||||
|
@ -26,6 +26,8 @@ class IQNPolicy(QRDQNPolicy):
|
||||
you do not use the target network).
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||
Default to False.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
|
||||
.. seealso::
|
||||
|
||||
|
@ -127,10 +127,6 @@ class NPGPolicy(A2CPolicy):
|
||||
vf_losses.append(vf_loss.item())
|
||||
kls.append(kl.item())
|
||||
|
||||
# update learning rate if lr_scheduler is given
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
|
||||
return {
|
||||
"loss/actor": actor_losses,
|
||||
"loss/vf": vf_losses,
|
||||
|
@ -44,7 +44,6 @@ class PGPolicy(BasePolicy):
|
||||
reward_normalization: bool = False,
|
||||
action_scaling: bool = True,
|
||||
action_bound_method: str = "clip",
|
||||
lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None,
|
||||
deterministic_eval: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@ -55,7 +54,6 @@ class PGPolicy(BasePolicy):
|
||||
)
|
||||
self.actor = model
|
||||
self.optim = optim
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.dist_fn = dist_fn
|
||||
assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]"
|
||||
self._gamma = discount_factor
|
||||
@ -137,8 +135,5 @@ class PGPolicy(BasePolicy):
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
losses.append(loss.item())
|
||||
# update learning rate if lr_scheduler is given
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
|
||||
return {"loss": losses}
|
||||
|
@ -152,9 +152,6 @@ class PPOPolicy(A2CPolicy):
|
||||
vf_losses.append(vf_loss.item())
|
||||
ent_losses.append(ent_loss.item())
|
||||
losses.append(loss.item())
|
||||
# update learning rate if lr_scheduler is given
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
|
||||
return {
|
||||
"loss": losses,
|
||||
|
@ -23,6 +23,8 @@ class QRDQNPolicy(DQNPolicy):
|
||||
you do not use the target network).
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||
Default to False.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
|
||||
.. seealso::
|
||||
|
||||
|
@ -23,6 +23,8 @@ class RainbowPolicy(C51Policy):
|
||||
you do not use the target network). Default to 0.
|
||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||
Default to False.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
|
||||
.. seealso::
|
||||
|
||||
|
@ -42,6 +42,8 @@ class SACPolicy(DDPGPolicy):
|
||||
Default to "clip".
|
||||
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
|
||||
.. seealso::
|
||||
|
||||
|
@ -40,6 +40,8 @@ class TD3Policy(DDPGPolicy):
|
||||
Default to "clip".
|
||||
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||
|
||||
.. seealso::
|
||||
|
||||
|
@ -146,10 +146,6 @@ class TRPOPolicy(NPGPolicy):
|
||||
step_sizes.append(step_size.item())
|
||||
kls.append(kl.item())
|
||||
|
||||
# update learning rate if lr_scheduler is given
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
|
||||
return {
|
||||
"loss/actor": actor_losses,
|
||||
"loss/vf": vf_losses,
|
||||
|
@ -4,6 +4,7 @@ from tianshou.utils.config import tqdm_config
|
||||
from tianshou.utils.logger.base import BaseLogger, LazyLogger
|
||||
from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger
|
||||
from tianshou.utils.logger.wandb import WandbLogger
|
||||
from tianshou.utils.lr_scheduler import MultipleLRSchedulers
|
||||
from tianshou.utils.statistics import MovAvg, RunningMeanStd
|
||||
from tianshou.utils.warning import deprecation
|
||||
|
||||
@ -17,4 +18,5 @@ __all__ = [
|
||||
"LazyLogger",
|
||||
"WandbLogger",
|
||||
"deprecation",
|
||||
"MultipleLRSchedulers",
|
||||
]
|
||||
|
42
tianshou/utils/lr_scheduler.py
Normal file
42
tianshou/utils/lr_scheduler.py
Normal file
@ -0,0 +1,42 @@
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MultipleLRSchedulers:
|
||||
"""A wrapper for multiple learning rate schedulers.
|
||||
|
||||
Every time :meth:`~tianshou.utils.MultipleLRSchedulers.step` is called,
|
||||
it calls the step() method of each of the schedulers that it contains.
|
||||
Example usage:
|
||||
::
|
||||
|
||||
scheduler1 = ConstantLR(opt1, factor=0.1, total_iters=2)
|
||||
scheduler2 = ExponentialLR(opt2, gamma=0.9)
|
||||
scheduler = MultipleLRSchedulers(scheduler1, scheduler2)
|
||||
policy = PPOPolicy(..., lr_scheduler=scheduler)
|
||||
"""
|
||||
|
||||
def __init__(self, *args: torch.optim.lr_scheduler.LambdaLR):
|
||||
self.schedulers = args
|
||||
|
||||
def step(self) -> None:
|
||||
"""Take a step in each of the learning rate schedulers."""
|
||||
for scheduler in self.schedulers:
|
||||
scheduler.step()
|
||||
|
||||
def state_dict(self) -> List[Dict]:
|
||||
"""Get state_dict for each of the learning rate schedulers.
|
||||
|
||||
:return: A list of state_dict of learning rate schedulers.
|
||||
"""
|
||||
return [s.state_dict() for s in self.schedulers]
|
||||
|
||||
def load_state_dict(self, state_dict: List[Dict]) -> None:
|
||||
"""Load states from state_dict.
|
||||
|
||||
:param List[Dict] state_dict: A list of learning rate scheduler
|
||||
state_dict, in the same order as the schedulers.
|
||||
"""
|
||||
for (s, sd) in zip(self.schedulers, state_dict):
|
||||
s.__dict__.update(sd)
|
Loading…
x
Reference in New Issue
Block a user