Add SAC high-level interface

This commit is contained in:
Dominik Jain 2023-09-20 09:29:34 +02:00
parent 2a1cc6bb55
commit 316eb3c579
14 changed files with 377 additions and 533 deletions

View File

@ -2,9 +2,9 @@ import warnings
import gymnasium as gym import gymnasium as gym
from tianshou.config import RLSamplingConfig, BasicExperimentConfig from tianshou.config import BasicExperimentConfig, RLSamplingConfig
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
from tianshou.highlevel.env import EnvFactory, Environments, ContinuousEnvironments from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
try: try:
import envpool import envpool
@ -12,9 +12,7 @@ except ImportError:
envpool = None envpool = None
def make_mujoco_env( def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool):
task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool
):
"""Wrapper function for Mujoco env. """Wrapper function for Mujoco env.
If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env. If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env.

View File

@ -1,359 +0,0 @@
#!/usr/bin/env python3
import argparse
import datetime
import os
import pprint
from collections.abc import Sequence
from typing import Literal, Optional, Tuple, Union
import gymnasium as gym
import numpy as np
import torch
from jsonargparse import CLI
from torch import nn
from torch.distributions import Independent, Normal
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
from mujoco_env import make_mujoco_env
from tianshou.config import (
BasicExperimentConfig,
LoggerConfig,
NNConfig,
PGConfig,
PPOConfig,
RLAgentConfig,
RLSamplingConfig,
)
from tianshou.config.utils import collect_configs
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.env import VectorEnvNormObs
from tianshou.policy import BasePolicy, PPOPolicy
from tianshou.trainer import OnpolicyTrainer
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic
def set_seed(seed=42):
np.random.seed(seed)
torch.manual_seed(seed)
def get_logger_for_run(
algo_name: str,
task: str,
logger_config: LoggerConfig,
config: dict,
seed: int,
resume_id: Optional[Union[str, int]],
) -> Tuple[str, Union[WandbLogger, TensorboardLogger]]:
"""
:param algo_name:
:param task:
:param logger_config:
:param config: the experiment config
:param seed:
:param resume_id: used as run_id by wandb, unused for tensorboard
:return:
"""
"""Returns the log_path and logger."""
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, algo_name, str(seed), now)
log_path = os.path.join(logger_config.logdir, log_name)
logger = get_logger(
logger_config.logger,
log_path,
log_name=log_name,
run_id=resume_id,
config=config,
wandb_project=logger_config.wandb_project,
)
return log_path, logger
def get_continuous_env_info(
env: gym.Env,
) -> Tuple[Tuple[int, ...], Tuple[int, ...], float]:
if not isinstance(env.action_space, gym.spaces.Box):
raise ValueError(
"Only environments with continuous action space are supported here. "
f"But got env with action space: {env.action_space.__class__}."
)
state_shape = env.observation_space.shape or env.observation_space.n
if not state_shape:
raise ValueError("Observation space shape is not defined")
action_shape = env.action_space.shape
max_action = env.action_space.high[0]
return state_shape, action_shape, max_action
def resume_from_checkpoint(
path: str,
policy: BasePolicy,
train_envs: VectorEnvNormObs | None = None,
test_envs: VectorEnvNormObs | None = None,
device: str | int | torch.device | None = None,
):
ckpt = torch.load(path, map_location=device)
policy.load_state_dict(ckpt["model"])
if train_envs:
train_envs.set_obs_rms(ckpt["obs_rms"])
if test_envs:
test_envs.set_obs_rms(ckpt["obs_rms"])
print("Loaded agent and obs. running means from: ", path)
def watch_agent(n_episode, policy: BasePolicy, test_collector: Collector, render=0.0):
policy.eval()
test_collector.reset()
result = test_collector.collect(n_episode=n_episode, render=render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
def get_train_test_collector(
buffer_size: int,
policy: BasePolicy,
train_envs: VectorEnvNormObs,
test_envs: VectorEnvNormObs,
):
if len(train_envs) > 1:
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
return test_collector, train_collector
TShape = Union[int, Sequence[int]]
def get_actor_critic(
state_shape: TShape,
hidden_sizes: Sequence[int],
action_shape: TShape,
device: str | int | torch.device = "cpu",
):
net_a = Net(
state_shape, hidden_sizes=hidden_sizes, activation=nn.Tanh, device=device
)
actor = ActorProb(net_a, action_shape, unbounded=True, device=device).to(device)
net_c = Net(
state_shape, hidden_sizes=hidden_sizes, activation=nn.Tanh, device=device
)
# TODO: twice device?
critic = Critic(net_c, device=device).to(device)
return actor, critic
def get_logger(
kind: Literal["wandb", "tensorboard"],
log_path: str,
log_name="",
run_id: Optional[Union[str, int]] = None,
config: Optional[Union[dict, argparse.Namespace]] = None,
wandb_project: Optional[str] = None,
):
writer = SummaryWriter(log_path)
writer.add_text("args", str(config))
if kind == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=run_id,
config=config,
project=wandb_project,
)
logger.load(writer)
elif kind == "tensorboard":
logger = TensorboardLogger(writer)
else:
raise ValueError(f"Unknown logger: {kind}")
return logger
def get_lr_scheduler(optim, step_per_epoch: int, step_per_collect: int, epochs: int):
"""Decay learning rate to 0 linearly."""
max_update_num = np.ceil(step_per_epoch / step_per_collect) * epochs
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
return lr_scheduler
def init_and_get_optim(actor: nn.Module, critic: nn.Module, lr: float):
"""Initializes layers of actor and critic.
:param actor:
:param critic:
:param lr:
:return:
"""
actor_critic = ActorCritic(actor, critic)
torch.nn.init.constant_(actor.sigma_param, -0.5)
for m in actor_critic.modules():
if isinstance(m, torch.nn.Linear):
# orthogonal initialization
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
if hasattr(actor, "mu"):
# For continuous action spaces with Gaussian policies
# do last policy layer scaling, this will make initial actions have (close to)
# 0 mean and std, and will help boost performances,
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
for m in actor.mu.modules():
# TODO: seems like biases are initialized twice for the actor
if isinstance(m, torch.nn.Linear):
torch.nn.init.zeros_(m.bias)
m.weight.data.copy_(0.01 * m.weight.data)
optim = torch.optim.Adam(actor_critic.parameters(), lr=lr)
return optim
def main(
experiment_config: BasicExperimentConfig,
logger_config: LoggerConfig,
sampling_config: RLSamplingConfig,
general_config: RLAgentConfig,
pg_config: PGConfig,
ppo_config: PPOConfig,
nn_config: NNConfig,
):
"""
Run the PPO test on the provided parameters.
:param experiment_config: BasicExperimentConfig - not ML or RL specific
:param logger_config: LoggerConfig
:param sampling_config: SamplingConfig -
sampling, epochs, parallelization, buffers, collectors, and batching.
:param general_config: RLAgentConfig - general RL agent config
:param pg_config: PGConfig: common to most policy gradient algorithms
:param ppo_config: PPOConfig - PPO specific config
:param nn_config: NNConfig - NN-training specific config
:return: None
"""
full_config = collect_configs(*locals().values())
set_seed(experiment_config.seed)
# create test and train envs, add env info to config
env, train_envs, test_envs = make_mujoco_env(
task=experiment_config.task,
seed=experiment_config.seed,
num_train_envs=sampling_config.num_train_envs,
num_test_envs=sampling_config.num_test_envs,
obs_norm=True,
)
# adding env_info to logged config
state_shape, action_shape, max_action = get_continuous_env_info(env)
full_config["env_info"] = {
"state_shape": state_shape,
"action_shape": action_shape,
"max_action": max_action,
}
log_path, logger = get_logger_for_run(
"ppo",
experiment_config.task,
logger_config,
full_config,
experiment_config.seed,
experiment_config.resume_id,
)
# Setup NNs
actor, critic = get_actor_critic(
state_shape, nn_config.hidden_sizes, action_shape, experiment_config.device
)
optim = init_and_get_optim(actor, critic, nn_config.lr)
lr_scheduler = None
if nn_config.lr_decay:
lr_scheduler = get_lr_scheduler(
optim,
sampling_config.step_per_epoch,
sampling_config.step_per_collect,
sampling_config.num_epochs,
)
# Create policy
def dist_fn(*logits):
return Independent(Normal(*logits), 1)
policy = PPOPolicy(
# nn-stuff
actor,
critic,
optim,
dist_fn=dist_fn,
lr_scheduler=lr_scheduler,
# env-stuff
action_space=train_envs.action_space,
action_scaling=True,
# general_config
discount_factor=general_config.gamma,
gae_lambda=general_config.gae_lambda,
reward_normalization=general_config.rew_norm,
action_bound_method=general_config.action_bound_method,
# pg_config
max_grad_norm=pg_config.max_grad_norm,
vf_coef=pg_config.vf_coef,
ent_coef=pg_config.ent_coef,
# ppo_config
eps_clip=ppo_config.eps_clip,
value_clip=ppo_config.value_clip,
dual_clip=ppo_config.dual_clip,
advantage_normalization=ppo_config.norm_adv,
recompute_advantage=ppo_config.recompute_adv,
)
if experiment_config.resume_path:
resume_from_checkpoint(
experiment_config.resume_path,
policy,
train_envs=train_envs,
test_envs=test_envs,
device=experiment_config.device,
)
test_collector, train_collector = get_train_test_collector(
sampling_config.buffer_size, policy, test_envs, train_envs
)
# TODO: test num is the number of test envs but used as episode_per_test
# here and in watch_agent
if not experiment_config.watch:
# RL training
def save_best_fn(pol: nn.Module):
state = {"model": pol.state_dict(), "obs_rms": train_envs.get_obs_rms()}
torch.save(state, os.path.join(log_path, "policy.pth"))
trainer = OnpolicyTrainer(
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch,
repeat_per_collect=sampling_config.repeat_per_collect,
episode_per_test=sampling_config.num_test_envs,
batch_size=sampling_config.batch_size,
step_per_collect=sampling_config.step_per_collect,
save_best_fn=save_best_fn,
logger=logger,
test_in_train=False,
)
result = trainer.run()
pprint.pprint(result)
watch_agent(
sampling_config.num_test_envs,
policy,
test_collector,
render=experiment_config.render,
)
if __name__ == "__main__":
CLI(main)

