Initial high-level interfaces, demonstrated in mujoco_ppo_hl
This commit is contained in:
parent
a54aade730
commit
16ed5fd2a5
@ -2,7 +2,9 @@ import warnings
|
|||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
|
from tianshou.config import RLSamplingConfig, BasicExperimentConfig
|
||||||
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
||||||
|
from tianshou.highlevel.env import EnvFactory, Environments, ContinuousEnvironments
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import envpool
|
import envpool
|
||||||
@ -38,3 +40,19 @@ def make_mujoco_env(
|
|||||||
test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False)
|
test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False)
|
||||||
test_envs.set_obs_rms(train_envs.get_obs_rms())
|
test_envs.set_obs_rms(train_envs.get_obs_rms())
|
||||||
return env, train_envs, test_envs
|
return env, train_envs, test_envs
|
||||||
|
|
||||||
|
|
||||||
|
class MujocoEnvFactory(EnvFactory):
|
||||||
|
def __init__(self, experiment_config: BasicExperimentConfig, sampling_config: RLSamplingConfig):
|
||||||
|
self.sampling_config = sampling_config
|
||||||
|
self.experiment_config = experiment_config
|
||||||
|
|
||||||
|
def create_envs(self) -> ContinuousEnvironments:
|
||||||
|
env, train_envs, test_envs = make_mujoco_env(
|
||||||
|
task=self.experiment_config.task,
|
||||||
|
seed=self.experiment_config.seed,
|
||||||
|
num_train_envs=self.sampling_config.num_train_envs,
|
||||||
|
num_test_envs=self.sampling_config.num_test_envs,
|
||||||
|
obs_norm=True,
|
||||||
|
)
|
||||||
|
return ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
|
||||||
|
60
examples/mujoco/mujoco_ppo_hl.py
Normal file
60
examples/mujoco/mujoco_ppo_hl.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
|
||||||
|
from jsonargparse import CLI
|
||||||
|
from torch.distributions import Independent, Normal
|
||||||
|
|
||||||
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
|
from tianshou.config import (
|
||||||
|
BasicExperimentConfig,
|
||||||
|
LoggerConfig,
|
||||||
|
NNConfig,
|
||||||
|
PGConfig,
|
||||||
|
PPOConfig,
|
||||||
|
RLAgentConfig,
|
||||||
|
RLSamplingConfig,
|
||||||
|
)
|
||||||
|
from tianshou.highlevel.agent import PPOAgentFactory
|
||||||
|
from tianshou.highlevel.logger import DefaultLoggerFactory
|
||||||
|
from tianshou.highlevel.module import ContinuousActorProbFactory, ContinuousNetCriticFactory
|
||||||
|
from tianshou.highlevel.optim import AdamOptimizerFactory, LinearLRSchedulerFactory
|
||||||
|
from tianshou.highlevel.experiment import RLExperiment
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
experiment_config: BasicExperimentConfig,
|
||||||
|
logger_config: LoggerConfig,
|
||||||
|
sampling_config: RLSamplingConfig,
|
||||||
|
general_config: RLAgentConfig,
|
||||||
|
pg_config: PGConfig,
|
||||||
|
ppo_config: PPOConfig,
|
||||||
|
nn_config: NNConfig,
|
||||||
|
):
|
||||||
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
|
log_name = os.path.join(experiment_config.task, "ppo", str(experiment_config.seed), now)
|
||||||
|
logger_factory = DefaultLoggerFactory(logger_config)
|
||||||
|
|
||||||
|
env_factory = MujocoEnvFactory(experiment_config, sampling_config)
|
||||||
|
|
||||||
|
def dist_fn(*logits):
|
||||||
|
return Independent(Normal(*logits), 1)
|
||||||
|
|
||||||
|
actor_factory = ContinuousActorProbFactory(nn_config.hidden_sizes)
|
||||||
|
critic_factory = ContinuousNetCriticFactory(nn_config.hidden_sizes)
|
||||||
|
optim_factory = AdamOptimizerFactory(lr=nn_config.lr)
|
||||||
|
lr_scheduler_factory = LinearLRSchedulerFactory(nn_config, sampling_config)
|
||||||
|
agent_factory = PPOAgentFactory(general_config, pg_config, ppo_config, sampling_config, nn_config,
|
||||||
|
actor_factory, critic_factory, optim_factory, dist_fn, lr_scheduler_factory)
|
||||||
|
|
||||||
|
experiment = RLExperiment(experiment_config, logger_config, general_config, sampling_config,
|
||||||
|
env_factory,
|
||||||
|
logger_factory,
|
||||||
|
agent_factory)
|
||||||
|
|
||||||
|
experiment.run(log_name)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
CLI(main)
|
@ -14,8 +14,8 @@ class BasicExperimentConfig:
|
|||||||
seed: int = 42
|
seed: int = 42
|
||||||
task: str = "Ant-v4"
|
task: str = "Ant-v4"
|
||||||
"""Mujoco specific"""
|
"""Mujoco specific"""
|
||||||
render: float = 0.0
|
render: Optional[float] = 0.0
|
||||||
"""Milliseconds between rendered frames"""
|
"""Milliseconds between rendered frames; if None, no rendering"""
|
||||||
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
resume_id: Optional[int] = None
|
resume_id: Optional[int] = None
|
||||||
"""For restoring a model and running means of env-specifics from a checkpoint"""
|
"""For restoring a model and running means of env-specifics from a checkpoint"""
|
||||||
@ -23,6 +23,7 @@ class BasicExperimentConfig:
|
|||||||
"""For restoring a model and running means of env-specifics from a checkpoint"""
|
"""For restoring a model and running means of env-specifics from a checkpoint"""
|
||||||
watch: bool = False
|
watch: bool = False
|
||||||
"""If True, will not perform training and only watch the restored policy"""
|
"""If True, will not perform training and only watch the restored policy"""
|
||||||
|
watch_num_episodes = 10
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
0
tianshou/highlevel/__init__.py
Normal file
0
tianshou/highlevel/__init__.py
Normal file
146
tianshou/highlevel/agent.py
Normal file
146
tianshou/highlevel/agent.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
import os
|
||||||
|
from abc import abstractmethod, ABC
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tianshou.config import RLSamplingConfig, PGConfig, PPOConfig, RLAgentConfig, NNConfig
|
||||||
|
from tianshou.data import VectorReplayBuffer, ReplayBuffer, Collector
|
||||||
|
from tianshou.highlevel.env import Environments
|
||||||
|
from tianshou.highlevel.logger import Logger
|
||||||
|
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
|
||||||
|
from tianshou.highlevel.optim import OptimizerFactory, LRSchedulerFactory
|
||||||
|
from tianshou.policy import BasePolicy, PPOPolicy
|
||||||
|
from tianshou.trainer import BaseTrainer, OnpolicyTrainer
|
||||||
|
from tianshou.utils.net.common import ActorCritic
|
||||||
|
|
||||||
|
|
||||||
|
CHECKPOINT_DICT_KEY_MODEL = "model"
|
||||||
|
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
||||||
|
|
||||||
|
|
||||||
|
class AgentFactory(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_save_best_fn(envs: Environments, log_path: str) -> Callable:
|
||||||
|
def save_best_fn(pol: torch.nn.Module):
|
||||||
|
state = {"model": pol.state_dict(), "obs_rms": envs.train_envs.get_obs_rms()}
|
||||||
|
torch.save(state, os.path.join(log_path, "policy.pth"))
|
||||||
|
|
||||||
|
return save_best_fn
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_checkpoint(policy: torch.nn.Module, path, envs: Environments, device: TDevice):
|
||||||
|
ckpt = torch.load(path, map_location=device)
|
||||||
|
policy.load_state_dict(ckpt[CHECKPOINT_DICT_KEY_MODEL])
|
||||||
|
if envs.train_envs:
|
||||||
|
envs.train_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS])
|
||||||
|
if envs.test_envs:
|
||||||
|
envs.test_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS])
|
||||||
|
print("Loaded agent and obs. running means from: ", path) # TODO logging
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_train_test_collector(self,
|
||||||
|
policy: BasePolicy,
|
||||||
|
envs: Environments):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_trainer(self, policy: BasePolicy, train_collector: Collector, test_collector: Collector,
|
||||||
|
envs: Environments, logger: Logger) -> BaseTrainer:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OnpolicyAgentFactory(AgentFactory, ABC):
|
||||||
|
def __init__(self, sampling_config: RLSamplingConfig):
|
||||||
|
self.sampling_config = sampling_config
|
||||||
|
|
||||||
|
def create_train_test_collector(self,
|
||||||
|
policy: BasePolicy,
|
||||||
|
envs: Environments):
|
||||||
|
buffer_size = self.sampling_config.buffer_size
|
||||||
|
train_envs = envs.train_envs
|
||||||
|
if len(train_envs) > 1:
|
||||||
|
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
|
||||||
|
else:
|
||||||
|
buffer = ReplayBuffer(buffer_size)
|
||||||
|
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||||
|
test_collector = Collector(policy, envs.test_envs)
|
||||||
|
return train_collector, test_collector
|
||||||
|
|
||||||
|
def create_trainer(self, policy: BasePolicy, train_collector: Collector, test_collector: Collector,
|
||||||
|
envs: Environments, logger: Logger) -> OnpolicyTrainer:
|
||||||
|
sampling_config = self.sampling_config
|
||||||
|
return OnpolicyTrainer(
|
||||||
|
policy=policy,
|
||||||
|
train_collector=train_collector,
|
||||||
|
test_collector=test_collector,
|
||||||
|
max_epoch=sampling_config.num_epochs,
|
||||||
|
step_per_epoch=sampling_config.step_per_epoch,
|
||||||
|
repeat_per_collect=sampling_config.repeat_per_collect,
|
||||||
|
episode_per_test=sampling_config.num_test_envs,
|
||||||
|
batch_size=sampling_config.batch_size,
|
||||||
|
step_per_collect=sampling_config.step_per_collect,
|
||||||
|
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
|
||||||
|
logger=logger.logger,
|
||||||
|
test_in_train=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PPOAgentFactory(OnpolicyAgentFactory):
|
||||||
|
def __init__(self, general_config: RLAgentConfig,
|
||||||
|
pg_config: PGConfig,
|
||||||
|
ppo_config: PPOConfig,
|
||||||
|
sampling_config: RLSamplingConfig,
|
||||||
|
nn_config: NNConfig,
|
||||||
|
actor_factory: ActorFactory,
|
||||||
|
critic_factory: CriticFactory,
|
||||||
|
optimizer_factory: OptimizerFactory,
|
||||||
|
dist_fn,
|
||||||
|
lr_scheduler_factory: LRSchedulerFactory):
|
||||||
|
super().__init__(sampling_config)
|
||||||
|
self.optimizer_factory = optimizer_factory
|
||||||
|
self.critic_factory = critic_factory
|
||||||
|
self.actor_factory = actor_factory
|
||||||
|
self.ppo_config = ppo_config
|
||||||
|
self.pg_config = pg_config
|
||||||
|
self.general_config = general_config
|
||||||
|
self.lr_scheduler_factory = lr_scheduler_factory
|
||||||
|
self.dist_fn = dist_fn
|
||||||
|
self.nn_config = nn_config
|
||||||
|
|
||||||
|
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
|
||||||
|
actor = self.actor_factory.create_module(envs, device)
|
||||||
|
critic = self.critic_factory.create_module(envs, device)
|
||||||
|
actor_critic = ActorCritic(actor, critic)
|
||||||
|
optim = self.optimizer_factory.create_optimizer(actor_critic)
|
||||||
|
lr_scheduler = self.lr_scheduler_factory.create_scheduler(optim)
|
||||||
|
return PPOPolicy(
|
||||||
|
# nn-stuff
|
||||||
|
actor,
|
||||||
|
critic,
|
||||||
|
optim,
|
||||||
|
dist_fn=self.dist_fn,
|
||||||
|
lr_scheduler=lr_scheduler,
|
||||||
|
# env-stuff
|
||||||
|
action_space=envs.get_action_space(),
|
||||||
|
action_scaling=True,
|
||||||
|
# general_config
|
||||||
|
discount_factor=self.general_config.gamma,
|
||||||
|
gae_lambda=self.general_config.gae_lambda,
|
||||||
|
reward_normalization=self.general_config.rew_norm,
|
||||||
|
action_bound_method=self.general_config.action_bound_method,
|
||||||
|
# pg_config
|
||||||
|
max_grad_norm=self.pg_config.max_grad_norm,
|
||||||
|
vf_coef=self.pg_config.vf_coef,
|
||||||
|
ent_coef=self.pg_config.ent_coef,
|
||||||
|
# ppo_config
|
||||||
|
eps_clip=self.ppo_config.eps_clip,
|
||||||
|
value_clip=self.ppo_config.value_clip,
|
||||||
|
dual_clip=self.ppo_config.dual_clip,
|
||||||
|
advantage_normalization=self.ppo_config.norm_adv,
|
||||||
|
recompute_advantage=self.ppo_config.recompute_adv,
|
||||||
|
)
|
71
tianshou/highlevel/env.py
Normal file
71
tianshou/highlevel/env.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Tuple, Optional, Dict, Any, Union, Sequence
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
|
||||||
|
from tianshou.env import BaseVectorEnv
|
||||||
|
|
||||||
|
TShape = Union[int, Sequence[int]]
|
||||||
|
|
||||||
|
|
||||||
|
class Environments(ABC):
|
||||||
|
def __init__(self, env: Optional[gym.Env], train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||||
|
self.env = env
|
||||||
|
self.train_envs = train_envs
|
||||||
|
self.test_envs = test_envs
|
||||||
|
|
||||||
|
def info(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"action_shape": self.get_action_shape(),
|
||||||
|
"state_shape": self.get_state_shape()
|
||||||
|
}
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_action_shape(self) -> TShape:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_state_shape(self) -> TShape:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_action_space(self) -> gym.Space:
|
||||||
|
return self.env.action_space
|
||||||
|
|
||||||
|
|
||||||
|
class ContinuousEnvironments(Environments):
|
||||||
|
def __init__(self, env: Optional[gym.Env], train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||||
|
super().__init__(env, train_envs, test_envs)
|
||||||
|
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
|
||||||
|
|
||||||
|
def info(self):
|
||||||
|
d = super().info()
|
||||||
|
d["max_action"] = self.max_action
|
||||||
|
return d
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_continuous_env_info(
|
||||||
|
env: gym.Env,
|
||||||
|
) -> Tuple[Tuple[int, ...], Tuple[int, ...], float]:
|
||||||
|
if not isinstance(env.action_space, gym.spaces.Box):
|
||||||
|
raise ValueError(
|
||||||
|
"Only environments with continuous action space are supported here. "
|
||||||
|
f"But got env with action space: {env.action_space.__class__}."
|
||||||
|
)
|
||||||
|
state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
|
if not state_shape:
|
||||||
|
raise ValueError("Observation space shape is not defined")
|
||||||
|
action_shape = env.action_space.shape
|
||||||
|
max_action = env.action_space.high[0]
|
||||||
|
return state_shape, action_shape, max_action
|
||||||
|
|
||||||
|
def get_action_shape(self) -> TShape:
|
||||||
|
return self.action_shape
|
||||||
|
|
||||||
|
def get_state_shape(self) -> TShape:
|
||||||
|
return self.state_shape
|
||||||
|
|
||||||
|
|
||||||
|
class EnvFactory(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def create_envs(self) -> Environments:
|
||||||
|
pass
|
76
tianshou/highlevel/experiment.py
Normal file
76
tianshou/highlevel/experiment.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
from pprint import pprint
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tianshou.config import BasicExperimentConfig, LoggerConfig, RLAgentConfig, RLSamplingConfig
|
||||||
|
from tianshou.data import Collector
|
||||||
|
from tianshou.highlevel.agent import AgentFactory
|
||||||
|
from tianshou.highlevel.env import EnvFactory
|
||||||
|
from tianshou.highlevel.logger import LoggerFactory
|
||||||
|
from tianshou.policy import BasePolicy
|
||||||
|
from tianshou.trainer import BaseTrainer
|
||||||
|
|
||||||
|
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||||
|
TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
|
||||||
|
|
||||||
|
|
||||||
|
class RLExperiment(Generic[TPolicy, TTrainer]):
|
||||||
|
def __init__(self,
|
||||||
|
config: BasicExperimentConfig,
|
||||||
|
logger_config: LoggerConfig,
|
||||||
|
general_config: RLAgentConfig,
|
||||||
|
sampling_config: RLSamplingConfig,
|
||||||
|
env_factory: EnvFactory,
|
||||||
|
logger_factory: LoggerFactory,
|
||||||
|
agent_factory: AgentFactory):
|
||||||
|
self.config = config
|
||||||
|
self.logger_config = logger_config
|
||||||
|
self.general_config = general_config
|
||||||
|
self.sampling_config = sampling_config
|
||||||
|
self.env_factory = env_factory
|
||||||
|
self.logger_factory = logger_factory
|
||||||
|
self.agent_factory = agent_factory
|
||||||
|
|
||||||
|
def _set_seed(self):
|
||||||
|
seed = self.config.seed
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
def _build_config_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
# TODO
|
||||||
|
}
|
||||||
|
|
||||||
|
def run(self, log_name: str):
|
||||||
|
self._set_seed()
|
||||||
|
|
||||||
|
envs = self.env_factory.create_envs()
|
||||||
|
|
||||||
|
full_config = self._build_config_dict()
|
||||||
|
full_config.update(envs.info())
|
||||||
|
|
||||||
|
run_id = self.config.resume_id
|
||||||
|
logger = self.logger_factory.create_logger(log_name=log_name, run_id=run_id, config_dict=full_config)
|
||||||
|
|
||||||
|
policy = self.agent_factory.create_policy(envs, self.config.device)
|
||||||
|
if self.config.resume_path:
|
||||||
|
self.agent_factory.load_checkpoint(policy, self.config.resume_path, envs, self.config.device)
|
||||||
|
|
||||||
|
train_collector, test_collector = self.agent_factory.create_train_test_collector(policy, envs)
|
||||||
|
|
||||||
|
if not self.config.watch:
|
||||||
|
trainer = self.agent_factory.create_trainer(policy, train_collector, test_collector, envs, logger)
|
||||||
|
result = trainer.run()
|
||||||
|
pprint(result) # TODO logging
|
||||||
|
|
||||||
|
self._watch_agent(self.config.watch_num_episodes, policy, test_collector, self.config.render)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render):
|
||||||
|
policy.eval()
|
||||||
|
test_collector.reset()
|
||||||
|
result = test_collector.collect(n_episode=num_episodes, render=render)
|
||||||
|
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
||||||
|
|
49
tianshou/highlevel/logger.py
Normal file
49
tianshou/highlevel/logger.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Union, Optional
|
||||||
|
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.config import LoggerConfig
|
||||||
|
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||||
|
|
||||||
|
|
||||||
|
TLogger = Union[TensorboardLogger, WandbLogger]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Logger:
|
||||||
|
logger: TLogger
|
||||||
|
log_path: str
|
||||||
|
|
||||||
|
|
||||||
|
class LoggerFactory(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def create_logger(self, log_name: str, run_id: Optional[int], config_dict: dict) -> Logger:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultLoggerFactory(LoggerFactory):
|
||||||
|
def __init__(self, config: LoggerConfig):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def create_logger(self, log_name: str, run_id: Optional[int], config_dict: dict) -> Logger:
|
||||||
|
writer = SummaryWriter(self.config.logdir)
|
||||||
|
writer.add_text("args", str(self.config))
|
||||||
|
if self.config.logger == "wandb":
|
||||||
|
logger = WandbLogger(
|
||||||
|
save_interval=1,
|
||||||
|
name=log_name.replace(os.path.sep, "__"),
|
||||||
|
run_id=run_id,
|
||||||
|
config=config_dict,
|
||||||
|
project=self.config.wandb_project,
|
||||||
|
)
|
||||||
|
logger.load(writer)
|
||||||
|
elif self.config.logger == "tensorboard":
|
||||||
|
logger = TensorboardLogger(writer)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown logger: {self.config.logger}")
|
||||||
|
log_path = os.path.join(self.config.logdir, log_name)
|
||||||
|
os.makedirs(log_path, exist_ok=True)
|
||||||
|
return Logger(logger=logger, log_path=log_path)
|
90
tianshou/highlevel/module.py
Normal file
90
tianshou/highlevel/module.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
from abc import abstractmethod, ABC
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tianshou.highlevel.env import Environments
|
||||||
|
from tianshou.utils.net.common import Net
|
||||||
|
from tianshou.utils.net.continuous import ActorProb, Critic as ContinuousCritic
|
||||||
|
|
||||||
|
TDevice = str | int | torch.device
|
||||||
|
|
||||||
|
|
||||||
|
def init_linear_orthogonal(m: torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0
|
||||||
|
|
||||||
|
:param m: the module whose submodules are to be processed
|
||||||
|
"""
|
||||||
|
for m in m.modules():
|
||||||
|
if isinstance(m, torch.nn.Linear):
|
||||||
|
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
||||||
|
torch.nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
|
||||||
|
class ActorFactory(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _init_linear(actor: torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Initializes linear layers of an actor module using default mechanisms
|
||||||
|
:param module: the actor module
|
||||||
|
"""
|
||||||
|
init_linear_orthogonal(actor)
|
||||||
|
if hasattr(actor, "mu"):
|
||||||
|
# For continuous action spaces with Gaussian policies
|
||||||
|
# do last policy layer scaling, this will make initial actions have (close to)
|
||||||
|
# 0 mean and std, and will help boost performances,
|
||||||
|
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
|
||||||
|
for m in actor.mu.modules():
|
||||||
|
if isinstance(m, torch.nn.Linear):
|
||||||
|
m.weight.data.copy_(0.01 * m.weight.data)
|
||||||
|
|
||||||
|
|
||||||
|
class ContinuousActorFactory(ActorFactory, ABC):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ContinuousActorProbFactory(ContinuousActorFactory):
|
||||||
|
def __init__(self, hidden_sizes: Sequence[int]):
|
||||||
|
self.hidden_sizes = hidden_sizes
|
||||||
|
|
||||||
|
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||||
|
net_a = Net(
|
||||||
|
envs.get_state_shape(), hidden_sizes=self.hidden_sizes, activation=nn.Tanh, device=device
|
||||||
|
)
|
||||||
|
actor = ActorProb(net_a, envs.get_action_shape(), unbounded=True, device=device).to(device)
|
||||||
|
|
||||||
|
# init params
|
||||||
|
torch.nn.init.constant_(actor.sigma_param, -0.5)
|
||||||
|
self._init_linear(actor)
|
||||||
|
|
||||||
|
return actor
|
||||||
|
|
||||||
|
|
||||||
|
class CriticFactory(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ContinuousCriticFactory(CriticFactory, ABC):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ContinuousNetCriticFactory(ContinuousCriticFactory):
|
||||||
|
def __init__(self, hidden_sizes: Sequence[int]):
|
||||||
|
self.hidden_sizes = hidden_sizes
|
||||||
|
|
||||||
|
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
|
||||||
|
net_c = Net(
|
||||||
|
envs.get_state_shape(), hidden_sizes=self.hidden_sizes, activation=nn.Tanh, device=device
|
||||||
|
)
|
||||||
|
critic = ContinuousCritic(net_c, device=device).to(device)
|
||||||
|
init_linear_orthogonal(critic)
|
||||||
|
return critic
|
54
tianshou/highlevel/optim.py
Normal file
54
tianshou/highlevel/optim.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Union, Iterable, Dict, Any, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.optim import Adam
|
||||||
|
from torch.optim.lr_scheduler import LRScheduler, LambdaLR
|
||||||
|
|
||||||
|
from tianshou.config import RLSamplingConfig, NNConfig
|
||||||
|
|
||||||
|
TParams = Union[Iterable[Tensor], Iterable[Dict[str, Any]]]
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerFactory(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def create_optimizer(self, module: torch.nn.Module) -> torch.optim.Optimizer:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TorchOptimizerFactory(OptimizerFactory):
|
||||||
|
def __init__(self, optim_class, **kwargs):
|
||||||
|
self.optim_class = optim_class
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def create_optimizer(self, module: torch.nn.Module) -> torch.optim.Optimizer:
|
||||||
|
return self.optim_class(module.parameters(), **self.kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class AdamOptimizerFactory(OptimizerFactory):
|
||||||
|
def __init__(self, lr):
|
||||||
|
self.lr = lr
|
||||||
|
|
||||||
|
def create_optimizer(self, module: torch.nn.Module) -> Adam:
|
||||||
|
return Adam(module.parameters(), lr=self.lr)
|
||||||
|
|
||||||
|
|
||||||
|
class LRSchedulerFactory(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def create_scheduler(self, optim: torch.optim.Optimizer) -> Optional[LRScheduler]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LinearLRSchedulerFactory(LRSchedulerFactory):
|
||||||
|
def __init__(self, nn_config: NNConfig, sampling_config: RLSamplingConfig):
|
||||||
|
self.nn_config = nn_config
|
||||||
|
self.sampling_config = sampling_config
|
||||||
|
|
||||||
|
def create_scheduler(self, optim: torch.optim.Optimizer) -> Optional[LRScheduler]:
|
||||||
|
lr_scheduler = None
|
||||||
|
if self.nn_config.lr_decay:
|
||||||
|
max_update_num = np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect) * self.sampling_config.num_epochs
|
||||||
|
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
||||||
|
return lr_scheduler
|
Loading…
x
Reference in New Issue
Block a user