add logger (#295)

This PR focus on refactor of logging method to solve bug of nan reward and log interval. After these two pr, hopefully fundamental change of tianshou/data is finished. We then can concentrate on building benchmarks of tianshou finally.

Things changed:

1. trainer now accepts logger (BasicLogger or LazyLogger) instead of writer;
2. remove utils.SummaryWriter;
This commit is contained in:
ChenDRAG 2021-02-24 14:48:42 +08:00
parent e99e1b0fdd
commit 9b61bc620c
45 changed files with 406 additions and 249 deletions

View File

@ -197,6 +197,7 @@ buffer_size = 20000
eps_train, eps_test = 0.1, 0.05 eps_train, eps_test = 0.1, 0.05
step_per_epoch, step_per_collect = 10000, 10 step_per_epoch, step_per_collect = 10000, 10
writer = SummaryWriter('log/dqn') # tensorboard is also supported! writer = SummaryWriter('log/dqn') # tensorboard is also supported!
logger = ts.utils.BasicLogger(writer)
``` ```
Make environments: Make environments:
@ -237,7 +238,7 @@ result = ts.trainer.offpolicy_trainer(
train_fn=lambda epoch, env_step: policy.set_eps(eps_train), train_fn=lambda epoch, env_step: policy.set_eps(eps_train),
test_fn=lambda epoch, env_step: policy.set_eps(eps_test), test_fn=lambda epoch, env_step: policy.set_eps(eps_test),
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
writer=writer) logger=logger)
print(f'Finished training! Use {result["duration"]}') print(f'Finished training! Use {result["duration"]}')
``` ```

View File

@ -7,3 +7,4 @@ We always welcome contributions to help make Tianshou better. Below are an incom
* Minghao Zhang (`Mehooz <https://github.com/Mehooz>`_) * Minghao Zhang (`Mehooz <https://github.com/Mehooz>`_)
* Alexis Duburcq (`duburcqa <https://github.com/duburcqa>`_) * Alexis Duburcq (`duburcqa <https://github.com/duburcqa>`_)
* Kaichao You (`youkaichao <https://github.com/youkaichao>`_) * Kaichao You (`youkaichao <https://github.com/youkaichao>`_)
* Huayu Chen (`ChenDRAG <https://github.com/ChenDRAG>`_)

View File

@ -130,7 +130,7 @@ Tianshou provides :func:`~tianshou.trainer.onpolicy_trainer`, :func:`~tianshou.t
train_fn=lambda epoch, env_step: policy.set_eps(0.1), train_fn=lambda epoch, env_step: policy.set_eps(0.1),
test_fn=lambda epoch, env_step: policy.set_eps(0.05), test_fn=lambda epoch, env_step: policy.set_eps(0.05),
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold, stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
writer=None) logger=None)
print(f'Finished training! Use {result["duration"]}') print(f'Finished training! Use {result["duration"]}')
The meaning of each parameter is as follows (full description can be found at :func:`~tianshou.trainer.offpolicy_trainer`): The meaning of each parameter is as follows (full description can be found at :func:`~tianshou.trainer.offpolicy_trainer`):
@ -143,15 +143,17 @@ The meaning of each parameter is as follows (full description can be found at :f
* ``train_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training". * ``train_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training".
* ``test_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing". * ``test_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing".
* ``stop_fn``: A function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal. * ``stop_fn``: A function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal.
* ``writer``: See below. * ``logger``: See below.
The trainer supports `TensorBoard <https://www.tensorflow.org/tensorboard>`_ for logging. It can be used as: The trainer supports `TensorBoard <https://www.tensorflow.org/tensorboard>`_ for logging. It can be used as:
:: ::
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
writer = SummaryWriter('log/dqn') writer = SummaryWriter('log/dqn')
logger = BasicLogger(writer)
Pass the writer into the trainer, and the training result will be recorded into the TensorBoard. Pass the logger into the trainer, and the training result will be recorded into the TensorBoard.
The returned result is a dictionary as follows: The returned result is a dictionary as follows:
:: ::

View File

@ -176,6 +176,7 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul
import numpy as np import numpy as np
from copy import deepcopy from copy import deepcopy
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
@ -319,11 +320,10 @@ With the above preparation, we are close to the first learned agent. The followi
train_collector.collect(n_step=args.batch_size * args.training_num) train_collector.collect(n_step=args.batch_size * args.training_num)
# ======== tensorboard logging setup ========= # ======== tensorboard logging setup =========
if not hasattr(args, 'writer'): log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn') writer = SummaryWriter(log_path)
writer = SummaryWriter(log_path) writer.add_text("args", str(args))
else: logger = BasicLogger(writer)
writer = args.writer
# ======== callback functions used during training ========= # ======== callback functions used during training =========
@ -359,7 +359,7 @@ With the above preparation, we are close to the first learned agent. The followi
args.step_per_epoch, args.step_per_collect, args.test_num, args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn, args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step, stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step,
writer=writer, test_in_train=False, reward_metric=reward_metric) logger=logger, test_in_train=False, reward_metric=reward_metric)
agent = policy.policies[args.agent_id - 1] agent = policy.policies[args.agent_id - 1]
# let's watch the match! # let's watch the match!

View File

@ -2,10 +2,12 @@ import os
import torch import torch
import pickle import pickle
import pprint import pprint
import datetime
import argparse import argparse
import numpy as np import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offline_trainer from tianshou.trainer import offline_trainer
from tianshou.utils.net.discrete import Actor from tianshou.utils.net.discrete import Actor
@ -39,7 +41,7 @@ def get_args():
parser.add_argument("--resume-path", type=str, default=None) parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--watch", default=False, action="store_true", parser.add_argument("--watch", default=False, action="store_true",
help="watch the play of pre-trained policy only") help="watch the play of pre-trained policy only")
parser.add_argument("--log-interval", type=int, default=1000) parser.add_argument("--log-interval", type=int, default=100)
parser.add_argument( parser.add_argument(
"--load-buffer-name", type=str, "--load-buffer-name", type=str,
default="./expert_DQN_PongNoFrameskip-v4.hdf5", default="./expert_DQN_PongNoFrameskip-v4.hdf5",
@ -113,8 +115,13 @@ def test_discrete_bcq(args=get_args()):
# collector # collector
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)
log_path = os.path.join(args.logdir, args.task, 'discrete_bcq') # log
log_path = os.path.join(
args.logdir, args.task, 'bcq',
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=args.log_interval)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -141,7 +148,7 @@ def test_discrete_bcq(args=get_args()):
result = offline_trainer( result = offline_trainer(
policy, buffer, test_collector, policy, buffer, test_collector,
args.epoch, args.update_per_epoch, args.test_num, args.batch_size, args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
stop_fn=stop_fn, save_fn=save_fn, writer=writer, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
log_interval=args.log_interval, log_interval=args.log_interval,
) )

View File

@ -6,6 +6,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import C51Policy from tianshou.policy import C51Policy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
@ -98,6 +99,8 @@ def test_c51(args=get_args()):
# log # log
log_path = os.path.join(args.logdir, args.task, 'c51') log_path = os.path.join(args.logdir, args.task, 'c51')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -118,7 +121,7 @@ def test_c51(args=get_args()):
else: else:
eps = args.eps_train_final eps = args.eps_train_final
policy.set_eps(eps) policy.set_eps(eps)
writer.add_scalar('train/eps', eps, global_step=env_step) logger.write('train/eps', env_step, eps)
def test_fn(epoch, env_step): def test_fn(epoch, env_step):
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)
@ -144,7 +147,7 @@ def test_c51(args=get_args()):
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn, args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step, test_in_train=False) update_per_step=args.update_per_step, test_in_train=False)
pprint.pprint(result) pprint.pprint(result)

