PolicyWrapperFactory: Remove unnecessary input type variable

This commit is contained in:
Dominik Jain 2023-10-16 12:07:51 +02:00
parent fc695a5394
commit 90eaacb606

View File

@ -9,15 +9,14 @@ from tianshou.policy import BasePolicy, ICMPolicy
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
from tianshou.utils.string import ToStringMixin
TPolicyIn = TypeVar("TPolicyIn", bound=BasePolicy)
TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy)
class PolicyWrapperFactory(Generic[TPolicyIn, TPolicyOut], ToStringMixin, ABC):
class PolicyWrapperFactory(Generic[TPolicyOut], ToStringMixin, ABC):
@abstractmethod
def create_wrapped_policy(
self,
policy: TPolicyIn,
policy: BasePolicy,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,
@ -26,8 +25,7 @@ class PolicyWrapperFactory(Generic[TPolicyIn, TPolicyOut], ToStringMixin, ABC):
class PolicyWrapperFactoryIntrinsicCuriosity(
Generic[TPolicyIn],
PolicyWrapperFactory[TPolicyIn, ICMPolicy],
PolicyWrapperFactory[ICMPolicy],
):
def __init__(
self,
@ -47,7 +45,7 @@ class PolicyWrapperFactoryIntrinsicCuriosity(
def create_wrapped_policy(
self,
policy: TPolicyIn,
policy: BasePolicy,
envs: Environments,
optim_factory: OptimizerFactory,
device: TDevice,