2023-09-19 18:53:11 +02:00
|
|
|
import os
|
2023-09-20 09:29:34 +02:00
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
from collections.abc import Callable
|
2023-09-19 18:53:11 +02:00
|
|
|
|
|
|
|
import torch
|
|
|
|
|
2023-09-20 09:29:34 +02:00
|
|
|
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
2023-09-20 15:28:33 +02:00
|
|
|
from tianshou.highlevel.config import RLSamplingConfig
|
2023-09-20 15:45:09 +02:00
|
|
|
from tianshou.highlevel.env import Environments
|
2023-09-19 18:53:11 +02:00
|
|
|
from tianshou.highlevel.logger import Logger
|
2023-09-26 15:35:18 +02:00
|
|
|
from tianshou.highlevel.module import (
|
|
|
|
ActorCriticModuleOpt,
|
|
|
|
ActorFactory,
|
|
|
|
ActorModuleOptFactory,
|
|
|
|
CriticFactory,
|
|
|
|
CriticModuleOptFactory,
|
|
|
|
ModuleOpt,
|
|
|
|
TDevice,
|
|
|
|
)
|
2023-09-25 17:56:37 +02:00
|
|
|
from tianshou.highlevel.optim import OptimizerFactory
|
2023-09-26 15:35:18 +02:00
|
|
|
from tianshou.highlevel.params.policy_params import (
|
2023-09-26 17:43:16 +02:00
|
|
|
ParamTransformerData,
|
2023-09-26 15:35:18 +02:00
|
|
|
PPOParams,
|
|
|
|
SACParams,
|
|
|
|
TD3Params,
|
|
|
|
)
|
|
|
|
from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy, TD3Policy
|
2023-09-25 17:56:37 +02:00
|
|
|
from tianshou.policy.modelfree.pg import TDistParams
|
2023-09-20 09:29:34 +02:00
|
|
|
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
2023-09-19 18:53:11 +02:00
|
|
|
from tianshou.utils.net.common import ActorCritic
|
|
|
|
|
|
|
|
CHECKPOINT_DICT_KEY_MODEL = "model"
|
|
|
|
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
|
|
|
|
|
|
|
|
|
|
|
class AgentFactory(ABC):
|
2023-09-20 09:29:34 +02:00
|
|
|
def __init__(self, sampling_config: RLSamplingConfig):
|
|
|
|
self.sampling_config = sampling_config
|
|
|
|
|
|
|
|
def create_train_test_collector(self, policy: BasePolicy, envs: Environments):
|
|
|
|
buffer_size = self.sampling_config.buffer_size
|
|
|
|
train_envs = envs.train_envs
|
|
|
|
if len(train_envs) > 1:
|
|
|
|
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
|
|
|
|
else:
|
|
|
|
buffer = ReplayBuffer(buffer_size)
|
|
|
|
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
|
|
|
test_collector = Collector(policy, envs.test_envs)
|
2023-09-20 15:13:05 +02:00
|
|
|
if self.sampling_config.start_timesteps > 0:
|
|
|
|
train_collector.collect(n_step=self.sampling_config.start_timesteps, random=True)
|
2023-09-20 09:29:34 +02:00
|
|
|
return train_collector, test_collector
|
|
|
|
|
2023-09-19 18:53:11 +02:00
|
|
|
@abstractmethod
|
|
|
|
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
|
|
|
pass
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def _create_save_best_fn(envs: Environments, log_path: str) -> Callable:
|
2023-09-20 09:29:34 +02:00
|
|
|
def save_best_fn(pol: torch.nn.Module) -> None:
|
|
|
|
state = {
|
|
|
|
CHECKPOINT_DICT_KEY_MODEL: pol.state_dict(),
|
|
|
|
CHECKPOINT_DICT_KEY_OBS_RMS: envs.train_envs.get_obs_rms(),
|
|
|
|
}
|
2023-09-19 18:53:11 +02:00
|
|
|
torch.save(state, os.path.join(log_path, "policy.pth"))
|
|
|
|
|
|
|
|
return save_best_fn
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def load_checkpoint(policy: torch.nn.Module, path, envs: Environments, device: TDevice):
|
|
|
|
ckpt = torch.load(path, map_location=device)
|
|
|
|
policy.load_state_dict(ckpt[CHECKPOINT_DICT_KEY_MODEL])
|
|
|
|
if envs.train_envs:
|
|
|
|
envs.train_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS])
|
|
|
|
if envs.test_envs:
|
|
|
|
envs.test_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS])
|
|
|
|
print("Loaded agent and obs. running means from: ", path) # TODO logging
|
|
|
|
|
|
|
|
@abstractmethod
|
2023-09-20 09:29:34 +02:00
|
|
|
def create_trainer(
|
|
|
|
self,
|
|
|
|
policy: BasePolicy,
|
|
|
|
train_collector: Collector,
|
|
|
|
test_collector: Collector,
|
|
|
|
envs: Environments,
|
|
|
|
logger: Logger,
|
|
|
|
) -> BaseTrainer:
|
2023-09-19 18:53:11 +02:00
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class OnpolicyAgentFactory(AgentFactory, ABC):
|
2023-09-20 09:29:34 +02:00
|
|
|
def create_trainer(
|
|
|
|
self,
|
|
|
|
policy: BasePolicy,
|
|
|
|
train_collector: Collector,
|
|
|
|
test_collector: Collector,
|
|
|
|
envs: Environments,
|
|
|
|
logger: Logger,
|
|
|
|
) -> OnpolicyTrainer:
|
2023-09-19 18:53:11 +02:00
|
|
|
sampling_config = self.sampling_config
|
|
|
|
return OnpolicyTrainer(
|
|
|
|
policy=policy,
|
|
|
|
train_collector=train_collector,
|
|
|
|
test_collector=test_collector,
|
|
|
|
max_epoch=sampling_config.num_epochs,
|
|
|
|
step_per_epoch=sampling_config.step_per_epoch,
|
|
|
|
repeat_per_collect=sampling_config.repeat_per_collect,
|
|
|
|
episode_per_test=sampling_config.num_test_envs,
|
|
|
|
batch_size=sampling_config.batch_size,
|
|
|
|
step_per_collect=sampling_config.step_per_collect,
|
|
|
|
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
|
|
|
|
logger=logger.logger,
|
|
|
|
test_in_train=False,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-09-20 09:29:34 +02:00
|
|
|
class OffpolicyAgentFactory(AgentFactory, ABC):
|
|
|
|
def create_trainer(
|
|
|
|
self,
|
|
|
|
policy: BasePolicy,
|
|
|
|
train_collector: Collector,
|
|
|
|
test_collector: Collector,
|
|
|
|
envs: Environments,
|
|
|
|
logger: Logger,
|
|
|
|
) -> OffpolicyTrainer:
|
|
|
|
sampling_config = self.sampling_config
|
|
|
|
return OffpolicyTrainer(
|
|
|
|
policy=policy,
|
|
|
|
train_collector=train_collector,
|
|
|
|
test_collector=test_collector,
|
|
|
|
max_epoch=sampling_config.num_epochs,
|
|
|
|
step_per_epoch=sampling_config.step_per_epoch,
|
|
|
|
step_per_collect=sampling_config.step_per_collect,
|
|
|
|
episode_per_test=sampling_config.num_test_envs,
|
|
|
|
batch_size=sampling_config.batch_size,
|
|
|
|
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
|
|
|
|
logger=logger.logger,
|
|
|
|
update_per_step=sampling_config.update_per_step,
|
|
|
|
test_in_train=False,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2023-09-26 15:35:18 +02:00
|
|
|
class _ActorMixin:
|
|
|
|
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
|
|
|
self.actor_module_opt_factory = ActorModuleOptFactory(actor_factory, optim_factory)
|
2023-09-20 13:15:06 +02:00
|
|
|
|
2023-09-26 15:35:18 +02:00
|
|
|
def create_actor_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
|
|
|
return self.actor_module_opt_factory.create_module_opt(envs, device, lr)
|
|
|
|
|
|
|
|
|
|
|
|
class _ActorCriticMixin:
|
|
|
|
"""Mixin for agents that use an ActorCritic module with a single optimizer."""
|
2023-09-20 13:15:06 +02:00
|
|
|
|
2023-09-26 15:35:18 +02:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
actor_factory: ActorFactory,
|
|
|
|
critic_factory: CriticFactory,
|
|
|
|
optim_factory: OptimizerFactory,
|
|
|
|
critic_use_action: bool,
|
|
|
|
):
|
|
|
|
self.actor_factory = actor_factory
|
|
|
|
self.critic_factory = critic_factory
|
|
|
|
self.optim_factory = optim_factory
|
|
|
|
self.critic_use_action = critic_use_action
|
|
|
|
|
|
|
|
def create_actor_critic_module_opt(
|
|
|
|
self,
|
|
|
|
envs: Environments,
|
|
|
|
device: TDevice,
|
|
|
|
lr: float,
|
|
|
|
) -> ActorCriticModuleOpt:
|
|
|
|
actor = self.actor_factory.create_module(envs, device)
|
|
|
|
critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action)
|
|
|
|
actor_critic = ActorCritic(actor, critic)
|
|
|
|
optim = self.optim_factory.create_optimizer(actor_critic, lr)
|
|
|
|
return ActorCriticModuleOpt(actor_critic, optim)
|
|
|
|
|
|
|
|
|
|
|
|
class _ActorAndCriticMixin(_ActorMixin):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
actor_factory: ActorFactory,
|
|
|
|
critic_factory: CriticFactory,
|
|
|
|
optim_factory: OptimizerFactory,
|
|
|
|
critic_use_action: bool,
|
|
|
|
):
|
|
|
|
super().__init__(actor_factory, optim_factory)
|
|
|
|
self.critic_module_opt_factory = CriticModuleOptFactory(
|
|
|
|
critic_factory,
|
|
|
|
optim_factory,
|
|
|
|
critic_use_action,
|
|
|
|
)
|
|
|
|
|
|
|
|
def create_critic_module_opt(self, envs: Environments, device: TDevice, lr: float) -> ModuleOpt:
|
|
|
|
return self.critic_module_opt_factory.create_module_opt(envs, device, lr)
|
|
|
|
|
|
|
|
|
|
|
|
class _ActorAndDualCriticsMixin(_ActorAndCriticMixin):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
actor_factory: ActorFactory,
|
|
|
|
critic_factory: CriticFactory,
|
|
|
|
critic2_factory: CriticFactory,
|
|
|
|
optim_factory: OptimizerFactory,
|
|
|
|
critic_use_action: bool,
|
|
|
|
):
|
|
|
|
super().__init__(actor_factory, critic_factory, optim_factory, critic_use_action)
|
|
|
|
self.critic2_module_opt_factory = CriticModuleOptFactory(
|
|
|
|
critic2_factory,
|
|
|
|
optim_factory,
|
|
|
|
critic_use_action,
|
|
|
|
)
|
|
|
|
|
|
|
|
def create_critic2_module_opt(
|
|
|
|
self,
|
|
|
|
envs: Environments,
|
|
|
|
device: TDevice,
|
|
|
|
lr: float,
|
|
|
|
) -> ModuleOpt:
|
|
|
|
return self.critic2_module_opt_factory.create_module_opt(envs, device, lr)
|
|
|
|
|
|
|
|
|
|
|
|
class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin):
|
2023-09-20 09:29:34 +02:00
|
|
|
def __init__(
|
|
|
|
self,
|
2023-09-25 17:56:37 +02:00
|
|
|
params: PPOParams,
|
2023-09-20 09:29:34 +02:00
|
|
|
sampling_config: RLSamplingConfig,
|
|
|
|
actor_factory: ActorFactory,
|
|
|
|
critic_factory: CriticFactory,
|
|
|
|
optimizer_factory: OptimizerFactory,
|
2023-09-25 17:56:37 +02:00
|
|
|
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
|
2023-09-20 09:29:34 +02:00
|
|
|
):
|
2023-09-19 18:53:11 +02:00
|
|
|
super().__init__(sampling_config)
|
2023-09-26 15:35:18 +02:00
|
|
|
_ActorCriticMixin.__init__(
|
|
|
|
self,
|
|
|
|
actor_factory,
|
|
|
|
critic_factory,
|
|
|
|
optimizer_factory,
|
|
|
|
critic_use_action=False,
|
|
|
|
)
|
|
|
|
self.params = params
|
2023-09-25 17:56:37 +02:00
|
|
|
self.dist_fn = dist_fn
|
2023-09-19 18:53:11 +02:00
|
|
|
|
|
|
|
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
|
2023-09-26 15:35:18 +02:00
|
|
|
actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
|
|
|
kwargs = self.params.create_kwargs(
|
2023-09-26 17:43:16 +02:00
|
|
|
ParamTransformerData(
|
|
|
|
envs=envs,
|
|
|
|
device=device,
|
|
|
|
optim_factory=self.optim_factory,
|
|
|
|
optim=actor_critic.optim,
|
|
|
|
),
|
2023-09-26 15:35:18 +02:00
|
|
|
)
|
2023-09-19 18:53:11 +02:00
|
|
|
return PPOPolicy(
|
2023-09-26 15:35:18 +02:00
|
|
|
actor=actor_critic.actor,
|
|
|
|
critic=actor_critic.critic,
|
|
|
|
optim=actor_critic.optim,
|
2023-09-25 17:56:37 +02:00
|
|
|
dist_fn=self.dist_fn,
|
2023-09-19 18:53:11 +02:00
|
|
|
action_space=envs.get_action_space(),
|
2023-09-26 15:35:18 +02:00
|
|
|
**kwargs,
|
2023-09-19 18:53:11 +02:00
|
|
|
)
|
2023-09-20 09:29:34 +02:00
|
|
|
|
|
|
|
|
2023-09-26 15:35:18 +02:00
|
|
|
class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
2023-09-20 09:29:34 +02:00
|
|
|
def __init__(
|
|
|
|
self,
|
2023-09-25 17:56:37 +02:00
|
|
|
params: SACParams,
|
2023-09-20 09:29:34 +02:00
|
|
|
sampling_config: RLSamplingConfig,
|
|
|
|
actor_factory: ActorFactory,
|
|
|
|
critic1_factory: CriticFactory,
|
|
|
|
critic2_factory: CriticFactory,
|
|
|
|
optim_factory: OptimizerFactory,
|
|
|
|
):
|
|
|
|
super().__init__(sampling_config)
|
2023-09-26 15:35:18 +02:00
|
|
|
_ActorAndDualCriticsMixin.__init__(
|
|
|
|
self,
|
|
|
|
actor_factory,
|
|
|
|
critic1_factory,
|
|
|
|
critic2_factory,
|
|
|
|
optim_factory,
|
|
|
|
critic_use_action=True,
|
|
|
|
)
|
2023-09-25 17:56:37 +02:00
|
|
|
self.params = params
|
2023-09-26 15:35:18 +02:00
|
|
|
self.optim_factory = optim_factory
|
2023-09-20 09:29:34 +02:00
|
|
|
|
|
|
|
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
2023-09-26 15:35:18 +02:00
|
|
|
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
|
|
|
|
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
|
|
|
|
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
|
2023-09-25 17:56:37 +02:00
|
|
|
kwargs = self.params.create_kwargs(
|
2023-09-26 17:43:16 +02:00
|
|
|
ParamTransformerData(
|
|
|
|
envs=envs,
|
|
|
|
device=device,
|
|
|
|
optim_factory=self.optim_factory,
|
|
|
|
actor=actor,
|
|
|
|
critic1=critic1,
|
|
|
|
critic2=critic2,
|
2023-09-25 17:56:37 +02:00
|
|
|
),
|
2023-09-26 15:35:18 +02:00
|
|
|
)
|
2023-09-20 09:29:34 +02:00
|
|
|
return SACPolicy(
|
2023-09-26 15:35:18 +02:00
|
|
|
actor=actor.module,
|
|
|
|
actor_optim=actor.optim,
|
|
|
|
critic=critic1.module,
|
|
|
|
critic_optim=critic1.optim,
|
|
|
|
critic2=critic2.module,
|
|
|
|
critic2_optim=critic2.optim,
|
|
|
|
action_space=envs.get_action_space(),
|
|
|
|
observation_space=envs.get_observation_space(),
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
params: TD3Params,
|
|
|
|
sampling_config: RLSamplingConfig,
|
|
|
|
actor_factory: ActorFactory,
|
|
|
|
critic1_factory: CriticFactory,
|
|
|
|
critic2_factory: CriticFactory,
|
|
|
|
optim_factory: OptimizerFactory,
|
|
|
|
):
|
|
|
|
super().__init__(sampling_config)
|
|
|
|
_ActorAndDualCriticsMixin.__init__(
|
|
|
|
self,
|
|
|
|
actor_factory,
|
|
|
|
critic1_factory,
|
|
|
|
critic2_factory,
|
|
|
|
optim_factory,
|
|
|
|
critic_use_action=True,
|
|
|
|
)
|
|
|
|
self.params = params
|
|
|
|
self.optim_factory = optim_factory
|
|
|
|
|
|
|
|
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
|
|
|
actor = self.create_actor_module_opt(envs, device, self.params.actor_lr)
|
|
|
|
critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
|
|
|
|
critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
|
|
|
|
kwargs = self.params.create_kwargs(
|
2023-09-26 17:43:16 +02:00
|
|
|
ParamTransformerData(
|
|
|
|
envs=envs,
|
|
|
|
device=device,
|
|
|
|
optim_factory=self.optim_factory,
|
|
|
|
actor=actor,
|
|
|
|
critic1=critic1,
|
|
|
|
critic2=critic2,
|
2023-09-26 15:35:18 +02:00
|
|
|
),
|
|
|
|
)
|
|
|
|
return TD3Policy(
|
|
|
|
actor=actor.module,
|
|
|
|
actor_optim=actor.optim,
|
|
|
|
critic=critic1.module,
|
|
|
|
critic_optim=critic1.optim,
|
|
|
|
critic2=critic2.module,
|
|
|
|
critic2_optim=critic2.optim,
|
2023-09-20 09:29:34 +02:00
|
|
|
action_space=envs.get_action_space(),
|
2023-09-25 17:56:37 +02:00
|
|
|
observation_space=envs.get_observation_space(),
|
2023-09-26 15:35:18 +02:00
|
|
|
**kwargs,
|
2023-09-20 09:29:34 +02:00
|
|
|
)
|