diff --git a/tianshou/highlevel/params/policy_wrapper.py b/tianshou/highlevel/params/policy_wrapper.py index 09e5065..8bf43cc 100644 --- a/tianshou/highlevel/params/policy_wrapper.py +++ b/tianshou/highlevel/params/policy_wrapper.py @@ -54,7 +54,7 @@ class PolicyWrapperFactoryIntrinsicCuriosity( ) -> ICMPolicy: feature_net = self.feature_net_factory.create_module(envs, device) action_dim = envs.get_action_shape() - if type(action_dim) != int: + if not isinstance(action_dim, int): raise ValueError(f"Environment action shape must be an integer, got {action_dim}") feature_dim = feature_net.output_dim icm_net = IntrinsicCuriosityModule(