Support IQN in high-level API

* Add example atari_iqn_hl
* Factor out trainer callbacks to new module atari_callbacks
* Extract base class for DQN-based agent factories
* Improved module factory interface design, achieving higher generality
This commit is contained in:
Dominik Jain 2023-10-11 15:31:38 +02:00
parent 213e08a846
commit a8a367c42d
12 changed files with 309 additions and 41 deletions

View File

@ -0,0 +1,33 @@
from tianshou.highlevel.trainer import (
TrainerEpochCallbackTest,
TrainerEpochCallbackTrain,
TrainingContext,
)
from tianshou.policy import DQNPolicy
class TestEpochCallbackDQNSetEps(TrainerEpochCallbackTest):
def __init__(self, eps_test: float):
self.eps_test = eps_test
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy: DQNPolicy = context.policy
policy.set_eps(self.eps_test)
class TrainEpochCallbackNatureDQNEpsLinearDecay(TrainerEpochCallbackTrain):
def __init__(self, eps_train: float, eps_train_final: float):
self.eps_train = eps_train
self.eps_train_final = eps_train_final
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy: DQNPolicy = context.policy
logger = context.logger.logger
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = self.eps_train - env_step / 1e6 * (self.eps_train - self.eps_train_final)
else:
eps = self.eps_train_final
policy.set_eps(eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})

View File

