Refactoring, dropping package config
This commit is contained in:
parent
316eb3c579
commit
997b520580
examples/mujoco
tianshou
@ -2,9 +2,9 @@ import warnings
|
|||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
from tianshou.config import BasicExperimentConfig, RLSamplingConfig
|
|
||||||
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
||||||
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
||||||
|
from tianshou.highlevel.experiment import RLSamplingConfig
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import envpool
|
import envpool
|
||||||
@ -41,14 +41,15 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in
|
|||||||
|
|
||||||
|
|
||||||
class MujocoEnvFactory(EnvFactory):
|
class MujocoEnvFactory(EnvFactory):
|
||||||
def __init__(self, experiment_config: BasicExperimentConfig, sampling_config: RLSamplingConfig):
|
def __init__(self, task: str, seed: int, sampling_config: RLSamplingConfig):
|
||||||
|
self.task = task
|
||||||
self.sampling_config = sampling_config
|
self.sampling_config = sampling_config
|
||||||
self.experiment_config = experiment_config
|
self.seed = seed
|
||||||
|
|
||||||
def create_envs(self) -> ContinuousEnvironments:
|
def create_envs(self) -> ContinuousEnvironments:
|
||||||
env, train_envs, test_envs = make_mujoco_env(
|
env, train_envs, test_envs = make_mujoco_env(
|
||||||
task=self.experiment_config.task,
|
task=self.task,
|
||||||
seed=self.experiment_config.seed,
|
seed=self.seed,
|
||||||
num_train_envs=self.sampling_config.num_train_envs,
|
num_train_envs=self.sampling_config.num_train_envs,
|
||||||
num_test_envs=self.sampling_config.num_test_envs,
|
num_test_envs=self.sampling_config.num_test_envs,
|
||||||
obs_norm=True,
|
obs_norm=True,
|
||||||
|
@ -9,17 +9,13 @@ from jsonargparse import CLI
|
|||||||
from torch.distributions import Independent, Normal
|
from torch.distributions import Independent, Normal
|
||||||
|
|
||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
from tianshou.config import (
|
from tianshou.highlevel.agent import PGConfig, PPOAgentFactory, PPOConfig, RLAgentConfig
|
||||||
BasicExperimentConfig,
|
from tianshou.highlevel.experiment import (
|
||||||
LoggerConfig,
|
RLExperiment,
|
||||||
PGConfig,
|
RLExperimentConfig,
|
||||||
PPOConfig,
|
|
||||||
RLAgentConfig,
|
|
||||||
RLSamplingConfig,
|
RLSamplingConfig,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.agent import PPOAgentFactory
|
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerConfig
|
||||||
from tianshou.highlevel.experiment import RLExperiment
|
|
||||||
from tianshou.highlevel.logger import DefaultLoggerFactory
|
|
||||||
from tianshou.highlevel.module import (
|
from tianshou.highlevel.module import (
|
||||||
ContinuousActorProbFactory,
|
ContinuousActorProbFactory,
|
||||||
ContinuousNetCriticFactory,
|
ContinuousNetCriticFactory,
|
||||||
@ -35,19 +31,20 @@ class NNConfig:
|
|||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
experiment_config: BasicExperimentConfig,
|
experiment_config: RLExperimentConfig,
|
||||||
logger_config: LoggerConfig,
|
logger_config: LoggerConfig,
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: RLSamplingConfig,
|
||||||
general_config: RLAgentConfig,
|
general_config: RLAgentConfig,
|
||||||
pg_config: PGConfig,
|
pg_config: PGConfig,
|
||||||
ppo_config: PPOConfig,
|
ppo_config: PPOConfig,
|
||||||
nn_config: NNConfig,
|
nn_config: NNConfig,
|
||||||
|
task: str = "Ant-v4",
|
||||||
):
|
):
|
||||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
log_name = os.path.join(experiment_config.task, "ppo", str(experiment_config.seed), now)
|
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
||||||
logger_factory = DefaultLoggerFactory(logger_config)
|
logger_factory = DefaultLoggerFactory(logger_config)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(experiment_config, sampling_config)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||||
|
|
||||||
def dist_fn(*logits):
|
def dist_fn(*logits):
|
||||||
return Independent(Normal(*logits), 1)
|
return Independent(Normal(*logits), 1)
|
||||||
|
@ -7,14 +7,13 @@ from collections.abc import Sequence
|
|||||||
from jsonargparse import CLI
|
from jsonargparse import CLI
|
||||||
|
|
||||||
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
from examples.mujoco.mujoco_env import MujocoEnvFactory
|
||||||
from tianshou.config import (
|
from tianshou.highlevel.agent import SACAgentFactory, SACConfig
|
||||||
BasicExperimentConfig,
|
from tianshou.highlevel.experiment import (
|
||||||
LoggerConfig,
|
RLExperiment,
|
||||||
|
RLExperimentConfig,
|
||||||
RLSamplingConfig,
|
RLSamplingConfig,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.agent import SACAgentFactory
|
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerConfig
|
||||||
from tianshou.highlevel.experiment import RLExperiment
|
|
||||||
from tianshou.highlevel.logger import DefaultLoggerFactory
|
|
||||||
from tianshou.highlevel.module import (
|
from tianshou.highlevel.module import (
|
||||||
ContinuousActorProbFactory,
|
ContinuousActorProbFactory,
|
||||||
ContinuousNetCriticFactory,
|
ContinuousNetCriticFactory,
|
||||||
@ -23,17 +22,18 @@ from tianshou.highlevel.optim import AdamOptimizerFactory
|
|||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
experiment_config: BasicExperimentConfig,
|
experiment_config: RLExperimentConfig,
|
||||||
logger_config: LoggerConfig,
|
logger_config: LoggerConfig,
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: RLSamplingConfig,
|
||||||
sac_config: SACAgentFactory.Config,
|
sac_config: SACConfig,
|
||||||
hidden_sizes: Sequence[int] = (256, 256),
|
hidden_sizes: Sequence[int] = (256, 256),
|
||||||
|
task: str = "Ant-v4",
|
||||||
):
|
):
|
||||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
log_name = os.path.join(experiment_config.task, "sac", str(experiment_config.seed), now)
|
log_name = os.path.join(task, "sac", str(experiment_config.seed), now)
|
||||||
logger_factory = DefaultLoggerFactory(logger_config)
|
logger_factory = DefaultLoggerFactory(logger_config)
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(experiment_config, sampling_config)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||||
|
|
||||||
actor_factory = ContinuousActorProbFactory(hidden_sizes, conditioned_sigma=True)
|
actor_factory = ContinuousActorProbFactory(hidden_sizes, conditioned_sigma=True)
|
||||||
critic_factory = ContinuousNetCriticFactory(hidden_sizes)
|
critic_factory = ContinuousNetCriticFactory(hidden_sizes)
|
||||||
|
@ -1,10 +0,0 @@
|
|||||||
__all__ = ["PGConfig", "PPOConfig", "RLAgentConfig", "RLSamplingConfig", "BasicExperimentConfig", "LoggerConfig"]
|
|
||||||
|
|
||||||
from .config import (
|
|
||||||
BasicExperimentConfig,
|
|
||||||
PGConfig,
|
|
||||||
PPOConfig,
|
|
||||||
RLAgentConfig,
|
|
||||||
RLSamplingConfig,
|
|
||||||
LoggerConfig,
|
|
||||||
)
|
|
@ -1,86 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from jsonargparse import set_docstring_parse_options
|
|
||||||
|
|
||||||
set_docstring_parse_options(attribute_docstrings=True)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BasicExperimentConfig:
|
|
||||||
"""Generic config for setting up the experiment, not RL or training specific."""
|
|
||||||
|
|
||||||
seed: int = 42
|
|
||||||
task: str = "Ant-v4"
|
|
||||||
"""Mujoco specific"""
|
|
||||||
render: float | None = 0.0
|
|
||||||
"""Milliseconds between rendered frames; if None, no rendering"""
|
|
||||||
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
resume_id: str | None = None
|
|
||||||
"""For restoring a model and running means of env-specifics from a checkpoint"""
|
|
||||||
resume_path: str | None = None
|
|
||||||
"""For restoring a model and running means of env-specifics from a checkpoint"""
|
|
||||||
watch: bool = False
|
|
||||||
"""If True, will not perform training and only watch the restored policy"""
|
|
||||||
watch_num_episodes = 10
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LoggerConfig:
|
|
||||||
"""Logging config."""
|
|
||||||
|
|
||||||
logdir: str = "log"
|
|
||||||
logger: Literal["tensorboard", "wandb"] = "tensorboard"
|
|
||||||
wandb_project: str = "mujoco.benchmark"
|
|
||||||
"""Only used if logger is wandb."""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RLSamplingConfig:
|
|
||||||
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
|
|
||||||
|
|
||||||
num_epochs: int = 100
|
|
||||||
step_per_epoch: int = 30000
|
|
||||||
batch_size: int = 64
|
|
||||||
num_train_envs: int = 64
|
|
||||||
num_test_envs: int = 10
|
|
||||||
buffer_size: int = 4096
|
|
||||||
step_per_collect: int = 2048
|
|
||||||
repeat_per_collect: int = 10
|
|
||||||
update_per_step: int = 1
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RLAgentConfig:
|
|
||||||
"""Config common to most RL algorithms."""
|
|
||||||
|
|
||||||
gamma: float = 0.99
|
|
||||||
"""Discount factor"""
|
|
||||||
gae_lambda: float = 0.95
|
|
||||||
"""For Generalized Advantage Estimate (equivalent to TD(lambda))"""
|
|
||||||
action_bound_method: Literal["clip", "tanh"] | None = "clip"
|
|
||||||
"""How to map original actions in range (-inf, inf) to [-1, 1]"""
|
|
||||||
rew_norm: bool = True
|
|
||||||
"""Whether to normalize rewards"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PGConfig:
|
|
||||||
"""Config of general policy-gradient algorithms."""
|
|
||||||
|
|
||||||
ent_coef: float = 0.0
|
|
||||||
vf_coef: float = 0.25
|
|
||||||
max_grad_norm: float = 0.5
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PPOConfig:
|
|
||||||
"""PPO specific config."""
|
|
||||||
|
|
||||||
value_clip: bool = False
|
|
||||||
norm_adv: bool = False
|
|
||||||
"""Whether to normalize advantages"""
|
|
||||||
eps_clip: float = 0.2
|
|
||||||
dual_clip: float | None = None
|
|
||||||
recompute_adv: bool = True
|
|
@ -0,0 +1,3 @@
|
|||||||
|
from jsonargparse import set_docstring_parse_options
|
||||||
|
|
||||||
|
set_docstring_parse_options(attribute_docstrings=True)
|
@ -2,13 +2,14 @@ import os
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tianshou.config import PGConfig, PPOConfig, RLAgentConfig, RLSamplingConfig
|
|
||||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||||
from tianshou.exploration import BaseNoise
|
from tianshou.exploration import BaseNoise
|
||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
|
from tianshou.highlevel.experiment import RLSamplingConfig
|
||||||
from tianshou.highlevel.logger import Logger
|
from tianshou.highlevel.logger import Logger
|
||||||
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
|
from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice
|
||||||
from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory
|
from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory
|
||||||
@ -124,6 +125,41 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RLAgentConfig:
|
||||||
|
"""Config common to most RL algorithms."""
|
||||||
|
|
||||||
|
gamma: float = 0.99
|
||||||
|
"""Discount factor"""
|
||||||
|
gae_lambda: float = 0.95
|
||||||
|
"""For Generalized Advantage Estimate (equivalent to TD(lambda))"""
|
||||||
|
action_bound_method: Literal["clip", "tanh"] | None = "clip"
|
||||||
|
"""How to map original actions in range (-inf, inf) to [-1, 1]"""
|
||||||
|
rew_norm: bool = True
|
||||||
|
"""Whether to normalize rewards"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PGConfig:
|
||||||
|
"""Config of general policy-gradient algorithms."""
|
||||||
|
|
||||||
|
ent_coef: float = 0.0
|
||||||
|
vf_coef: float = 0.25
|
||||||
|
max_grad_norm: float = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PPOConfig:
|
||||||
|
"""PPO specific config."""
|
||||||
|
|
||||||
|
value_clip: bool = False
|
||||||
|
norm_adv: bool = False
|
||||||
|
"""Whether to normalize advantages"""
|
||||||
|
eps_clip: float = 0.2
|
||||||
|
dual_clip: float | None = None
|
||||||
|
recompute_adv: bool = True
|
||||||
|
|
||||||
|
|
||||||
class PPOAgentFactory(OnpolicyAgentFactory):
|
class PPOAgentFactory(OnpolicyAgentFactory):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -186,10 +222,22 @@ class PPOAgentFactory(OnpolicyAgentFactory):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SACConfig:
|
||||||
|
tau: float = 0.005
|
||||||
|
gamma: float = 0.99
|
||||||
|
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2
|
||||||
|
reward_normalization: bool = False
|
||||||
|
estimation_step: int = 1
|
||||||
|
deterministic_eval: bool = True
|
||||||
|
actor_lr: float = 1e-3
|
||||||
|
critic1_lr: float = 1e-3
|
||||||
|
critic2_lr: float = 1e-3
|
||||||
|
|
||||||
|
|
||||||
class SACAgentFactory(OffpolicyAgentFactory):
|
class SACAgentFactory(OffpolicyAgentFactory):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: "SACAgentFactory.Config",
|
config: SACConfig,
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: RLSamplingConfig,
|
||||||
actor_factory: ActorFactory,
|
actor_factory: ActorFactory,
|
||||||
critic1_factory: CriticFactory,
|
critic1_factory: CriticFactory,
|
||||||
@ -227,17 +275,3 @@ class SACAgentFactory(OffpolicyAgentFactory):
|
|||||||
deterministic_eval=self.config.deterministic_eval,
|
deterministic_eval=self.config.deterministic_eval,
|
||||||
exploration_noise=self.exploration_noise,
|
exploration_noise=self.exploration_noise,
|
||||||
)
|
)
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Config:
|
|
||||||
"""SAC configuration."""
|
|
||||||
|
|
||||||
tau: float = 0.005
|
|
||||||
gamma: float = 0.99
|
|
||||||
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2
|
|
||||||
reward_normalization: bool = False
|
|
||||||
estimation_step: int = 1
|
|
||||||
deterministic_eval: bool = True
|
|
||||||
actor_lr: float = 1e-3
|
|
||||||
critic1_lr: float = 1e-3
|
|
||||||
critic2_lr: float = 1e-3
|
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
from typing import Generic, TypeVar
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tianshou.config import (
|
|
||||||
BasicExperimentConfig,
|
|
||||||
)
|
|
||||||
from tianshou.data import Collector
|
from tianshou.data import Collector
|
||||||
from tianshou.highlevel.agent import AgentFactory
|
from tianshou.highlevel.agent import AgentFactory
|
||||||
from tianshou.highlevel.env import EnvFactory
|
from tianshou.highlevel.env import EnvFactory
|
||||||
@ -18,10 +16,42 @@ TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
|||||||
TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
|
TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RLExperimentConfig:
|
||||||
|
"""Generic config for setting up the experiment, not RL or training specific."""
|
||||||
|
|
||||||
|
seed: int = 42
|
||||||
|
render: float | None = 0.0
|
||||||
|
"""Milliseconds between rendered frames; if None, no rendering"""
|
||||||
|
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
resume_id: str | None = None
|
||||||
|
"""For restoring a model and running means of env-specifics from a checkpoint"""
|
||||||
|
resume_path: str | None = None
|
||||||
|
"""For restoring a model and running means of env-specifics from a checkpoint"""
|
||||||
|
watch: bool = False
|
||||||
|
"""If True, will not perform training and only watch the restored policy"""
|
||||||
|
watch_num_episodes = 10
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RLSamplingConfig:
|
||||||
|
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
|
||||||
|
|
||||||
|
num_epochs: int = 100
|
||||||
|
step_per_epoch: int = 30000
|
||||||
|
batch_size: int = 64
|
||||||
|
num_train_envs: int = 64
|
||||||
|
num_test_envs: int = 10
|
||||||
|
buffer_size: int = 4096
|
||||||
|
step_per_collect: int = 2048
|
||||||
|
repeat_per_collect: int = 10
|
||||||
|
update_per_step: int = 1
|
||||||
|
|
||||||
|
|
||||||
class RLExperiment(Generic[TPolicy, TTrainer]):
|
class RLExperiment(Generic[TPolicy, TTrainer]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: BasicExperimentConfig,
|
config: RLExperimentConfig,
|
||||||
env_factory: EnvFactory,
|
env_factory: EnvFactory,
|
||||||
logger_factory: LoggerFactory,
|
logger_factory: LoggerFactory,
|
||||||
agent_factory: AgentFactory,
|
agent_factory: AgentFactory,
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.config import LoggerConfig
|
|
||||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||||
|
|
||||||
TLogger = TensorboardLogger | WandbLogger
|
TLogger = TensorboardLogger | WandbLogger
|
||||||
@ -22,11 +22,21 @@ class LoggerFactory(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoggerConfig:
|
||||||
|
"""Logging config."""
|
||||||
|
|
||||||
|
logdir: str = "log"
|
||||||
|
logger: Literal["tensorboard", "wandb"] = "tensorboard"
|
||||||
|
wandb_project: str = "mujoco.benchmark"
|
||||||
|
"""Only used if logger is wandb."""
|
||||||
|
|
||||||
|
|
||||||
class DefaultLoggerFactory(LoggerFactory):
|
class DefaultLoggerFactory(LoggerFactory):
|
||||||
def __init__(self, config: LoggerConfig):
|
def __init__(self, config: LoggerConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger:
|
def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger:
|
||||||
writer = SummaryWriter(self.config.logdir)
|
writer = SummaryWriter(self.config.logdir)
|
||||||
writer.add_text("args", str(self.config))
|
writer.add_text("args", str(self.config))
|
||||||
if self.config.logger == "wandb":
|
if self.config.logger == "wandb":
|
||||||
|
@ -66,7 +66,7 @@ class ContinuousActorProbFactory(ContinuousActorFactory):
|
|||||||
actor = ActorProb(
|
actor = ActorProb(
|
||||||
net_a,
|
net_a,
|
||||||
envs.get_action_shape(),
|
envs.get_action_shape(),
|
||||||
unbounded=True,
|
unbounded=self.unbounded,
|
||||||
device=device,
|
device=device,
|
||||||
conditioned_sigma=self.conditioned_sigma,
|
conditioned_sigma=self.conditioned_sigma,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Any, Type
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -8,7 +8,7 @@ from torch import Tensor
|
|||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||||
|
|
||||||
from tianshou.config import RLSamplingConfig
|
from tianshou.highlevel.experiment import RLSamplingConfig
|
||||||
|
|
||||||
TParams = Iterable[Tensor] | Iterable[dict[str, Any]]
|
TParams = Iterable[Tensor] | Iterable[dict[str, Any]]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user