View File

@ -6,6 +6,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DQNPolicy from tianshou.policy import DQNPolicy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
@ -94,6 +95,8 @@ def test_dqn(args=get_args()):
# log # log
log_path = os.path.join(args.logdir, args.task, 'dqn') log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -114,7 +117,7 @@ def test_dqn(args=get_args()):
else: else:
eps = args.eps_train_final eps = args.eps_train_final
policy.set_eps(eps) policy.set_eps(eps)
writer.add_scalar('train/eps', eps, global_step=env_step) logger.write('train/eps', env_step, eps)
def test_fn(epoch, env_step): def test_fn(epoch, env_step):
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)
@ -154,7 +157,7 @@ def test_dqn(args=get_args()):
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn, args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step, test_in_train=False) update_per_step=args.update_per_step, test_in_train=False)
pprint.pprint(result) pprint.pprint(result)

View File

@ -5,6 +5,7 @@ import argparse
import numpy as np import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.policy import QRDQNPolicy from tianshou.policy import QRDQNPolicy
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
@ -96,6 +97,8 @@ def test_qrdqn(args=get_args()):
# log # log
log_path = os.path.join(args.logdir, args.task, 'qrdqn') log_path = os.path.join(args.logdir, args.task, 'qrdqn')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -116,7 +119,7 @@ def test_qrdqn(args=get_args()):
else: else:
eps = args.eps_train_final eps = args.eps_train_final
policy.set_eps(eps) policy.set_eps(eps)
writer.add_scalar('train/eps', eps, global_step=env_step) logger.write('train/eps', env_step, eps)
def test_fn(epoch, env_step): def test_fn(epoch, env_step):
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)
@ -142,7 +145,7 @@ def test_qrdqn(args=get_args()):
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn, args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step, test_in_train=False) update_per_step=args.update_per_step, test_in_train=False)
pprint.pprint(result) pprint.pprint(result)

View File

@ -4,7 +4,9 @@ import pprint
import argparse import argparse
import numpy as np import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import A2CPolicy from tianshou.policy import A2CPolicy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer from tianshou.trainer import onpolicy_trainer
@ -79,7 +81,9 @@ def test_a2c(args=get_args()):
preprocess_fn=preprocess_fn, exploration_noise=True) preprocess_fn=preprocess_fn, exploration_noise=True)
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
# log # log
writer = SummaryWriter(os.path.join(args.logdir, args.task, 'a2c')) log_path = os.path.join(args.logdir, args.task, 'a2c')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def stop_fn(mean_rewards): def stop_fn(mean_rewards):
if env.env.spec.reward_threshold: if env.env.spec.reward_threshold:
@ -91,7 +95,7 @@ def test_a2c(args=get_args()):
result = onpolicy_trainer( result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer) episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, logger=logger)
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!

View File

@ -6,11 +6,12 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import PPOPolicy from tianshou.policy import PPOPolicy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.utils.net.discrete import Actor, Critic from tianshou.utils.net.discrete import Actor, Critic
from tianshou.data import Collector, VectorReplayBuffer
from atari import create_atari_environment, preprocess_fn from atari import create_atari_environment, preprocess_fn
@ -84,7 +85,9 @@ def test_ppo(args=get_args()):
preprocess_fn=preprocess_fn, exploration_noise=True) preprocess_fn=preprocess_fn, exploration_noise=True)
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn) test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
# log # log
writer = SummaryWriter(os.path.join(args.logdir, args.task, 'ppo')) log_path = os.path.join(args.logdir, args.task, 'ppo')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def stop_fn(mean_rewards): def stop_fn(mean_rewards):
if env.env.spec.reward_threshold: if env.env.spec.reward_threshold:
@ -96,7 +99,8 @@ def test_ppo(args=get_args()):
result = onpolicy_trainer( result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer) episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, logger=logger)
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!

View File

@ -7,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DQNPolicy from tianshou.policy import DQNPolicy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
@ -81,6 +82,7 @@ def test_dqn(args=get_args()):
# log # log
log_path = os.path.join(args.logdir, args.task, 'dqn') log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -106,7 +108,7 @@ def test_dqn(args=get_args()):
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
update_per_step=args.update_per_step, train_fn=train_fn, test_fn=test_fn, update_per_step=args.update_per_step, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer) stop_fn=stop_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -7,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import SACPolicy from tianshou.policy import SACPolicy
from tianshou.utils import BasicLogger
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
@ -134,6 +135,7 @@ def test_sac_bipedal(args=get_args()):
# log # log
log_path = os.path.join(args.logdir, args.task, 'sac') log_path = os.path.join(args.logdir, args.task, 'sac')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -146,7 +148,7 @@ def test_sac_bipedal(args=get_args()):
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
update_per_step=args.update_per_step, test_in_train=False, update_per_step=args.update_per_step, test_in_train=False,
stop_fn=stop_fn, save_fn=save_fn, writer=writer) stop_fn=stop_fn, save_fn=save_fn, logger=logger)
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)

View File

@ -7,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DQNPolicy from tianshou.policy import DQNPolicy
from tianshou.utils import BasicLogger
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
@ -83,6 +84,7 @@ def test_dqn(args=get_args()):
# log # log
log_path = os.path.join(args.logdir, args.task, 'dqn') log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -102,7 +104,7 @@ def test_dqn(args=get_args()):
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
update_per_step=args.update_per_step, stop_fn=stop_fn, train_fn=train_fn, update_per_step=args.update_per_step, stop_fn=stop_fn, train_fn=train_fn,
test_fn=test_fn, save_fn=save_fn, writer=writer) test_fn=test_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -7,11 +7,12 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import SACPolicy from tianshou.policy import SACPolicy
from tianshou.trainer import offpolicy_trainer from tianshou.utils import BasicLogger
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.exploration import OUNoise from tianshou.exploration import OUNoise
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.utils.net.continuous import ActorProb, Critic from tianshou.utils.net.continuous import ActorProb, Critic
@ -103,6 +104,7 @@ def test_sac(args=get_args()):
# log # log
log_path = os.path.join(args.logdir, args.task, 'sac') log_path = os.path.join(args.logdir, args.task, 'sac')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -115,7 +117,7 @@ def test_sac(args=get_args()):
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
update_per_step=args.update_per_step, stop_fn=stop_fn, update_per_step=args.update_per_step, stop_fn=stop_fn,
save_fn=save_fn, writer=writer) save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -2,11 +2,13 @@ import os
import gym import gym
import torch import torch
import pprint import pprint
import datetime
import argparse import argparse
import numpy as np import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import SACPolicy from tianshou.policy import SACPolicy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
@ -114,8 +116,11 @@ def test_sac(args=get_args()):
exploration_noise=True) exploration_noise=True)
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
# log # log
log_path = os.path.join(args.logdir, args.task, 'sac') log_path = os.path.join(args.logdir, args.task, 'sac', 'seed_' + str(
args.seed) + '_' + datetime.datetime.now().strftime('%m%d-%H%M%S'))
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, train_interval=args.log_interval)
def watch(): def watch():
# watch agent's performance # watch agent's performance
@ -141,8 +146,8 @@ def test_sac(args=get_args()):
result = offpolicy_trainer( result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step, log_interval=args.log_interval) update_per_step=args.update_per_step)
pprint.pprint(result) pprint.pprint(result)
watch() watch()

View File

