2023-10-11 19:31:26 +02:00
|
|
|
import os
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from typing import TYPE_CHECKING, Optional
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
from tianshou.data import Collector
|
|
|
|
from tianshou.highlevel.env import Environments
|
2023-10-12 17:40:16 +02:00
|
|
|
from tianshou.highlevel.logger import TLogger
|
2023-10-11 19:31:26 +02:00
|
|
|
from tianshou.policy import BasePolicy
|
|
|
|
from tianshou.trainer import BaseTrainer
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class World:
|
|
|
|
envs: "Environments"
|
|
|
|
policy: "BasePolicy"
|
|
|
|
train_collector: "Collector"
|
|
|
|
test_collector: "Collector"
|
2023-10-12 17:40:16 +02:00
|
|
|
logger: "TLogger"
|
|
|
|
persist_directory: str
|
2023-10-13 12:25:28 +02:00
|
|
|
restore_directory: str | None
|
2023-10-11 19:31:26 +02:00
|
|
|
trainer: Optional["BaseTrainer"] = None
|
|
|
|
|
2023-10-12 15:01:49 +02:00
|
|
|
def persist_path(self, filename: str) -> str:
|
|
|
|
return os.path.join(self.persist_directory, filename)
|
|
|
|
|
|
|
|
def restore_path(self, filename: str) -> str:
|
|
|
|
return os.path.join(self.restore_directory, filename)
|