Typing: fixed multiple typing issues

This commit is contained in:
Michael Panchenko 2023-12-05 12:04:18 +01:00
parent 2e39a252e3
commit a846b52063
10 changed files with 84 additions and 68 deletions

View File

@ -68,7 +68,7 @@ class Collector:
super().__init__() super().__init__()
if isinstance(env, gym.Env) and not hasattr(env, "__len__"): if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
warnings.warn("Single environment detected, wrap to DummyVectorEnv.") warnings.warn("Single environment detected, wrap to DummyVectorEnv.")
self.env = DummyVectorEnv([lambda: env]) # type: ignore self.env = DummyVectorEnv([lambda: env])
else: else:
self.env = env # type: ignore self.env = env # type: ignore
self.env_num = len(self.env) self.env_num = len(self.env)

View File

@ -203,7 +203,7 @@ class SubprocEnvWorker(EnvWorker):
obs = result[0] obs = result[0]
if self.share_memory: if self.share_memory:
obs = self._decode_obs() obs = self._decode_obs()
return (obs, *result[1:]) # type: ignore return (obs, *result[1:])
obs = result obs = result
if self.share_memory: if self.share_memory:
obs = self._decode_obs() obs = self._decode_obs()

View File

@ -78,7 +78,7 @@ class ActorFactory(ModuleFactory, ToStringMixin, ABC):
# do last policy layer scaling, this will make initial actions have (close to) # do last policy layer scaling, this will make initial actions have (close to)
# 0 mean and std, and will help boost performances, # 0 mean and std, and will help boost performances,
# see https://arxiv.org/abs/2006.05990, Fig.24 for details # see https://arxiv.org/abs/2006.05990, Fig.24 for details
for m in actor.mu.modules(): # type: ignore for m in actor.mu.modules():
if isinstance(m, torch.nn.Linear): if isinstance(m, torch.nn.Linear):
m.weight.data.copy_(0.01 * m.weight.data) m.weight.data.copy_(0.01 * m.weight.data)

View File

@ -95,7 +95,7 @@ class ICMPolicy(BasePolicy):
def set_eps(self, eps: float) -> None: def set_eps(self, eps: float) -> None:
"""Set the eps for epsilon-greedy exploration.""" """Set the eps for epsilon-greedy exploration."""
if hasattr(self.policy, "set_eps"): if hasattr(self.policy, "set_eps"):
self.policy.set_eps(eps) # type: ignore self.policy.set_eps(eps)
else: else:
raise NotImplementedError raise NotImplementedError

View File

