Typing: fixed multiple typing issues

This commit is contained in:
Michael Panchenko 2023-12-05 12:04:18 +01:00
parent 2e39a252e3
commit a846b52063
10 changed files with 84 additions and 68 deletions

View File

@ -68,7 +68,7 @@ class Collector:
super().__init__()
if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
warnings.warn("Single environment detected, wrap to DummyVectorEnv.")
self.env = DummyVectorEnv([lambda: env]) # type: ignore
self.env = DummyVectorEnv([lambda: env])
else:
self.env = env # type: ignore
self.env_num = len(self.env)

View File

@ -203,7 +203,7 @@ class SubprocEnvWorker(EnvWorker):
obs = result[0]
if self.share_memory:
obs = self._decode_obs()
return (obs, *result[1:]) # type: ignore
return (obs, *result[1:])
obs = result
if self.share_memory:
obs = self._decode_obs()

View File

@ -78,7 +78,7 @@ class ActorFactory(ModuleFactory, ToStringMixin, ABC):
# do last policy layer scaling, this will make initial actions have (close to)
# 0 mean and std, and will help boost performances,
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
for m in actor.mu.modules(): # type: ignore
for m in actor.mu.modules():
if isinstance(m, torch.nn.Linear):
m.weight.data.copy_(0.01 * m.weight.data)

View File

@ -95,7 +95,7 @@ class ICMPolicy(BasePolicy):
def set_eps(self, eps: float) -> None:
"""Set the eps for epsilon-greedy exploration."""
if hasattr(self.policy, "set_eps"):
self.policy.set_eps(eps) # type: ignore
self.policy.set_eps(eps)
else:
raise NotImplementedError

View File

