Remove obsolete mixin, improve class names
This commit is contained in:
parent
90eaacb606
commit
97e21b5ddf
@ -145,7 +145,7 @@ class AgentFactory(ABC, ToStringMixin):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class OnpolicyAgentFactory(AgentFactory, ABC):
|
class OnPolicyAgentFactory(AgentFactory, ABC):
|
||||||
def create_trainer(
|
def create_trainer(
|
||||||
self,
|
self,
|
||||||
world: World,
|
world: World,
|
||||||
@ -187,7 +187,7 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class OffpolicyAgentFactory(AgentFactory, ABC):
|
class OffPolicyAgentFactory(AgentFactory, ABC):
|
||||||
def create_trainer(
|
def create_trainer(
|
||||||
self,
|
self,
|
||||||
world: World,
|
world: World,
|
||||||
@ -229,35 +229,7 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class _ActorCriticMixin: # TODO merge
|
class PGAgentFactory(OnPolicyAgentFactory):
|
||||||
"""Mixin for agents that use an ActorCritic module with a single optimizer."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
actor_factory: ActorFactory,
|
|
||||||
critic_factory: CriticFactory,
|
|
||||||
optim_factory: OptimizerFactory,
|
|
||||||
critic_use_action: bool,
|
|
||||||
):
|
|
||||||
self.actor_factory = actor_factory
|
|
||||||
self.critic_factory = critic_factory
|
|
||||||
self.optim_factory = optim_factory
|
|
||||||
self.critic_use_action = critic_use_action
|
|
||||||
|
|
||||||
def create_actor_critic_module_opt(
|
|
||||||
self,
|
|
||||||
envs: Environments,
|
|
||||||
device: TDevice,
|
|
||||||
lr: float,
|
|
||||||
) -> ActorCriticModuleOpt:
|
|
||||||
actor = self.actor_factory.create_module(envs, device)
|
|
||||||
critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action)
|
|
||||||
actor_critic = ActorCritic(actor, critic)
|
|
||||||
optim = self.optim_factory.create_optimizer(actor_critic, lr)
|
|
||||||
return ActorCriticModuleOpt(actor_critic, optim)
|
|
||||||
|
|
||||||
|
|
||||||
class PGAgentFactory(OnpolicyAgentFactory):
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: PGParams,
|
params: PGParams,
|
||||||
@ -296,7 +268,7 @@ class PGAgentFactory(OnpolicyAgentFactory):
|
|||||||
|
|
||||||
class ActorCriticAgentFactory(
|
class ActorCriticAgentFactory(
|
||||||
Generic[TActorCriticParams, TPolicy],
|
Generic[TActorCriticParams, TPolicy],
|
||||||
OnpolicyAgentFactory,
|
OnPolicyAgentFactory,
|
||||||
ABC,
|
ABC,
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -373,7 +345,7 @@ class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
|
|||||||
|
|
||||||
|
|
||||||
class DiscreteCriticOnlyAgentFactory(
|
class DiscreteCriticOnlyAgentFactory(
|
||||||
OffpolicyAgentFactory,
|
OffPolicyAgentFactory,
|
||||||
Generic[TDiscreteCriticOnlyParams, TPolicy],
|
Generic[TDiscreteCriticOnlyParams, TPolicy],
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -426,7 +398,7 @@ class IQNAgentFactory(DiscreteCriticOnlyAgentFactory[IQNParams, IQNPolicy]):
|
|||||||
return IQNPolicy
|
return IQNPolicy
|
||||||
|
|
||||||
|
|
||||||
class DDPGAgentFactory(OffpolicyAgentFactory):
|
class DDPGAgentFactory(OffPolicyAgentFactory):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: DDPGParams,
|
params: DDPGParams,
|
||||||
@ -475,7 +447,7 @@ class DDPGAgentFactory(OffpolicyAgentFactory):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class REDQAgentFactory(OffpolicyAgentFactory):
|
class REDQAgentFactory(OffPolicyAgentFactory):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
params: REDQParams,
|
params: REDQParams,
|
||||||
@ -528,7 +500,7 @@ class REDQAgentFactory(OffpolicyAgentFactory):
|
|||||||
|
|
||||||
|
|
||||||
class ActorDualCriticsAgentFactory(
|
class ActorDualCriticsAgentFactory(
|
||||||
OffpolicyAgentFactory,
|
OffPolicyAgentFactory,
|
||||||
Generic[TActorDualCriticsParams, TPolicy],
|
Generic[TActorDualCriticsParams, TPolicy],
|
||||||
ABC,
|
ABC,
|
||||||
):
|
):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user