581 lines
19 KiB
Python
581 lines
19 KiB
Python
import logging
|
|
from abc import abstractmethod
|
|
from collections.abc import Callable, Sequence
|
|
from dataclasses import dataclass
|
|
from pprint import pprint
|
|
from typing import Generic, Self, TypeVar
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from tianshou.data import Collector
|
|
from tianshou.highlevel.agent import (
|
|
A2CAgentFactory,
|
|
AgentFactory,
|
|
DDPGAgentFactory,
|
|
DQNAgentFactory,
|
|
PPOAgentFactory,
|
|
SACAgentFactory,
|
|
TD3AgentFactory,
|
|
)
|
|
from tianshou.highlevel.config import RLSamplingConfig
|
|
from tianshou.highlevel.env import EnvFactory, Environments
|
|
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
|
|
from tianshou.highlevel.module.actor import (
|
|
ActorFactory,
|
|
ActorFactoryDefault,
|
|
ContinuousActorType,
|
|
)
|
|
from tianshou.highlevel.module.critic import CriticFactory, CriticFactoryDefault
|
|
from tianshou.highlevel.optim import OptimizerFactory, OptimizerFactoryAdam
|
|
from tianshou.highlevel.params.policy_params import (
|
|
A2CParams,
|
|
DDPGParams,
|
|
DQNParams,
|
|
PPOParams,
|
|
SACParams,
|
|
TD3Params,
|
|
)
|
|
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
|
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
|
from tianshou.highlevel.trainer import (
|
|
TrainerCallbacks,
|
|
TrainerEpochCallback,
|
|
TrainerStopCallback,
|
|
)
|
|
from tianshou.policy import BasePolicy
|
|
from tianshou.trainer import BaseTrainer
|
|
from tianshou.utils.string import ToStringMixin
|
|
|
|
log = logging.getLogger(__name__)
|
|
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
|
TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
|
|
|
|
|
|
@dataclass
|
|
class RLExperimentConfig:
|
|
"""Generic config for setting up the experiment, not RL or training specific."""
|
|
|
|
seed: int = 42
|
|
render: float | None = 0.0
|
|
"""Milliseconds between rendered frames; if None, no rendering"""
|
|
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
|
resume_id: str | None = None
|
|
"""For restoring a model and running means of env-specifics from a checkpoint"""
|
|
resume_path: str | None = None
|
|
"""For restoring a model and running means of env-specifics from a checkpoint"""
|
|
watch: bool = False
|
|
"""If True, will not perform training and only watch the restored policy"""
|
|
watch_num_episodes = 10
|
|
|
|
|
|
class RLExperiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
|
def __init__(
|
|
self,
|
|
config: RLExperimentConfig,
|
|
env_factory: EnvFactory | Callable[[PersistableConfigProtocol | None], Environments],
|
|
agent_factory: AgentFactory,
|
|
logger_factory: LoggerFactory | None = None,
|
|
env_config: PersistableConfigProtocol | None = None,
|
|
):
|
|
if logger_factory is None:
|
|
logger_factory = DefaultLoggerFactory()
|
|
self.config = config
|
|
self.env_factory = env_factory
|
|
self.agent_factory = agent_factory
|
|
self.logger_factory = logger_factory
|
|
self.env_config = env_config
|
|
|
|
def _set_seed(self) -> None:
|
|
seed = self.config.seed
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
|
|
def _build_config_dict(self) -> dict:
|
|
return {
|
|
# TODO
|
|
}
|
|
|
|
def run(self, log_name: str) -> None:
|
|
self._set_seed()
|
|
envs = self.env_factory(self.env_config)
|
|
|
|
full_config = self._build_config_dict()
|
|
full_config.update(envs.info())
|
|
|
|
run_id = self.config.resume_id
|
|
logger = self.logger_factory.create_logger(
|
|
log_name=log_name,
|
|
run_id=run_id,
|
|
config_dict=full_config,
|
|
)
|
|
|
|
policy = self.agent_factory.create_policy(envs, self.config.device)
|
|
if self.config.resume_path:
|
|
self.agent_factory.load_checkpoint(
|
|
policy,
|
|
self.config.resume_path,
|
|
envs,
|
|
self.config.device,
|
|
)
|
|
|
|
train_collector, test_collector = self.agent_factory.create_train_test_collector(
|
|
policy,
|
|
envs,
|
|
)
|
|
|
|
if not self.config.watch:
|
|
trainer = self.agent_factory.create_trainer(
|
|
policy,
|
|
train_collector,
|
|
test_collector,
|
|
envs,
|
|
logger,
|
|
)
|
|
result = trainer.run()
|
|
pprint(result) # TODO logging
|
|
|
|
self._watch_agent(
|
|
self.config.watch_num_episodes,
|
|
policy,
|
|
test_collector,
|
|
self.config.render,
|
|
)
|
|
|
|
@staticmethod
|
|
def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render) -> None:
|
|
policy.eval()
|
|
test_collector.reset()
|
|
result = test_collector.collect(n_episode=num_episodes, render=render)
|
|
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
|
|
|
|
|
TBuilder = TypeVar("TBuilder", bound="RLExperimentBuilder")
|
|
|
|
|
|
class RLExperimentBuilder:
|
|
def __init__(
|
|
self,
|
|
experiment_config: RLExperimentConfig,
|
|
env_factory: EnvFactory,
|
|
sampling_config: RLSamplingConfig,
|
|
):
|
|
self._config = experiment_config
|
|
self._env_factory = env_factory
|
|
self._sampling_config = sampling_config
|
|
self._logger_factory: LoggerFactory | None = None
|
|
self._optim_factory: OptimizerFactory | None = None
|
|
self._env_config: PersistableConfigProtocol | None = None
|
|
self._policy_wrapper_factory: PolicyWrapperFactory | None = None
|
|
self._trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
|
|
|
def with_env_config(self, config: PersistableConfigProtocol) -> Self:
|
|
self._env_config = config
|
|
return self
|
|
|
|
def with_logger_factory(self: TBuilder, logger_factory: LoggerFactory) -> TBuilder:
|
|
self._logger_factory = logger_factory
|
|
return self
|
|
|
|
def with_policy_wrapper_factory(self, policy_wrapper_factory: PolicyWrapperFactory) -> Self:
|
|
self._policy_wrapper_factory = policy_wrapper_factory
|
|
return self
|
|
|
|
def with_optim_factory(self: TBuilder, optim_factory: OptimizerFactory) -> TBuilder:
|
|
self._optim_factory = optim_factory
|
|
return self
|
|
|
|
def with_optim_factory_default(
|
|
self: TBuilder,
|
|
betas=(0.9, 0.999),
|
|
eps=1e-08,
|
|
weight_decay=0,
|
|
) -> TBuilder:
|
|
"""Configures the use of the default optimizer, Adam, with the given parameters.
|
|
|
|
:param betas: coefficients used for computing running averages of gradient and its square
|
|
:param eps: term added to the denominator to improve numerical stability
|
|
:param weight_decay: weight decay (L2 penalty)
|
|
:return: the builder
|
|
"""
|
|
self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay)
|
|
return self
|
|
|
|
def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallback) -> Self:
|
|
self._trainer_callbacks.epoch_callback_train = callback
|
|
return self
|
|
|
|
def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallback) -> Self:
|
|
self._trainer_callbacks.epoch_callback_test = callback
|
|
return self
|
|
|
|
def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self:
|
|
self._trainer_callbacks.stop_callback = callback
|
|
return self
|
|
|
|
@abstractmethod
|
|
def _create_agent_factory(self) -> AgentFactory:
|
|
pass
|
|
|
|
def _get_optim_factory(self) -> OptimizerFactory:
|
|
if self._optim_factory is None:
|
|
return OptimizerFactoryAdam()
|
|
else:
|
|
return self._optim_factory
|
|
|
|
def build(self) -> RLExperiment:
|
|
agent_factory = self._create_agent_factory()
|
|
agent_factory.set_trainer_callbacks(self._trainer_callbacks)
|
|
if self._policy_wrapper_factory:
|
|
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
|
|
experiment = RLExperiment(
|
|
self._config,
|
|
self._env_factory,
|
|
agent_factory,
|
|
self._logger_factory,
|
|
env_config=self._env_config,
|
|
)
|
|
log.info(f"Created experiment:\n{experiment.pprints()}")
|
|
return experiment
|
|
|
|
|
|
class _BuilderMixinActorFactory:
|
|
def __init__(self, continuous_actor_type: ContinuousActorType):
|
|
self._continuous_actor_type = continuous_actor_type
|
|
self._actor_factory: ActorFactory | None = None
|
|
|
|
def with_actor_factory(self: TBuilder, actor_factory: ActorFactory) -> TBuilder:
|
|
self: TBuilder | _BuilderMixinActorFactory
|
|
self._actor_factory = actor_factory
|
|
return self
|
|
|
|
def _with_actor_factory_default(
|
|
self: TBuilder,
|
|
hidden_sizes: Sequence[int],
|
|
continuous_unbounded=False,
|
|
continuous_conditioned_sigma=False,
|
|
) -> TBuilder:
|
|
self: TBuilder | _BuilderMixinActorFactory
|
|
self._actor_factory = ActorFactoryDefault(
|
|
self._continuous_actor_type,
|
|
hidden_sizes,
|
|
continuous_unbounded=continuous_unbounded,
|
|
continuous_conditioned_sigma=continuous_conditioned_sigma,
|
|
)
|
|
return self
|
|
|
|
def _get_actor_factory(self):
|
|
if self._actor_factory is None:
|
|
return ActorFactoryDefault(self._continuous_actor_type)
|
|
else:
|
|
return self._actor_factory
|
|
|
|
|
|
class _BuilderMixinActorFactory_ContinuousGaussian(_BuilderMixinActorFactory):
|
|
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
|
|
|
|
def __init__(self):
|
|
super().__init__(ContinuousActorType.GAUSSIAN)
|
|
|
|
def with_actor_factory_default(
|
|
self,
|
|
hidden_sizes: Sequence[int],
|
|
continuous_unbounded=False,
|
|
continuous_conditioned_sigma=False,
|
|
) -> Self:
|
|
return super()._with_actor_factory_default(
|
|
hidden_sizes,
|
|
continuous_unbounded=continuous_unbounded,
|
|
continuous_conditioned_sigma=continuous_conditioned_sigma,
|
|
)
|
|
|
|
|
|
class _BuilderMixinActorFactory_ContinuousDeterministic(_BuilderMixinActorFactory):
|
|
"""Specialization of the actor mixin where, in the continuous case, the actor uses a deterministic policy."""
|
|
|
|
def __init__(self):
|
|
super().__init__(ContinuousActorType.DETERMINISTIC)
|
|
|
|
def with_actor_factory_default(self, hidden_sizes: Sequence[int]) -> Self:
|
|
return super()._with_actor_factory_default(hidden_sizes)
|
|
|
|
|
|
class _BuilderMixinCriticsFactory:
|
|
def __init__(self, num_critics: int):
|
|
self._critic_factories: list[CriticFactory | None] = [None] * num_critics
|
|
|
|
def _with_critic_factory(self, idx: int, critic_factory: CriticFactory):
|
|
self._critic_factories[idx] = critic_factory
|
|
return self
|
|
|
|
def _with_critic_factory_default(self, idx: int, hidden_sizes: Sequence[int]):
|
|
self._critic_factories[idx] = CriticFactoryDefault(hidden_sizes)
|
|
return self
|
|
|
|
def _get_critic_factory(self, idx: int):
|
|
factory = self._critic_factories[idx]
|
|
if factory is None:
|
|
return CriticFactoryDefault()
|
|
else:
|
|
return factory
|
|
|
|
|
|
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
|
def __init__(self):
|
|
super().__init__(1)
|
|
|
|
def with_critic_factory(self, critic_factory: CriticFactory) -> Self:
|
|
self._with_critic_factory(0, critic_factory)
|
|
return self
|
|
|
|
def with_critic_factory_default(
|
|
self,
|
|
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
|
) -> Self:
|
|
self._with_critic_factory_default(0, hidden_sizes)
|
|
return self
|
|
|
|
|
|
class _BuilderMixinSingleCriticCanUseActorFactory(_BuilderMixinSingleCriticFactory):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._critic_use_actor_module = False
|
|
|
|
def with_critic_factory_use_actor(self) -> Self:
|
|
"""Makes the critic use the same network as the actor."""
|
|
self._critic_use_actor_module = True
|
|
return self
|
|
|
|
|
|
class _BuilderMixinDualCriticFactory(_BuilderMixinCriticsFactory):
|
|
def __init__(self):
|
|
super().__init__(2)
|
|
|
|
def with_common_critic_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
|
|
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
|
for i in range(len(self._critic_factories)):
|
|
self._with_critic_factory(i, critic_factory)
|
|
return self
|
|
|
|
def with_common_critic_factory_default(
|
|
self,
|
|
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
|
) -> TBuilder:
|
|
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
|
for i in range(len(self._critic_factories)):
|
|
self._with_critic_factory_default(i, hidden_sizes)
|
|
return self
|
|
|
|
def with_critic1_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
|
|
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
|
self._with_critic_factory(0, critic_factory)
|
|
return self
|
|
|
|
def with_critic1_factory_default(
|
|
self,
|
|
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
|
) -> TBuilder:
|
|
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
|
self._with_critic_factory_default(0, hidden_sizes)
|
|
return self
|
|
|
|
def with_critic2_factory(self: TBuilder, critic_factory: CriticFactory) -> TBuilder:
|
|
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
|
self._with_critic_factory(1, critic_factory)
|
|
return self
|
|
|
|
def with_critic2_factory_default(
|
|
self,
|
|
hidden_sizes: Sequence[int] = CriticFactoryDefault.DEFAULT_HIDDEN_SIZES,
|
|
) -> TBuilder:
|
|
self: TBuilder | "_BuilderMixinDualCriticFactory"
|
|
self._with_critic_factory_default(0, hidden_sizes)
|
|
return self
|
|
|
|
|
|
class A2CExperimentBuilder(
|
|
RLExperimentBuilder,
|
|
_BuilderMixinActorFactory_ContinuousGaussian,
|
|
_BuilderMixinSingleCriticCanUseActorFactory,
|
|
):
|
|
def __init__(
|
|
self,
|
|
experiment_config: RLExperimentConfig,
|
|
env_factory: EnvFactory,
|
|
sampling_config: RLSamplingConfig,
|
|
env_config: PersistableConfigProtocol | None = None,
|
|
):
|
|
super().__init__(experiment_config, env_factory, sampling_config)
|
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
|
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
|
self._params: A2CParams = A2CParams()
|
|
self._env_config = env_config
|
|
|
|
def with_a2c_params(self, params: A2CParams) -> Self:
|
|
self._params = params
|
|
return self
|
|
|
|
@abstractmethod
|
|
def _create_agent_factory(self) -> AgentFactory:
|
|
return A2CAgentFactory(
|
|
self._params,
|
|
self._sampling_config,
|
|
self._get_actor_factory(),
|
|
self._get_critic_factory(0),
|
|
self._get_optim_factory(),
|
|
self._critic_use_actor_module,
|
|
)
|
|
|
|
|
|
class PPOExperimentBuilder(
|
|
RLExperimentBuilder,
|
|
_BuilderMixinActorFactory_ContinuousGaussian,
|
|
_BuilderMixinSingleCriticCanUseActorFactory,
|
|
):
|
|
def __init__(
|
|
self,
|
|
experiment_config: RLExperimentConfig,
|
|
env_factory: EnvFactory,
|
|
sampling_config: RLSamplingConfig,
|
|
):
|
|
super().__init__(experiment_config, env_factory, sampling_config)
|
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
|
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
|
self._params: PPOParams = PPOParams()
|
|
|
|
def with_ppo_params(self, params: PPOParams) -> Self:
|
|
self._params = params
|
|
return self
|
|
|
|
@abstractmethod
|
|
def _create_agent_factory(self) -> AgentFactory:
|
|
return PPOAgentFactory(
|
|
self._params,
|
|
self._sampling_config,
|
|
self._get_actor_factory(),
|
|
self._get_critic_factory(0),
|
|
self._get_optim_factory(),
|
|
self._critic_use_actor_module,
|
|
)
|
|
|
|
|
|
class DQNExperimentBuilder(
|
|
RLExperimentBuilder,
|
|
_BuilderMixinActorFactory,
|
|
):
|
|
def __init__(
|
|
self,
|
|
experiment_config: RLExperimentConfig,
|
|
env_factory: EnvFactory,
|
|
sampling_config: RLSamplingConfig,
|
|
):
|
|
super().__init__(experiment_config, env_factory, sampling_config)
|
|
_BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
|
|
self._params: DQNParams = DQNParams()
|
|
|
|
def with_dqn_params(self, params: DQNParams) -> Self:
|
|
self._params = params
|
|
return self
|
|
|
|
@abstractmethod
|
|
def _create_agent_factory(self) -> AgentFactory:
|
|
return DQNAgentFactory(
|
|
self._params,
|
|
self._sampling_config,
|
|
self._get_actor_factory(),
|
|
self._get_optim_factory(),
|
|
)
|
|
|
|
|
|
class DDPGExperimentBuilder(
|
|
RLExperimentBuilder,
|
|
_BuilderMixinActorFactory_ContinuousDeterministic,
|
|
_BuilderMixinSingleCriticCanUseActorFactory,
|
|
):
|
|
def __init__(
|
|
self,
|
|
experiment_config: RLExperimentConfig,
|
|
env_factory: EnvFactory,
|
|
sampling_config: RLSamplingConfig,
|
|
):
|
|
super().__init__(experiment_config, env_factory, sampling_config)
|
|
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
|
_BuilderMixinSingleCriticCanUseActorFactory.__init__(self)
|
|
self._params: DDPGParams = DDPGParams()
|
|
|
|
def with_ddpg_params(self, params: DDPGParams) -> Self:
|
|
self._params = params
|
|
return self
|
|
|
|
@abstractmethod
|
|
def _create_agent_factory(self) -> AgentFactory:
|
|
return DDPGAgentFactory(
|
|
self._params,
|
|
self._sampling_config,
|
|
self._get_actor_factory(),
|
|
self._get_critic_factory(0),
|
|
self._get_optim_factory(),
|
|
)
|
|
|
|
|
|
class SACExperimentBuilder(
|
|
RLExperimentBuilder,
|
|
_BuilderMixinActorFactory_ContinuousGaussian,
|
|
_BuilderMixinDualCriticFactory,
|
|
):
|
|
def __init__(
|
|
self,
|
|
experiment_config: RLExperimentConfig,
|
|
env_factory: EnvFactory,
|
|
sampling_config: RLSamplingConfig,
|
|
):
|
|
super().__init__(experiment_config, env_factory, sampling_config)
|
|
_BuilderMixinActorFactory_ContinuousGaussian.__init__(self)
|
|
_BuilderMixinDualCriticFactory.__init__(self)
|
|
self._params: SACParams = SACParams()
|
|
|
|
def with_sac_params(self, params: SACParams) -> Self:
|
|
self._params = params
|
|
return self
|
|
|
|
def _create_agent_factory(self) -> AgentFactory:
|
|
return SACAgentFactory(
|
|
self._params,
|
|
self._sampling_config,
|
|
self._get_actor_factory(),
|
|
self._get_critic_factory(0),
|
|
self._get_critic_factory(1),
|
|
self._get_optim_factory(),
|
|
)
|
|
|
|
|
|
class TD3ExperimentBuilder(
|
|
RLExperimentBuilder,
|
|
_BuilderMixinActorFactory_ContinuousDeterministic,
|
|
_BuilderMixinDualCriticFactory,
|
|
):
|
|
def __init__(
|
|
self,
|
|
experiment_config: RLExperimentConfig,
|
|
env_factory: EnvFactory,
|
|
sampling_config: RLSamplingConfig,
|
|
):
|
|
super().__init__(experiment_config, env_factory, sampling_config)
|
|
_BuilderMixinActorFactory_ContinuousDeterministic.__init__(self)
|
|
_BuilderMixinDualCriticFactory.__init__(self)
|
|
self._params: TD3Params = TD3Params()
|
|
|
|
def with_td3_params(self, params: TD3Params) -> Self:
|
|
self._params = params
|
|
return self
|
|
|
|
def _create_agent_factory(self) -> AgentFactory:
|
|
return TD3AgentFactory(
|
|
self._params,
|
|
self._sampling_config,
|
|
self._get_actor_factory(),
|
|
self._get_critic_factory(0),
|
|
self._get_critic_factory(1),
|
|
self._get_optim_factory(),
|
|
)
|