@ -78,11 +78,11 @@ class BranchingDQNPolicy(DQNPolicy):
# but it collides with an attr of the same name in base class
@property
def _action_per_branch(self) -> int:
return self.model.action_per_branch # type: ignore
return self.model.action_per_branch
@property
def num_branches(self) -> int:
return self.model.num_branches # type: ignore
return self.model.num_branches
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
obs_next_batch = Batch(

View File

@ -78,7 +78,7 @@ class DDPGPolicy(BasePolicy):
action_bound_method=action_bound_method,
lr_scheduler=lr_scheduler,
)
if action_scaling and not np.isclose(actor.max_action, 1.0): # type: ignore
if action_scaling and not np.isclose(actor.max_action, 1.0):
warnings.warn(
"action_scaling and action_bound_method are only intended to deal"
"with unbounded model action space, but find actor model bound"

View File

@ -77,7 +77,7 @@ class PGPolicy(BasePolicy):
action_bound_method=action_bound_method,
lr_scheduler=lr_scheduler,
)
if action_scaling and not np.isclose(actor.max_action, 1.0): # type: ignore
if action_scaling and not np.isclose(actor.max_action, 1.0):
warnings.warn(
"action_scaling and action_bound_method are only intended"
"to deal with unbounded model action space, but find actor model"

View File

@ -9,9 +9,9 @@ from collections.abc import Callable
from datetime import datetime
from io import StringIO
from logging import *
from typing import Any
from typing import Any, TypeVar, cast
log = getLogger(__name__)
log = getLogger(__name__) # type: ignore
LOG_DEFAULT_FORMAT = "%(levelname)-5s %(asctime)-15s %(name)s:%(funcName)s - %(message)s"
@ -20,18 +20,18 @@ LOG_DEFAULT_FORMAT = "%(levelname)-5s %(asctime)-15s %(name)s:%(funcName)s - %(m
_logFormat = LOG_DEFAULT_FORMAT
def remove_log_handlers():
def remove_log_handlers() -> None:
"""Removes all current log handlers."""
logger = getLogger()
while logger.hasHandlers():
logger.removeHandler(logger.handlers[0])
def remove_log_handler(handler):
def remove_log_handler(handler: Handler) -> None:
getLogger().removeHandler(handler)
def is_log_handler_active(handler):
def is_log_handler_active(handler: Handler) -> bool:
"""Checks whether the given handler is active.
:param handler: a log handler
@ -41,7 +41,7 @@ def is_log_handler_active(handler):
# noinspection PyShadowingBuiltins
def configure(format=LOG_DEFAULT_FORMAT, level=lg.DEBUG):
def configure(format: str = LOG_DEFAULT_FORMAT, level: int = lg.DEBUG) -> None:
"""Configures logging to stdout with the given format and log level,
also configuring the default log levels of some overly verbose libraries as well as some pandas output options.
@ -56,8 +56,13 @@ def configure(format=LOG_DEFAULT_FORMAT, level=lg.DEBUG):
getLogger("numba").setLevel(INFO)
T = TypeVar("T")
# noinspection PyShadowingBuiltins
def run_main(main_fn: Callable[[], Any], format=LOG_DEFAULT_FORMAT, level=lg.DEBUG):
def run_main(
main_fn: Callable[[], T], format: str = LOG_DEFAULT_FORMAT, level: int = lg.DEBUG
) -> T | None:
"""Configures logging with the given parameters, ensuring that any exceptions that occur during
the execution of the given function are logged.
Logs two additional messages, one before the execution of the function, and one upon its completion.
@ -68,16 +73,19 @@ def run_main(main_fn: Callable[[], Any], format=LOG_DEFAULT_FORMAT, level=lg.DEB
:return: the result of `main_fn`
"""
configure(format=format, level=level)
log.info("Starting")
log.info("Starting") # type: ignore
try:
result = main_fn()
log.info("Done")
log.info("Done") # type: ignore
return result
except Exception as e:
log.error("Exception during script execution", exc_info=e)
log.error("Exception during script execution", exc_info=e) # type: ignore
return None
def run_cli(main_fn: Callable[[], Any], format=LOG_DEFAULT_FORMAT, level=lg.DEBUG):
def run_cli(
main_fn: Callable[[], T], format: str = LOG_DEFAULT_FORMAT, level: int = lg.DEBUG
) -> T | None:
"""
Configures logging with the given parameters and runs the given main function as a
CLI using `jsonargparse` (which is configured to also parse attribute docstrings, such
@ -107,14 +115,14 @@ _isAtExitReportFileLoggerRegistered = False
_memoryLogStream: StringIO | None = None
def _at_exit_report_file_logger():
def _at_exit_report_file_logger() -> None:
for path in _fileLoggerPaths:
print(f"A log file was saved to {path}")
def add_file_logger(path, register_atexit=True):
def add_file_logger(path: str, register_atexit: bool = True) -> FileHandler:
global _isAtExitReportFileLoggerRegistered
log.info(f"Logging to {path} ...")
log.info(f"Logging to {path} ...") # type: ignore
handler = FileHandler(path)
handler.setFormatter(Formatter(_logFormat))
Logger.root.addHandler(handler)
@ -138,21 +146,22 @@ def add_memory_logger() -> None:
Logger.root.addHandler(handler)
def get_memory_log():
def get_memory_log() -> Any:
""":return: the in-memory log (provided that `add_memory_logger` was called beforehand)"""
assert _memoryLogStream is not None, "This should not have happened and might be a bug."
return _memoryLogStream.getvalue()
class FileLoggerContext:
def __init__(self, path: str, enabled=True):
def __init__(self, path: str, enabled: bool = True):
self.enabled = enabled
self.path = path
self._log_handler = None
self._log_handler: Handler | None = None
def __enter__(self):
def __enter__(self) -> None:
if self.enabled:
self._log_handler = add_file_logger(self.path, register_atexit=False)
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
if self._log_handler is not None:
remove_log_handler(self._log_handler)

View File

@ -370,13 +370,13 @@ class NoisyLinear(nn.Module):
# TODO: rename or change functionality? Usually sample is not an inplace operation...
def sample(self) -> None:
self.eps_p.copy_(self.f(self.eps_p)) # type: ignore
self.eps_q.copy_(self.f(self.eps_q)) # type: ignore
self.eps_p.copy_(self.f(self.eps_p))
self.eps_q.copy_(self.f(self.eps_q))
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.training:
weight = self.mu_W + self.sigma_W * (self.eps_q.ger(self.eps_p)) # type: ignore
bias = self.mu_bias + self.sigma_bias * self.eps_q.clone() # type: ignore
weight = self.mu_W + self.sigma_W * (self.eps_q.ger(self.eps_p))
bias = self.mu_bias + self.sigma_bias * self.eps_q.clone()
else:
weight = self.mu_W
bias = self.mu_bias

View File

@ -10,6 +10,8 @@ from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Mapping, Sequence
from typing import (
Any,
Self,
cast,
)
reCommaWhitespacePotentiallyBreaks = re.compile(r",\s+")
@ -23,11 +25,13 @@ class StringConverter(ABC):
"""Abstraction for a string conversion mechanism."""
@abstractmethod
def to_string(self, x) -> str:
def to_string(self, x: Any) -> str:
pass
def dict_string(d: Mapping, brackets: str | None = None, converter: StringConverter = None):
def dict_string(
d: Mapping, brackets: str | None = None, converter: StringConverter | None = None
) -> str:
"""Converts a dictionary to a string of the form "<key>=<value>, <key>=<value>, ...", optionally enclosed
by brackets.
@ -46,10 +50,10 @@ def dict_string(d: Mapping, brackets: str | None = None, converter: StringConver
def list_string(
l: Iterable[Any],
brackets="[]",
brackets: str | None = "[]",
quote: str | None = None,
converter: StringConverter = None,
):
converter: StringConverter | None = None,
) -> str:
"""Converts a list or any other iterable to a string of the form "[<value>, <value>, ...]", optionally enclosed
by different brackets or with the values quoted.
@ -61,7 +65,7 @@ def list_string(
:return: the string representation
"""
def item(x):
def item(x: Any) -> str:
x = to_string(x, converter=converter, context="list")
if quote is not None:
return quote + x + quote
@ -76,11 +80,11 @@ def list_string(
def to_string(
x,
converter: StringConverter = None,
apply_converter_to_non_complex_objects=True,
context=None,
):
x: Any,
converter: StringConverter | None = None,
apply_converter_to_non_complex_objects: bool = True,
context: Any = None,
) -> str:
"""Converts the given object to a string, with proper handling of lists, tuples and dictionaries, optionally using a converter.
The conversion also removes unwanted line breaks (as present, in particular, in sklearn's string representations).
@ -118,7 +122,7 @@ def to_string(
raise
def object_repr(obj, member_names_or_dict: list[str] | dict[str, Any]):
def object_repr(obj: Any, member_names_or_dict: list[str] | dict[str, Any]) -> str:
"""Creates a string representation for the given object based on the given members.
The string takes the form "ClassName[attr1=value1, attr2=value2, ...]"
@ -130,7 +134,7 @@ def object_repr(obj, member_names_or_dict: list[str] | dict[str, Any]):
return f"{obj.__class__.__name__}[{dict_string(members_dict)}]"
def or_regex_group(allowed_names: Sequence[str]):
def or_regex_group(allowed_names: Sequence[str]) -> str:
""":param allowed_names: strings to include as literals in the regex
:return: a regular expression string of the form (<name1>| ...|<nameN>), which any of the given names
"""
@ -185,7 +189,7 @@ class ToStringMixin:
_TOSTRING_INCLUDE_ALL = "__all__"
def _tostring_class_name(self):
def _tostring_class_name(self) -> str:
""":return: the string use for <class name> in the string representation ``"<class name>[<object info]"``"""
return type(self).__qualname__
@ -196,7 +200,7 @@ class ToStringMixin:
exclude_exceptions: list[str] | None = None,
include_forced: list[str] | None = None,
additional_entries: dict[str, Any] | None = None,
converter: StringConverter = None,
converter: StringConverter | None = None,
) -> str:
"""Creates a string of the class attributes, with optional exclusions/inclusions/additions.
Exclusions take precedence over inclusions.
@ -210,7 +214,7 @@ class ToStringMixin:
:return: a string containing entry/property names and values
"""
def mklist(x):
def mklist(x: Any) -> list[str]:
if x is None:
return []
if isinstance(x, str):
@ -222,7 +226,7 @@ class ToStringMixin:
include_forced = mklist(include_forced)
exclude_exceptions = mklist(exclude_exceptions)
def is_excluded(k):
def is_excluded(k: Any) -> bool:
if k in include_forced or k in exclude_exceptions:
return False
if k in exclude:
@ -329,17 +333,17 @@ class ToStringMixin:
"""
return []
def __str__(self):
def __str__(self) -> str:
return f"{self._tostring_class_name()}[{self._tostring_object_info()}]"
def __repr__(self):
def __repr__(self) -> str:
info = f"id={id(self)}"
property_info = self._tostring_object_info()
if len(property_info) > 0:
info += ", " + property_info
return f"{self._tostring_class_name()}[{info}]"
def pprint(self, file=sys.stdout):
def pprint(self, file: Any = sys.stdout) -> None:
"""Prints a prettily formatted string representation of the object (with line breaks and indentations)
to ``stdout`` or the given file.
@ -364,7 +368,7 @@ class ToStringMixin:
""":param handled_objects: objects which are initially assumed to have been handled already"""
self._handled_to_string_mixin_ids = {id(o) for o in handled_objects}
def to_string(self, x) -> str:
def to_string(self, x: Any) -> str:
if isinstance(x, ToStringMixin):
oid = id(x)
if oid in self._handled_to_string_mixin_ids:
@ -389,17 +393,17 @@ class ToStringMixin:
# methods where we assume that they could transitively call _toStringProperties (others are assumed not to)
TOSTRING_METHODS_TRANSITIVELY_CALLING_TOSTRINGPROPERTIES = {"_tostring_object_info"}
def __init__(self, x: "ToStringMixin", converter):
def __init__(self, x: "ToStringMixin", converter: Any) -> None:
self.x = x
self.converter = converter
def _tostring_properties(self, *args, **kwargs):
return self.x._tostring_properties(*args, **kwargs, converter=self.converter)
def _tostring_properties(self, *args: Any, **kwargs: Any) -> str:
return self.x._tostring_properties(*args, **kwargs, converter=self.converter) # type: ignore[misc]
def _tostring_class_name(self):
def _tostring_class_name(self) -> str:
return self.x._tostring_class_name()
def __getattr__(self, attr: str):
def __getattr__(self, attr: str) -> Any:
if attr.startswith(
"_tostring",
): # ToStringMixin method which we may bind to use this proxy to ensure correct transitive call
@ -413,11 +417,13 @@ class ToStringMixin:
else:
return getattr(self.x, attr)
def __str__(self: "ToStringMixin"):
return ToStringMixin.__str__(self)
def __str__(self) -> str:
return ToStringMixin.__str__(self) # type: ignore[arg-type]
def pretty_string_repr(s: Any, initial_indentation_level=0, indentation_string=" "):
def pretty_string_repr(
s: Any, initial_indentation_level: int = 0, indentation_string: str = " "
) -> str:
"""Creates a pretty string representation (using indentations) from the given object/string representation (as generated, for example, via
ToStringMixin). An indentation level is added for every opening bracket.
@ -432,16 +438,16 @@ def pretty_string_repr(s: Any, initial_indentation_level=0, indentation_string="
result = indentation_string * indent
i = 0
def nl():
def nl() -> None:
nonlocal result
result += "\n" + (indentation_string * indent)
def take(cnt=1):
def take(cnt: int = 1) -> None:
nonlocal result, i
result += s[i : i + cnt]
i += cnt
def find_matching(j):
def find_matching(j: int) -> int | None:
start = j
op = s[j]
cl = {"[": "]", "(": ")", "'": "'"}[s[j]]
@ -492,17 +498,18 @@ def pretty_string_repr(s: Any, initial_indentation_level=0, indentation_string="
class TagBuilder:
"""Assists in building strings made up of components that are joined via a glue string."""
def __init__(self, *initial_components: str, glue="_"):
def __init__(self, *initial_components: str, glue: str = "_"):
""":param initial_components: initial components to always include at the beginning
:param glue: the glue string which joins components
"""
self.glue = glue
self.components = list(initial_components)
def with_component(self, component: str):
def with_component(self, component: str) -> Self:
self.components.append(component)
return self
def with_conditional(self, cond: bool, component: str):
def with_conditional(self, cond: bool, component: str) -> Self:
"""Conditionally adds the given component.
:param cond: the condition
@ -513,7 +520,7 @@ class TagBuilder:
self.components.append(component)
return self
def with_alternative(self, cond: bool, true_component: str, false_component: str):
def with_alternative(self, cond: bool, true_component: str, false_component: str) -> Self:
"""Adds a component depending on a condition.
:param cond: the condition
@ -524,6 +531,6 @@ class TagBuilder:
self.components.append(true_component if cond else false_component)
return self
def build(self):
def build(self) -> str:
""":return: the string (with all components joined)"""
return self.glue.join(self.components)