View File

@ -2,6 +2,8 @@
import datetime import datetime
import os import os
from collections.abc import Sequence
from dataclasses import dataclass
from jsonargparse import CLI from jsonargparse import CLI
from torch.distributions import Independent, Normal from torch.distributions import Independent, Normal
@ -10,17 +12,26 @@ from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.config import ( from tianshou.config import (
BasicExperimentConfig, BasicExperimentConfig,
LoggerConfig, LoggerConfig,
NNConfig,
PGConfig, PGConfig,
PPOConfig, PPOConfig,
RLAgentConfig, RLAgentConfig,
RLSamplingConfig, RLSamplingConfig,
) )
from tianshou.highlevel.agent import PPOAgentFactory from tianshou.highlevel.agent import PPOAgentFactory
from tianshou.highlevel.logger import DefaultLoggerFactory
from tianshou.highlevel.module import ContinuousActorProbFactory, ContinuousNetCriticFactory
from tianshou.highlevel.optim import AdamOptimizerFactory, LinearLRSchedulerFactory
from tianshou.highlevel.experiment import RLExperiment from tianshou.highlevel.experiment import RLExperiment
from tianshou.highlevel.logger import DefaultLoggerFactory
from tianshou.highlevel.module import (
ContinuousActorProbFactory,
ContinuousNetCriticFactory,
)
from tianshou.highlevel.optim import AdamOptimizerFactory, LinearLRSchedulerFactory
@dataclass
class NNConfig:
hidden_sizes: Sequence[int] = (64, 64)
lr: float = 3e-4
lr_decay: bool = True
def main( def main(
@ -43,15 +54,22 @@ def main(
actor_factory = ContinuousActorProbFactory(nn_config.hidden_sizes) actor_factory = ContinuousActorProbFactory(nn_config.hidden_sizes)
critic_factory = ContinuousNetCriticFactory(nn_config.hidden_sizes) critic_factory = ContinuousNetCriticFactory(nn_config.hidden_sizes)
optim_factory = AdamOptimizerFactory(lr=nn_config.lr) optim_factory = AdamOptimizerFactory()
lr_scheduler_factory = LinearLRSchedulerFactory(nn_config, sampling_config) lr_scheduler_factory = LinearLRSchedulerFactory(sampling_config) if nn_config.lr_decay else None
agent_factory = PPOAgentFactory(general_config, pg_config, ppo_config, sampling_config, nn_config, agent_factory = PPOAgentFactory(
actor_factory, critic_factory, optim_factory, dist_fn, lr_scheduler_factory) general_config,
pg_config,
ppo_config,
sampling_config,
actor_factory,
critic_factory,
optim_factory,
dist_fn,
nn_config.lr,
lr_scheduler_factory,
)
experiment = RLExperiment(experiment_config, logger_config, general_config, sampling_config, experiment = RLExperiment(experiment_config, env_factory, logger_factory, agent_factory)
env_factory,
logger_factory,
agent_factory)
experiment.run(log_name) experiment.run(log_name)

View File

@ -0,0 +1,56 @@
#!/usr/bin/env python3
import datetime
import os
from collections.abc import Sequence
from jsonargparse import CLI
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.config import (
BasicExperimentConfig,
LoggerConfig,
RLSamplingConfig,
)
from tianshou.highlevel.agent import SACAgentFactory
from tianshou.highlevel.experiment import RLExperiment
from tianshou.highlevel.logger import DefaultLoggerFactory
from tianshou.highlevel.module import (
ContinuousActorProbFactory,
ContinuousNetCriticFactory,
)
from tianshou.highlevel.optim import AdamOptimizerFactory
def main(
experiment_config: BasicExperimentConfig,
logger_config: LoggerConfig,
sampling_config: RLSamplingConfig,
sac_config: SACAgentFactory.Config,
hidden_sizes: Sequence[int] = (256, 256),
):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(experiment_config.task, "sac", str(experiment_config.seed), now)
logger_factory = DefaultLoggerFactory(logger_config)
env_factory = MujocoEnvFactory(experiment_config, sampling_config)
actor_factory = ContinuousActorProbFactory(hidden_sizes, conditioned_sigma=True)
critic_factory = ContinuousNetCriticFactory(hidden_sizes)
optim_factory = AdamOptimizerFactory()
agent_factory = SACAgentFactory(
sac_config,
sampling_config,
actor_factory,
critic_factory,
critic_factory,
optim_factory,
)
experiment = RLExperiment(experiment_config, env_factory, logger_factory, agent_factory)
experiment.run(log_name)
if __name__ == "__main__":
CLI(main)

View File

@ -63,7 +63,9 @@ envpool = ["envpool"]
optional = true optional = true
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
black = "^23.7.0" black = "^23.7.0"
docstring-parser = "^0.15"
jinja2 = "*" jinja2 = "*"
jsonargparse = "^4.24.1"
mypy = "^1.4.1" mypy = "^1.4.1"
# networkx is used in a test # networkx is used in a test
networkx = "*" networkx = "*"
@ -83,8 +85,6 @@ sphinx_rtd_theme = "*"
sphinxcontrib-bibtex = "*" sphinxcontrib-bibtex = "*"
sphinxcontrib-spelling = "^8.0.0" sphinxcontrib-spelling = "^8.0.0"
wandb = "^0.12.0" wandb = "^0.12.0"
jsonargparse = "^4.24.1"
docstring-parser = "^0.15"
[tool.mypy] [tool.mypy]
allow_redefinition = true allow_redefinition = true

View File

@ -1 +1,10 @@
from .config import * __all__ = ["PGConfig", "PPOConfig", "RLAgentConfig", "RLSamplingConfig", "BasicExperimentConfig", "LoggerConfig"]
from .config import (
BasicExperimentConfig,
PGConfig,
PPOConfig,
RLAgentConfig,
RLSamplingConfig,
LoggerConfig,
)

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal, Optional, Sequence from typing import Literal
import torch import torch
from jsonargparse import set_docstring_parse_options from jsonargparse import set_docstring_parse_options
@ -14,12 +14,12 @@ class BasicExperimentConfig:
seed: int = 42 seed: int = 42
task: str = "Ant-v4" task: str = "Ant-v4"
"""Mujoco specific""" """Mujoco specific"""
render: Optional[float] = 0.0 render: float | None = 0.0
"""Milliseconds between rendered frames; if None, no rendering""" """Milliseconds between rendered frames; if None, no rendering"""
device: str = "cuda" if torch.cuda.is_available() else "cpu" device: str = "cuda" if torch.cuda.is_available() else "cpu"
resume_id: Optional[int] = None resume_id: str | None = None
"""For restoring a model and running means of env-specifics from a checkpoint""" """For restoring a model and running means of env-specifics from a checkpoint"""
resume_path: str = None resume_path: str | None = None
"""For restoring a model and running means of env-specifics from a checkpoint""" """For restoring a model and running means of env-specifics from a checkpoint"""
watch: bool = False watch: bool = False
"""If True, will not perform training and only watch the restored policy""" """If True, will not perform training and only watch the restored policy"""
@ -28,7 +28,7 @@ class BasicExperimentConfig:
@dataclass @dataclass
class LoggerConfig: class LoggerConfig:
"""Logging config""" """Logging config."""
logdir: str = "log" logdir: str = "log"
logger: Literal["tensorboard", "wandb"] = "tensorboard" logger: Literal["tensorboard", "wandb"] = "tensorboard"
@ -48,17 +48,18 @@ class RLSamplingConfig:
buffer_size: int = 4096 buffer_size: int = 4096
step_per_collect: int = 2048 step_per_collect: int = 2048
repeat_per_collect: int = 10 repeat_per_collect: int = 10
update_per_step: int = 1
@dataclass @dataclass
class RLAgentConfig: class RLAgentConfig:
"""Config common to most RL algorithms""" """Config common to most RL algorithms."""
gamma: float = 0.99 gamma: float = 0.99
"""Discount factor""" """Discount factor"""
gae_lambda: float = 0.95 gae_lambda: float = 0.95
"""For Generalized Advantage Estimate (equivalent to TD(lambda))""" """For Generalized Advantage Estimate (equivalent to TD(lambda))"""
action_bound_method: Optional[Literal["clip", "tanh"]] = "clip" action_bound_method: Literal["clip", "tanh"] | None = "clip"
"""How to map original actions in range (-inf, inf) to [-1, 1]""" """How to map original actions in range (-inf, inf) to [-1, 1]"""
rew_norm: bool = True rew_norm: bool = True
"""Whether to normalize rewards""" """Whether to normalize rewards"""
@ -66,7 +67,7 @@ class RLAgentConfig:
@dataclass @dataclass
class PGConfig: class PGConfig:
"""Config of general policy-gradient algorithms""" """Config of general policy-gradient algorithms."""
ent_coef: float = 0.0 ent_coef: float = 0.0
vf_coef: float = 0.25 vf_coef: float = 0.25
@ -75,18 +76,11 @@ class PGConfig:
@dataclass @dataclass
class PPOConfig: class PPOConfig:
"""PPO specific config""" """PPO specific config."""
value_clip: bool = False value_clip: bool = False
norm_adv: bool = False norm_adv: bool = False
"""Whether to normalize advantages""" """Whether to normalize advantages"""
eps_clip: float = 0.2 eps_clip: float = 0.2
dual_clip: Optional[float] = None dual_clip: float | None = None
recompute_adv: bool = True recompute_adv: bool = True
@dataclass
class NNConfig:
hidden_sizes: Sequence[int] = (64, 64)
lr: float = 3e-4
lr_decay: bool = True

View File

@ -2,11 +2,10 @@ from dataclasses import asdict, is_dataclass
def collect_configs(*confs): def collect_configs(*confs):
""" """Collect instances of dataclasses to a single dict mapping the classname to the values.
Collect instances of dataclasses to a single dict mapping the
classname to the values. If any of the passed objects is not a If any of the passed objects is not a ddataclass or if two instances
dataclass or if two instances of the same config class are passed, of the same config class are passed, an error will be raised.
an error will be raised.
:param confs: dataclasses :param confs: dataclasses
:return: Dictionary mapping class names to their instances. :return: Dictionary mapping class names to their instances.

View File

@ -1,33 +1,51 @@
import os import os
from abc import abstractmethod, ABC from abc import ABC, abstractmethod
from typing import Callable from collections.abc import Callable
from dataclasses import dataclass
import torch import torch
from tianshou.config import RLSamplingConfig, PGConfig, PPOConfig, RLAgentConfig, NNConfig from tianshou.config import PGConfig, PPOConfig, RLAgentConfig, RLSamplingConfig
from tianshou.data import VectorReplayBuffer, ReplayBuffer, Collector from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.exploration import BaseNoise
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import Logger from tianshou.highlevel.logger import Logger
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
from tianshou.highlevel.optim import OptimizerFactory, LRSchedulerFactory from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory
from tianshou.policy import BasePolicy, PPOPolicy from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy
from tianshou.trainer import BaseTrainer, OnpolicyTrainer from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.common import ActorCritic
CHECKPOINT_DICT_KEY_MODEL = "model" CHECKPOINT_DICT_KEY_MODEL = "model"
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms" CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
class AgentFactory(ABC): class AgentFactory(ABC):
def __init__(self, sampling_config: RLSamplingConfig):
self.sampling_config = sampling_config
def create_train_test_collector(self, policy: BasePolicy, envs: Environments):
buffer_size = self.sampling_config.buffer_size
train_envs = envs.train_envs
if len(train_envs) > 1:
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, envs.test_envs)
return train_collector, test_collector
@abstractmethod @abstractmethod
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
pass pass
@staticmethod @staticmethod
def _create_save_best_fn(envs: Environments, log_path: str) -> Callable: def _create_save_best_fn(envs: Environments, log_path: str) -> Callable:
def save_best_fn(pol: torch.nn.Module): def save_best_fn(pol: torch.nn.Module) -> None:
state = {"model": pol.state_dict(), "obs_rms": envs.train_envs.get_obs_rms()} state = {
CHECKPOINT_DICT_KEY_MODEL: pol.state_dict(),
CHECKPOINT_DICT_KEY_OBS_RMS: envs.train_envs.get_obs_rms(),
}
torch.save(state, os.path.join(log_path, "policy.pth")) torch.save(state, os.path.join(log_path, "policy.pth"))
return save_best_fn return save_best_fn
@ -43,36 +61,26 @@ class AgentFactory(ABC):
print("Loaded agent and obs. running means from: ", path) # TODO logging print("Loaded agent and obs. running means from: ", path) # TODO logging
@abstractmethod @abstractmethod
def create_train_test_collector(self, def create_trainer(
policy: BasePolicy, self,
envs: Environments): policy: BasePolicy,
pass train_collector: Collector,
test_collector: Collector,
@abstractmethod envs: Environments,
def create_trainer(self, policy: BasePolicy, train_collector: Collector, test_collector: Collector, logger: Logger,
envs: Environments, logger: Logger) -> BaseTrainer: ) -> BaseTrainer:
pass pass
class OnpolicyAgentFactory(AgentFactory, ABC): class OnpolicyAgentFactory(AgentFactory, ABC):
def __init__(self, sampling_config: RLSamplingConfig): def create_trainer(
self.sampling_config = sampling_config self,
policy: BasePolicy,
def create_train_test_collector(self, train_collector: Collector,
policy: BasePolicy, test_collector: Collector,
envs: Environments): envs: Environments,
buffer_size = self.sampling_config.buffer_size logger: Logger,
train_envs = envs.train_envs ) -> OnpolicyTrainer:
if len(train_envs) > 1:
buffer = VectorReplayBuffer(buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, envs.test_envs)
return train_collector, test_collector
def create_trainer(self, policy: BasePolicy, train_collector: Collector, test_collector: Collector,
envs: Environments, logger: Logger) -> OnpolicyTrainer:
sampling_config = self.sampling_config sampling_config = self.sampling_config
return OnpolicyTrainer( return OnpolicyTrainer(
policy=policy, policy=policy,
@ -90,17 +98,46 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
) )
class OffpolicyAgentFactory(AgentFactory, ABC):
def create_trainer(
self,
policy: BasePolicy,
train_collector: Collector,
test_collector: Collector,
envs: Environments,
logger: Logger,
) -> OffpolicyTrainer:
sampling_config = self.sampling_config
return OffpolicyTrainer(
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch,
step_per_collect=sampling_config.step_per_collect,
episode_per_test=sampling_config.num_test_envs,
batch_size=sampling_config.batch_size,
save_best_fn=self._create_save_best_fn(envs, logger.log_path),
logger=logger.logger,
update_per_step=sampling_config.update_per_step,
test_in_train=False,
)
class PPOAgentFactory(OnpolicyAgentFactory): class PPOAgentFactory(OnpolicyAgentFactory):
def __init__(self, general_config: RLAgentConfig, def __init__(
pg_config: PGConfig, self,
ppo_config: PPOConfig, general_config: RLAgentConfig,
sampling_config: RLSamplingConfig, pg_config: PGConfig,
nn_config: NNConfig, ppo_config: PPOConfig,
actor_factory: ActorFactory, sampling_config: RLSamplingConfig,
critic_factory: CriticFactory, actor_factory: ActorFactory,
optimizer_factory: OptimizerFactory, critic_factory: CriticFactory,
dist_fn, optimizer_factory: OptimizerFactory,
lr_scheduler_factory: LRSchedulerFactory): dist_fn,
lr: float,
lr_scheduler_factory: LRSchedulerFactory | None = None,
):
super().__init__(sampling_config) super().__init__(sampling_config)
self.optimizer_factory = optimizer_factory self.optimizer_factory = optimizer_factory
self.critic_factory = critic_factory self.critic_factory = critic_factory
@ -108,16 +145,19 @@ class PPOAgentFactory(OnpolicyAgentFactory):
self.ppo_config = ppo_config self.ppo_config = ppo_config
self.pg_config = pg_config self.pg_config = pg_config
self.general_config = general_config self.general_config = general_config
self.lr = lr
self.lr_scheduler_factory = lr_scheduler_factory self.lr_scheduler_factory = lr_scheduler_factory
self.dist_fn = dist_fn self.dist_fn = dist_fn
self.nn_config = nn_config
def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy: def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
actor = self.actor_factory.create_module(envs, device) actor = self.actor_factory.create_module(envs, device)
critic = self.critic_factory.create_module(envs, device) critic = self.critic_factory.create_module(envs, device, use_action=False)
actor_critic = ActorCritic(actor, critic) actor_critic = ActorCritic(actor, critic)
optim = self.optimizer_factory.create_optimizer(actor_critic) optim = self.optimizer_factory.create_optimizer(actor_critic, self.lr)
lr_scheduler = self.lr_scheduler_factory.create_scheduler(optim) if self.lr_scheduler_factory is not None:
lr_scheduler = self.lr_scheduler_factory.create_scheduler(optim)
else:
lr_scheduler = None
return PPOPolicy( return PPOPolicy(
# nn-stuff # nn-stuff
actor, actor,
@ -144,3 +184,60 @@ class PPOAgentFactory(OnpolicyAgentFactory):
advantage_normalization=self.ppo_config.norm_adv, advantage_normalization=self.ppo_config.norm_adv,
recompute_advantage=self.ppo_config.recompute_adv, recompute_advantage=self.ppo_config.recompute_adv,
) )
class SACAgentFactory(OffpolicyAgentFactory):
def __init__(
self,
config: "SACAgentFactory.Config",
sampling_config: RLSamplingConfig,
actor_factory: ActorFactory,
critic1_factory: CriticFactory,
critic2_factory: CriticFactory,
optim_factory: OptimizerFactory,
exploration_noise: BaseNoise | None = None,
):
super().__init__(sampling_config)
self.critic2_factory = critic2_factory
self.critic1_factory = critic1_factory
self.actor_factory = actor_factory
self.exploration_noise = exploration_noise
self.optim_factory = optim_factory
self.config = config
def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
actor = self.actor_factory.create_module(envs, device)
critic1 = self.critic1_factory.create_module(envs, device, use_action=True)
critic2 = self.critic2_factory.create_module(envs, device, use_action=True)
actor_optim = self.optim_factory.create_optimizer(actor, lr=self.config.actor_lr)
critic1_optim = self.optim_factory.create_optimizer(critic1, lr=self.config.critic1_lr)
critic2_optim = self.optim_factory.create_optimizer(critic2, lr=self.config.critic2_lr)
return SACPolicy(
actor,
actor_optim,
critic1,
critic1_optim,
critic2,
critic2_optim,
tau=self.config.tau,
gamma=self.config.gamma,
alpha=self.config.alpha,
estimation_step=self.config.estimation_step,
action_space=envs.get_action_space(),
deterministic_eval=self.config.deterministic_eval,
exploration_noise=self.exploration_noise,
)
@dataclass
class Config:
"""SAC configuration."""
tau: float = 0.005
gamma: float = 0.99
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2
reward_normalization: bool = False
estimation_step: int = 1
deterministic_eval: bool = True
actor_lr: float = 1e-3
critic1_lr: float = 1e-3
critic2_lr: float = 1e-3

View File

@ -1,24 +1,22 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Tuple, Optional, Dict, Any, Union, Sequence from collections.abc import Sequence
from typing import Any
import gymnasium as gym import gymnasium as gym
from tianshou.env import BaseVectorEnv from tianshou.env import BaseVectorEnv
TShape = Union[int, Sequence[int]] TShape = int | Sequence[int]
class Environments(ABC): class Environments(ABC):
def __init__(self, env: Optional[gym.Env], train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
self.env = env self.env = env
self.train_envs = train_envs self.train_envs = train_envs
self.test_envs = test_envs self.test_envs = test_envs
def info(self) -> Dict[str, Any]: def info(self) -> dict[str, Any]:
return { return {"action_shape": self.get_action_shape(), "state_shape": self.get_state_shape()}
"action_shape": self.get_action_shape(),
"state_shape": self.get_state_shape()
}
@abstractmethod @abstractmethod
def get_action_shape(self) -> TShape: def get_action_shape(self) -> TShape:
@ -33,7 +31,7 @@ class Environments(ABC):
class ContinuousEnvironments(Environments): class ContinuousEnvironments(Environments):
def __init__(self, env: Optional[gym.Env], train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
super().__init__(env, train_envs, test_envs) super().__init__(env, train_envs, test_envs)
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env) self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
@ -44,12 +42,12 @@ class ContinuousEnvironments(Environments):
@staticmethod @staticmethod
def _get_continuous_env_info( def _get_continuous_env_info(
env: gym.Env, env: gym.Env,
) -> Tuple[Tuple[int, ...], Tuple[int, ...], float]: ) -> tuple[tuple[int, ...], tuple[int, ...], float]:
if not isinstance(env.action_space, gym.spaces.Box): if not isinstance(env.action_space, gym.spaces.Box):
raise ValueError( raise ValueError(
"Only environments with continuous action space are supported here. " "Only environments with continuous action space are supported here. "
f"But got env with action space: {env.action_space.__class__}." f"But got env with action space: {env.action_space.__class__}.",
) )
state_shape = env.observation_space.shape or env.observation_space.n state_shape = env.observation_space.shape or env.observation_space.n
if not state_shape: if not state_shape:

View File

@ -4,7 +4,9 @@ from typing import Generic, TypeVar
import numpy as np import numpy as np
import torch import torch
from tianshou.config import BasicExperimentConfig, LoggerConfig, RLAgentConfig, RLSamplingConfig from tianshou.config import (
BasicExperimentConfig,
)
from tianshou.data import Collector from tianshou.data import Collector
from tianshou.highlevel.agent import AgentFactory from tianshou.highlevel.agent import AgentFactory
from tianshou.highlevel.env import EnvFactory from tianshou.highlevel.env import EnvFactory
@ -17,23 +19,19 @@ TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
class RLExperiment(Generic[TPolicy, TTrainer]): class RLExperiment(Generic[TPolicy, TTrainer]):
def __init__(self, def __init__(
config: BasicExperimentConfig, self,
logger_config: LoggerConfig, config: BasicExperimentConfig,
general_config: RLAgentConfig, env_factory: EnvFactory,
sampling_config: RLSamplingConfig, logger_factory: LoggerFactory,
env_factory: EnvFactory, agent_factory: AgentFactory,
logger_factory: LoggerFactory, ):
agent_factory: AgentFactory):
self.config = config self.config = config
self.logger_config = logger_config
self.general_config = general_config
self.sampling_config = sampling_config
self.env_factory = env_factory self.env_factory = env_factory
self.logger_factory = logger_factory self.logger_factory = logger_factory
self.agent_factory = agent_factory self.agent_factory = agent_factory
def _set_seed(self): def _set_seed(self) -> None:
seed = self.config.seed seed = self.config.seed
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
@ -43,7 +41,7 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
# TODO # TODO
} }
def run(self, log_name: str): def run(self, log_name: str) -> None:
self._set_seed() self._set_seed()
envs = self.env_factory.create_envs() envs = self.env_factory.create_envs()
@ -52,25 +50,47 @@ class RLExperiment(Generic[TPolicy, TTrainer]):
full_config.update(envs.info()) full_config.update(envs.info())
run_id = self.config.resume_id run_id = self.config.resume_id
logger = self.logger_factory.create_logger(log_name=log_name, run_id=run_id, config_dict=full_config) logger = self.logger_factory.create_logger(
log_name=log_name,
run_id=run_id,
config_dict=full_config,
)
policy = self.agent_factory.create_policy(envs, self.config.device) policy = self.agent_factory.create_policy(envs, self.config.device)
if self.config.resume_path: if self.config.resume_path:
self.agent_factory.load_checkpoint(policy, self.config.resume_path, envs, self.config.device) self.agent_factory.load_checkpoint(
policy,
self.config.resume_path,
envs,
self.config.device,
)
train_collector, test_collector = self.agent_factory.create_train_test_collector(policy, envs) train_collector, test_collector = self.agent_factory.create_train_test_collector(
policy,
envs,
)
if not self.config.watch: if not self.config.watch:
trainer = self.agent_factory.create_trainer(policy, train_collector, test_collector, envs, logger) trainer = self.agent_factory.create_trainer(
policy,
train_collector,
test_collector,
envs,
logger,
)
result = trainer.run() result = trainer.run()
pprint(result) # TODO logging pprint(result) # TODO logging
self._watch_agent(self.config.watch_num_episodes, policy, test_collector, self.config.render) self._watch_agent(
self.config.watch_num_episodes,
policy,
test_collector,
self.config.render,
)
@staticmethod @staticmethod
def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render): def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render) -> None:
policy.eval() policy.eval()
test_collector.reset() test_collector.reset()
result = test_collector.collect(n_episode=num_episodes, render=render) result = test_collector.collect(n_episode=num_episodes, render=render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')

View File

@ -1,15 +1,13 @@
from abc import ABC, abstractmethod
import os import os
from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Union, Optional
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.config import LoggerConfig from tianshou.config import LoggerConfig
from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils import TensorboardLogger, WandbLogger
TLogger = TensorboardLogger | WandbLogger
TLogger = Union[TensorboardLogger, WandbLogger]
@dataclass @dataclass
@ -20,7 +18,7 @@ class Logger:
class LoggerFactory(ABC): class LoggerFactory(ABC):
@abstractmethod @abstractmethod
def create_logger(self, log_name: str, run_id: Optional[int], config_dict: dict) -> Logger: def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger:
pass pass
@ -28,7 +26,7 @@ class DefaultLoggerFactory(LoggerFactory):
def __init__(self, config: LoggerConfig): def __init__(self, config: LoggerConfig):
self.config = config self.config = config
def create_logger(self, log_name: str, run_id: Optional[int], config_dict: dict) -> Logger: def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger:
writer = SummaryWriter(self.config.logdir) writer = SummaryWriter(self.config.logdir)
writer.add_text("args", str(self.config)) writer.add_text("args", str(self.config))
if self.config.logger == "wandb": if self.config.logger == "wandb":

View File

@ -1,24 +1,24 @@
from abc import abstractmethod, ABC from abc import ABC, abstractmethod
from typing import Sequence from collections.abc import Sequence
import numpy as np
import torch import torch
from torch import nn from torch import nn
import numpy as np
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic as ContinuousCritic from tianshou.utils.net.continuous import ActorProb
from tianshou.utils.net.continuous import Critic as ContinuousCritic
TDevice = str | int | torch.device TDevice = str | int | torch.device
def init_linear_orthogonal(m: torch.nn.Module): def init_linear_orthogonal(module: torch.nn.Module):
""" """Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0.
Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0
:param m: the module whose submodules are to be processed :param module: the module whose submodules are to be processed
""" """
for m in m.modules(): for m in module.modules():
if isinstance(m, torch.nn.Linear): if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias) torch.nn.init.zeros_(m.bias)
@ -31,9 +31,9 @@ class ActorFactory(ABC):
@staticmethod @staticmethod
def _init_linear(actor: torch.nn.Module): def _init_linear(actor: torch.nn.Module):
""" """Initializes linear layers of an actor module using default mechanisms.
Initializes linear layers of an actor module using default mechanisms
:param module: the actor module :param module: the actor module.
""" """
init_linear_orthogonal(actor) init_linear_orthogonal(actor)
if hasattr(actor, "mu"): if hasattr(actor, "mu"):
@ -51,17 +51,29 @@ class ContinuousActorFactory(ActorFactory, ABC):
class ContinuousActorProbFactory(ContinuousActorFactory): class ContinuousActorProbFactory(ContinuousActorFactory):
def __init__(self, hidden_sizes: Sequence[int]): def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False):
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
self.unbounded = unbounded
self.conditioned_sigma = conditioned_sigma
def create_module(self, envs: Environments, device: TDevice) -> nn.Module: def create_module(self, envs: Environments, device: TDevice) -> nn.Module:
net_a = Net( net_a = Net(
envs.get_state_shape(), hidden_sizes=self.hidden_sizes, activation=nn.Tanh, device=device envs.get_state_shape(),
hidden_sizes=self.hidden_sizes,
activation=nn.Tanh,
device=device,
) )
actor = ActorProb(net_a, envs.get_action_shape(), unbounded=True, device=device).to(device) actor = ActorProb(
net_a,
envs.get_action_shape(),
unbounded=True,
device=device,
conditioned_sigma=self.conditioned_sigma,
).to(device)
# init params # init params
torch.nn.init.constant_(actor.sigma_param, -0.5) if not self.conditioned_sigma:
torch.nn.init.constant_(actor.sigma_param, -0.5)
self._init_linear(actor) self._init_linear(actor)
return actor return actor
@ -69,7 +81,7 @@ class ContinuousActorProbFactory(ContinuousActorFactory):
class CriticFactory(ABC): class CriticFactory(ABC):
@abstractmethod @abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> nn.Module: def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
pass pass
@ -78,12 +90,19 @@ class ContinuousCriticFactory(CriticFactory, ABC):
class ContinuousNetCriticFactory(ContinuousCriticFactory): class ContinuousNetCriticFactory(ContinuousCriticFactory):
def __init__(self, hidden_sizes: Sequence[int]): def __init__(self, hidden_sizes: Sequence[int], action_shape=0):
self.action_shape = action_shape
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice) -> nn.Module: def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
action_shape = envs.get_action_shape() if use_action else 0
net_c = Net( net_c = Net(
envs.get_state_shape(), hidden_sizes=self.hidden_sizes, activation=nn.Tanh, device=device envs.get_state_shape(),
action_shape=action_shape,
hidden_sizes=self.hidden_sizes,
concat=use_action,
activation=nn.Tanh,
device=device,
) )
critic = ContinuousCritic(net_c, device=device).to(device) critic = ContinuousCritic(net_c, device=device).to(device)
init_linear_orthogonal(critic) init_linear_orthogonal(critic)

