Add generalised DQN network representation, adding specialised class for feature_only=True

This commit is contained in:
Dominik Jain 2023-10-16 18:38:32 +02:00
parent 4b270eaa2d
commit 83048788a1
5 changed files with 18 additions and 11 deletions

View File

@ -6,7 +6,7 @@ from jsonargparse import CLI
from examples.atari.atari_network import (
ActorFactoryAtariPlainDQN,
IntermediateModuleFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import SamplingConfig
@ -114,7 +114,7 @@ def main(
if icm_lr_scale > 0:
builder.with_policy_wrapper_factory(
PolicyWrapperFactoryIntrinsicCuriosity(
IntermediateModuleFactoryAtariDQN(net_only=True),
IntermediateModuleFactoryAtariDQNFeatures(),
[512],
lr,
icm_lr_scale,

View File

@ -90,7 +90,7 @@ def main(
num_cosines=num_cosines,
),
)
.with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(net_only=False))
.with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True))
.with_trainer_epoch_callback_train(
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
)

View File

@ -260,7 +260,8 @@ class ActorFactoryAtariDQN(ActorFactory):
class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory):
def __init__(self, net_only: bool):
def __init__(self, features_only: bool = False, net_only: bool = False):
self.features_only = features_only
self.net_only = net_only
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
@ -268,6 +269,12 @@ class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory):
*envs.get_observation_shape(),
envs.get_action_shape(),
device=device,
features_only=True,
)
return IntermediateModule(dqn.net if self.net_only else dqn, dqn.output_dim)
features_only=self.features_only,
).to(device)
module = dqn.net if self.net_only else dqn
return IntermediateModule(module, dqn.output_dim)
class IntermediateModuleFactoryAtariDQNFeatures(IntermediateModuleFactoryAtariDQN):
def __init__(self):
super().__init__(features_only=True, net_only=True)

View File

@ -8,7 +8,7 @@ from jsonargparse import CLI
from examples.atari.atari_network import (
ActorFactoryAtariDQN,
IntermediateModuleFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import SamplingConfig
@ -103,7 +103,7 @@ def main(
if icm_lr_scale > 0:
builder.with_policy_wrapper_factory(
PolicyWrapperFactoryIntrinsicCuriosity(
IntermediateModuleFactoryAtariDQN(net_only=True),
IntermediateModuleFactoryAtariDQNFeatures(),
[hidden_sizes],
lr,
icm_lr_scale,

View File

@ -6,7 +6,7 @@ from jsonargparse import CLI
from examples.atari.atari_network import (
ActorFactoryAtariDQN,
IntermediateModuleFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import SamplingConfig
@ -95,7 +95,7 @@ def main(
if icm_lr_scale > 0:
builder.with_policy_wrapper_factory(
PolicyWrapperFactoryIntrinsicCuriosity(
IntermediateModuleFactoryAtariDQN(net_only=True),
IntermediateModuleFactoryAtariDQNFeatures(),
[hidden_size],
actor_lr,
icm_lr_scale,