Add Weights and Biases Logger (#427)

- rename BasicLogger to TensorboardLogger
- refactor logger code
- add WandbLogger

Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
This commit is contained in:
Andriy Drozdyuk 2021-08-30 10:35:02 -04:00 committed by GitHub
parent e4f4f0e144
commit 8a5e2190f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
57 changed files with 371 additions and 309 deletions

View File

@ -192,7 +192,7 @@ buffer_size = 20000
eps_train, eps_test = 0.1, 0.05
step_per_epoch, step_per_collect = 10000, 10
writer = SummaryWriter('log/dqn') # tensorboard is also supported!
logger = ts.utils.BasicLogger(writer)
logger = ts.utils.TensorboardLogger(writer)
```
Make environments:

View File

@ -40,8 +40,8 @@ This is related to `Issue 349 <https://github.com/thu-ml/tianshou/issues/349>`_.
To resume training process from an existing checkpoint, you need to do the following things in the training process:
1. Make sure you write ``save_checkpoint_fn`` which saves everything needed in the training process, i.e., policy, optim, buffer; pass it to trainer;
2. Use ``BasicLogger`` which contains a tensorboard;
3. To adjust the save frequency, specify ``save_interval`` when initializing BasicLogger.
2. Use ``TensorboardLogger``;
3. To adjust the save frequency, specify ``save_interval`` when initializing TensorboardLogger.
And to successfully resume from a checkpoint:

View File

@ -148,9 +148,9 @@ The trainer supports `TensorBoard <https://www.tensorflow.org/tensorboard>`_ for
::
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
writer = SummaryWriter('log/dqn')
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
Pass the logger into the trainer, and the training result will be recorded into the TensorBoard.

View File

@ -176,7 +176,7 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul
import numpy as np
from copy import deepcopy
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
@ -323,7 +323,7 @@ With the above preparation, we are close to the first learned agent. The followi
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
# ======== callback functions used during training =========

View File

@ -7,7 +7,7 @@ import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offline_trainer
from tianshou.utils.net.discrete import Actor
@ -116,7 +116,7 @@ def test_discrete_bcq(args=get_args()):
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=args.log_interval)
logger = TensorboardLogger(writer, update_interval=args.log_interval)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -6,7 +6,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import C51Policy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
@ -101,7 +101,7 @@ def test_c51(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'c51')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -7,7 +7,7 @@ import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offline_trainer
from tianshou.policy import DiscreteCQLPolicy
@ -108,7 +108,7 @@ def test_discrete_cql(args=get_args()):
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=args.log_interval)
logger = TensorboardLogger(writer, update_interval=args.log_interval)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -7,7 +7,7 @@ import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offline_trainer
from tianshou.utils.net.discrete import Actor
@ -117,7 +117,7 @@ def test_discrete_crr(args=get_args()):
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=args.log_interval)
logger = TensorboardLogger(writer, update_interval=args.log_interval)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -6,7 +6,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DQNPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
@ -96,7 +96,7 @@ def test_dqn(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -6,7 +6,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import FQFPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
@ -112,7 +112,7 @@ def test_fqf(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'fqf')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -6,7 +6,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import IQNPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
@ -109,7 +109,7 @@ def test_iqn(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'iqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -5,7 +5,7 @@ import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.policy import QRDQNPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
@ -99,7 +99,7 @@ def test_qrdqn(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'qrdqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -7,7 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import RainbowPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
@ -121,7 +121,7 @@ def test_rainbow(args=get_args()):
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -7,7 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DQNPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
@ -82,7 +82,7 @@ def test_dqn(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -7,7 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import SACPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
@ -134,7 +134,7 @@ def test_sac_bipedal(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'sac')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -7,7 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DQNPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
@ -84,7 +84,7 @@ def test_dqn(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -7,7 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import SACPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
from tianshou.exploration import OUNoise
from tianshou.utils.net.common import Net
@ -104,7 +104,7 @@ def test_sac(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'sac')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -13,7 +13,7 @@ from torch.utils.tensorboard import SummaryWriter
from torch.distributions import Independent, Normal
from tianshou.policy import A2CPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
@ -141,7 +141,7 @@ def test_a2c(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'a2c', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=100, train_interval=100)
logger = TensorboardLogger(writer, update_interval=100, train_interval=100)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -10,7 +10,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DDPGPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.exploration import GaussianNoise
@ -110,7 +110,7 @@ def test_ddpg(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'ddpg', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -13,7 +13,7 @@ from torch.utils.tensorboard import SummaryWriter
from torch.distributions import Independent, Normal
from tianshou.policy import NPGPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
@ -142,7 +142,7 @@ def test_npg(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'npg', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=100, train_interval=100)
logger = TensorboardLogger(writer, update_interval=100, train_interval=100)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -13,7 +13,7 @@ from torch.utils.tensorboard import SummaryWriter
from torch.distributions import Independent, Normal
from tianshou.policy import PPOPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
@ -149,7 +149,7 @@ def test_ppo(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'ppo', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=100, train_interval=100)
logger = TensorboardLogger(writer, update_interval=100, train_interval=100)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -13,7 +13,7 @@ from torch.utils.tensorboard import SummaryWriter
from torch.distributions import Independent, Normal
from tianshou.policy import PGPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
@ -131,7 +131,7 @@ def test_reinforce(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'reinforce', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=10, train_interval=100)
logger = TensorboardLogger(writer, update_interval=10, train_interval=100)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -10,7 +10,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import SACPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
@ -122,7 +122,7 @@ def test_sac(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'sac', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -10,7 +10,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import TD3Policy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.exploration import GaussianNoise
@ -123,7 +123,7 @@ def test_td3(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'td3', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -13,7 +13,7 @@ from torch.utils.tensorboard import SummaryWriter
from torch.distributions import Independent, Normal
from tianshou.policy import TRPOPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
@ -147,7 +147,7 @@ def test_trpo(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'trpo', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=100, train_interval=100)
logger = TensorboardLogger(writer, update_interval=100, train_interval=100)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -6,7 +6,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import C51Policy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
@ -101,7 +101,7 @@ def test_c51(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'c51')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -63,6 +63,7 @@ setup(
"pytest",
"pytest-cov",
"ray>=1.0.0",
"wandb>=0.12.0",
"networkx",
"mypy",
"pydocstyle",

View File

@ -7,7 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DDPGPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
@ -90,7 +90,7 @@ def test_ddpg(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'ddpg')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -9,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter
from torch.distributions import Independent, Normal
from tianshou.policy import NPGPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
@ -105,7 +105,7 @@ def test_npg(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'npg')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -8,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter
from torch.distributions import Independent, Normal
from tianshou.policy import PPOPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
@ -116,7 +116,7 @@ def test_ppo(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'ppo')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer, save_interval=args.save_interval)
logger = TensorboardLogger(writer, save_interval=args.save_interval)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -6,7 +6,7 @@ import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
@ -109,7 +109,7 @@ def test_sac_with_il(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'sac')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -7,7 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import TD3Policy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
@ -104,7 +104,7 @@ def test_td3(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'td3')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -9,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter
from torch.distributions import Independent, Normal
from tianshou.policy import TRPOPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
@ -109,7 +109,7 @@ def test_trpo(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'trpo')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -6,7 +6,7 @@ import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.data import Collector, VectorReplayBuffer
@ -91,7 +91,7 @@ def test_a2c_with_il(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'a2c')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -8,7 +8,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import C51Policy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
@ -93,7 +93,7 @@ def test_c51(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'c51')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer, save_interval=args.save_interval)
logger = TensorboardLogger(writer, save_interval=args.save_interval)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -8,7 +8,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DQNPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
@ -92,7 +92,7 @@ def test_dqn(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -7,7 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DQNPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.utils.net.common import Recurrent
@ -78,7 +78,7 @@ def test_drqn(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'drqn')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -7,7 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import FQFPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
@ -100,7 +100,7 @@ def test_fqf(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'fqf')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -8,7 +8,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offline_trainer
@ -87,7 +87,7 @@ def test_discrete_bcq(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'discrete_bcq')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer, save_interval=args.save_interval)
logger = TensorboardLogger(writer, save_interval=args.save_interval)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -8,7 +8,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offline_trainer
@ -80,7 +80,7 @@ def test_discrete_crr(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'discrete_cql')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -7,7 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import IQNPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
@ -96,7 +96,7 @@ def test_iqn(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'iqn')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -7,7 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import PGPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
@ -78,7 +78,7 @@ def test_pg(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'pg')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -7,7 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import PPOPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
@ -102,7 +102,7 @@ def test_ppo(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'ppo')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -7,7 +7,7 @@ import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.policy import QRDQNPolicy
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
@ -94,7 +94,7 @@ def test_qrdqn(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'qrdqn')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -8,7 +8,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offline_trainer
@ -79,7 +79,7 @@ def test_discrete_cql(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'discrete_cql')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -8,7 +8,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import RainbowPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.utils.net.discrete import NoisyLinear
@ -102,7 +102,7 @@ def test_rainbow(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'rainbow')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer, save_interval=args.save_interval)
logger = TensorboardLogger(writer, save_interval=args.save_interval)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -6,7 +6,7 @@ import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.policy import DiscreteSACPolicy
@ -97,7 +97,7 @@ def test_discrete_sac(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'discrete_sac')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

View File

@ -7,7 +7,6 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import PSRLPolicy
# from tianshou.utils import BasicLogger
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
@ -71,7 +70,6 @@ def test_psrl(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'psrl')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
# logger = BasicLogger(writer)
def stop_fn(mean_rewards):
if env.spec.reward_threshold:

View File

@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter
from tianshou.env import DummyVectorEnv
from tianshou.data import Collector
from tianshou.policy import RandomPolicy
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tic_tac_toe_env import TicTacToeEnv
from tic_tac_toe import get_parser, get_agents, train_agent, watch
@ -33,7 +33,7 @@ def gomoku(args=get_args()):
# log
log_path = os.path.join(args.logdir, 'Gomoku', 'dqn')
writer = SummaryWriter(log_path)
args.logger = BasicLogger(writer)
args.logger = TensorboardLogger(writer)
opponent_pool = [agent_opponent]

View File

@ -6,7 +6,7 @@ from copy import deepcopy
from typing import Optional, Tuple
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.utils import TensorboardLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
@ -135,7 +135,7 @@ def train_agent(
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
logger = TensorboardLogger(writer)
def save_fn(policy):
if hasattr(args, 'model_save_path'):

View File

@ -22,7 +22,7 @@ class PPOPolicy(A2CPolicy):
:param float dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5,
where c > 1 is a constant indicating the lower bound.
Default to 5.0 (set None if you do not want to use it).
:param bool value_clip: a parameter mentioned in arXiv:1811.02553 Sec. 4.1.
:param bool value_clip: a parameter mentioned in arXiv:1811.02553v3 Sec. 4.1.
Default to True.
:param bool advantage_normalization: whether to do per mini-batch advantage
normalization. Default to True.

View File

@ -1,12 +1,17 @@
from tianshou.utils.config import tqdm_config
from tianshou.utils.statistics import MovAvg, RunningMeanStd
from tianshou.utils.log_tools import BasicLogger, LazyLogger, BaseLogger
from tianshou.utils.logger.base import BaseLogger, LazyLogger
from tianshou.utils.logger.tensorboard import TensorboardLogger, BasicLogger
from tianshou.utils.logger.wandb import WandBLogger
__all__ = [
"MovAvg",
"RunningMeanStd",
"tqdm_config",
"BaseLogger",
"TensorboardLogger",
"BasicLogger",
"LazyLogger",
"WandBLogger"
]

View File

@ -1,210 +0,0 @@
import numpy as np
from numbers import Number
from abc import ABC, abstractmethod
from torch.utils.tensorboard import SummaryWriter
from typing import Any, Tuple, Union, Callable, Optional
from tensorboard.backend.event_processing import event_accumulator
WRITE_TYPE = Union[int, Number, np.number, np.ndarray]
class BaseLogger(ABC):
"""The base class for any logger which is compatible with trainer."""
def __init__(self, writer: Any) -> None:
super().__init__()
self.writer = writer
@abstractmethod
def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None:
"""Specify how the writer is used to log data.
:param str key: namespace which the input data tuple belongs to.
:param int x: stands for the ordinate of the input data tuple.
:param y: stands for the abscissa of the input data tuple.
"""
pass
def log_train_data(self, collect_result: dict, step: int) -> None:
"""Use writer to log statistics generated during training.
:param collect_result: a dict containing information of data collected in
training stage, i.e., returns of collector.collect().
:param int step: stands for the timestep the collect_result being logged.
"""
pass
def log_update_data(self, update_result: dict, step: int) -> None:
"""Use writer to log statistics generated during updating.
:param update_result: a dict containing information of data collected in
updating stage, i.e., returns of policy.update().
:param int step: stands for the timestep the collect_result being logged.
"""
pass
def log_test_data(self, collect_result: dict, step: int) -> None:
"""Use writer to log statistics generated during evaluating.
:param collect_result: a dict containing information of data collected in
evaluating stage, i.e., returns of collector.collect().
:param int step: stands for the timestep the collect_result being logged.
"""
pass
def save_data(
self,
epoch: int,
env_step: int,
gradient_step: int,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
) -> None:
"""Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.
:param int epoch: the epoch in trainer.
:param int env_step: the env_step in trainer.
:param int gradient_step: the gradient_step in trainer.
:param function save_checkpoint_fn: a hook defined by user, see trainer
documentation for detail.
"""
pass
def restore_data(self) -> Tuple[int, int, int]:
"""Return the metadata from existing log.
If it finds nothing or an error occurs during the recover process, it will
return the default parameters.
:return: epoch, env_step, gradient_step.
"""
pass
class BasicLogger(BaseLogger):
"""A loggger that relies on tensorboard SummaryWriter by default to visualize \
and log statistics.
You can also rewrite write() func to use your own writer.
:param SummaryWriter writer: the writer to log data.
:param int train_interval: the log interval in log_train_data(). Default to 1000.
:param int test_interval: the log interval in log_test_data(). Default to 1.
:param int update_interval: the log interval in log_update_data(). Default to 1000.
:param int save_interval: the save interval in save_data(). Default to 1 (save at
the end of each epoch).
"""
def __init__(
self,
writer: SummaryWriter,
train_interval: int = 1000,
test_interval: int = 1,
update_interval: int = 1000,
save_interval: int = 1,
) -> None:
super().__init__(writer)
self.train_interval = train_interval
self.test_interval = test_interval
self.update_interval = update_interval
self.save_interval = save_interval
self.last_log_train_step = -1
self.last_log_test_step = -1
self.last_log_update_step = -1
self.last_save_step = -1
def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None:
self.writer.add_scalar(key, y, global_step=x)
def log_train_data(self, collect_result: dict, step: int) -> None:
"""Use writer to log statistics generated during training.
:param collect_result: a dict containing information of data collected in
training stage, i.e., returns of collector.collect().
:param int step: stands for the timestep the collect_result being logged.
.. note::
``collect_result`` will be modified in-place with "rew" and "len" keys.
"""
if collect_result["n/ep"] > 0:
collect_result["rew"] = collect_result["rews"].mean()
collect_result["len"] = collect_result["lens"].mean()
if step - self.last_log_train_step >= self.train_interval:
self.write("train/n/ep", step, collect_result["n/ep"])
self.write("train/rew", step, collect_result["rew"])
self.write("train/len", step, collect_result["len"])
self.last_log_train_step = step
def log_test_data(self, collect_result: dict, step: int) -> None:
"""Use writer to log statistics generated during evaluating.
:param collect_result: a dict containing information of data collected in
evaluating stage, i.e., returns of collector.collect().
:param int step: stands for the timestep the collect_result being logged.
.. note::
``collect_result`` will be modified in-place with "rew", "rew_std", "len",
and "len_std" keys.
"""
assert collect_result["n/ep"] > 0
rews, lens = collect_result["rews"], collect_result["lens"]
rew, rew_std, len_, len_std = rews.mean(), rews.std(), lens.mean(), lens.std()
collect_result.update(rew=rew, rew_std=rew_std, len=len_, len_std=len_std)
if step - self.last_log_test_step >= self.test_interval:
self.write("test/rew", step, rew)
self.write("test/len", step, len_)
self.write("test/rew_std", step, rew_std)
self.write("test/len_std", step, len_std)
self.last_log_test_step = step
def log_update_data(self, update_result: dict, step: int) -> None:
if step - self.last_log_update_step >= self.update_interval:
for k, v in update_result.items():
self.write(k, step, v)
self.last_log_update_step = step
def save_data(
self,
epoch: int,
env_step: int,
gradient_step: int,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
) -> None:
if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval:
self.last_save_step = epoch
save_checkpoint_fn(epoch, env_step, gradient_step)
self.write("save/epoch", epoch, epoch)
self.write("save/env_step", env_step, env_step)
self.write("save/gradient_step", gradient_step, gradient_step)
def restore_data(self) -> Tuple[int, int, int]:
ea = event_accumulator.EventAccumulator(self.writer.log_dir)
ea.Reload()
try: # epoch / gradient_step
epoch = ea.scalars.Items("save/epoch")[-1].step
self.last_save_step = self.last_log_test_step = epoch
gradient_step = ea.scalars.Items("save/gradient_step")[-1].step
self.last_log_update_step = gradient_step
except KeyError:
epoch, gradient_step = 0, 0
try: # offline trainer doesn't have env_step
env_step = ea.scalars.Items("save/env_step")[-1].step
self.last_log_train_step = env_step
except KeyError:
env_step = 0
return epoch, env_step, gradient_step
class LazyLogger(BasicLogger):
"""A loggger that does nothing. Used as the placeholder in trainer."""
def __init__(self) -> None:
super().__init__(None) # type: ignore
def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None:
"""The LazyLogger writes nothing."""
pass

View File

View File

@ -0,0 +1,141 @@
import numpy as np
from numbers import Number
from abc import ABC, abstractmethod
from typing import Dict, Tuple, Union, Callable, Optional
LOG_DATA_TYPE = Dict[str, Union[int, Number, np.number, np.ndarray]]
class BaseLogger(ABC):
"""The base class for any logger which is compatible with trainer.
Try to overwrite write() method to use your own writer.
:param int train_interval: the log interval in log_train_data(). Default to 1000.
:param int test_interval: the log interval in log_test_data(). Default to 1.
:param int update_interval: the log interval in log_update_data(). Default to 1000.
"""
def __init__(
self,
train_interval: int = 1000,
test_interval: int = 1,
update_interval: int = 1000,
) -> None:
super().__init__()
self.train_interval = train_interval
self.test_interval = test_interval
self.update_interval = update_interval
self.last_log_train_step = -1
self.last_log_test_step = -1
self.last_log_update_step = -1
@abstractmethod
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
"""Specify how the writer is used to log data.
:param str step_type: namespace which the data dict belongs to.
:param int step: stands for the ordinate of the data dict.
:param dict data: the data to write with format ``{key: value}``.
"""
pass
def log_train_data(self, collect_result: dict, step: int) -> None:
"""Use writer to log statistics generated during training.
:param collect_result: a dict containing information of data collected in
training stage, i.e., returns of collector.collect().
:param int step: stands for the timestep the collect_result being logged.
.. note::
``collect_result`` will be modified in-place with "rew" and "len" keys.
"""
if collect_result["n/ep"] > 0:
collect_result["rew"] = collect_result["rews"].mean()
collect_result["len"] = collect_result["lens"].mean()
if step - self.last_log_train_step >= self.train_interval:
log_data = {
"train/episode": collect_result["n/ep"],
"train/reward": collect_result["rew"],
"train/length": collect_result["len"],
}
self.write("train/env_step", step, log_data)
self.last_log_train_step = step
def log_test_data(self, collect_result: dict, step: int) -> None:
"""Use writer to log statistics generated during evaluating.
:param collect_result: a dict containing information of data collected in
evaluating stage, i.e., returns of collector.collect().
:param int step: stands for the timestep the collect_result being logged.
.. note::
``collect_result`` will be modified in-place with "rew", "rew_std", "len",
and "len_std" keys.
"""
assert collect_result["n/ep"] > 0
rews, lens = collect_result["rews"], collect_result["lens"]
rew, rew_std, len_, len_std = rews.mean(), rews.std(), lens.mean(), lens.std()
collect_result.update(rew=rew, rew_std=rew_std, len=len_, len_std=len_std)
if step - self.last_log_test_step >= self.test_interval:
log_data = {
"test/env_step": step,
"test/reward": rew,
"test/length": len_,
"test/reward_std": rew_std,
"test/length_std": len_std,
}
self.write("test/env_step", step, log_data)
self.last_log_test_step = step
def log_update_data(self, update_result: dict, step: int) -> None:
"""Use writer to log statistics generated during updating.
:param update_result: a dict containing information of data collected in
updating stage, i.e., returns of policy.update().
:param int step: stands for the timestep the collect_result being logged.
"""
if step - self.last_log_update_step >= self.update_interval:
log_data = {f"update/{k}": v for k, v in update_result.items()}
self.write("update/gradient_step", step, log_data)
self.last_log_update_step = step
def save_data(
self,
epoch: int,
env_step: int,
gradient_step: int,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
) -> None:
"""Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.
:param int epoch: the epoch in trainer.
:param int env_step: the env_step in trainer.
:param int gradient_step: the gradient_step in trainer.
:param function save_checkpoint_fn: a hook defined by user, see trainer
documentation for detail.
"""
pass
def restore_data(self) -> Tuple[int, int, int]:
"""Return the metadata from existing log.
If it finds nothing or an error occurs during the recover process, it will
return the default parameters.
:return: epoch, env_step, gradient_step.
"""
pass
class LazyLogger(BaseLogger):
"""A logger that does nothing. Used as the placeholder in trainer."""
def __init__(self) -> None:
super().__init__()
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
"""The LazyLogger writes nothing."""
pass

View File

@ -0,0 +1,83 @@
import warnings
from typing import Any, Tuple, Callable, Optional
from torch.utils.tensorboard import SummaryWriter
from tensorboard.backend.event_processing import event_accumulator
from tianshou.utils.logger.base import BaseLogger, LOG_DATA_TYPE
class TensorboardLogger(BaseLogger):
"""A logger that relies on tensorboard SummaryWriter by default to visualize \
and log statistics.
:param SummaryWriter writer: the writer to log data.
:param int train_interval: the log interval in log_train_data(). Default to 1000.
:param int test_interval: the log interval in log_test_data(). Default to 1.
:param int update_interval: the log interval in log_update_data(). Default to 1000.
:param int save_interval: the save interval in save_data(). Default to 1 (save at
the end of each epoch).
"""
def __init__(
self,
writer: SummaryWriter,
train_interval: int = 1000,
test_interval: int = 1,
update_interval: int = 1000,
save_interval: int = 1,
) -> None:
super().__init__(train_interval, test_interval, update_interval)
self.save_interval = save_interval
self.last_save_step = -1
self.writer = writer
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
for k, v in data.items():
self.writer.add_scalar(k, v, global_step=step)
def save_data(
self,
epoch: int,
env_step: int,
gradient_step: int,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
) -> None:
if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval:
self.last_save_step = epoch
save_checkpoint_fn(epoch, env_step, gradient_step)
self.write("save/epoch", epoch, {"save/epoch": epoch})
self.write("save/env_step", env_step, {"save/env_step": env_step})
self.write("save/gradient_step", gradient_step,
{"save/gradient_step": gradient_step})
def restore_data(self) -> Tuple[int, int, int]:
ea = event_accumulator.EventAccumulator(self.writer.log_dir)
ea.Reload()
try: # epoch / gradient_step
epoch = ea.scalars.Items("save/epoch")[-1].step
self.last_save_step = self.last_log_test_step = epoch
gradient_step = ea.scalars.Items("save/gradient_step")[-1].step
self.last_log_update_step = gradient_step
except KeyError:
epoch, gradient_step = 0, 0
try: # offline trainer doesn't have env_step
env_step = ea.scalars.Items("save/env_step")[-1].step
self.last_log_train_step = env_step
except KeyError:
env_step = 0
return epoch, env_step, gradient_step
class BasicLogger(TensorboardLogger):
"""BasicLogger has changed its name to TensorboardLogger in #427.
This class is for compatibility.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn(
"Deprecated soon: BasicLogger has renamed to TensorboardLogger in #427.")
super().__init__(*args, **kwargs)

View File

@ -0,0 +1,44 @@
from tianshou.utils import BaseLogger
from tianshou.utils.logger.base import LOG_DATA_TYPE
try:
import wandb
except ImportError:
pass
class WandBLogger(BaseLogger):
"""Weights and Biases logger that sends data to Weights and Biases.
Creates three panels with plots: train, test, and update.
Make sure to select the correct access for each panel in weights and biases:
- ``train/env_step`` for train plots
- ``test/env_step`` for test plots
- ``update/gradient_step`` for update plots
Example of usage:
::
with wandb.init(project="My Project"):
logger = WandBLogger()
result = onpolicy_trainer(policy, train_collector, test_collector,
logger=logger)
:param int train_interval: the log interval in log_train_data(). Default to 1000.
:param int test_interval: the log interval in log_test_data(). Default to 1.
:param int update_interval: the log interval in log_update_data().
Default to 1000.
"""
def __init__(
self,
train_interval: int = 1000,
test_interval: int = 1,
update_interval: int = 1000,
) -> None:
super().__init__(train_interval, test_interval, update_interval)
def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
data[step_type] = step
wandb.log(data)