From 73a6d15eeeaa14865fa9a90302c1d62d4a6ff580 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 10 Oct 2023 13:26:07 +0200 Subject: [PATCH] Log Environments --- tianshou/highlevel/env.py | 9 ++++++++- tianshou/highlevel/experiment.py | 1 + 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index c48fa90..8f22236 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -8,6 +8,7 @@ import gymnasium as gym from tianshou.env import BaseVectorEnv from tianshou.highlevel.persistence import PersistableConfigProtocol from tianshou.utils.net.common import TActionShape +from tianshou.utils.string import ToStringMixin TObservationShape: TypeAlias = int | Sequence[int] @@ -31,12 +32,18 @@ class EnvType(Enum): raise AssertionError(f"{requiring_entity} requires discrete environments") -class Environments(ABC): +class Environments(ToStringMixin, ABC): def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): self.env = env self.train_envs = train_envs self.test_envs = test_envs + def _tostring_includes(self) -> list[str]: + return [] + + def _tostring_additional_entries(self) -> dict[str, Any]: + return self.info() + def info(self) -> dict[str, Any]: return { "action_shape": self.get_action_shape(), diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index b097157..77e94ec 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -102,6 +102,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): def run(self, log_name: str) -> None: self._set_seed() envs = self.env_factory(self.env_config) + log.info(f"Created {envs}") full_config = self._build_config_dict() full_config.update(envs.info())