Log Environments
This commit is contained in:
parent
a8ea6808c3
commit
73a6d15eee
@ -8,6 +8,7 @@ import gymnasium as gym
|
||||
from tianshou.env import BaseVectorEnv
|
||||
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
||||
from tianshou.utils.net.common import TActionShape
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
TObservationShape: TypeAlias = int | Sequence[int]
|
||||
|
||||
@ -31,12 +32,18 @@ class EnvType(Enum):
|
||||
raise AssertionError(f"{requiring_entity} requires discrete environments")
|
||||
|
||||
|
||||
class Environments(ABC):
|
||||
class Environments(ToStringMixin, ABC):
|
||||
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
self.env = env
|
||||
self.train_envs = train_envs
|
||||
self.test_envs = test_envs
|
||||
|
||||
def _tostring_includes(self) -> list[str]:
|
||||
return []
|
||||
|
||||
def _tostring_additional_entries(self) -> dict[str, Any]:
|
||||
return self.info()
|
||||
|
||||
def info(self) -> dict[str, Any]:
|
||||
return {
|
||||
"action_shape": self.get_action_shape(),
|
||||
|
@ -102,6 +102,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
def run(self, log_name: str) -> None:
|
||||
self._set_seed()
|
||||
envs = self.env_factory(self.env_config)
|
||||
log.info(f"Created {envs}")
|
||||
|
||||
full_config = self._build_config_dict()
|
||||
full_config.update(envs.info())
|
||||
|
Loading…
x
Reference in New Issue
Block a user