Use logging to report trainer epoch status

This commit is contained in:
Dominik Jain 2023-10-13 16:16:12 +02:00
parent 3bba192633
commit fc695a5394
2 changed files with 15 additions and 7 deletions

View File

@ -183,6 +183,7 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
train_fn=train_fn, train_fn=train_fn,
test_fn=test_fn, test_fn=test_fn,
stop_fn=stop_fn, stop_fn=stop_fn,
verbose=False,
) )
@ -224,6 +225,7 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
train_fn=train_fn, train_fn=train_fn,
test_fn=test_fn, test_fn=test_fn,
stop_fn=stop_fn, stop_fn=stop_fn,
verbose=False,
) )

View File

@ -1,3 +1,4 @@
import logging
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict, deque from collections import defaultdict, deque
@ -19,6 +20,8 @@ from tianshou.utils import (
tqdm_config, tqdm_config,
) )
log = logging.getLogger(__name__)
class BaseTrainer(ABC): class BaseTrainer(ABC):
"""An iterator base class for trainers. """An iterator base class for trainers.
@ -79,7 +82,9 @@ class BaseTrainer(ABC):
e.g., the reward of agent 1 or the average reward over all agents. e.g., the reward of agent 1 or the average reward over all agents.
:param logger: A logger that logs statistics during :param logger: A logger that logs statistics during
training/testing/updating. Default to a logger that doesn't log anything. training/testing/updating. Default to a logger that doesn't log anything.
:param verbose: whether to print the information. Default to True. :param verbose: whether to print status information to stdout.
If set to False, status information will still be logged (provided that
logging is enabled via the `logging` module).
:param show_progress: whether to display a progress bar when training. :param show_progress: whether to display a progress bar when training.
Default to True. Default to True.
:param test_in_train: whether to test in the training phase. :param test_in_train: whether to test in the training phase.
@ -376,13 +381,14 @@ class BaseTrainer(ABC):
self.best_reward_std = rew_std self.best_reward_std = rew_std
if self.save_best_fn: if self.save_best_fn:
self.save_best_fn(self.policy) self.save_best_fn(self.policy)
log_msg = (
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
f" best_reward: {self.best_reward:.6f} ± "
f"{self.best_reward_std:.6f} in #{self.best_epoch}"
)
log.info(log_msg)
if self.verbose: if self.verbose:
print( print(log_msg, flush=True)
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
f" best_reward: {self.best_reward:.6f} ± "
f"{self.best_reward_std:.6f} in #{self.best_epoch}",
flush=True,
)
if not self.is_run: if not self.is_run:
test_stat = { test_stat = {
"test_reward": rew, "test_reward": rew,