Establish backward compatibility by implementing __setstate__
This commit is contained in:
parent
ca69e79b4a
commit
c35be8d07e
@ -3,7 +3,7 @@ import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast
|
||||
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast, Dict
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -240,6 +240,12 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
||||
"""
|
||||
self._compile()
|
||||
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||
# TODO Use setstate function once merged
|
||||
if "is_within_training_step" not in state:
|
||||
state["is_within_training_step"] = False
|
||||
self.__dict__ = state
|
||||
|
||||
@property
|
||||
def action_type(self) -> Literal["discrete", "continuous"]:
|
||||
return self._action_type
|
||||
|
Loading…
x
Reference in New Issue
Block a user