diff --git a/tianshou/utils/logging.py b/tianshou/utils/logging.py index 36e8146..fbf8817 100644 --- a/tianshou/utils/logging.py +++ b/tianshou/utils/logging.py @@ -24,6 +24,10 @@ def remove_log_handlers(): logger.removeHandler(logger.handlers[0]) +def remove_log_handler(handler): + getLogger().removeHandler(handler) + + def is_log_handler_active(handler): """Checks whether the given handler is active. @@ -85,16 +89,17 @@ def _at_exit_report_file_logger(): print(f"A log file was saved to {path}") -def add_file_logger(path): +def add_file_logger(path, register_atexit=True): global _isAtExitReportFileLoggerRegistered log.info(f"Logging to {path} ...") handler = FileHandler(path) handler.setFormatter(Formatter(_logFormat)) Logger.root.addHandler(handler) _fileLoggerPaths.append(path) - if not _isAtExitReportFileLoggerRegistered: + if not _isAtExitReportFileLoggerRegistered and register_atexit: atexit.register(_at_exit_report_file_logger) _isAtExitReportFileLoggerRegistered = True + return handler def add_memory_logger() -> None: @@ -113,3 +118,18 @@ def add_memory_logger() -> None: def get_memory_log(): """:return: the in-memory log (provided that `add_memory_logger` was called beforehand)""" return _memoryLogStream.getvalue() + + +class FileLoggerContext: + def __init__(self, path: str, enabled=True): + self.enabled = enabled + self.path = path + self._log_handler = None + + def __enter__(self): + if self.enabled: + self._log_handler = add_file_logger(self.path, register_atexit=False) + + def __exit__(self, exc_type, exc_value, traceback): + if self._log_handler is not None: + remove_log_handler(self._log_handler)