@ -1,3 +1,4 @@
import os
import gym import gym
import torch import torch
import pprint import pprint
@ -6,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DDPGPolicy from tianshou.policy import DDPGPolicy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
@ -79,7 +81,9 @@ def test_ddpg(args=get_args()):
exploration_noise=True) exploration_noise=True)
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
# log # log
writer = SummaryWriter(args.logdir + '/' + 'ddpg') log_path = os.path.join(args.logdir, args.task, 'ddpg')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def stop_fn(mean_rewards): def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold return mean_rewards >= env.spec.reward_threshold
@ -88,7 +92,7 @@ def test_ddpg(args=get_args()):
result = offpolicy_trainer( result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer) args.batch_size, stop_fn=stop_fn, logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)

View File

@ -1,3 +1,4 @@
import os
import gym import gym
import torch import torch
import pprint import pprint
@ -6,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import TD3Policy from tianshou.policy import TD3Policy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.exploration import GaussianNoise from tianshou.exploration import GaussianNoise
@ -88,7 +90,9 @@ def test_td3(args=get_args()):
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size) # train_collector.collect(n_step=args.buffer_size)
# log # log
writer = SummaryWriter(args.logdir + '/' + 'td3') log_path = os.path.join(args.logdir, args.task, 'td3')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def stop_fn(mean_rewards): def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold return mean_rewards >= env.spec.reward_threshold
@ -97,7 +101,7 @@ def test_td3(args=get_args()):
result = offpolicy_trainer( result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer) args.batch_size, stop_fn=stop_fn, logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)

View File

@ -2,12 +2,14 @@ import os
import gym import gym
import torch import torch
import pprint import pprint
import datetime
import argparse import argparse
import numpy as np import numpy as np
import pybullet_envs import pybullet_envs
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import SACPolicy from tianshou.policy import SACPolicy
from tianshou.utils import BasicLogger
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
@ -88,8 +90,10 @@ def test_sac(args=get_args()):
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size) # train_collector.collect(n_step=args.buffer_size)
# log # log
log_path = os.path.join(args.logdir, args.task, 'sac', args.run_id) log_path = os.path.join(args.logdir, args.task, 'sac', 'seed_' + str(
args.seed) + '_' + datetime.datetime.now().strftime('%m%d-%H%M%S'))
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
logger = BasicLogger(writer, train_interval=args.log_interval)
def stop_fn(mean_rewards): def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold return mean_rewards >= env.spec.reward_threshold
@ -99,7 +103,7 @@ def test_sac(args=get_args()):
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, stop_fn=stop_fn, args.batch_size, stop_fn=stop_fn,
writer=writer, log_interval=args.log_interval) logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)

View File

@ -1,3 +1,4 @@
import os
import gym import gym
import torch import torch
import pprint import pprint
@ -6,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import TD3Policy from tianshou.policy import TD3Policy
from tianshou.utils import BasicLogger
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.exploration import GaussianNoise from tianshou.exploration import GaussianNoise
@ -93,7 +95,9 @@ def test_td3(args=get_args()):
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size) # train_collector.collect(n_step=args.buffer_size)
# log # log
writer = SummaryWriter(args.logdir + '/' + 'td3') log_path = os.path.join(args.logdir, args.task, 'td3')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def stop_fn(mean_rewards): def stop_fn(mean_rewards):
if env.spec.reward_threshold: if env.spec.reward_threshold:
@ -105,7 +109,7 @@ def test_td3(args=get_args()):
result = offpolicy_trainer( result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer) args.batch_size, stop_fn=stop_fn, logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)

View File

@ -2,7 +2,6 @@ import torch
import numpy as np import numpy as np
from tianshou.utils import MovAvg from tianshou.utils import MovAvg
from tianshou.utils import SummaryWriter
from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.common import MLP, Net
from tianshou.exploration import GaussianNoise, OUNoise from tianshou.exploration import GaussianNoise, OUNoise
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
@ -77,25 +76,7 @@ def test_net():
assert list(net(data, act).shape) == [bsz, 1] assert list(net(data, act).shape) == [bsz, 1]
def test_summary_writer():
# get first instance by key of `default` or your own key
writer1 = SummaryWriter.get_instance(
key="first", log_dir="log/test_sw/first")
assert writer1.log_dir == "log/test_sw/first"
writer2 = SummaryWriter.get_instance()
assert writer1 is writer2
# create new instance by specify a new key
writer3 = SummaryWriter.get_instance(
key="second", log_dir="log/test_sw/second")
assert writer3.log_dir == "log/test_sw/second"
writer4 = SummaryWriter.get_instance(key="second")
assert writer3 is writer4
assert writer1 is not writer3
assert writer1.log_dir != writer4.log_dir
if __name__ == '__main__': if __name__ == '__main__':
test_noise() test_noise()
test_moving_average() test_moving_average()
test_net() test_net()
test_summary_writer()

View File

@ -7,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DDPGPolicy from tianshou.policy import DDPGPolicy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
@ -93,6 +94,7 @@ def test_ddpg(args=get_args()):
# log # log
log_path = os.path.join(args.logdir, args.task, 'ddpg') log_path = os.path.join(args.logdir, args.task, 'ddpg')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -105,7 +107,7 @@ def test_ddpg(args=get_args()):
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
update_per_step=args.update_per_step, stop_fn=stop_fn, update_per_step=args.update_per_step, stop_fn=stop_fn,
save_fn=save_fn, writer=writer) save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)

View File

@ -8,6 +8,7 @@ from torch.utils.tensorboard import SummaryWriter
from torch.distributions import Independent, Normal from torch.distributions import Independent, Normal
from tianshou.policy import PPOPolicy from tianshou.policy import PPOPolicy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer from tianshou.trainer import onpolicy_trainer
@ -111,6 +112,7 @@ def test_ppo(args=get_args()):
# log # log
log_path = os.path.join(args.logdir, args.task, 'ppo') log_path = os.path.join(args.logdir, args.task, 'ppo')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -123,7 +125,7 @@ def test_ppo(args=get_args()):
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
writer=writer) logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)

View File

@ -6,6 +6,7 @@ import argparse
import numpy as np import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
@ -102,6 +103,7 @@ def test_sac_with_il(args=get_args()):
# log # log
log_path = os.path.join(args.logdir, args.task, 'sac') log_path = os.path.join(args.logdir, args.task, 'sac')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -114,7 +116,7 @@ def test_sac_with_il(args=get_args()):
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
update_per_step=args.update_per_step, stop_fn=stop_fn, update_per_step=args.update_per_step, stop_fn=stop_fn,
save_fn=save_fn, writer=writer) save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)
@ -146,7 +148,7 @@ def test_sac_with_il(args=get_args()):
result = offpolicy_trainer( result = offpolicy_trainer(
il_policy, train_collector, il_test_collector, args.epoch, il_policy, train_collector, il_test_collector, args.epoch,
args.il_step_per_epoch, args.step_per_collect, args.test_num, args.il_step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)

View File

@ -7,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import TD3Policy from tianshou.policy import TD3Policy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
@ -106,6 +107,7 @@ def test_td3(args=get_args()):
# log # log
log_path = os.path.join(args.logdir, args.task, 'td3') log_path = os.path.join(args.logdir, args.task, 'td3')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -118,7 +120,7 @@ def test_td3(args=get_args()):
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size, args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
update_per_step=args.update_per_step, stop_fn=stop_fn, update_per_step=args.update_per_step, stop_fn=stop_fn,
save_fn=save_fn, writer=writer) save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)

View File

