77 lines
2.7 KiB
Python
77 lines
2.7 KiB
Python
|
from pprint import pprint
|
||
|
from typing import Generic, TypeVar
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
|
||
|
from tianshou.config import BasicExperimentConfig, LoggerConfig, RLAgentConfig, RLSamplingConfig
|
||
|
from tianshou.data import Collector
|
||
|
from tianshou.highlevel.agent import AgentFactory
|
||
|
from tianshou.highlevel.env import EnvFactory
|
||
|
from tianshou.highlevel.logger import LoggerFactory
|
||
|
from tianshou.policy import BasePolicy
|
||
|
from tianshou.trainer import BaseTrainer
|
||
|
|
||
|
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||
|
TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
|
||
|
|
||
|
|
||
|
class RLExperiment(Generic[TPolicy, TTrainer]):
|
||
|
def __init__(self,
|
||
|
config: BasicExperimentConfig,
|
||
|
logger_config: LoggerConfig,
|
||
|
general_config: RLAgentConfig,
|
||
|
sampling_config: RLSamplingConfig,
|
||
|
env_factory: EnvFactory,
|
||
|
logger_factory: LoggerFactory,
|
||
|
agent_factory: AgentFactory):
|
||
|
self.config = config
|
||
|
self.logger_config = logger_config
|
||
|
self.general_config = general_config
|
||
|
self.sampling_config = sampling_config
|
||
|
self.env_factory = env_factory
|
||
|
self.logger_factory = logger_factory
|
||
|
self.agent_factory = agent_factory
|
||
|
|
||
|
def _set_seed(self):
|
||
|
seed = self.config.seed
|
||
|
np.random.seed(seed)
|
||
|
torch.manual_seed(seed)
|
||
|
|
||
|
def _build_config_dict(self) -> dict:
|
||
|
return {
|
||
|
# TODO
|
||
|
}
|
||
|
|
||
|
def run(self, log_name: str):
|
||
|
self._set_seed()
|
||
|
|
||
|
envs = self.env_factory.create_envs()
|
||
|
|
||
|
full_config = self._build_config_dict()
|
||
|
full_config.update(envs.info())
|
||
|
|
||
|
run_id = self.config.resume_id
|
||
|
logger = self.logger_factory.create_logger(log_name=log_name, run_id=run_id, config_dict=full_config)
|
||
|
|
||
|
policy = self.agent_factory.create_policy(envs, self.config.device)
|
||
|
if self.config.resume_path:
|
||
|
self.agent_factory.load_checkpoint(policy, self.config.resume_path, envs, self.config.device)
|
||
|
|
||
|
train_collector, test_collector = self.agent_factory.create_train_test_collector(policy, envs)
|
||
|
|
||
|
if not self.config.watch:
|
||
|
trainer = self.agent_factory.create_trainer(policy, train_collector, test_collector, envs, logger)
|
||
|
result = trainer.run()
|
||
|
pprint(result) # TODO logging
|
||
|
|
||
|
self._watch_agent(self.config.watch_num_episodes, policy, test_collector, self.config.render)
|
||
|
|
||
|
@staticmethod
|
||
|
def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render):
|
||
|
policy.eval()
|
||
|
test_collector.reset()
|
||
|
result = test_collector.collect(n_episode=num_episodes, render=render)
|
||
|
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
||
|
|