Use logging to report trainer epoch status
This commit is contained in:
parent
3bba192633
commit
fc695a5394
@ -183,6 +183,7 @@ class OnpolicyAgentFactory(AgentFactory, ABC):
|
|||||||
train_fn=train_fn,
|
train_fn=train_fn,
|
||||||
test_fn=test_fn,
|
test_fn=test_fn,
|
||||||
stop_fn=stop_fn,
|
stop_fn=stop_fn,
|
||||||
|
verbose=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -224,6 +225,7 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
|
|||||||
train_fn=train_fn,
|
train_fn=train_fn,
|
||||||
test_fn=test_fn,
|
test_fn=test_fn,
|
||||||
stop_fn=stop_fn,
|
stop_fn=stop_fn,
|
||||||
|
verbose=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
@ -19,6 +20,8 @@ from tianshou.utils import (
|
|||||||
tqdm_config,
|
tqdm_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BaseTrainer(ABC):
|
class BaseTrainer(ABC):
|
||||||
"""An iterator base class for trainers.
|
"""An iterator base class for trainers.
|
||||||
@ -79,7 +82,9 @@ class BaseTrainer(ABC):
|
|||||||
e.g., the reward of agent 1 or the average reward over all agents.
|
e.g., the reward of agent 1 or the average reward over all agents.
|
||||||
:param logger: A logger that logs statistics during
|
:param logger: A logger that logs statistics during
|
||||||
training/testing/updating. Default to a logger that doesn't log anything.
|
training/testing/updating. Default to a logger that doesn't log anything.
|
||||||
:param verbose: whether to print the information. Default to True.
|
:param verbose: whether to print status information to stdout.
|
||||||
|
If set to False, status information will still be logged (provided that
|
||||||
|
logging is enabled via the `logging` module).
|
||||||
:param show_progress: whether to display a progress bar when training.
|
:param show_progress: whether to display a progress bar when training.
|
||||||
Default to True.
|
Default to True.
|
||||||
:param test_in_train: whether to test in the training phase.
|
:param test_in_train: whether to test in the training phase.
|
||||||
@ -376,13 +381,14 @@ class BaseTrainer(ABC):
|
|||||||
self.best_reward_std = rew_std
|
self.best_reward_std = rew_std
|
||||||
if self.save_best_fn:
|
if self.save_best_fn:
|
||||||
self.save_best_fn(self.policy)
|
self.save_best_fn(self.policy)
|
||||||
|
log_msg = (
|
||||||
|
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
|
||||||
|
f" best_reward: {self.best_reward:.6f} ± "
|
||||||
|
f"{self.best_reward_std:.6f} in #{self.best_epoch}"
|
||||||
|
)
|
||||||
|
log.info(log_msg)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(
|
print(log_msg, flush=True)
|
||||||
f"Epoch #{self.epoch}: test_reward: {rew:.6f} ± {rew_std:.6f},"
|
|
||||||
f" best_reward: {self.best_reward:.6f} ± "
|
|
||||||
f"{self.best_reward_std:.6f} in #{self.best_epoch}",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
if not self.is_run:
|
if not self.is_run:
|
||||||
test_stat = {
|
test_stat = {
|
||||||
"test_reward": rew,
|
"test_reward": rew,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user