View File

@ -1,54 +1,51 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Union, Iterable, Dict, Any, Optional from collections.abc import Iterable
from typing import Any, Type
import numpy as np import numpy as np
import torch import torch
from torch import Tensor from torch import Tensor
from torch.optim import Adam from torch.optim import Adam
from torch.optim.lr_scheduler import LRScheduler, LambdaLR from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from tianshou.config import RLSamplingConfig, NNConfig from tianshou.config import RLSamplingConfig
TParams = Union[Iterable[Tensor], Iterable[Dict[str, Any]]] TParams = Iterable[Tensor] | Iterable[dict[str, Any]]
class OptimizerFactory(ABC): class OptimizerFactory(ABC):
@abstractmethod @abstractmethod
def create_optimizer(self, module: torch.nn.Module) -> torch.optim.Optimizer: def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
pass pass
class TorchOptimizerFactory(OptimizerFactory): class TorchOptimizerFactory(OptimizerFactory):
def __init__(self, optim_class, **kwargs): def __init__(self, optim_class: Any, **kwargs):
self.optim_class = optim_class self.optim_class = optim_class
self.kwargs = kwargs self.kwargs = kwargs
def create_optimizer(self, module: torch.nn.Module) -> torch.optim.Optimizer: def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
return self.optim_class(module.parameters(), **self.kwargs) return self.optim_class(module.parameters(), lr=lr, **self.kwargs)
class AdamOptimizerFactory(OptimizerFactory): class AdamOptimizerFactory(OptimizerFactory):
def __init__(self, lr): def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam:
self.lr = lr return Adam(module.parameters(), lr=lr)
def create_optimizer(self, module: torch.nn.Module) -> Adam:
return Adam(module.parameters(), lr=self.lr)
class LRSchedulerFactory(ABC): class LRSchedulerFactory(ABC):
@abstractmethod @abstractmethod
def create_scheduler(self, optim: torch.optim.Optimizer) -> Optional[LRScheduler]: def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
pass pass
class LinearLRSchedulerFactory(LRSchedulerFactory): class LinearLRSchedulerFactory(LRSchedulerFactory):
def __init__(self, nn_config: NNConfig, sampling_config: RLSamplingConfig): def __init__(self, sampling_config: RLSamplingConfig):
self.nn_config = nn_config
self.sampling_config = sampling_config self.sampling_config = sampling_config
def create_scheduler(self, optim: torch.optim.Optimizer) -> Optional[LRScheduler]: def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
lr_scheduler = None max_update_num = (
if self.nn_config.lr_decay: np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect)
max_update_num = np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect) * self.sampling_config.num_epochs * self.sampling_config.num_epochs
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) )
return lr_scheduler return LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)