Remove obsolete mixin, improve class names

This commit is contained in:
Dominik Jain 2023-10-16 12:10:23 +02:00
parent 90eaacb606
commit 97e21b5ddf

View File

@ -145,7 +145,7 @@ class AgentFactory(ABC, ToStringMixin):
pass
class OnpolicyAgentFactory(AgentFactory, ABC):
class OnPolicyAgentFactory(AgentFactory, ABC):
def create_trainer(
self,
world: World,
@ -187,7 +187,7 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
)
class OffpolicyAgentFactory(AgentFactory, ABC):
class OffPolicyAgentFactory(AgentFactory, ABC):
def create_trainer(
self,
world: World,
@ -229,35 +229,7 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
)
class _ActorCriticMixin: # TODO merge
"""Mixin for agents that use an ActorCritic module with a single optimizer."""
def __init__(
self,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
critic_use_action: bool,
):
self.actor_factory = actor_factory
self.critic_factory = critic_factory
self.optim_factory = optim_factory
self.critic_use_action = critic_use_action
def create_actor_critic_module_opt(
self,
envs: Environments,
device: TDevice,
lr: float,
) -> ActorCriticModuleOpt:
actor = self.actor_factory.create_module(envs, device)
critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action)
actor_critic = ActorCritic(actor, critic)
optim = self.optim_factory.create_optimizer(actor_critic, lr)
return ActorCriticModuleOpt(actor_critic, optim)
class PGAgentFactory(OnpolicyAgentFactory):
class PGAgentFactory(OnPolicyAgentFactory):
def __init__(
self,
params: PGParams,
@ -296,7 +268,7 @@ class PGAgentFactory(OnpolicyAgentFactory):
class ActorCriticAgentFactory(
Generic[TActorCriticParams, TPolicy],
OnpolicyAgentFactory,
OnPolicyAgentFactory,
ABC,
):
def __init__(
@ -373,7 +345,7 @@ class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
class DiscreteCriticOnlyAgentFactory(
OffpolicyAgentFactory,
OffPolicyAgentFactory,
Generic[TDiscreteCriticOnlyParams, TPolicy],
):
def __init__(
@ -426,7 +398,7 @@ class IQNAgentFactory(DiscreteCriticOnlyAgentFactory[IQNParams, IQNPolicy]):
return IQNPolicy
class DDPGAgentFactory(OffpolicyAgentFactory):
class DDPGAgentFactory(OffPolicyAgentFactory):
def __init__(
self,
params: DDPGParams,
@ -475,7 +447,7 @@ class DDPGAgentFactory(OffpolicyAgentFactory):
)
class REDQAgentFactory(OffpolicyAgentFactory):
class REDQAgentFactory(OffPolicyAgentFactory):
def __init__(
self,
params: REDQParams,
@ -528,7 +500,7 @@ class REDQAgentFactory(OffpolicyAgentFactory):
class ActorDualCriticsAgentFactory(
OffpolicyAgentFactory,
OffPolicyAgentFactory,
Generic[TActorDualCriticsParams, TPolicy],
ABC,
):