Initial high-level interfaces, demonstrated in mujoco_ppo_hl

This commit is contained in:
Dominik Jain 2023-09-19 18:53:11 +02:00
parent a54aade730
commit 16ed5fd2a5
10 changed files with 567 additions and 2 deletions

View File

@ -2,7 +2,9 @@ import warnings
import gymnasium as gym
from tianshou.config import RLSamplingConfig, BasicExperimentConfig
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
from tianshou.highlevel.env import EnvFactory, Environments, ContinuousEnvironments
try:
import envpool
@ -38,3 +40,19 @@ def make_mujoco_env(
test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False)
test_envs.set_obs_rms(train_envs.get_obs_rms())
return env, train_envs, test_envs
class MujocoEnvFactory(EnvFactory):
def __init__(self, experiment_config: BasicExperimentConfig, sampling_config: RLSamplingConfig):
self.sampling_config = sampling_config
self.experiment_config = experiment_config
def create_envs(self) -> ContinuousEnvironments:
env, train_envs, test_envs = make_mujoco_env(
task=self.experiment_config.task,
seed=self.experiment_config.seed,
num_train_envs=self.sampling_config.num_train_envs,
num_test_envs=self.sampling_config.num_test_envs,
obs_norm=True,
)
return ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)

View File

@ -0,0 +1,60 @@
#!/usr/bin/env python3
import datetime
import os
from jsonargparse import CLI
from torch.distributions import Independent, Normal
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.config import (
BasicExperimentConfig,
LoggerConfig,
NNConfig,
PGConfig,
PPOConfig,
RLAgentConfig,
RLSamplingConfig,
)
from tianshou.highlevel.agent import PPOAgentFactory
from tianshou.highlevel.logger import DefaultLoggerFactory
from tianshou.highlevel.module import ContinuousActorProbFactory, ContinuousNetCriticFactory
from tianshou.highlevel.optim import AdamOptimizerFactory, LinearLRSchedulerFactory
from tianshou.highlevel.experiment import RLExperiment
def main(
experiment_config: BasicExperimentConfig,
logger_config: LoggerConfig,
sampling_config: RLSamplingConfig,
general_config: RLAgentConfig,
pg_config: PGConfig,
ppo_config: PPOConfig,
nn_config: NNConfig,
):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(experiment_config.task, "ppo", str(experiment_config.seed), now)
logger_factory = DefaultLoggerFactory(logger_config)
env_factory = MujocoEnvFactory(experiment_config, sampling_config)
def dist_fn(*logits):
return Independent(Normal(*logits), 1)
actor_factory = ContinuousActorProbFactory(nn_config.hidden_sizes)
critic_factory = ContinuousNetCriticFactory(nn_config.hidden_sizes)
optim_factory = AdamOptimizerFactory(lr=nn_config.lr)
lr_scheduler_factory = LinearLRSchedulerFactory(nn_config, sampling_config)
agent_factory = PPOAgentFactory(general_config, pg_config, ppo_config, sampling_config, nn_config,
actor_factory, critic_factory, optim_factory, dist_fn, lr_scheduler_factory)
experiment = RLExperiment(experiment_config, logger_config, general_config, sampling_config,
env_factory,
logger_factory,
agent_factory)
experiment.run(log_name)
if __name__ == "__main__":
CLI(main)

View File

@ -14,8 +14,8 @@ class BasicExperimentConfig:
seed: int = 42
task: str = "Ant-v4"
"""Mujoco specific"""
render: float = 0.0
"""Milliseconds between rendered frames"""
render: Optional[float] = 0.0
"""Milliseconds between rendered frames; if None, no rendering"""
device: str = "cuda" if torch.cuda.is_available() else "cpu"
resume_id: Optional[int] = None
"""For restoring a model and running means of env-specifics from a checkpoint"""
@ -23,6 +23,7 @@ class BasicExperimentConfig:
"""For restoring a model and running means of env-specifics from a checkpoint"""
watch: bool = False
"""If True, will not perform training and only watch the restored policy"""
watch_num_episodes = 10
@dataclass

View File

146
tianshou/highlevel/agent.py Normal file
View File

