Remove obsolete mixin, improve class names
This commit is contained in:
parent
90eaacb606
commit
97e21b5ddf
@ -145,7 +145,7 @@ class AgentFactory(ABC, ToStringMixin):
|
||||
pass
|
||||
|
||||
|
||||
class OnpolicyAgentFactory(AgentFactory, ABC):
|
||||
class OnPolicyAgentFactory(AgentFactory, ABC):
|
||||
def create_trainer(
|
||||
self,
|
||||
world: World,
|
||||
@ -187,7 +187,7 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
|
||||
)
|
||||
|
||||
|
||||
class OffpolicyAgentFactory(AgentFactory, ABC):
|
||||
class OffPolicyAgentFactory(AgentFactory, ABC):
|
||||
def create_trainer(
|
||||
self,
|
||||
world: World,
|
||||
@ -229,35 +229,7 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
|
||||
)
|
||||
|
||||
|
||||
class _ActorCriticMixin: # TODO merge
|
||||
"""Mixin for agents that use an ActorCritic module with a single optimizer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor_factory: ActorFactory,
|
||||
critic_factory: CriticFactory,
|
||||
optim_factory: OptimizerFactory,
|
||||
critic_use_action: bool,
|
||||
):
|
||||
self.actor_factory = actor_factory
|
||||
self.critic_factory = critic_factory
|
||||
self.optim_factory = optim_factory
|
||||
self.critic_use_action = critic_use_action
|
||||
|
||||
def create_actor_critic_module_opt(
|
||||
self,
|
||||
envs: Environments,
|
||||
device: TDevice,
|
||||
lr: float,
|
||||
) -> ActorCriticModuleOpt:
|
||||
actor = self.actor_factory.create_module(envs, device)
|
||||
critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action)
|
||||
actor_critic = ActorCritic(actor, critic)
|
||||
optim = self.optim_factory.create_optimizer(actor_critic, lr)
|
||||
return ActorCriticModuleOpt(actor_critic, optim)
|
||||
|
||||
|
||||
class PGAgentFactory(OnpolicyAgentFactory):
|
||||
class PGAgentFactory(OnPolicyAgentFactory):
|
||||
def __init__(
|
||||
self,
|
||||
params: PGParams,
|
||||
@ -296,7 +268,7 @@ class PGAgentFactory(OnpolicyAgentFactory):
|
||||
|
||||
class ActorCriticAgentFactory(
|
||||
Generic[TActorCriticParams, TPolicy],
|
||||
OnpolicyAgentFactory,
|
||||
OnPolicyAgentFactory,
|
||||
ABC,
|
||||
):
|
||||
def __init__(
|
||||
@ -373,7 +345,7 @@ class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
|
||||
|
||||
|
||||
class DiscreteCriticOnlyAgentFactory(
|
||||
OffpolicyAgentFactory,
|
||||
OffPolicyAgentFactory,
|
||||
Generic[TDiscreteCriticOnlyParams, TPolicy],
|
||||
):
|
||||
def __init__(
|
||||
@ -426,7 +398,7 @@ class IQNAgentFactory(DiscreteCriticOnlyAgentFactory[IQNParams, IQNPolicy]):
|
||||
return IQNPolicy
|
||||
|
||||
|
||||
class DDPGAgentFactory(OffpolicyAgentFactory):
|
||||
class DDPGAgentFactory(OffPolicyAgentFactory):
|
||||
def __init__(
|
||||
self,
|
||||
params: DDPGParams,
|
||||
@ -475,7 +447,7 @@ class DDPGAgentFactory(OffpolicyAgentFactory):
|
||||
)
|
||||
|
||||
|
||||
class REDQAgentFactory(OffpolicyAgentFactory):
|
||||
class REDQAgentFactory(OffPolicyAgentFactory):
|
||||
def __init__(
|
||||
self,
|
||||
params: REDQParams,
|
||||
@ -528,7 +500,7 @@ class REDQAgentFactory(OffpolicyAgentFactory):
|
||||
|
||||
|
||||
class ActorDualCriticsAgentFactory(
|
||||
OffpolicyAgentFactory,
|
||||
OffPolicyAgentFactory,
|
||||
Generic[TActorDualCriticsParams, TPolicy],
|
||||
ABC,
|
||||
):
|
||||
|
Loading…
x
Reference in New Issue
Block a user