Tianshou/tianshou/highlevel/experiment.py

112 lines
3.4 KiB
Python
Raw Normal View History

2023-09-20 13:15:06 +02:00
from dataclasses import dataclass
from pprint import pprint
from typing import Generic, TypeVar
import numpy as np
import torch
from tianshou.data import Collector
from tianshou.highlevel.agent import AgentFactory
from tianshou.highlevel.env import EnvFactory
from tianshou.highlevel.logger import LoggerFactory
from tianshou.policy import BasePolicy
from tianshou.trainer import BaseTrainer
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
2023-09-20 13:15:06 +02:00
@dataclass
class RLExperimentConfig:
"""Generic config for setting up the experiment, not RL or training specific."""
seed: int = 42
render: float | None = 0.0
"""Milliseconds between rendered frames; if None, no rendering"""
device: str = "cuda" if torch.cuda.is_available() else "cpu"
resume_id: str | None = None
"""For restoring a model and running means of env-specifics from a checkpoint"""
resume_path: str | None = None
"""For restoring a model and running means of env-specifics from a checkpoint"""
watch: bool = False
"""If True, will not perform training and only watch the restored policy"""
watch_num_episodes = 10
class RLExperiment(Generic[TPolicy, TTrainer]):
2023-09-20 09:29:34 +02:00
def __init__(
self,
2023-09-20 13:15:06 +02:00
config: RLExperimentConfig,
2023-09-20 09:29:34 +02:00
env_factory: EnvFactory,
logger_factory: LoggerFactory,
agent_factory: AgentFactory,
):
self.config = config
self.env_factory = env_factory
self.logger_factory = logger_factory
self.agent_factory = agent_factory
2023-09-20 09:29:34 +02:00
def _set_seed(self) -> None:
seed = self.config.seed
np.random.seed(seed)
torch.manual_seed(seed)
def _build_config_dict(self) -> dict:
return {
# TODO
}
2023-09-20 09:29:34 +02:00
def run(self, log_name: str) -> None:
self._set_seed()
envs = self.env_factory.create_envs()
full_config = self._build_config_dict()
full_config.update(envs.info())
run_id = self.config.resume_id
2023-09-20 09:29:34 +02:00
logger = self.logger_factory.create_logger(
log_name=log_name,
run_id=run_id,
config_dict=full_config,
)
policy = self.agent_factory.create_policy(envs, self.config.device)
if self.config.resume_path:
2023-09-20 09:29:34 +02:00
self.agent_factory.load_checkpoint(
policy,
self.config.resume_path,
envs,
self.config.device,
)
train_collector, test_collector = self.agent_factory.create_train_test_collector(
policy,
envs,
)
if not self.config.watch:
2023-09-20 09:29:34 +02:00
trainer = self.agent_factory.create_trainer(
policy,
train_collector,
test_collector,
envs,
logger,
)
result = trainer.run()
pprint(result) # TODO logging
2023-09-20 09:29:34 +02:00
self._watch_agent(
self.config.watch_num_episodes,
policy,
test_collector,
self.config.render,
)
@staticmethod
2023-09-20 09:29:34 +02:00
def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render) -> None:
policy.eval()
test_collector.reset()
result = test_collector.collect(n_episode=num_episodes, render=render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')