Make mypy happy
This commit is contained in:
parent
76e870207d
commit
023b33c917
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Generic, TypeVar, cast
|
||||
|
||||
@ -63,11 +64,17 @@ from tianshou.utils.string import ToStringMixin
|
||||
CHECKPOINT_DICT_KEY_MODEL = "model"
|
||||
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
||||
TParams = TypeVar("TParams", bound=Params)
|
||||
TActorCriticParams = TypeVar("TActorCriticParams", bound=ParamsMixinLearningRateWithScheduler)
|
||||
TActorDualCriticsParams = TypeVar("TActorDualCriticsParams", bound=ParamsMixinActorAndDualCritics)
|
||||
TActorCriticParams = TypeVar(
|
||||
"TActorCriticParams",
|
||||
bound=Params | ParamsMixinLearningRateWithScheduler,
|
||||
)
|
||||
TActorDualCriticsParams = TypeVar(
|
||||
"TActorDualCriticsParams",
|
||||
bound=Params | ParamsMixinActorAndDualCritics,
|
||||
)
|
||||
TDiscreteCriticOnlyParams = TypeVar(
|
||||
"TDiscreteCriticOnlyParams",
|
||||
bound=ParamsMixinLearningRateWithScheduler,
|
||||
bound=Params | ParamsMixinLearningRateWithScheduler,
|
||||
)
|
||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||
log = logging.getLogger(__name__)
|
||||
@ -321,6 +328,7 @@ class ActorCriticAgentFactory(
|
||||
optim = self.optim_factory.create_optimizer(actor_critic, lr)
|
||||
return ActorCriticModuleOpt(actor_critic, optim)
|
||||
|
||||
@typing.no_type_check
|
||||
def _create_kwargs(self, envs: Environments, device: TDevice) -> dict[str, Any]:
|
||||
actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr)
|
||||
kwargs = self.params.create_kwargs(
|
||||
@ -382,6 +390,7 @@ class DiscreteCriticOnlyAgentFactory(
|
||||
def _get_policy_class(self) -> type[TPolicy]:
|
||||
pass
|
||||
|
||||
@typing.no_type_check
|
||||
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
||||
model = self.model_factory.create_module(envs, device)
|
||||
optim = self.optim_factory.create_optimizer(model, self.params.lr)
|
||||
@ -548,6 +557,7 @@ class ActorDualCriticsAgentFactory(
|
||||
def _get_critic_use_action(envs: Environments) -> bool:
|
||||
return envs.get_type().is_continuous()
|
||||
|
||||
@typing.no_type_check
|
||||
def _create_policy(self, envs: Environments, device: TDevice) -> TPolicy:
|
||||
actor = self.actor_factory.create_module_opt(
|
||||
envs,
|
||||
|
@ -37,7 +37,7 @@ class Environments(ToStringMixin, ABC):
|
||||
self.env = env
|
||||
self.train_envs = train_envs
|
||||
self.test_envs = test_envs
|
||||
self.persistence = []
|
||||
self.persistence: Sequence[Persistence] = []
|
||||
|
||||
def _tostring_includes(self) -> list[str]:
|
||||
return []
|
||||
@ -51,7 +51,7 @@ class Environments(ToStringMixin, ABC):
|
||||
"state_shape": self.get_observation_shape(),
|
||||
}
|
||||
|
||||
def set_persistence(self, *p: Persistence):
|
||||
def set_persistence(self, *p: Persistence) -> None:
|
||||
self.persistence = p
|
||||
|
||||
@abstractmethod
|
||||
|
@ -27,7 +27,7 @@ from tianshou.highlevel.agent import (
|
||||
)
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.env import EnvFactory, Environments
|
||||
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory
|
||||
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerFactory, TLogger
|
||||
from tianshou.highlevel.module.actor import (
|
||||
ActorFactory,
|
||||
ActorFactoryDefault,
|
||||
@ -142,13 +142,18 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
self.env_config = env_config
|
||||
|
||||
@classmethod
|
||||
def from_directory(cls, directory: str) -> Self:
|
||||
def from_directory(cls, directory: str, restore_policy: bool = True) -> Self:
|
||||
"""Restores an experiment from a previously stored pickle.
|
||||
|
||||
:param directory: persistence directory of a previous run, in which a pickled experiment is found
|
||||
:param restore_policy: whether the experiment shall be configured to restore the policy that was
|
||||
persisted in the given directory
|
||||
"""
|
||||
with open(os.path.join(directory, cls.EXPERIMENT_PICKLE_FILENAME), "rb") as f:
|
||||
return pickle.load(f)
|
||||
experiment: Experiment = pickle.load(f)
|
||||
if restore_policy:
|
||||
experiment.config.policy_restore_directory = directory
|
||||
return experiment
|
||||
|
||||
def _set_seed(self) -> None:
|
||||
seed = self.config.seed
|
||||
@ -159,7 +164,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
def _build_config_dict(self) -> dict:
|
||||
return {"experiment": self.pprints()}
|
||||
|
||||
def save(self, directory: str):
|
||||
def save(self, directory: str) -> None:
|
||||
path = os.path.join(directory, self.EXPERIMENT_PICKLE_FILENAME)
|
||||
log.info(
|
||||
f"Saving serialized experiment in {path}; can be restored via Experiment.from_directory('{directory}')",
|
||||
@ -187,7 +192,8 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
os.makedirs(persistence_dir, exist_ok=True)
|
||||
|
||||
with logging.FileLoggerContext(
|
||||
os.path.join(persistence_dir, self.LOG_FILENAME), enabled=use_persistence,
|
||||
os.path.join(persistence_dir, self.LOG_FILENAME),
|
||||
enabled=use_persistence,
|
||||
):
|
||||
# log initial information
|
||||
log.info(f"Running experiment (name='{experiment_name}'):\n{self.pprints()}")
|
||||
@ -209,6 +215,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
# initialize logger
|
||||
full_config = self._build_config_dict()
|
||||
full_config.update(envs.info())
|
||||
logger: TLogger
|
||||
if use_persistence:
|
||||
logger = self.logger_factory.create_logger(
|
||||
log_dir=persistence_dir,
|
||||
@ -460,7 +467,7 @@ class _BuilderMixinCriticsFactory:
|
||||
|
||||
|
||||
class _BuilderMixinSingleCriticFactory(_BuilderMixinCriticsFactory):
|
||||
def __init__(self, actor_future_provider: ActorFutureProviderProtocol = None) -> None:
|
||||
def __init__(self, actor_future_provider: ActorFutureProviderProtocol) -> None:
|
||||
super().__init__(1, actor_future_provider)
|
||||
|
||||
def with_critic_factory(self, critic_factory: CriticFactory) -> Self:
|
||||
@ -553,7 +560,7 @@ class _BuilderMixinCriticEnsembleFactory:
|
||||
self.critic_ensemble_factory = CriticEnsembleFactoryDefault(hidden_sizes)
|
||||
return self
|
||||
|
||||
def _get_critic_ensemble_factory(self):
|
||||
def _get_critic_ensemble_factory(self) -> CriticEnsembleFactory:
|
||||
if self.critic_ensemble_factory is None:
|
||||
return CriticEnsembleFactoryDefault()
|
||||
else:
|
||||
@ -745,8 +752,10 @@ class IQNExperimentBuilder(ExperimentBuilder):
|
||||
):
|
||||
super().__init__(env_factory, experiment_config, sampling_config)
|
||||
self._params: IQNParams = IQNParams()
|
||||
self._preprocess_network_factory = IntermediateModuleFactoryFromActorFactory(
|
||||
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED),
|
||||
self._preprocess_network_factory: IntermediateModuleFactory = (
|
||||
IntermediateModuleFactoryFromActorFactory(
|
||||
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED),
|
||||
)
|
||||
)
|
||||
|
||||
def with_iqn_params(self, params: IQNParams) -> Self:
|
||||
|
@ -4,16 +4,20 @@ from typing import Literal, TypeAlias
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||
from tianshou.utils import BaseLogger, TensorboardLogger, WandbLogger
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
TLogger: TypeAlias = TensorboardLogger | WandbLogger
|
||||
TLogger: TypeAlias = BaseLogger
|
||||
|
||||
|
||||
class LoggerFactory(ToStringMixin, ABC):
|
||||
@abstractmethod
|
||||
def create_logger(
|
||||
self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict,
|
||||
self,
|
||||
log_dir: str,
|
||||
experiment_name: str,
|
||||
run_id: str | None,
|
||||
config_dict: dict,
|
||||
) -> TLogger:
|
||||
""":param log_dir: path to the directory in which log data is to be stored
|
||||
:param experiment_name: the name of the job, which may contain os.path.sep
|
||||
@ -35,7 +39,11 @@ class DefaultLoggerFactory(LoggerFactory):
|
||||
self.wandb_project = wandb_project
|
||||
|
||||
def create_logger(
|
||||
self, log_dir: str, experiment_name: str, run_id: str | None, config_dict: dict,
|
||||
self,
|
||||
log_dir: str,
|
||||
experiment_name: str,
|
||||
run_id: str | None,
|
||||
config_dict: dict,
|
||||
) -> TLogger:
|
||||
writer = SummaryWriter(log_dir)
|
||||
writer.add_text(
|
||||
@ -48,18 +56,18 @@ class DefaultLoggerFactory(LoggerFactory):
|
||||
),
|
||||
),
|
||||
)
|
||||
logger: TLogger
|
||||
if self.logger_type == "wandb":
|
||||
logger = WandbLogger(
|
||||
save_interval=1,
|
||||
name=experiment_name.replace(os.path.sep, "__"),
|
||||
run_id=run_id,
|
||||
config=config_dict,
|
||||
project=self.wandb_project,
|
||||
)
|
||||
logger.load(writer)
|
||||
elif self.logger_type == "tensorboard":
|
||||
logger = TensorboardLogger(writer)
|
||||
else:
|
||||
raise ValueError(f"Unknown logger type '{self.logger_type}'")
|
||||
return logger
|
||||
match self.logger_type:
|
||||
case "wandb":
|
||||
wandb_logger = WandbLogger(
|
||||
save_interval=1,
|
||||
name=experiment_name.replace(os.path.sep, "__"),
|
||||
run_id=run_id,
|
||||
config=config_dict,
|
||||
project=self.wandb_project,
|
||||
)
|
||||
wandb_logger.load(writer)
|
||||
return wandb_logger
|
||||
case "tensorboard":
|
||||
return TensorboardLogger(writer)
|
||||
case _:
|
||||
raise ValueError(f"Unknown logger type '{self.logger_type}'")
|
||||
|
@ -209,16 +209,16 @@ class ActorFactoryTransientStorageDecorator(ActorFactory):
|
||||
self.actor_factory = actor_factory
|
||||
self._actor_future = actor_future
|
||||
|
||||
def __getstate__(self):
|
||||
def __getstate__(self) -> dict:
|
||||
d = dict(self.__dict__)
|
||||
del d["_actor_future"]
|
||||
return d
|
||||
|
||||
def __setstate__(self, state):
|
||||
def __setstate__(self, state: dict) -> None:
|
||||
self.__dict__ = state
|
||||
self._actor_future = ActorFuture()
|
||||
|
||||
def _tostring_excludes(self):
|
||||
def _tostring_excludes(self) -> list[str]:
|
||||
return [*super()._tostring_excludes(), "_actor_future"]
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor | nn.Module:
|
||||
|
@ -1,6 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
|
||||
from tianshou.highlevel.env import Environments, EnvType
|
||||
@ -61,7 +62,7 @@ class CriticFactoryDefault(CriticFactory):
|
||||
envs: Environments,
|
||||
device: TDevice,
|
||||
use_action: bool,
|
||||
discrete_last_size_use_action_shape=False,
|
||||
discrete_last_size_use_action_shape: bool = False,
|
||||
) -> nn.Module:
|
||||
factory: CriticFactory
|
||||
env_type = envs.get_type()
|
||||
@ -89,7 +90,7 @@ class CriticFactoryContinuousNet(CriticFactory):
|
||||
envs: Environments,
|
||||
device: TDevice,
|
||||
use_action: bool,
|
||||
discrete_last_size_use_action_shape=False,
|
||||
discrete_last_size_use_action_shape: bool = False,
|
||||
) -> nn.Module:
|
||||
action_shape = envs.get_action_shape() if use_action else 0
|
||||
net_c = Net(
|
||||
@ -114,7 +115,7 @@ class CriticFactoryDiscreteNet(CriticFactory):
|
||||
envs: Environments,
|
||||
device: TDevice,
|
||||
use_action: bool,
|
||||
discrete_last_size_use_action_shape=False,
|
||||
discrete_last_size_use_action_shape: bool = False,
|
||||
) -> nn.Module:
|
||||
action_shape = envs.get_action_shape() if use_action else 0
|
||||
net_c = Net(
|
||||
@ -125,7 +126,9 @@ class CriticFactoryDiscreteNet(CriticFactory):
|
||||
activation=nn.Tanh,
|
||||
device=device,
|
||||
)
|
||||
last_size = envs.get_action_shape() if discrete_last_size_use_action_shape else 1
|
||||
last_size = (
|
||||
int(np.prod(envs.get_action_shape())) if discrete_last_size_use_action_shape else 1
|
||||
)
|
||||
critic = discrete.Critic(net_c, device=device, last_size=last_size).to(device)
|
||||
init_linear_orthogonal(critic)
|
||||
return critic
|
||||
@ -149,7 +152,7 @@ class CriticFactoryReuseActor(CriticFactory):
|
||||
envs: Environments,
|
||||
device: TDevice,
|
||||
use_action: bool,
|
||||
discrete_last_size_use_action_shape=False,
|
||||
discrete_last_size_use_action_shape: bool = False,
|
||||
) -> nn.Module:
|
||||
actor = self.actor_future.actor
|
||||
if not isinstance(actor, BaseActor):
|
||||
@ -157,7 +160,10 @@ class CriticFactoryReuseActor(CriticFactory):
|
||||
f"Option critic_use_action can only be used if actor is of type {BaseActor.__class__.__name__}",
|
||||
)
|
||||
if envs.get_type().is_discrete():
|
||||
last_size = envs.get_action_shape() if discrete_last_size_use_action_shape else 1
|
||||
# TODO get rid of this prod pattern here and elsewhere
|
||||
last_size = (
|
||||
int(np.prod(envs.get_action_shape())) if discrete_last_size_use_action_shape else 1
|
||||
)
|
||||
return discrete.Critic(
|
||||
actor.get_preprocess_net(),
|
||||
device=device,
|
||||
|
@ -45,12 +45,12 @@ class Persistence(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def restore(self, event: RestoreEvent, world: World):
|
||||
def restore(self, event: RestoreEvent, world: World) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class PersistenceGroup(Persistence):
|
||||
def __init__(self, *p: Persistence, enabled=True):
|
||||
def __init__(self, *p: Persistence, enabled: bool = True):
|
||||
self.items = p
|
||||
self.enabled = enabled
|
||||
|
||||
@ -60,7 +60,7 @@ class PersistenceGroup(Persistence):
|
||||
for item in self.items:
|
||||
item.persist(event, world)
|
||||
|
||||
def restore(self, event: RestoreEvent, world: World):
|
||||
def restore(self, event: RestoreEvent, world: World) -> None:
|
||||
for item in self.items:
|
||||
item.restore(event, world)
|
||||
|
||||
@ -68,7 +68,7 @@ class PersistenceGroup(Persistence):
|
||||
class PolicyPersistence:
|
||||
FILENAME = "policy.dat"
|
||||
|
||||
def __init__(self, additional_persistence: Persistence | None = None, enabled=True):
|
||||
def __init__(self, additional_persistence: Persistence | None = None, enabled: bool = True):
|
||||
""":param additional_persistence: a persistence instance which is to be envoked whenever
|
||||
this object is used to persist/restore data
|
||||
:param enabled: whether persistence is enabled (restoration is always enabled)
|
||||
@ -93,7 +93,7 @@ class PolicyPersistence:
|
||||
if self.additional_persistence is not None:
|
||||
self.additional_persistence.restore(RestoreEvent.RESTORE_POLICY, world)
|
||||
|
||||
def get_save_best_fn(self, world) -> Callable[[torch.nn.Module], None]:
|
||||
def get_save_best_fn(self, world: World) -> Callable[[torch.nn.Module], None]:
|
||||
def save_best_fn(pol: torch.nn.Module) -> None:
|
||||
self.persist(pol, world)
|
||||
|
||||
|
@ -18,7 +18,7 @@ class World:
|
||||
test_collector: "Collector"
|
||||
logger: "TLogger"
|
||||
persist_directory: str
|
||||
restore_directory: str
|
||||
restore_directory: str | None
|
||||
trainer: Optional["BaseTrainer"] = None
|
||||
|
||||
def persist_path(self, filename: str) -> str:
|
||||
|
@ -17,7 +17,8 @@ from tianshou.policy import BasePolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
from tianshou.utils import RunningMeanStd
|
||||
|
||||
TDistributionFunction: TypeAlias = Callable[[torch.Tensor, ...], torch.distributions.Distribution]
|
||||
# TODO: Is there a better way to define this type? mypy doesn't like Callable[[torch.Tensor, ...], torch.distributions.Distribution]
|
||||
TDistributionFunction: TypeAlias = Callable[..., torch.distributions.Distribution]
|
||||
|
||||
|
||||
class PGPolicy(BasePolicy):
|
||||
|
@ -216,7 +216,7 @@ class ActorProb(BaseActor):
|
||||
def get_preprocess_net(self) -> nn.Module:
|
||||
return self.preprocess
|
||||
|
||||
def get_output_dim(self):
|
||||
def get_output_dim(self) -> int:
|
||||
return self.output_dim
|
||||
|
||||
def forward(
|
||||
|
@ -189,7 +189,7 @@ class ImplicitQuantileNetwork(Critic):
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
action_shape: Sequence[int],
|
||||
action_shape: Sequence[int] | int,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
num_cosines: int = 64,
|
||||
preprocess_net_output_dim: int | None = None,
|
||||
|
Loading…
x
Reference in New Issue
Block a user