PolicyWrapperFactory: Remove unnecessary input type variable
This commit is contained in:
parent
fc695a5394
commit
90eaacb606
@ -9,15 +9,14 @@ from tianshou.policy import BasePolicy, ICMPolicy
|
|||||||
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
|
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
TPolicyIn = TypeVar("TPolicyIn", bound=BasePolicy)
|
|
||||||
TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy)
|
TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy)
|
||||||
|
|
||||||
|
|
||||||
class PolicyWrapperFactory(Generic[TPolicyIn, TPolicyOut], ToStringMixin, ABC):
|
class PolicyWrapperFactory(Generic[TPolicyOut], ToStringMixin, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_wrapped_policy(
|
def create_wrapped_policy(
|
||||||
self,
|
self,
|
||||||
policy: TPolicyIn,
|
policy: BasePolicy,
|
||||||
envs: Environments,
|
envs: Environments,
|
||||||
optim_factory: OptimizerFactory,
|
optim_factory: OptimizerFactory,
|
||||||
device: TDevice,
|
device: TDevice,
|
||||||
@ -26,8 +25,7 @@ class PolicyWrapperFactory(Generic[TPolicyIn, TPolicyOut], ToStringMixin, ABC):
|
|||||||
|
|
||||||
|
|
||||||
class PolicyWrapperFactoryIntrinsicCuriosity(
|
class PolicyWrapperFactoryIntrinsicCuriosity(
|
||||||
Generic[TPolicyIn],
|
PolicyWrapperFactory[ICMPolicy],
|
||||||
PolicyWrapperFactory[TPolicyIn, ICMPolicy],
|
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -47,7 +45,7 @@ class PolicyWrapperFactoryIntrinsicCuriosity(
|
|||||||
|
|
||||||
def create_wrapped_policy(
|
def create_wrapped_policy(
|
||||||
self,
|
self,
|
||||||
policy: TPolicyIn,
|
policy: BasePolicy,
|
||||||
envs: Environments,
|
envs: Environments,
|
||||||
optim_factory: OptimizerFactory,
|
optim_factory: OptimizerFactory,
|
||||||
device: TDevice,
|
device: TDevice,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user