Change high-level DQN interface to expect an actor instead of a critic,

because that is what is functionally required
This commit is contained in:
Dominik Jain 2023-10-05 19:21:08 +02:00
parent 1cba589bd4
commit b54fcd12cb
7 changed files with 51 additions and 34 deletions

View File

@ -6,7 +6,7 @@ import os
from jsonargparse import CLI
from examples.atari.atari_network import (
CriticFactoryAtariDQN,
ActorFactoryAtariPlainDQN,
FeatureNetFactoryDQN,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
@ -103,7 +103,7 @@ def main(
target_update_freq=target_update_freq,
),
)
.with_critic_factory(CriticFactoryAtariDQN())
.with_actor_factory(ActorFactoryAtariPlainDQN())
.with_trainer_epoch_callback_train(TrainEpochCallback())
.with_trainer_epoch_callback_test(TestEpochCallback())
.with_trainer_stop_callback(AtariStopCallback(task))

View File

@ -8,8 +8,6 @@ from torch import nn
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.actor import ActorFactory
from tianshou.highlevel.module.core import Module, ModuleFactory, TDevice
from tianshou.highlevel.module.critic import CriticFactory
from tianshou.utils.net.common import BaseActor
from tianshou.utils.net.discrete import Actor, NoisyLinear
@ -227,14 +225,8 @@ class QRDQN(DQN):
return obs, state
class CriticFactoryAtariDQN(CriticFactory):
def create_module(
self,
envs: Environments,
device: TDevice,
use_action: bool,
) -> torch.nn.Module:
assert use_action
class ActorFactoryAtariPlainDQN(ActorFactory):
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
return DQN(
*envs.get_observation_shape(),
envs.get_action_shape(),
@ -243,17 +235,18 @@ class CriticFactoryAtariDQN(CriticFactory):
class ActorFactoryAtariDQN(ActorFactory):
def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool):
def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool, features_only: bool):
self.hidden_size = hidden_size
self.scale_obs = scale_obs
self.features_only = features_only
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
def create_module(self, envs: Environments, device: TDevice) -> Actor:
net_cls = scale_obs(DQN) if self.scale_obs else DQN
net = net_cls(
*envs.get_observation_shape(),
envs.get_action_shape(),
device=device,
features_only=True,
features_only=self.features_only,
output_dim=self.hidden_size,
layer_init=layer_init,
)

View File

@ -96,7 +96,7 @@ def main(
else None,
),
)
.with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs))
.with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs, features_only=True))
.with_critic_factory_use_actor()
.with_trainer_stop_callback(AtariStopCallback(task))
)

View File

