diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 498a333..2703de7 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -3,7 +3,7 @@ import time from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass, field -from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast +from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast, Dict import gymnasium as gym import numpy as np @@ -240,6 +240,12 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC): """ self._compile() + def __setstate__(self, state: Dict[str, Any]) -> None: + # TODO Use setstate function once merged + if "is_within_training_step" not in state: + state["is_within_training_step"] = False + self.__dict__ = state + @property def action_type(self) -> Literal["discrete", "continuous"]: return self._action_type