@ -6,6 +6,7 @@ import argparse
import numpy as np import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
@ -89,6 +90,7 @@ def test_a2c_with_il(args=get_args()):
# log # log
log_path = os.path.join(args.logdir, args.task, 'a2c') log_path = os.path.join(args.logdir, args.task, 'a2c')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -101,7 +103,7 @@ def test_a2c_with_il(args=get_args()):
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
writer=writer) logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)
@ -130,7 +132,7 @@ def test_a2c_with_il(args=get_args()):
result = offpolicy_trainer( result = offpolicy_trainer(
il_policy, train_collector, il_test_collector, args.epoch, il_policy, train_collector, il_test_collector, args.epoch,
args.il_step_per_epoch, args.step_per_collect, args.test_num, args.il_step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer) args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)

View File

@ -7,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import C51Policy from tianshou.policy import C51Policy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
@ -89,6 +90,7 @@ def test_c51(args=get_args()):
# log # log
log_path = os.path.join(args.logdir, args.task, 'c51') log_path = os.path.join(args.logdir, args.task, 'c51')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -115,7 +117,7 @@ def test_c51(args=get_args()):
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn, args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer) stop_fn=stop_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -8,6 +8,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DQNPolicy from tianshou.policy import DQNPolicy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
@ -91,6 +92,7 @@ def test_dqn(args=get_args()):
# log # log
log_path = os.path.join(args.logdir, args.task, 'dqn') log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -117,7 +119,7 @@ def test_dqn(args=get_args()):
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer) test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])

View File

@ -7,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DQNPolicy from tianshou.policy import DQNPolicy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
from tianshou.utils.net.common import Recurrent from tianshou.utils.net.common import Recurrent
@ -77,6 +78,7 @@ def test_drqn(args=get_args()):
# log # log
log_path = os.path.join(args.logdir, args.task, 'drqn') log_path = os.path.join(args.logdir, args.task, 'drqn')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -96,7 +98,7 @@ def test_drqn(args=get_args()):
args.step_per_epoch, args.step_per_collect, args.test_num, args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, update_per_step=args.update_per_step, args.batch_size, update_per_step=args.update_per_step,
train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn,
save_fn=save_fn, writer=writer) save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -8,6 +8,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector from tianshou.data import Collector
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.trainer import offline_trainer from tianshou.trainer import offline_trainer
@ -82,6 +83,7 @@ def test_discrete_bcq(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'discrete_bcq') log_path = os.path.join(args.logdir, args.task, 'discrete_bcq')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -92,7 +94,7 @@ def test_discrete_bcq(args=get_args()):
result = offline_trainer( result = offline_trainer(
policy, buffer, test_collector, policy, buffer, test_collector,
args.epoch, args.update_per_epoch, args.test_num, args.batch_size, args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
stop_fn=stop_fn, save_fn=save_fn, writer=writer) stop_fn=stop_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])

View File

@ -7,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import PGPolicy from tianshou.policy import PGPolicy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer from tianshou.trainer import onpolicy_trainer
@ -72,6 +73,7 @@ def test_pg(args=get_args()):
# log # log
log_path = os.path.join(args.logdir, args.task, 'pg') log_path = os.path.join(args.logdir, args.task, 'pg')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -84,7 +86,7 @@ def test_pg(args=get_args()):
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
writer=writer) logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)

View File

@ -7,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import PPOPolicy from tianshou.policy import PPOPolicy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer from tianshou.trainer import onpolicy_trainer
@ -98,6 +99,7 @@ def test_ppo(args=get_args()):
# log # log
log_path = os.path.join(args.logdir, args.task, 'ppo') log_path = os.path.join(args.logdir, args.task, 'ppo')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -110,7 +112,7 @@ def test_ppo(args=get_args()):
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
writer=writer) logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)

View File

@ -6,6 +6,7 @@ import argparse
import numpy as np import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.policy import QRDQNPolicy from tianshou.policy import QRDQNPolicy
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
@ -87,6 +88,7 @@ def test_qrdqn(args=get_args()):
# log # log
log_path = os.path.join(args.logdir, args.task, 'qrdqn') log_path = os.path.join(args.logdir, args.task, 'qrdqn')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -113,7 +115,7 @@ def test_qrdqn(args=get_args()):
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn, args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step) update_per_step=args.update_per_step)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])

View File

@ -6,12 +6,13 @@ import argparse
import numpy as np import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import DiscreteSACPolicy from tianshou.policy import DiscreteSACPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils.net.discrete import Actor, Critic from tianshou.utils.net.discrete import Actor, Critic
from tianshou.data import Collector, VectorReplayBuffer
def get_args(): def get_args():
@ -99,6 +100,7 @@ def test_discrete_sac(args=get_args()):
# log # log
log_path = os.path.join(args.logdir, args.task, 'discrete_sac') log_path = os.path.join(args.logdir, args.task, 'discrete_sac')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -110,7 +112,7 @@ def test_discrete_sac(args=get_args()):
result = offpolicy_trainer( result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step, test_in_train=False) update_per_step=args.update_per_step, test_in_train=False)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,3 +1,4 @@
import os
import gym import gym
import torch import torch
import pprint import pprint
@ -6,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import PSRLPolicy from tianshou.policy import PSRLPolicy
# from tianshou.utils import BasicLogger
from tianshou.trainer import onpolicy_trainer from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.env import DummyVectorEnv, SubprocVectorEnv
@ -66,7 +68,10 @@ def test_psrl(args=get_args()):
exploration_noise=True) exploration_noise=True)
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
# log # log
writer = SummaryWriter(args.logdir + '/' + args.task) log_path = os.path.join(args.logdir, args.task, 'psrl')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
# logger = BasicLogger(writer)
def stop_fn(mean_rewards): def stop_fn(mean_rewards):
if env.spec.reward_threshold: if env.spec.reward_threshold:
@ -75,11 +80,12 @@ def test_psrl(args=get_args()):
return False return False
train_collector.collect(n_step=args.buffer_size, random=True) train_collector.collect(n_step=args.buffer_size, random=True)
# trainer # trainer, test it without logger
result = onpolicy_trainer( result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, 1, args.test_num, 0, args.step_per_epoch, 1, args.test_num, 0,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn,
# logger=logger,
test_in_train=False) test_in_train=False)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -7,6 +7,7 @@ from torch.utils.tensorboard import SummaryWriter
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.data import Collector from tianshou.data import Collector
from tianshou.policy import RandomPolicy from tianshou.policy import RandomPolicy
from tianshou.utils import BasicLogger
from tic_tac_toe_env import TicTacToeEnv from tic_tac_toe_env import TicTacToeEnv
from tic_tac_toe import get_parser, get_agents, train_agent, watch from tic_tac_toe import get_parser, get_agents, train_agent, watch
@ -31,7 +32,8 @@ def gomoku(args=get_args()):
# log # log
log_path = os.path.join(args.logdir, 'Gomoku', 'dqn') log_path = os.path.join(args.logdir, 'Gomoku', 'dqn')
args.writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
args.logger = BasicLogger(writer)
opponent_pool = [agent_opponent] opponent_pool = [agent_opponent]

View File

