BatchPolicy: check that self.is_within_training_step
is True on update
This commit is contained in:
parent
c35be8d07e
commit
6927eadaa7
@ -3,7 +3,7 @@ import time
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast, Dict
|
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -214,13 +214,13 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.observation_space = observation_space
|
self.observation_space = observation_space
|
||||||
self.action_space = action_space
|
self.action_space = action_space
|
||||||
self._action_type: Literal["discrete", "continuous"]
|
|
||||||
if isinstance(action_space, Discrete | MultiDiscrete | MultiBinary):
|
if isinstance(action_space, Discrete | MultiDiscrete | MultiBinary):
|
||||||
self._action_type = "discrete"
|
action_type = "discrete"
|
||||||
elif isinstance(action_space, Box):
|
elif isinstance(action_space, Box):
|
||||||
self._action_type = "continuous"
|
action_type = "continuous"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported action space: {action_space}.")
|
raise ValueError(f"Unsupported action space: {action_space}.")
|
||||||
|
self._action_type = cast(Literal["discrete", "continuous"], action_type)
|
||||||
self.agent_id = 0
|
self.agent_id = 0
|
||||||
self.updating = False
|
self.updating = False
|
||||||
self.action_scaling = action_scaling
|
self.action_scaling = action_scaling
|
||||||
@ -228,19 +228,22 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
|||||||
self.lr_scheduler = lr_scheduler
|
self.lr_scheduler = lr_scheduler
|
||||||
self.is_within_training_step = False
|
self.is_within_training_step = False
|
||||||
"""
|
"""
|
||||||
flag indicating whether we are currently within a training step, which encompasses data collection
|
flag indicating whether we are currently within a training step,
|
||||||
for training and the policy update (gradient steps).
|
which encompasses data collection for training (in online RL algorithms)
|
||||||
|
and the policy update (gradient steps).
|
||||||
|
|
||||||
It can be used, for example, to control whether a flag controlling deterministic evaluation should
|
It can be used, for example, to control whether a flag controlling deterministic evaluation should
|
||||||
indeed be applied, because within a training step, we typically always want to apply stochastic evaluation
|
indeed be applied, because within a training step, we typically always want to apply stochastic evaluation
|
||||||
(even if such a flag is enabled).
|
(even if such a flag is enabled), as well as stochastic action computation for q-targets (e.g. in SAC
|
||||||
|
based algorithms).
|
||||||
|
|
||||||
This flag should normally remain False and should be set to True only by the algorithm which performs
|
This flag should normally remain False and should be set to True only by the algorithm which performs
|
||||||
training steps.
|
training steps. This is done automatically by the Trainer classes. If a policy is used outside of a Trainer,
|
||||||
|
the user should ensure that this flag is set correctly before calling update or learn.
|
||||||
"""
|
"""
|
||||||
self._compile()
|
self._compile()
|
||||||
|
|
||||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
def __setstate__(self, state: dict[str, Any]) -> None:
|
||||||
# TODO Use setstate function once merged
|
# TODO Use setstate function once merged
|
||||||
if "is_within_training_step" not in state:
|
if "is_within_training_step" not in state:
|
||||||
state["is_within_training_step"] = False
|
state["is_within_training_step"] = False
|
||||||
@ -524,6 +527,14 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
|||||||
"""
|
"""
|
||||||
# TODO: when does this happen?
|
# TODO: when does this happen?
|
||||||
# -> this happens never in practice as update is either called with a collector buffer or an assert before
|
# -> this happens never in practice as update is either called with a collector buffer or an assert before
|
||||||
|
|
||||||
|
if not self.is_within_training_step:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"update() was called outside of a training step as signalled by {self.is_within_training_step=} "
|
||||||
|
f"If you want to update the policy without a Trainer, you will have to manage the above-mentioned "
|
||||||
|
f"flag yourself.",
|
||||||
|
)
|
||||||
|
|
||||||
if buffer is None:
|
if buffer is None:
|
||||||
return TrainingStats() # type: ignore[return-value]
|
return TrainingStats() # type: ignore[return-value]
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user