Policy objects are now parametrised by converting the parameter dataclass instances to kwargs, using some injectable conversions along the way
263 lines
10 KiB
Python
263 lines
10 KiB
Python
import os
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Callable
|
|
from typing import Dict, Any, List, Tuple
|
|
|
|
import torch
|
|
|
|
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
|
from tianshou.exploration import BaseNoise
|
|
from tianshou.highlevel.config import RLSamplingConfig
|
|
from tianshou.highlevel.env import Environments
|
|
from tianshou.highlevel.logger import Logger
|
|
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
|
|
from tianshou.highlevel.optim import OptimizerFactory
|
|
from tianshou.highlevel.params.alpha import AutoAlphaFactory
|
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
|
from tianshou.highlevel.params.policy_params import PPOParams, ParamTransformer, SACParams
|
|
from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy
|
|
from tianshou.policy.modelfree.pg import TDistParams
|
|
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
|
from tianshou.utils import MultipleLRSchedulers
|
|
from tianshou.utils.net.common import ActorCritic
|
|
|
|
CHECKPOINT_DICT_KEY_MODEL = "model"
|
|
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
|
|
|
|
|
class AgentFactory(ABC):
|
|
def __init__(self, sampling_config: RLSamplingConfig):
|
|
self.sampling_config = sampling_config
|
|
|
|
def create_train_test_collector(self, policy: BasePolicy, envs: Environments):
|
|
buffer_size = self.sampling_config.buffer_size
|
|
train_envs = envs.train_envs
|
|
if len(train_envs) > 1:
|
|
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
|
|
else:
|
|
buffer = ReplayBuffer(buffer_size)
|
|
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
|
test_collector = Collector(policy, envs.test_envs)
|
|
if self.sampling_config.start_timesteps > 0:
|
|
train_collector.collect(n_step=self.sampling_config.start_timesteps, random=True)
|
|
return train_collector, test_collector
|
|
|
|
@abstractmethod
|
|
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
|
pass
|
|
|
|
@staticmethod
|
|
def _create_save_best_fn(envs: Environments, log_path: str) -> Callable:
|
|
def save_best_fn(pol: torch.nn.Module) -> None:
|
|
state = {
|
|
CHECKPOINT_DICT_KEY_MODEL: pol.state_dict(),
|
|
CHECKPOINT_DICT_KEY_OBS_RMS: envs.train_envs.get_obs_rms(),
|
|
}
|
|
torch.save(state, os.path.join(log_path, "policy.pth"))
|
|
|
|
return save_best_fn
|
|
|
|
@staticmethod
|
|
def load_checkpoint(policy: torch.nn.Module, path, envs: Environments, device: TDevice):
|
|
ckpt = torch.load(path, map_location=device)
|
|
policy.load_state_dict(ckpt[CHECKPOINT_DICT_KEY_MODEL])
|
|
if envs.train_envs:
|
|
envs.train_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS])
|
|
if envs.test_envs:
|
|
envs.test_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS])
|
|
print("Loaded agent and obs. running means from: ", path) # TODO logging
|
|
|
|
@abstractmethod
|
|
def create_trainer(
|
|
self,
|
|
policy: BasePolicy,
|
|
train_collector: Collector,
|
|
test_collector: Collector,
|
|
envs: Environments,
|
|
logger: Logger,
|
|
) -> BaseTrainer:
|
|
pass
|
|
|
|
|
|
class OnpolicyAgentFactory(AgentFactory, ABC):
|
|
def create_trainer(
|
|
self,
|
|
policy: BasePolicy,
|
|
train_collector: Collector,
|
|
test_collector: Collector,
|
|
envs: Environments,
|
|
logger: Logger,
|
|
) -> OnpolicyTrainer:
|
|
sampling_config = self.sampling_config
|
|
return OnpolicyTrainer(
|
|
policy=policy,
|
|
train_collector=train_collector,
|
|
test_collector=test_collector,
|
|
max_epoch=sampling_config.num_epochs,
|
|
step_per_epoch=sampling_config.step_per_epoch,
|
|
repeat_per_collect=sampling_config.repeat_per_collect,
|
|
episode_per_test=sampling_config.num_test_envs,
|
|
batch_size=sampling_config.batch_size,
|
|
step_per_collect=sampling_config.step_per_collect,
|
|
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
|
|
logger=logger.logger,
|
|
test_in_train=False,
|
|
)
|
|
|
|
|
|
class OffpolicyAgentFactory(AgentFactory, ABC):
|
|
def create_trainer(
|
|
self,
|
|
policy: BasePolicy,
|
|
train_collector: Collector,
|
|
test_collector: Collector,
|
|
envs: Environments,
|
|
logger: Logger,
|
|
) -> OffpolicyTrainer:
|
|
sampling_config = self.sampling_config
|
|
return OffpolicyTrainer(
|
|
policy=policy,
|
|
train_collector=train_collector,
|
|
test_collector=test_collector,
|
|
max_epoch=sampling_config.num_epochs,
|
|
step_per_epoch=sampling_config.step_per_epoch,
|
|
step_per_collect=sampling_config.step_per_collect,
|
|
episode_per_test=sampling_config.num_test_envs,
|
|
batch_size=sampling_config.batch_size,
|
|
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
|
|
logger=logger.logger,
|
|
update_per_step=sampling_config.update_per_step,
|
|
test_in_train=False,
|
|
)
|
|
|
|
|
|
class ParamTransformerDrop(ParamTransformer):
|
|
def __init__(self, *keys: str):
|
|
self.keys = keys
|
|
|
|
def transform(self, kwargs: Dict[str, Any]) -> None:
|
|
for k in self.keys:
|
|
del kwargs[k]
|
|
|
|
|
|
class ParamTransformerLRScheduler(ParamTransformer):
|
|
def __init__(self, optim: torch.optim.Optimizer):
|
|
self.optim = optim
|
|
|
|
def transform(self, kwargs: Dict[str, Any]) -> None:
|
|
factory: LRSchedulerFactory | None = self.get(kwargs, "lr_scheduler_factory", drop=True)
|
|
kwargs["lr_scheduler"] = factory.create_scheduler(self.optim) if factory is not None else None
|
|
|
|
|
|
class PPOAgentFactory(OnpolicyAgentFactory):
|
|
def __init__(
|
|
self,
|
|
params: PPOParams,
|
|
sampling_config: RLSamplingConfig,
|
|
actor_factory: ActorFactory,
|
|
critic_factory: CriticFactory,
|
|
optimizer_factory: OptimizerFactory,
|
|
dist_fn: Callable[[TDistParams], torch.distributions.Distribution],
|
|
):
|
|
super().__init__(sampling_config)
|
|
self.optimizer_factory = optimizer_factory
|
|
self.critic_factory = critic_factory
|
|
self.actor_factory = actor_factory
|
|
self.config = params
|
|
self.dist_fn = dist_fn
|
|
|
|
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
|
|
actor = self.actor_factory.create_module(envs, device)
|
|
critic = self.critic_factory.create_module(envs, device, use_action=False)
|
|
actor_critic = ActorCritic(actor, critic)
|
|
optim = self.optimizer_factory.create_optimizer(actor_critic, self.config.lr)
|
|
kwargs = self.config.create_kwargs(
|
|
ParamTransformerDrop("lr"),
|
|
ParamTransformerLRScheduler(optim))
|
|
return PPOPolicy(
|
|
actor=actor,
|
|
critic=critic,
|
|
optim=optim,
|
|
dist_fn=self.dist_fn,
|
|
action_space=envs.get_action_space(),
|
|
**kwargs
|
|
)
|
|
|
|
|
|
class ParamTransformerAlpha(ParamTransformer):
|
|
def __init__(self, envs: Environments, optim_factory: OptimizerFactory, device: TDevice):
|
|
self.envs = envs
|
|
self.optim_factory = optim_factory
|
|
self.device = device
|
|
|
|
def transform(self, kwargs: Dict[str, Any]) -> None:
|
|
key = "alpha"
|
|
alpha = self.get(kwargs, key)
|
|
if isinstance(alpha, AutoAlphaFactory):
|
|
kwargs[key] = alpha.create_auto_alpha(self.envs, self.optim_factory, self.device)
|
|
|
|
|
|
class ParamTransformerMultiLRScheduler(ParamTransformer):
|
|
def __init__(self, optim_key_list: List[Tuple[torch.optim.Optimizer, str]]):
|
|
self.optim_key_list = optim_key_list
|
|
|
|
def transform(self, kwargs: Dict[str, Any]) -> None:
|
|
lr_schedulers = []
|
|
for optim, lr_scheduler_factory_key in self.optim_key_list:
|
|
lr_scheduler_factory: LRSchedulerFactory | None = self.get(kwargs, lr_scheduler_factory_key, drop=True)
|
|
if lr_scheduler_factory is not None:
|
|
lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
|
|
match len(lr_schedulers):
|
|
case 0:
|
|
lr_scheduler = None
|
|
case 1:
|
|
lr_scheduler = lr_schedulers[0]
|
|
case _:
|
|
lr_scheduler = MultipleLRSchedulers(*lr_schedulers)
|
|
kwargs["lr_scheduler"] = lr_scheduler
|
|
|
|
|
|
class SACAgentFactory(OffpolicyAgentFactory):
|
|
def __init__(
|
|
self,
|
|
params: SACParams,
|
|
sampling_config: RLSamplingConfig,
|
|
actor_factory: ActorFactory,
|
|
critic1_factory: CriticFactory,
|
|
critic2_factory: CriticFactory,
|
|
optim_factory: OptimizerFactory,
|
|
):
|
|
super().__init__(sampling_config)
|
|
self.critic2_factory = critic2_factory
|
|
self.critic1_factory = critic1_factory
|
|
self.actor_factory = actor_factory
|
|
self.optim_factory = optim_factory
|
|
self.params = params
|
|
|
|
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
|
actor = self.actor_factory.create_module(envs, device)
|
|
critic1 = self.critic1_factory.create_module(envs, device, use_action=True)
|
|
critic2 = self.critic2_factory.create_module(envs, device, use_action=True)
|
|
actor_optim = self.optim_factory.create_optimizer(actor, lr=self.params.actor_lr)
|
|
critic1_optim = self.optim_factory.create_optimizer(critic1, lr=self.params.critic1_lr)
|
|
critic2_optim = self.optim_factory.create_optimizer(critic2, lr=self.params.critic2_lr)
|
|
kwargs = self.params.create_kwargs(
|
|
ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
|
|
ParamTransformerMultiLRScheduler([
|
|
(actor_optim, "actor_lr_scheduler_factory"),
|
|
(critic1_optim, "critic1_lr_scheduler_factory"),
|
|
(critic2_optim, "critic2_lr_scheduler_factory")]
|
|
),
|
|
ParamTransformerAlpha(envs, optim_factory=self.optim_factory, device=device))
|
|
return SACPolicy(
|
|
actor=actor,
|
|
actor_optim=actor_optim,
|
|
critic=critic1,
|
|
critic_optim=critic1_optim,
|
|
critic2=critic2,
|
|
critic2_optim=critic2_optim,
|
|
action_space=envs.get_action_space(),
|
|
observation_space=envs.get_observation_space(),
|
|
**kwargs
|
|
)
|