BatchPolicy: check that self.is_within_training_step is True on update

This commit is contained in:
Michael Panchenko 2024-05-05 15:14:59 +02:00
parent c35be8d07e
commit 6927eadaa7

View File

@ -3,7 +3,7 @@ import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast, Dict from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
@ -214,13 +214,13 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
super().__init__() super().__init__()
self.observation_space = observation_space self.observation_space = observation_space
self.action_space = action_space self.action_space = action_space
self._action_type: Literal["discrete", "continuous"]
if isinstance(action_space, Discrete | MultiDiscrete | MultiBinary): if isinstance(action_space, Discrete | MultiDiscrete | MultiBinary):
self._action_type = "discrete" action_type = "discrete"
elif isinstance(action_space, Box): elif isinstance(action_space, Box):
self._action_type = "continuous" action_type = "continuous"
else: else:
raise ValueError(f"Unsupported action space: {action_space}.") raise ValueError(f"Unsupported action space: {action_space}.")
self._action_type = cast(Literal["discrete", "continuous"], action_type)
self.agent_id = 0 self.agent_id = 0
self.updating = False self.updating = False
self.action_scaling = action_scaling self.action_scaling = action_scaling
@ -228,19 +228,22 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
self.is_within_training_step = False self.is_within_training_step = False
""" """
flag indicating whether we are currently within a training step, which encompasses data collection flag indicating whether we are currently within a training step,
for training and the policy update (gradient steps). which encompasses data collection for training (in online RL algorithms)
and the policy update (gradient steps).
It can be used, for example, to control whether a flag controlling deterministic evaluation should
It can be used, for example, to control whether a flag controlling deterministic evaluation should
indeed be applied, because within a training step, we typically always want to apply stochastic evaluation indeed be applied, because within a training step, we typically always want to apply stochastic evaluation
(even if such a flag is enabled). (even if such a flag is enabled), as well as stochastic action computation for q-targets (e.g. in SAC
based algorithms).
This flag should normally remain False and should be set to True only by the algorithm which performs This flag should normally remain False and should be set to True only by the algorithm which performs
training steps. training steps. This is done automatically by the Trainer classes. If a policy is used outside of a Trainer,
the user should ensure that this flag is set correctly before calling update or learn.
""" """
self._compile() self._compile()
def __setstate__(self, state: Dict[str, Any]) -> None: def __setstate__(self, state: dict[str, Any]) -> None:
# TODO Use setstate function once merged # TODO Use setstate function once merged
if "is_within_training_step" not in state: if "is_within_training_step" not in state:
state["is_within_training_step"] = False state["is_within_training_step"] = False
@ -524,6 +527,14 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
""" """
# TODO: when does this happen? # TODO: when does this happen?
# -> this happens never in practice as update is either called with a collector buffer or an assert before # -> this happens never in practice as update is either called with a collector buffer or an assert before
if not self.is_within_training_step:
raise RuntimeError(
f"update() was called outside of a training step as signalled by {self.is_within_training_step=} "
f"If you want to update the policy without a Trainer, you will have to manage the above-mentioned "
f"flag yourself.",
)
if buffer is None: if buffer is None:
return TrainingStats() # type: ignore[return-value] return TrainingStats() # type: ignore[return-value]
start_time = time.time() start_time = time.time()