Typing: fixed multiple typing issues
This commit is contained in:
parent
2e39a252e3
commit
a846b52063
@ -68,7 +68,7 @@ class Collector:
|
||||
super().__init__()
|
||||
if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
|
||||
warnings.warn("Single environment detected, wrap to DummyVectorEnv.")
|
||||
self.env = DummyVectorEnv([lambda: env]) # type: ignore
|
||||
self.env = DummyVectorEnv([lambda: env])
|
||||
else:
|
||||
self.env = env # type: ignore
|
||||
self.env_num = len(self.env)
|
||||
|
||||
2
tianshou/env/worker/subproc.py
vendored
2
tianshou/env/worker/subproc.py
vendored
@ -203,7 +203,7 @@ class SubprocEnvWorker(EnvWorker):
|
||||
obs = result[0]
|
||||
if self.share_memory:
|
||||
obs = self._decode_obs()
|
||||
return (obs, *result[1:]) # type: ignore
|
||||
return (obs, *result[1:])
|
||||
obs = result
|
||||
if self.share_memory:
|
||||
obs = self._decode_obs()
|
||||
|
||||
@ -78,7 +78,7 @@ class ActorFactory(ModuleFactory, ToStringMixin, ABC):
|
||||
# do last policy layer scaling, this will make initial actions have (close to)
|
||||
# 0 mean and std, and will help boost performances,
|
||||
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
|
||||
for m in actor.mu.modules(): # type: ignore
|
||||
for m in actor.mu.modules():
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
m.weight.data.copy_(0.01 * m.weight.data)
|
||||
|
||||
|
||||
@ -95,7 +95,7 @@ class ICMPolicy(BasePolicy):
|
||||
def set_eps(self, eps: float) -> None:
|
||||
"""Set the eps for epsilon-greedy exploration."""
|
||||
if hasattr(self.policy, "set_eps"):
|
||||
self.policy.set_eps(eps) # type: ignore
|
||||
self.policy.set_eps(eps)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -78,11 +78,11 @@ class BranchingDQNPolicy(DQNPolicy):
|
||||
# but it collides with an attr of the same name in base class
|
||||
@property
|
||||
def _action_per_branch(self) -> int:
|
||||
return self.model.action_per_branch # type: ignore
|
||||
return self.model.action_per_branch
|
||||
|
||||
@property
|
||||
def num_branches(self) -> int:
|
||||
return self.model.num_branches # type: ignore
|
||||
return self.model.num_branches
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
obs_next_batch = Batch(
|
||||
|
||||
@ -78,7 +78,7 @@ class DDPGPolicy(BasePolicy):
|
||||
action_bound_method=action_bound_method,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
if action_scaling and not np.isclose(actor.max_action, 1.0): # type: ignore
|
||||
if action_scaling and not np.isclose(actor.max_action, 1.0):
|
||||
warnings.warn(
|
||||
"action_scaling and action_bound_method are only intended to deal"
|
||||
"with unbounded model action space, but find actor model bound"
|
||||
|
||||
@ -77,7 +77,7 @@ class PGPolicy(BasePolicy):
|
||||
action_bound_method=action_bound_method,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
if action_scaling and not np.isclose(actor.max_action, 1.0): # type: ignore
|
||||
if action_scaling and not np.isclose(actor.max_action, 1.0):
|
||||
warnings.warn(
|
||||
"action_scaling and action_bound_method are only intended"
|
||||
"to deal with unbounded model action space, but find actor model"
|
||||
|
||||
@ -9,9 +9,9 @@ from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from io import StringIO
|
||||
from logging import *
|
||||
from typing import Any
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
log = getLogger(__name__)
|
||||
log = getLogger(__name__) # type: ignore
|
||||
|
||||
LOG_DEFAULT_FORMAT = "%(levelname)-5s %(asctime)-15s %(name)s:%(funcName)s - %(message)s"
|
||||
|
||||
@ -20,18 +20,18 @@ LOG_DEFAULT_FORMAT = "%(levelname)-5s %(asctime)-15s %(name)s:%(funcName)s - %(m
|
||||
_logFormat = LOG_DEFAULT_FORMAT
|
||||
|
||||
|
||||
def remove_log_handlers():
|
||||
def remove_log_handlers() -> None:
|
||||
"""Removes all current log handlers."""
|
||||
logger = getLogger()
|
||||
while logger.hasHandlers():
|
||||
logger.removeHandler(logger.handlers[0])
|
||||
|
||||
|
||||
def remove_log_handler(handler):
|
||||
def remove_log_handler(handler: Handler) -> None:
|
||||
getLogger().removeHandler(handler)
|
||||
|
||||
|
||||
def is_log_handler_active(handler):
|
||||
def is_log_handler_active(handler: Handler) -> bool:
|
||||
"""Checks whether the given handler is active.
|
||||
|
||||
:param handler: a log handler
|
||||
@ -41,7 +41,7 @@ def is_log_handler_active(handler):
|
||||
|
||||
|
||||
# noinspection PyShadowingBuiltins
|
||||
def configure(format=LOG_DEFAULT_FORMAT, level=lg.DEBUG):
|
||||
def configure(format: str = LOG_DEFAULT_FORMAT, level: int = lg.DEBUG) -> None:
|
||||
"""Configures logging to stdout with the given format and log level,
|
||||
also configuring the default log levels of some overly verbose libraries as well as some pandas output options.
|
||||
|
||||
@ -56,8 +56,13 @@ def configure(format=LOG_DEFAULT_FORMAT, level=lg.DEBUG):
|
||||
getLogger("numba").setLevel(INFO)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# noinspection PyShadowingBuiltins
|
||||
def run_main(main_fn: Callable[[], Any], format=LOG_DEFAULT_FORMAT, level=lg.DEBUG):
|
||||
def run_main(
|
||||
main_fn: Callable[[], T], format: str = LOG_DEFAULT_FORMAT, level: int = lg.DEBUG
|
||||
) -> T | None:
|
||||
"""Configures logging with the given parameters, ensuring that any exceptions that occur during
|
||||
the execution of the given function are logged.
|
||||
Logs two additional messages, one before the execution of the function, and one upon its completion.
|
||||
@ -68,16 +73,19 @@ def run_main(main_fn: Callable[[], Any], format=LOG_DEFAULT_FORMAT, level=lg.DEB
|
||||
:return: the result of `main_fn`
|
||||
"""
|
||||
configure(format=format, level=level)
|
||||
log.info("Starting")
|
||||
log.info("Starting") # type: ignore
|
||||
try:
|
||||
result = main_fn()
|
||||
log.info("Done")
|
||||
log.info("Done") # type: ignore
|
||||
return result
|
||||
except Exception as e:
|
||||
log.error("Exception during script execution", exc_info=e)
|
||||
log.error("Exception during script execution", exc_info=e) # type: ignore
|
||||
return None
|
||||
|
||||
|
||||
def run_cli(main_fn: Callable[[], Any], format=LOG_DEFAULT_FORMAT, level=lg.DEBUG):
|
||||
def run_cli(
|
||||
main_fn: Callable[[], T], format: str = LOG_DEFAULT_FORMAT, level: int = lg.DEBUG
|
||||
) -> T | None:
|
||||
"""
|
||||
Configures logging with the given parameters and runs the given main function as a
|
||||
CLI using `jsonargparse` (which is configured to also parse attribute docstrings, such
|
||||
@ -107,14 +115,14 @@ _isAtExitReportFileLoggerRegistered = False
|
||||
_memoryLogStream: StringIO | None = None
|
||||
|
||||
|
||||
def _at_exit_report_file_logger():
|
||||
def _at_exit_report_file_logger() -> None:
|
||||
for path in _fileLoggerPaths:
|
||||
print(f"A log file was saved to {path}")
|
||||
|
||||
|
||||
def add_file_logger(path, register_atexit=True):
|
||||
def add_file_logger(path: str, register_atexit: bool = True) -> FileHandler:
|
||||
global _isAtExitReportFileLoggerRegistered
|
||||
log.info(f"Logging to {path} ...")
|
||||
log.info(f"Logging to {path} ...") # type: ignore
|
||||
handler = FileHandler(path)
|
||||
handler.setFormatter(Formatter(_logFormat))
|
||||
Logger.root.addHandler(handler)
|
||||
@ -138,21 +146,22 @@ def add_memory_logger() -> None:
|
||||
Logger.root.addHandler(handler)
|
||||
|
||||
|
||||
def get_memory_log():
|
||||
def get_memory_log() -> Any:
|
||||
""":return: the in-memory log (provided that `add_memory_logger` was called beforehand)"""
|
||||
assert _memoryLogStream is not None, "This should not have happened and might be a bug."
|
||||
return _memoryLogStream.getvalue()
|
||||
|
||||
|
||||
class FileLoggerContext:
|
||||
def __init__(self, path: str, enabled=True):
|
||||
def __init__(self, path: str, enabled: bool = True):
|
||||
self.enabled = enabled
|
||||
self.path = path
|
||||
self._log_handler = None
|
||||
self._log_handler: Handler | None = None
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> None:
|
||||
if self.enabled:
|
||||
self._log_handler = add_file_logger(self.path, register_atexit=False)
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
if self._log_handler is not None:
|
||||
remove_log_handler(self._log_handler)
|
||||
|
||||
@ -370,13 +370,13 @@ class NoisyLinear(nn.Module):
|
||||
|
||||
# TODO: rename or change functionality? Usually sample is not an inplace operation...
|
||||
def sample(self) -> None:
|
||||
self.eps_p.copy_(self.f(self.eps_p)) # type: ignore
|
||||
self.eps_q.copy_(self.f(self.eps_q)) # type: ignore
|
||||
self.eps_p.copy_(self.f(self.eps_p))
|
||||
self.eps_q.copy_(self.f(self.eps_q))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.training:
|
||||
weight = self.mu_W + self.sigma_W * (self.eps_q.ger(self.eps_p)) # type: ignore
|
||||
bias = self.mu_bias + self.sigma_bias * self.eps_q.clone() # type: ignore
|
||||
weight = self.mu_W + self.sigma_W * (self.eps_q.ger(self.eps_p))
|
||||
bias = self.mu_bias + self.sigma_bias * self.eps_q.clone()
|
||||
else:
|
||||
weight = self.mu_W
|
||||
bias = self.mu_bias
|
||||
|
||||
@ -10,6 +10,8 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Self,
|
||||
cast,
|
||||
)
|
||||
|
||||
reCommaWhitespacePotentiallyBreaks = re.compile(r",\s+")
|
||||
@ -23,11 +25,13 @@ class StringConverter(ABC):
|
||||
"""Abstraction for a string conversion mechanism."""
|
||||
|
||||
@abstractmethod
|
||||
def to_string(self, x) -> str:
|
||||
def to_string(self, x: Any) -> str:
|
||||
pass
|
||||
|
||||
|
||||
def dict_string(d: Mapping, brackets: str | None = None, converter: StringConverter = None):
|
||||
def dict_string(
|
||||
d: Mapping, brackets: str | None = None, converter: StringConverter | None = None
|
||||
) -> str:
|
||||
"""Converts a dictionary to a string of the form "<key>=<value>, <key>=<value>, ...", optionally enclosed
|
||||
by brackets.
|
||||
|
||||
@ -46,10 +50,10 @@ def dict_string(d: Mapping, brackets: str | None = None, converter: StringConver
|
||||
|
||||
def list_string(
|
||||
l: Iterable[Any],
|
||||
brackets="[]",
|
||||
brackets: str | None = "[]",
|
||||
quote: str | None = None,
|
||||
converter: StringConverter = None,
|
||||
):
|
||||
converter: StringConverter | None = None,
|
||||
) -> str:
|
||||
"""Converts a list or any other iterable to a string of the form "[<value>, <value>, ...]", optionally enclosed
|
||||
by different brackets or with the values quoted.
|
||||
|
||||
@ -61,7 +65,7 @@ def list_string(
|
||||
:return: the string representation
|
||||
"""
|
||||
|
||||
def item(x):
|
||||
def item(x: Any) -> str:
|
||||
x = to_string(x, converter=converter, context="list")
|
||||
if quote is not None:
|
||||
return quote + x + quote
|
||||
@ -76,11 +80,11 @@ def list_string(
|
||||
|
||||
|
||||
def to_string(
|
||||
x,
|
||||
converter: StringConverter = None,
|
||||
apply_converter_to_non_complex_objects=True,
|
||||
context=None,
|
||||
):
|
||||
x: Any,
|
||||
converter: StringConverter | None = None,
|
||||
apply_converter_to_non_complex_objects: bool = True,
|
||||
context: Any = None,
|
||||
) -> str:
|
||||
"""Converts the given object to a string, with proper handling of lists, tuples and dictionaries, optionally using a converter.
|
||||
The conversion also removes unwanted line breaks (as present, in particular, in sklearn's string representations).
|
||||
|
||||
@ -118,7 +122,7 @@ def to_string(
|
||||
raise
|
||||
|
||||
|
||||
def object_repr(obj, member_names_or_dict: list[str] | dict[str, Any]):
|
||||
def object_repr(obj: Any, member_names_or_dict: list[str] | dict[str, Any]) -> str:
|
||||
"""Creates a string representation for the given object based on the given members.
|
||||
|
||||
The string takes the form "ClassName[attr1=value1, attr2=value2, ...]"
|
||||
@ -130,7 +134,7 @@ def object_repr(obj, member_names_or_dict: list[str] | dict[str, Any]):
|
||||
return f"{obj.__class__.__name__}[{dict_string(members_dict)}]"
|
||||
|
||||
|
||||
def or_regex_group(allowed_names: Sequence[str]):
|
||||
def or_regex_group(allowed_names: Sequence[str]) -> str:
|
||||
""":param allowed_names: strings to include as literals in the regex
|
||||
:return: a regular expression string of the form (<name1>| ...|<nameN>), which any of the given names
|
||||
"""
|
||||
@ -185,7 +189,7 @@ class ToStringMixin:
|
||||
|
||||
_TOSTRING_INCLUDE_ALL = "__all__"
|
||||
|
||||
def _tostring_class_name(self):
|
||||
def _tostring_class_name(self) -> str:
|
||||
""":return: the string use for <class name> in the string representation ``"<class name>[<object info]"``"""
|
||||
return type(self).__qualname__
|
||||
|
||||
@ -196,7 +200,7 @@ class ToStringMixin:
|
||||
exclude_exceptions: list[str] | None = None,
|
||||
include_forced: list[str] | None = None,
|
||||
additional_entries: dict[str, Any] | None = None,
|
||||
converter: StringConverter = None,
|
||||
converter: StringConverter | None = None,
|
||||
) -> str:
|
||||
"""Creates a string of the class attributes, with optional exclusions/inclusions/additions.
|
||||
Exclusions take precedence over inclusions.
|
||||
@ -210,7 +214,7 @@ class ToStringMixin:
|
||||
:return: a string containing entry/property names and values
|
||||
"""
|
||||
|
||||
def mklist(x):
|
||||
def mklist(x: Any) -> list[str]:
|
||||
if x is None:
|
||||
return []
|
||||
if isinstance(x, str):
|
||||
@ -222,7 +226,7 @@ class ToStringMixin:
|
||||
include_forced = mklist(include_forced)
|
||||
exclude_exceptions = mklist(exclude_exceptions)
|
||||
|
||||
def is_excluded(k):
|
||||
def is_excluded(k: Any) -> bool:
|
||||
if k in include_forced or k in exclude_exceptions:
|
||||
return False
|
||||
if k in exclude:
|
||||
@ -329,17 +333,17 @@ class ToStringMixin:
|
||||
"""
|
||||
return []
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return f"{self._tostring_class_name()}[{self._tostring_object_info()}]"
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
info = f"id={id(self)}"
|
||||
property_info = self._tostring_object_info()
|
||||
if len(property_info) > 0:
|
||||
info += ", " + property_info
|
||||
return f"{self._tostring_class_name()}[{info}]"
|
||||
|
||||
def pprint(self, file=sys.stdout):
|
||||
def pprint(self, file: Any = sys.stdout) -> None:
|
||||
"""Prints a prettily formatted string representation of the object (with line breaks and indentations)
|
||||
to ``stdout`` or the given file.
|
||||
|
||||
@ -364,7 +368,7 @@ class ToStringMixin:
|
||||
""":param handled_objects: objects which are initially assumed to have been handled already"""
|
||||
self._handled_to_string_mixin_ids = {id(o) for o in handled_objects}
|
||||
|
||||
def to_string(self, x) -> str:
|
||||
def to_string(self, x: Any) -> str:
|
||||
if isinstance(x, ToStringMixin):
|
||||
oid = id(x)
|
||||
if oid in self._handled_to_string_mixin_ids:
|
||||
@ -389,17 +393,17 @@ class ToStringMixin:
|
||||
# methods where we assume that they could transitively call _toStringProperties (others are assumed not to)
|
||||
TOSTRING_METHODS_TRANSITIVELY_CALLING_TOSTRINGPROPERTIES = {"_tostring_object_info"}
|
||||
|
||||
def __init__(self, x: "ToStringMixin", converter):
|
||||
def __init__(self, x: "ToStringMixin", converter: Any) -> None:
|
||||
self.x = x
|
||||
self.converter = converter
|
||||
|
||||
def _tostring_properties(self, *args, **kwargs):
|
||||
return self.x._tostring_properties(*args, **kwargs, converter=self.converter)
|
||||
def _tostring_properties(self, *args: Any, **kwargs: Any) -> str:
|
||||
return self.x._tostring_properties(*args, **kwargs, converter=self.converter) # type: ignore[misc]
|
||||
|
||||
def _tostring_class_name(self):
|
||||
def _tostring_class_name(self) -> str:
|
||||
return self.x._tostring_class_name()
|
||||
|
||||
def __getattr__(self, attr: str):
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
if attr.startswith(
|
||||
"_tostring",
|
||||
): # ToStringMixin method which we may bind to use this proxy to ensure correct transitive call
|
||||
@ -413,11 +417,13 @@ class ToStringMixin:
|
||||
else:
|
||||
return getattr(self.x, attr)
|
||||
|
||||
def __str__(self: "ToStringMixin"):
|
||||
return ToStringMixin.__str__(self)
|
||||
def __str__(self) -> str:
|
||||
return ToStringMixin.__str__(self) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def pretty_string_repr(s: Any, initial_indentation_level=0, indentation_string=" "):
|
||||
def pretty_string_repr(
|
||||
s: Any, initial_indentation_level: int = 0, indentation_string: str = " "
|
||||
) -> str:
|
||||
"""Creates a pretty string representation (using indentations) from the given object/string representation (as generated, for example, via
|
||||
ToStringMixin). An indentation level is added for every opening bracket.
|
||||
|
||||
@ -432,16 +438,16 @@ def pretty_string_repr(s: Any, initial_indentation_level=0, indentation_string="
|
||||
result = indentation_string * indent
|
||||
i = 0
|
||||
|
||||
def nl():
|
||||
def nl() -> None:
|
||||
nonlocal result
|
||||
result += "\n" + (indentation_string * indent)
|
||||
|
||||
def take(cnt=1):
|
||||
def take(cnt: int = 1) -> None:
|
||||
nonlocal result, i
|
||||
result += s[i : i + cnt]
|
||||
i += cnt
|
||||
|
||||
def find_matching(j):
|
||||
def find_matching(j: int) -> int | None:
|
||||
start = j
|
||||
op = s[j]
|
||||
cl = {"[": "]", "(": ")", "'": "'"}[s[j]]
|
||||
@ -492,17 +498,18 @@ def pretty_string_repr(s: Any, initial_indentation_level=0, indentation_string="
|
||||
class TagBuilder:
|
||||
"""Assists in building strings made up of components that are joined via a glue string."""
|
||||
|
||||
def __init__(self, *initial_components: str, glue="_"):
|
||||
def __init__(self, *initial_components: str, glue: str = "_"):
|
||||
""":param initial_components: initial components to always include at the beginning
|
||||
:param glue: the glue string which joins components
|
||||
"""
|
||||
self.glue = glue
|
||||
self.components = list(initial_components)
|
||||
|
||||
def with_component(self, component: str):
|
||||
def with_component(self, component: str) -> Self:
|
||||
self.components.append(component)
|
||||
return self
|
||||
|
||||
def with_conditional(self, cond: bool, component: str):
|
||||
def with_conditional(self, cond: bool, component: str) -> Self:
|
||||
"""Conditionally adds the given component.
|
||||
|
||||
:param cond: the condition
|
||||
@ -513,7 +520,7 @@ class TagBuilder:
|
||||
self.components.append(component)
|
||||
return self
|
||||
|
||||
def with_alternative(self, cond: bool, true_component: str, false_component: str):
|
||||
def with_alternative(self, cond: bool, true_component: str, false_component: str) -> Self:
|
||||
"""Adds a component depending on a condition.
|
||||
|
||||
:param cond: the condition
|
||||
@ -524,6 +531,6 @@ class TagBuilder:
|
||||
self.components.append(true_component if cond else false_component)
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
def build(self) -> str:
|
||||
""":return: the string (with all components joined)"""
|
||||
return self.glue.join(self.components)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user