Change high-level DQN interface to expect an actor instead of a critic,
because that is what is functionally required
This commit is contained in:
parent
1cba589bd4
commit
b54fcd12cb
@ -6,7 +6,7 @@ import os
|
|||||||
from jsonargparse import CLI
|
from jsonargparse import CLI
|
||||||
|
|
||||||
from examples.atari.atari_network import (
|
from examples.atari.atari_network import (
|
||||||
CriticFactoryAtariDQN,
|
ActorFactoryAtariPlainDQN,
|
||||||
FeatureNetFactoryDQN,
|
FeatureNetFactoryDQN,
|
||||||
)
|
)
|
||||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||||
@ -103,7 +103,7 @@ def main(
|
|||||||
target_update_freq=target_update_freq,
|
target_update_freq=target_update_freq,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.with_critic_factory(CriticFactoryAtariDQN())
|
.with_actor_factory(ActorFactoryAtariPlainDQN())
|
||||||
.with_trainer_epoch_callback_train(TrainEpochCallback())
|
.with_trainer_epoch_callback_train(TrainEpochCallback())
|
||||||
.with_trainer_epoch_callback_test(TestEpochCallback())
|
.with_trainer_epoch_callback_test(TestEpochCallback())
|
||||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||||
|
@ -8,8 +8,6 @@ from torch import nn
|
|||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.highlevel.module.actor import ActorFactory
|
from tianshou.highlevel.module.actor import ActorFactory
|
||||||
from tianshou.highlevel.module.core import Module, ModuleFactory, TDevice
|
from tianshou.highlevel.module.core import Module, ModuleFactory, TDevice
|
||||||
from tianshou.highlevel.module.critic import CriticFactory
|
|
||||||
from tianshou.utils.net.common import BaseActor
|
|
||||||
from tianshou.utils.net.discrete import Actor, NoisyLinear
|
from tianshou.utils.net.discrete import Actor, NoisyLinear
|
||||||
|
|
||||||
|
|
||||||
@ -227,14 +225,8 @@ class QRDQN(DQN):
|
|||||||
return obs, state
|
return obs, state
|
||||||
|
|
||||||
|
|
||||||
class CriticFactoryAtariDQN(CriticFactory):
|
class ActorFactoryAtariPlainDQN(ActorFactory):
|
||||||
def create_module(
|
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
|
||||||
self,
|
|
||||||
envs: Environments,
|
|
||||||
device: TDevice,
|
|
||||||
use_action: bool,
|
|
||||||
) -> torch.nn.Module:
|
|
||||||
assert use_action
|
|
||||||
return DQN(
|
return DQN(
|
||||||
*envs.get_observation_shape(),
|
*envs.get_observation_shape(),
|
||||||
envs.get_action_shape(),
|
envs.get_action_shape(),
|
||||||
@ -243,17 +235,18 @@ class CriticFactoryAtariDQN(CriticFactory):
|
|||||||
|
|
||||||
|
|
||||||
class ActorFactoryAtariDQN(ActorFactory):
|
class ActorFactoryAtariDQN(ActorFactory):
|
||||||
def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool):
|
def __init__(self, hidden_size: int | Sequence[int], scale_obs: bool, features_only: bool):
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.scale_obs = scale_obs
|
self.scale_obs = scale_obs
|
||||||
|
self.features_only = features_only
|
||||||
|
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
def create_module(self, envs: Environments, device: TDevice) -> Actor:
|
||||||
net_cls = scale_obs(DQN) if self.scale_obs else DQN
|
net_cls = scale_obs(DQN) if self.scale_obs else DQN
|
||||||
net = net_cls(
|
net = net_cls(
|
||||||
*envs.get_observation_shape(),
|
*envs.get_observation_shape(),
|
||||||
envs.get_action_shape(),
|
envs.get_action_shape(),
|
||||||
device=device,
|
device=device,
|
||||||
features_only=True,
|
features_only=self.features_only,
|
||||||
output_dim=self.hidden_size,
|
output_dim=self.hidden_size,
|
||||||
layer_init=layer_init,
|
layer_init=layer_init,
|
||||||
)
|
)
|
||||||
|
@ -96,7 +96,7 @@ def main(
|
|||||||
else None,
|
else None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs))
|
.with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs, features_only=True))
|
||||||
.with_critic_factory_use_actor()
|
.with_critic_factory_use_actor()
|
||||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||||
)
|
)
|
||||||
|
@ -44,7 +44,7 @@ from tianshou.policy import (
|
|||||||
)
|
)
|
||||||
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
||||||
from tianshou.utils.net import continuous, discrete
|
from tianshou.utils.net import continuous, discrete
|
||||||
from tianshou.utils.net.common import ActorCritic
|
from tianshou.utils.net.common import ActorCritic, BaseActor
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
CHECKPOINT_DICT_KEY_MODEL = "model"
|
CHECKPOINT_DICT_KEY_MODEL = "model"
|
||||||
@ -285,6 +285,10 @@ class _ActorCriticMixin:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The options critic_use_actor_module and critic_use_action are mutually exclusive",
|
"The options critic_use_actor_module and critic_use_action are mutually exclusive",
|
||||||
)
|
)
|
||||||
|
if not isinstance(actor, BaseActor):
|
||||||
|
raise ValueError(
|
||||||
|
f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}",
|
||||||
|
)
|
||||||
if envs.get_type().is_discrete():
|
if envs.get_type().is_discrete():
|
||||||
critic = discrete.Critic(actor.get_preprocess_net(), device=device).to(device)
|
critic = discrete.Critic(actor.get_preprocess_net(), device=device).to(device)
|
||||||
elif envs.get_type().is_continuous():
|
elif envs.get_type().is_continuous():
|
||||||
@ -444,17 +448,17 @@ class DQNAgentFactory(OffpolicyAgentFactory):
|
|||||||
self,
|
self,
|
||||||
params: DQNParams,
|
params: DQNParams,
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: RLSamplingConfig,
|
||||||
critic_factory: CriticFactory,
|
actor_factory: ActorFactory,
|
||||||
optim_factory: OptimizerFactory,
|
optim_factory: OptimizerFactory,
|
||||||
):
|
):
|
||||||
super().__init__(sampling_config, optim_factory)
|
super().__init__(sampling_config, optim_factory)
|
||||||
self.params = params
|
self.params = params
|
||||||
self.critic_factory = critic_factory
|
self.actor_factory = actor_factory
|
||||||
self.optim_factory = optim_factory
|
self.optim_factory = optim_factory
|
||||||
|
|
||||||
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
|
||||||
critic = self.critic_factory.create_module(envs, device, use_action=True)
|
model = self.actor_factory.create_module(envs, device)
|
||||||
optim = self.optim_factory.create_optimizer(critic, self.params.lr)
|
optim = self.optim_factory.create_optimizer(model, self.params.lr)
|
||||||
kwargs = self.params.create_kwargs(
|
kwargs = self.params.create_kwargs(
|
||||||
ParamTransformerData(
|
ParamTransformerData(
|
||||||
envs=envs,
|
envs=envs,
|
||||||
@ -467,7 +471,7 @@ class DQNAgentFactory(OffpolicyAgentFactory):
|
|||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
action_space: gymnasium.spaces.Discrete = envs.get_action_space()
|
action_space: gymnasium.spaces.Discrete = envs.get_action_space()
|
||||||
return DQNPolicy(
|
return DQNPolicy(
|
||||||
model=critic,
|
model=model,
|
||||||
optim=optim,
|
optim=optim,
|
||||||
action_space=action_space,
|
action_space=action_space,
|
||||||
observation_space=envs.get_observation_space(),
|
observation_space=envs.get_observation_space(),
|
||||||
|
@ -461,7 +461,7 @@ class PPOExperimentBuilder(
|
|||||||
|
|
||||||
class DQNExperimentBuilder(
|
class DQNExperimentBuilder(
|
||||||
RLExperimentBuilder,
|
RLExperimentBuilder,
|
||||||
_BuilderMixinSingleCriticFactory,
|
_BuilderMixinActorFactory,
|
||||||
):
|
):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -470,7 +470,7 @@ class DQNExperimentBuilder(
|
|||||||
sampling_config: RLSamplingConfig,
|
sampling_config: RLSamplingConfig,
|
||||||
):
|
):
|
||||||
super().__init__(experiment_config, env_factory, sampling_config)
|
super().__init__(experiment_config, env_factory, sampling_config)
|
||||||
_BuilderMixinSingleCriticFactory.__init__(self)
|
_BuilderMixinActorFactory.__init__(self, ContinuousActorType.UNSUPPORTED)
|
||||||
self._params: DQNParams = DQNParams()
|
self._params: DQNParams = DQNParams()
|
||||||
|
|
||||||
def with_dqn_params(self, params: DQNParams) -> Self:
|
def with_dqn_params(self, params: DQNParams) -> Self:
|
||||||
@ -482,7 +482,7 @@ class DQNExperimentBuilder(
|
|||||||
return DQNAgentFactory(
|
return DQNAgentFactory(
|
||||||
self._params,
|
self._params,
|
||||||
self._sampling_config,
|
self._sampling_config,
|
||||||
self._get_critic_factory(0),
|
self._get_actor_factory(),
|
||||||
self._get_optim_factory(),
|
self._get_optim_factory(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -14,11 +14,12 @@ from tianshou.utils.string import ToStringMixin
|
|||||||
class ContinuousActorType:
|
class ContinuousActorType:
|
||||||
GAUSSIAN = "gaussian"
|
GAUSSIAN = "gaussian"
|
||||||
DETERMINISTIC = "deterministic"
|
DETERMINISTIC = "deterministic"
|
||||||
|
UNSUPPORTED = "unsupported"
|
||||||
|
|
||||||
|
|
||||||
class ActorFactory(ToStringMixin, ABC):
|
class ActorFactory(ToStringMixin, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -67,11 +68,14 @@ class ActorFactoryDefault(ActorFactory):
|
|||||||
)
|
)
|
||||||
case ContinuousActorType.DETERMINISTIC:
|
case ContinuousActorType.DETERMINISTIC:
|
||||||
factory = ActorFactoryContinuousDeterministicNet(self.hidden_sizes)
|
factory = ActorFactoryContinuousDeterministicNet(self.hidden_sizes)
|
||||||
|
case ContinuousActorType.UNSUPPORTED:
|
||||||
|
raise ValueError("Continuous action spaces are not supported by the algorithm")
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(self.continuous_actor_type)
|
raise ValueError(self.continuous_actor_type)
|
||||||
return factory.create_module(envs, device)
|
return factory.create_module(envs, device)
|
||||||
elif env_type == EnvType.DISCRETE:
|
elif env_type == EnvType.DISCRETE:
|
||||||
raise NotImplementedError
|
factory = ActorFactoryDiscreteNet(self.DEFAULT_HIDDEN_SIZES)
|
||||||
|
return factory.create_module(envs, device)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{env_type} not supported")
|
raise ValueError(f"{env_type} not supported")
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from torch import nn
|
|||||||
|
|
||||||
from tianshou.highlevel.env import Environments, EnvType
|
from tianshou.highlevel.env import Environments, EnvType
|
||||||
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
||||||
from tianshou.utils.net import continuous
|
from tianshou.utils.net import continuous, discrete
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
from tianshou.utils.string import ToStringMixin
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
@ -30,16 +30,13 @@ class CriticFactoryDefault(CriticFactory):
|
|||||||
factory = CriticFactoryContinuousNet(self.hidden_sizes)
|
factory = CriticFactoryContinuousNet(self.hidden_sizes)
|
||||||
return factory.create_module(envs, device, use_action)
|
return factory.create_module(envs, device, use_action)
|
||||||
elif env_type == EnvType.DISCRETE:
|
elif env_type == EnvType.DISCRETE:
|
||||||
raise NotImplementedError
|
factory = CriticFactoryDiscreteNet(self.hidden_sizes)
|
||||||
|
return factory.create_module(envs, device, use_action)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{env_type} not supported")
|
raise ValueError(f"{env_type} not supported")
|
||||||
|
|
||||||
|
|
||||||
class CriticFactoryContinuous(CriticFactory, ABC):
|
class CriticFactoryContinuousNet(CriticFactory):
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class CriticFactoryContinuousNet(CriticFactoryContinuous):
|
|
||||||
def __init__(self, hidden_sizes: Sequence[int]):
|
def __init__(self, hidden_sizes: Sequence[int]):
|
||||||
self.hidden_sizes = hidden_sizes
|
self.hidden_sizes = hidden_sizes
|
||||||
|
|
||||||
@ -56,3 +53,22 @@ class CriticFactoryContinuousNet(CriticFactoryContinuous):
|
|||||||
critic = continuous.Critic(net_c, device=device).to(device)
|
critic = continuous.Critic(net_c, device=device).to(device)
|
||||||
init_linear_orthogonal(critic)
|
init_linear_orthogonal(critic)
|
||||||
return critic
|
return critic
|
||||||
|
|
||||||
|
|
||||||
|
class CriticFactoryDiscreteNet(CriticFactory):
|
||||||
|
def __init__(self, hidden_sizes: Sequence[int]):
|
||||||
|
self.hidden_sizes = hidden_sizes
|
||||||
|
|
||||||
|
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||||
|
action_shape = envs.get_action_shape() if use_action else 0
|
||||||
|
net_c = Net(
|
||||||
|
envs.get_observation_shape(),
|
||||||
|
action_shape=action_shape,
|
||||||
|
hidden_sizes=self.hidden_sizes,
|
||||||
|
concat=use_action,
|
||||||
|
activation=nn.Tanh,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
critic = discrete.Critic(net_c, device=device).to(device)
|
||||||
|
init_linear_orthogonal(critic)
|
||||||
|
return critic
|
||||||
|
Loading…
x
Reference in New Issue
Block a user