BatchPolicy: check that self.is_within_training_step is True on update

This commit is contained in:
Michael Panchenko 2024-05-05 15:14:59 +02:00
parent c35be8d07e
commit 6927eadaa7

View File

@ -3,7 +3,7 @@ import time
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast, Dict
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast
import gymnasium as gym
import numpy as np
@ -214,13 +214,13 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
super().__init__()
self.observation_space = observation_space
self.action_space = action_space
self._action_type: Literal["discrete", "continuous"]
if isinstance(action_space, Discrete | MultiDiscrete | MultiBinary):
self._action_type = "discrete"
action_type = "discrete"
elif isinstance(action_space, Box):
self._action_type = "continuous"
action_type = "continuous"
else:
raise ValueError(f"Unsupported action space: {action_space}.")
self._action_type = cast(Literal["discrete", "continuous"], action_type)
self.agent_id = 0
self.updating = False
self.action_scaling = action_scaling
@ -228,19 +228,22 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
self.lr_scheduler = lr_scheduler
self.is_within_training_step = False
"""
flag indicating whether we are currently within a training step, which encompasses data collection
for training and the policy update (gradient steps).
It can be used, for example, to control whether a flag controlling deterministic evaluation should
flag indicating whether we are currently within a training step,
which encompasses data collection for training (in online RL algorithms)
and the policy update (gradient steps).
It can be used, for example, to control whether a flag controlling deterministic evaluation should
indeed be applied, because within a training step, we typically always want to apply stochastic evaluation
(even if such a flag is enabled).
(even if such a flag is enabled), as well as stochastic action computation for q-targets (e.g. in SAC
based algorithms).
This flag should normally remain False and should be set to True only by the algorithm which performs
training steps.
training steps. This is done automatically by the Trainer classes. If a policy is used outside of a Trainer,
the user should ensure that this flag is set correctly before calling update or learn.
"""
self._compile()
def __setstate__(self, state: Dict[str, Any]) -> None:
def __setstate__(self, state: dict[str, Any]) -> None:
# TODO Use setstate function once merged
if "is_within_training_step" not in state:
state["is_within_training_step"] = False
@ -524,6 +527,14 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
"""
# TODO: when does this happen?
# -> this happens never in practice as update is either called with a collector buffer or an assert before
if not self.is_within_training_step:
raise RuntimeError(
f"update() was called outside of a training step as signalled by {self.is_within_training_step=} "
f"If you want to update the policy without a Trainer, you will have to manage the above-mentioned "
f"flag yourself.",
)
if buffer is None:
return TrainingStats() # type: ignore[return-value]
start_time = time.time()