Log Environments
This commit is contained in:
parent
a8ea6808c3
commit
73a6d15eee
@ -8,6 +8,7 @@ import gymnasium as gym
|
|||||||
from tianshou.env import BaseVectorEnv
|
from tianshou.env import BaseVectorEnv
|
||||||
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
||||||
from tianshou.utils.net.common import TActionShape
|
from tianshou.utils.net.common import TActionShape
|
||||||
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
TObservationShape: TypeAlias = int | Sequence[int]
|
TObservationShape: TypeAlias = int | Sequence[int]
|
||||||
|
|
||||||
@ -31,12 +32,18 @@ class EnvType(Enum):
|
|||||||
raise AssertionError(f"{requiring_entity} requires discrete environments")
|
raise AssertionError(f"{requiring_entity} requires discrete environments")
|
||||||
|
|
||||||
|
|
||||||
class Environments(ABC):
|
class Environments(ToStringMixin, ABC):
|
||||||
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||||
self.env = env
|
self.env = env
|
||||||
self.train_envs = train_envs
|
self.train_envs = train_envs
|
||||||
self.test_envs = test_envs
|
self.test_envs = test_envs
|
||||||
|
|
||||||
|
def _tostring_includes(self) -> list[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _tostring_additional_entries(self) -> dict[str, Any]:
|
||||||
|
return self.info()
|
||||||
|
|
||||||
def info(self) -> dict[str, Any]:
|
def info(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"action_shape": self.get_action_shape(),
|
"action_shape": self.get_action_shape(),
|
||||||
|
@ -102,6 +102,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
|||||||
def run(self, log_name: str) -> None:
|
def run(self, log_name: str) -> None:
|
||||||
self._set_seed()
|
self._set_seed()
|
||||||
envs = self.env_factory(self.env_config)
|
envs = self.env_factory(self.env_config)
|
||||||
|
log.info(f"Created {envs}")
|
||||||
|
|
||||||
full_config = self._build_config_dict()
|
full_config = self._build_config_dict()
|
||||||
full_config.update(envs.info())
|
full_config.update(envs.info())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user