Support IQN in high-level API

* Add example atari_iqn_hl
* Factor out trainer callbacks to new module atari_callbacks
* Extract base class for DQN-based agent factories
* Improved module factory interface design, achieving higher generality
This commit is contained in:
Dominik Jain 2023-10-11 15:31:38 +02:00
parent 213e08a846
commit a8a367c42d
12 changed files with 309 additions and 41 deletions

View File

@ -0,0 +1,33 @@
from tianshou.highlevel.trainer import (
TrainerEpochCallbackTest,
TrainerEpochCallbackTrain,
TrainingContext,
)
from tianshou.policy import DQNPolicy
class TestEpochCallbackDQNSetEps(TrainerEpochCallbackTest):
def __init__(self, eps_test: float):
self.eps_test = eps_test
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy: DQNPolicy = context.policy
policy.set_eps(self.eps_test)
class TrainEpochCallbackNatureDQNEpsLinearDecay(TrainerEpochCallbackTrain):
def __init__(self, eps_train: float, eps_train_final: float):
self.eps_train = eps_train
self.eps_train_final = eps_train_final
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy: DQNPolicy = context.policy
logger = context.logger.logger
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = self.eps_train - env_step / 1e6 * (self.eps_train - self.eps_train_final)
else:
eps = self.eps_train_final
policy.set_eps(eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})

View File

