Establish backward compatibility by implementing __setstate__

This commit is contained in:
Dominik Jain 2024-05-02 18:47:42 +02:00
parent ca69e79b4a
commit c35be8d07e

View File

@ -3,7 +3,7 @@ import time
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast, Dict
import gymnasium as gym
import numpy as np
@ -240,6 +240,12 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
"""
self._compile()
def __setstate__(self, state: Dict[str, Any]) -> None:
# TODO Use setstate function once merged
if "is_within_training_step" not in state:
state["is_within_training_step"] = False
self.__dict__ = state
@property
def action_type(self) -> Literal["discrete", "continuous"]:
return self._action_type