Make mypy happy

This commit is contained in:
Dominik Jain 2023-10-13 12:25:28 +02:00
parent 76e870207d
commit 023b33c917
11 changed files with 85 additions and 51 deletions

View File

@ -1,4 +1,5 @@
import logging import logging
import typing
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Generic, TypeVar, cast from typing import Any, Generic, TypeVar, cast
@ -63,11 +64,17 @@ from tianshou.utils.string import ToStringMixin
CHECKPOINT_DICT_KEY_MODEL = "model" CHECKPOINT_DICT_KEY_MODEL = "model"
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms" CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
TParams = TypeVar("TParams", bound=Params) TParams = TypeVar("TParams", bound=Params)
TActorCriticParams = TypeVar("TActorCriticParams", bound=ParamsMixinLearningRateWithScheduler) TActorCriticParams = TypeVar(
TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics) "TActorCriticParams",
bound=Params | ParamsMixinLearningRateWithScheduler,
)
TActorDualCriticsParams = TypeVar(
"TActorDualCriticsParams",
bound=Params | ParamsMixinActorAndDualCritics,
)
TDiscreteCriticOnlyParams = TypeVar( TDiscreteCriticOnlyParams = TypeVar(
"TDiscreteCriticOnlyParams", "TDiscreteCriticOnlyParams",
bound=ParamsMixinLearningRateWithScheduler, bound=Params | ParamsMixinLearningRateWithScheduler,
) )
TPolicy = TypeVar("TPolicy", bound=BasePolicy) TPolicy = TypeVar("TPolicy", bound=BasePolicy)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -321,6 +328,7 @@ class ActorCriticAgentFactory(
optim = self.optim_factory.create_optimizer(actor_critic, lr) optim = self.optim_factory.create_optimizer(actor_critic, lr)
return ActorCriticModuleOpt(actor_critic, optim) return ActorCriticModuleOpt(actor_critic, optim)
@typing.no_type_check
def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]: def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]:
actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr) actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr)
kwargs = self.params.create_kwargs( kwargs = self.params.create_kwargs(
@ -382,6 +390,7 @@ class DiscreteCriticOnlyAgentFactory(
def _get_policy_class(self) -> type[TPolicy]: def _get_policy_class(self) -> type[TPolicy]:
pass pass
@typing.no_type_check
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
model = self.model_factory.create_module(envs, device) model = self.model_factory.create_module(envs, device)
optim = self.optim_factory.create_optimizer(model, self.params.lr) optim = self.optim_factory.create_optimizer(model, self.params.lr)
@ -548,6 +557,7 @@ class ActorDualCriticsAgentFactory(
def _get_critic_use_action(envs: Environments) -> bool: def _get_critic_use_action(envs: Environments) -> bool:
return envs.get_type().is_continuous() return envs.get_type().is_continuous()
@typing.no_type_check
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy: def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
actor = self.actor_factory.create_module_opt( actor = self.actor_factory.create_module_opt(
envs, envs,

View File

@ -37,7 +37,7 @@ class Environments(ToStringMixin, ABC):
self.env = env self.env = env
self.train_envs = train_envs self.train_envs = train_envs
self.test_envs = test_envs self.test_envs = test_envs
self.persistence = [] self.persistence: Sequence[Persistence] = []
def _tostring_includes(self) -> list[str]: def _tostring_includes(self) -> list[str]:
return [] return []
@ -51,7 +51,7 @@ class Environments(ToStringMixin, ABC):
"state_shape": self.get_observation_shape(), "state_shape": self.get_observation_shape(),
} }
def set_persistence(self, *p: Persistence): def set_persistence(self, *p: Persistence) -> None:
self.persistence = p self.persistence = p
@abstractmethod @abstractmethod

View File

@ -27,7 +27,7 @@ from tianshou.highlevel.agent import (
) )
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import EnvFactory, Environments from tianshou.highlevel.env import EnvFactory, Environments
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory, TLogger
from tianshou.highlevel.module.actor import ( from tianshou.highlevel.module.actor import (
ActorFactory, ActorFactory,
ActorFactoryDefault, ActorFactoryDefault,
@ -142,13 +142,18 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
self.env_config = env_config self.env_config = env_config
@classmethod @classmethod
def from_directory(cls, directory: str) -> Self: def from_directory(cls, directory: str, restore_policy: bool = True) -> Self:
"""Restores an experiment from a previously stored pickle. """Restores an experiment from a previously stored pickle.
:param directory: persistence directory of a previous run, in which a pickled experiment is found :param directory: persistence directory of a previous run, in which a pickled experiment is found
:param restore_policy: whether the experiment shall be configured to restore the policy that was
persisted in the given directory
""" """
with open(os.path.join(directory, cls.EXPERIMENT_PICKLE_FILENAME), "rb") as f: with open(os.path.join(directory, cls.EXPERIMENT_PICKLE_FILENAME), "rb") as f:
return pickle.load(f) experiment: Experiment = pickle.load(f)
if restore_policy:
experiment.config.policy_restore_directory = directory
return experiment
def _set_seed(self) -> None: def _set_seed(self) -> None:
seed = self.config.seed seed = self.config.seed
@ -159,7 +164,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
def _build_config_dict(self) -> dict: def _build_config_dict(self) -> dict:
return {"experiment": self.pprints()} return {"experiment": self.pprints()}
def save(self, directory: str): def save(self, directory: str) -> None:
path = os.path.join(directory, self.EXPERIMENT_PICKLE_FILENAME) path = os.path.join(directory, self.EXPERIMENT_PICKLE_FILENAME)
log.info( log.info(
f"Saving serialized experiment in {path}; can be restored via Experiment.from_directory('{directory}')", f"Saving serialized experiment in {path}; can be restored via Experiment.from_directory('{directory}')",
@ -187,7 +192,8 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
os.makedirs(persistence_dir, exist_ok=True) os.makedirs(persistence_dir, exist_ok=True)
with logging.FileLoggerContext( with logging.FileLoggerContext(
os.path.join(persistence_dir, self.LOG_FILENAME), enabled=use_persistence, os.path.join(persistence_dir, self.LOG_FILENAME),
enabled=use_persistence,
): ):
# log initial information # log initial information
log.info(f"Running experiment (name='{experiment_name}'):\n{self.pprints()}") log.info(f"Running experiment (name='{experiment_name}'):\n{self.pprints()}")
@ -209,6 +215,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
# initialize logger # initialize logger
full_config = self._build_config_dict() full_config = self._build_config_dict()
full_config.update(envs.info()) full_config.update(envs.info())
logger: TLogger
if use_persistence: if use_persistence:
logger = self.logger_factory.create_logger( logger = self.logger_factory.create_logger(
log_dir=persistence_dir, log_dir=persistence_dir,
@ -460,7 +467,7 @@ class _BuilderMixinCriticsFactory:
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory): class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
def __init__(self, actor_future_provider: ActorFutureProviderProtocol = None) -> None: def __init__(self, actor_future_provider: ActorFutureProviderProtocol) -> None:
super().__init__(1, actor_future_provider) super().__init__(1, actor_future_provider)
def with_critic_factory(self, critic_factory: CriticFactory) -> Self: def with_critic_factory(self, critic_factory: CriticFactory) -> Self:
@ -553,7 +560,7 @@ class _BuilderMixinCriticEnsembleFactory:
self.critic_ensemble_factory = CriticEnsembleFactoryDefault(hidden_sizes) self.critic_ensemble_factory = CriticEnsembleFactoryDefault(hidden_sizes)
return self return self
def _get_critic_ensemble_factory(self): def _get_critic_ensemble_factory(self) -> CriticEnsembleFactory:
if self.critic_ensemble_factory is None: if self.critic_ensemble_factory is None:
return CriticEnsembleFactoryDefault() return CriticEnsembleFactoryDefault()
else: else:
@ -745,8 +752,10 @@ class IQNExperimentBuilder(ExperimentBuilder):
): ):
super().__init__(env_factory, experiment_config, sampling_config) super().__init__(env_factory, experiment_config, sampling_config)
self._params: IQNParams = IQNParams() self._params: IQNParams = IQNParams()
self._preprocess_network_factory = IntermediateModuleFactoryFromActorFactory( self._preprocess_network_factory: IntermediateModuleFactory = (
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED), IntermediateModuleFactoryFromActorFactory(
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED),
)
) )
def with_iqn_params(self, params: IQNParams) -> Self: def with_iqn_params(self, params: IQNParams) -> Self:

View File

@ -4,16 +4,20 @@ from typing import Literal, TypeAlias
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
TLogger: TypeAlias = TensorboardLogger | WandbLogger TLogger: TypeAlias = BaseLogger
class LoggerFactory(ToStringMixin, ABC): class LoggerFactory(ToStringMixin, ABC):
@abstractmethod @abstractmethod
def create_logger( def create_logger(
self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict, self,
log_dir: str,
experiment_name: str,
run_id: str | None,
config_dict: dict,
) -> TLogger: ) -> TLogger:
""":param log_dir: path to the directory in which log data is to be stored """:param log_dir: path to the directory in which log data is to be stored
:param experiment_name: the name of the job, which may contain os.path.sep :param experiment_name: the name of the job, which may contain os.path.sep
@ -35,7 +39,11 @@ class DefaultLoggerFactory(LoggerFactory):
self.wandb_project = wandb_project self.wandb_project = wandb_project
def create_logger( def create_logger(
self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict, self,
log_dir: str,
experiment_name: str,
run_id: str | None,
config_dict: dict,
) -> TLogger: ) -> TLogger:
writer = SummaryWriter(log_dir) writer = SummaryWriter(log_dir)
writer.add_text( writer.add_text(
@ -48,18 +56,18 @@ class DefaultLoggerFactory(LoggerFactory):
), ),
), ),
) )
logger: TLogger match self.logger_type:
if self.logger_type == "wandb": case "wandb":
logger = WandbLogger( wandb_logger = WandbLogger(
save_interval=1, save_interval=1,
name=experiment_name.replace(os.path.sep, "__"), name=experiment_name.replace(os.path.sep, "__"),
run_id=run_id, run_id=run_id,
config=config_dict, config=config_dict,
project=self.wandb_project, project=self.wandb_project,
) )
logger.load(writer) wandb_logger.load(writer)
elif self.logger_type == "tensorboard": return wandb_logger
logger = TensorboardLogger(writer) case "tensorboard":
else: return TensorboardLogger(writer)
raise ValueError(f"Unknown logger type '{self.logger_type}'") case _:
return logger raise ValueError(f"Unknown logger type '{self.logger_type}'")

