Remove obsolete mixin, improve class names

This commit is contained in:
Dominik Jain 2023-10-16 12:10:23 +02:00
parent 90eaacb606
commit 97e21b5ddf

View File

@ -145,7 +145,7 @@ class AgentFactory(ABC, ToStringMixin):
pass pass
class OnpolicyAgentFactory(AgentFactory, ABC): class OnPolicyAgentFactory(AgentFactory, ABC):
def create_trainer( def create_trainer(
self, self,
world: World, world: World,
@ -187,7 +187,7 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
) )
class OffpolicyAgentFactory(AgentFactory, ABC): class OffPolicyAgentFactory(AgentFactory, ABC):
def create_trainer( def create_trainer(
self, self,
world: World, world: World,
@ -229,35 +229,7 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
) )
class _ActorCriticMixin: # TODO merge class PGAgentFactory(OnPolicyAgentFactory):
"""Mixin for agents that use an ActorCritic module with a single optimizer."""
def __init__(
self,
actor_factory: ActorFactory,
critic_factory: CriticFactory,
optim_factory: OptimizerFactory,
critic_use_action: bool,
):
self.actor_factory = actor_factory
self.critic_factory = critic_factory
self.optim_factory = optim_factory
self.critic_use_action = critic_use_action
def create_actor_critic_module_opt(
self,
envs: Environments,
device: TDevice,
lr: float,
) -> ActorCriticModuleOpt:
actor = self.actor_factory.create_module(envs, device)
critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action)
actor_critic = ActorCritic(actor, critic)
optim = self.optim_factory.create_optimizer(actor_critic, lr)
return ActorCriticModuleOpt(actor_critic, optim)
class PGAgentFactory(OnpolicyAgentFactory):
def __init__( def __init__(
self, self,
params: PGParams, params: PGParams,
@ -296,7 +268,7 @@ class PGAgentFactory(OnpolicyAgentFactory):
class ActorCriticAgentFactory( class ActorCriticAgentFactory(
Generic[TActorCriticParams, TPolicy], Generic[TActorCriticParams, TPolicy],
OnpolicyAgentFactory, OnPolicyAgentFactory,
ABC, ABC,
): ):
def __init__( def __init__(
@ -373,7 +345,7 @@ class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
class DiscreteCriticOnlyAgentFactory( class DiscreteCriticOnlyAgentFactory(
OffpolicyAgentFactory, OffPolicyAgentFactory,
Generic[TDiscreteCriticOnlyParams, TPolicy], Generic[TDiscreteCriticOnlyParams, TPolicy],
): ):
def __init__( def __init__(
@ -426,7 +398,7 @@ class IQNAgentFactory(DiscreteCriticOnlyAgentFactory[IQNParams, IQNPolicy]):
return IQNPolicy return IQNPolicy
class DDPGAgentFactory(OffpolicyAgentFactory): class DDPGAgentFactory(OffPolicyAgentFactory):
def __init__( def __init__(
self, self,
params: DDPGParams, params: DDPGParams,
@ -475,7 +447,7 @@ class DDPGAgentFactory(OffpolicyAgentFactory):
) )
class REDQAgentFactory(OffpolicyAgentFactory): class REDQAgentFactory(OffPolicyAgentFactory):
def __init__( def __init__(
self, self,
params: REDQParams, params: REDQParams,
@ -528,7 +500,7 @@ class REDQAgentFactory(OffpolicyAgentFactory):
class ActorDualCriticsAgentFactory( class ActorDualCriticsAgentFactory(
OffpolicyAgentFactory, OffPolicyAgentFactory,
Generic[TActorDualCriticsParams, TPolicy], Generic[TActorDualCriticsParams, TPolicy],
ABC, ABC,
): ):