Force kwargs in PolicyWrapperFactoryIntrinsicCuriosity init
This commit is contained in:
parent
96298eafd8
commit
da2194eff6
@ -88,12 +88,12 @@ def main(
|
|||||||
if icm_lr_scale > 0:
|
if icm_lr_scale > 0:
|
||||||
builder.with_policy_wrapper_factory(
|
builder.with_policy_wrapper_factory(
|
||||||
PolicyWrapperFactoryIntrinsicCuriosity(
|
PolicyWrapperFactoryIntrinsicCuriosity(
|
||||||
IntermediateModuleFactoryAtariDQNFeatures(),
|
feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(),
|
||||||
[512],
|
hidden_sizes=[512],
|
||||||
lr,
|
lr=lr,
|
||||||
icm_lr_scale,
|
lr_scale=icm_lr_scale,
|
||||||
icm_reward_scale,
|
reward_scale=icm_reward_scale,
|
||||||
icm_forward_loss_weight,
|
forward_loss_weight=icm_forward_loss_weight,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -101,12 +101,12 @@ def main(
|
|||||||
if icm_lr_scale > 0:
|
if icm_lr_scale > 0:
|
||||||
builder.with_policy_wrapper_factory(
|
builder.with_policy_wrapper_factory(
|
||||||
PolicyWrapperFactoryIntrinsicCuriosity(
|
PolicyWrapperFactoryIntrinsicCuriosity(
|
||||||
IntermediateModuleFactoryAtariDQNFeatures(),
|
feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(),
|
||||||
[hidden_sizes],
|
hidden_sizes=[hidden_sizes],
|
||||||
lr,
|
lr=lr,
|
||||||
icm_lr_scale,
|
lr_scale=icm_lr_scale,
|
||||||
icm_reward_scale,
|
reward_scale=icm_reward_scale,
|
||||||
icm_forward_loss_weight,
|
forward_loss_weight=icm_forward_loss_weight,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
experiment = builder.build()
|
experiment = builder.build()
|
||||||
|
@ -87,12 +87,12 @@ def main(
|
|||||||
if icm_lr_scale > 0:
|
if icm_lr_scale > 0:
|
||||||
builder.with_policy_wrapper_factory(
|
builder.with_policy_wrapper_factory(
|
||||||
PolicyWrapperFactoryIntrinsicCuriosity(
|
PolicyWrapperFactoryIntrinsicCuriosity(
|
||||||
IntermediateModuleFactoryAtariDQNFeatures(),
|
feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(),
|
||||||
[hidden_size],
|
hidden_sizes=[hidden_size],
|
||||||
actor_lr,
|
lr=actor_lr,
|
||||||
icm_lr_scale,
|
lr_scale=icm_lr_scale,
|
||||||
icm_reward_scale,
|
reward_scale=icm_reward_scale,
|
||||||
icm_forward_loss_weight,
|
forward_loss_weight=icm_forward_loss_weight,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
experiment = builder.build()
|
experiment = builder.build()
|
||||||
|
@ -30,6 +30,7 @@ class PolicyWrapperFactoryIntrinsicCuriosity(
|
|||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
*,
|
||||||
feature_net_factory: IntermediateModuleFactory,
|
feature_net_factory: IntermediateModuleFactory,
|
||||||
hidden_sizes: Sequence[int],
|
hidden_sizes: Sequence[int],
|
||||||
lr: float,
|
lr: float,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user