@ -6,6 +6,7 @@ from copy import deepcopy
from typing import Optional, Tuple from typing import Optional, Tuple
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer from tianshou.trainer import offpolicy_trainer
@ -131,12 +132,10 @@ def train_agent(
# policy.set_eps(1) # policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num) train_collector.collect(n_step=args.batch_size * args.training_num)
# log # log
if not hasattr(args, 'writer'): log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn') writer = SummaryWriter(log_path)
writer = SummaryWriter(log_path) writer.add_text("args", str(args))
args.writer = writer logger = BasicLogger(writer)
else:
writer = args.writer
def save_fn(policy): def save_fn(policy):
if hasattr(args, 'model_save_path'): if hasattr(args, 'model_save_path'):
@ -166,7 +165,7 @@ def train_agent(
args.step_per_epoch, args.step_per_collect, args.test_num, args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn, args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step, stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step,
writer=writer, test_in_train=False, reward_metric=reward_metric) logger=logger, test_in_train=False, reward_metric=reward_metric)
return result, policy.policies[args.agent_id - 1] return result, policy.policies[args.agent_id - 1]

View File

@ -92,8 +92,6 @@ class ReplayBuffer:
("buffer.__getattr__" is customized). ("buffer.__getattr__" is customized).
""" """
self.__dict__.update(state) self.__dict__.update(state)
# compatible with version == 0.3.1's HDF5 data format
self._indices = np.arange(self.maxsize)
def __setattr__(self, key: str, value: Any) -> None: def __setattr__(self, key: str, value: Any) -> None:
"""Set self.key = value.""" """Set self.key = value."""

View File

@ -184,9 +184,9 @@ class Collector(object):
* ``n/ep`` collected number of episodes. * ``n/ep`` collected number of episodes.
* ``n/st`` collected number of steps. * ``n/st`` collected number of steps.
* ``rews`` list of episode reward over collected episodes. * ``rews`` array of episode reward over collected episodes.
* ``lens`` list of episode length over collected episodes. * ``lens`` array of episode length over collected episodes.
* ``idxs`` list of episode start index in buffer over collected episodes. * ``idxs`` array of episode start index in buffer over collected episodes.
""" """
assert not self.env.is_async, "Please use AsyncCollector if using async venv." assert not self.env.is_async, "Please use AsyncCollector if using async venv."
if n_step is not None: if n_step is not None:
@ -379,9 +379,9 @@ class AsyncCollector(Collector):
* ``n/ep`` collected number of episodes. * ``n/ep`` collected number of episodes.
* ``n/st`` collected number of steps. * ``n/st`` collected number of steps.
* ``rews`` list of episode reward over collected episodes. * ``rews`` array of episode reward over collected episodes.
* ``lens`` list of episode length over collected episodes. * ``lens`` array of episode length over collected episodes.
* ``idxs`` list of episode start index in buffer over collected episodes. * ``idxs`` array of episode start index in buffer over collected episodes.
""" """
# collect at least n_step or n_episode # collect at least n_step or n_episode
if n_step is not None: if n_step is not None:

View File

@ -4,7 +4,7 @@ import numpy as np
from torch import nn from torch import nn
from numba import njit from numba import njit
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, List, Union, Mapping, Optional, Callable from typing import Any, Dict, Union, Optional, Callable
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
@ -124,12 +124,10 @@ class BasePolicy(ABC, nn.Module):
return batch return batch
@abstractmethod @abstractmethod
def learn( def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, Any]:
self, batch: Batch, **kwargs: Any
) -> Mapping[str, Union[float, List[float]]]:
"""Update policy with a given batch of data. """Update policy with a given batch of data.
:return: A dict which includes loss and its corresponding label. :return: A dict, including the data needed to be logged (e.g., loss).
.. note:: .. note::
@ -162,18 +160,20 @@ class BasePolicy(ABC, nn.Module):
def update( def update(
self, sample_size: int, buffer: Optional[ReplayBuffer], **kwargs: Any self, sample_size: int, buffer: Optional[ReplayBuffer], **kwargs: Any
) -> Mapping[str, Union[float, List[float]]]: ) -> Dict[str, Any]:
"""Update the policy network and replay buffer. """Update the policy network and replay buffer.
It includes 3 function steps: process_fn, learn, and post_process_fn. It includes 3 function steps: process_fn, learn, and post_process_fn. In
In addition, this function will change the value of ``self.updating``: addition, this function will change the value of ``self.updating``: it will be
it will be False before this function and will be True when executing False before this function and will be True when executing :meth:`update`.
:meth:`update`. Please refer to :ref:`policy_state` for more detailed Please refer to :ref:`policy_state` for more detailed explanation.
explanation.
:param int sample_size: 0 means it will extract all the data from the :param int sample_size: 0 means it will extract all the data from the buffer,
buffer, otherwise it will sample a batch with given sample_size. otherwise it will sample a batch with given sample_size.
:param ReplayBuffer buffer: the corresponding replay buffer. :param ReplayBuffer buffer: the corresponding replay buffer.
:return: A dict, including the data needed to be logged (e.g., loss) from
``policy.learn()``.
""" """
if buffer is None: if buffer is None:
return {} return {}

View File

