From a846b52063b09f80a3871fdf41d235de96919c3e Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Tue, 5 Dec 2023 12:04:18 +0100 Subject: [PATCH] Typing: fixed multiple typing issues --- tianshou/data/collector.py | 2 +- tianshou/env/worker/subproc.py | 2 +- tianshou/highlevel/module/actor.py | 2 +- tianshou/policy/modelbased/icm.py | 2 +- tianshou/policy/modelfree/bdq.py | 4 +- tianshou/policy/modelfree/ddpg.py | 2 +- tianshou/policy/modelfree/pg.py | 2 +- tianshou/utils/logging.py | 47 ++++++++++------- tianshou/utils/net/discrete.py | 8 +-- tianshou/utils/string.py | 81 ++++++++++++++++-------------- 10 files changed, 84 insertions(+), 68 deletions(-) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 0600422..da04420 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -68,7 +68,7 @@ class Collector: super().__init__() if isinstance(env, gym.Env) and not hasattr(env, "__len__"): warnings.warn("Single environment detected, wrap to DummyVectorEnv.") - self.env = DummyVectorEnv([lambda: env]) # type: ignore + self.env = DummyVectorEnv([lambda: env]) else: self.env = env # type: ignore self.env_num = len(self.env) diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index 54a081c..331e683 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -203,7 +203,7 @@ class SubprocEnvWorker(EnvWorker): obs = result[0] if self.share_memory: obs = self._decode_obs() - return (obs, *result[1:]) # type: ignore + return (obs, *result[1:]) obs = result if self.share_memory: obs = self._decode_obs() diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index dac3b79..faaaa68 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -78,7 +78,7 @@ class ActorFactory(ModuleFactory, ToStringMixin, ABC): # do last policy layer scaling, this will make initial actions have (close to) # 0 mean and std, and will help boost performances, # see https://arxiv.org/abs/2006.05990, Fig.24 for details - for m in actor.mu.modules(): # type: ignore + for m in actor.mu.modules(): if isinstance(m, torch.nn.Linear): m.weight.data.copy_(0.01 * m.weight.data) diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index 016399e..8f05d6f 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -95,7 +95,7 @@ class ICMPolicy(BasePolicy): def set_eps(self, eps: float) -> None: """Set the eps for epsilon-greedy exploration.""" if hasattr(self.policy, "set_eps"): - self.policy.set_eps(eps) # type: ignore + self.policy.set_eps(eps) else: raise NotImplementedError diff --git a/tianshou/policy/modelfree/bdq.py b/tianshou/policy/modelfree/bdq.py index 3623289..b78aa15 100644 --- a/tianshou/policy/modelfree/bdq.py +++ b/tianshou/policy/modelfree/bdq.py @@ -78,11 +78,11 @@ class BranchingDQNPolicy(DQNPolicy): # but it collides with an attr of the same name in base class @property def _action_per_branch(self) -> int: - return self.model.action_per_branch # type: ignore + return self.model.action_per_branch @property def num_branches(self) -> int: - return self.model.num_branches # type: ignore + return self.model.num_branches def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: obs_next_batch = Batch( diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index c00d199..d2f0987 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -78,7 +78,7 @@ class DDPGPolicy(BasePolicy): action_bound_method=action_bound_method, lr_scheduler=lr_scheduler, ) - if action_scaling and not np.isclose(actor.max_action, 1.0): # type: ignore + if action_scaling and not np.isclose(actor.max_action, 1.0): warnings.warn( "action_scaling and action_bound_method are only intended to deal" "with unbounded model action space, but find actor model bound" diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 1bbc095..09be46a 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -77,7 +77,7 @@ class PGPolicy(BasePolicy): action_bound_method=action_bound_method, lr_scheduler=lr_scheduler, ) - if action_scaling and not np.isclose(actor.max_action, 1.0): # type: ignore + if action_scaling and not np.isclose(actor.max_action, 1.0): warnings.warn( "action_scaling and action_bound_method are only intended" "to deal with unbounded model action space, but find actor model" diff --git a/tianshou/utils/logging.py b/tianshou/utils/logging.py index 6c12570..607f09c 100644 --- a/tianshou/utils/logging.py +++ b/tianshou/utils/logging.py @@ -9,9 +9,9 @@ from collections.abc import Callable from datetime import datetime from io import StringIO from logging import * -from typing import Any +from typing import Any, TypeVar, cast -log = getLogger(__name__) +log = getLogger(__name__) # type: ignore LOG_DEFAULT_FORMAT = "%(levelname)-5s %(asctime)-15s %(name)s:%(funcName)s - %(message)s" @@ -20,18 +20,18 @@ LOG_DEFAULT_FORMAT = "%(levelname)-5s %(asctime)-15s %(name)s:%(funcName)s - %(m _logFormat = LOG_DEFAULT_FORMAT -def remove_log_handlers(): +def remove_log_handlers() -> None: """Removes all current log handlers.""" logger = getLogger() while logger.hasHandlers(): logger.removeHandler(logger.handlers[0]) -def remove_log_handler(handler): +def remove_log_handler(handler: Handler) -> None: getLogger().removeHandler(handler) -def is_log_handler_active(handler): +def is_log_handler_active(handler: Handler) -> bool: """Checks whether the given handler is active. :param handler: a log handler @@ -41,7 +41,7 @@ def is_log_handler_active(handler): # noinspection PyShadowingBuiltins -def configure(format=LOG_DEFAULT_FORMAT, level=lg.DEBUG): +def configure(format: str = LOG_DEFAULT_FORMAT, level: int = lg.DEBUG) -> None: """Configures logging to stdout with the given format and log level, also configuring the default log levels of some overly verbose libraries as well as some pandas output options. @@ -56,8 +56,13 @@ def configure(format=LOG_DEFAULT_FORMAT, level=lg.DEBUG): getLogger("numba").setLevel(INFO) +T = TypeVar("T") + + # noinspection PyShadowingBuiltins -def run_main(main_fn: Callable[[], Any], format=LOG_DEFAULT_FORMAT, level=lg.DEBUG): +def run_main( + main_fn: Callable[[], T], format: str = LOG_DEFAULT_FORMAT, level: int = lg.DEBUG +) -> T | None: """Configures logging with the given parameters, ensuring that any exceptions that occur during the execution of the given function are logged. Logs two additional messages, one before the execution of the function, and one upon its completion. @@ -68,16 +73,19 @@ def run_main(main_fn: Callable[[], Any], format=LOG_DEFAULT_FORMAT, level=lg.DEB :return: the result of `main_fn` """ configure(format=format, level=level) - log.info("Starting") + log.info("Starting") # type: ignore try: result = main_fn() - log.info("Done") + log.info("Done") # type: ignore return result except Exception as e: - log.error("Exception during script execution", exc_info=e) + log.error("Exception during script execution", exc_info=e) # type: ignore + return None -def run_cli(main_fn: Callable[[], Any], format=LOG_DEFAULT_FORMAT, level=lg.DEBUG): +def run_cli( + main_fn: Callable[[], T], format: str = LOG_DEFAULT_FORMAT, level: int = lg.DEBUG +) -> T | None: """ Configures logging with the given parameters and runs the given main function as a CLI using `jsonargparse` (which is configured to also parse attribute docstrings, such @@ -107,14 +115,14 @@ _isAtExitReportFileLoggerRegistered = False _memoryLogStream: StringIO | None = None -def _at_exit_report_file_logger(): +def _at_exit_report_file_logger() -> None: for path in _fileLoggerPaths: print(f"A log file was saved to {path}") -def add_file_logger(path, register_atexit=True): +def add_file_logger(path: str, register_atexit: bool = True) -> FileHandler: global _isAtExitReportFileLoggerRegistered - log.info(f"Logging to {path} ...") + log.info(f"Logging to {path} ...") # type: ignore handler = FileHandler(path) handler.setFormatter(Formatter(_logFormat)) Logger.root.addHandler(handler) @@ -138,21 +146,22 @@ def add_memory_logger() -> None: Logger.root.addHandler(handler) -def get_memory_log(): +def get_memory_log() -> Any: """:return: the in-memory log (provided that `add_memory_logger` was called beforehand)""" + assert _memoryLogStream is not None, "This should not have happened and might be a bug." return _memoryLogStream.getvalue() class FileLoggerContext: - def __init__(self, path: str, enabled=True): + def __init__(self, path: str, enabled: bool = True): self.enabled = enabled self.path = path - self._log_handler = None + self._log_handler: Handler | None = None - def __enter__(self): + def __enter__(self) -> None: if self.enabled: self._log_handler = add_file_logger(self.path, register_atexit=False) - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: if self._log_handler is not None: remove_log_handler(self._log_handler) diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 1b5b019..ae77c5d 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -370,13 +370,13 @@ class NoisyLinear(nn.Module): # TODO: rename or change functionality? Usually sample is not an inplace operation... def sample(self) -> None: - self.eps_p.copy_(self.f(self.eps_p)) # type: ignore - self.eps_q.copy_(self.f(self.eps_q)) # type: ignore + self.eps_p.copy_(self.f(self.eps_p)) + self.eps_q.copy_(self.f(self.eps_q)) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.training: - weight = self.mu_W + self.sigma_W * (self.eps_q.ger(self.eps_p)) # type: ignore - bias = self.mu_bias + self.sigma_bias * self.eps_q.clone() # type: ignore + weight = self.mu_W + self.sigma_W * (self.eps_q.ger(self.eps_p)) + bias = self.mu_bias + self.sigma_bias * self.eps_q.clone() else: weight = self.mu_W bias = self.mu_bias diff --git a/tianshou/utils/string.py b/tianshou/utils/string.py index 445f2e9..dcc71d0 100644 --- a/tianshou/utils/string.py +++ b/tianshou/utils/string.py @@ -10,6 +10,8 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Mapping, Sequence from typing import ( Any, + Self, + cast, ) reCommaWhitespacePotentiallyBreaks = re.compile(r",\s+") @@ -23,11 +25,13 @@ class StringConverter(ABC): """Abstraction for a string conversion mechanism.""" @abstractmethod - def to_string(self, x) -> str: + def to_string(self, x: Any) -> str: pass -def dict_string(d: Mapping, brackets: str | None = None, converter: StringConverter = None): +def dict_string( + d: Mapping, brackets: str | None = None, converter: StringConverter | None = None +) -> str: """Converts a dictionary to a string of the form "=, =, ...", optionally enclosed by brackets. @@ -46,10 +50,10 @@ def dict_string(d: Mapping, brackets: str | None = None, converter: StringConver def list_string( l: Iterable[Any], - brackets="[]", + brackets: str | None = "[]", quote: str | None = None, - converter: StringConverter = None, -): + converter: StringConverter | None = None, +) -> str: """Converts a list or any other iterable to a string of the form "[, , ...]", optionally enclosed by different brackets or with the values quoted. @@ -61,7 +65,7 @@ def list_string( :return: the string representation """ - def item(x): + def item(x: Any) -> str: x = to_string(x, converter=converter, context="list") if quote is not None: return quote + x + quote @@ -76,11 +80,11 @@ def list_string( def to_string( - x, - converter: StringConverter = None, - apply_converter_to_non_complex_objects=True, - context=None, -): + x: Any, + converter: StringConverter | None = None, + apply_converter_to_non_complex_objects: bool = True, + context: Any = None, +) -> str: """Converts the given object to a string, with proper handling of lists, tuples and dictionaries, optionally using a converter. The conversion also removes unwanted line breaks (as present, in particular, in sklearn's string representations). @@ -118,7 +122,7 @@ def to_string( raise -def object_repr(obj, member_names_or_dict: list[str] | dict[str, Any]): +def object_repr(obj: Any, member_names_or_dict: list[str] | dict[str, Any]) -> str: """Creates a string representation for the given object based on the given members. The string takes the form "ClassName[attr1=value1, attr2=value2, ...]" @@ -130,7 +134,7 @@ def object_repr(obj, member_names_or_dict: list[str] | dict[str, Any]): return f"{obj.__class__.__name__}[{dict_string(members_dict)}]" -def or_regex_group(allowed_names: Sequence[str]): +def or_regex_group(allowed_names: Sequence[str]) -> str: """:param allowed_names: strings to include as literals in the regex :return: a regular expression string of the form (| ...|), which any of the given names """ @@ -185,7 +189,7 @@ class ToStringMixin: _TOSTRING_INCLUDE_ALL = "__all__" - def _tostring_class_name(self): + def _tostring_class_name(self) -> str: """:return: the string use for in the string representation ``"[ str: """Creates a string of the class attributes, with optional exclusions/inclusions/additions. Exclusions take precedence over inclusions. @@ -210,7 +214,7 @@ class ToStringMixin: :return: a string containing entry/property names and values """ - def mklist(x): + def mklist(x: Any) -> list[str]: if x is None: return [] if isinstance(x, str): @@ -222,7 +226,7 @@ class ToStringMixin: include_forced = mklist(include_forced) exclude_exceptions = mklist(exclude_exceptions) - def is_excluded(k): + def is_excluded(k: Any) -> bool: if k in include_forced or k in exclude_exceptions: return False if k in exclude: @@ -329,17 +333,17 @@ class ToStringMixin: """ return [] - def __str__(self): + def __str__(self) -> str: return f"{self._tostring_class_name()}[{self._tostring_object_info()}]" - def __repr__(self): + def __repr__(self) -> str: info = f"id={id(self)}" property_info = self._tostring_object_info() if len(property_info) > 0: info += ", " + property_info return f"{self._tostring_class_name()}[{info}]" - def pprint(self, file=sys.stdout): + def pprint(self, file: Any = sys.stdout) -> None: """Prints a prettily formatted string representation of the object (with line breaks and indentations) to ``stdout`` or the given file. @@ -364,7 +368,7 @@ class ToStringMixin: """:param handled_objects: objects which are initially assumed to have been handled already""" self._handled_to_string_mixin_ids = {id(o) for o in handled_objects} - def to_string(self, x) -> str: + def to_string(self, x: Any) -> str: if isinstance(x, ToStringMixin): oid = id(x) if oid in self._handled_to_string_mixin_ids: @@ -389,17 +393,17 @@ class ToStringMixin: # methods where we assume that they could transitively call _toStringProperties (others are assumed not to) TOSTRING_METHODS_TRANSITIVELY_CALLING_TOSTRINGPROPERTIES = {"_tostring_object_info"} - def __init__(self, x: "ToStringMixin", converter): + def __init__(self, x: "ToStringMixin", converter: Any) -> None: self.x = x self.converter = converter - def _tostring_properties(self, *args, **kwargs): - return self.x._tostring_properties(*args, **kwargs, converter=self.converter) + def _tostring_properties(self, *args: Any, **kwargs: Any) -> str: + return self.x._tostring_properties(*args, **kwargs, converter=self.converter) # type: ignore[misc] - def _tostring_class_name(self): + def _tostring_class_name(self) -> str: return self.x._tostring_class_name() - def __getattr__(self, attr: str): + def __getattr__(self, attr: str) -> Any: if attr.startswith( "_tostring", ): # ToStringMixin method which we may bind to use this proxy to ensure correct transitive call @@ -413,11 +417,13 @@ class ToStringMixin: else: return getattr(self.x, attr) - def __str__(self: "ToStringMixin"): - return ToStringMixin.__str__(self) + def __str__(self) -> str: + return ToStringMixin.__str__(self) # type: ignore[arg-type] -def pretty_string_repr(s: Any, initial_indentation_level=0, indentation_string=" "): +def pretty_string_repr( + s: Any, initial_indentation_level: int = 0, indentation_string: str = " " +) -> str: """Creates a pretty string representation (using indentations) from the given object/string representation (as generated, for example, via ToStringMixin). An indentation level is added for every opening bracket. @@ -432,16 +438,16 @@ def pretty_string_repr(s: Any, initial_indentation_level=0, indentation_string=" result = indentation_string * indent i = 0 - def nl(): + def nl() -> None: nonlocal result result += "\n" + (indentation_string * indent) - def take(cnt=1): + def take(cnt: int = 1) -> None: nonlocal result, i result += s[i : i + cnt] i += cnt - def find_matching(j): + def find_matching(j: int) -> int | None: start = j op = s[j] cl = {"[": "]", "(": ")", "'": "'"}[s[j]] @@ -492,17 +498,18 @@ def pretty_string_repr(s: Any, initial_indentation_level=0, indentation_string=" class TagBuilder: """Assists in building strings made up of components that are joined via a glue string.""" - def __init__(self, *initial_components: str, glue="_"): + def __init__(self, *initial_components: str, glue: str = "_"): """:param initial_components: initial components to always include at the beginning :param glue: the glue string which joins components """ self.glue = glue self.components = list(initial_components) - def with_component(self, component: str): + def with_component(self, component: str) -> Self: self.components.append(component) + return self - def with_conditional(self, cond: bool, component: str): + def with_conditional(self, cond: bool, component: str) -> Self: """Conditionally adds the given component. :param cond: the condition @@ -513,7 +520,7 @@ class TagBuilder: self.components.append(component) return self - def with_alternative(self, cond: bool, true_component: str, false_component: str): + def with_alternative(self, cond: bool, true_component: str, false_component: str) -> Self: """Adds a component depending on a condition. :param cond: the condition @@ -524,6 +531,6 @@ class TagBuilder: self.components.append(true_component if cond else false_component) return self - def build(self): + def build(self) -> str: """:return: the string (with all components joined)""" return self.glue.join(self.components)