from pprint import pprint from typing import Generic, TypeVar import numpy as np import torch from tianshou.config import ( BasicExperimentConfig, ) from tianshou.data import Collector from tianshou.highlevel.agent import AgentFactory from tianshou.highlevel.env import EnvFactory from tianshou.highlevel.logger import LoggerFactory from tianshou.policy import BasePolicy from tianshou.trainer import BaseTrainer TPolicy = TypeVar("TPolicy", bound=BasePolicy) TTrainer = TypeVar("TTrainer", bound=BaseTrainer) class RLExperiment(Generic[TPolicy, TTrainer]): def __init__( self, config: BasicExperimentConfig, env_factory: EnvFactory, logger_factory: LoggerFactory, agent_factory: AgentFactory, ): self.config = config self.env_factory = env_factory self.logger_factory = logger_factory self.agent_factory = agent_factory def _set_seed(self) -> None: seed = self.config.seed np.random.seed(seed) torch.manual_seed(seed) def _build_config_dict(self) -> dict: return { # TODO } def run(self, log_name: str) -> None: self._set_seed() envs = self.env_factory.create_envs() full_config = self._build_config_dict() full_config.update(envs.info()) run_id = self.config.resume_id logger = self.logger_factory.create_logger( log_name=log_name, run_id=run_id, config_dict=full_config, ) policy = self.agent_factory.create_policy(envs, self.config.device) if self.config.resume_path: self.agent_factory.load_checkpoint( policy, self.config.resume_path, envs, self.config.device, ) train_collector, test_collector = self.agent_factory.create_train_test_collector( policy, envs, ) if not self.config.watch: trainer = self.agent_factory.create_trainer( policy, train_collector, test_collector, envs, logger, ) result = trainer.run() pprint(result) # TODO logging self._watch_agent( self.config.watch_num_episodes, policy, test_collector, self.config.render, ) @staticmethod def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render) -> None: policy.eval() test_collector.reset() result = test_collector.collect(n_episode=num_episodes, render=render) print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')