From fc695a53941e8afe8abef61d976f5bffab01a378 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Fri, 13 Oct 2023 16:16:12 +0200 Subject: [PATCH] Use logging to report trainer epoch status --- tianshou/highlevel/agent.py | 2 ++ tianshou/trainer/base.py | 20 +++++++++++++------- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 6cc43d9..702252e 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -183,6 +183,7 @@ class OnpolicyAgentFactory(AgentFactory, ABC): train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, + verbose=False, ) @@ -224,6 +225,7 @@ class OffpolicyAgentFactory(AgentFactory, ABC): train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, + verbose=False, ) diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 1ad82b1..919ff89 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -1,3 +1,4 @@ +import logging import time from abc import ABC, abstractmethod from collections import defaultdict, deque @@ -19,6 +20,8 @@ from tianshou.utils import ( tqdm_config, ) +log = logging.getLogger(__name__) + class BaseTrainer(ABC): """An iterator base class for trainers. @@ -79,7 +82,9 @@ class BaseTrainer(ABC): e.g., the reward of agent 1 or the average reward over all agents. :param logger: A logger that logs statistics during training/testing/updating. Default to a logger that doesn't log anything. - :param verbose: whether to print the information. Default to True. + :param verbose: whether to print status information to stdout. + If set to False, status information will still be logged (provided that + logging is enabled via the `logging` module). :param show_progress: whether to display a progress bar when training. Default to True. :param test_in_train: whether to test in the training phase. @@ -376,13 +381,14 @@ class BaseTrainer(ABC): self.best_reward_std = rew_std if self.save_best_fn: self.save_best_fn(self.policy) + log_msg = ( + f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}," + f" best_reward: {self.best_reward:.6f} ± " + f"{self.best_reward_std:.6f} in #{self.best_epoch}" + ) + log.info(log_msg) if self.verbose: - print( - f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}," - f" best_reward: {self.best_reward:.6f} ± " - f"{self.best_reward_std:.6f} in #{self.best_epoch}", - flush=True, - ) + print(log_msg, flush=True) if not self.is_run: test_stat = { "test_reward": rew,