@ -78,11 +78,11 @@ class BranchingDQNPolicy(DQNPolicy):
# but it collides with an attr of the same name in base class # but it collides with an attr of the same name in base class
@property @property
def _action_per_branch(self) -> int: def _action_per_branch(self) -> int:
return self.model.action_per_branch # type: ignore return self.model.action_per_branch
@property @property
def num_branches(self) -> int: def num_branches(self) -> int:
return self.model.num_branches # type: ignore return self.model.num_branches
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
obs_next_batch = Batch( obs_next_batch = Batch(

View File

@ -78,7 +78,7 @@ class DDPGPolicy(BasePolicy):
action_bound_method=action_bound_method, action_bound_method=action_bound_method,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
) )
if action_scaling and not np.isclose(actor.max_action, 1.0): # type: ignore if action_scaling and not np.isclose(actor.max_action, 1.0):
warnings.warn( warnings.warn(
"action_scaling and action_bound_method are only intended to deal" "action_scaling and action_bound_method are only intended to deal"
"with unbounded model action space, but find actor model bound" "with unbounded model action space, but find actor model bound"

View File

@ -77,7 +77,7 @@ class PGPolicy(BasePolicy):
action_bound_method=action_bound_method, action_bound_method=action_bound_method,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
) )
if action_scaling and not np.isclose(actor.max_action, 1.0): # type: ignore if action_scaling and not np.isclose(actor.max_action, 1.0):
warnings.warn( warnings.warn(
"action_scaling and action_bound_method are only intended" "action_scaling and action_bound_method are only intended"
"to deal with unbounded model action space, but find actor model" "to deal with unbounded model action space, but find actor model"

View File

@ -9,9 +9,9 @@ from collections.abc import Callable
from datetime import datetime from datetime import datetime
from io import StringIO from io import StringIO
from logging import * from logging import *
from typing import Any from typing import Any, TypeVar, cast
log = getLogger(__name__) log = getLogger(__name__) # type: ignore
LOG_DEFAULT_FORMAT = "%(levelname)-5s %(asctime)-15s %(name)s:%(funcName)s - %(message)s" LOG_DEFAULT_FORMAT = "%(levelname)-5s %(asctime)-15s %(name)s:%(funcName)s - %(message)s"
@ -20,18 +20,18 @@ LOG_DEFAULT_FORMAT = "%(levelname)-5s %(asctime)-15s %(name)s:%(funcName)s - %(m
_logFormat = LOG_DEFAULT_FORMAT _logFormat = LOG_DEFAULT_FORMAT
def remove_log_handlers(): def remove_log_handlers() -> None:
"""Removes all current log handlers.""" """Removes all current log handlers."""
logger = getLogger() logger = getLogger()
while logger.hasHandlers(): while logger.hasHandlers():
logger.removeHandler(logger.handlers[0]) logger.removeHandler(logger.handlers[0])
def remove_log_handler(handler): def remove_log_handler(handler: Handler) -> None:
getLogger().removeHandler(handler) getLogger().removeHandler(handler)
def is_log_handler_active(handler): def is_log_handler_active(handler: Handler) -> bool:
"""Checks whether the given handler is active. """Checks whether the given handler is active.
:param handler: a log handler :param handler: a log handler
@ -41,7 +41,7 @@ def is_log_handler_active(handler):
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
def configure(format=LOG_DEFAULT_FORMAT, level=lg.DEBUG): def configure(format: str = LOG_DEFAULT_FORMAT, level: int = lg.DEBUG) -> None:
"""Configures logging to stdout with the given format and log level, """Configures logging to stdout with the given format and log level,
also configuring the default log levels of some overly verbose libraries as well as some pandas output options. also configuring the default log levels of some overly verbose libraries as well as some pandas output options.
@ -56,8 +56,13 @@ def configure(format=LOG_DEFAULT_FORMAT, level=lg.DEBUG):
getLogger("numba").setLevel(INFO) getLogger("numba").setLevel(INFO)
T = TypeVar("T")
# noinspection PyShadowingBuiltins # noinspection PyShadowingBuiltins
def run_main(main_fn: Callable[[], Any], format=LOG_DEFAULT_FORMAT, level=lg.DEBUG): def run_main(
main_fn: Callable[[], T], format: str = LOG_DEFAULT_FORMAT, level: int = lg.DEBUG
) -> T | None:
"""Configures logging with the given parameters, ensuring that any exceptions that occur during """Configures logging with the given parameters, ensuring that any exceptions that occur during
the execution of the given function are logged. the execution of the given function are logged.
Logs two additional messages, one before the execution of the function, and one upon its completion. Logs two additional messages, one before the execution of the function, and one upon its completion.
@ -68,16 +73,19 @@ def run_main(main_fn: Callable[[], Any], format=LOG_DEFAULT_FORMAT, level=lg.DEB
:return: the result of `main_fn` :return: the result of `main_fn`
""" """
configure(format=format, level=level) configure(format=format, level=level)
log.info("Starting") log.info("Starting") # type: ignore
try: try:
result = main_fn() result = main_fn()
log.info("Done") log.info("Done") # type: ignore
return result return result
except Exception as e: except Exception as e:
log.error("Exception during script execution", exc_info=e) log.error("Exception during script execution", exc_info=e) # type: ignore
return None
def run_cli(main_fn: Callable[[], Any], format=LOG_DEFAULT_FORMAT, level=lg.DEBUG): def run_cli(
main_fn: Callable[[], T], format: str = LOG_DEFAULT_FORMAT, level: int = lg.DEBUG
) -> T | None:
""" """
Configures logging with the given parameters and runs the given main function as a Configures logging with the given parameters and runs the given main function as a
CLI using `jsonargparse` (which is configured to also parse attribute docstrings, such CLI using `jsonargparse` (which is configured to also parse attribute docstrings, such
@ -107,14 +115,14 @@ _isAtExitReportFileLoggerRegistered = False
_memoryLogStream: StringIO | None = None _memoryLogStream: StringIO | None = None
def _at_exit_report_file_logger(): def _at_exit_report_file_logger() -> None:
for path in _fileLoggerPaths: for path in _fileLoggerPaths:
print(f"A log file was saved to {path}") print(f"A log file was saved to {path}")
def add_file_logger(path, register_atexit=True): def add_file_logger(path: str, register_atexit: bool = True) -> FileHandler:
global _isAtExitReportFileLoggerRegistered global _isAtExitReportFileLoggerRegistered
log.info(f"Logging to {path} ...") log.info(f"Logging to {path} ...") # type: ignore
handler = FileHandler(path) handler = FileHandler(path)
handler.setFormatter(Formatter(_logFormat)) handler.setFormatter(Formatter(_logFormat))
Logger.root.addHandler(handler) Logger.root.addHandler(handler)
@ -138,21 +146,22 @@ def add_memory_logger() -> None:
Logger.root.addHandler(handler) Logger.root.addHandler(handler)
def get_memory_log(): def get_memory_log() -> Any:
""":return: the in-memory log (provided that `add_memory_logger` was called beforehand)""" """:return: the in-memory log (provided that `add_memory_logger` was called beforehand)"""
assert _memoryLogStream is not None, "This should not have happened and might be a bug."
return _memoryLogStream.getvalue() return _memoryLogStream.getvalue()
class FileLoggerContext: class FileLoggerContext:
def __init__(self, path: str, enabled=True): def __init__(self, path: str, enabled: bool = True):
self.enabled = enabled self.enabled = enabled
self.path = path self.path = path
self._log_handler = None self._log_handler: Handler | None = None
def __enter__(self): def __enter__(self) -> None:
if self.enabled: if self.enabled:
self._log_handler = add_file_logger(self.path, register_atexit=False) self._log_handler = add_file_logger(self.path, register_atexit=False)
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
if self._log_handler is not None: if self._log_handler is not None:
remove_log_handler(self._log_handler) remove_log_handler(self._log_handler)

View File

@ -370,13 +370,13 @@ class NoisyLinear(nn.Module):
# TODO: rename or change functionality? Usually sample is not an inplace operation... # TODO: rename or change functionality? Usually sample is not an inplace operation...
def sample(self) -> None: def sample(self) -> None:
self.eps_p.copy_(self.f(self.eps_p)) # type: ignore self.eps_p.copy_(self.f(self.eps_p))
self.eps_q.copy_(self.f(self.eps_q)) # type: ignore self.eps_q.copy_(self.f(self.eps_q))
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.training: if self.training:
weight = self.mu_W + self.sigma_W * (self.eps_q.ger(self.eps_p)) # type: ignore weight = self.mu_W + self.sigma_W * (self.eps_q.ger(self.eps_p))
bias = self.mu_bias + self.sigma_bias * self.eps_q.clone() # type: ignore bias = self.mu_bias + self.sigma_bias * self.eps_q.clone()
else: else:
weight = self.mu_W weight = self.mu_W
bias = self.mu_bias bias = self.mu_bias

View File

@ -10,6 +10,8 @@ from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Mapping, Sequence
from typing import ( from typing import (
Any, Any,
Self,
cast,
) )
reCommaWhitespacePotentiallyBreaks = re.compile(r",\s+") reCommaWhitespacePotentiallyBreaks = re.compile(r",\s+")
@ -23,11 +25,13 @@ class StringConverter(ABC):
"""Abstraction for a string conversion mechanism.""" """Abstraction for a string conversion mechanism."""
@abstractmethod @abstractmethod
def to_string(self, x) -> str: def to_string(self, x: Any) -> str:
pass pass
def dict_string(d: Mapping, brackets: str | None = None, converter: StringConverter = None): def dict_string(
d: Mapping, brackets: str | None = None, converter: StringConverter | None = None
) -> str:
"""Converts a dictionary to a string of the form "<key>=<value>, <key>=<value>, ...", optionally enclosed """Converts a dictionary to a string of the form "<key>=<value>, <key>=<value>, ...", optionally enclosed
by brackets. by brackets.
@ -46,10 +50,10 @@ def dict_string(d: Mapping, brackets: str | None = None, converter: StringConver
def list_string( def list_string(
l: Iterable[Any], l: Iterable[Any],
brackets="[]", brackets: str | None = "[]",
quote: str | None = None, quote: str | None = None,
converter: StringConverter = None, converter: StringConverter | None = None,
): ) -> str:
"""Converts a list or any other iterable to a string of the form "[<value>, <value>, ...]", optionally enclosed """Converts a list or any other iterable to a string of the form "[<value>, <value>, ...]", optionally enclosed
by different brackets or with the values quoted. by different brackets or with the values quoted.
@ -61,7 +65,7 @@ def list_string(
:return: the string representation :return: the string representation
""" """
def item(x): def item(x: Any) -> str:
x = to_string(x, converter=converter, context="list") x = to_string(x, converter=converter, context="list")
if quote is not None: if quote is not None:
return quote + x + quote return quote + x + quote
@ -76,11 +80,11 @@ def list_string(
def to_string( def to_string(
x, x: Any,
converter: StringConverter = None, converter: StringConverter | None = None,
apply_converter_to_non_complex_objects=True, apply_converter_to_non_complex_objects: bool = True,
context=None, context: Any = None,
): ) -> str:
"""Converts the given object to a string, with proper handling of lists, tuples and dictionaries, optionally using a converter. """Converts the given object to a string, with proper handling of lists, tuples and dictionaries, optionally using a converter.
The conversion also removes unwanted line breaks (as present, in particular, in sklearn's string representations). The conversion also removes unwanted line breaks (as present, in particular, in sklearn's string representations).
@ -118,7 +122,7 @@ def to_string(
raise raise
def object_repr(obj, member_names_or_dict: list[str] | dict[str, Any]): def object_repr(obj: Any, member_names_or_dict: list[str] | dict[str, Any]) -> str:
"""Creates a string representation for the given object based on the given members. """Creates a string representation for the given object based on the given members.
The string takes the form "ClassName[attr1=value1, attr2=value2, ...]" The string takes the form "ClassName[attr1=value1, attr2=value2, ...]"
@ -130,7 +134,7 @@ def object_repr(obj, member_names_or_dict: list[str] | dict[str, Any]):
return f"{obj.__class__.__name__}[{dict_string(members_dict)}]" return f"{obj.__class__.__name__}[{dict_string(members_dict)}]"
def or_regex_group(allowed_names: Sequence[str]): def or_regex_group(allowed_names: Sequence[str]) -> str:
""":param allowed_names: strings to include as literals in the regex """:param allowed_names: strings to include as literals in the regex
:return: a regular expression string of the form (<name1>| ...|<nameN>), which any of the given names :return: a regular expression string of the form (<name1>| ...|<nameN>), which any of the given names
""" """
@ -185,7 +189,7 @@ class ToStringMixin:
_TOSTRING_INCLUDE_ALL = "__all__" _TOSTRING_INCLUDE_ALL = "__all__"
def _tostring_class_name(self): def _tostring_class_name(self) -> str:
""":return: the string use for <class name> in the string representation ``"<class name>[<object info]"``""" """:return: the string use for <class name> in the string representation ``"<class name>[<object info]"``"""
return type(self).__qualname__ return type(self).__qualname__
@ -196,7 +200,7 @@ class ToStringMixin:
exclude_exceptions: list[str] | None = None, exclude_exceptions: list[str] | None = None,
include_forced: list[str] | None = None, include_forced: list[str] | None = None,
additional_entries: dict[str, Any] | None = None, additional_entries: dict[str, Any] | None = None,
converter: StringConverter = None, converter: StringConverter | None = None,
) -> str: ) -> str:
"""Creates a string of the class attributes, with optional exclusions/inclusions/additions. """Creates a string of the class attributes, with optional exclusions/inclusions/additions.
Exclusions take precedence over inclusions. Exclusions take precedence over inclusions.
@ -210,7 +214,7 @@ class ToStringMixin:
:return: a string containing entry/property names and values :return: a string containing entry/property names and values
""" """
def mklist(x): def mklist(x: Any) -> list[str]:
if x is None: if x is None:
return [] return []
if isinstance(x, str): if isinstance(x, str):
@ -222,7 +226,7 @@ class ToStringMixin:
include_forced = mklist(include_forced) include_forced = mklist(include_forced)
exclude_exceptions = mklist(exclude_exceptions) exclude_exceptions = mklist(exclude_exceptions)
def is_excluded(k): def is_excluded(k: Any) -> bool:
if k in include_forced or k in exclude_exceptions: if k in include_forced or k in exclude_exceptions:
return False return False
if k in exclude: if k in exclude:
@ -329,17 +333,17 @@ class ToStringMixin:
""" """
return [] return []
def __str__(self): def __str__(self) -> str:
return f"{self._tostring_class_name()}[{self._tostring_object_info()}]" return f"{self._tostring_class_name()}[{self._tostring_object_info()}]"
def __repr__(self): def __repr__(self) -> str:
info = f"id={id(self)}" info = f"id={id(self)}"
property_info = self._tostring_object_info() property_info = self._tostring_object_info()
if len(property_info) > 0: if len(property_info) > 0:
info += ", " + property_info info += ", " + property_info
return f"{self._tostring_class_name()}[{info}]" return f"{self._tostring_class_name()}[{info}]"
def pprint(self, file=sys.stdout): def pprint(self, file: Any = sys.stdout) -> None:
"""Prints a prettily formatted string representation of the object (with line breaks and indentations) """Prints a prettily formatted string representation of the object (with line breaks and indentations)
to ``stdout`` or the given file. to ``stdout`` or the given file.
@ -364,7 +368,7 @@ class ToStringMixin:
""":param handled_objects: objects which are initially assumed to have been handled already""" """:param handled_objects: objects which are initially assumed to have been handled already"""
self._handled_to_string_mixin_ids = {id(o) for o in handled_objects} self._handled_to_string_mixin_ids = {id(o) for o in handled_objects}
def to_string(self, x) -> str: def to_string(self, x: Any) -> str:
if isinstance(x, ToStringMixin): if isinstance(x, ToStringMixin):
oid = id(x) oid = id(x)
if oid in self._handled_to_string_mixin_ids: if oid in self._handled_to_string_mixin_ids:
@ -389,17 +393,17 @@ class ToStringMixin:
# methods where we assume that they could transitively call _toStringProperties (others are assumed not to) # methods where we assume that they could transitively call _toStringProperties (others are assumed not to)
TOSTRING_METHODS_TRANSITIVELY_CALLING_TOSTRINGPROPERTIES = {"_tostring_object_info"} TOSTRING_METHODS_TRANSITIVELY_CALLING_TOSTRINGPROPERTIES = {"_tostring_object_info"}
def __init__(self, x: "ToStringMixin", converter): def __init__(self, x: "ToStringMixin", converter: Any) -> None:
self.x = x self.x = x
self.converter = converter self.converter = converter
def _tostring_properties(self, *args, **kwargs): def _tostring_properties(self, *args: Any, **kwargs: Any) -> str:
return self.x._tostring_properties(*args, **kwargs, converter=self.converter) return self.x._tostring_properties(*args, **kwargs, converter=self.converter) # type: ignore[misc]
def _tostring_class_name(self): def _tostring_class_name(self) -> str:
return self.x._tostring_class_name() return self.x._tostring_class_name()
def __getattr__(self, attr: str): def __getattr__(self, attr: str) -> Any:
if attr.startswith( if attr.startswith(
"_tostring", "_tostring",
): # ToStringMixin method which we may bind to use this proxy to ensure correct transitive call ): # ToStringMixin method which we may bind to use this proxy to ensure correct transitive call
@ -413,11 +417,13 @@ class ToStringMixin:
else: else:
return getattr(self.x, attr) return getattr(self.x, attr)
def __str__(self: "ToStringMixin"): def __str__(self) -> str:
return ToStringMixin.__str__(self) return ToStringMixin.__str__(self) # type: ignore[arg-type]
def pretty_string_repr(s: Any, initial_indentation_level=0, indentation_string=" "): def pretty_string_repr(
s: Any, initial_indentation_level: int = 0, indentation_string: str = " "
) -> str:
"""Creates a pretty string representation (using indentations) from the given object/string representation (as generated, for example, via """Creates a pretty string representation (using indentations) from the given object/string representation (as generated, for example, via
ToStringMixin). An indentation level is added for every opening bracket. ToStringMixin). An indentation level is added for every opening bracket.
@ -432,16 +438,16 @@ def pretty_string_repr(s: Any, initial_indentation_level=0, indentation_string="
result = indentation_string * indent result = indentation_string * indent
i = 0 i = 0
def nl(): def nl() -> None:
nonlocal result nonlocal result
result += "\n" + (indentation_string * indent) result += "\n" + (indentation_string * indent)
def take(cnt=1): def take(cnt: int = 1) -> None:
nonlocal result, i nonlocal result, i
result += s[i : i + cnt] result += s[i : i + cnt]
i += cnt i += cnt
def find_matching(j): def find_matching(j: int) -> int | None:
start = j start = j
op = s[j] op = s[j]
cl = {"[": "]", "(": ")", "'": "'"}[s[j]] cl = {"[": "]", "(": ")", "'": "'"}[s[j]]
@ -492,17 +498,18 @@ def pretty_string_repr(s: Any, initial_indentation_level=0, indentation_string="
class TagBuilder: class TagBuilder:
"""Assists in building strings made up of components that are joined via a glue string.""" """Assists in building strings made up of components that are joined via a glue string."""
def __init__(self, *initial_components: str, glue="_"): def __init__(self, *initial_components: str, glue: str = "_"):
""":param initial_components: initial components to always include at the beginning """:param initial_components: initial components to always include at the beginning
:param glue: the glue string which joins components :param glue: the glue string which joins components
""" """
self.glue = glue self.glue = glue
self.components = list(initial_components) self.components = list(initial_components)
def with_component(self, component: str): def with_component(self, component: str) -> Self:
self.components.append(component) self.components.append(component)
return self
def with_conditional(self, cond: bool, component: str): def with_conditional(self, cond: bool, component: str) -> Self:
"""Conditionally adds the given component. """Conditionally adds the given component.
:param cond: the condition :param cond: the condition
@ -513,7 +520,7 @@ class TagBuilder:
self.components.append(component) self.components.append(component)
return self return self
def with_alternative(self, cond: bool, true_component: str, false_component: str): def with_alternative(self, cond: bool, true_component: str, false_component: str) -> Self:
"""Adds a component depending on a condition. """Adds a component depending on a condition.
:param cond: the condition :param cond: the condition
@ -524,6 +531,6 @@ class TagBuilder:
self.components.append(true_component if cond else false_component) self.components.append(true_component if cond else false_component)
return self return self
def build(self): def build(self) -> str:
""":return: the string (with all components joined)""" """:return: the string (with all components joined)"""
return self.glue.join(self.components) return self.glue.join(self.components)