View File

@ -209,16 +209,16 @@ class ActorFactoryTransientStorageDecorator(ActorFactory):
self.actor_factory = actor_factory self.actor_factory = actor_factory
self._actor_future = actor_future self._actor_future = actor_future
def __getstate__(self): def __getstate__(self) -> dict:
d = dict(self.__dict__) d = dict(self.__dict__)
del d["_actor_future"] del d["_actor_future"]
return d return d
def __setstate__(self, state): def __setstate__(self, state: dict) -> None:
self.__dict__ = state self.__dict__ = state
self._actor_future = ActorFuture() self._actor_future = ActorFuture()
def _tostring_excludes(self): def _tostring_excludes(self) -> list[str]:
return [*super()._tostring_excludes(), "_actor_future"] return [*super()._tostring_excludes(), "_actor_future"]
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module: def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:

View File

@ -1,6 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
import numpy as np
from torch import nn from torch import nn
from tianshou.highlevel.env import Environments, EnvType from tianshou.highlevel.env import Environments, EnvType
@ -61,7 +62,7 @@ class CriticFactoryDefault(CriticFactory):
envs: Environments, envs: Environments,
device: TDevice, device: TDevice,
use_action: bool, use_action: bool,
discrete_last_size_use_action_shape=False, discrete_last_size_use_action_shape: bool = False,
) -> nn.Module: ) -> nn.Module:
factory: CriticFactory factory: CriticFactory
env_type = envs.get_type() env_type = envs.get_type()
@ -89,7 +90,7 @@ class CriticFactoryContinuousNet(CriticFactory):
envs: Environments, envs: Environments,
device: TDevice, device: TDevice,
use_action: bool, use_action: bool,
discrete_last_size_use_action_shape=False, discrete_last_size_use_action_shape: bool = False,
) -> nn.Module: ) -> nn.Module:
action_shape = envs.get_action_shape() if use_action else 0 action_shape = envs.get_action_shape() if use_action else 0
net_c = Net( net_c = Net(
@ -114,7 +115,7 @@ class CriticFactoryDiscreteNet(CriticFactory):
envs: Environments, envs: Environments,
device: TDevice, device: TDevice,
use_action: bool, use_action: bool,
discrete_last_size_use_action_shape=False, discrete_last_size_use_action_shape: bool = False,
) -> nn.Module: ) -> nn.Module:
action_shape = envs.get_action_shape() if use_action else 0 action_shape = envs.get_action_shape() if use_action else 0
net_c = Net( net_c = Net(
@ -125,7 +126,9 @@ class CriticFactoryDiscreteNet(CriticFactory):
activation=nn.Tanh, activation=nn.Tanh,
device=device, device=device,
) )
last_size = envs.get_action_shape() if discrete_last_size_use_action_shape else 1 last_size = (
int(np.prod(envs.get_action_shape())) if discrete_last_size_use_action_shape else 1
)
critic = discrete.Critic(net_c, device=device, last_size=last_size).to(device) critic = discrete.Critic(net_c, device=device, last_size=last_size).to(device)
init_linear_orthogonal(critic) init_linear_orthogonal(critic)
return critic return critic
@ -149,7 +152,7 @@ class CriticFactoryReuseActor(CriticFactory):
envs: Environments, envs: Environments,
device: TDevice, device: TDevice,
use_action: bool, use_action: bool,
discrete_last_size_use_action_shape=False, discrete_last_size_use_action_shape: bool = False,
) -> nn.Module: ) -> nn.Module:
actor = self.actor_future.actor actor = self.actor_future.actor
if not isinstance(actor, BaseActor): if not isinstance(actor, BaseActor):
@ -157,7 +160,10 @@ class CriticFactoryReuseActor(CriticFactory):
f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}", f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}",
) )
if envs.get_type().is_discrete(): if envs.get_type().is_discrete():
last_size = envs.get_action_shape() if discrete_last_size_use_action_shape else 1 # TODO get rid of this prod pattern here and elsewhere
last_size = (
int(np.prod(envs.get_action_shape())) if discrete_last_size_use_action_shape else 1
)
return discrete.Critic( return discrete.Critic(
actor.get_preprocess_net(), actor.get_preprocess_net(),
device=device, device=device,

View File

@ -45,12 +45,12 @@ class Persistence(ABC):
pass pass
@abstractmethod @abstractmethod
def restore(self, event: RestoreEvent, world: World): def restore(self, event: RestoreEvent, world: World) -> None:
pass pass
class PersistenceGroup(Persistence): class PersistenceGroup(Persistence):
def __init__(self, *p: Persistence, enabled=True): def __init__(self, *p: Persistence, enabled: bool = True):
self.items = p self.items = p
self.enabled = enabled self.enabled = enabled
@ -60,7 +60,7 @@ class PersistenceGroup(Persistence):
for item in self.items: for item in self.items:
item.persist(event, world) item.persist(event, world)
def restore(self, event: RestoreEvent, world: World): def restore(self, event: RestoreEvent, world: World) -> None:
for item in self.items: for item in self.items:
item.restore(event, world) item.restore(event, world)
@ -68,7 +68,7 @@ class PersistenceGroup(Persistence):
class PolicyPersistence: class PolicyPersistence:
FILENAME = "policy.dat" FILENAME = "policy.dat"
def __init__(self, additional_persistence: Persistence | None = None, enabled=True): def __init__(self, additional_persistence: Persistence | None = None, enabled: bool = True):
""":param additional_persistence: a persistence instance which is to be envoked whenever """:param additional_persistence: a persistence instance which is to be envoked whenever
this object is used to persist/restore data this object is used to persist/restore data
:param enabled: whether persistence is enabled (restoration is always enabled) :param enabled: whether persistence is enabled (restoration is always enabled)
@ -93,7 +93,7 @@ class PolicyPersistence:
if self.additional_persistence is not None: if self.additional_persistence is not None:
self.additional_persistence.restore(RestoreEvent.RESTORE_POLICY, world) self.additional_persistence.restore(RestoreEvent.RESTORE_POLICY, world)
def get_save_best_fn(self, world) -> Callable[[torch.nn.Module], None]: def get_save_best_fn(self, world: World) -> Callable[[torch.nn.Module], None]:
def save_best_fn(pol: torch.nn.Module) -> None: def save_best_fn(pol: torch.nn.Module) -> None:
self.persist(pol, world) self.persist(pol, world)

View File

@ -18,7 +18,7 @@ class World:
test_collector: "Collector" test_collector: "Collector"
logger: "TLogger" logger: "TLogger"
persist_directory: str persist_directory: str
restore_directory: str restore_directory: str | None
trainer: Optional["BaseTrainer"] = None trainer: Optional["BaseTrainer"] = None
def persist_path(self, filename: str) -> str: def persist_path(self, filename: str) -> str:

View File

@ -17,7 +17,8 @@ from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.base import TLearningRateScheduler
from tianshou.utils import RunningMeanStd from tianshou.utils import RunningMeanStd
TDistributionFunction: TypeAlias = Callable[[torch.Tensor, ...], torch.distributions.Distribution] # TODO: Is there a better way to define this type? mypy doesn't like Callable[[torch.Tensor, ...], torch.distributions.Distribution]
TDistributionFunction: TypeAlias = Callable[..., torch.distributions.Distribution]
class PGPolicy(BasePolicy): class PGPolicy(BasePolicy):

View File

@ -216,7 +216,7 @@ class ActorProb(BaseActor):
def get_preprocess_net(self) -> nn.Module: def get_preprocess_net(self) -> nn.Module:
return self.preprocess return self.preprocess
def get_output_dim(self): def get_output_dim(self) -> int:
return self.output_dim return self.output_dim
def forward( def forward(

View File

@ -189,7 +189,7 @@ class ImplicitQuantileNetwork(Critic):
def __init__( def __init__(
self, self,
preprocess_net: nn.Module, preprocess_net: nn.Module,
action_shape: Sequence[int], action_shape: Sequence[int] | int,
hidden_sizes: Sequence[int] = (), hidden_sizes: Sequence[int] = (),
num_cosines: int = 64, num_cosines: int = 64,
preprocess_net_output_dim: int | None = None, preprocess_net_output_dim: int | None = None,