@ -0,0 +1,146 @@
import os
from abc import abstractmethod, ABC
from typing import Callable
import torch
from tianshou.config import RLSamplingConfig, PGConfig, PPOConfig, RLAgentConfig, NNConfig
from tianshou.data import VectorReplayBuffer, ReplayBuffer, Collector
from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import Logger
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
from tianshou.highlevel.optim import OptimizerFactory, LRSchedulerFactory
from tianshou.policy import BasePolicy, PPOPolicy
from tianshou.trainer import BaseTrainer, OnpolicyTrainer
from tianshou.utils.net.common import ActorCritic
CHECKPOINT_DICT_KEY_MODEL = "model"
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
class AgentFactory(ABC):
@abstractmethod
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
pass
@staticmethod
def _create_save_best_fn(envs: Environments, log_path: str) -> Callable:
def save_best_fn(pol: torch.nn.Module):
state = {"model": pol.state_dict(), "obs_rms": envs.train_envs.get_obs_rms()}
torch.save(state, os.path.join(log_path, "policy.pth"))
return save_best_fn
@staticmethod
def load_checkpoint(policy: torch.nn.Module, path, envs: Environments, device: TDevice):
ckpt = torch.load(path, map_location=device)
policy.load_state_dict(ckpt[CHECKPOINT_DICT_KEY_MODEL])
if envs.train_envs:
envs.train_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS])
if envs.test_envs:
envs.test_envs.set_obs_rms(ckpt[CHECKPOINT_DICT_KEY_OBS_RMS])
print("Loaded agent and obs. running means from: ", path) # TODO logging
@abstractmethod
def create_train_test_collector(self,
policy: BasePolicy,
envs: Environments):
pass
@abstractmethod
def create_trainer(self, policy: BasePolicy, train_collector: Collector, test_collector: Collector,
envs: Environments, logger: Logger) -> BaseTrainer:
pass
class OnpolicyAgentFactory(AgentFactory, ABC):
def __init__(self, sampling_config: RLSamplingConfig):
self.sampling_config = sampling_config
def create_train_test_collector(self,
policy: BasePolicy,
envs: Environments):
buffer_size = self.sampling_config.buffer_size
train_envs = envs.train_envs
if len(train_envs) > 1:
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, envs.test_envs)
return train_collector, test_collector
def create_trainer(self, policy: BasePolicy, train_collector: Collector, test_collector: Collector,
envs: Environments, logger: Logger) -> OnpolicyTrainer:
sampling_config = self.sampling_config
return OnpolicyTrainer(
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch,
repeat_per_collect=sampling_config.repeat_per_collect,
episode_per_test=sampling_config.num_test_envs,
batch_size=sampling_config.batch_size,
step_per_collect=sampling_config.step_per_collect,
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
logger=logger.logger,
test_in_train=False,
)
class PPOAgentFactory(OnpolicyAgentFactory):
def __init__(self, general_config: RLAgentConfig,
pg_config: PGConfig,
ppo_config: PPOConfig,
sampling_config: RLSamplingConfig,
nn_config: NNConfig,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optimizer_factory: OptimizerFactory,
dist_fn,
lr_scheduler_factory: LRSchedulerFactory):
super().__init__(sampling_config)
self.optimizer_factory = optimizer_factory
self.critic_factory = critic_factory
self.actor_factory = actor_factory
self.ppo_config = ppo_config
self.pg_config = pg_config
self.general_config = general_config
self.lr_scheduler_factory = lr_scheduler_factory
self.dist_fn = dist_fn
self.nn_config = nn_config
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
actor = self.actor_factory.create_module(envs, device)
critic = self.critic_factory.create_module(envs, device)
actor_critic = ActorCritic(actor, critic)
optim = self.optimizer_factory.create_optimizer(actor_critic)
lr_scheduler = self.lr_scheduler_factory.create_scheduler(optim)
return PPOPolicy(
# nn-stuff
actor,
critic,
optim,
dist_fn=self.dist_fn,
lr_scheduler=lr_scheduler,
# env-stuff
action_space=envs.get_action_space(),
action_scaling=True,
# general_config
discount_factor=self.general_config.gamma,
gae_lambda=self.general_config.gae_lambda,
reward_normalization=self.general_config.rew_norm,
action_bound_method=self.general_config.action_bound_method,
# pg_config
max_grad_norm=self.pg_config.max_grad_norm,
vf_coef=self.pg_config.vf_coef,
ent_coef=self.pg_config.ent_coef,
# ppo_config
eps_clip=self.ppo_config.eps_clip,
value_clip=self.ppo_config.value_clip,
dual_clip=self.ppo_config.dual_clip,
advantage_normalization=self.ppo_config.norm_adv,
recompute_advantage=self.ppo_config.recompute_adv,
)

71
tianshou/highlevel/env.py Normal file
View File

