diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 2703de7..8753af6 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -3,7 +3,7 @@ import time from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass, field -from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast, Dict +from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast import gymnasium as gym import numpy as np @@ -214,13 +214,13 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC): super().__init__() self.observation_space = observation_space self.action_space = action_space - self._action_type: Literal["discrete", "continuous"] if isinstance(action_space, Discrete | MultiDiscrete | MultiBinary): - self._action_type = "discrete" + action_type = "discrete" elif isinstance(action_space, Box): - self._action_type = "continuous" + action_type = "continuous" else: raise ValueError(f"Unsupported action space: {action_space}.") + self._action_type = cast(Literal["discrete", "continuous"], action_type) self.agent_id = 0 self.updating = False self.action_scaling = action_scaling @@ -228,19 +228,22 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC): self.lr_scheduler = lr_scheduler self.is_within_training_step = False """ - flag indicating whether we are currently within a training step, which encompasses data collection - for training and the policy update (gradient steps). - - It can be used, for example, to control whether a flag controlling deterministic evaluation should + flag indicating whether we are currently within a training step, + which encompasses data collection for training (in online RL algorithms) + and the policy update (gradient steps). + + It can be used, for example, to control whether a flag controlling deterministic evaluation should indeed be applied, because within a training step, we typically always want to apply stochastic evaluation - (even if such a flag is enabled). - + (even if such a flag is enabled), as well as stochastic action computation for q-targets (e.g. in SAC + based algorithms). + This flag should normally remain False and should be set to True only by the algorithm which performs - training steps. + training steps. This is done automatically by the Trainer classes. If a policy is used outside of a Trainer, + the user should ensure that this flag is set correctly before calling update or learn. """ self._compile() - def __setstate__(self, state: Dict[str, Any]) -> None: + def __setstate__(self, state: dict[str, Any]) -> None: # TODO Use setstate function once merged if "is_within_training_step" not in state: state["is_within_training_step"] = False @@ -524,6 +527,14 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC): """ # TODO: when does this happen? # -> this happens never in practice as update is either called with a collector buffer or an assert before + + if not self.is_within_training_step: + raise RuntimeError( + f"update() was called outside of a training step as signalled by {self.is_within_training_step=} " + f"If you want to update the policy without a Trainer, you will have to manage the above-mentioned " + f"flag yourself.", + ) + if buffer is None: return TrainingStats() # type: ignore[return-value] start_time = time.time()