307 lines
11 KiB
Python
Raw Normal View History

import os
2023-09-20 09:29:34 +02:00
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
2023-09-20 13:15:06 +02:00
from typing import Literal
import numpy as np
import torch
2023-09-20 09:29:34 +02:00
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.exploration import BaseNoise
from tianshou.highlevel.env import Environments
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.highlevel.logger import Logger
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
2023-09-20 09:29:34 +02:00
from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory
from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
from tianshou.utils.net.common import ActorCritic
CHECKPOINT_DICT_KEY_MODEL = "model"
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
class AgentFactory(ABC):
2023-09-20 09:29:34 +02:00
def __init__(self, sampling_config: RLSamplingConfig):
self.sampling_config = sampling_config
def create_train_test_collector(self, policy: BasePolicy, envs: Environments):
buffer_size = self.sampling_config.buffer_size
train_envs = envs.train_envs
if len(train_envs) > 1:
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, envs.test_envs)
if self.sampling_config.start_timesteps > 0:
train_collector.collect(n_step=self.sampling_config.start_timesteps, random=True)
2023-09-20 09:29:34 +02:00
return train_collector, test_collector
@abstractmethod
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
pass
@staticmethod
def _create_save_best_fn(envs: Environments, log_path: str) -> Callable:
2023-09-20 09:29:34 +02:00
def save_best_fn(pol: torch.nn.Module) -> None:
state = {
CHECKPOINT_DICT_KEY_MODEL: pol.state_dict(),
CHECKPOINT_DICT_KEY_OBS_RMS: envs.train_envs.get_obs_rms(),
}
torch.save(state, os.path.join(log_path, "policy.pth"))
return save_best_fn
@staticmethod
def load_checkpoint(policy: torch.nn.Module, path, envs: Environments, device: TDevice):
ckpt = torch.load(path, map_location=device)
policy.load_state_dict(ckpt[CHECKPOINT_DICT_KEY_MODEL])
if envs.train_envs:
envs.train_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS])
if envs.test_envs:
envs.test_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS])
print("Loaded agent and obs. running means from: ", path) # TODO logging
@abstractmethod
2023-09-20 09:29:34 +02:00
def create_trainer(
self,
policy: BasePolicy,
train_collector: Collector,
test_collector: Collector,
envs: Environments,
logger: Logger,
) -> BaseTrainer:
pass
class OnpolicyAgentFactory(AgentFactory, ABC):
2023-09-20 09:29:34 +02:00
def create_trainer(
self,
policy: BasePolicy,
train_collector: Collector,
test_collector: Collector,
envs: Environments,
logger: Logger,
) -> OnpolicyTrainer:
sampling_config = self.sampling_config
return OnpolicyTrainer(
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch,
repeat_per_collect=sampling_config.repeat_per_collect,
episode_per_test=sampling_config.num_test_envs,
batch_size=sampling_config.batch_size,
step_per_collect=sampling_config.step_per_collect,
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
logger=logger.logger,
test_in_train=False,
)
2023-09-20 09:29:34 +02:00
class OffpolicyAgentFactory(AgentFactory, ABC):
def create_trainer(
self,
policy: BasePolicy,
train_collector: Collector,
test_collector: Collector,
envs: Environments,
logger: Logger,
) -> OffpolicyTrainer:
sampling_config = self.sampling_config
return OffpolicyTrainer(
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch,
step_per_collect=sampling_config.step_per_collect,
episode_per_test=sampling_config.num_test_envs,
batch_size=sampling_config.batch_size,
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
logger=logger.logger,
update_per_step=sampling_config.update_per_step,
test_in_train=False,
)
2023-09-20 13:15:06 +02:00
@dataclass
class RLAgentConfig:
"""Config common to most RL algorithms."""
gamma: float = 0.99
"""Discount factor"""
gae_lambda: float = 0.95
"""For Generalized Advantage Estimate (equivalent to TD(lambda))"""
action_bound_method: Literal["clip", "tanh"] | None = "clip"
"""How to map original actions in range (-inf, inf) to [-1, 1]"""
rew_norm: bool = True
"""Whether to normalize rewards"""
@dataclass
class PGConfig:
"""Config of general policy-gradient algorithms."""
ent_coef: float = 0.0
vf_coef: float = 0.25
max_grad_norm: float = 0.5
@dataclass
class PPOConfig:
"""PPO specific config."""
value_clip: bool = False
norm_adv: bool = False
"""Whether to normalize advantages"""
eps_clip: float = 0.2
dual_clip: float | None = None
recompute_adv: bool = True
class PPOAgentFactory(OnpolicyAgentFactory):
2023-09-20 09:29:34 +02:00
def __init__(
self,
general_config: RLAgentConfig,
pg_config: PGConfig,
ppo_config: PPOConfig,
sampling_config: RLSamplingConfig,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
dist_fn,
lr: float,
lr_scheduler_factory: LRSchedulerFactory | None = None,
):
super().__init__(sampling_config)
self.optimizer_factory = optimizer_factory
self.critic_factory = critic_factory
self.actor_factory = actor_factory
self.ppo_config = ppo_config
self.pg_config = pg_config
self.general_config = general_config
2023-09-20 09:29:34 +02:00
self.lr = lr
self.lr_scheduler_factory = lr_scheduler_factory
self.dist_fn = dist_fn
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
actor = self.actor_factory.create_module(envs, device)
2023-09-20 09:29:34 +02:00
critic = self.critic_factory.create_module(envs, device, use_action=False)
actor_critic = ActorCritic(actor, critic)
2023-09-20 09:29:34 +02:00
optim = self.optimizer_factory.create_optimizer(actor_critic, self.lr)
if self.lr_scheduler_factory is not None:
lr_scheduler = self.lr_scheduler_factory.create_scheduler(optim)
else:
lr_scheduler = None
return PPOPolicy(
# nn-stuff
actor,
critic,
optim,
dist_fn=self.dist_fn,
lr_scheduler=lr_scheduler,
# env-stuff
action_space=envs.get_action_space(),
action_scaling=True,
# general_config
discount_factor=self.general_config.gamma,
gae_lambda=self.general_config.gae_lambda,
reward_normalization=self.general_config.rew_norm,
action_bound_method=self.general_config.action_bound_method,
# pg_config
max_grad_norm=self.pg_config.max_grad_norm,
vf_coef=self.pg_config.vf_coef,
ent_coef=self.pg_config.ent_coef,
# ppo_config
eps_clip=self.ppo_config.eps_clip,
value_clip=self.ppo_config.value_clip,
dual_clip=self.ppo_config.dual_clip,
advantage_normalization=self.ppo_config.norm_adv,
recompute_advantage=self.ppo_config.recompute_adv,
)
2023-09-20 09:29:34 +02:00
class AutoAlphaFactory(ABC):
@abstractmethod
def create_auto_alpha(
self, envs: Environments, optim_factory: OptimizerFactory, device: TDevice,
):
pass
class DefaultAutoAlphaFactory(AutoAlphaFactory): # TODO better name?
def __init__(self, lr: float = 3e-4):
self.lr = lr
def create_auto_alpha(
self, envs: Environments, optim_factory: OptimizerFactory, device: TDevice,
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
target_entropy = -np.prod(envs.get_action_shape())
log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
return target_entropy, log_alpha, alpha_optim
@dataclass
2023-09-20 13:15:06 +02:00
class SACConfig:
tau: float = 0.005
gamma: float = 0.99
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
2023-09-20 13:15:06 +02:00
reward_normalization: bool = False
estimation_step: int = 1
deterministic_eval: bool = True
actor_lr: float = 1e-3
critic1_lr: float = 1e-3
critic2_lr: float = 1e-3
2023-09-20 09:29:34 +02:00
class SACAgentFactory(OffpolicyAgentFactory):
def __init__(
self,
2023-09-20 13:15:06 +02:00
config: SACConfig,
2023-09-20 09:29:34 +02:00
sampling_config: RLSamplingConfig,
actor_factory: ActorFactory,
critic1_factory: CriticFactory,
critic2_factory: CriticFactory,
optim_factory: OptimizerFactory,
exploration_noise: BaseNoise | None = None,
):
super().__init__(sampling_config)
self.critic2_factory = critic2_factory
self.critic1_factory = critic1_factory
self.actor_factory = actor_factory
self.exploration_noise = exploration_noise
self.optim_factory = optim_factory
self.config = config
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
actor = self.actor_factory.create_module(envs, device)
critic1 = self.critic1_factory.create_module(envs, device, use_action=True)
critic2 = self.critic2_factory.create_module(envs, device, use_action=True)
actor_optim = self.optim_factory.create_optimizer(actor, lr=self.config.actor_lr)
critic1_optim = self.optim_factory.create_optimizer(critic1, lr=self.config.critic1_lr)
critic2_optim = self.optim_factory.create_optimizer(critic2, lr=self.config.critic2_lr)
if isinstance(self.config.alpha, AutoAlphaFactory):
alpha = self.config.alpha.create_auto_alpha(envs, self.optim_factory, device)
else:
alpha = self.config.alpha
2023-09-20 09:29:34 +02:00
return SACPolicy(
actor,
actor_optim,
critic1,
critic1_optim,
critic2,
critic2_optim,
tau=self.config.tau,
gamma=self.config.gamma,
alpha=alpha,
2023-09-20 09:29:34 +02:00
estimation_step=self.config.estimation_step,
action_space=envs.get_action_space(),
deterministic_eval=self.config.deterministic_eval,
exploration_noise=self.exploration_noise,
)