2023-09-20 09:29:34 +02:00
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
from collections.abc import Callable
|
2023-09-28 14:28:03 +02:00
|
|
|
from typing import Generic, TypeVar
|
2023-09-19 18:53:11 +02:00
|
|
|
|
2023-10-05 15:39:32 +02:00
|
|
|
import gymnasium
|
2023-09-19 18:53:11 +02:00
|
|
|
import torch
|
|
|
|
|
2023-09-20 09:29:34 +02:00
|
|
|
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
2023-09-20 15:28:33 +02:00
|
|
|
from tianshou.highlevel.config import RLSamplingConfig
|
2023-09-20 15:45:09 +02:00
|
|
|
from tianshou.highlevel.env import Environments
|
2023-09-19 18:53:11 +02:00
|
|
|
from tianshou.highlevel.logger import Logger
|
2023-09-28 20:07:52 +02:00
|
|
|
from tianshou.highlevel.module.actor import (
|
2023-09-26 15:35:18 +02:00
|
|
|
ActorFactory,
|
2023-09-28 20:07:52 +02:00
|
|
|
)
|
|
|
|
from tianshou.highlevel.module.core import TDevice
|
|
|
|
from tianshou.highlevel.module.critic import CriticFactory
|
|
|
|
from tianshou.highlevel.module.module_opt import (
|
|
|
|
ActorCriticModuleOpt,
|
2023-09-26 15:35:18 +02:00
|
|
|
ActorModuleOptFactory,
|
|
|
|
CriticModuleOptFactory,
|
|
|
|
ModuleOpt,
|
|
|
|
)
|
2023-09-25 17:56:37 +02:00
|
|
|
from tianshou.highlevel.optim import OptimizerFactory
|
2023-09-26 15:35:18 +02:00
|
|
|
from tianshou.highlevel.params.policy_params import (
|
2023-09-28 14:28:03 +02:00
|
|
|
A2CParams,
|
2023-10-03 20:26:39 +02:00
|
|
|
DDPGParams,
|
2023-10-05 15:39:32 +02:00
|
|
|
DQNParams,
|
2023-09-28 14:28:03 +02:00
|
|
|
Params,
|
2023-09-26 17:43:16 +02:00
|
|
|
ParamTransformerData,
|
2023-09-26 15:35:18 +02:00
|
|
|
PPOParams,
|
|
|
|
SACParams,
|
|
|
|
TD3Params,
|
|
|
|
)
|
2023-09-28 20:07:52 +02:00
|
|
|
from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
2023-10-05 15:39:32 +02:00
|
|
|
from tianshou.highlevel.trainer import TrainerCallbacks, TrainingContext
|
2023-10-03 20:26:39 +02:00
|
|
|
from tianshou.policy import (
|
|
|
|
A2CPolicy,
|
|
|
|
BasePolicy,
|
|
|
|
DDPGPolicy,
|
2023-10-05 15:39:32 +02:00
|
|
|
DQNPolicy,
|
2023-10-03 20:26:39 +02:00
|
|
|
PPOPolicy,
|
|
|
|
SACPolicy,
|
|
|
|
TD3Policy,
|
|
|
|
)
|
2023-09-20 09:29:34 +02:00
|
|
|
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
2023-09-28 20:07:52 +02:00
|
|
|
from tianshou.utils.net import continuous, discrete
|
2023-09-19 18:53:11 +02:00
|
|
|
from tianshou.utils.net.common import ActorCritic
|
2023-10-03 21:14:22 +02:00
|
|
|
from tianshou.utils.string import ToStringMixin
|
2023-09-19 18:53:11 +02:00
|
|
|
|
|
|
|
CHECKPOINT_DICT_KEY_MODEL = "model"
|
|
|
|
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
2023-09-28 14:28:03 +02:00
|
|
|
TParams = TypeVar("TParams", bound=Params)
|
|
|
|
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
2023-09-19 18:53:11 +02:00
|
|
|
|
|
|
|
|
2023-10-03 21:14:22 +02:00
|
|
|
class AgentFactory(ABC, ToStringMixin):
|
2023-09-28 20:07:52 +02:00
|
|
|
def __init__(self, sampling_config: RLSamplingConfig, optim_factory: OptimizerFactory):
|
2023-09-20 09:29:34 +02:00
|
|
|
self.sampling_config = sampling_config
|
2023-09-28 20:07:52 +02:00
|
|
|
self.optim_factory = optim_factory
|
|
|
|
self.policy_wrapper_factory: PolicyWrapperFactory | None = None
|
2023-10-05 15:39:32 +02:00
|
|
|
self.trainer_callbacks: TrainerCallbacks = TrainerCallbacks()
|
2023-09-20 09:29:34 +02:00
|
|
|
|
|
|
|
def create_train_test_collector(self, policy: BasePolicy, envs: Environments):
|
|
|
|
buffer_size = self.sampling_config.buffer_size
|
|
|
|
train_envs = envs.train_envs
|
|
|
|
if len(train_envs) > 1:
|
2023-09-28 20:07:52 +02:00
|
|
|
buffer = VectorReplayBuffer(
|
|
|
|
buffer_size,
|
|
|
|
len(train_envs),
|
|
|
|
stack_num=self.sampling_config.replay_buffer_stack_num,
|
|
|
|
save_only_last_obs=self.sampling_config.replay_buffer_save_only_last_obs,
|
|
|
|
ignore_obs_next=self.sampling_config.replay_buffer_ignore_obs_next,
|
|
|
|
)
|
2023-09-20 09:29:34 +02:00
|
|
|
else:
|
2023-09-28 20:07:52 +02:00
|
|
|
buffer = ReplayBuffer(
|
|
|
|
buffer_size,
|
|
|
|
stack_num=self.sampling_config.replay_buffer_stack_num,
|
|
|
|
save_only_last_obs=self.sampling_config.replay_buffer_save_only_last_obs,
|
|
|
|
ignore_obs_next=self.sampling_config.replay_buffer_ignore_obs_next,
|
|
|
|
)
|
2023-09-20 09:29:34 +02:00
|
|
|
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
|
|
|
test_collector = Collector(policy, envs.test_envs)
|
2023-09-20 15:13:05 +02:00
|
|
|
if self.sampling_config.start_timesteps > 0:
|
|
|
|
train_collector.collect(n_step=self.sampling_config.start_timesteps, random=True)
|
2023-09-20 09:29:34 +02:00
|
|
|
return train_collector, test_collector
|
|
|
|
|
2023-09-28 20:07:52 +02:00
|
|
|
def set_policy_wrapper_factory(
|
2023-10-03 20:26:39 +02:00
|
|
|
self,
|
|
|
|
policy_wrapper_factory: PolicyWrapperFactory | None,
|
2023-09-28 20:07:52 +02:00
|
|
|
) -> None:
|
|
|
|
self.policy_wrapper_factory = policy_wrapper_factory
|
|
|
|
|
2023-10-05 15:39:32 +02:00
|
|
|
def set_trainer_callbacks(self, callbacks: TrainerCallbacks):
|
|
|
|
self.trainer_callbacks = callbacks
|
|
|
|
|
2023-09-19 18:53:11 +02:00
|
|
|
@abstractmethod
|
2023-09-28 20:07:52 +02:00
|
|
|
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
2023-09-19 18:53:11 +02:00
|
|
|
pass
|
|
|
|
|
2023-09-28 20:07:52 +02:00
|
|
|
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
|
|
|
policy = self._create_policy(envs, device)
|
|
|
|
if self.policy_wrapper_factory is not None:
|
|
|
|
policy = self.policy_wrapper_factory.create_wrapped_policy(
|
2023-10-03 20:26:39 +02:00
|
|
|
policy,
|
|
|
|
envs,
|
|
|
|
self.optim_factory,
|
|
|
|
device,
|
2023-09-28 20:07:52 +02:00
|
|
|
)
|
|
|
|
return policy
|
|
|
|
|
2023-09-19 18:53:11 +02:00
|
|
|
@staticmethod
|
|
|
|
def _create_save_best_fn(envs: Environments, log_path: str) -> Callable:
|
2023-09-20 09:29:34 +02:00
|
|
|
def save_best_fn(pol: torch.nn.Module) -> None:
|
2023-09-28 20:07:52 +02:00
|
|
|
pass
|
|
|
|
# TODO: Fix saving in general (code works only for mujoco)
|
|
|
|
# state = {
|
|
|
|
# CHECKPOINT_DICT_KEY_MODEL: pol.state_dict(),
|
|
|
|
# CHECKPOINT_DICT_KEY_OBS_RMS: envs.train_envs.get_obs_rms(),
|
|
|
|
# }
|
|
|
|
# torch.save(state, os.path.join(log_path, "policy.pth"))
|
2023-09-19 18:53:11 +02:00
|
|
|
|
|
|
|
return save_best_fn
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def load_checkpoint(policy: torch.nn.Module, path, envs: Environments, device: TDevice):
|
|
|
|
ckpt = torch.load(path, map_location=device)
|
|
|
|
policy.load_state_dict(ckpt[CHECKPOINT_DICT_KEY_MODEL])
|
|
|
|
if envs.train_envs:
|
|
|
|
envs.train_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS])
|
|
|
|
if envs.test_envs:
|
|
|
|
envs.test_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS])
|
|
|
|
print("Loaded agent and obs. running means from: ", path) # TODO logging
|
|
|
|
|
|
|
|
@abstractmethod
|
2023-09-20 09:29:34 +02:00
|
|
|
def create_trainer(
|
|
|
|
self,
|
|
|
|
policy: BasePolicy,
|
|
|
|
train_collector: Collector,
|
|
|
|
test_collector: Collector,
|
|
|
|
envs: Environments,
|
|
|
|
logger: Logger,
|
|
|
|
) -> BaseTrainer:
|
2023-09-19 18:53:11 +02:00
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class OnpolicyAgentFactory(AgentFactory, ABC):
|
2023-09-20 09:29:34 +02:00
|
|
|
def create_trainer(
|
|
|
|
self,
|
|
|
|
policy: BasePolicy,
|
|
|
|
train_collector: Collector,
|
|
|
|
test_collector: Collector,
|
|
|
|
envs: Environments,
|
|
|
|
logger: Logger,
|
|
|
|
) -> OnpolicyTrainer:
|
2023-09-19 18:53:11 +02:00
|
|
|
sampling_config = self.sampling_config
|
2023-10-05 15:39:32 +02:00
|
|
|
callbacks = self.trainer_callbacks
|
|
|
|
context = TrainingContext(policy, envs, logger)
|
|
|
|
train_fn = (
|
|
|
|
callbacks.epoch_callback_train.get_trainer_fn(context)
|
|
|
|
if callbacks.epoch_callback_train
|
|
|
|
else None
|
|
|
|
)
|
|
|
|
test_fn = (
|
|
|
|
callbacks.epoch_callback_test.get_trainer_fn(context)
|
|
|
|
if callbacks.epoch_callback_test
|
|
|
|
else None
|
|
|
|
)
|
|
|
|
stop_fn = (
|
|
|
|
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
|
|
|
|
)
|
2023-09-19 18:53:11 +02:00
|
|
|
return OnpolicyTrainer(
|
|
|
|
policy=policy,
|
|
|
|
train_collector=train_collector,
|
|
|
|
test_collector=test_collector,
|
|
|
|
max_epoch=sampling_config.num_epochs,
|
|
|
|
step_per_epoch=sampling_config.step_per_epoch,
|
|
|
|
repeat_per_collect=sampling_config.repeat_per_collect,
|
|
|
|
episode_per_test=sampling_config.num_test_envs,
|
|
|
|
batch_size=sampling_config.batch_size,
|
|
|
|
step_per_collect=sampling_config.step_per_collect,
|
|
|
|
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
|
|
|
|
logger=logger.logger,
|
|
|
|
test_in_train=False,
|
2023-10-05 15:39:32 +02:00
|
|
|
train_fn=train_fn,
|
|
|
|
test_fn=test_fn,
|
|
|
|
stop_fn=stop_fn,
|
2023-09-19 18:53:11 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-09-20 09:29:34 +02:00
|
|
|
class OffpolicyAgentFactory(AgentFactory, ABC):
|
|
|
|
def create_trainer(
|
|
|
|
self,
|
|
|
|
policy: BasePolicy,
|
|
|
|
train_collector: Collector,
|
|
|
|
test_collector: Collector,
|
|
|
|
envs: Environments,
|
|
|
|
logger: Logger,
|
|
|
|
) -> OffpolicyTrainer:
|
|
|
|
sampling_config = self.sampling_config
|
2023-10-05 15:39:32 +02:00
|
|
|
callbacks = self.trainer_callbacks
|
|
|
|
context = TrainingContext(policy, envs, logger)
|
|
|
|
train_fn = (
|
|
|
|
callbacks.epoch_callback_train.get_trainer_fn(context)
|
|
|
|
if callbacks.epoch_callback_train
|
|
|
|
else None
|
|
|
|
)
|
|
|
|
test_fn = (
|
|
|
|
callbacks.epoch_callback_test.get_trainer_fn(context)
|
|
|
|
if callbacks.epoch_callback_test
|
|
|
|
else None
|
|
|
|
)
|
|
|
|
stop_fn = (
|
|
|
|
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
|
|
|
|
)
|
2023-09-20 09:29:34 +02:00
|
|
|
return OffpolicyTrainer(
|
|
|
|
policy=policy,
|
|
|
|
train_collector=train_collector,
|
|
|
|
test_collector=test_collector,
|
|
|
|
max_epoch=sampling_config.num_epochs,
|
|
|
|
step_per_epoch=sampling_config.step_per_epoch,
|
|
|
|
step_per_collect=sampling_config.step_per_collect,
|
|
|
|
episode_per_test=sampling_config.num_test_envs,
|
|
|
|
batch_size=sampling_config.batch_size,
|
|
|
|
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
|
|
|
|
logger=logger.logger,
|
|
|
|
update_per_step=sampling_config.update_per_step,
|
|
|
|
test_in_train=False,
|
2023-10-05 15:39:32 +02:00
|
|
|
train_fn=train_fn,
|
|
|
|
test_fn=test_fn,
|
|
|
|
stop_fn=stop_fn,
|
2023-09-20 09:29:34 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-09-26 15:35:18 +02:00
|
|
|
class _ActorMixin:
|
|
|
|
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
|
|
|
self.actor_module_opt_factory = ActorModuleOptFactory(actor_factory, optim_factory)
|
2023-09-20 13:15:06 +02:00
|
|
|
|
2023-09-26 15:35:18 +02:00
|
|
|
def create_actor_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
|
|
|
return self.actor_module_opt_factory.create_module_opt(envs, device, lr)
|
|
|
|
|
|
|
|
|
2023-10-05 15:39:32 +02:00
|
|
|
class _CriticMixin:
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
critic_factory: CriticFactory,
|
|
|
|
optim_factory: OptimizerFactory,
|
|
|
|
critic_use_action: bool,
|
|
|
|
):
|
|
|
|
self.critic_module_opt_factory = CriticModuleOptFactory(
|
|
|
|
critic_factory,
|
|
|
|
optim_factory,
|
|
|
|
critic_use_action,
|
|
|
|
)
|
|
|
|
|
|
|
|
def create_critic_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
|
|
|
return self.critic_module_opt_factory.create_module_opt(envs, device, lr)
|
|
|
|
|
|
|
|
|
2023-09-26 15:35:18 +02:00
|
|
|
class _ActorCriticMixin:
|
|
|
|
"""Mixin for agents that use an ActorCritic module with a single optimizer."""
|
2023-09-20 13:15:06 +02:00
|
|
|
|
2023-09-26 15:35:18 +02:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
actor_factory: ActorFactory,
|
|
|
|
critic_factory: CriticFactory,
|
|
|
|
optim_factory: OptimizerFactory,
|
|
|
|
critic_use_action: bool,
|
2023-09-28 20:07:52 +02:00
|
|
|
critic_use_actor_module: bool,
|
2023-09-26 15:35:18 +02:00
|
|
|
):
|
|
|
|
self.actor_factory = actor_factory
|
|
|
|
self.critic_factory = critic_factory
|
|
|
|
self.optim_factory = optim_factory
|
|
|
|
self.critic_use_action = critic_use_action
|
2023-09-28 20:07:52 +02:00
|
|
|
self.critic_use_actor_module = critic_use_actor_module
|
2023-09-26 15:35:18 +02:00
|
|
|
|
|
|
|
def create_actor_critic_module_opt(
|
|
|
|
self,
|
|
|
|
envs: Environments,
|
|
|
|
device: TDevice,
|
|
|
|
lr: float,
|
|
|
|
) -> ActorCriticModuleOpt:
|
|
|
|
actor = self.actor_factory.create_module(envs, device)
|
2023-09-28 20:07:52 +02:00
|
|
|
if self.critic_use_actor_module:
|
|
|
|
if self.critic_use_action:
|
|
|
|
raise ValueError(
|
|
|
|
"The options critic_use_actor_module and critic_use_action are mutually exclusive",
|
|
|
|
)
|
|
|
|
if envs.get_type().is_discrete():
|
|
|
|
critic = discrete.Critic(actor.get_preprocess_net(), device=device).to(device)
|
|
|
|
elif envs.get_type().is_continuous():
|
|
|
|
critic = continuous.Critic(actor.get_preprocess_net(), device=device).to(device)
|
|
|
|
else:
|
|
|
|
raise ValueError
|
|
|
|
else:
|
|
|
|
critic = self.critic_factory.create_module(
|
|
|
|
envs,
|
|
|
|
device,
|
|
|
|
use_action=self.critic_use_action,
|
|
|
|
)
|
2023-09-26 15:35:18 +02:00
|
|
|
actor_critic = ActorCritic(actor, critic)
|
|
|
|
optim = self.optim_factory.create_optimizer(actor_critic, lr)
|
|
|
|
return ActorCriticModuleOpt(actor_critic, optim)
|
|
|
|
|
|
|
|
|
2023-10-05 15:39:32 +02:00
|
|
|
class _ActorAndCriticMixin(_ActorMixin, _CriticMixin):
|
2023-09-26 15:35:18 +02:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
actor_factory: ActorFactory,
|
|
|
|
critic_factory: CriticFactory,
|
|
|
|
optim_factory: OptimizerFactory,
|
|
|
|
critic_use_action: bool,
|
|
|
|
):
|
2023-10-05 15:39:32 +02:00
|
|
|
_ActorMixin.__init__(self, actor_factory, optim_factory)
|
|
|
|
_CriticMixin.__init__(self, critic_factory, optim_factory, critic_use_action)
|
2023-09-26 15:35:18 +02:00
|
|
|
|
|
|
|
|
|
|
|
class _ActorAndDualCriticsMixin(_ActorAndCriticMixin):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
actor_factory: ActorFactory,
|
|
|
|
critic_factory: CriticFactory,
|
|
|
|
critic2_factory: CriticFactory,
|
|
|
|
optim_factory: OptimizerFactory,
|
|
|
|
critic_use_action: bool,
|
|
|
|
):
|
|
|
|
super().__init__(actor_factory, critic_factory, optim_factory, critic_use_action)
|
|
|
|
self.critic2_module_opt_factory = CriticModuleOptFactory(
|
|
|
|
critic2_factory,
|
|
|
|
optim_factory,
|
|
|
|
critic_use_action,
|
|
|
|
)
|
|
|
|
|
|
|
|
def create_critic2_module_opt(
|
|
|
|
self,
|
|
|
|
envs: Environments,
|
|
|
|
device: TDevice,
|
|
|
|
lr: float,
|
|
|
|
) -> ModuleOpt:
|
|
|
|
return self.critic2_module_opt_factory.create_module_opt(envs, device, lr)
|
|
|
|
|
|
|
|
|
2023-09-28 14:28:03 +02:00
|
|
|
class ActorCriticAgentFactory(
|
|
|
|
Generic[TParams, TPolicy],
|
|
|
|
OnpolicyAgentFactory,
|
|
|
|
_ActorCriticMixin,
|
|
|
|
ABC,
|
|
|
|
):
|
2023-09-20 09:29:34 +02:00
|
|
|
def __init__(
|
|
|
|
self,
|
2023-09-28 14:28:03 +02:00
|
|
|
params: TParams,
|
2023-09-20 09:29:34 +02:00
|
|
|
sampling_config: RLSamplingConfig,
|
|
|
|
actor_factory: ActorFactory,
|
|
|
|
critic_factory: CriticFactory,
|
|
|
|
optimizer_factory: OptimizerFactory,
|
2023-09-28 14:28:03 +02:00
|
|
|
policy_class: type[TPolicy],
|
2023-09-28 20:07:52 +02:00
|
|
|
critic_use_actor_module: bool,
|
2023-09-20 09:29:34 +02:00
|
|
|
):
|
2023-09-28 20:07:52 +02:00
|
|
|
super().__init__(sampling_config, optim_factory=optimizer_factory)
|
2023-09-26 15:35:18 +02:00
|
|
|
_ActorCriticMixin.__init__(
|
|
|
|
self,
|
|
|
|
actor_factory,
|
|
|
|
critic_factory,
|
|
|
|
optimizer_factory,
|
|
|
|
critic_use_action=False,
|
2023-09-28 20:07:52 +02:00
|
|
|
critic_use_actor_module=critic_use_actor_module,
|
2023-09-26 15:35:18 +02:00
|
|
|
)
|
|
|
|
self.params = params
|
2023-09-28 14:28:03 +02:00
|
|
|
self.policy_class = policy_class
|
2023-09-19 18:53:11 +02:00
|
|
|
|
2023-09-28 14:28:03 +02:00
|
|
|
@abstractmethod
|
|
|
|
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
|
|
|
pass
|
|
|
|
|
|
|
|
def _create_kwargs(self, envs: Environments, device: TDevice):
|
|
|
|
actor_critic = self._create_actor_critic(envs, device)
|
2023-09-26 15:35:18 +02:00
|
|
|
kwargs = self.params.create_kwargs(
|
2023-09-26 17:43:16 +02:00
|
|
|
ParamTransformerData(
|
|
|
|
envs=envs,
|
|
|
|
device=device,
|
|
|
|
optim_factory=self.optim_factory,
|
|
|
|
optim=actor_critic.optim,
|
|
|
|
),
|
2023-09-26 15:35:18 +02:00
|
|
|
)
|
2023-09-28 14:28:03 +02:00
|
|
|
kwargs["actor"] = actor_critic.actor
|
|
|
|
kwargs["critic"] = actor_critic.critic
|
|
|
|
kwargs["optim"] = actor_critic.optim
|
|
|
|
kwargs["action_space"] = envs.get_action_space()
|
|
|
|
return kwargs
|
|
|
|
|
2023-09-28 20:07:52 +02:00
|
|
|
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
2023-09-28 14:28:03 +02:00
|
|
|
return self.policy_class(**self._create_kwargs(envs, device))
|
|
|
|
|
|
|
|
|
|
|
|
class A2CAgentFactory(ActorCriticAgentFactory[A2CParams, A2CPolicy]):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
params: A2CParams,
|
|
|
|
sampling_config: RLSamplingConfig,
|
|
|
|
actor_factory: ActorFactory,
|
|
|
|
critic_factory: CriticFactory,
|
|
|
|
optimizer_factory: OptimizerFactory,
|
2023-09-28 20:07:52 +02:00
|
|
|
critic_use_actor_module: bool,
|
2023-09-28 14:28:03 +02:00
|
|
|
):
|
|
|
|
super().__init__(
|
|
|
|
params,
|
|
|
|
sampling_config,
|
|
|
|
actor_factory,
|
|
|
|
critic_factory,
|
|
|
|
optimizer_factory,
|
|
|
|
A2CPolicy,
|
2023-09-28 20:07:52 +02:00
|
|
|
critic_use_actor_module,
|
2023-09-19 18:53:11 +02:00
|
|
|
)
|
2023-09-20 09:29:34 +02:00
|
|
|
|
2023-09-28 14:28:03 +02:00
|
|
|
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
|
|
|
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
|
|
|
|
|
|
|
|
|
|
|
class PPOAgentFactory(ActorCriticAgentFactory[PPOParams, PPOPolicy]):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
params: PPOParams,
|
|
|
|
sampling_config: RLSamplingConfig,
|
|
|
|
actor_factory: ActorFactory,
|
|
|
|
critic_factory: CriticFactory,
|
|
|
|
optimizer_factory: OptimizerFactory,
|
2023-09-28 20:07:52 +02:00
|
|
|
critic_use_actor_module: bool,
|
2023-09-28 14:28:03 +02:00
|
|
|
):
|
|
|
|
super().__init__(
|
|
|
|
params,
|
|
|
|
sampling_config,
|
|
|
|
actor_factory,
|
|
|
|
critic_factory,
|
|
|
|
optimizer_factory,
|
|
|
|
PPOPolicy,
|
2023-09-28 20:07:52 +02:00
|
|
|
critic_use_actor_module,
|
2023-09-28 14:28:03 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
def _create_actor_critic(self, envs: Environments, device: TDevice) -> ActorCriticModuleOpt:
|
|
|
|
return self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
|
|
|
|
2023-09-20 09:29:34 +02:00
|
|
|
|
2023-10-05 15:39:32 +02:00
|
|
|
class DQNAgentFactory(OffpolicyAgentFactory):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
params: DQNParams,
|
|
|
|
sampling_config: RLSamplingConfig,
|
|
|
|
critic_factory: CriticFactory,
|
|
|
|
optim_factory: OptimizerFactory,
|
|
|
|
):
|
|
|
|
super().__init__(sampling_config, optim_factory)
|
|
|
|
self.params = params
|
|
|
|
self.critic_factory = critic_factory
|
|
|
|
self.optim_factory = optim_factory
|
|
|
|
|
|
|
|
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
|
|
|
critic = self.critic_factory.create_module(envs, device, use_action=True)
|
|
|
|
optim = self.optim_factory.create_optimizer(critic, self.params.lr)
|
|
|
|
kwargs = self.params.create_kwargs(
|
|
|
|
ParamTransformerData(
|
|
|
|
envs=envs,
|
|
|
|
device=device,
|
|
|
|
optim=optim,
|
|
|
|
optim_factory=self.optim_factory,
|
|
|
|
),
|
|
|
|
)
|
|
|
|
envs.get_type().assert_discrete(self)
|
|
|
|
# noinspection PyTypeChecker
|
|
|
|
action_space: gymnasium.spaces.Discrete = envs.get_action_space()
|
|
|
|
return DQNPolicy(
|
|
|
|
model=critic,
|
|
|
|
optim=optim,
|
|
|
|
action_space=action_space,
|
|
|
|
observation_space=envs.get_observation_space(),
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-10-03 20:26:39 +02:00
|
|
|
class DDPGAgentFactory(OffpolicyAgentFactory, _ActorAndCriticMixin):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
params: DDPGParams,
|
|
|
|
sampling_config: RLSamplingConfig,
|
|
|
|
actor_factory: ActorFactory,
|
|
|
|
critic_factory: CriticFactory,
|
|
|
|
optim_factory: OptimizerFactory,
|
|
|
|
):
|
|
|
|
super().__init__(sampling_config, optim_factory)
|
|
|
|
_ActorAndCriticMixin.__init__(
|
|
|
|
self,
|
|
|
|
actor_factory,
|
|
|
|
critic_factory,
|
|
|
|
optim_factory,
|
|
|
|
critic_use_action=True,
|
|
|
|
)
|
|
|
|
self.params = params
|
|
|
|
self.optim_factory = optim_factory
|
|
|
|
|
|
|
|
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
|
|
|
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
|
|
|
|
critic = self.create_critic_module_opt(envs, device, self.params.critic_lr)
|
|
|
|
kwargs = self.params.create_kwargs(
|
|
|
|
ParamTransformerData(
|
|
|
|
envs=envs,
|
|
|
|
device=device,
|
|
|
|
optim_factory=self.optim_factory,
|
|
|
|
actor=actor,
|
|
|
|
critic1=critic,
|
|
|
|
),
|
|
|
|
)
|
|
|
|
return DDPGPolicy(
|
|
|
|
actor=actor.module,
|
|
|
|
actor_optim=actor.optim,
|
|
|
|
critic=critic.module,
|
|
|
|
critic_optim=critic.optim,
|
|
|
|
action_space=envs.get_action_space(),
|
|
|
|
observation_space=envs.get_observation_space(),
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-09-26 15:35:18 +02:00
|
|
|
class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
2023-09-20 09:29:34 +02:00
|
|
|
def __init__(
|
|
|
|
self,
|
2023-09-25 17:56:37 +02:00
|
|
|
params: SACParams,
|
2023-09-20 09:29:34 +02:00
|
|
|
sampling_config: RLSamplingConfig,
|
|
|
|
actor_factory: ActorFactory,
|
|
|
|
critic1_factory: CriticFactory,
|
|
|
|
critic2_factory: CriticFactory,
|
|
|
|
optim_factory: OptimizerFactory,
|
|
|
|
):
|
2023-09-28 20:07:52 +02:00
|
|
|
super().__init__(sampling_config, optim_factory)
|
2023-09-26 15:35:18 +02:00
|
|
|
_ActorAndDualCriticsMixin.__init__(
|
|
|
|
self,
|
|
|
|
actor_factory,
|
|
|
|
critic1_factory,
|
|
|
|
critic2_factory,
|
|
|
|
optim_factory,
|
|
|
|
critic_use_action=True,
|
|
|
|
)
|
2023-09-25 17:56:37 +02:00
|
|
|
self.params = params
|
2023-09-26 15:35:18 +02:00
|
|
|
self.optim_factory = optim_factory
|
2023-09-20 09:29:34 +02:00
|
|
|
|
2023-09-28 20:07:52 +02:00
|
|
|
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
2023-09-26 15:35:18 +02:00
|
|
|
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
|
|
|
|
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
|
|
|
|
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
|
2023-09-25 17:56:37 +02:00
|
|
|
kwargs = self.params.create_kwargs(
|
2023-09-26 17:43:16 +02:00
|
|
|
ParamTransformerData(
|
|
|
|
envs=envs,
|
|
|
|
device=device,
|
|
|
|
optim_factory=self.optim_factory,
|
|
|
|
actor=actor,
|
|
|
|
critic1=critic1,
|
|
|
|
critic2=critic2,
|
2023-09-25 17:56:37 +02:00
|
|
|
),
|
2023-09-26 15:35:18 +02:00
|
|
|
)
|
2023-09-20 09:29:34 +02:00
|
|
|
return SACPolicy(
|
2023-09-26 15:35:18 +02:00
|
|
|
actor=actor.module,
|
|
|
|
actor_optim=actor.optim,
|
|
|
|
critic=critic1.module,
|
|
|
|
critic_optim=critic1.optim,
|
|
|
|
critic2=critic2.module,
|
|
|
|
critic2_optim=critic2.optim,
|
|
|
|
action_space=envs.get_action_space(),
|
|
|
|
observation_space=envs.get_observation_space(),
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
params: TD3Params,
|
|
|
|
sampling_config: RLSamplingConfig,
|
|
|
|
actor_factory: ActorFactory,
|
|
|
|
critic1_factory: CriticFactory,
|
|
|
|
critic2_factory: CriticFactory,
|
|
|
|
optim_factory: OptimizerFactory,
|
|
|
|
):
|
2023-09-28 20:07:52 +02:00
|
|
|
super().__init__(sampling_config, optim_factory)
|
2023-09-26 15:35:18 +02:00
|
|
|
_ActorAndDualCriticsMixin.__init__(
|
|
|
|
self,
|
|
|
|
actor_factory,
|
|
|
|
critic1_factory,
|
|
|
|
critic2_factory,
|
|
|
|
optim_factory,
|
|
|
|
critic_use_action=True,
|
|
|
|
)
|
|
|
|
self.params = params
|
|
|
|
self.optim_factory = optim_factory
|
|
|
|
|
2023-09-28 20:07:52 +02:00
|
|
|
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
2023-09-26 15:35:18 +02:00
|
|
|
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
|
|
|
|
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
|
|
|
|
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
|
|
|
|
kwargs = self.params.create_kwargs(
|
2023-09-26 17:43:16 +02:00
|
|
|
ParamTransformerData(
|
|
|
|
envs=envs,
|
|
|
|
device=device,
|
|
|
|
optim_factory=self.optim_factory,
|
|
|
|
actor=actor,
|
|
|
|
critic1=critic1,
|
|
|
|
critic2=critic2,
|
2023-09-26 15:35:18 +02:00
|
|
|
),
|
|
|
|
)
|
|
|
|
return TD3Policy(
|
|
|
|
actor=actor.module,
|
|
|
|
actor_optim=actor.optim,
|
|
|
|
critic=critic1.module,
|
|
|
|
critic_optim=critic1.optim,
|
|
|
|
critic2=critic2.module,
|
|
|
|
critic2_optim=critic2.optim,
|
2023-09-20 09:29:34 +02:00
|
|
|
action_space=envs.get_action_space(),
|
2023-09-25 17:56:37 +02:00
|
|
|
observation_space=envs.get_observation_space(),
|
2023-09-26 15:35:18 +02:00
|
|
|
**kwargs,
|
2023-09-20 09:29:34 +02:00
|
|
|
)
|