PolicyWrapperFactory: Remove unnecessary input type variable

This commit is contained in:
Dominik Jain 2023-10-16 12:07:51 +02:00
parent fc695a5394
commit 90eaacb606

View File

@ -9,15 +9,14 @@ from tianshou.policy import BasePolicy, ICMPolicy
from tianshou.utils.net.discrete import IntrinsicCuriosityModule from tianshou.utils.net.discrete import IntrinsicCuriosityModule
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
TPolicyIn = TypeVar("TPolicyIn", bound=BasePolicy)
TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy) TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy)
class PolicyWrapperFactory(Generic[TPolicyIn, TPolicyOut], ToStringMixin, ABC): class PolicyWrapperFactory(Generic[TPolicyOut], ToStringMixin, ABC):
@abstractmethod @abstractmethod
def create_wrapped_policy( def create_wrapped_policy(
self, self,
policy: TPolicyIn, policy: BasePolicy,
envs: Environments, envs: Environments,
optim_factory: OptimizerFactory, optim_factory: OptimizerFactory,
device: TDevice, device: TDevice,
@ -26,8 +25,7 @@ class PolicyWrapperFactory(Generic[TPolicyIn, TPolicyOut], ToStringMixin, ABC):
class PolicyWrapperFactoryIntrinsicCuriosity( class PolicyWrapperFactoryIntrinsicCuriosity(
Generic[TPolicyIn], PolicyWrapperFactory[ICMPolicy],
PolicyWrapperFactory[TPolicyIn, ICMPolicy],
): ):
def __init__( def __init__(
self, self,
@ -47,7 +45,7 @@ class PolicyWrapperFactoryIntrinsicCuriosity(
def create_wrapped_policy( def create_wrapped_policy(
self, self,
policy: TPolicyIn, policy: BasePolicy,
envs: Environments, envs: Environments,
optim_factory: OptimizerFactory, optim_factory: OptimizerFactory,
device: TDevice, device: TDevice,