Initial high-level interfaces, demonstrated in mujoco_ppo_hl
This commit is contained in:
parent
a54aade730
commit
16ed5fd2a5
@ -2,7 +2,9 @@ import warnings
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from tianshou.config import RLSamplingConfig, BasicExperimentConfig
|
||||
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
||||
from tianshou.highlevel.env import EnvFactory, Environments, ContinuousEnvironments
|
||||
|
||||
try:
|
||||
import envpool
|
||||
@ -38,3 +40,19 @@ def make_mujoco_env(
|
||||
test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False)
|
||||
test_envs.set_obs_rms(train_envs.get_obs_rms())
|
||||
return env, train_envs, test_envs
|
||||
|
||||
|
||||
class MujocoEnvFactory(EnvFactory):
|
||||
def __init__(self, experiment_config: BasicExperimentConfig, sampling_config: RLSamplingConfig):
|
||||
self.sampling_config = sampling_config
|
||||
self.experiment_config = experiment_config
|
||||
|
||||
def create_envs(self) -> ContinuousEnvironments:
|
||||
env, train_envs, test_envs = make_mujoco_env(
|
||||
task=self.experiment_config.task,
|
||||
seed=self.experiment_config.seed,
|
||||
num_train_envs=self.sampling_config.num_train_envs,
|
||||
num_test_envs=self.sampling_config.num_test_envs,
|
||||
obs_norm=True,
|
||||
)
|
||||
return ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
|
||||
|
60
examples/mujoco/mujoco_ppo_hl.py
Normal file
60
examples/mujoco/mujoco_ppo_hl.py
Normal file
@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import datetime
|
||||
import os
|
||||
|
||||
from jsonargparse import CLI
|
||||
from torch.distributions import Independent, Normal
|
||||
|
||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||
from tianshou.config import (
|
||||
BasicExperimentConfig,
|
||||
LoggerConfig,
|
||||
NNConfig,
|
||||
PGConfig,
|
||||
PPOConfig,
|
||||
RLAgentConfig,
|
||||
RLSamplingConfig,
|
||||
)
|
||||
from tianshou.highlevel.agent import PPOAgentFactory
|
||||
from tianshou.highlevel.logger import DefaultLoggerFactory
|
||||
from tianshou.highlevel.module import ContinuousActorProbFactory, ContinuousNetCriticFactory
|
||||
from tianshou.highlevel.optim import AdamOptimizerFactory, LinearLRSchedulerFactory
|
||||
from tianshou.highlevel.experiment import RLExperiment
|
||||
|
||||
|
||||
def main(
|
||||
experiment_config: BasicExperimentConfig,
|
||||
logger_config: LoggerConfig,
|
||||
sampling_config: RLSamplingConfig,
|
||||
general_config: RLAgentConfig,
|
||||
pg_config: PGConfig,
|
||||
ppo_config: PPOConfig,
|
||||
nn_config: NNConfig,
|
||||
):
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
log_name = os.path.join(experiment_config.task, "ppo", str(experiment_config.seed), now)
|
||||
logger_factory = DefaultLoggerFactory(logger_config)
|
||||
|
||||
env_factory = MujocoEnvFactory(experiment_config, sampling_config)
|
||||
|
||||
def dist_fn(*logits):
|
||||
return Independent(Normal(*logits), 1)
|
||||
|
||||
actor_factory = ContinuousActorProbFactory(nn_config.hidden_sizes)
|
||||
critic_factory = ContinuousNetCriticFactory(nn_config.hidden_sizes)
|
||||
optim_factory = AdamOptimizerFactory(lr=nn_config.lr)
|
||||
lr_scheduler_factory = LinearLRSchedulerFactory(nn_config, sampling_config)
|
||||
agent_factory = PPOAgentFactory(general_config, pg_config, ppo_config, sampling_config, nn_config,
|
||||
actor_factory, critic_factory, optim_factory, dist_fn, lr_scheduler_factory)
|
||||
|
||||
experiment = RLExperiment(experiment_config, logger_config, general_config, sampling_config,
|
||||
env_factory,
|
||||
logger_factory,
|
||||
agent_factory)
|
||||
|
||||
experiment.run(log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI(main)
|
@ -14,8 +14,8 @@ class BasicExperimentConfig:
|
||||
seed: int = 42
|
||||
task: str = "Ant-v4"
|
||||
"""Mujoco specific"""
|
||||
render: float = 0.0
|
||||
"""Milliseconds between rendered frames"""
|
||||
render: Optional[float] = 0.0
|
||||
"""Milliseconds between rendered frames; if None, no rendering"""
|
||||
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
resume_id: Optional[int] = None
|
||||
"""For restoring a model and running means of env-specifics from a checkpoint"""
|
||||
@ -23,6 +23,7 @@ class BasicExperimentConfig:
|
||||
"""For restoring a model and running means of env-specifics from a checkpoint"""
|
||||
watch: bool = False
|
||||
"""If True, will not perform training and only watch the restored policy"""
|
||||
watch_num_episodes = 10
|
||||
|
||||
|
||||
@dataclass
|
||||
|
0
tianshou/highlevel/__init__.py
Normal file
0
tianshou/highlevel/__init__.py
Normal file
146
tianshou/highlevel/agent.py
Normal file
146
tianshou/highlevel/agent.py
Normal file
@ -0,0 +1,146 @@
|
||||
import os
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from tianshou.config import RLSamplingConfig, PGConfig, PPOConfig, RLAgentConfig, NNConfig
|
||||
from tianshou.data import VectorReplayBuffer, ReplayBuffer, Collector
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.logger import Logger
|
||||
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
|
||||
from tianshou.highlevel.optim import OptimizerFactory, LRSchedulerFactory
|
||||
from tianshou.policy import BasePolicy, PPOPolicy
|
||||
from tianshou.trainer import BaseTrainer, OnpolicyTrainer
|
||||
from tianshou.utils.net.common import ActorCritic
|
||||
|
||||
|
||||
CHECKPOINT_DICT_KEY_MODEL = "model"
|
||||
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
||||
|
||||
|
||||
class AgentFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _create_save_best_fn(envs: Environments, log_path: str) -> Callable:
|
||||
def save_best_fn(pol: torch.nn.Module):
|
||||
state = {"model": pol.state_dict(), "obs_rms": envs.train_envs.get_obs_rms()}
|
||||
torch.save(state, os.path.join(log_path, "policy.pth"))
|
||||
|
||||
return save_best_fn
|
||||
|
||||
@staticmethod
|
||||
def load_checkpoint(policy: torch.nn.Module, path, envs: Environments, device: TDevice):
|
||||
ckpt = torch.load(path, map_location=device)
|
||||
policy.load_state_dict(ckpt[CHECKPOINT_DICT_KEY_MODEL])
|
||||
if envs.train_envs:
|
||||
envs.train_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS])
|
||||
if envs.test_envs:
|
||||
envs.test_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS])
|
||||
print("Loaded agent and obs. running means from: ", path) # TODO logging
|
||||
|
||||
@abstractmethod
|
||||
def create_train_test_collector(self,
|
||||
policy: BasePolicy,
|
||||
envs: Environments):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_trainer(self, policy: BasePolicy, train_collector: Collector, test_collector: Collector,
|
||||
envs: Environments, logger: Logger) -> BaseTrainer:
|
||||
pass
|
||||
|
||||
|
||||
class OnpolicyAgentFactory(AgentFactory, ABC):
|
||||
def __init__(self, sampling_config: RLSamplingConfig):
|
||||
self.sampling_config = sampling_config
|
||||
|
||||
def create_train_test_collector(self,
|
||||
policy: BasePolicy,
|
||||
envs: Environments):
|
||||
buffer_size = self.sampling_config.buffer_size
|
||||
train_envs = envs.train_envs
|
||||
if len(train_envs) > 1:
|
||||
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
|
||||
else:
|
||||
buffer = ReplayBuffer(buffer_size)
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, envs.test_envs)
|
||||
return train_collector, test_collector
|
||||
|
||||
def create_trainer(self, policy: BasePolicy, train_collector: Collector, test_collector: Collector,
|
||||
envs: Environments, logger: Logger) -> OnpolicyTrainer:
|
||||
sampling_config = self.sampling_config
|
||||
return OnpolicyTrainer(
|
||||
policy=policy,
|
||||
train_collector=train_collector,
|
||||
test_collector=test_collector,
|
||||
max_epoch=sampling_config.num_epochs,
|
||||
step_per_epoch=sampling_config.step_per_epoch,
|
||||
repeat_per_collect=sampling_config.repeat_per_collect,
|
||||
episode_per_test=sampling_config.num_test_envs,
|
||||
batch_size=sampling_config.batch_size,
|
||||
step_per_collect=sampling_config.step_per_collect,
|
||||
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
|
||||
logger=logger.logger,
|
||||
test_in_train=False,
|
||||
)
|
||||
|
||||
|
||||
class PPOAgentFactory(OnpolicyAgentFactory):
|
||||
def __init__(self, general_config: RLAgentConfig,
|
||||
pg_config: PGConfig,
|
||||
ppo_config: PPOConfig,
|
||||
sampling_config: RLSamplingConfig,
|
||||
nn_config: NNConfig,
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
optimizer_factory: OptimizerFactory,
|
||||
dist_fn,
|
||||
lr_scheduler_factory: LRSchedulerFactory):
|
||||
super().__init__(sampling_config)
|
||||
self.optimizer_factory = optimizer_factory
|
||||
self.critic_factory = critic_factory
|
||||
self.actor_factory = actor_factory
|
||||
self.ppo_config = ppo_config
|
||||
self.pg_config = pg_config
|
||||
self.general_config = general_config
|
||||
self.lr_scheduler_factory = lr_scheduler_factory
|
||||
self.dist_fn = dist_fn
|
||||
self.nn_config = nn_config
|
||||
|
||||
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
|
||||
actor = self.actor_factory.create_module(envs, device)
|
||||
critic = self.critic_factory.create_module(envs, device)
|
||||
actor_critic = ActorCritic(actor, critic)
|
||||
optim = self.optimizer_factory.create_optimizer(actor_critic)
|
||||
lr_scheduler = self.lr_scheduler_factory.create_scheduler(optim)
|
||||
return PPOPolicy(
|
||||
# nn-stuff
|
||||
actor,
|
||||
critic,
|
||||
optim,
|
||||
dist_fn=self.dist_fn,
|
||||
lr_scheduler=lr_scheduler,
|
||||
# env-stuff
|
||||
action_space=envs.get_action_space(),
|
||||
action_scaling=True,
|
||||
# general_config
|
||||
discount_factor=self.general_config.gamma,
|
||||
gae_lambda=self.general_config.gae_lambda,
|
||||
reward_normalization=self.general_config.rew_norm,
|
||||
action_bound_method=self.general_config.action_bound_method,
|
||||
# pg_config
|
||||
max_grad_norm=self.pg_config.max_grad_norm,
|
||||
vf_coef=self.pg_config.vf_coef,
|
||||
ent_coef=self.pg_config.ent_coef,
|
||||
# ppo_config
|
||||
eps_clip=self.ppo_config.eps_clip,
|
||||
value_clip=self.ppo_config.value_clip,
|
||||
dual_clip=self.ppo_config.dual_clip,
|
||||
advantage_normalization=self.ppo_config.norm_adv,
|
||||
recompute_advantage=self.ppo_config.recompute_adv,
|
||||
)
|
71
tianshou/highlevel/env.py
Normal file
71
tianshou/highlevel/env.py
Normal file
@ -0,0 +1,71 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple, Optional, Dict, Any, Union, Sequence
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from tianshou.env import BaseVectorEnv
|
||||
|
||||
TShape = Union[int, Sequence[int]]
|
||||
|
||||
|
||||
class Environments(ABC):
|
||||
def __init__(self, env: Optional[gym.Env], train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
self.env = env
|
||||
self.train_envs = train_envs
|
||||
self.test_envs = test_envs
|
||||
|
||||
def info(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"action_shape": self.get_action_shape(),
|
||||
"state_shape": self.get_state_shape()
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
def get_action_shape(self) -> TShape:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_state_shape(self) -> TShape:
|
||||
pass
|
||||
|
||||
def get_action_space(self) -> gym.Space:
|
||||
return self.env.action_space
|
||||
|
||||
|
||||
class ContinuousEnvironments(Environments):
|
||||
def __init__(self, env: Optional[gym.Env], train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
super().__init__(env, train_envs, test_envs)
|
||||
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
|
||||
|
||||
def info(self):
|
||||
d = super().info()
|
||||
d["max_action"] = self.max_action
|
||||
return d
|
||||
|
||||
@staticmethod
|
||||
def _get_continuous_env_info(
|
||||
env: gym.Env,
|
||||
) -> Tuple[Tuple[int, ...], Tuple[int, ...], float]:
|
||||
if not isinstance(env.action_space, gym.spaces.Box):
|
||||
raise ValueError(
|
||||
"Only environments with continuous action space are supported here. "
|
||||
f"But got env with action space: {env.action_space.__class__}."
|
||||
)
|
||||
state_shape = env.observation_space.shape or env.observation_space.n
|
||||
if not state_shape:
|
||||
raise ValueError("Observation space shape is not defined")
|
||||
action_shape = env.action_space.shape
|
||||
max_action = env.action_space.high[0]
|
||||
return state_shape, action_shape, max_action
|
||||
|
||||
def get_action_shape(self) -> TShape:
|
||||
return self.action_shape
|
||||
|
||||
def get_state_shape(self) -> TShape:
|
||||
return self.state_shape
|
||||
|
||||
|
||||
class EnvFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_envs(self) -> Environments:
|
||||
pass
|
76
tianshou/highlevel/experiment.py
Normal file
76
tianshou/highlevel/experiment.py
Normal file
@ -0,0 +1,76 @@
|
||||
from pprint import pprint
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from tianshou.config import BasicExperimentConfig, LoggerConfig, RLAgentConfig, RLSamplingConfig
|
||||
from tianshou.data import Collector
|
||||
from tianshou.highlevel.agent import AgentFactory
|
||||
from tianshou.highlevel.env import EnvFactory
|
||||
from tianshou.highlevel.logger import LoggerFactory
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.trainer import BaseTrainer
|
||||
|
||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||
TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
|
||||
|
||||
|
||||
class RLExperiment(Generic[TPolicy, TTrainer]):
|
||||
def __init__(self,
|
||||
config: BasicExperimentConfig,
|
||||
logger_config: LoggerConfig,
|
||||
general_config: RLAgentConfig,
|
||||
sampling_config: RLSamplingConfig,
|
||||
env_factory: EnvFactory,
|
||||
logger_factory: LoggerFactory,
|
||||
agent_factory: AgentFactory):
|
||||
self.config = config
|
||||
self.logger_config = logger_config
|
||||
self.general_config = general_config
|
||||
self.sampling_config = sampling_config
|
||||
self.env_factory = env_factory
|
||||
self.logger_factory = logger_factory
|
||||
self.agent_factory = agent_factory
|
||||
|
||||
def _set_seed(self):
|
||||
seed = self.config.seed
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
def _build_config_dict(self) -> dict:
|
||||
return {
|
||||
# TODO
|
||||
}
|
||||
|
||||
def run(self, log_name: str):
|
||||
self._set_seed()
|
||||
|
||||
envs = self.env_factory.create_envs()
|
||||
|
||||
full_config = self._build_config_dict()
|
||||
full_config.update(envs.info())
|
||||
|
||||
run_id = self.config.resume_id
|
||||
logger = self.logger_factory.create_logger(log_name=log_name, run_id=run_id, config_dict=full_config)
|
||||
|
||||
policy = self.agent_factory.create_policy(envs, self.config.device)
|
||||
if self.config.resume_path:
|
||||
self.agent_factory.load_checkpoint(policy, self.config.resume_path, envs, self.config.device)
|
||||
|
||||
train_collector, test_collector = self.agent_factory.create_train_test_collector(policy, envs)
|
||||
|
||||
if not self.config.watch:
|
||||
trainer = self.agent_factory.create_trainer(policy, train_collector, test_collector, envs, logger)
|
||||
result = trainer.run()
|
||||
pprint(result) # TODO logging
|
||||
|
||||
self._watch_agent(self.config.watch_num_episodes, policy, test_collector, self.config.render)
|
||||
|
||||
@staticmethod
|
||||
def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render):
|
||||
policy.eval()
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=num_episodes, render=render)
|
||||
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
||||
|
49
tianshou/highlevel/logger.py
Normal file
49
tianshou/highlevel/logger.py
Normal file
@ -0,0 +1,49 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Union, Optional
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.config import LoggerConfig
|
||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||
|
||||
|
||||
TLogger = Union[TensorboardLogger, WandbLogger]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Logger:
|
||||
logger: TLogger
|
||||
log_path: str
|
||||
|
||||
|
||||
class LoggerFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_logger(self, log_name: str, run_id: Optional[int], config_dict: dict) -> Logger:
|
||||
pass
|
||||
|
||||
|
||||
class DefaultLoggerFactory(LoggerFactory):
|
||||
def __init__(self, config: LoggerConfig):
|
||||
self.config = config
|
||||
|
||||
def create_logger(self, log_name: str, run_id: Optional[int], config_dict: dict) -> Logger:
|
||||
writer = SummaryWriter(self.config.logdir)
|
||||
writer.add_text("args", str(self.config))
|
||||
if self.config.logger == "wandb":
|
||||
logger = WandbLogger(
|
||||
save_interval=1,
|
||||
name=log_name.replace(os.path.sep, "__"),
|
||||
run_id=run_id,
|
||||
config=config_dict,
|
||||
project=self.config.wandb_project,
|
||||
)
|
||||
logger.load(writer)
|
||||
elif self.config.logger == "tensorboard":
|
||||
logger = TensorboardLogger(writer)
|
||||
else:
|
||||
raise ValueError(f"Unknown logger: {self.config.logger}")
|
||||
log_path = os.path.join(self.config.logdir, log_name)
|
||||
os.makedirs(log_path, exist_ok=True)
|
||||
return Logger(logger=logger, log_path=log_path)
|
90
tianshou/highlevel/module.py
Normal file
90
tianshou/highlevel/module.py
Normal file
@ -0,0 +1,90 @@
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic as ContinuousCritic
|
||||
|
||||
TDevice = str | int | torch.device
|
||||
|
||||
|
||||
def init_linear_orthogonal(m: torch.nn.Module):
|
||||
"""
|
||||
Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0
|
||||
|
||||
:param m: the module whose submodules are to be processed
|
||||
"""
|
||||
for m in m.modules():
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
|
||||
|
||||
class ActorFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _init_linear(actor: torch.nn.Module):
|
||||
"""
|
||||
Initializes linear layers of an actor module using default mechanisms
|
||||
:param module: the actor module
|
||||
"""
|
||||
init_linear_orthogonal(actor)
|
||||
if hasattr(actor, "mu"):
|
||||
# For continuous action spaces with Gaussian policies
|
||||
# do last policy layer scaling, this will make initial actions have (close to)
|
||||
# 0 mean and std, and will help boost performances,
|
||||
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
|
||||
for m in actor.mu.modules():
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
m.weight.data.copy_(0.01 * m.weight.data)
|
||||
|
||||
|
||||
class ContinuousActorFactory(ActorFactory, ABC):
|
||||
pass
|
||||
|
||||
|
||||
class ContinuousActorProbFactory(ContinuousActorFactory):
|
||||
def __init__(self, hidden_sizes: Sequence[int]):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||
net_a = Net(
|
||||
envs.get_state_shape(), hidden_sizes=self.hidden_sizes, activation=nn.Tanh, device=device
|
||||
)
|
||||
actor = ActorProb(net_a, envs.get_action_shape(), unbounded=True, device=device).to(device)
|
||||
|
||||
# init params
|
||||
torch.nn.init.constant_(actor.sigma_param, -0.5)
|
||||
self._init_linear(actor)
|
||||
|
||||
return actor
|
||||
|
||||
|
||||
class CriticFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||
pass
|
||||
|
||||
|
||||
class ContinuousCriticFactory(CriticFactory, ABC):
|
||||
pass
|
||||
|
||||
|
||||
class ContinuousNetCriticFactory(ContinuousCriticFactory):
|
||||
def __init__(self, hidden_sizes: Sequence[int]):
|
||||
self.hidden_sizes = hidden_sizes
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||
net_c = Net(
|
||||
envs.get_state_shape(), hidden_sizes=self.hidden_sizes, activation=nn.Tanh, device=device
|
||||
)
|
||||
critic = ContinuousCritic(net_c, device=device).to(device)
|
||||
init_linear_orthogonal(critic)
|
||||
return critic
|
54
tianshou/highlevel/optim.py
Normal file
54
tianshou/highlevel/optim.py
Normal file
@ -0,0 +1,54 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Union, Iterable, Dict, Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import LRScheduler, LambdaLR
|
||||
|
||||
from tianshou.config import RLSamplingConfig, NNConfig
|
||||
|
||||
TParams = Union[Iterable[Tensor], Iterable[Dict[str, Any]]]
|
||||
|
||||
|
||||
class OptimizerFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_optimizer(self, module: torch.nn.Module) -> torch.optim.Optimizer:
|
||||
pass
|
||||
|
||||
|
||||
class TorchOptimizerFactory(OptimizerFactory):
|
||||
def __init__(self, optim_class, **kwargs):
|
||||
self.optim_class = optim_class
|
||||
self.kwargs = kwargs
|
||||
|
||||
def create_optimizer(self, module: torch.nn.Module) -> torch.optim.Optimizer:
|
||||
return self.optim_class(module.parameters(), **self.kwargs)
|
||||
|
||||
|
||||
class AdamOptimizerFactory(OptimizerFactory):
|
||||
def __init__(self, lr):
|
||||
self.lr = lr
|
||||
|
||||
def create_optimizer(self, module: torch.nn.Module) -> Adam:
|
||||
return Adam(module.parameters(), lr=self.lr)
|
||||
|
||||
|
||||
class LRSchedulerFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_scheduler(self, optim: torch.optim.Optimizer) -> Optional[LRScheduler]:
|
||||
pass
|
||||
|
||||
|
||||
class LinearLRSchedulerFactory(LRSchedulerFactory):
|
||||
def __init__(self, nn_config: NNConfig, sampling_config: RLSamplingConfig):
|
||||
self.nn_config = nn_config
|
||||
self.sampling_config = sampling_config
|
||||
|
||||
def create_scheduler(self, optim: torch.optim.Optimizer) -> Optional[LRScheduler]:
|
||||
lr_scheduler = None
|
||||
if self.nn_config.lr_decay:
|
||||
max_update_num = np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect) * self.sampling_config.num_epochs
|
||||
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
||||
return lr_scheduler
|
Loading…
x
Reference in New Issue
Block a user