Add generalised DQN network representation, adding specialised class for feature_only=True
This commit is contained in:
parent
4b270eaa2d
commit
83048788a1
@ -6,7 +6,7 @@ from jsonargparse import CLI
|
||||
|
||||
from examples.atari.atari_network import (
|
||||
ActorFactoryAtariPlainDQN,
|
||||
IntermediateModuleFactoryAtariDQN,
|
||||
IntermediateModuleFactoryAtariDQNFeatures,
|
||||
)
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
@ -114,7 +114,7 @@ def main(
|
||||
if icm_lr_scale > 0:
|
||||
builder.with_policy_wrapper_factory(
|
||||
PolicyWrapperFactoryIntrinsicCuriosity(
|
||||
IntermediateModuleFactoryAtariDQN(net_only=True),
|
||||
IntermediateModuleFactoryAtariDQNFeatures(),
|
||||
[512],
|
||||
lr,
|
||||
icm_lr_scale,
|
||||
|
@ -90,7 +90,7 @@ def main(
|
||||
num_cosines=num_cosines,
|
||||
),
|
||||
)
|
||||
.with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(net_only=False))
|
||||
.with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True))
|
||||
.with_trainer_epoch_callback_train(
|
||||
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
|
||||
)
|
||||
|
@ -260,7 +260,8 @@ class ActorFactoryAtariDQN(ActorFactory):
|
||||
|
||||
|
||||
class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory):
|
||||
def __init__(self, net_only: bool):
|
||||
def __init__(self, features_only: bool = False, net_only: bool = False):
|
||||
self.features_only = features_only
|
||||
self.net_only = net_only
|
||||
|
||||
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
|
||||
@ -268,6 +269,12 @@ class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory):
|
||||
*envs.get_observation_shape(),
|
||||
envs.get_action_shape(),
|
||||
device=device,
|
||||
features_only=True,
|
||||
)
|
||||
return IntermediateModule(dqn.net if self.net_only else dqn, dqn.output_dim)
|
||||
features_only=self.features_only,
|
||||
).to(device)
|
||||
module = dqn.net if self.net_only else dqn
|
||||
return IntermediateModule(module, dqn.output_dim)
|
||||
|
||||
|
||||
class IntermediateModuleFactoryAtariDQNFeatures(IntermediateModuleFactoryAtariDQN):
|
||||
def __init__(self):
|
||||
super().__init__(features_only=True, net_only=True)
|
||||
|
@ -8,7 +8,7 @@ from jsonargparse import CLI
|
||||
|
||||
from examples.atari.atari_network import (
|
||||
ActorFactoryAtariDQN,
|
||||
IntermediateModuleFactoryAtariDQN,
|
||||
IntermediateModuleFactoryAtariDQNFeatures,
|
||||
)
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
@ -103,7 +103,7 @@ def main(
|
||||
if icm_lr_scale > 0:
|
||||
builder.with_policy_wrapper_factory(
|
||||
PolicyWrapperFactoryIntrinsicCuriosity(
|
||||
IntermediateModuleFactoryAtariDQN(net_only=True),
|
||||
IntermediateModuleFactoryAtariDQNFeatures(),
|
||||
[hidden_sizes],
|
||||
lr,
|
||||
icm_lr_scale,
|
||||
|
@ -6,7 +6,7 @@ from jsonargparse import CLI
|
||||
|
||||
from examples.atari.atari_network import (
|
||||
ActorFactoryAtariDQN,
|
||||
IntermediateModuleFactoryAtariDQN,
|
||||
IntermediateModuleFactoryAtariDQNFeatures,
|
||||
)
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
@ -95,7 +95,7 @@ def main(
|
||||
if icm_lr_scale > 0:
|
||||
builder.with_policy_wrapper_factory(
|
||||
PolicyWrapperFactoryIntrinsicCuriosity(
|
||||
IntermediateModuleFactoryAtariDQN(net_only=True),
|
||||
IntermediateModuleFactoryAtariDQNFeatures(),
|
||||
[hidden_size],
|
||||
actor_lr,
|
||||
icm_lr_scale,
|
||||
|
Loading…
x
Reference in New Issue
Block a user