@ -0,0 +1,71 @@
from abc import ABC, abstractmethod
from typing import Tuple, Optional, Dict, Any, Union, Sequence
import gymnasium as gym
from tianshou.env import BaseVectorEnv
TShape = Union[int, Sequence[int]]
class Environments(ABC):
def __init__(self, env: Optional[gym.Env], train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
self.env = env
self.train_envs = train_envs
self.test_envs = test_envs
def info(self) -> Dict[str, Any]:
return {
"action_shape": self.get_action_shape(),
"state_shape": self.get_state_shape()
}
@abstractmethod
def get_action_shape(self) -> TShape:
pass
@abstractmethod
def get_state_shape(self) -> TShape:
pass
def get_action_space(self) -> gym.Space:
return self.env.action_space
class ContinuousEnvironments(Environments):
def __init__(self, env: Optional[gym.Env], train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
super().__init__(env, train_envs, test_envs)
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
def info(self):
d = super().info()
d["max_action"] = self.max_action
return d
@staticmethod
def _get_continuous_env_info(
env: gym.Env,
) -> Tuple[Tuple[int, ...], Tuple[int, ...], float]:
if not isinstance(env.action_space, gym.spaces.Box):
raise ValueError(
"Only environments with continuous action space are supported here. "
f"But got env with action space: {env.action_space.__class__}."
)
state_shape = env.observation_space.shape or env.observation_space.n
if not state_shape:
raise ValueError("Observation space shape is not defined")
action_shape = env.action_space.shape
max_action = env.action_space.high[0]
return state_shape, action_shape, max_action
def get_action_shape(self) -> TShape:
return self.action_shape
def get_state_shape(self) -> TShape:
return self.state_shape
class EnvFactory(ABC):
@abstractmethod
def create_envs(self) -> Environments:
pass

View File

@ -0,0 +1,76 @@
from pprint import pprint
from typing import Generic, TypeVar
import numpy as np
import torch
from tianshou.config import BasicExperimentConfig, LoggerConfig, RLAgentConfig, RLSamplingConfig
from tianshou.data import Collector
from tianshou.highlevel.agent import AgentFactory
from tianshou.highlevel.env import EnvFactory
from tianshou.highlevel.logger import LoggerFactory
from tianshou.policy import BasePolicy
from tianshou.trainer import BaseTrainer
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
class RLExperiment(Generic[TPolicy, TTrainer]):
def __init__(self,
config: BasicExperimentConfig,
logger_config: LoggerConfig,
general_config: RLAgentConfig,
sampling_config: RLSamplingConfig,
env_factory: EnvFactory,
logger_factory: LoggerFactory,
agent_factory: AgentFactory):
self.config = config
self.logger_config = logger_config
self.general_config = general_config
self.sampling_config = sampling_config
self.env_factory = env_factory
self.logger_factory = logger_factory
self.agent_factory = agent_factory
def _set_seed(self):
seed = self.config.seed
np.random.seed(seed)
torch.manual_seed(seed)
def _build_config_dict(self) -> dict:
return {
# TODO
}
def run(self, log_name: str):
self._set_seed()
envs = self.env_factory.create_envs()
full_config = self._build_config_dict()
full_config.update(envs.info())
run_id = self.config.resume_id
logger = self.logger_factory.create_logger(log_name=log_name, run_id=run_id, config_dict=full_config)
policy = self.agent_factory.create_policy(envs, self.config.device)
if self.config.resume_path:
self.agent_factory.load_checkpoint(policy, self.config.resume_path, envs, self.config.device)
train_collector, test_collector = self.agent_factory.create_train_test_collector(policy, envs)
if not self.config.watch:
trainer = self.agent_factory.create_trainer(policy, train_collector, test_collector, envs, logger)
result = trainer.run()
pprint(result) # TODO logging
self._watch_agent(self.config.watch_num_episodes, policy, test_collector, self.config.render)
@staticmethod
def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render):
policy.eval()
test_collector.reset()
result = test_collector.collect(n_episode=num_episodes, render=render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')

View File

@ -0,0 +1,49 @@
from abc import ABC, abstractmethod
import os
from dataclasses import dataclass
from typing import Union, Optional
from torch.utils.tensorboard import SummaryWriter
from tianshou.config import LoggerConfig
from tianshou.utils import TensorboardLogger, WandbLogger
TLogger = Union[TensorboardLogger, WandbLogger]
@dataclass
class Logger:
logger: TLogger
log_path: str
class LoggerFactory(ABC):
@abstractmethod
def create_logger(self, log_name: str, run_id: Optional[int], config_dict: dict) -> Logger:
pass
class DefaultLoggerFactory(LoggerFactory):
def __init__(self, config: LoggerConfig):
self.config = config
def create_logger(self, log_name: str, run_id: Optional[int], config_dict: dict) -> Logger:
writer = SummaryWriter(self.config.logdir)
writer.add_text("args", str(self.config))
if self.config.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=run_id,
config=config_dict,
project=self.config.wandb_project,
)
logger.load(writer)
elif self.config.logger == "tensorboard":
logger = TensorboardLogger(writer)
else:
raise ValueError(f"Unknown logger: {self.config.logger}")
log_path = os.path.join(self.config.logdir, log_name)
os.makedirs(log_path, exist_ok=True)
return Logger(logger=logger, log_path=log_path)

View File

@ -0,0 +1,90 @@
from abc import abstractmethod, ABC
from typing import Sequence
import torch
from torch import nn
import numpy as np
from tianshou.highlevel.env import Environments
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic as ContinuousCritic
TDevice = str | int | torch.device
def init_linear_orthogonal(m: torch.nn.Module):
"""
Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0
:param m: the module whose submodules are to be processed
"""
for m in m.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
class ActorFactory(ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
pass
@staticmethod
def _init_linear(actor: torch.nn.Module):
"""
Initializes linear layers of an actor module using default mechanisms
:param module: the actor module
"""
init_linear_orthogonal(actor)
if hasattr(actor, "mu"):
# For continuous action spaces with Gaussian policies
# do last policy layer scaling, this will make initial actions have (close to)
# 0 mean and std, and will help boost performances,
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
for m in actor.mu.modules():
if isinstance(m, torch.nn.Linear):
m.weight.data.copy_(0.01 * m.weight.data)
class ContinuousActorFactory(ActorFactory, ABC):
pass
class ContinuousActorProbFactory(ContinuousActorFactory):
def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
net_a = Net(
envs.get_state_shape(), hidden_sizes=self.hidden_sizes, activation=nn.Tanh, device=device
)
actor = ActorProb(net_a, envs.get_action_shape(), unbounded=True, device=device).to(device)
# init params
torch.nn.init.constant_(actor.sigma_param, -0.5)
self._init_linear(actor)
return actor
class CriticFactory(ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
pass
class ContinuousCriticFactory(CriticFactory, ABC):
pass
class ContinuousNetCriticFactory(ContinuousCriticFactory):
def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
net_c = Net(
envs.get_state_shape(), hidden_sizes=self.hidden_sizes, activation=nn.Tanh, device=device
)
critic = ContinuousCritic(net_c, device=device).to(device)
init_linear_orthogonal(critic)
return critic

View File

@ -0,0 +1,54 @@
from abc import ABC, abstractmethod
from typing import Union, Iterable, Dict, Any, Optional
import numpy as np
import torch
from torch import Tensor
from torch.optim import Adam
from torch.optim.lr_scheduler import LRScheduler, LambdaLR
from tianshou.config import RLSamplingConfig, NNConfig
TParams = Union[Iterable[Tensor], Iterable[Dict[str, Any]]]
class OptimizerFactory(ABC):
@abstractmethod
def create_optimizer(self, module: torch.nn.Module) -> torch.optim.Optimizer:
pass
class TorchOptimizerFactory(OptimizerFactory):
def __init__(self, optim_class, **kwargs):
self.optim_class = optim_class
self.kwargs = kwargs
def create_optimizer(self, module: torch.nn.Module) -> torch.optim.Optimizer:
return self.optim_class(module.parameters(), **self.kwargs)
class AdamOptimizerFactory(OptimizerFactory):
def __init__(self, lr):
self.lr = lr
def create_optimizer(self, module: torch.nn.Module) -> Adam:
return Adam(module.parameters(), lr=self.lr)
class LRSchedulerFactory(ABC):
@abstractmethod
def create_scheduler(self, optim: torch.optim.Optimizer) -> Optional[LRScheduler]:
pass
class LinearLRSchedulerFactory(LRSchedulerFactory):
def __init__(self, nn_config: NNConfig, sampling_config: RLSamplingConfig):
self.nn_config = nn_config
self.sampling_config = sampling_config
def create_scheduler(self, optim: torch.optim.Optimizer) -> Optional[LRScheduler]:
lr_scheduler = None
if self.nn_config.lr_decay:
max_update_num = np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect) * self.sampling_config.num_epochs
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
return lr_scheduler