diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index 61a8fdf..466520e 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -6,7 +6,7 @@ from jsonargparse import CLI from examples.atari.atari_network import ( ActorFactoryAtariPlainDQN, - IntermediateModuleFactoryAtariDQN, + IntermediateModuleFactoryAtariDQNFeatures, ) from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback from tianshou.highlevel.config import SamplingConfig @@ -114,7 +114,7 @@ def main( if icm_lr_scale > 0: builder.with_policy_wrapper_factory( PolicyWrapperFactoryIntrinsicCuriosity( - IntermediateModuleFactoryAtariDQN(net_only=True), + IntermediateModuleFactoryAtariDQNFeatures(), [512], lr, icm_lr_scale, diff --git a/examples/atari/atari_iqn_hl.py b/examples/atari/atari_iqn_hl.py index 4d1d6c1..a4d3dd2 100644 --- a/examples/atari/atari_iqn_hl.py +++ b/examples/atari/atari_iqn_hl.py @@ -90,7 +90,7 @@ def main( num_cosines=num_cosines, ), ) - .with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(net_only=False)) + .with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True)) .with_trainer_epoch_callback_train( TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final), ) diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index bd16d4f..a371621 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -260,7 +260,8 @@ class ActorFactoryAtariDQN(ActorFactory): class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory): - def __init__(self, net_only: bool): + def __init__(self, features_only: bool = False, net_only: bool = False): + self.features_only = features_only self.net_only = net_only def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: @@ -268,6 +269,12 @@ class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory): *envs.get_observation_shape(), envs.get_action_shape(), device=device, - features_only=True, - ) - return IntermediateModule(dqn.net if self.net_only else dqn, dqn.output_dim) + features_only=self.features_only, + ).to(device) + module = dqn.net if self.net_only else dqn + return IntermediateModule(module, dqn.output_dim) + + +class IntermediateModuleFactoryAtariDQNFeatures(IntermediateModuleFactoryAtariDQN): + def __init__(self): + super().__init__(features_only=True, net_only=True) diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 9a5777b..25fd0a5 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -8,7 +8,7 @@ from jsonargparse import CLI from examples.atari.atari_network import ( ActorFactoryAtariDQN, - IntermediateModuleFactoryAtariDQN, + IntermediateModuleFactoryAtariDQNFeatures, ) from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback from tianshou.highlevel.config import SamplingConfig @@ -103,7 +103,7 @@ def main( if icm_lr_scale > 0: builder.with_policy_wrapper_factory( PolicyWrapperFactoryIntrinsicCuriosity( - IntermediateModuleFactoryAtariDQN(net_only=True), + IntermediateModuleFactoryAtariDQNFeatures(), [hidden_sizes], lr, icm_lr_scale, diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index ac9ceb2..9f743eb 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -6,7 +6,7 @@ from jsonargparse import CLI from examples.atari.atari_network import ( ActorFactoryAtariDQN, - IntermediateModuleFactoryAtariDQN, + IntermediateModuleFactoryAtariDQNFeatures, ) from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback from tianshou.highlevel.config import SamplingConfig @@ -95,7 +95,7 @@ def main( if icm_lr_scale > 0: builder.with_policy_wrapper_factory( PolicyWrapperFactoryIntrinsicCuriosity( - IntermediateModuleFactoryAtariDQN(net_only=True), + IntermediateModuleFactoryAtariDQNFeatures(), [hidden_size], actor_lr, icm_lr_scale,