Establish backward compatibility by implementing __setstate__
This commit is contained in:
parent
ca69e79b4a
commit
c35be8d07e
@ -3,7 +3,7 @@ import time
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast
|
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast, Dict
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -240,6 +240,12 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
|||||||
"""
|
"""
|
||||||
self._compile()
|
self._compile()
|
||||||
|
|
||||||
|
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||||
|
# TODO Use setstate function once merged
|
||||||
|
if "is_within_training_step" not in state:
|
||||||
|
state["is_within_training_step"] = False
|
||||||
|
self.__dict__ = state
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def action_type(self) -> Literal["discrete", "continuous"]:
|
def action_type(self) -> Literal["discrete", "continuous"]:
|
||||||
return self._action_type
|
return self._action_type
|
||||||
|
Loading…
x
Reference in New Issue
Block a user