Use logging to report trainer epoch status
This commit is contained in:
parent
3bba192633
commit
fc695a5394
@ -183,6 +183,7 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
|
||||
train_fn=train_fn,
|
||||
test_fn=test_fn,
|
||||
stop_fn=stop_fn,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
|
||||
@ -224,6 +225,7 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
|
||||
train_fn=train_fn,
|
||||
test_fn=test_fn,
|
||||
stop_fn=stop_fn,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict, deque
|
||||
@ -19,6 +20,8 @@ from tianshou.utils import (
|
||||
tqdm_config,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseTrainer(ABC):
|
||||
"""An iterator base class for trainers.
|
||||
@ -79,7 +82,9 @@ class BaseTrainer(ABC):
|
||||
e.g., the reward of agent 1 or the average reward over all agents.
|
||||
:param logger: A logger that logs statistics during
|
||||
training/testing/updating. Default to a logger that doesn't log anything.
|
||||
:param verbose: whether to print the information. Default to True.
|
||||
:param verbose: whether to print status information to stdout.
|
||||
If set to False, status information will still be logged (provided that
|
||||
logging is enabled via the `logging` module).
|
||||
:param show_progress: whether to display a progress bar when training.
|
||||
Default to True.
|
||||
:param test_in_train: whether to test in the training phase.
|
||||
@ -376,13 +381,14 @@ class BaseTrainer(ABC):
|
||||
self.best_reward_std = rew_std
|
||||
if self.save_best_fn:
|
||||
self.save_best_fn(self.policy)
|
||||
log_msg = (
|
||||
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
|
||||
f" best_reward: {self.best_reward:.6f} ± "
|
||||
f"{self.best_reward_std:.6f} in #{self.best_epoch}"
|
||||
)
|
||||
log.info(log_msg)
|
||||
if self.verbose:
|
||||
print(
|
||||
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
|
||||
f" best_reward: {self.best_reward:.6f} ± "
|
||||
f"{self.best_reward_std:.6f} in #{self.best_epoch}",
|
||||
flush=True,
|
||||
)
|
||||
print(log_msg, flush=True)
|
||||
if not self.is_run:
|
||||
test_stat = {
|
||||
"test_reward": rew,
|
||||
|
Loading…
x
Reference in New Issue
Block a user