@ -2,11 +2,10 @@ import time
import tqdm import tqdm
import numpy as np import numpy as np
from collections import defaultdict from collections import defaultdict
from torch.utils.tensorboard import SummaryWriter
from typing import Dict, Union, Callable, Optional from typing import Dict, Union, Callable, Optional
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.utils import tqdm_config, MovAvg from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.trainer import test_episode, gather_info from tianshou.trainer import test_episode, gather_info
@ -23,8 +22,7 @@ def offline_trainer(
stop_fn: Optional[Callable[[float], bool]] = None, stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
writer: Optional[SummaryWriter] = None, logger: BaseLogger = LazyLogger(),
log_interval: int = 1,
verbose: bool = True, verbose: bool = True,
) -> Dict[str, Union[float, str]]: ) -> Dict[str, Union[float, str]]:
"""A wrapper for offline trainer procedure. """A wrapper for offline trainer procedure.
@ -55,9 +53,8 @@ def offline_trainer(
result to monitor training in the multi-agent RL setting. This function result to monitor training in the multi-agent RL setting. This function
specifies what is the desired metric, e.g., the reward of agent 1 or the specifies what is the desired metric, e.g., the reward of agent 1 or the
average reward over all agents. average reward over all agents.
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter; :param BaseLogger logger: A logger that logs statistics during updating/testing.
if None is given, it will not write logs to TensorBoard. Default to None. Default to a logger that doesn't log anything.
:param int log_interval: the log interval of the writer. Default to 1.
:param bool verbose: whether to print the information. Default to True. :param bool verbose: whether to print the information. Default to True.
:return: See :func:`~tianshou.trainer.gather_info`. :return: See :func:`~tianshou.trainer.gather_info`.
@ -67,10 +64,9 @@ def offline_trainer(
start_time = time.time() start_time = time.time()
test_collector.reset_stat() test_collector.reset_stat()
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test, test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test,
writer, gradient_step, reward_metric) logger, gradient_step, reward_metric)
best_epoch = 0 best_epoch = 0
best_reward = test_result["rews"].mean() best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
best_reward_std = test_result["rews"].std()
for epoch in range(1, 1 + max_epoch): for epoch in range(1, 1 + max_epoch):
policy.train() policy.train()
with tqdm.trange( with tqdm.trange(
@ -82,27 +78,23 @@ def offline_trainer(
data = {"gradient_step": str(gradient_step)} data = {"gradient_step": str(gradient_step)}
for k in losses.keys(): for k in losses.keys():
stat[k].add(losses[k]) stat[k].add(losses[k])
data[k] = f"{stat[k].get():.6f}" losses[k] = stat[k].get()
if writer and gradient_step % log_interval == 0: data[k] = f"{losses[k]:.6f}"
writer.add_scalar( logger.log_update_data(losses, gradient_step)
"train/" + k, stat[k].get(),
global_step=gradient_step)
t.set_postfix(**data) t.set_postfix(**data)
# test # test
test_result = test_episode(policy, test_collector, test_fn, epoch, test_result = test_episode(
episode_per_test, writer, gradient_step, policy, test_collector, test_fn, epoch, episode_per_test,
reward_metric) logger, gradient_step, reward_metric)
if best_epoch == -1 or best_reward < test_result["rews"].mean(): rew, rew_std = test_result["rew"], test_result["rew_std"]
best_reward = test_result["rews"].mean() if best_epoch == -1 or best_reward < rew:
best_reward_std = test_result['rews'].std() best_reward, best_reward_std = rew, rew_std
best_epoch = epoch best_epoch = epoch
if save_fn: if save_fn:
save_fn(policy) save_fn(policy)
if verbose: if verbose:
print(f"Epoch #{epoch}: test_reward: {test_result['rews'].mean():.6f} ± " print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"{test_result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
f"{best_reward_std:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward): if stop_fn and stop_fn(best_reward):
break break
return gather_info(start_time, None, test_collector, return gather_info(start_time, None, test_collector, best_reward, best_reward_std)
best_reward, best_reward_std)

View File

@ -2,12 +2,11 @@ import time
import tqdm import tqdm
import numpy as np import numpy as np
from collections import defaultdict from collections import defaultdict
from torch.utils.tensorboard import SummaryWriter
from typing import Dict, Union, Callable, Optional from typing import Dict, Union, Callable, Optional
from tianshou.data import Collector from tianshou.data import Collector
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.utils import tqdm_config, MovAvg from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
from tianshou.trainer import test_episode, gather_info from tianshou.trainer import test_episode, gather_info
@ -26,8 +25,7 @@ def offpolicy_trainer(
stop_fn: Optional[Callable[[float], bool]] = None, stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
writer: Optional[SummaryWriter] = None, logger: BaseLogger = LazyLogger(),
log_interval: int = 1,
verbose: bool = True, verbose: bool = True,
test_in_train: bool = True, test_in_train: bool = True,
) -> Dict[str, Union[float, str]]: ) -> Dict[str, Union[float, str]]:
@ -70,25 +68,24 @@ def offpolicy_trainer(
result to monitor training in the multi-agent RL setting. This function result to monitor training in the multi-agent RL setting. This function
specifies what is the desired metric, e.g., the reward of agent 1 or the specifies what is the desired metric, e.g., the reward of agent 1 or the
average reward over all agents. average reward over all agents.
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter; :param BaseLogger logger: A logger that logs statistics during
if None is given, it will not write logs to TensorBoard. Default to None. training/testing/updating. Default to a logger that doesn't log anything.
:param int log_interval: the log interval of the writer. Default to 1.
:param bool verbose: whether to print the information. Default to True. :param bool verbose: whether to print the information. Default to True.
:param bool test_in_train: whether to test in the training phase. Default to True. :param bool test_in_train: whether to test in the training phase. Default to True.
:return: See :func:`~tianshou.trainer.gather_info`. :return: See :func:`~tianshou.trainer.gather_info`.
""" """
env_step, gradient_step = 0, 0 env_step, gradient_step = 0, 0
last_rew, last_len = 0.0, 0
stat: Dict[str, MovAvg] = defaultdict(MovAvg) stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time() start_time = time.time()
train_collector.reset_stat() train_collector.reset_stat()
test_collector.reset_stat() test_collector.reset_stat()
test_in_train = test_in_train and train_collector.policy == policy test_in_train = test_in_train and train_collector.policy == policy
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test, test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test,
writer, env_step, reward_metric) logger, env_step, reward_metric)
best_epoch = 0 best_epoch = 0
best_reward = test_result["rews"].mean() best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
best_reward_std = test_result["rews"].std()
for epoch in range(1, 1 + max_epoch): for epoch in range(1, 1 + max_epoch):
# train # train
policy.train() policy.train()
@ -99,34 +96,32 @@ def offpolicy_trainer(
if train_fn: if train_fn:
train_fn(epoch, env_step) train_fn(epoch, env_step)
result = train_collector.collect(n_step=step_per_collect) result = train_collector.collect(n_step=step_per_collect)
if len(result["rews"]) > 0 and reward_metric: if result["n/ep"] > 0 and reward_metric:
result["rews"] = reward_metric(result["rews"]) result["rews"] = reward_metric(result["rews"])
env_step += int(result["n/st"]) env_step += int(result["n/st"])
t.update(result["n/st"]) t.update(result["n/st"])
logger.log_train_data(result, env_step)
last_rew = result['rew'] if 'rew' in result else last_rew
last_len = result['len'] if 'len' in result else last_len
data = { data = {
"env_step": str(env_step), "env_step": str(env_step),
"rew": f"{result['rews'].mean():.2f}", "rew": f"{last_rew:.2f}",
"len": str(result["lens"].mean()), "len": str(last_len),
"n/ep": str(int(result["n/ep"])), "n/ep": str(int(result["n/ep"])),
"n/st": str(int(result["n/st"])), "n/st": str(int(result["n/st"])),
} }
if result["n/ep"] > 0: if result["n/ep"] > 0:
if writer and env_step % log_interval == 0: if test_in_train and stop_fn and stop_fn(result["rew"]):
writer.add_scalar(
"train/rew", result['rews'].mean(), global_step=env_step)
writer.add_scalar(
"train/len", result['lens'].mean(), global_step=env_step)
if test_in_train and stop_fn and stop_fn(result["rews"].mean()):
test_result = test_episode( test_result = test_episode(
policy, test_collector, test_fn, policy, test_collector, test_fn,
epoch, episode_per_test, writer, env_step) epoch, episode_per_test, logger, env_step)
if stop_fn(test_result["rews"].mean()): if stop_fn(test_result["rew"]):
if save_fn: if save_fn:
save_fn(policy) save_fn(policy)
t.set_postfix(**data) t.set_postfix(**data)
return gather_info( return gather_info(
start_time, train_collector, test_collector, start_time, train_collector, test_collector,
test_result["rews"].mean(), test_result["rews"].std()) test_result["rew"], test_result["rew_std"])
else: else:
policy.train() policy.train()
for i in range(round(update_per_step * result["n/st"])): for i in range(round(update_per_step * result["n/st"])):
@ -134,26 +129,24 @@ def offpolicy_trainer(
losses = policy.update(batch_size, train_collector.buffer) losses = policy.update(batch_size, train_collector.buffer)
for k in losses.keys(): for k in losses.keys():
stat[k].add(losses[k]) stat[k].add(losses[k])
data[k] = f"{stat[k].get():.6f}" losses[k] = stat[k].get()
if writer and gradient_step % log_interval == 0: data[k] = f"{losses[k]:.6f}"
writer.add_scalar( logger.log_update_data(losses, gradient_step)
k, stat[k].get(), global_step=gradient_step)
t.set_postfix(**data) t.set_postfix(**data)
if t.n <= t.total: if t.n <= t.total:
t.update() t.update()
# test # test
test_result = test_episode(policy, test_collector, test_fn, epoch, test_result = test_episode(policy, test_collector, test_fn, epoch,
episode_per_test, writer, env_step, reward_metric) episode_per_test, logger, env_step, reward_metric)
if best_epoch == -1 or best_reward < test_result["rews"].mean(): rew, rew_std = test_result["rew"], test_result["rew_std"]
best_reward = test_result["rews"].mean() if best_epoch == -1 or best_reward < rew:
best_reward_std = test_result['rews'].std() best_reward, best_reward_std = rew, rew_std
best_epoch = epoch best_epoch = epoch
if save_fn: if save_fn:
save_fn(policy) save_fn(policy)
if verbose: if verbose:
print(f"Epoch #{epoch}: test_reward: {test_result['rews'].mean():.6f} ± " print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"{test_result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
f"{best_reward_std:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward): if stop_fn and stop_fn(best_reward):
break break
return gather_info(start_time, train_collector, test_collector, return gather_info(start_time, train_collector, test_collector,

View File

@ -2,12 +2,11 @@ import time
import tqdm import tqdm
import numpy as np import numpy as np
from collections import defaultdict from collections import defaultdict
from torch.utils.tensorboard import SummaryWriter
from typing import Dict, Union, Callable, Optional from typing import Dict, Union, Callable, Optional
from tianshou.data import Collector from tianshou.data import Collector
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.utils import tqdm_config, MovAvg from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
from tianshou.trainer import test_episode, gather_info from tianshou.trainer import test_episode, gather_info
@ -27,8 +26,7 @@ def onpolicy_trainer(
stop_fn: Optional[Callable[[float], bool]] = None, stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
writer: Optional[SummaryWriter] = None, logger: BaseLogger = LazyLogger(),
log_interval: int = 1,
verbose: bool = True, verbose: bool = True,
test_in_train: bool = True, test_in_train: bool = True,
) -> Dict[str, Union[float, str]]: ) -> Dict[str, Union[float, str]]:
@ -72,9 +70,8 @@ def onpolicy_trainer(
result to monitor training in the multi-agent RL setting. This function result to monitor training in the multi-agent RL setting. This function
specifies what is the desired metric, e.g., the reward of agent 1 or the specifies what is the desired metric, e.g., the reward of agent 1 or the
average reward over all agents. average reward over all agents.
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter; :param BaseLogger logger: A logger that logs statistics during
if None is given, it will not write logs to TensorBoard. Default to None. training/testing/updating. Default to a logger that doesn't log anything.
:param int log_interval: the log interval of the writer. Default to 1.
:param bool verbose: whether to print the information. Default to True. :param bool verbose: whether to print the information. Default to True.
:param bool test_in_train: whether to test in the training phase. Default to True. :param bool test_in_train: whether to test in the training phase. Default to True.
@ -85,16 +82,16 @@ def onpolicy_trainer(
Only either one of step_per_collect and episode_per_collect can be specified. Only either one of step_per_collect and episode_per_collect can be specified.
""" """
env_step, gradient_step = 0, 0 env_step, gradient_step = 0, 0
last_rew, last_len = 0.0, 0
stat: Dict[str, MovAvg] = defaultdict(MovAvg) stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time() start_time = time.time()
train_collector.reset_stat() train_collector.reset_stat()
test_collector.reset_stat() test_collector.reset_stat()
test_in_train = test_in_train and train_collector.policy == policy test_in_train = test_in_train and train_collector.policy == policy
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test, test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test,
writer, env_step, reward_metric) logger, env_step, reward_metric)
best_epoch = 0 best_epoch = 0
best_reward = test_result["rews"].mean() best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
best_reward_std = test_result["rews"].std()
for epoch in range(1, 1 + max_epoch): for epoch in range(1, 1 + max_epoch):
# train # train
policy.train() policy.train()
@ -110,29 +107,27 @@ def onpolicy_trainer(
result["rews"] = reward_metric(result["rews"]) result["rews"] = reward_metric(result["rews"])
env_step += int(result["n/st"]) env_step += int(result["n/st"])
t.update(result["n/st"]) t.update(result["n/st"])
logger.log_train_data(result, env_step)
last_rew = result['rew'] if 'rew' in result else last_rew
last_len = result['len'] if 'len' in result else last_len
data = { data = {
"env_step": str(env_step), "env_step": str(env_step),
"rew": f"{result['rews'].mean():.2f}", "rew": f"{last_rew:.2f}",
"len": str(int(result["lens"].mean())), "len": str(last_len),
"n/ep": str(int(result["n/ep"])), "n/ep": str(int(result["n/ep"])),
"n/st": str(int(result["n/st"])), "n/st": str(int(result["n/st"])),
} }
if writer and env_step % log_interval == 0: if test_in_train and stop_fn and stop_fn(result["rew"]):
writer.add_scalar(
"train/rew", result['rews'].mean(), global_step=env_step)
writer.add_scalar(
"train/len", result['lens'].mean(), global_step=env_step)
if test_in_train and stop_fn and stop_fn(result["rews"].mean()):
test_result = test_episode( test_result = test_episode(
policy, test_collector, test_fn, policy, test_collector, test_fn,
epoch, episode_per_test, writer, env_step) epoch, episode_per_test, logger, env_step)
if stop_fn(test_result["rews"].mean()): if stop_fn(test_result["rew"]):
if save_fn: if save_fn:
save_fn(policy) save_fn(policy)
t.set_postfix(**data) t.set_postfix(**data)
return gather_info( return gather_info(
start_time, train_collector, test_collector, start_time, train_collector, test_collector,
test_result["rews"].mean(), test_result["rews"].std()) test_result["rew"], test_result["rew_std"])
else: else:
policy.train() policy.train()
losses = policy.update( losses = policy.update(
@ -144,26 +139,24 @@ def onpolicy_trainer(
gradient_step += step gradient_step += step
for k in losses.keys(): for k in losses.keys():
stat[k].add(losses[k]) stat[k].add(losses[k])
data[k] = f"{stat[k].get():.6f}" losses[k] = stat[k].get()
if writer and gradient_step % log_interval == 0: data[k] = f"{losses[k]:.6f}"
writer.add_scalar( logger.log_update_data(losses, gradient_step)
k, stat[k].get(), global_step=gradient_step)
t.set_postfix(**data) t.set_postfix(**data)
if t.n <= t.total: if t.n <= t.total:
t.update() t.update()
# test # test
test_result = test_episode(policy, test_collector, test_fn, epoch, test_result = test_episode(policy, test_collector, test_fn, epoch,
episode_per_test, writer, env_step) episode_per_test, logger, env_step)
if best_epoch == -1 or best_reward < test_result["rews"].mean(): rew, rew_std = test_result["rew"], test_result["rew_std"]
best_reward = test_result["rews"].mean() if best_epoch == -1 or best_reward < rew:
best_reward_std = test_result['rews'].std() best_reward, best_reward_std = rew, rew_std
best_epoch = epoch best_epoch = epoch
if save_fn: if save_fn:
save_fn(policy) save_fn(policy)
if verbose: if verbose:
print(f"Epoch #{epoch}: test_reward: {test_result['rews'].mean():.6f} ± " print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"{test_result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± " f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
f"{best_reward_std:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward): if stop_fn and stop_fn(best_reward):
break break
return gather_info(start_time, train_collector, test_collector, return gather_info(start_time, train_collector, test_collector,

View File

@ -1,10 +1,10 @@
import time import time
import numpy as np import numpy as np
from torch.utils.tensorboard import SummaryWriter
from typing import Any, Dict, Union, Callable, Optional from typing import Any, Dict, Union, Callable, Optional
from tianshou.data import Collector from tianshou.data import Collector
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.utils import BaseLogger
def test_episode( def test_episode(
@ -13,7 +13,7 @@ def test_episode(
test_fn: Optional[Callable[[int, Optional[int]], None]], test_fn: Optional[Callable[[int, Optional[int]], None]],
epoch: int, epoch: int,
n_episode: int, n_episode: int,
writer: Optional[SummaryWriter] = None, logger: Optional[BaseLogger] = None,
global_step: Optional[int] = None, global_step: Optional[int] = None,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@ -26,12 +26,8 @@ def test_episode(
result = collector.collect(n_episode=n_episode) result = collector.collect(n_episode=n_episode)
if reward_metric: if reward_metric:
result["rews"] = reward_metric(result["rews"]) result["rews"] = reward_metric(result["rews"])
if writer is not None and global_step is not None: if logger and global_step is not None:
rews, lens = result["rews"], result["lens"] logger.log_test_data(result, global_step)
writer.add_scalar("test/rew", rews.mean(), global_step=global_step)
writer.add_scalar("test/rew_std", rews.std(), global_step=global_step)
writer.add_scalar("test/len", lens.mean(), global_step=global_step)
writer.add_scalar("test/len_std", lens.std(), global_step=global_step)
return result return result

View File

@ -1,9 +1,11 @@
from tianshou.utils.config import tqdm_config from tianshou.utils.config import tqdm_config
from tianshou.utils.moving_average import MovAvg from tianshou.utils.moving_average import MovAvg
from tianshou.utils.log_tools import SummaryWriter from tianshou.utils.log_tools import BasicLogger, LazyLogger, BaseLogger
__all__ = [ __all__ = [
"MovAvg", "MovAvg",
"tqdm_config", "tqdm_config",
"SummaryWriter", "BaseLogger",
"BasicLogger",
"LazyLogger",
] ]

View File

@ -1,47 +1,159 @@
import threading import numpy as np
from torch.utils import tensorboard from numbers import Number
from typing import Any, Dict, Optional from typing import Any, Union
from abc import ABC, abstractmethod
from torch.utils.tensorboard import SummaryWriter
class SummaryWriter(tensorboard.SummaryWriter): class BaseLogger(ABC):
"""A more convenient Summary Writer(`tensorboard.SummaryWriter`). """The base class for any logger which is compatible with trainer."""
You can get the same instance of summary writer everywhere after you def __init__(self, writer: Any) -> None:
created one. super().__init__()
:: self.writer = writer
>>> writer1 = SummaryWriter.get_instance( @abstractmethod
key="first", log_dir="log/test_sw/first") def write(
>>> writer2 = SummaryWriter.get_instance() self,
>>> writer1 is writer2 key: str,
True x: Union[Number, np.number, np.ndarray],
>>> writer4 = SummaryWriter.get_instance( y: Union[Number, np.number, np.ndarray],
key="second", log_dir="log/test_sw/second") **kwargs: Any,
>>> writer5 = SummaryWriter.get_instance(key="second") ) -> None:
>>> writer1 is not writer4 """Specify how the writer is used to log data.
True
>>> writer4 is writer5 :param key: namespace which the input data tuple belongs to.
True :param x: stands for the ordinate of the input data tuple.
:param y: stands for the abscissa of the input data tuple.
"""
pass
def log_train_data(self, collect_result: dict, step: int) -> None:
"""Use writer to log statistics generated during training.
:param collect_result: a dict containing information of data collected in
training stage, i.e., returns of collector.collect().
:param int step: stands for the timestep the collect_result being logged.
"""
pass
def log_update_data(self, update_result: dict, step: int) -> None:
"""Use writer to log statistics generated during updating.
:param update_result: a dict containing information of data collected in
updating stage, i.e., returns of policy.update().
:param int step: stands for the timestep the collect_result being logged.
"""
pass
def log_test_data(self, collect_result: dict, step: int) -> None:
"""Use writer to log statistics generated during evaluating.
:param collect_result: a dict containing information of data collected in
evaluating stage, i.e., returns of collector.collect().
:param int step: stands for the timestep the collect_result being logged.
"""
pass
class BasicLogger(BaseLogger):
"""A loggger that relies on tensorboard SummaryWriter by default to visualize \
and log statistics.
You can also rewrite write() func to use your own writer.
:param SummaryWriter writer: the writer to log data.
:param int train_interval: the log interval in log_train_data(). Default to 1.
:param int test_interval: the log interval in log_test_data(). Default to 1.
:param int update_interval: the log interval in log_update_data(). Default to 1000.
""" """
_mutex_lock = threading.Lock() def __init__(
_default_key: str self,
_instance: Optional[Dict[str, "SummaryWriter"]] = None writer: SummaryWriter,
train_interval: int = 1,
test_interval: int = 1,
update_interval: int = 1000,
) -> None:
super().__init__(writer)
self.train_interval = train_interval
self.test_interval = test_interval
self.update_interval = update_interval
self.last_log_train_step = -1
self.last_log_test_step = -1
self.last_log_update_step = -1
@classmethod def write(
def get_instance( self,
cls, key: str,
key: Optional[str] = None, x: Union[Number, np.number, np.ndarray],
*args: Any, y: Union[Number, np.number, np.ndarray],
**kwargs: Any, **kwargs: Any,
) -> "SummaryWriter": ) -> None:
"""Get instance of torch.utils.tensorboard.SummaryWriter by key.""" self.writer.add_scalar(key, y, global_step=x)
with SummaryWriter._mutex_lock:
if key is None: def log_train_data(self, collect_result: dict, step: int) -> None:
key = SummaryWriter._default_key """Use writer to log statistics generated during training.
if SummaryWriter._instance is None:
SummaryWriter._instance = {} :param collect_result: a dict containing information of data collected in
SummaryWriter._default_key = key training stage, i.e., returns of collector.collect().
if key not in SummaryWriter._instance.keys(): :param int step: stands for the timestep the collect_result being logged.
SummaryWriter._instance[key] = SummaryWriter(*args, **kwargs)
return SummaryWriter._instance[key] .. note::
``collect_result`` will be modified in-place with "rew" and "len" keys.
"""
if collect_result["n/ep"] > 0:
collect_result["rew"] = collect_result["rews"].mean()
collect_result["len"] = collect_result["lens"].mean()
if step - self.last_log_train_step >= self.train_interval:
self.write("train/n/ep", step, collect_result["n/ep"])
self.write("train/rew", step, collect_result["rew"])
self.write("train/len", step, collect_result["len"])
self.last_log_train_step = step
def log_test_data(self, collect_result: dict, step: int) -> None:
"""Use writer to log statistics generated during evaluating.
:param collect_result: a dict containing information of data collected in
evaluating stage, i.e., returns of collector.collect().
:param int step: stands for the timestep the collect_result being logged.
.. note::
``collect_result`` will be modified in-place with "rew", "rew_std", "len",
and "len_std" keys.
"""
assert collect_result["n/ep"] > 0
rews, lens = collect_result["rews"], collect_result["lens"]
rew, rew_std, len_, len_std = rews.mean(), rews.std(), lens.mean(), lens.std()
collect_result.update(rew=rew, rew_std=rew_std, len=len_, len_std=len_std)
if step - self.last_log_test_step >= self.test_interval:
self.write("test/rew", step, rew)
self.write("test/len", step, len_)
self.write("test/rew_std", step, rew_std)
self.write("test/len_std", step, len_std)
self.last_log_test_step = step
def log_update_data(self, update_result: dict, step: int) -> None:
if step - self.last_log_update_step >= self.update_interval:
for k, v in update_result.items():
self.write("train/" + k, step, v) # save in train/
self.last_log_update_step = step
class LazyLogger(BasicLogger):
"""A loggger that does nothing. Used as the placeholder in trainer."""
def __init__(self) -> None:
super().__init__(None) # type: ignore
def write(
self,
key: str,
x: Union[Number, np.number, np.ndarray],
y: Union[Number, np.number, np.ndarray],
**kwargs: Any,
) -> None:
"""The LazyLogger writes nothing."""
pass