Add learning rate scheduler to BasePolicy (#598)
This commit is contained in:
parent
6fc6857812
commit
92456cdb68
10
Makefile
10
Makefile
@ -1,7 +1,7 @@
|
|||||||
SHELL=/bin/bash
|
SHELL=/bin/bash
|
||||||
PROJECT_NAME=tianshou
|
PROJECT_NAME=tianshou
|
||||||
PROJECT_PATH=${PROJECT_NAME}/
|
PROJECT_PATH=${PROJECT_NAME}/
|
||||||
LINT_PATHS=${PROJECT_PATH} test/ docs/conf.py examples/ setup.py
|
PYTHON_FILES = $(shell find setup.py ${PROJECT_NAME} test docs/conf.py examples -type f -name "*.py")
|
||||||
|
|
||||||
check_install = python3 -c "import $(1)" || pip3 install $(1) --upgrade
|
check_install = python3 -c "import $(1)" || pip3 install $(1) --upgrade
|
||||||
check_install_extra = python3 -c "import $(1)" || pip3 install $(2) --upgrade
|
check_install_extra = python3 -c "import $(1)" || pip3 install $(2) --upgrade
|
||||||
@ -19,18 +19,18 @@ mypy:
|
|||||||
lint:
|
lint:
|
||||||
$(call check_install, flake8)
|
$(call check_install, flake8)
|
||||||
$(call check_install_extra, bugbear, flake8_bugbear)
|
$(call check_install_extra, bugbear, flake8_bugbear)
|
||||||
flake8 ${LINT_PATHS} --count --show-source --statistics
|
flake8 ${PYTHON_FILES} --count --show-source --statistics
|
||||||
|
|
||||||
format:
|
format:
|
||||||
$(call check_install, isort)
|
$(call check_install, isort)
|
||||||
isort ${LINT_PATHS}
|
isort ${PYTHON_FILES}
|
||||||
$(call check_install, yapf)
|
$(call check_install, yapf)
|
||||||
yapf -ir ${LINT_PATHS}
|
yapf -ir ${PYTHON_FILES}
|
||||||
|
|
||||||
check-codestyle:
|
check-codestyle:
|
||||||
$(call check_install, isort)
|
$(call check_install, isort)
|
||||||
$(call check_install, yapf)
|
$(call check_install, yapf)
|
||||||
isort --check ${LINT_PATHS} && yapf -r -d ${LINT_PATHS}
|
isort --check ${PYTHON_FILES} && yapf -r -d ${PYTHON_FILES}
|
||||||
|
|
||||||
check-docstyle:
|
check-docstyle:
|
||||||
$(call check_install, pydocstyle)
|
$(call check_install, pydocstyle)
|
||||||
|
@ -18,6 +18,7 @@ numpy
|
|||||||
ndarray
|
ndarray
|
||||||
stackoverflow
|
stackoverflow
|
||||||
tensorboard
|
tensorboard
|
||||||
|
state_dict
|
||||||
len
|
len
|
||||||
tac
|
tac
|
||||||
fqf
|
fqf
|
||||||
|
@ -10,7 +10,7 @@ pip install envpool
|
|||||||
|
|
||||||
After that, `atari_wrapper` will automatically switch to envpool's Atari env. EnvPool's implementation is much faster (about 2\~3x faster for pure execution speed, 1.5x for overall RL training pipeline) than python vectorized env implementation, and it's behavior is consistent to that approach (OpenAI wrapper), which will describe below.
|
After that, `atari_wrapper` will automatically switch to envpool's Atari env. EnvPool's implementation is much faster (about 2\~3x faster for pure execution speed, 1.5x for overall RL training pipeline) than python vectorized env implementation, and it's behavior is consistent to that approach (OpenAI wrapper), which will describe below.
|
||||||
|
|
||||||
For more information, please refer to EnvPool's [GitHub](https://github.com/sail-sg/envpool/), [Docs](https://envpool.readthedocs.io/en/latest/api/atari.html), and [3rd-party report](https://ppo-details.cleanrl.dev/2021/11/05/ppo-implementation-details/#solving-pong-in-5-minutes-with-ppo--envpool).
|
For more information, please refer to EnvPool's [GitHub](https://github.com/sail-sg/envpool/), [Docs](https://envpool.readthedocs.io/en/latest/api/atari.html), and [3rd-party report](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#solving-pong-in-5-minutes-with-ppo--envpool).
|
||||||
|
|
||||||
## ALE-py
|
## ALE-py
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tianshou.exploration import GaussianNoise, OUNoise
|
from tianshou.exploration import GaussianNoise, OUNoise
|
||||||
from tianshou.utils import MovAvg, RunningMeanStd
|
from tianshou.utils import MovAvg, MultipleLRSchedulers, RunningMeanStd
|
||||||
from tianshou.utils.net.common import MLP, Net
|
from tianshou.utils.net.common import MLP, Net
|
||||||
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
|
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
|
||||||
|
|
||||||
@ -99,8 +99,48 @@ def test_net():
|
|||||||
assert list(net(data, act).shape) == [bsz, 1]
|
assert list(net(data, act).shape) == [bsz, 1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_lr_schedulers():
|
||||||
|
initial_lr_1 = 10.0
|
||||||
|
step_size_1 = 1
|
||||||
|
gamma_1 = 0.5
|
||||||
|
net_1 = torch.nn.Linear(2, 3)
|
||||||
|
optim_1 = torch.optim.Adam(net_1.parameters(), lr=initial_lr_1)
|
||||||
|
sched_1 = torch.optim.lr_scheduler.StepLR(
|
||||||
|
optim_1, step_size=step_size_1, gamma=gamma_1
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_lr_2 = 5.0
|
||||||
|
step_size_2 = 2
|
||||||
|
gamma_2 = 0.3
|
||||||
|
net_2 = torch.nn.Linear(3, 2)
|
||||||
|
optim_2 = torch.optim.Adam(net_2.parameters(), lr=initial_lr_2)
|
||||||
|
sched_2 = torch.optim.lr_scheduler.StepLR(
|
||||||
|
optim_2, step_size=step_size_2, gamma=gamma_2
|
||||||
|
)
|
||||||
|
schedulers = MultipleLRSchedulers(sched_1, sched_2)
|
||||||
|
for _ in range(10):
|
||||||
|
loss_1 = (torch.ones((1, 3)) - net_1(torch.ones((1, 2)))).sum()
|
||||||
|
optim_1.zero_grad()
|
||||||
|
loss_1.backward()
|
||||||
|
optim_1.step()
|
||||||
|
loss_2 = (torch.ones((1, 2)) - net_2(torch.ones((1, 3)))).sum()
|
||||||
|
optim_2.zero_grad()
|
||||||
|
loss_2.backward()
|
||||||
|
optim_2.step()
|
||||||
|
schedulers.step()
|
||||||
|
assert (
|
||||||
|
optim_1.state_dict()["param_groups"][0]["lr"] ==
|
||||||
|
(initial_lr_1 * gamma_1**(10 // step_size_1))
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
optim_2.state_dict()["param_groups"][0]["lr"] ==
|
||||||
|
(initial_lr_2 * gamma_2**(10 // step_size_2))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_noise()
|
test_noise()
|
||||||
test_moving_average()
|
test_moving_average()
|
||||||
test_rms()
|
test_rms()
|
||||||
test_net()
|
test_net()
|
||||||
|
test_lr_schedulers()
|
||||||
|
@ -9,6 +9,7 @@ from numba import njit
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
|
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
|
||||||
|
from tianshou.utils import MultipleLRSchedulers
|
||||||
|
|
||||||
|
|
||||||
class BasePolicy(ABC, nn.Module):
|
class BasePolicy(ABC, nn.Module):
|
||||||
@ -64,6 +65,8 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
action_space: Optional[gym.Space] = None,
|
action_space: Optional[gym.Space] = None,
|
||||||
action_scaling: bool = False,
|
action_scaling: bool = False,
|
||||||
action_bound_method: str = "",
|
action_bound_method: str = "",
|
||||||
|
lr_scheduler: Optional[Union[torch.optim.lr_scheduler.LambdaLR,
|
||||||
|
MultipleLRSchedulers]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.observation_space = observation_space
|
self.observation_space = observation_space
|
||||||
@ -79,6 +82,7 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
# can be one of ("clip", "tanh", ""), empty string means no bounding
|
# can be one of ("clip", "tanh", ""), empty string means no bounding
|
||||||
assert action_bound_method in ("", "clip", "tanh")
|
assert action_bound_method in ("", "clip", "tanh")
|
||||||
self.action_bound_method = action_bound_method
|
self.action_bound_method = action_bound_method
|
||||||
|
self.lr_scheduler = lr_scheduler
|
||||||
self._compile()
|
self._compile()
|
||||||
|
|
||||||
def set_agent_id(self, agent_id: int) -> None:
|
def set_agent_id(self, agent_id: int) -> None:
|
||||||
@ -272,6 +276,8 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
batch = self.process_fn(batch, buffer, indices)
|
batch = self.process_fn(batch, buffer, indices)
|
||||||
result = self.learn(batch, **kwargs)
|
result = self.learn(batch, **kwargs)
|
||||||
self.post_process_fn(batch, buffer, indices)
|
self.post_process_fn(batch, buffer, indices)
|
||||||
|
if self.lr_scheduler is not None:
|
||||||
|
self.lr_scheduler.step()
|
||||||
self.updating = False
|
self.updating = False
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -15,6 +15,8 @@ class ImitationPolicy(BasePolicy):
|
|||||||
:class:`~tianshou.policy.BasePolicy`. (s -> a)
|
:class:`~tianshou.policy.BasePolicy`. (s -> a)
|
||||||
:param torch.optim.Optimizer optim: for optimizing the model.
|
:param torch.optim.Optimizer optim: for optimizing the model.
|
||||||
:param gym.Space action_space: env's action space.
|
:param gym.Space action_space: env's action space.
|
||||||
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
|
@ -36,6 +36,8 @@ class BCQPolicy(BasePolicy):
|
|||||||
:param int num_sampled_action: the number of sampled actions in calculating
|
:param int num_sampled_action: the number of sampled actions in calculating
|
||||||
target Q. The algorithm samples several actions using VAE, and perturbs
|
target Q. The algorithm samples several actions using VAE, and perturbs
|
||||||
each action to get the target Q. Default to 10.
|
each action to get the target Q. Default to 10.
|
||||||
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
|
@ -46,6 +46,8 @@ class CQLPolicy(SACPolicy):
|
|||||||
:param float clip_grad: clip_grad for updating critic network. Default to 1.0.
|
:param float clip_grad: clip_grad for updating critic network. Default to 1.0.
|
||||||
:param Union[str, torch.device] device: which device to create this model on.
|
:param Union[str, torch.device] device: which device to create this model on.
|
||||||
Default to "cpu".
|
Default to "cpu".
|
||||||
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
|
@ -27,6 +27,8 @@ class DiscreteBCQPolicy(DQNPolicy):
|
|||||||
logits. Default to 1e-2.
|
logits. Default to 1e-2.
|
||||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||||
Default to False.
|
Default to False.
|
||||||
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
|
@ -23,6 +23,8 @@ class DiscreteCQLPolicy(QRDQNPolicy):
|
|||||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||||
Default to False.
|
Default to False.
|
||||||
:param float min_q_weight: the weight for the cql loss.
|
:param float min_q_weight: the weight for the cql loss.
|
||||||
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed
|
Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed
|
||||||
|
@ -29,6 +29,8 @@ class DiscreteCRRPolicy(PGPolicy):
|
|||||||
you do not use the target network). Default to 0.
|
you do not use the target network). Default to 0.
|
||||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||||
Default to False.
|
Default to False.
|
||||||
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
Please refer to :class:`~tianshou.policy.PGPolicy` for more detailed
|
Please refer to :class:`~tianshou.policy.PGPolicy` for more detailed
|
||||||
|
@ -59,6 +59,8 @@ class GAILPolicy(PPOPolicy):
|
|||||||
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
:param bool deterministic_eval: whether to use deterministic action instead of
|
:param bool deterministic_eval: whether to use deterministic action instead of
|
||||||
stochastic action sampled by the policy. Default to False.
|
stochastic action sampled by the policy. Default to False.
|
||||||
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
|
@ -17,6 +17,8 @@ class ICMPolicy(BasePolicy):
|
|||||||
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
|
||||||
:param float lr_scale: the scaling factor for ICM learning.
|
:param float lr_scale: the scaling factor for ICM learning.
|
||||||
:param float forward_loss_weight: the weight for forward model loss.
|
:param float forward_loss_weight: the weight for forward model loss.
|
||||||
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
|
@ -18,6 +18,8 @@ class PSRLModel(object):
|
|||||||
of rewards, with shape (n_state, n_action).
|
of rewards, with shape (n_state, n_action).
|
||||||
:param float discount_factor: in [0, 1].
|
:param float discount_factor: in [0, 1].
|
||||||
:param float epsilon: for precision control in value iteration.
|
:param float epsilon: for precision control in value iteration.
|
||||||
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -146,9 +146,6 @@ class A2CPolicy(PGPolicy):
|
|||||||
vf_losses.append(vf_loss.item())
|
vf_losses.append(vf_loss.item())
|
||||||
ent_losses.append(ent_loss.item())
|
ent_losses.append(ent_loss.item())
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
# update learning rate if lr_scheduler is given
|
|
||||||
if self.lr_scheduler is not None:
|
|
||||||
self.lr_scheduler.step()
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"loss": losses,
|
"loss": losses,
|
||||||
|
@ -25,6 +25,8 @@ class C51Policy(DQNPolicy):
|
|||||||
you do not use the target network). Default to 0.
|
you do not use the target network). Default to 0.
|
||||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||||
Default to False.
|
Default to False.
|
||||||
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
|
@ -32,6 +32,8 @@ class DDPGPolicy(BasePolicy):
|
|||||||
Default to "clip".
|
Default to "clip".
|
||||||
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
||||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||||
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
|
@ -28,6 +28,8 @@ class DiscreteSACPolicy(SACPolicy):
|
|||||||
alpha is automatically tuned.
|
alpha is automatically tuned.
|
||||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||||
Default to False.
|
Default to False.
|
||||||
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
|
@ -26,6 +26,8 @@ class DQNPolicy(BasePolicy):
|
|||||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||||
Default to False.
|
Default to False.
|
||||||
:param bool is_double: use double dqn. Default to True.
|
:param bool is_double: use double dqn. Default to True.
|
||||||
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
|
@ -27,6 +27,8 @@ class FQFPolicy(QRDQNPolicy):
|
|||||||
you do not use the target network).
|
you do not use the target network).
|
||||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||||
Default to False.
|
Default to False.
|
||||||
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
|
@ -26,6 +26,8 @@ class IQNPolicy(QRDQNPolicy):
|
|||||||
you do not use the target network).
|
you do not use the target network).
|
||||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||||
Default to False.
|
Default to False.
|
||||||
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
|
@ -127,10 +127,6 @@ class NPGPolicy(A2CPolicy):
|
|||||||
vf_losses.append(vf_loss.item())
|
vf_losses.append(vf_loss.item())
|
||||||
kls.append(kl.item())
|
kls.append(kl.item())
|
||||||
|
|
||||||
# update learning rate if lr_scheduler is given
|
|
||||||
if self.lr_scheduler is not None:
|
|
||||||
self.lr_scheduler.step()
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"loss/actor": actor_losses,
|
"loss/actor": actor_losses,
|
||||||
"loss/vf": vf_losses,
|
"loss/vf": vf_losses,
|
||||||
|
@ -44,7 +44,6 @@ class PGPolicy(BasePolicy):
|
|||||||
reward_normalization: bool = False,
|
reward_normalization: bool = False,
|
||||||
action_scaling: bool = True,
|
action_scaling: bool = True,
|
||||||
action_bound_method: str = "clip",
|
action_bound_method: str = "clip",
|
||||||
lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None,
|
|
||||||
deterministic_eval: bool = False,
|
deterministic_eval: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -55,7 +54,6 @@ class PGPolicy(BasePolicy):
|
|||||||
)
|
)
|
||||||
self.actor = model
|
self.actor = model
|
||||||
self.optim = optim
|
self.optim = optim
|
||||||
self.lr_scheduler = lr_scheduler
|
|
||||||
self.dist_fn = dist_fn
|
self.dist_fn = dist_fn
|
||||||
assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]"
|
assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]"
|
||||||
self._gamma = discount_factor
|
self._gamma = discount_factor
|
||||||
@ -137,8 +135,5 @@ class PGPolicy(BasePolicy):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
# update learning rate if lr_scheduler is given
|
|
||||||
if self.lr_scheduler is not None:
|
|
||||||
self.lr_scheduler.step()
|
|
||||||
|
|
||||||
return {"loss": losses}
|
return {"loss": losses}
|
||||||
|
@ -152,9 +152,6 @@ class PPOPolicy(A2CPolicy):
|
|||||||
vf_losses.append(vf_loss.item())
|
vf_losses.append(vf_loss.item())
|
||||||
ent_losses.append(ent_loss.item())
|
ent_losses.append(ent_loss.item())
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
# update learning rate if lr_scheduler is given
|
|
||||||
if self.lr_scheduler is not None:
|
|
||||||
self.lr_scheduler.step()
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"loss": losses,
|
"loss": losses,
|
||||||
|
@ -23,6 +23,8 @@ class QRDQNPolicy(DQNPolicy):
|
|||||||
you do not use the target network).
|
you do not use the target network).
|
||||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||||
Default to False.
|
Default to False.
|
||||||
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
|
@ -23,6 +23,8 @@ class RainbowPolicy(C51Policy):
|
|||||||
you do not use the target network). Default to 0.
|
you do not use the target network). Default to 0.
|
||||||
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
:param bool reward_normalization: normalize the reward to Normal(0, 1).
|
||||||
Default to False.
|
Default to False.
|
||||||
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
|
@ -42,6 +42,8 @@ class SACPolicy(DDPGPolicy):
|
|||||||
Default to "clip".
|
Default to "clip".
|
||||||
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
||||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||||
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
|
@ -40,6 +40,8 @@ class TD3Policy(DDPGPolicy):
|
|||||||
Default to "clip".
|
Default to "clip".
|
||||||
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
|
||||||
to use option "action_scaling" or "action_bound_method". Default to None.
|
to use option "action_scaling" or "action_bound_method". Default to None.
|
||||||
|
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||||
|
optimizer in each policy.update(). Default to None (no lr_scheduler).
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
|
@ -146,10 +146,6 @@ class TRPOPolicy(NPGPolicy):
|
|||||||
step_sizes.append(step_size.item())
|
step_sizes.append(step_size.item())
|
||||||
kls.append(kl.item())
|
kls.append(kl.item())
|
||||||
|
|
||||||
# update learning rate if lr_scheduler is given
|
|
||||||
if self.lr_scheduler is not None:
|
|
||||||
self.lr_scheduler.step()
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"loss/actor": actor_losses,
|
"loss/actor": actor_losses,
|
||||||
"loss/vf": vf_losses,
|
"loss/vf": vf_losses,
|
||||||
|
@ -4,6 +4,7 @@ from tianshou.utils.config import tqdm_config
|
|||||||
from tianshou.utils.logger.base import BaseLogger, LazyLogger
|
from tianshou.utils.logger.base import BaseLogger, LazyLogger
|
||||||
from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger
|
from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger
|
||||||
from tianshou.utils.logger.wandb import WandbLogger
|
from tianshou.utils.logger.wandb import WandbLogger
|
||||||
|
from tianshou.utils.lr_scheduler import MultipleLRSchedulers
|
||||||
from tianshou.utils.statistics import MovAvg, RunningMeanStd
|
from tianshou.utils.statistics import MovAvg, RunningMeanStd
|
||||||
from tianshou.utils.warning import deprecation
|
from tianshou.utils.warning import deprecation
|
||||||
|
|
||||||
@ -17,4 +18,5 @@ __all__ = [
|
|||||||
"LazyLogger",
|
"LazyLogger",
|
||||||
"WandbLogger",
|
"WandbLogger",
|
||||||
"deprecation",
|
"deprecation",
|
||||||
|
"MultipleLRSchedulers",
|
||||||
]
|
]
|
||||||
|
42
tianshou/utils/lr_scheduler.py
Normal file
42
tianshou/utils/lr_scheduler.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class MultipleLRSchedulers:
|
||||||
|
"""A wrapper for multiple learning rate schedulers.
|
||||||
|
|
||||||
|
Every time :meth:`~tianshou.utils.MultipleLRSchedulers.step` is called,
|
||||||
|
it calls the step() method of each of the schedulers that it contains.
|
||||||
|
Example usage:
|
||||||
|
::
|
||||||
|
|
||||||
|
scheduler1 = ConstantLR(opt1, factor=0.1, total_iters=2)
|
||||||
|
scheduler2 = ExponentialLR(opt2, gamma=0.9)
|
||||||
|
scheduler = MultipleLRSchedulers(scheduler1, scheduler2)
|
||||||
|
policy = PPOPolicy(..., lr_scheduler=scheduler)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args: torch.optim.lr_scheduler.LambdaLR):
|
||||||
|
self.schedulers = args
|
||||||
|
|
||||||
|
def step(self) -> None:
|
||||||
|
"""Take a step in each of the learning rate schedulers."""
|
||||||
|
for scheduler in self.schedulers:
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
def state_dict(self) -> List[Dict]:
|
||||||
|
"""Get state_dict for each of the learning rate schedulers.
|
||||||
|
|
||||||
|
:return: A list of state_dict of learning rate schedulers.
|
||||||
|
"""
|
||||||
|
return [s.state_dict() for s in self.schedulers]
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict: List[Dict]) -> None:
|
||||||
|
"""Load states from state_dict.
|
||||||
|
|
||||||
|
:param List[Dict] state_dict: A list of learning rate scheduler
|
||||||
|
state_dict, in the same order as the schedulers.
|
||||||
|
"""
|
||||||
|
for (s, sd) in zip(self.schedulers, state_dict):
|
||||||
|
s.__dict__.update(sd)
|
Loading…
x
Reference in New Issue
Block a user