- rename BasicLogger to TensorboardLogger - refactor logger code - add WandbLogger Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
45 lines
1.3 KiB
Python
45 lines
1.3 KiB
Python
from tianshou.utils import BaseLogger
|
|
from tianshou.utils.logger.base import LOG_DATA_TYPE
|
|
|
|
try:
|
|
import wandb
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
class WandBLogger(BaseLogger):
|
|
"""Weights and Biases logger that sends data to Weights and Biases.
|
|
|
|
Creates three panels with plots: train, test, and update.
|
|
Make sure to select the correct access for each panel in weights and biases:
|
|
|
|
- ``train/env_step`` for train plots
|
|
- ``test/env_step`` for test plots
|
|
- ``update/gradient_step`` for update plots
|
|
|
|
Example of usage:
|
|
::
|
|
|
|
with wandb.init(project="My Project"):
|
|
logger = WandBLogger()
|
|
result = onpolicy_trainer(policy, train_collector, test_collector,
|
|
logger=logger)
|
|
|
|
:param int train_interval: the log interval in log_train_data(). Default to 1000.
|
|
:param int test_interval: the log interval in log_test_data(). Default to 1.
|
|
:param int update_interval: the log interval in log_update_data().
|
|
Default to 1000.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
train_interval: int = 1000,
|
|
test_interval: int = 1,
|
|
update_interval: int = 1000,
|
|
) -> None:
|
|
super().__init__(train_interval, test_interval, update_interval)
|
|
|
|
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
|
|
data[step_type] = step
|
|
wandb.log(data)
|