Log Environments

This commit is contained in:
Dominik Jain 2023-10-10 13:26:07 +02:00
parent a8ea6808c3
commit 73a6d15eee
2 changed files with 9 additions and 1 deletions

View File

@ -8,6 +8,7 @@ import gymnasium as gym
from tianshou.env import BaseVectorEnv
from tianshou.highlevel.persistence import PersistableConfigProtocol
from tianshou.utils.net.common import TActionShape
from tianshou.utils.string import ToStringMixin
TObservationShape: TypeAlias = int | Sequence[int]
@ -31,12 +32,18 @@ class EnvType(Enum):
raise AssertionError(f"{requiring_entity} requires discrete environments")
class Environments(ABC):
class Environments(ToStringMixin, ABC):
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
self.env = env
self.train_envs = train_envs
self.test_envs = test_envs
def _tostring_includes(self) -> list[str]:
return []
def _tostring_additional_entries(self) -> dict[str, Any]:
return self.info()
def info(self) -> dict[str, Any]:
return {
"action_shape": self.get_action_shape(),

View File

@ -102,6 +102,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
def run(self, log_name: str) -> None:
self._set_seed()
envs = self.env_factory(self.env_config)
log.info(f"Created {envs}")
full_config = self._build_config_dict()
full_config.update(envs.info())