2023-09-19 18:53:11 +02:00
|
|
|
from pprint import pprint
|
|
|
|
from typing import Generic, TypeVar
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
|
2023-09-20 09:29:34 +02:00
|
|
|
from tianshou.config import (
|
|
|
|
BasicExperimentConfig,
|
|
|
|
)
|
2023-09-19 18:53:11 +02:00
|
|
|
from tianshou.data import Collector
|
|
|
|
from tianshou.highlevel.agent import AgentFactory
|
|
|
|
from tianshou.highlevel.env import EnvFactory
|
|
|
|
from tianshou.highlevel.logger import LoggerFactory
|
|
|
|
from tianshou.policy import BasePolicy
|
|
|
|
from tianshou.trainer import BaseTrainer
|
|
|
|
|
|
|
|
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
|
|
|
TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
|
|
|
|
|
|
|
|
|
|
|
|
class RLExperiment(Generic[TPolicy, TTrainer]):
|
2023-09-20 09:29:34 +02:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
config: BasicExperimentConfig,
|
|
|
|
env_factory: EnvFactory,
|
|
|
|
logger_factory: LoggerFactory,
|
|
|
|
agent_factory: AgentFactory,
|
|
|
|
):
|
2023-09-19 18:53:11 +02:00
|
|
|
self.config = config
|
|
|
|
self.env_factory = env_factory
|
|
|
|
self.logger_factory = logger_factory
|
|
|
|
self.agent_factory = agent_factory
|
|
|
|
|
2023-09-20 09:29:34 +02:00
|
|
|
def _set_seed(self) -> None:
|
2023-09-19 18:53:11 +02:00
|
|
|
seed = self.config.seed
|
|
|
|
np.random.seed(seed)
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
|
|
|
|
def _build_config_dict(self) -> dict:
|
|
|
|
return {
|
|
|
|
# TODO
|
|
|
|
}
|
|
|
|
|
2023-09-20 09:29:34 +02:00
|
|
|
def run(self, log_name: str) -> None:
|
2023-09-19 18:53:11 +02:00
|
|
|
self._set_seed()
|
|
|
|
|
|
|
|
envs = self.env_factory.create_envs()
|
|
|
|
|
|
|
|
full_config = self._build_config_dict()
|
|
|
|
full_config.update(envs.info())
|
|
|
|
|
|
|
|
run_id = self.config.resume_id
|
2023-09-20 09:29:34 +02:00
|
|
|
logger = self.logger_factory.create_logger(
|
|
|
|
log_name=log_name,
|
|
|
|
run_id=run_id,
|
|
|
|
config_dict=full_config,
|
|
|
|
)
|
2023-09-19 18:53:11 +02:00
|
|
|
|
|
|
|
policy = self.agent_factory.create_policy(envs, self.config.device)
|
|
|
|
if self.config.resume_path:
|
2023-09-20 09:29:34 +02:00
|
|
|
self.agent_factory.load_checkpoint(
|
|
|
|
policy,
|
|
|
|
self.config.resume_path,
|
|
|
|
envs,
|
|
|
|
self.config.device,
|
|
|
|
)
|
|
|
|
|
|
|
|
train_collector, test_collector = self.agent_factory.create_train_test_collector(
|
|
|
|
policy,
|
|
|
|
envs,
|
|
|
|
)
|
2023-09-19 18:53:11 +02:00
|
|
|
|
|
|
|
if not self.config.watch:
|
2023-09-20 09:29:34 +02:00
|
|
|
trainer = self.agent_factory.create_trainer(
|
|
|
|
policy,
|
|
|
|
train_collector,
|
|
|
|
test_collector,
|
|
|
|
envs,
|
|
|
|
logger,
|
|
|
|
)
|
2023-09-19 18:53:11 +02:00
|
|
|
result = trainer.run()
|
|
|
|
pprint(result) # TODO logging
|
|
|
|
|
2023-09-20 09:29:34 +02:00
|
|
|
self._watch_agent(
|
|
|
|
self.config.watch_num_episodes,
|
|
|
|
policy,
|
|
|
|
test_collector,
|
|
|
|
self.config.render,
|
|
|
|
)
|
2023-09-19 18:53:11 +02:00
|
|
|
|
|
|
|
@staticmethod
|
2023-09-20 09:29:34 +02:00
|
|
|
def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render) -> None:
|
2023-09-19 18:53:11 +02:00
|
|
|
policy.eval()
|
|
|
|
test_collector.reset()
|
|
|
|
result = test_collector.collect(n_episode=num_episodes, render=render)
|
|
|
|
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|