BatchPolicy: check that self.is_within_training_step
is True on update
This commit is contained in:
parent
c35be8d07e
commit
6927eadaa7
@ -3,7 +3,7 @@ import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast, Dict
|
||||
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -214,13 +214,13 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
||||
super().__init__()
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self._action_type: Literal["discrete", "continuous"]
|
||||
if isinstance(action_space, Discrete | MultiDiscrete | MultiBinary):
|
||||
self._action_type = "discrete"
|
||||
action_type = "discrete"
|
||||
elif isinstance(action_space, Box):
|
||||
self._action_type = "continuous"
|
||||
action_type = "continuous"
|
||||
else:
|
||||
raise ValueError(f"Unsupported action space: {action_space}.")
|
||||
self._action_type = cast(Literal["discrete", "continuous"], action_type)
|
||||
self.agent_id = 0
|
||||
self.updating = False
|
||||
self.action_scaling = action_scaling
|
||||
@ -228,19 +228,22 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.is_within_training_step = False
|
||||
"""
|
||||
flag indicating whether we are currently within a training step, which encompasses data collection
|
||||
for training and the policy update (gradient steps).
|
||||
|
||||
It can be used, for example, to control whether a flag controlling deterministic evaluation should
|
||||
flag indicating whether we are currently within a training step,
|
||||
which encompasses data collection for training (in online RL algorithms)
|
||||
and the policy update (gradient steps).
|
||||
|
||||
It can be used, for example, to control whether a flag controlling deterministic evaluation should
|
||||
indeed be applied, because within a training step, we typically always want to apply stochastic evaluation
|
||||
(even if such a flag is enabled).
|
||||
|
||||
(even if such a flag is enabled), as well as stochastic action computation for q-targets (e.g. in SAC
|
||||
based algorithms).
|
||||
|
||||
This flag should normally remain False and should be set to True only by the algorithm which performs
|
||||
training steps.
|
||||
training steps. This is done automatically by the Trainer classes. If a policy is used outside of a Trainer,
|
||||
the user should ensure that this flag is set correctly before calling update or learn.
|
||||
"""
|
||||
self._compile()
|
||||
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||
def __setstate__(self, state: dict[str, Any]) -> None:
|
||||
# TODO Use setstate function once merged
|
||||
if "is_within_training_step" not in state:
|
||||
state["is_within_training_step"] = False
|
||||
@ -524,6 +527,14 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
||||
"""
|
||||
# TODO: when does this happen?
|
||||
# -> this happens never in practice as update is either called with a collector buffer or an assert before
|
||||
|
||||
if not self.is_within_training_step:
|
||||
raise RuntimeError(
|
||||
f"update() was called outside of a training step as signalled by {self.is_within_training_step=} "
|
||||
f"If you want to update the policy without a Trainer, you will have to manage the above-mentioned "
|
||||
f"flag yourself.",
|
||||
)
|
||||
|
||||
if buffer is None:
|
||||
return TrainingStats() # type: ignore[return-value]
|
||||
start_time = time.time()
|
||||
|
Loading…
x
Reference in New Issue
Block a user