Force kwargs in PolicyWrapperFactoryIntrinsicCuriosity init

This commit is contained in:
Dominik Jain 2023-10-26 10:43:48 +02:00
parent 96298eafd8
commit da2194eff6
4 changed files with 19 additions and 18 deletions

View File

@ -88,12 +88,12 @@ def main(
if icm_lr_scale > 0:
builder.with_policy_wrapper_factory(
PolicyWrapperFactoryIntrinsicCuriosity(
IntermediateModuleFactoryAtariDQNFeatures(),
[512],
lr,
icm_lr_scale,
icm_reward_scale,
icm_forward_loss_weight,
feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(),
hidden_sizes=[512],
lr=lr,
lr_scale=icm_lr_scale,
reward_scale=icm_reward_scale,
forward_loss_weight=icm_forward_loss_weight,
),
)

View File

@ -101,12 +101,12 @@ def main(
if icm_lr_scale > 0:
builder.with_policy_wrapper_factory(
PolicyWrapperFactoryIntrinsicCuriosity(
IntermediateModuleFactoryAtariDQNFeatures(),
[hidden_sizes],
lr,
icm_lr_scale,
icm_reward_scale,
icm_forward_loss_weight,
feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(),
hidden_sizes=[hidden_sizes],
lr=lr,
lr_scale=icm_lr_scale,
reward_scale=icm_reward_scale,
forward_loss_weight=icm_forward_loss_weight,
),
)
experiment = builder.build()

View File

@ -87,12 +87,12 @@ def main(
if icm_lr_scale > 0:
builder.with_policy_wrapper_factory(
PolicyWrapperFactoryIntrinsicCuriosity(
IntermediateModuleFactoryAtariDQNFeatures(),
[hidden_size],
actor_lr,
icm_lr_scale,
icm_reward_scale,
icm_forward_loss_weight,
feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(),
hidden_sizes=[hidden_size],
lr=actor_lr,
lr_scale=icm_lr_scale,
reward_scale=icm_reward_scale,
forward_loss_weight=icm_forward_loss_weight,
),
)
experiment = builder.build()

View File

@ -30,6 +30,7 @@ class PolicyWrapperFactoryIntrinsicCuriosity(
):
def __init__(
self,
*,
feature_net_factory: IntermediateModuleFactory,
hidden_sizes: Sequence[int],
lr: float,