@ -44,7 +44,7 @@ from tianshou.policy import (
)
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import ActorCritic
from tianshou.utils.net.common import ActorCritic, BaseActor
from tianshou.utils.string import ToStringMixin
CHECKPOINT_DICT_KEY_MODEL = "model"
@ -285,6 +285,10 @@ class _ActorCriticMixin:
raise ValueError(
"The options critic_use_actor_module and critic_use_action are mutually exclusive",
)
if not isinstance(actor, BaseActor):
raise ValueError(
f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}",
)
if envs.get_type().is_discrete():
critic = discrete.Critic(actor.get_preprocess_net(), device=device).to(device)
elif envs.get_type().is_continuous():
@ -444,17 +448,17 @@ class DQNAgentFactory(OffpolicyAgentFactory):
self,
params: DQNParams,
sampling_config: RLSamplingConfig,
critic_factory: CriticFactory,
actor_factory: ActorFactory,
optim_factory: OptimizerFactory,
):
super().__init__(sampling_config, optim_factory)
self.params = params
self.critic_factory = critic_factory
self.actor_factory = actor_factory
self.optim_factory = optim_factory
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
critic = self.critic_factory.create_module(envs, device, use_action=True)
optim = self.optim_factory.create_optimizer(critic, self.params.lr)
model = self.actor_factory.create_module(envs, device)
optim = self.optim_factory.create_optimizer(model, self.params.lr)
kwargs = self.params.create_kwargs(
ParamTransformerData(
envs=envs,
@ -467,7 +471,7 @@ class DQNAgentFactory(OffpolicyAgentFactory):
# noinspection PyTypeChecker
action_space: gymnasium.spaces.Discrete = envs.get_action_space()
return DQNPolicy(
model=critic,
model=model,
optim=optim,
action_space=action_space,
observation_space=envs.get_observation_space(),

View File

@ -461,7 +461,7 @@ class PPOExperimentBuilder(
class DQNExperimentBuilder(
RLExperimentBuilder,
_BuilderMixinSingleCriticFactory,
_BuilderMixinActorFactory,
):
def __init__(
self,
@ -470,7 +470,7 @@ class DQNExperimentBuilder(
sampling_config: RLSamplingConfig,
):
super().__init__(experiment_config, env_factory, sampling_config)
_BuilderMixinSingleCriticFactory.__init__(self)
_BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
self._params: DQNParams = DQNParams()
def with_dqn_params(self, params: DQNParams) -> Self:
@ -482,7 +482,7 @@ class DQNExperimentBuilder(
return DQNAgentFactory(
self._params,
self._sampling_config,
self._get_critic_factory(0),
self._get_actor_factory(),
self._get_optim_factory(),
)

View File

@ -14,11 +14,12 @@ from tianshou.utils.string import ToStringMixin
class ContinuousActorType:
GAUSSIAN = "gaussian"
DETERMINISTIC = "deterministic"
UNSUPPORTED = "unsupported"
class ActorFactory(ToStringMixin, ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
pass
@staticmethod
@ -67,11 +68,14 @@ class ActorFactoryDefault(ActorFactory):
)
case ContinuousActorType.DETERMINISTIC:
factory = ActorFactoryContinuousDeterministicNet(self.hidden_sizes)
case ContinuousActorType.UNSUPPORTED:
raise ValueError("Continuous action spaces are not supported by the algorithm")
case _:
raise ValueError(self.continuous_actor_type)
return factory.create_module(envs, device)
elif env_type == EnvType.DISCRETE:
raise NotImplementedError
factory = ActorFactoryDiscreteNet(self.DEFAULT_HIDDEN_SIZES)
return factory.create_module(envs, device)
else:
raise ValueError(f"{env_type} not supported")

View File

@ -5,7 +5,7 @@ from torch import nn
from tianshou.highlevel.env import Environments, EnvType
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
from tianshou.utils.net import continuous
from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import Net
from tianshou.utils.string import ToStringMixin
@ -30,16 +30,13 @@ class CriticFactoryDefault(CriticFactory):
factory = CriticFactoryContinuousNet(self.hidden_sizes)
return factory.create_module(envs, device, use_action)
elif env_type == EnvType.DISCRETE:
raise NotImplementedError
factory = CriticFactoryDiscreteNet(self.hidden_sizes)
return factory.create_module(envs, device, use_action)
else:
raise ValueError(f"{env_type} not supported")
class CriticFactoryContinuous(CriticFactory, ABC):
pass
class CriticFactoryContinuousNet(CriticFactoryContinuous):
class CriticFactoryContinuousNet(CriticFactory):
def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes
@ -56,3 +53,22 @@ class CriticFactoryContinuousNet(CriticFactoryContinuous):
critic = continuous.Critic(net_c, device=device).to(device)
init_linear_orthogonal(critic)
return critic
class CriticFactoryDiscreteNet(CriticFactory):
def __init__(self, hidden_sizes: Sequence[int]):
self.hidden_sizes = hidden_sizes
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
action_shape = envs.get_action_shape() if use_action else 0
net_c = Net(
envs.get_observation_shape(),
action_shape=action_shape,
hidden_sizes=self.hidden_sizes,
concat=use_action,
activation=nn.Tanh,
device=device,
)
critic = discrete.Critic(net_c, device=device).to(device)
init_linear_orthogonal(critic)
return critic