Tianshou/tianshou/highlevel/experiment.py

97 lines
2.7 KiB
Python
Raw Normal View History

from pprint import pprint
from typing import Generic, TypeVar
import numpy as np
import torch
2023-09-20 09:29:34 +02:00
from tianshou.config import (
BasicExperimentConfig,
)
from tianshou.data import Collector
from tianshou.highlevel.agent import AgentFactory
from tianshou.highlevel.env import EnvFactory
from tianshou.highlevel.logger import LoggerFactory
from tianshou.policy import BasePolicy
from tianshou.trainer import BaseTrainer
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
class RLExperiment(Generic[TPolicy, TTrainer]):
2023-09-20 09:29:34 +02:00
def __init__(
self,
config: BasicExperimentConfig,
env_factory: EnvFactory,
logger_factory: LoggerFactory,
agent_factory: AgentFactory,
):
self.config = config
self.env_factory = env_factory
self.logger_factory = logger_factory
self.agent_factory = agent_factory
2023-09-20 09:29:34 +02:00
def _set_seed(self) -> None:
seed = self.config.seed
np.random.seed(seed)
torch.manual_seed(seed)
def _build_config_dict(self) -> dict:
return {
# TODO
}
2023-09-20 09:29:34 +02:00
def run(self, log_name: str) -> None:
self._set_seed()
envs = self.env_factory.create_envs()
full_config = self._build_config_dict()
full_config.update(envs.info())
run_id = self.config.resume_id
2023-09-20 09:29:34 +02:00
logger = self.logger_factory.create_logger(
log_name=log_name,
run_id=run_id,
config_dict=full_config,
)
policy = self.agent_factory.create_policy(envs, self.config.device)
if self.config.resume_path:
2023-09-20 09:29:34 +02:00
self.agent_factory.load_checkpoint(
policy,
self.config.resume_path,
envs,
self.config.device,
)
train_collector, test_collector = self.agent_factory.create_train_test_collector(
policy,
envs,
)
if not self.config.watch:
2023-09-20 09:29:34 +02:00
trainer = self.agent_factory.create_trainer(
policy,
train_collector,
test_collector,
envs,
logger,
)
result = trainer.run()
pprint(result) # TODO logging
2023-09-20 09:29:34 +02:00
self._watch_agent(
self.config.watch_num_episodes,
policy,
test_collector,
self.config.render,
)
@staticmethod
2023-09-20 09:29:34 +02:00
def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render) -> None:
policy.eval()
test_collector.reset()
result = test_collector.collect(n_episode=num_episodes, render=render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')