Force kwargs in PolicyWrapperFactoryIntrinsicCuriosity init
This commit is contained in:
parent
96298eafd8
commit
da2194eff6
@ -88,12 +88,12 @@ def main(
|
||||
if icm_lr_scale > 0:
|
||||
builder.with_policy_wrapper_factory(
|
||||
PolicyWrapperFactoryIntrinsicCuriosity(
|
||||
IntermediateModuleFactoryAtariDQNFeatures(),
|
||||
[512],
|
||||
lr,
|
||||
icm_lr_scale,
|
||||
icm_reward_scale,
|
||||
icm_forward_loss_weight,
|
||||
feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(),
|
||||
hidden_sizes=[512],
|
||||
lr=lr,
|
||||
lr_scale=icm_lr_scale,
|
||||
reward_scale=icm_reward_scale,
|
||||
forward_loss_weight=icm_forward_loss_weight,
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -101,12 +101,12 @@ def main(
|
||||
if icm_lr_scale > 0:
|
||||
builder.with_policy_wrapper_factory(
|
||||
PolicyWrapperFactoryIntrinsicCuriosity(
|
||||
IntermediateModuleFactoryAtariDQNFeatures(),
|
||||
[hidden_sizes],
|
||||
lr,
|
||||
icm_lr_scale,
|
||||
icm_reward_scale,
|
||||
icm_forward_loss_weight,
|
||||
feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(),
|
||||
hidden_sizes=[hidden_sizes],
|
||||
lr=lr,
|
||||
lr_scale=icm_lr_scale,
|
||||
reward_scale=icm_reward_scale,
|
||||
forward_loss_weight=icm_forward_loss_weight,
|
||||
),
|
||||
)
|
||||
experiment = builder.build()
|
||||
|
@ -87,12 +87,12 @@ def main(
|
||||
if icm_lr_scale > 0:
|
||||
builder.with_policy_wrapper_factory(
|
||||
PolicyWrapperFactoryIntrinsicCuriosity(
|
||||
IntermediateModuleFactoryAtariDQNFeatures(),
|
||||
[hidden_size],
|
||||
actor_lr,
|
||||
icm_lr_scale,
|
||||
icm_reward_scale,
|
||||
icm_forward_loss_weight,
|
||||
feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(),
|
||||
hidden_sizes=[hidden_size],
|
||||
lr=actor_lr,
|
||||
lr_scale=icm_lr_scale,
|
||||
reward_scale=icm_reward_scale,
|
||||
forward_loss_weight=icm_forward_loss_weight,
|
||||
),
|
||||
)
|
||||
experiment = builder.build()
|
||||
|
@ -30,6 +30,7 @@ class PolicyWrapperFactoryIntrinsicCuriosity(
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
feature_net_factory: IntermediateModuleFactory,
|
||||
hidden_sizes: Sequence[int],
|
||||
lr: float,
|
||||
|
Loading…
x
Reference in New Issue
Block a user