Add generalised DQN network representation, adding specialised class for feature_only=True
This commit is contained in:
parent
4b270eaa2d
commit
83048788a1
@ -6,7 +6,7 @@ from jsonargparse import CLI
|
|||||||
|
|
||||||
from examples.atari.atari_network import (
|
from examples.atari.atari_network import (
|
||||||
ActorFactoryAtariPlainDQN,
|
ActorFactoryAtariPlainDQN,
|
||||||
IntermediateModuleFactoryAtariDQN,
|
IntermediateModuleFactoryAtariDQNFeatures,
|
||||||
)
|
)
|
||||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
@ -114,7 +114,7 @@ def main(
|
|||||||
if icm_lr_scale > 0:
|
if icm_lr_scale > 0:
|
||||||
builder.with_policy_wrapper_factory(
|
builder.with_policy_wrapper_factory(
|
||||||
PolicyWrapperFactoryIntrinsicCuriosity(
|
PolicyWrapperFactoryIntrinsicCuriosity(
|
||||||
IntermediateModuleFactoryAtariDQN(net_only=True),
|
IntermediateModuleFactoryAtariDQNFeatures(),
|
||||||
[512],
|
[512],
|
||||||
lr,
|
lr,
|
||||||
icm_lr_scale,
|
icm_lr_scale,
|
||||||
|
@ -90,7 +90,7 @@ def main(
|
|||||||
num_cosines=num_cosines,
|
num_cosines=num_cosines,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(net_only=False))
|
.with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True))
|
||||||
.with_trainer_epoch_callback_train(
|
.with_trainer_epoch_callback_train(
|
||||||
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
|
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
|
||||||
)
|
)
|
||||||
|
@ -260,7 +260,8 @@ class ActorFactoryAtariDQN(ActorFactory):
|
|||||||
|
|
||||||
|
|
||||||
class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory):
|
class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory):
|
||||||
def __init__(self, net_only: bool):
|
def __init__(self, features_only: bool = False, net_only: bool = False):
|
||||||
|
self.features_only = features_only
|
||||||
self.net_only = net_only
|
self.net_only = net_only
|
||||||
|
|
||||||
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
|
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
|
||||||
@ -268,6 +269,12 @@ class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory):
|
|||||||
*envs.get_observation_shape(),
|
*envs.get_observation_shape(),
|
||||||
envs.get_action_shape(),
|
envs.get_action_shape(),
|
||||||
device=device,
|
device=device,
|
||||||
features_only=True,
|
features_only=self.features_only,
|
||||||
)
|
).to(device)
|
||||||
return IntermediateModule(dqn.net if self.net_only else dqn, dqn.output_dim)
|
module = dqn.net if self.net_only else dqn
|
||||||
|
return IntermediateModule(module, dqn.output_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class IntermediateModuleFactoryAtariDQNFeatures(IntermediateModuleFactoryAtariDQN):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(features_only=True, net_only=True)
|
||||||
|
@ -8,7 +8,7 @@ from jsonargparse import CLI
|
|||||||
|
|
||||||
from examples.atari.atari_network import (
|
from examples.atari.atari_network import (
|
||||||
ActorFactoryAtariDQN,
|
ActorFactoryAtariDQN,
|
||||||
IntermediateModuleFactoryAtariDQN,
|
IntermediateModuleFactoryAtariDQNFeatures,
|
||||||
)
|
)
|
||||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
@ -103,7 +103,7 @@ def main(
|
|||||||
if icm_lr_scale > 0:
|
if icm_lr_scale > 0:
|
||||||
builder.with_policy_wrapper_factory(
|
builder.with_policy_wrapper_factory(
|
||||||
PolicyWrapperFactoryIntrinsicCuriosity(
|
PolicyWrapperFactoryIntrinsicCuriosity(
|
||||||
IntermediateModuleFactoryAtariDQN(net_only=True),
|
IntermediateModuleFactoryAtariDQNFeatures(),
|
||||||
[hidden_sizes],
|
[hidden_sizes],
|
||||||
lr,
|
lr,
|
||||||
icm_lr_scale,
|
icm_lr_scale,
|
||||||
|
@ -6,7 +6,7 @@ from jsonargparse import CLI
|
|||||||
|
|
||||||
from examples.atari.atari_network import (
|
from examples.atari.atari_network import (
|
||||||
ActorFactoryAtariDQN,
|
ActorFactoryAtariDQN,
|
||||||
IntermediateModuleFactoryAtariDQN,
|
IntermediateModuleFactoryAtariDQNFeatures,
|
||||||
)
|
)
|
||||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
@ -95,7 +95,7 @@ def main(
|
|||||||
if icm_lr_scale > 0:
|
if icm_lr_scale > 0:
|
||||||
builder.with_policy_wrapper_factory(
|
builder.with_policy_wrapper_factory(
|
||||||
PolicyWrapperFactoryIntrinsicCuriosity(
|
PolicyWrapperFactoryIntrinsicCuriosity(
|
||||||
IntermediateModuleFactoryAtariDQN(net_only=True),
|
IntermediateModuleFactoryAtariDQNFeatures(),
|
||||||
[hidden_size],
|
[hidden_size],
|
||||||
actor_lr,
|
actor_lr,
|
||||||
icm_lr_scale,
|
icm_lr_scale,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user