From da2194eff6c12d32d8ca0e717dedb480243a31ef Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Thu, 26 Oct 2023 10:43:48 +0200 Subject: [PATCH] Force kwargs in PolicyWrapperFactoryIntrinsicCuriosity init --- examples/atari/atari_dqn_hl.py | 12 ++++++------ examples/atari/atari_ppo_hl.py | 12 ++++++------ examples/atari/atari_sac_hl.py | 12 ++++++------ tianshou/highlevel/params/policy_wrapper.py | 1 + 4 files changed, 19 insertions(+), 18 deletions(-) diff --git a/examples/atari/atari_dqn_hl.py b/examples/atari/atari_dqn_hl.py index be981b5..f7aa331 100644 --- a/examples/atari/atari_dqn_hl.py +++ b/examples/atari/atari_dqn_hl.py @@ -88,12 +88,12 @@ def main( if icm_lr_scale > 0: builder.with_policy_wrapper_factory( PolicyWrapperFactoryIntrinsicCuriosity( - IntermediateModuleFactoryAtariDQNFeatures(), - [512], - lr, - icm_lr_scale, - icm_reward_scale, - icm_forward_loss_weight, + feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(), + hidden_sizes=[512], + lr=lr, + lr_scale=icm_lr_scale, + reward_scale=icm_reward_scale, + forward_loss_weight=icm_forward_loss_weight, ), ) diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index 1c1f1ad..d388d6b 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -101,12 +101,12 @@ def main( if icm_lr_scale > 0: builder.with_policy_wrapper_factory( PolicyWrapperFactoryIntrinsicCuriosity( - IntermediateModuleFactoryAtariDQNFeatures(), - [hidden_sizes], - lr, - icm_lr_scale, - icm_reward_scale, - icm_forward_loss_weight, + feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(), + hidden_sizes=[hidden_sizes], + lr=lr, + lr_scale=icm_lr_scale, + reward_scale=icm_reward_scale, + forward_loss_weight=icm_forward_loss_weight, ), ) experiment = builder.build() diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index c602550..8314bee 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -87,12 +87,12 @@ def main( if icm_lr_scale > 0: builder.with_policy_wrapper_factory( PolicyWrapperFactoryIntrinsicCuriosity( - IntermediateModuleFactoryAtariDQNFeatures(), - [hidden_size], - actor_lr, - icm_lr_scale, - icm_reward_scale, - icm_forward_loss_weight, + feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(), + hidden_sizes=[hidden_size], + lr=actor_lr, + lr_scale=icm_lr_scale, + reward_scale=icm_reward_scale, + forward_loss_weight=icm_forward_loss_weight, ), ) experiment = builder.build() diff --git a/tianshou/highlevel/params/policy_wrapper.py b/tianshou/highlevel/params/policy_wrapper.py index c821f50..e795822 100644 --- a/tianshou/highlevel/params/policy_wrapper.py +++ b/tianshou/highlevel/params/policy_wrapper.py @@ -30,6 +30,7 @@ class PolicyWrapperFactoryIntrinsicCuriosity( ): def __init__( self, + *, feature_net_factory: IntermediateModuleFactory, hidden_sizes: Sequence[int], lr: float,