@ -1,13 +1,12 @@
#!/usr/bin/env python3
import datetime
import os
from jsonargparse import CLI
from examples.atari.atari_network import (
ActorFactoryAtariPlainDQN,
FeatureNetFactoryDQN,
IntermediateModuleFactoryAtariDQN,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import SamplingConfig
@ -26,6 +25,7 @@ from tianshou.highlevel.trainer import (
)
from tianshou.policy import DQNPolicy
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag
def main(
@ -53,8 +53,7 @@ def main(
icm_reward_scale: float = 0.01,
icm_forward_loss_weight: float = 0.2,
):
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
log_name = os.path.join(task, "dqn", str(experiment_config.seed), datetime_tag())
sampling_config = SamplingConfig(
num_epochs=epoch,
@ -115,7 +114,7 @@ def main(
if icm_lr_scale > 0:
builder.with_policy_wrapper_factory(
PolicyWrapperFactoryIntrinsicCuriosity(
FeatureNetFactoryDQN(),
IntermediateModuleFactoryAtariDQN(net_only=True),
[512],
lr,
icm_lr_scale,

View File

@ -0,0 +1,105 @@
#!/usr/bin/env python3
import os
from collections.abc import Sequence
from jsonargparse import CLI
from examples.atari.atari_callbacks import (
TestEpochCallbackDQNSetEps,
TrainEpochCallbackNatureDQNEpsLinearDecay,
)
from examples.atari.atari_network import (
IntermediateModuleFactoryAtariDQN,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
ExperimentConfig,
IQNExperimentBuilder,
)
from tianshou.highlevel.params.policy_params import IQNParams
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag
def main(
experiment_config: ExperimentConfig,
task: str = "PongNoFrameskip-v4",
scale_obs: int = 0,
eps_test: float = 0.005,
eps_train: float = 1.0,
eps_train_final: float = 0.05,
buffer_size: int = 100000,
lr: float = 0.0001,
gamma: float = 0.99,
sample_size: int = 32,
online_sample_size: int = 8,
target_sample_size: int = 8,
num_cosines: int = 64,
hidden_sizes: Sequence[int] = (512,),
n_step: int = 3,
target_update_freq: int = 500,
epoch: int = 100,
step_per_epoch: int = 100000,
step_per_collect: int = 10,
update_per_step: float = 0.1,
batch_size: int = 32,
training_num: int = 10,
test_num: int = 10,
frames_stack: int = 4,
save_buffer_name: str | None = None, # TODO support?
):
log_name = os.path.join(task, "iqn", str(experiment_config.seed), datetime_tag())
sampling_config = SamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
batch_size=batch_size,
num_train_envs=training_num,
num_test_envs=test_num,
buffer_size=buffer_size,
step_per_collect=step_per_collect,
update_per_step=update_per_step,
repeat_per_collect=None,
replay_buffer_stack_num=frames_stack,
replay_buffer_ignore_obs_next=True,
replay_buffer_save_only_last_obs=True,
)
env_factory = AtariEnvFactory(
task,
experiment_config.seed,
sampling_config,
frames_stack,
scale=scale_obs,
)
experiment = (
IQNExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_iqn_params(
IQNParams(
discount_factor=gamma,
estimation_step=n_step,
lr=lr,
sample_size=sample_size,
online_sample_size=online_sample_size,
target_update_freq=target_update_freq,
target_sample_size=target_sample_size,
hidden_sizes=hidden_sizes,
num_cosines=num_cosines,
),
)
.with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(net_only=False))
.with_trainer_epoch_callback_train(
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
)
.with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test))
.with_trainer_stop_callback(AtariStopCallback(task))
.build()
)
experiment.run(log_name)
if __name__ == "__main__":
logging.run_main(lambda: CLI(main))

View File

@ -7,7 +7,11 @@ from torch import nn
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.actor import ActorFactory
from tianshou.highlevel.module.core import Module, ModuleFactory, TDevice
from tianshou.highlevel.module.core import (
IntermediateModule,
IntermediateModuleFactory,
TDevice,
)
from tianshou.utils.net.discrete import Actor, NoisyLinear
@ -253,12 +257,15 @@ class ActorFactoryAtariDQN(ActorFactory):
return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device)
class FeatureNetFactoryDQN(ModuleFactory):
def create_module(self, envs: Environments, device: TDevice) -> Module:
class IntermediateModuleFactoryAtariDQN(IntermediateModuleFactory):
def __init__(self, net_only: bool):
self.net_only = net_only
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
dqn = DQN(
*envs.get_observation_shape(),
envs.get_action_shape(),
device,
device=device,
features_only=True,
)
return Module(dqn.net, dqn.output_dim)
return IntermediateModule(dqn.net if self.net_only else dqn, dqn.output_dim)

View File

@ -8,7 +8,7 @@ from jsonargparse import CLI
from examples.atari.atari_network import (
ActorFactoryAtariDQN,
FeatureNetFactoryDQN,
IntermediateModuleFactoryAtariDQN,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import SamplingConfig
@ -103,7 +103,7 @@ def main(
if icm_lr_scale > 0:
builder.with_policy_wrapper_factory(
PolicyWrapperFactoryIntrinsicCuriosity(
FeatureNetFactoryDQN(),
IntermediateModuleFactoryAtariDQN(net_only=True),
[hidden_sizes],
lr,
icm_lr_scale,

View File

@ -6,7 +6,7 @@ from jsonargparse import CLI
from examples.atari.atari_network import (
ActorFactoryAtariDQN,
FeatureNetFactoryDQN,
IntermediateModuleFactoryAtariDQN,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import SamplingConfig
@ -26,7 +26,7 @@ from tianshou.utils.logging import datetime_tag
def main(
experiment_config: ExperimentConfig,
task: str = "PongNoFrameskip-v4",
scale_obs: bool = False,
scale_obs: int = 0,
buffer_size: int = 100000,
actor_lr: float = 1e-5,
critic_lr: float = 1e-5,
@ -67,7 +67,9 @@ def main(
replay_buffer_save_only_last_obs=True,
)
env_factory = AtariEnvFactory(task, experiment_config.seed, sampling_config, frames_stack)
env_factory = AtariEnvFactory(
task, experiment_config.seed, sampling_config, frames_stack, scale=scale_obs,
)
builder = (
DiscreteSACExperimentBuilder(env_factory, experiment_config, sampling_config)
@ -82,14 +84,14 @@ def main(
estimation_step=n_step,
),
)
.with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs, features_only=True))
.with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs=False, features_only=True))
.with_common_critic_factory_use_actor()
.with_trainer_stop_callback(AtariStopCallback(task))
)
if icm_lr_scale > 0:
builder.with_policy_wrapper_factory(
PolicyWrapperFactoryIntrinsicCuriosity(
FeatureNetFactoryDQN(),
IntermediateModuleFactoryAtariDQN(net_only=True),
[hidden_size],
actor_lr,
icm_lr_scale,

View File

@ -13,7 +13,10 @@ from tianshou.highlevel.logger import Logger
from tianshou.highlevel.module.actor import (
ActorFactory,
)
from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.module.core import (
ModuleFactory,
TDevice,
)
from tianshou.highlevel.module.critic import CriticEnsembleFactory, CriticFactory
from tianshou.highlevel.module.module_opt import (
ActorCriticModuleOpt,
@ -24,6 +27,7 @@ from tianshou.highlevel.params.policy_params import (
DDPGParams,
DiscreteSACParams,
DQNParams,
IQNParams,
NPGParams,
Params,
ParamsMixinActorAndDualCritics,
@ -44,6 +48,7 @@ from tianshou.policy import (
DDPGPolicy,
DiscreteSACPolicy,
DQNPolicy,
IQNPolicy,
NPGPolicy,
PGPolicy,
PPOPolicy,
@ -61,6 +66,9 @@ CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
TParams = TypeVar("TParams", bound=Params)
TActorCriticParams = TypeVar("TActorCriticParams", bound=ParamsMixinLearningRateWithScheduler)
TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics)
TDiscreteCriticOnlyParams = TypeVar(
"TDiscreteCriticOnlyParams", bound=ParamsMixinLearningRateWithScheduler,
)
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
@ -394,21 +402,27 @@ class TRPOAgentFactory(ActorCriticAgentFactory[TRPOParams, TRPOPolicy]):
return TRPOPolicy
class DQNAgentFactory(OffpolicyAgentFactory):
class DiscreteCriticOnlyAgentFactory(
OffpolicyAgentFactory, Generic[TDiscreteCriticOnlyParams, TPolicy],
):
def __init__(
self,
params: DQNParams,
params: TDiscreteCriticOnlyParams,
sampling_config: SamplingConfig,
actor_factory: ActorFactory,
model_factory: ModuleFactory,
optim_factory: OptimizerFactory,
):
super().__init__(sampling_config, optim_factory)
self.params = params
self.actor_factory = actor_factory
self.model_factory = model_factory
self.optim_factory = optim_factory
def _create_policy(self, envs: Environments, device: TDevice) -> BasePolicy:
model = self.actor_factory.create_module(envs, device)
@abstractmethod
def _get_policy_class(self) -> type[TPolicy]:
pass
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
model = self.model_factory.create_module(envs, device)
optim = self.optim_factory.create_optimizer(model, self.params.lr)
kwargs = self.params.create_kwargs(
ParamTransformerData(
@ -420,7 +434,8 @@ class DQNAgentFactory(OffpolicyAgentFactory):
)
envs.get_type().assert_discrete(self)
action_space = cast(gymnasium.spaces.Discrete, envs.get_action_space())
return DQNPolicy(
policy_class = self._get_policy_class()
return policy_class(
model=model,
optim=optim,
action_space=action_space,
@ -429,6 +444,16 @@ class DQNAgentFactory(OffpolicyAgentFactory):
)
class DQNAgentFactory(DiscreteCriticOnlyAgentFactory[DQNParams, DQNPolicy]):
def _get_policy_class(self) -> type[DQNPolicy]:
return DQNPolicy
class IQNAgentFactory(DiscreteCriticOnlyAgentFactory[IQNParams, IQNPolicy]):
def _get_policy_class(self) -> type[IQNPolicy]:
return IQNPolicy
class DDPGAgentFactory(OffpolicyAgentFactory):
def __init__(
self,

View File

@ -15,6 +15,7 @@ from tianshou.highlevel.agent import (
DDPGAgentFactory,
DiscreteSACAgentFactory,
DQNAgentFactory,
IQNAgentFactory,
NPGAgentFactory,
PGAgentFactory,
PPOAgentFactory,
@ -33,6 +34,11 @@ from tianshou.highlevel.module.actor import (
ActorFuture,
ActorFutureProviderProtocol,
ContinuousActorType,
IntermediateModuleFactoryFromActorFactory,
)
from tianshou.highlevel.module.core import (
ImplicitQuantileNetworkFactory,
IntermediateModuleFactory,
)
from tianshou.highlevel.module.critic import (
CriticEnsembleFactory,
@ -47,6 +53,7 @@ from tianshou.highlevel.params.policy_params import (
DDPGParams,
DiscreteSACParams,
DQNParams,
IQNParams,
NPGParams,
PGParams,
PPOParams,
@ -641,6 +648,41 @@ class DQNExperimentBuilder(
)
class IQNExperimentBuilder(ExperimentBuilder):
def __init__(
self,
env_factory: EnvFactory,
experiment_config: ExperimentConfig | None = None,
sampling_config: SamplingConfig | None = None,
):
super().__init__(env_factory, experiment_config, sampling_config)
self._params: IQNParams = IQNParams()
self._preprocess_network_factory = IntermediateModuleFactoryFromActorFactory(
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED),
)
def with_iqn_params(self, params: IQNParams) -> Self:
self._params = params
return self
def with_preprocess_network_factory(self, module_factory: IntermediateModuleFactory) -> Self:
self._preprocess_network_factory = module_factory
return self
def _create_agent_factory(self) -> AgentFactory:
model_factory = ImplicitQuantileNetworkFactory(
self._preprocess_network_factory,
hidden_sizes=self._params.hidden_sizes,
num_cosines=self._params.num_cosines,
)
return IQNAgentFactory(
self._params,
self._sampling_config,
model_factory,
self._get_optim_factory(),
)
class DDPGExperimentBuilder(
ExperimentBuilder,
_BuilderMixinActorFactory_ContinuousDeterministic,

View File

@ -8,7 +8,13 @@ import torch
from torch import nn
from tianshou.highlevel.env import Environments, EnvType
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
from tianshou.highlevel.module.core import (
IntermediateModule,
IntermediateModuleFactory,
ModuleFactory,
TDevice,
init_linear_orthogonal,
)
from tianshou.highlevel.module.module_opt import ModuleOpt
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net import continuous, discrete
@ -34,7 +40,7 @@ class ActorFutureProviderProtocol(Protocol):
pass
class ActorFactory(ToStringMixin, ABC):
class ActorFactory(ModuleFactory, ToStringMixin, ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
pass
@ -212,3 +218,13 @@ class ActorFactoryTransientStorageDecorator(ActorFactory):
module = self.actor_factory.create_module(envs, device)
self._actor_future.actor = module
return module
class IntermediateModuleFactoryFromActorFactory(IntermediateModuleFactory):
def __init__(self, actor_factory: ActorFactory):
self.actor_factory = actor_factory
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
actor = self.actor_factory.create_module(envs, device)
assert isinstance(actor, BaseActor)
return IntermediateModule(actor, actor.get_output_dim())

View File

@ -7,7 +7,7 @@ import numpy as np
import torch
from tianshou.highlevel.env import Environments
from tianshou.utils.net.common import Net
from tianshou.utils.net.discrete import ImplicitQuantileNetwork
from tianshou.utils.string import ToStringMixin
TDevice: TypeAlias = str | torch.device
@ -24,24 +24,45 @@ def init_linear_orthogonal(module: torch.nn.Module) -> None:
torch.nn.init.zeros_(m.bias)
class ModuleFactory(ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
pass
@dataclass
class Module:
class IntermediateModule:
module: torch.nn.Module
output_dim: int
class ModuleFactory(ToStringMixin, ABC):
class IntermediateModuleFactory(ToStringMixin, ModuleFactory, ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> Module:
def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule:
pass
def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module:
return self.create_intermediate_module(envs, device).module
class ModuleFactoryNet(
ModuleFactory,
): # TODO This is unused and broken; use it in ActorFactory* and so on?
def __init__(self, hidden_sizes: int | Sequence[int]):
class ImplicitQuantileNetworkFactory(ModuleFactory, ToStringMixin):
def __init__(
self,
preprocess_net_factory: IntermediateModuleFactory,
hidden_sizes: Sequence[int] = (),
num_cosines: int = 64,
):
self.preprocess_net_factory = preprocess_net_factory
self.hidden_sizes = hidden_sizes
self.num_cosines = num_cosines
def create_module(self, envs: Environments, device: TDevice) -> Module:
module = Net(envs.get_observation_shape())
return Module(module, module.output_dim)
def create_module(self, envs: Environments, device: TDevice) -> ImplicitQuantileNetwork:
preprocess_net = self.preprocess_net_factory.create_intermediate_module(envs, device)
return ImplicitQuantileNetwork(
preprocess_net=preprocess_net.module,
action_shape=envs.get_action_shape(),
hidden_sizes=self.hidden_sizes,
num_cosines=self.num_cosines,
preprocess_net_output_dim=preprocess_net.output_dim,
device=device,
).to(device)

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import asdict, dataclass
from typing import Any, Literal, Protocol
@ -388,6 +389,23 @@ class DQNParams(Params, ParamsMixinLearningRateWithScheduler):
return transformers
@dataclass
class IQNParams(DQNParams):
sample_size: int = 32
online_sample_size: int = 8
target_sample_size: int = 8
num_quantiles: int = 200
hidden_sizes: Sequence[int] = ()
"""hidden dimensions to use in the IQN network"""
num_cosines: int = 64
"""number of cosines to use in the IQN network"""
def _get_param_transformers(self) -> list[ParamTransformer]:
transformers = super()._get_param_transformers()
transformers.append(ParamTransformerDrop("hidden_sizes", "num_cosines"))
return transformers
@dataclass
class DDPGParams(Params, ParamsMixinActorAndCritic):
tau: float = 0.005

View File

@ -3,7 +3,7 @@ from collections.abc import Sequence
from typing import Generic, TypeVar
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import ModuleFactory, TDevice
from tianshou.highlevel.module.core import IntermediateModuleFactory, TDevice
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.policy import BasePolicy, ICMPolicy
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
@ -31,7 +31,7 @@ class PolicyWrapperFactoryIntrinsicCuriosity(
):
def __init__(
self,
feature_net_factory: ModuleFactory,
feature_net_factory: IntermediateModuleFactory,
hidden_sizes: Sequence[int],
lr: float,
lr_scale: float,
@ -52,7 +52,7 @@ class PolicyWrapperFactoryIntrinsicCuriosity(
optim_factory: OptimizerFactory,
device: TDevice,
) -> ICMPolicy:
feature_net = self.feature_net_factory.create_module(envs, device)
feature_net = self.feature_net_factory.create_intermediate_module(envs, device)
action_dim = envs.get_action_shape()
if not isinstance(action_dim, int):
raise ValueError(f"Environment action shape must be an integer, got {action_dim}")