diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index c48fa90..8f22236 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -8,6 +8,7 @@ import gymnasium as gym from tianshou.env import BaseVectorEnv from tianshou.highlevel.persistence import PersistableConfigProtocol from tianshou.utils.net.common import TActionShape +from tianshou.utils.string import ToStringMixin TObservationShape: TypeAlias = int | Sequence[int] @@ -31,12 +32,18 @@ class EnvType(Enum): raise AssertionError(f"{requiring_entity} requires discrete environments") -class Environments(ABC): +class Environments(ToStringMixin, ABC): def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): self.env = env self.train_envs = train_envs self.test_envs = test_envs + def _tostring_includes(self) -> list[str]: + return [] + + def _tostring_additional_entries(self) -> dict[str, Any]: + return self.info() + def info(self) -> dict[str, Any]: return { "action_shape": self.get_action_shape(), diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index b097157..77e94ec 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -102,6 +102,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): def run(self, log_name: str) -> None: self._set_seed() envs = self.env_factory(self.env_config) + log.info(f"Created {envs}") full_config = self._build_config_dict() full_config.update(envs.info())