Force kwargs in PolicyWrapperFactoryIntrinsicCuriosity init

This commit is contained in:
Dominik Jain 2023-10-26 10:43:48 +02:00
parent 96298eafd8
commit da2194eff6
4 changed files with 19 additions and 18 deletions

View File

@ -88,12 +88,12 @@ def main(
if icm_lr_scale > 0: if icm_lr_scale > 0:
builder.with_policy_wrapper_factory( builder.with_policy_wrapper_factory(
PolicyWrapperFactoryIntrinsicCuriosity( PolicyWrapperFactoryIntrinsicCuriosity(
IntermediateModuleFactoryAtariDQNFeatures(), feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(),
[512], hidden_sizes=[512],
lr, lr=lr,
icm_lr_scale, lr_scale=icm_lr_scale,
icm_reward_scale, reward_scale=icm_reward_scale,
icm_forward_loss_weight, forward_loss_weight=icm_forward_loss_weight,
), ),
) )

View File

@ -101,12 +101,12 @@ def main(
if icm_lr_scale > 0: if icm_lr_scale > 0:
builder.with_policy_wrapper_factory( builder.with_policy_wrapper_factory(
PolicyWrapperFactoryIntrinsicCuriosity( PolicyWrapperFactoryIntrinsicCuriosity(
IntermediateModuleFactoryAtariDQNFeatures(), feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(),
[hidden_sizes], hidden_sizes=[hidden_sizes],
lr, lr=lr,
icm_lr_scale, lr_scale=icm_lr_scale,
icm_reward_scale, reward_scale=icm_reward_scale,
icm_forward_loss_weight, forward_loss_weight=icm_forward_loss_weight,
), ),
) )
experiment = builder.build() experiment = builder.build()

View File

@ -87,12 +87,12 @@ def main(
if icm_lr_scale > 0: if icm_lr_scale > 0:
builder.with_policy_wrapper_factory( builder.with_policy_wrapper_factory(
PolicyWrapperFactoryIntrinsicCuriosity( PolicyWrapperFactoryIntrinsicCuriosity(
IntermediateModuleFactoryAtariDQNFeatures(), feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(),
[hidden_size], hidden_sizes=[hidden_size],
actor_lr, lr=actor_lr,
icm_lr_scale, lr_scale=icm_lr_scale,
icm_reward_scale, reward_scale=icm_reward_scale,
icm_forward_loss_weight, forward_loss_weight=icm_forward_loss_weight,
), ),
) )
experiment = builder.build() experiment = builder.build()

View File

@ -30,6 +30,7 @@ class PolicyWrapperFactoryIntrinsicCuriosity(
): ):
def __init__( def __init__(
self, self,
*,
feature_net_factory: IntermediateModuleFactory, feature_net_factory: IntermediateModuleFactory,
hidden_sizes: Sequence[int], hidden_sizes: Sequence[int],
lr: float, lr: float,