Log Environments

This commit is contained in:
Dominik Jain 2023-10-10 13:26:07 +02:00
parent a8ea6808c3
commit 73a6d15eee
2 changed files with 9 additions and 1 deletions

View File

@ -8,6 +8,7 @@ import gymnasium as gym
from tianshou.env import BaseVectorEnv from tianshou.env import BaseVectorEnv
from tianshou.highlevel.persistence import PersistableConfigProtocol from tianshou.highlevel.persistence import PersistableConfigProtocol
from tianshou.utils.net.common import TActionShape from tianshou.utils.net.common import TActionShape
from tianshou.utils.string import ToStringMixin
TObservationShape: TypeAlias = int | Sequence[int] TObservationShape: TypeAlias = int | Sequence[int]
@ -31,12 +32,18 @@ class EnvType(Enum):
raise AssertionError(f"{requiring_entity} requires discrete environments") raise AssertionError(f"{requiring_entity} requires discrete environments")
class Environments(ABC): class Environments(ToStringMixin, ABC):
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
self.env = env self.env = env
self.train_envs = train_envs self.train_envs = train_envs
self.test_envs = test_envs self.test_envs = test_envs
def _tostring_includes(self) -> list[str]:
return []
def _tostring_additional_entries(self) -> dict[str, Any]:
return self.info()
def info(self) -> dict[str, Any]: def info(self) -> dict[str, Any]:
return { return {
"action_shape": self.get_action_shape(), "action_shape": self.get_action_shape(),

View File

@ -102,6 +102,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
def run(self, log_name: str) -> None: def run(self, log_name: str) -> None:
self._set_seed() self._set_seed()
envs = self.env_factory(self.env_config) envs = self.env_factory(self.env_config)
log.info(f"Created {envs}")
full_config = self._build_config_dict() full_config = self._build_config_dict()
full_config.update(envs.info()) full_config.update(envs.info())