@ -1,13 +1,12 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import datetime
import os import os
from jsonargparse import CLI from jsonargparse import CLI
from examples.atari.atari_network import ( from examples.atari.atari_network import (
ActorFactoryAtariPlainDQN, ActorFactoryAtariPlainDQN,
FeatureNetFactoryDQN, IntermediateModuleFactoryAtariDQN,
) )
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
@ -26,6 +25,7 @@ from tianshou.highlevel.trainer import (
) )
from tianshou.policy import DQNPolicy from tianshou.policy import DQNPolicy
from tianshou.utils import logging from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag
def main( def main(
@ -53,8 +53,7 @@ def main(
icm_reward_scale: float = 0.01, icm_reward_scale: float = 0.01,
icm_forward_loss_weight: float = 0.2, icm_forward_loss_weight: float = 0.2,
): ):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") log_name = os.path.join(task, "dqn", str(experiment_config.seed), datetime_tag())
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
sampling_config = SamplingConfig( sampling_config = SamplingConfig(
num_epochs=epoch, num_epochs=epoch,
@ -115,7 +114,7 @@ def main(
if icm_lr_scale > 0: if icm_lr_scale > 0:
builder.with_policy_wrapper_factory( builder.with_policy_wrapper_factory(
PolicyWrapperFactoryIntrinsicCuriosity( PolicyWrapperFactoryIntrinsicCuriosity(
FeatureNetFactoryDQN(), IntermediateModuleFactoryAtariDQN(net_only=True),
[512], [512],
lr, lr,
icm_lr_scale, icm_lr_scale,

View File

@ -0,0 +1,105 @@
#!/usr/bin/env python3
import os
from collections.abc import Sequence
from jsonargparse import CLI
from examples.atari.atari_callbacks import (
TestEpochCallbackDQNSetEps,
TrainEpochCallbackNatureDQNEpsLinearDecay,
)
from examples.atari.atari_network import (
IntermediateModuleFactoryAtariDQN,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
ExperimentConfig,
IQNExperimentBuilder,
)
from tianshou.highlevel.params.policy_params import IQNParams
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag
def main(
experiment_config: ExperimentConfig,
task: str = "PongNoFrameskip-v4",
scale_obs: int = 0,
eps_test: float = 0.005,
eps_train: float = 1.0,
eps_train_final: float = 0.05,
buffer_size: int = 100000,
lr: float = 0.0001,
gamma: float = 0.99,
sample_size: int = 32,
online_sample_size: int = 8,
target_sample_size: int = 8,
num_cosines: int = 64,
hidden_sizes: Sequence[int] = (512,),
n_step: int = 3,
target_update_freq: int = 500,
epoch: int = 100,
step_per_epoch: int = 100000,
step_per_collect: int = 10,
update_per_step: float = 0.1,
batch_size: int = 32,
training_num: int = 10,
test_num: int = 10,
frames_stack: int = 4,
save_buffer_name: str | None = None, # TODO support?
):
log_name = os.path.join(task, "iqn", str(experiment_config.seed), datetime_tag())
sampling_config = SamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
batch_size=batch_size,
num_train_envs=training_num,
num_test_envs=test_num,
buffer_size=buffer_size,
step_per_collect=step_per_collect,
update_per_step=update_per_step,
repeat_per_collect=None,
replay_buffer_stack_num=frames_stack,
replay_buffer_ignore_obs_next=True,
replay_buffer_save_only_last_obs=True,
)
env_factory = AtariEnvFactory(
task,
experiment_config.seed,
sampling_config,
frames_stack,
scale=scale_obs,
)
experiment = (
IQNExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_iqn_params(
IQNParams(
discount_factor=gamma,
estimation_step=n_step,
lr=lr,
sample_size=sample_size,
online_sample_size=online_sample_size,
target_update_freq=target_update_freq,
target_sample_size=target_sample_size,
hidden_sizes=hidden_sizes,
num_cosines=num_cosines,
),
)
.with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(net_only=False))
.with_trainer_epoch_callback_train(
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
)
.with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test))
.with_trainer_stop_callback(AtariStopCallback(task))
.build()
)
experiment.run(log_name)
if __name__ == "__main__":
logging.run_main(lambda: CLI(main))

View File

@ -7,7 +7,11 @@ from torch import nn
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.actor import ActorFactory from tianshou.highlevel.module.actor import ActorFactory
from tianshou.highlevel.module.core import Module, ModuleFactory, TDevice from tianshou.highlevel.module.core import (
IntermediateModule,
IntermediateModuleFactory,
TDevice,
)
from tianshou.utils.net.discrete import Actor, NoisyLinear from tianshou.utils.net.discrete import Actor, NoisyLinear
@ -253,12 +257,15 @@ class ActorFactoryAtariDQN(ActorFactory):
return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device) return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device)
class FeatureNetFactoryDQN(ModuleFactory): class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory):
def create_module(self, envs: Environments, device: TDevice) -> Module: def __init__(self, net_only: bool):
self.net_only = net_only
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
dqn = DQN( dqn = DQN(
*envs.get_observation_shape(), *envs.get_observation_shape(),
envs.get_action_shape(), envs.get_action_shape(),
device, device=device,
features_only=True, features_only=True,
) )
return Module(dqn.net, dqn.output_dim) return IntermediateModule(dqn.net if self.net_only else dqn, dqn.output_dim)

View File

@ -8,7 +8,7 @@ from jsonargparse import CLI
from examples.atari.atari_network import ( from examples.atari.atari_network import (
ActorFactoryAtariDQN, ActorFactoryAtariDQN,
FeatureNetFactoryDQN, IntermediateModuleFactoryAtariDQN,
) )
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
@ -103,7 +103,7 @@ def main(
if icm_lr_scale > 0: if icm_lr_scale > 0:
builder.with_policy_wrapper_factory( builder.with_policy_wrapper_factory(
PolicyWrapperFactoryIntrinsicCuriosity( PolicyWrapperFactoryIntrinsicCuriosity(
FeatureNetFactoryDQN(), IntermediateModuleFactoryAtariDQN(net_only=True),
[hidden_sizes], [hidden_sizes],
lr, lr,
icm_lr_scale, icm_lr_scale,

View File

@ -6,7 +6,7 @@ from jsonargparse import CLI
from examples.atari.atari_network import ( from examples.atari.atari_network import (
ActorFactoryAtariDQN, ActorFactoryAtariDQN,
FeatureNetFactoryDQN, IntermediateModuleFactoryAtariDQN,
) )
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
@ -26,7 +26,7 @@ from tianshou.utils.logging import datetime_tag
def main( def main(
experiment_config: ExperimentConfig, experiment_config: ExperimentConfig,
task: str = "PongNoFrameskip-v4", task: str = "PongNoFrameskip-v4",
scale_obs: bool = False, scale_obs: int = 0,
buffer_size: int = 100000, buffer_size: int = 100000,
actor_lr: float = 1e-5, actor_lr: float = 1e-5,
critic_lr: float = 1e-5, critic_lr: float = 1e-5,
@ -67,7 +67,9 @@ def main(
replay_buffer_save_only_last_obs=True, replay_buffer_save_only_last_obs=True,
) )
env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack) env_factory = AtariEnvFactory(
task, experiment_config.seed, sampling_config, frames_stack, scale=scale_obs,
)
builder = ( builder = (
DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config) DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config)
@ -82,14 +84,14 @@ def main(
estimation_step=n_step, estimation_step=n_step,
), ),
) )
.with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs, features_only=True)) .with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs=False, features_only=True))
.with_common_critic_factory_use_actor() .with_common_critic_factory_use_actor()
.with_trainer_stop_callback(AtariStopCallback(task)) .with_trainer_stop_callback(AtariStopCallback(task))
) )
if icm_lr_scale > 0: if icm_lr_scale > 0:
builder.with_policy_wrapper_factory( builder.with_policy_wrapper_factory(
PolicyWrapperFactoryIntrinsicCuriosity( PolicyWrapperFactoryIntrinsicCuriosity(
FeatureNetFactoryDQN(), IntermediateModuleFactoryAtariDQN(net_only=True),
[hidden_size], [hidden_size],
actor_lr, actor_lr,
icm_lr_scale, icm_lr_scale,

View File

@ -13,7 +13,10 @@ from tianshou.highlevel.logger import Logger
from tianshou.highlevel.module.actor import ( from tianshou.highlevel.module.actor import (
ActorFactory, ActorFactory,
) )
from tianshou.highlevel.module.core import TDevice from tianshou.highlevel.module.core import (
ModuleFactory,
TDevice,
)
from tianshou.highlevel.module.critic import CriticEnsembleFactory, CriticFactory from tianshou.highlevel.module.critic import CriticEnsembleFactory, CriticFactory
from tianshou.highlevel.module.module_opt import ( from tianshou.highlevel.module.module_opt import (
ActorCriticModuleOpt, ActorCriticModuleOpt,
@ -24,6 +27,7 @@ from tianshou.highlevel.params.policy_params import (
DDPGParams, DDPGParams,
DiscreteSACParams, DiscreteSACParams,
DQNParams, DQNParams,
IQNParams,
NPGParams, NPGParams,
Params, Params,
ParamsMixinActorAndDualCritics, ParamsMixinActorAndDualCritics,
@ -44,6 +48,7 @@ from tianshou.policy import (
DDPGPolicy, DDPGPolicy,
DiscreteSACPolicy, DiscreteSACPolicy,
DQNPolicy, DQNPolicy,
IQNPolicy,
NPGPolicy, NPGPolicy,
PGPolicy, PGPolicy,
PPOPolicy, PPOPolicy,
@ -61,6 +66,9 @@ CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
TParams = TypeVar("TParams", bound=Params) TParams = TypeVar("TParams", bound=Params)
TActorCriticParams = TypeVar("TActorCriticParams", bound=ParamsMixinLearningRateWithScheduler) TActorCriticParams = TypeVar("TActorCriticParams", bound=ParamsMixinLearningRateWithScheduler)
TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics) TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics)
TDiscreteCriticOnlyParams = TypeVar(
"TDiscreteCriticOnlyParams", bound=ParamsMixinLearningRateWithScheduler,
)
TPolicy = TypeVar("TPolicy", bound=BasePolicy) TPolicy = TypeVar("TPolicy", bound=BasePolicy)
@ -394,21 +402,27 @@ class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
return TRPOPolicy return TRPOPolicy
class DQNAgentFactory(OffpolicyAgentFactory): class DiscreteCriticOnlyAgentFactory(
OffpolicyAgentFactory, Generic[TDiscreteCriticOnlyParams, TPolicy],
):
def __init__( def __init__(
self, self,
params: DQNParams, params: TDiscreteCriticOnlyParams,
sampling_config: SamplingConfig, sampling_config: SamplingConfig,
actor_factory: ActorFactory, model_factory: ModuleFactory,
optim_factory: OptimizerFactory, optim_factory: OptimizerFactory,
): ):
super().__init__(sampling_config, optim_factory) super().__init__(sampling_config, optim_factory)
self.params = params self.params = params
self.actor_factory = actor_factory self.model_factory = model_factory
self.optim_factory = optim_factory self.optim_factory = optim_factory
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: @abstractmethod
model = self.actor_factory.create_module(envs, device) def _get_policy_class(self) -> type[TPolicy]:
pass
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
model = self.model_factory.create_module(envs, device)
optim = self.optim_factory.create_optimizer(model, self.params.lr) optim = self.optim_factory.create_optimizer(model, self.params.lr)
kwargs = self.params.create_kwargs( kwargs = self.params.create_kwargs(
ParamTransformerData( ParamTransformerData(
@ -420,7 +434,8 @@ class DQNAgentFactory(OffpolicyAgentFactory):
) )
envs.get_type().assert_discrete(self) envs.get_type().assert_discrete(self)
action_space = cast(gymnasium.spaces.Discrete, envs.get_action_space()) action_space = cast(gymnasium.spaces.Discrete, envs.get_action_space())
return DQNPolicy( policy_class = self._get_policy_class()
return policy_class(
model=model, model=model,
optim=optim, optim=optim,
action_space=action_space, action_space=action_space,
@ -429,6 +444,16 @@ class DQNAgentFactory(OffpolicyAgentFactory):
) )
class DQNAgentFactory(DiscreteCriticOnlyAgentFactory[DQNParams, DQNPolicy]):
def _get_policy_class(self) -> type[DQNPolicy]:
return DQNPolicy
class IQNAgentFactory(DiscreteCriticOnlyAgentFactory[IQNParams, IQNPolicy]):
def _get_policy_class(self) -> type[IQNPolicy]:
return IQNPolicy
class DDPGAgentFactory(OffpolicyAgentFactory): class DDPGAgentFactory(OffpolicyAgentFactory):
def __init__( def __init__(
self, self,

View File

@ -15,6 +15,7 @@ from tianshou.highlevel.agent import (
DDPGAgentFactory, DDPGAgentFactory,
DiscreteSACAgentFactory, DiscreteSACAgentFactory,
DQNAgentFactory, DQNAgentFactory,
IQNAgentFactory,
NPGAgentFactory, NPGAgentFactory,
PGAgentFactory, PGAgentFactory,
PPOAgentFactory, PPOAgentFactory,
@ -33,6 +34,11 @@ from tianshou.highlevel.module.actor import (
ActorFuture, ActorFuture,
ActorFutureProviderProtocol, ActorFutureProviderProtocol,
ContinuousActorType, ContinuousActorType,
IntermediateModuleFactoryFromActorFactory,
)
from tianshou.highlevel.module.core import (
ImplicitQuantileNetworkFactory,
IntermediateModuleFactory,
) )
from tianshou.highlevel.module.critic import ( from tianshou.highlevel.module.critic import (
CriticEnsembleFactory, CriticEnsembleFactory,
@ -47,6 +53,7 @@ from tianshou.highlevel.params.policy_params import (
DDPGParams, DDPGParams,
DiscreteSACParams, DiscreteSACParams,
DQNParams, DQNParams,
IQNParams,
NPGParams, NPGParams,
PGParams, PGParams,
PPOParams, PPOParams,
@ -641,6 +648,41 @@ class DQNExperimentBuilder(
) )
class IQNExperimentBuilder(ExperimentBuilder):
def __init__(
self,
env_factory: EnvFactory,
experiment_config: ExperimentConfig | None = None,
sampling_config: SamplingConfig | None = None,
):
super().__init__(env_factory, experiment_config, sampling_config)
self._params: IQNParams = IQNParams()
self._preprocess_network_factory = IntermediateModuleFactoryFromActorFactory(
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED),
)
def with_iqn_params(self, params: IQNParams) -> Self:
self._params = params
return self
def with_preprocess_network_factory(self, module_factory: IntermediateModuleFactory) -> Self:
self._preprocess_network_factory = module_factory
return self
def _create_agent_factory(self) -> AgentFactory:
model_factory = ImplicitQuantileNetworkFactory(
self._preprocess_network_factory,
hidden_sizes=self._params.hidden_sizes,
num_cosines=self._params.num_cosines,
)
return IQNAgentFactory(
self._params,
self._sampling_config,
model_factory,
self._get_optim_factory(),
)
class DDPGExperimentBuilder( class DDPGExperimentBuilder(
ExperimentBuilder, ExperimentBuilder,
_BuilderMixinActorFactory_ContinuousDeterministic, _BuilderMixinActorFactory_ContinuousDeterministic,

View File

@ -8,7 +8,13 @@ import torch
from torch import nn from torch import nn
from tianshou.highlevel.env import Environments, EnvType from tianshou.highlevel.env import Environments, EnvType
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal from tianshou.highlevel.module.core import (
IntermediateModule,
IntermediateModuleFactory,
ModuleFactory,
TDevice,
init_linear_orthogonal,
)
from tianshou.highlevel.module.module_opt import ModuleOpt from tianshou.highlevel.module.module_opt import ModuleOpt
from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net import continuous, discrete from tianshou.utils.net import continuous, discrete
@ -34,7 +40,7 @@ class ActorFutureProviderProtocol(Protocol):
pass pass
class ActorFactory(ToStringMixin, ABC): class ActorFactory(ModuleFactory, ToStringMixin, ABC):
@abstractmethod @abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module: def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
pass pass
@ -212,3 +218,13 @@ class ActorFactoryTransientStorageDecorator(ActorFactory):
module = self.actor_factory.create_module(envs, device) module = self.actor_factory.create_module(envs, device)
self._actor_future.actor = module self._actor_future.actor = module
return module return module
class IntermediateModuleFactoryFromActorFactory(IntermediateModuleFactory):
def __init__(self, actor_factory: ActorFactory):
self.actor_factory = actor_factory
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
actor = self.actor_factory.create_module(envs, device)
assert isinstance(actor, BaseActor)
return IntermediateModule(actor, actor.get_output_dim())

View File

@ -7,7 +7,7 @@ import numpy as np
import torch import torch
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
from tianshou.utils.net.common import Net from tianshou.utils.net.discrete import ImplicitQuantileNetwork
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
TDevice: TypeAlias = str | torch.device TDevice: TypeAlias = str | torch.device
@ -24,24 +24,45 @@ def init_linear_orthogonal(module: torch.nn.Module) -> None:
torch.nn.init.zeros_(m.bias) torch.nn.init.zeros_(m.bias)
class ModuleFactory(ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
pass
@dataclass @dataclass
class Module: class IntermediateModule:
module: torch.nn.Module module: torch.nn.Module
output_dim: int output_dim: int
class ModuleFactory(ToStringMixin, ABC): class IntermediateModuleFactory(ToStringMixin, ModuleFactory, ABC):
@abstractmethod @abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> Module: def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
pass pass
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
return self.create_intermediate_module(envs, device).module
class ModuleFactoryNet(
ModuleFactory, class ImplicitQuantileNetworkFactory(ModuleFactory, ToStringMixin):
): # TODO This is unused and broken; use it in ActorFactory* and so on? def __init__(
def __init__(self, hidden_sizes: int | Sequence[int]): self,
preprocess_net_factory: IntermediateModuleFactory,
hidden_sizes: Sequence[int] = (),
num_cosines: int = 64,
):
self.preprocess_net_factory = preprocess_net_factory
self.hidden_sizes = hidden_sizes self.hidden_sizes = hidden_sizes
self.num_cosines = num_cosines
def create_module(self, envs: Environments, device: TDevice) -> Module: def create_module(self, envs: Environments, device: TDevice) -> ImplicitQuantileNetwork:
module = Net(envs.get_observation_shape()) preprocess_net = self.preprocess_net_factory.create_intermediate_module(envs, device)
return Module(module, module.output_dim) return ImplicitQuantileNetwork(
preprocess_net=preprocess_net.module,
action_shape=envs.get_action_shape(),
hidden_sizes=self.hidden_sizes,
num_cosines=self.num_cosines,
preprocess_net_output_dim=preprocess_net.output_dim,
device=device,
).to(device)

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import Any, Literal, Protocol from typing import Any, Literal, Protocol
@ -388,6 +389,23 @@ class DQNParams(Params, ParamsMixinLearningRateWithScheduler):
return transformers return transformers
@dataclass
class IQNParams(DQNParams):
sample_size: int = 32
online_sample_size: int = 8
target_sample_size: int = 8
num_quantiles: int = 200
hidden_sizes: Sequence[int] = ()
"""hidden dimensions to use in the IQN network"""
num_cosines: int = 64
"""number of cosines to use in the IQN network"""
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.append(ParamTransformerDrop("hidden_sizes", "num_cosines"))
return transformers
@dataclass @dataclass
class DDPGParams(Params, ParamsMixinActorAndCritic): class DDPGParams(Params, ParamsMixinActorAndCritic):
tau: float = 0.005 tau: float = 0.005

View File

@ -3,7 +3,7 @@ from collections.abc import Sequence
from typing import Generic, TypeVar from typing import Generic, TypeVar
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import ModuleFactory, TDevice from tianshou.highlevel.module.core import IntermediateModuleFactory, TDevice
from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.optim import OptimizerFactory
from tianshou.policy import BasePolicy, ICMPolicy from tianshou.policy import BasePolicy, ICMPolicy
from tianshou.utils.net.discrete import IntrinsicCuriosityModule from tianshou.utils.net.discrete import IntrinsicCuriosityModule
@ -31,7 +31,7 @@ class PolicyWrapperFactoryIntrinsicCuriosity(
): ):
def __init__( def __init__(
self, self,
feature_net_factory: ModuleFactory, feature_net_factory: IntermediateModuleFactory,
hidden_sizes: Sequence[int], hidden_sizes: Sequence[int],
lr: float, lr: float,
lr_scale: float, lr_scale: float,
@ -52,7 +52,7 @@ class PolicyWrapperFactoryIntrinsicCuriosity(
optim_factory: OptimizerFactory, optim_factory: OptimizerFactory,
device: TDevice, device: TDevice,
) -> ICMPolicy: ) -> ICMPolicy:
feature_net = self.feature_net_factory.create_module(envs, device) feature_net = self.feature_net_factory.create_intermediate_module(envs, device)
action_dim = envs.get_action_shape() action_dim = envs.get_action_shape()
if not isinstance(action_dim, int): if not isinstance(action_dim, int):
raise ValueError(f"Environment action shape must be an integer, got {action_dim}") raise ValueError(f"Environment action shape must be an integer, got {action_dim}")