147 lines
5.8 KiB
Python
147 lines
5.8 KiB
Python
import os
|
|
from abc import abstractmethod, ABC
|
|
from typing import Callable
|
|
|
|
import torch
|
|
|
|
from tianshou.config import RLSamplingConfig, PGConfig, PPOConfig, RLAgentConfig, NNConfig
|
|
from tianshou.data import VectorReplayBuffer, ReplayBuffer, Collector
|
|
from tianshou.highlevel.env import Environments
|
|
from tianshou.highlevel.logger import Logger
|
|
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
|
|
from tianshou.highlevel.optim import OptimizerFactory, LRSchedulerFactory
|
|
from tianshou.policy import BasePolicy, PPOPolicy
|
|
from tianshou.trainer import BaseTrainer, OnpolicyTrainer
|
|
from tianshou.utils.net.common import ActorCritic
|
|
|
|
|
|
CHECKPOINT_DICT_KEY_MODEL = "model"
|
|
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
|
|
|
|
|
class AgentFactory(ABC):
|
|
@abstractmethod
|
|
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
|
pass
|
|
|
|
@staticmethod
|
|
def _create_save_best_fn(envs: Environments, log_path: str) -> Callable:
|
|
def save_best_fn(pol: torch.nn.Module):
|
|
state = {"model": pol.state_dict(), "obs_rms": envs.train_envs.get_obs_rms()}
|
|
torch.save(state, os.path.join(log_path, "policy.pth"))
|
|
|
|
return save_best_fn
|
|
|
|
@staticmethod
|
|
def load_checkpoint(policy: torch.nn.Module, path, envs: Environments, device: TDevice):
|
|
ckpt = torch.load(path, map_location=device)
|
|
policy.load_state_dict(ckpt[CHECKPOINT_DICT_KEY_MODEL])
|
|
if envs.train_envs:
|
|
envs.train_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS])
|
|
if envs.test_envs:
|
|
envs.test_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS])
|
|
print("Loaded agent and obs. running means from: ", path) # TODO logging
|
|
|
|
@abstractmethod
|
|
def create_train_test_collector(self,
|
|
policy: BasePolicy,
|
|
envs: Environments):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def create_trainer(self, policy: BasePolicy, train_collector: Collector, test_collector: Collector,
|
|
envs: Environments, logger: Logger) -> BaseTrainer:
|
|
pass
|
|
|
|
|
|
class OnpolicyAgentFactory(AgentFactory, ABC):
|
|
def __init__(self, sampling_config: RLSamplingConfig):
|
|
self.sampling_config = sampling_config
|
|
|
|
def create_train_test_collector(self,
|
|
policy: BasePolicy,
|
|
envs: Environments):
|
|
buffer_size = self.sampling_config.buffer_size
|
|
train_envs = envs.train_envs
|
|
if len(train_envs) > 1:
|
|
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
|
|
else:
|
|
buffer = ReplayBuffer(buffer_size)
|
|
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
|
test_collector = Collector(policy, envs.test_envs)
|
|
return train_collector, test_collector
|
|
|
|
def create_trainer(self, policy: BasePolicy, train_collector: Collector, test_collector: Collector,
|
|
envs: Environments, logger: Logger) -> OnpolicyTrainer:
|
|
sampling_config = self.sampling_config
|
|
return OnpolicyTrainer(
|
|
policy=policy,
|
|
train_collector=train_collector,
|
|
test_collector=test_collector,
|
|
max_epoch=sampling_config.num_epochs,
|
|
step_per_epoch=sampling_config.step_per_epoch,
|
|
repeat_per_collect=sampling_config.repeat_per_collect,
|
|
episode_per_test=sampling_config.num_test_envs,
|
|
batch_size=sampling_config.batch_size,
|
|
step_per_collect=sampling_config.step_per_collect,
|
|
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
|
|
logger=logger.logger,
|
|
test_in_train=False,
|
|
)
|
|
|
|
|
|
class PPOAgentFactory(OnpolicyAgentFactory):
|
|
def __init__(self, general_config: RLAgentConfig,
|
|
pg_config: PGConfig,
|
|
ppo_config: PPOConfig,
|
|
sampling_config: RLSamplingConfig,
|
|
nn_config: NNConfig,
|
|
actor_factory: ActorFactory,
|
|
critic_factory: CriticFactory,
|
|
optimizer_factory: OptimizerFactory,
|
|
dist_fn,
|
|
lr_scheduler_factory: LRSchedulerFactory):
|
|
super().__init__(sampling_config)
|
|
self.optimizer_factory = optimizer_factory
|
|
self.critic_factory = critic_factory
|
|
self.actor_factory = actor_factory
|
|
self.ppo_config = ppo_config
|
|
self.pg_config = pg_config
|
|
self.general_config = general_config
|
|
self.lr_scheduler_factory = lr_scheduler_factory
|
|
self.dist_fn = dist_fn
|
|
self.nn_config = nn_config
|
|
|
|
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
|
|
actor = self.actor_factory.create_module(envs, device)
|
|
critic = self.critic_factory.create_module(envs, device)
|
|
actor_critic = ActorCritic(actor, critic)
|
|
optim = self.optimizer_factory.create_optimizer(actor_critic)
|
|
lr_scheduler = self.lr_scheduler_factory.create_scheduler(optim)
|
|
return PPOPolicy(
|
|
# nn-stuff
|
|
actor,
|
|
critic,
|
|
optim,
|
|
dist_fn=self.dist_fn,
|
|
lr_scheduler=lr_scheduler,
|
|
# env-stuff
|
|
action_space=envs.get_action_space(),
|
|
action_scaling=True,
|
|
# general_config
|
|
discount_factor=self.general_config.gamma,
|
|
gae_lambda=self.general_config.gae_lambda,
|
|
reward_normalization=self.general_config.rew_norm,
|
|
action_bound_method=self.general_config.action_bound_method,
|
|
# pg_config
|
|
max_grad_norm=self.pg_config.max_grad_norm,
|
|
vf_coef=self.pg_config.vf_coef,
|
|
ent_coef=self.pg_config.ent_coef,
|
|
# ppo_config
|
|
eps_clip=self.ppo_config.eps_clip,
|
|
value_clip=self.ppo_config.value_clip,
|
|
dual_clip=self.ppo_config.dual_clip,
|
|
advantage_normalization=self.ppo_config.norm_adv,
|
|
recompute_advantage=self.ppo_config.recompute_adv,
|
|
)
|