Use logging to report trainer epoch status

This commit is contained in:
Dominik Jain 2023-10-13 16:16:12 +02:00
parent 3bba192633
commit fc695a5394
2 changed files with 15 additions and 7 deletions

View File

@ -183,6 +183,7 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
verbose=False,
)
@ -224,6 +225,7 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
train_fn=train_fn,
test_fn=test_fn,
stop_fn=stop_fn,
verbose=False,
)

View File

@ -1,3 +1,4 @@
import logging
import time
from abc import ABC, abstractmethod
from collections import defaultdict, deque
@ -19,6 +20,8 @@ from tianshou.utils import (
tqdm_config,
)
log = logging.getLogger(__name__)
class BaseTrainer(ABC):
"""An iterator base class for trainers.
@ -79,7 +82,9 @@ class BaseTrainer(ABC):
e.g., the reward of agent 1 or the average reward over all agents.
:param logger: A logger that logs statistics during
training/testing/updating. Default to a logger that doesn't log anything.
:param verbose: whether to print the information. Default to True.
:param verbose: whether to print status information to stdout.
If set to False, status information will still be logged (provided that
logging is enabled via the `logging` module).
:param show_progress: whether to display a progress bar when training.
Default to True.
:param test_in_train: whether to test in the training phase.
@ -376,13 +381,14 @@ class BaseTrainer(ABC):
self.best_reward_std = rew_std
if self.save_best_fn:
self.save_best_fn(self.policy)
log_msg = (
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
f" best_reward: {self.best_reward:.6f} ± "
f"{self.best_reward_std:.6f} in #{self.best_epoch}"
)
log.info(log_msg)
if self.verbose:
print(
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
f" best_reward: {self.best_reward:.6f} ± "
f"{self.best_reward_std:.6f} in #{self.best_epoch}",
flush=True,
)
print(log_msg, flush=True)
if not self.is_run:
test_stat = {
"test_reward": rew,