Support IQN in high-level API
* Add example atari_iqn_hl * Factor out trainer callbacks to new module atari_callbacks * Extract base class for DQN-based agent factories * Improved module factory interface design, achieving higher generality
This commit is contained in:
parent
213e08a846
commit
a8a367c42d
33
examples/atari/atari_callbacks.py
Normal file
33
examples/atari/atari_callbacks.py
Normal file
@ -0,0 +1,33 @@
|
||||
from tianshou.highlevel.trainer import (
|
||||
TrainerEpochCallbackTest,
|
||||
TrainerEpochCallbackTrain,
|
||||
TrainingContext,
|
||||
)
|
||||
from tianshou.policy import DQNPolicy
|
||||
|
||||
|
||||
class TestEpochCallbackDQNSetEps(TrainerEpochCallbackTest):
|
||||
def __init__(self, eps_test: float):
|
||||
self.eps_test = eps_test
|
||||
|
||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||
policy: DQNPolicy = context.policy
|
||||
policy.set_eps(self.eps_test)
|
||||
|
||||
|
||||
class TrainEpochCallbackNatureDQNEpsLinearDecay(TrainerEpochCallbackTrain):
|
||||
def __init__(self, eps_train: float, eps_train_final: float):
|
||||
self.eps_train = eps_train
|
||||
self.eps_train_final = eps_train_final
|
||||
|
||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||
policy: DQNPolicy = context.policy
|
||||
logger = context.logger.logger
|
||||
# nature DQN setting, linear decay in the first 1M steps
|
||||
if env_step <= 1e6:
|
||||
eps = self.eps_train - env_step / 1e6 * (self.eps_train - self.eps_train_final)
|
||||
else:
|
||||
eps = self.eps_train_final
|
||||
policy.set_eps(eps)
|
||||
if env_step % 1000 == 0:
|
||||
logger.write("train/env_step", env_step, {"train/eps": eps})
|
@ -1,13 +1,12 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import datetime
|
||||
import os
|
||||
|
||||
from jsonargparse import CLI
|
||||
|
||||
from examples.atari.atari_network import (
|
||||
ActorFactoryAtariPlainDQN,
|
||||
FeatureNetFactoryDQN,
|
||||
IntermediateModuleFactoryAtariDQN,
|
||||
)
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
@ -26,6 +25,7 @@ from tianshou.highlevel.trainer import (
|
||||
)
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.utils import logging
|
||||
from tianshou.utils.logging import datetime_tag
|
||||
|
||||
|
||||
def main(
|
||||
@ -53,8 +53,7 @@ def main(
|
||||
icm_reward_scale: float = 0.01,
|
||||
icm_forward_loss_weight: float = 0.2,
|
||||
):
|
||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
||||
log_name = os.path.join(task, "dqn", str(experiment_config.seed), datetime_tag())
|
||||
|
||||
sampling_config = SamplingConfig(
|
||||
num_epochs=epoch,
|
||||
@ -115,7 +114,7 @@ def main(
|
||||
if icm_lr_scale > 0:
|
||||
builder.with_policy_wrapper_factory(
|
||||
PolicyWrapperFactoryIntrinsicCuriosity(
|
||||
FeatureNetFactoryDQN(),
|
||||
IntermediateModuleFactoryAtariDQN(net_only=True),
|
||||
[512],
|
||||
lr,
|
||||
icm_lr_scale,
|
||||
|
105
examples/atari/atari_iqn_hl.py
Normal file
105
examples/atari/atari_iqn_hl.py
Normal file
@ -0,0 +1,105 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
|
||||
from jsonargparse import CLI
|
||||
|
||||
from examples.atari.atari_callbacks import (
|
||||
TestEpochCallbackDQNSetEps,
|
||||
TrainEpochCallbackNatureDQNEpsLinearDecay,
|
||||
)
|
||||
from examples.atari.atari_network import (
|
||||
IntermediateModuleFactoryAtariDQN,
|
||||
)
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
ExperimentConfig,
|
||||
IQNExperimentBuilder,
|
||||
)
|
||||
from tianshou.highlevel.params.policy_params import IQNParams
|
||||
from tianshou.utils import logging
|
||||
from tianshou.utils.logging import datetime_tag
|
||||
|
||||
|
||||
def main(
|
||||
experiment_config: ExperimentConfig,
|
||||
task: str = "PongNoFrameskip-v4",
|
||||
scale_obs: int = 0,
|
||||
eps_test: float = 0.005,
|
||||
eps_train: float = 1.0,
|
||||
eps_train_final: float = 0.05,
|
||||
buffer_size: int = 100000,
|
||||
lr: float = 0.0001,
|
||||
gamma: float = 0.99,
|
||||
sample_size: int = 32,
|
||||
online_sample_size: int = 8,
|
||||
target_sample_size: int = 8,
|
||||
num_cosines: int = 64,
|
||||
hidden_sizes: Sequence[int] = (512,),
|
||||
n_step: int = 3,
|
||||
target_update_freq: int = 500,
|
||||
epoch: int = 100,
|
||||
step_per_epoch: int = 100000,
|
||||
step_per_collect: int = 10,
|
||||
update_per_step: float = 0.1,
|
||||
batch_size: int = 32,
|
||||
training_num: int = 10,
|
||||
test_num: int = 10,
|
||||
frames_stack: int = 4,
|
||||
save_buffer_name: str | None = None, # TODO support?
|
||||
):
|
||||
log_name = os.path.join(task, "iqn", str(experiment_config.seed), datetime_tag())
|
||||
|
||||
sampling_config = SamplingConfig(
|
||||
num_epochs=epoch,
|
||||
step_per_epoch=step_per_epoch,
|
||||
batch_size=batch_size,
|
||||
num_train_envs=training_num,
|
||||
num_test_envs=test_num,
|
||||
buffer_size=buffer_size,
|
||||
step_per_collect=step_per_collect,
|
||||
update_per_step=update_per_step,
|
||||
repeat_per_collect=None,
|
||||
replay_buffer_stack_num=frames_stack,
|
||||
replay_buffer_ignore_obs_next=True,
|
||||
replay_buffer_save_only_last_obs=True,
|
||||
)
|
||||
|
||||
env_factory = AtariEnvFactory(
|
||||
task,
|
||||
experiment_config.seed,
|
||||
sampling_config,
|
||||
frames_stack,
|
||||
scale=scale_obs,
|
||||
)
|
||||
|
||||
experiment = (
|
||||
IQNExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
.with_iqn_params(
|
||||
IQNParams(
|
||||
discount_factor=gamma,
|
||||
estimation_step=n_step,
|
||||
lr=lr,
|
||||
sample_size=sample_size,
|
||||
online_sample_size=online_sample_size,
|
||||
target_update_freq=target_update_freq,
|
||||
target_sample_size=target_sample_size,
|
||||
hidden_sizes=hidden_sizes,
|
||||
num_cosines=num_cosines,
|
||||
),
|
||||
)
|
||||
.with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(net_only=False))
|
||||
.with_trainer_epoch_callback_train(
|
||||
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
|
||||
)
|
||||
.with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test))
|
||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||
.build()
|
||||
)
|
||||
experiment.run(log_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.run_main(lambda: CLI(main))
|
@ -7,7 +7,11 @@ from torch import nn
|
||||
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.module.actor import ActorFactory
|
||||
from tianshou.highlevel.module.core import Module, ModuleFactory, TDevice
|
||||
from tianshou.highlevel.module.core import (
|
||||
IntermediateModule,
|
||||
IntermediateModuleFactory,
|
||||
TDevice,
|
||||
)
|
||||
from tianshou.utils.net.discrete import Actor, NoisyLinear
|
||||
|
||||
|
||||
@ -253,12 +257,15 @@ class ActorFactoryAtariDQN(ActorFactory):
|
||||
return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device)
|
||||
|
||||
|
||||
class FeatureNetFactoryDQN(ModuleFactory):
|
||||
def create_module(self, envs: Environments, device: TDevice) -> Module:
|
||||
class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory):
|
||||
def __init__(self, net_only: bool):
|
||||
self.net_only = net_only
|
||||
|
||||
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
|
||||
dqn = DQN(
|
||||
*envs.get_observation_shape(),
|
||||
envs.get_action_shape(),
|
||||
device,
|
||||
device=device,
|
||||
features_only=True,
|
||||
)
|
||||
return Module(dqn.net, dqn.output_dim)
|
||||
return IntermediateModule(dqn.net if self.net_only else dqn, dqn.output_dim)
|
||||
|
@ -8,7 +8,7 @@ from jsonargparse import CLI
|
||||
|
||||
from examples.atari.atari_network import (
|
||||
ActorFactoryAtariDQN,
|
||||
FeatureNetFactoryDQN,
|
||||
IntermediateModuleFactoryAtariDQN,
|
||||
)
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
@ -103,7 +103,7 @@ def main(
|
||||
if icm_lr_scale > 0:
|
||||
builder.with_policy_wrapper_factory(
|
||||
PolicyWrapperFactoryIntrinsicCuriosity(
|
||||
FeatureNetFactoryDQN(),
|
||||
IntermediateModuleFactoryAtariDQN(net_only=True),
|
||||
[hidden_sizes],
|
||||
lr,
|
||||
icm_lr_scale,
|
||||
|
@ -6,7 +6,7 @@ from jsonargparse import CLI
|
||||
|
||||
from examples.atari.atari_network import (
|
||||
ActorFactoryAtariDQN,
|
||||
FeatureNetFactoryDQN,
|
||||
IntermediateModuleFactoryAtariDQN,
|
||||
)
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
@ -26,7 +26,7 @@ from tianshou.utils.logging import datetime_tag
|
||||
def main(
|
||||
experiment_config: ExperimentConfig,
|
||||
task: str = "PongNoFrameskip-v4",
|
||||
scale_obs: bool = False,
|
||||
scale_obs: int = 0,
|
||||
buffer_size: int = 100000,
|
||||
actor_lr: float = 1e-5,
|
||||
critic_lr: float = 1e-5,
|
||||
@ -67,7 +67,9 @@ def main(
|
||||
replay_buffer_save_only_last_obs=True,
|
||||
)
|
||||
|
||||
env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack)
|
||||
env_factory = AtariEnvFactory(
|
||||
task, experiment_config.seed, sampling_config, frames_stack, scale=scale_obs,
|
||||
)
|
||||
|
||||
builder = (
|
||||
DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config)
|
||||
@ -82,14 +84,14 @@ def main(
|
||||
estimation_step=n_step,
|
||||
),
|
||||
)
|
||||
.with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs, features_only=True))
|
||||
.with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs=False, features_only=True))
|
||||
.with_common_critic_factory_use_actor()
|
||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||
)
|
||||
if icm_lr_scale > 0:
|
||||
builder.with_policy_wrapper_factory(
|
||||
PolicyWrapperFactoryIntrinsicCuriosity(
|
||||
FeatureNetFactoryDQN(),
|
||||
IntermediateModuleFactoryAtariDQN(net_only=True),
|
||||
[hidden_size],
|
||||
actor_lr,
|
||||
icm_lr_scale,
|
||||
|
@ -13,7 +13,10 @@ from tianshou.highlevel.logger import Logger
|
||||
from tianshou.highlevel.module.actor import (
|
||||
ActorFactory,
|
||||
)
|
||||
from tianshou.highlevel.module.core import TDevice
|
||||
from tianshou.highlevel.module.core import (
|
||||
ModuleFactory,
|
||||
TDevice,
|
||||
)
|
||||
from tianshou.highlevel.module.critic import CriticEnsembleFactory, CriticFactory
|
||||
from tianshou.highlevel.module.module_opt import (
|
||||
ActorCriticModuleOpt,
|
||||
@ -24,6 +27,7 @@ from tianshou.highlevel.params.policy_params import (
|
||||
DDPGParams,
|
||||
DiscreteSACParams,
|
||||
DQNParams,
|
||||
IQNParams,
|
||||
NPGParams,
|
||||
Params,
|
||||
ParamsMixinActorAndDualCritics,
|
||||
@ -44,6 +48,7 @@ from tianshou.policy import (
|
||||
DDPGPolicy,
|
||||
DiscreteSACPolicy,
|
||||
DQNPolicy,
|
||||
IQNPolicy,
|
||||
NPGPolicy,
|
||||
PGPolicy,
|
||||
PPOPolicy,
|
||||
@ -61,6 +66,9 @@ CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
||||
TParams = TypeVar("TParams", bound=Params)
|
||||
TActorCriticParams = TypeVar("TActorCriticParams", bound=ParamsMixinLearningRateWithScheduler)
|
||||
TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics)
|
||||
TDiscreteCriticOnlyParams = TypeVar(
|
||||
"TDiscreteCriticOnlyParams", bound=ParamsMixinLearningRateWithScheduler,
|
||||
)
|
||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||
|
||||
|
||||
@ -394,21 +402,27 @@ class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
|
||||
return TRPOPolicy
|
||||
|
||||
|
||||
class DQNAgentFactory(OffpolicyAgentFactory):
|
||||
class DiscreteCriticOnlyAgentFactory(
|
||||
OffpolicyAgentFactory, Generic[TDiscreteCriticOnlyParams, TPolicy],
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
params: DQNParams,
|
||||
params: TDiscreteCriticOnlyParams,
|
||||
sampling_config: SamplingConfig,
|
||||
actor_factory: ActorFactory,
|
||||
model_factory: ModuleFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
):
|
||||
super().__init__(sampling_config, optim_factory)
|
||||
self.params = params
|
||||
self.actor_factory = actor_factory
|
||||
self.model_factory = model_factory
|
||||
self.optim_factory = optim_factory
|
||||
|
||||
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||
model = self.actor_factory.create_module(envs, device)
|
||||
@abstractmethod
|
||||
def _get_policy_class(self) -> type[TPolicy]:
|
||||
pass
|
||||
|
||||
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
||||
model = self.model_factory.create_module(envs, device)
|
||||
optim = self.optim_factory.create_optimizer(model, self.params.lr)
|
||||
kwargs = self.params.create_kwargs(
|
||||
ParamTransformerData(
|
||||
@ -420,7 +434,8 @@ class DQNAgentFactory(OffpolicyAgentFactory):
|
||||
)
|
||||
envs.get_type().assert_discrete(self)
|
||||
action_space = cast(gymnasium.spaces.Discrete, envs.get_action_space())
|
||||
return DQNPolicy(
|
||||
policy_class = self._get_policy_class()
|
||||
return policy_class(
|
||||
model=model,
|
||||
optim=optim,
|
||||
action_space=action_space,
|
||||
@ -429,6 +444,16 @@ class DQNAgentFactory(OffpolicyAgentFactory):
|
||||
)
|
||||
|
||||
|
||||
class DQNAgentFactory(DiscreteCriticOnlyAgentFactory[DQNParams, DQNPolicy]):
|
||||
def _get_policy_class(self) -> type[DQNPolicy]:
|
||||
return DQNPolicy
|
||||
|
||||
|
||||
class IQNAgentFactory(DiscreteCriticOnlyAgentFactory[IQNParams, IQNPolicy]):
|
||||
def _get_policy_class(self) -> type[IQNPolicy]:
|
||||
return IQNPolicy
|
||||
|
||||
|
||||
class DDPGAgentFactory(OffpolicyAgentFactory):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -15,6 +15,7 @@ from tianshou.highlevel.agent import (
|
||||
DDPGAgentFactory,
|
||||
DiscreteSACAgentFactory,
|
||||
DQNAgentFactory,
|
||||
IQNAgentFactory,
|
||||
NPGAgentFactory,
|
||||
PGAgentFactory,
|
||||
PPOAgentFactory,
|
||||
@ -33,6 +34,11 @@ from tianshou.highlevel.module.actor import (
|
||||
ActorFuture,
|
||||
ActorFutureProviderProtocol,
|
||||
ContinuousActorType,
|
||||
IntermediateModuleFactoryFromActorFactory,
|
||||
)
|
||||
from tianshou.highlevel.module.core import (
|
||||
ImplicitQuantileNetworkFactory,
|
||||
IntermediateModuleFactory,
|
||||
)
|
||||
from tianshou.highlevel.module.critic import (
|
||||
CriticEnsembleFactory,
|
||||
@ -47,6 +53,7 @@ from tianshou.highlevel.params.policy_params import (
|
||||
DDPGParams,
|
||||
DiscreteSACParams,
|
||||
DQNParams,
|
||||
IQNParams,
|
||||
NPGParams,
|
||||
PGParams,
|
||||
PPOParams,
|
||||
@ -641,6 +648,41 @@ class DQNExperimentBuilder(
|
||||
)
|
||||
|
||||
|
||||
class IQNExperimentBuilder(ExperimentBuilder):
|
||||
def __init__(
|
||||
self,
|
||||
env_factory: EnvFactory,
|
||||
experiment_config: ExperimentConfig | None = None,
|
||||
sampling_config: SamplingConfig | None = None,
|
||||
):
|
||||
super().__init__(env_factory, experiment_config, sampling_config)
|
||||
self._params: IQNParams = IQNParams()
|
||||
self._preprocess_network_factory = IntermediateModuleFactoryFromActorFactory(
|
||||
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED),
|
||||
)
|
||||
|
||||
def with_iqn_params(self, params: IQNParams) -> Self:
|
||||
self._params = params
|
||||
return self
|
||||
|
||||
def with_preprocess_network_factory(self, module_factory: IntermediateModuleFactory) -> Self:
|
||||
self._preprocess_network_factory = module_factory
|
||||
return self
|
||||
|
||||
def _create_agent_factory(self) -> AgentFactory:
|
||||
model_factory = ImplicitQuantileNetworkFactory(
|
||||
self._preprocess_network_factory,
|
||||
hidden_sizes=self._params.hidden_sizes,
|
||||
num_cosines=self._params.num_cosines,
|
||||
)
|
||||
return IQNAgentFactory(
|
||||
self._params,
|
||||
self._sampling_config,
|
||||
model_factory,
|
||||
self._get_optim_factory(),
|
||||
)
|
||||
|
||||
|
||||
class DDPGExperimentBuilder(
|
||||
ExperimentBuilder,
|
||||
_BuilderMixinActorFactory_ContinuousDeterministic,
|
||||
|
@ -8,7 +8,13 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from tianshou.highlevel.env import Environments, EnvType
|
||||
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
||||
from tianshou.highlevel.module.core import (
|
||||
IntermediateModule,
|
||||
IntermediateModuleFactory,
|
||||
ModuleFactory,
|
||||
TDevice,
|
||||
init_linear_orthogonal,
|
||||
)
|
||||
from tianshou.highlevel.module.module_opt import ModuleOpt
|
||||
from tianshou.highlevel.optim import OptimizerFactory
|
||||
from tianshou.utils.net import continuous, discrete
|
||||
@ -34,7 +40,7 @@ class ActorFutureProviderProtocol(Protocol):
|
||||
pass
|
||||
|
||||
|
||||
class ActorFactory(ToStringMixin, ABC):
|
||||
class ActorFactory(ModuleFactory, ToStringMixin, ABC):
|
||||
@abstractmethod
|
||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
|
||||
pass
|
||||
@ -212,3 +218,13 @@ class ActorFactoryTransientStorageDecorator(ActorFactory):
|
||||
module = self.actor_factory.create_module(envs, device)
|
||||
self._actor_future.actor = module
|
||||
return module
|
||||
|
||||
|
||||
class IntermediateModuleFactoryFromActorFactory(IntermediateModuleFactory):
|
||||
def __init__(self, actor_factory: ActorFactory):
|
||||
self.actor_factory = actor_factory
|
||||
|
||||
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
|
||||
actor = self.actor_factory.create_module(envs, device)
|
||||
assert isinstance(actor, BaseActor)
|
||||
return IntermediateModule(actor, actor.get_output_dim())
|
||||
|
@ -7,7 +7,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.discrete import ImplicitQuantileNetwork
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
TDevice: TypeAlias = str | torch.device
|
||||
@ -24,24 +24,45 @@ def init_linear_orthogonal(module: torch.nn.Module) -> None:
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
|
||||
|
||||
class ModuleFactory(ABC):
|
||||
@abstractmethod
|
||||
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Module:
|
||||
class IntermediateModule:
|
||||
module: torch.nn.Module
|
||||
output_dim: int
|
||||
|
||||
|
||||
class ModuleFactory(ToStringMixin, ABC):
|
||||
class IntermediateModuleFactory(ToStringMixin, ModuleFactory, ABC):
|
||||
@abstractmethod
|
||||
def create_module(self, envs: Environments, device: TDevice) -> Module:
|
||||
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
|
||||
pass
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
|
||||
return self.create_intermediate_module(envs, device).module
|
||||
|
||||
class ModuleFactoryNet(
|
||||
ModuleFactory,
|
||||
): # TODO This is unused and broken; use it in ActorFactory* and so on?
|
||||
def __init__(self, hidden_sizes: int | Sequence[int]):
|
||||
|
||||
class ImplicitQuantileNetworkFactory(ModuleFactory, ToStringMixin):
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net_factory: IntermediateModuleFactory,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
num_cosines: int = 64,
|
||||
):
|
||||
self.preprocess_net_factory = preprocess_net_factory
|
||||
self.hidden_sizes = hidden_sizes
|
||||
self.num_cosines = num_cosines
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> Module:
|
||||
module = Net(envs.get_observation_shape())
|
||||
return Module(module, module.output_dim)
|
||||
def create_module(self, envs: Environments, device: TDevice) -> ImplicitQuantileNetwork:
|
||||
preprocess_net = self.preprocess_net_factory.create_intermediate_module(envs, device)
|
||||
return ImplicitQuantileNetwork(
|
||||
preprocess_net=preprocess_net.module,
|
||||
action_shape=envs.get_action_shape(),
|
||||
hidden_sizes=self.hidden_sizes,
|
||||
num_cosines=self.num_cosines,
|
||||
preprocess_net_output_dim=preprocess_net.output_dim,
|
||||
device=device,
|
||||
).to(device)
|
||||
|
@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
@ -388,6 +389,23 @@ class DQNParams(Params, ParamsMixinLearningRateWithScheduler):
|
||||
return transformers
|
||||
|
||||
|
||||
@dataclass
|
||||
class IQNParams(DQNParams):
|
||||
sample_size: int = 32
|
||||
online_sample_size: int = 8
|
||||
target_sample_size: int = 8
|
||||
num_quantiles: int = 200
|
||||
hidden_sizes: Sequence[int] = ()
|
||||
"""hidden dimensions to use in the IQN network"""
|
||||
num_cosines: int = 64
|
||||
"""number of cosines to use in the IQN network"""
|
||||
|
||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||
transformers = super()._get_param_transformers()
|
||||
transformers.append(ParamTransformerDrop("hidden_sizes", "num_cosines"))
|
||||
return transformers
|
||||
|
||||
|
||||
@dataclass
|
||||
class DDPGParams(Params, ParamsMixinActorAndCritic):
|
||||
tau: float = 0.005
|
||||
|
@ -3,7 +3,7 @@ from collections.abc import Sequence
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.module.core import ModuleFactory, TDevice
|
||||
from tianshou.highlevel.module.core import IntermediateModuleFactory, TDevice
|
||||
from tianshou.highlevel.optim import OptimizerFactory
|
||||
from tianshou.policy import BasePolicy, ICMPolicy
|
||||
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
|
||||
@ -31,7 +31,7 @@ class PolicyWrapperFactoryIntrinsicCuriosity(
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
feature_net_factory: ModuleFactory,
|
||||
feature_net_factory: IntermediateModuleFactory,
|
||||
hidden_sizes: Sequence[int],
|
||||
lr: float,
|
||||
lr_scale: float,
|
||||
@ -52,7 +52,7 @@ class PolicyWrapperFactoryIntrinsicCuriosity(
|
||||
optim_factory: OptimizerFactory,
|
||||
device: TDevice,
|
||||
) -> ICMPolicy:
|
||||
feature_net = self.feature_net_factory.create_module(envs, device)
|
||||
feature_net = self.feature_net_factory.create_intermediate_module(envs, device)
|
||||
action_dim = envs.get_action_shape()
|
||||
if not isinstance(action_dim, int):
|
||||
raise ValueError(f"Environment action shape must be an integer, got {action_dim}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user