diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index c9daeac..ed6f347 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -18,7 +18,7 @@ from tianshou.highlevel.module.core import ( ) from tianshou.highlevel.module.critic import CriticEnsembleFactory, CriticFactory from tianshou.highlevel.module.module_opt import ( - ActorCriticModuleOpt, + ActorCriticOpt, ) from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.params.policy_params import ( @@ -297,12 +297,12 @@ class ActorCriticAgentFactory( envs: Environments, device: TDevice, lr: float, - ) -> ActorCriticModuleOpt: + ) -> ActorCriticOpt: actor = self.actor_factory.create_module(envs, device) critic = self.critic_factory.create_module(envs, device, use_action=self.critic_use_action) actor_critic = ActorCritic(actor, critic) optim = self.optim_factory.create_optimizer(actor_critic, lr) - return ActorCriticModuleOpt(actor_critic, optim) + return ActorCriticOpt(actor_critic, optim) @typing.no_type_check def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]: diff --git a/tianshou/highlevel/module/module_opt.py b/tianshou/highlevel/module/module_opt.py index 222d680..558686a 100644 --- a/tianshou/highlevel/module/module_opt.py +++ b/tianshou/highlevel/module/module_opt.py @@ -14,7 +14,7 @@ class ModuleOpt: @dataclass -class ActorCriticModuleOpt: +class ActorCriticOpt: """Container for an :class:`ActorCritic` instance along with its optimizer.""" actor_critic_module: ActorCritic