Add generalised DQN network representation, adding specialised class for feature_only=True

This commit is contained in:
Dominik Jain 2023-10-16 18:38:32 +02:00
parent 4b270eaa2d
commit 83048788a1
5 changed files with 18 additions and 11 deletions

View File

@ -6,7 +6,7 @@ from jsonargparse import CLI
from examples.atari.atari_network import ( from examples.atari.atari_network import (
ActorFactoryAtariPlainDQN, ActorFactoryAtariPlainDQN,
IntermediateModuleFactoryAtariDQN, IntermediateModuleFactoryAtariDQNFeatures,
) )
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
@ -114,7 +114,7 @@ def main(
if icm_lr_scale > 0: if icm_lr_scale > 0:
builder.with_policy_wrapper_factory( builder.with_policy_wrapper_factory(
PolicyWrapperFactoryIntrinsicCuriosity( PolicyWrapperFactoryIntrinsicCuriosity(
IntermediateModuleFactoryAtariDQN(net_only=True), IntermediateModuleFactoryAtariDQNFeatures(),
[512], [512],
lr, lr,
icm_lr_scale, icm_lr_scale,

View File

@ -90,7 +90,7 @@ def main(
num_cosines=num_cosines, num_cosines=num_cosines,
), ),
) )
.with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(net_only=False)) .with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True))
.with_trainer_epoch_callback_train( .with_trainer_epoch_callback_train(
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final), TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
) )

View File

@ -260,7 +260,8 @@ class ActorFactoryAtariDQN(ActorFactory):
class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory): class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory):
def __init__(self, net_only: bool): def __init__(self, features_only: bool = False, net_only: bool = False):
self.features_only = features_only
self.net_only = net_only self.net_only = net_only
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
@ -268,6 +269,12 @@ class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory):
*envs.get_observation_shape(), *envs.get_observation_shape(),
envs.get_action_shape(), envs.get_action_shape(),
device=device, device=device,
features_only=True, features_only=self.features_only,
) ).to(device)
return IntermediateModule(dqn.net if self.net_only else dqn, dqn.output_dim) module = dqn.net if self.net_only else dqn
return IntermediateModule(module, dqn.output_dim)
class IntermediateModuleFactoryAtariDQNFeatures(IntermediateModuleFactoryAtariDQN):
def __init__(self):
super().__init__(features_only=True, net_only=True)

View File

@ -8,7 +8,7 @@ from jsonargparse import CLI
from examples.atari.atari_network import ( from examples.atari.atari_network import (
ActorFactoryAtariDQN, ActorFactoryAtariDQN,
IntermediateModuleFactoryAtariDQN, IntermediateModuleFactoryAtariDQNFeatures,
) )
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
@ -103,7 +103,7 @@ def main(
if icm_lr_scale > 0: if icm_lr_scale > 0:
builder.with_policy_wrapper_factory( builder.with_policy_wrapper_factory(
PolicyWrapperFactoryIntrinsicCuriosity( PolicyWrapperFactoryIntrinsicCuriosity(
IntermediateModuleFactoryAtariDQN(net_only=True), IntermediateModuleFactoryAtariDQNFeatures(),
[hidden_sizes], [hidden_sizes],
lr, lr,
icm_lr_scale, icm_lr_scale,

View File

@ -6,7 +6,7 @@ from jsonargparse import CLI
from examples.atari.atari_network import ( from examples.atari.atari_network import (
ActorFactoryAtariDQN, ActorFactoryAtariDQN,
IntermediateModuleFactoryAtariDQN, IntermediateModuleFactoryAtariDQNFeatures,
) )
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
@ -95,7 +95,7 @@ def main(
if icm_lr_scale > 0: if icm_lr_scale > 0:
builder.with_policy_wrapper_factory( builder.with_policy_wrapper_factory(
PolicyWrapperFactoryIntrinsicCuriosity( PolicyWrapperFactoryIntrinsicCuriosity(
IntermediateModuleFactoryAtariDQN(net_only=True), IntermediateModuleFactoryAtariDQNFeatures(),
[hidden_size], [hidden_size],
actor_lr, actor_lr,
icm_lr_scale, icm_lr_scale,