add logger (#295)

This PR focus on refactor of logging method to solve bug of nan reward and log interval. After these two pr, hopefully fundamental change of tianshou/data is finished. We then can concentrate on building benchmarks of tianshou finally.

Things changed:

1. trainer now accepts logger (BasicLogger or LazyLogger) instead of writer;
2. remove utils.SummaryWriter;
This commit is contained in:
ChenDRAG 2021-02-24 14:48:42 +08:00
parent e99e1b0fdd
commit 9b61bc620c
45 changed files with 406 additions and 249 deletions

View File

@ -197,6 +197,7 @@ buffer_size = 20000
eps_train, eps_test = 0.1, 0.05
step_per_epoch, step_per_collect = 10000, 10
writer = SummaryWriter('log/dqn') # tensorboard is also supported!
logger = ts.utils.BasicLogger(writer)
```
Make environments:
@ -237,7 +238,7 @@ result = ts.trainer.offpolicy_trainer(
train_fn=lambda epoch, env_step: policy.set_eps(eps_train),
test_fn=lambda epoch, env_step: policy.set_eps(eps_test),
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
writer=writer)
logger=logger)
print(f'Finished training! Use {result["duration"]}')
```

View File

@ -7,3 +7,4 @@ We always welcome contributions to help make Tianshou better. Below are an incom
* Minghao Zhang (`Mehooz <https://github.com/Mehooz>`_)
* Alexis Duburcq (`duburcqa <https://github.com/duburcqa>`_)
* Kaichao You (`youkaichao <https://github.com/youkaichao>`_)
* Huayu Chen (`ChenDRAG <https://github.com/ChenDRAG>`_)

View File

@ -130,7 +130,7 @@ Tianshou provides :func:`~tianshou.trainer.onpolicy_trainer`, :func:`~tianshou.t
train_fn=lambda epoch, env_step: policy.set_eps(0.1),
test_fn=lambda epoch, env_step: policy.set_eps(0.05),
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
writer=None)
logger=None)
print(f'Finished training! Use {result["duration"]}')
The meaning of each parameter is as follows (full description can be found at :func:`~tianshou.trainer.offpolicy_trainer`):
@ -143,15 +143,17 @@ The meaning of each parameter is as follows (full description can be found at :f
* ``train_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of training in this epoch. For example, the code above means "reset the epsilon to 0.1 in DQN before training".
* ``test_fn``: A function receives the current number of epoch and step index, and performs some operations at the beginning of testing in this epoch. For example, the code above means "reset the epsilon to 0.05 in DQN before testing".
* ``stop_fn``: A function receives the average undiscounted returns of the testing result, return a boolean which indicates whether reaching the goal.
* ``writer``: See below.
* ``logger``: See below.
The trainer supports `TensorBoard <https://www.tensorflow.org/tensorboard>`_ for logging. It can be used as:
::
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
writer = SummaryWriter('log/dqn')
logger = BasicLogger(writer)
Pass the writer into the trainer, and the training result will be recorded into the TensorBoard.
Pass the logger into the trainer, and the training result will be recorded into the TensorBoard.
The returned result is a dictionary as follows:
::

View File

@ -176,6 +176,7 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul
import numpy as np
from copy import deepcopy
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
@ -319,11 +320,10 @@ With the above preparation, we are close to the first learned agent. The followi
train_collector.collect(n_step=args.batch_size * args.training_num)
# ======== tensorboard logging setup =========
if not hasattr(args, 'writer'):
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
writer = SummaryWriter(log_path)
else:
writer = args.writer
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
# ======== callback functions used during training =========
@ -359,7 +359,7 @@ With the above preparation, we are close to the first learned agent. The followi
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step,
writer=writer, test_in_train=False, reward_metric=reward_metric)
logger=logger, test_in_train=False, reward_metric=reward_metric)
agent = policy.policies[args.agent_id - 1]
# let's watch the match!

View File

@ -2,10 +2,12 @@ import os
import torch
import pickle
import pprint
import datetime
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offline_trainer
from tianshou.utils.net.discrete import Actor
@ -39,7 +41,7 @@ def get_args():
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--watch", default=False, action="store_true",
help="watch the play of pre-trained policy only")
parser.add_argument("--log-interval", type=int, default=1000)
parser.add_argument("--log-interval", type=int, default=100)
parser.add_argument(
"--load-buffer-name", type=str,
default="./expert_DQN_PongNoFrameskip-v4.hdf5",
@ -113,8 +115,13 @@ def test_discrete_bcq(args=get_args()):
# collector
test_collector = Collector(policy, test_envs, exploration_noise=True)
log_path = os.path.join(args.logdir, args.task, 'discrete_bcq')
# log
log_path = os.path.join(
args.logdir, args.task, 'bcq',
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=args.log_interval)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -141,7 +148,7 @@ def test_discrete_bcq(args=get_args()):
result = offline_trainer(
policy, buffer, test_collector,
args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
log_interval=args.log_interval,
)

View File

@ -6,6 +6,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import C51Policy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
@ -98,6 +99,8 @@ def test_c51(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'c51')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -118,7 +121,7 @@ def test_c51(args=get_args()):
else:
eps = args.eps_train_final
policy.set_eps(eps)
writer.add_scalar('train/eps', eps, global_step=env_step)
logger.write('train/eps', env_step, eps)
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
@ -144,7 +147,7 @@ def test_c51(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step, test_in_train=False)
pprint.pprint(result)

View File

@ -6,6 +6,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DQNPolicy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
@ -94,6 +95,8 @@ def test_dqn(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -114,7 +117,7 @@ def test_dqn(args=get_args()):
else:
eps = args.eps_train_final
policy.set_eps(eps)
writer.add_scalar('train/eps', eps, global_step=env_step)
logger.write('train/eps', env_step, eps)
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
@ -154,7 +157,7 @@ def test_dqn(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step, test_in_train=False)
pprint.pprint(result)

View File

@ -5,6 +5,7 @@ import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.policy import QRDQNPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
@ -96,6 +97,8 @@ def test_qrdqn(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'qrdqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -116,7 +119,7 @@ def test_qrdqn(args=get_args()):
else:
eps = args.eps_train_final
policy.set_eps(eps)
writer.add_scalar('train/eps', eps, global_step=env_step)
logger.write('train/eps', env_step, eps)
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
@ -142,7 +145,7 @@ def test_qrdqn(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step, test_in_train=False)
pprint.pprint(result)

View File

@ -4,7 +4,9 @@ import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import A2CPolicy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
@ -79,7 +81,9 @@ def test_a2c(args=get_args()):
preprocess_fn=preprocess_fn, exploration_noise=True)
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
# log
writer = SummaryWriter(os.path.join(args.logdir, args.task, 'a2c'))
log_path = os.path.join(args.logdir, args.task, 'a2c')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def stop_fn(mean_rewards):
if env.env.spec.reward_threshold:
@ -91,7 +95,7 @@ def test_a2c(args=get_args()):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer)
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, logger=logger)
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!

View File

@ -6,11 +6,12 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import PPOPolicy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.utils.net.discrete import Actor, Critic
from tianshou.data import Collector, VectorReplayBuffer
from atari import create_atari_environment, preprocess_fn
@ -84,7 +85,9 @@ def test_ppo(args=get_args()):
preprocess_fn=preprocess_fn, exploration_noise=True)
test_collector = Collector(policy, test_envs, preprocess_fn=preprocess_fn)
# log
writer = SummaryWriter(os.path.join(args.logdir, args.task, 'ppo'))
log_path = os.path.join(args.logdir, args.task, 'ppo')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def stop_fn(mean_rewards):
if env.env.spec.reward_threshold:
@ -96,7 +99,8 @@ def test_ppo(args=get_args()):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer)
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, logger=logger)
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!

View File

@ -7,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DQNPolicy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
@ -81,6 +82,7 @@ def test_dqn(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -106,7 +108,7 @@ def test_dqn(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
update_per_step=args.update_per_step, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
stop_fn=stop_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':

View File

@ -7,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import SACPolicy
from tianshou.utils import BasicLogger
from tianshou.utils.net.common import Net
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
@ -134,6 +135,7 @@ def test_sac_bipedal(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'sac')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -146,7 +148,7 @@ def test_sac_bipedal(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
update_per_step=args.update_per_step, test_in_train=False,
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
stop_fn=stop_fn, save_fn=save_fn, logger=logger)
if __name__ == '__main__':
pprint.pprint(result)

View File

@ -7,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DQNPolicy
from tianshou.utils import BasicLogger
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
@ -83,6 +84,7 @@ def test_dqn(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -102,7 +104,7 @@ def test_dqn(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
update_per_step=args.update_per_step, stop_fn=stop_fn, train_fn=train_fn,
test_fn=test_fn, save_fn=save_fn, writer=writer)
test_fn=test_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':

View File

@ -7,11 +7,12 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import SACPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.exploration import OUNoise
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.utils.net.continuous import ActorProb, Critic
@ -103,6 +104,7 @@ def test_sac(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'sac')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -115,7 +117,7 @@ def test_sac(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
update_per_step=args.update_per_step, stop_fn=stop_fn,
save_fn=save_fn, writer=writer)
save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':

View File

@ -2,11 +2,13 @@ import os
import gym
import torch
import pprint
import datetime
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import SACPolicy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
@ -114,8 +116,11 @@ def test_sac(args=get_args()):
exploration_noise=True)
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'sac')
log_path = os.path.join(args.logdir, args.task, 'sac', 'seed_' + str(
args.seed) + '_' + datetime.datetime.now().strftime('%m%d-%H%M%S'))
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, train_interval=args.log_interval)
def watch():
# watch agent's performance
@ -141,8 +146,8 @@ def test_sac(args=get_args()):
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer,
update_per_step=args.update_per_step, log_interval=args.log_interval)
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step)
pprint.pprint(result)
watch()

View File

@ -1,3 +1,4 @@
import os
import gym
import torch
import pprint
@ -6,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DDPGPolicy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
@ -79,7 +81,9 @@ def test_ddpg(args=get_args()):
exploration_noise=True)
test_collector = Collector(policy, test_envs)
# log
writer = SummaryWriter(args.logdir + '/' + 'ddpg')
log_path = os.path.join(args.logdir, args.task, 'ddpg')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
@ -88,7 +92,7 @@ def test_ddpg(args=get_args()):
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer)
args.batch_size, stop_fn=stop_fn, logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)

View File

@ -1,3 +1,4 @@
import os
import gym
import torch
import pprint
@ -6,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import TD3Policy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.exploration import GaussianNoise
@ -88,7 +90,9 @@ def test_td3(args=get_args()):
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size)
# log
writer = SummaryWriter(args.logdir + '/' + 'td3')
log_path = os.path.join(args.logdir, args.task, 'td3')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
@ -97,7 +101,7 @@ def test_td3(args=get_args()):
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer)
args.batch_size, stop_fn=stop_fn, logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)

View File

@ -2,12 +2,14 @@ import os
import gym
import torch
import pprint
import datetime
import argparse
import numpy as np
import pybullet_envs
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import SACPolicy
from tianshou.utils import BasicLogger
from tianshou.utils.net.common import Net
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
@ -88,8 +90,10 @@ def test_sac(args=get_args()):
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size)
# log
log_path = os.path.join(args.logdir, args.task, 'sac', args.run_id)
log_path = os.path.join(args.logdir, args.task, 'sac', 'seed_' + str(
args.seed) + '_' + datetime.datetime.now().strftime('%m%d-%H%M%S'))
writer = SummaryWriter(log_path)
logger = BasicLogger(writer, train_interval=args.log_interval)
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
@ -99,7 +103,7 @@ def test_sac(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, stop_fn=stop_fn,
writer=writer, log_interval=args.log_interval)
logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)

View File

@ -1,3 +1,4 @@
import os
import gym
import torch
import pprint
@ -6,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import TD3Policy
from tianshou.utils import BasicLogger
from tianshou.utils.net.common import Net
from tianshou.env import SubprocVectorEnv
from tianshou.exploration import GaussianNoise
@ -93,7 +95,9 @@ def test_td3(args=get_args()):
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size)
# log
writer = SummaryWriter(args.logdir + '/' + 'td3')
log_path = os.path.join(args.logdir, args.task, 'td3')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def stop_fn(mean_rewards):
if env.spec.reward_threshold:
@ -105,7 +109,7 @@ def test_td3(args=get_args()):
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer)
args.batch_size, stop_fn=stop_fn, logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)

View File

@ -2,7 +2,6 @@ import torch
import numpy as np
from tianshou.utils import MovAvg
from tianshou.utils import SummaryWriter
from tianshou.utils.net.common import MLP, Net
from tianshou.exploration import GaussianNoise, OUNoise
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic
@ -77,25 +76,7 @@ def test_net():
assert list(net(data, act).shape) == [bsz, 1]
def test_summary_writer():
# get first instance by key of `default` or your own key
writer1 = SummaryWriter.get_instance(
key="first", log_dir="log/test_sw/first")
assert writer1.log_dir == "log/test_sw/first"
writer2 = SummaryWriter.get_instance()
assert writer1 is writer2
# create new instance by specify a new key
writer3 = SummaryWriter.get_instance(
key="second", log_dir="log/test_sw/second")
assert writer3.log_dir == "log/test_sw/second"
writer4 = SummaryWriter.get_instance(key="second")
assert writer3 is writer4
assert writer1 is not writer3
assert writer1.log_dir != writer4.log_dir
if __name__ == '__main__':
test_noise()
test_moving_average()
test_net()
test_summary_writer()

View File

@ -7,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DDPGPolicy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
@ -93,6 +94,7 @@ def test_ddpg(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'ddpg')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -105,7 +107,7 @@ def test_ddpg(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
update_per_step=args.update_per_step, stop_fn=stop_fn,
save_fn=save_fn, writer=writer)
save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)

View File

@ -8,6 +8,7 @@ from torch.utils.tensorboard import SummaryWriter
from torch.distributions import Independent, Normal
from tianshou.policy import PPOPolicy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
@ -111,6 +112,7 @@ def test_ppo(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'ppo')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -123,7 +125,7 @@ def test_ppo(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
writer=writer)
logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)

View File

@ -6,6 +6,7 @@ import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
@ -102,6 +103,7 @@ def test_sac_with_il(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'sac')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -114,7 +116,7 @@ def test_sac_with_il(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
update_per_step=args.update_per_step, stop_fn=stop_fn,
save_fn=save_fn, writer=writer)
save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
@ -146,7 +148,7 @@ def test_sac_with_il(args=get_args()):
result = offpolicy_trainer(
il_policy, train_collector, il_test_collector, args.epoch,
args.il_step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)

View File

@ -7,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import TD3Policy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
@ -106,6 +107,7 @@ def test_td3(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'td3')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -118,7 +120,7 @@ def test_td3(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.batch_size,
update_per_step=args.update_per_step, stop_fn=stop_fn,
save_fn=save_fn, writer=writer)
save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)

View File

@ -6,6 +6,7 @@ import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.data import Collector, VectorReplayBuffer
@ -89,6 +90,7 @@ def test_a2c_with_il(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'a2c')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -101,7 +103,7 @@ def test_a2c_with_il(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
writer=writer)
logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
@ -130,7 +132,7 @@ def test_a2c_with_il(args=get_args()):
result = offpolicy_trainer(
il_policy, train_collector, il_test_collector, args.epoch,
args.il_step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)

View File

@ -7,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import C51Policy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
@ -89,6 +90,7 @@ def test_c51(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'c51')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -115,7 +117,7 @@ def test_c51(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
stop_fn=stop_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':

View File

@ -8,6 +8,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DQNPolicy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
@ -91,6 +92,7 @@ def test_dqn(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'dqn')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -117,7 +119,7 @@ def test_dqn(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])

View File

@ -7,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DQNPolicy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.utils.net.common import Recurrent
@ -77,6 +78,7 @@ def test_drqn(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'drqn')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -96,7 +98,7 @@ def test_drqn(args=get_args()):
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, update_per_step=args.update_per_step,
train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn,
save_fn=save_fn, writer=writer)
save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':

View File

@ -8,6 +8,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offline_trainer
@ -82,6 +83,7 @@ def test_discrete_bcq(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'discrete_bcq')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -92,7 +94,7 @@ def test_discrete_bcq(args=get_args()):
result = offline_trainer(
policy, buffer, test_collector,
args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
stop_fn=stop_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])

View File

@ -7,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import PGPolicy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
@ -72,6 +73,7 @@ def test_pg(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'pg')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -84,7 +86,7 @@ def test_pg(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
writer=writer)
logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)

View File

@ -7,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import PPOPolicy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
@ -98,6 +99,7 @@ def test_ppo(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'ppo')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -110,7 +112,7 @@ def test_ppo(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
writer=writer)
logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)

View File

@ -6,6 +6,7 @@ import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.policy import QRDQNPolicy
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
@ -87,6 +88,7 @@ def test_qrdqn(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'qrdqn')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -113,7 +115,7 @@ def test_qrdqn(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step)
assert stop_fn(result['best_reward'])

View File

@ -6,12 +6,13 @@ import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import DiscreteSACPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils.net.discrete import Actor, Critic
from tianshou.data import Collector, VectorReplayBuffer
def get_args():
@ -99,6 +100,7 @@ def test_discrete_sac(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'discrete_sac')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -110,7 +112,7 @@ def test_discrete_sac(args=get_args()):
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step, test_in_train=False)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':

View File

@ -1,3 +1,4 @@
import os
import gym
import torch
import pprint
@ -6,6 +7,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import PSRLPolicy
# from tianshou.utils import BasicLogger
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
@ -66,7 +68,10 @@ def test_psrl(args=get_args()):
exploration_noise=True)
test_collector = Collector(policy, test_envs)
# log
writer = SummaryWriter(args.logdir + '/' + args.task)
log_path = os.path.join(args.logdir, args.task, 'psrl')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
# logger = BasicLogger(writer)
def stop_fn(mean_rewards):
if env.spec.reward_threshold:
@ -75,11 +80,12 @@ def test_psrl(args=get_args()):
return False
train_collector.collect(n_step=args.buffer_size, random=True)
# trainer
# trainer, test it without logger
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, 1, args.test_num, 0,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, writer=writer,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn,
# logger=logger,
test_in_train=False)
if __name__ == '__main__':

View File

@ -7,6 +7,7 @@ from torch.utils.tensorboard import SummaryWriter
from tianshou.env import DummyVectorEnv
from tianshou.data import Collector
from tianshou.policy import RandomPolicy
from tianshou.utils import BasicLogger
from tic_tac_toe_env import TicTacToeEnv
from tic_tac_toe import get_parser, get_agents, train_agent, watch
@ -31,7 +32,8 @@ def gomoku(args=get_args()):
# log
log_path = os.path.join(args.logdir, 'Gomoku', 'dqn')
args.writer = SummaryWriter(log_path)
writer = SummaryWriter(log_path)
args.logger = BasicLogger(writer)
opponent_pool = [agent_opponent]

View File

@ -6,6 +6,7 @@ from copy import deepcopy
from typing import Optional, Tuple
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
@ -131,12 +132,10 @@ def train_agent(
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
if not hasattr(args, 'writer'):
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
writer = SummaryWriter(log_path)
args.writer = writer
else:
writer = args.writer
log_path = os.path.join(args.logdir, 'tic_tac_toe', 'dqn')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
def save_fn(policy):
if hasattr(args, 'model_save_path'):
@ -166,7 +165,7 @@ def train_agent(
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, update_per_step=args.update_per_step,
writer=writer, test_in_train=False, reward_metric=reward_metric)
logger=logger, test_in_train=False, reward_metric=reward_metric)
return result, policy.policies[args.agent_id - 1]

View File

@ -92,8 +92,6 @@ class ReplayBuffer:
("buffer.__getattr__" is customized).
"""
self.__dict__.update(state)
# compatible with version == 0.3.1's HDF5 data format
self._indices = np.arange(self.maxsize)
def __setattr__(self, key: str, value: Any) -> None:
"""Set self.key = value."""

View File

@ -184,9 +184,9 @@ class Collector(object):
* ``n/ep`` collected number of episodes.
* ``n/st`` collected number of steps.
* ``rews`` list of episode reward over collected episodes.
* ``lens`` list of episode length over collected episodes.
* ``idxs`` list of episode start index in buffer over collected episodes.
* ``rews`` array of episode reward over collected episodes.
* ``lens`` array of episode length over collected episodes.
* ``idxs`` array of episode start index in buffer over collected episodes.
"""
assert not self.env.is_async, "Please use AsyncCollector if using async venv."
if n_step is not None:
@ -379,9 +379,9 @@ class AsyncCollector(Collector):
* ``n/ep`` collected number of episodes.
* ``n/st`` collected number of steps.
* ``rews`` list of episode reward over collected episodes.
* ``lens`` list of episode length over collected episodes.
* ``idxs`` list of episode start index in buffer over collected episodes.
* ``rews`` array of episode reward over collected episodes.
* ``lens`` array of episode length over collected episodes.
* ``idxs`` array of episode start index in buffer over collected episodes.
"""
# collect at least n_step or n_episode
if n_step is not None:

View File

@ -4,7 +4,7 @@ import numpy as np
from torch import nn
from numba import njit
from abc import ABC, abstractmethod
from typing import Any, List, Union, Mapping, Optional, Callable
from typing import Any, Dict, Union, Optional, Callable
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
@ -124,12 +124,10 @@ class BasePolicy(ABC, nn.Module):
return batch
@abstractmethod
def learn(
self, batch: Batch, **kwargs: Any
) -> Mapping[str, Union[float, List[float]]]:
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, Any]:
"""Update policy with a given batch of data.
:return: A dict which includes loss and its corresponding label.
:return: A dict, including the data needed to be logged (e.g., loss).
.. note::
@ -162,18 +160,20 @@ class BasePolicy(ABC, nn.Module):
def update(
self, sample_size: int, buffer: Optional[ReplayBuffer], **kwargs: Any
) -> Mapping[str, Union[float, List[float]]]:
) -> Dict[str, Any]:
"""Update the policy network and replay buffer.
It includes 3 function steps: process_fn, learn, and post_process_fn.
In addition, this function will change the value of ``self.updating``:
it will be False before this function and will be True when executing
:meth:`update`. Please refer to :ref:`policy_state` for more detailed
explanation.
It includes 3 function steps: process_fn, learn, and post_process_fn. In
addition, this function will change the value of ``self.updating``: it will be
False before this function and will be True when executing :meth:`update`.
Please refer to :ref:`policy_state` for more detailed explanation.
:param int sample_size: 0 means it will extract all the data from the
buffer, otherwise it will sample a batch with given sample_size.
:param int sample_size: 0 means it will extract all the data from the buffer,
otherwise it will sample a batch with given sample_size.
:param ReplayBuffer buffer: the corresponding replay buffer.
:return: A dict, including the data needed to be logged (e.g., loss) from
``policy.learn()``.
"""
if buffer is None:
return {}

View File

@ -2,11 +2,10 @@ import time
import tqdm
import numpy as np
from collections import defaultdict
from torch.utils.tensorboard import SummaryWriter
from typing import Dict, Union, Callable, Optional
from tianshou.policy import BasePolicy
from tianshou.utils import tqdm_config, MovAvg
from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
from tianshou.data import Collector, ReplayBuffer
from tianshou.trainer import test_episode, gather_info
@ -23,8 +22,7 @@ def offline_trainer(
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
writer: Optional[SummaryWriter] = None,
log_interval: int = 1,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
) -> Dict[str, Union[float, str]]:
"""A wrapper for offline trainer procedure.
@ -55,9 +53,8 @@ def offline_trainer(
result to monitor training in the multi-agent RL setting. This function
specifies what is the desired metric, e.g., the reward of agent 1 or the
average reward over all agents.
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter;
if None is given, it will not write logs to TensorBoard. Default to None.
:param int log_interval: the log interval of the writer. Default to 1.
:param BaseLogger logger: A logger that logs statistics during updating/testing.
Default to a logger that doesn't log anything.
:param bool verbose: whether to print the information. Default to True.
:return: See :func:`~tianshou.trainer.gather_info`.
@ -67,10 +64,9 @@ def offline_trainer(
start_time = time.time()
test_collector.reset_stat()
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test,
writer, gradient_step, reward_metric)
logger, gradient_step, reward_metric)
best_epoch = 0
best_reward = test_result["rews"].mean()
best_reward_std = test_result["rews"].std()
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
for epoch in range(1, 1 + max_epoch):
policy.train()
with tqdm.trange(
@ -82,27 +78,23 @@ def offline_trainer(
data = {"gradient_step": str(gradient_step)}
for k in losses.keys():
stat[k].add(losses[k])
data[k] = f"{stat[k].get():.6f}"
if writer and gradient_step % log_interval == 0:
writer.add_scalar(
"train/" + k, stat[k].get(),
global_step=gradient_step)
losses[k] = stat[k].get()
data[k] = f"{losses[k]:.6f}"
logger.log_update_data(losses, gradient_step)
t.set_postfix(**data)
# test
test_result = test_episode(policy, test_collector, test_fn, epoch,
episode_per_test, writer, gradient_step,
reward_metric)
if best_epoch == -1 or best_reward < test_result["rews"].mean():
best_reward = test_result["rews"].mean()
best_reward_std = test_result['rews'].std()
test_result = test_episode(
policy, test_collector, test_fn, epoch, episode_per_test,
logger, gradient_step, reward_metric)
rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch == -1 or best_reward < rew:
best_reward, best_reward_std = rew, rew_std
best_epoch = epoch
if save_fn:
save_fn(policy)
if verbose:
print(f"Epoch #{epoch}: test_reward: {test_result['rews'].mean():.6f} ± "
f"{test_result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± "
f"{best_reward_std:.6f} in #{best_epoch}")
print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward):
break
return gather_info(start_time, None, test_collector,
best_reward, best_reward_std)
return gather_info(start_time, None, test_collector, best_reward, best_reward_std)

View File

@ -2,12 +2,11 @@ import time
import tqdm
import numpy as np
from collections import defaultdict
from torch.utils.tensorboard import SummaryWriter
from typing import Dict, Union, Callable, Optional
from tianshou.data import Collector
from tianshou.policy import BasePolicy
from tianshou.utils import tqdm_config, MovAvg
from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
from tianshou.trainer import test_episode, gather_info
@ -26,8 +25,7 @@ def offpolicy_trainer(
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
writer: Optional[SummaryWriter] = None,
log_interval: int = 1,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
test_in_train: bool = True,
) -> Dict[str, Union[float, str]]:
@ -70,25 +68,24 @@ def offpolicy_trainer(
result to monitor training in the multi-agent RL setting. This function
specifies what is the desired metric, e.g., the reward of agent 1 or the
average reward over all agents.
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter;
if None is given, it will not write logs to TensorBoard. Default to None.
:param int log_interval: the log interval of the writer. Default to 1.
:param BaseLogger logger: A logger that logs statistics during
training/testing/updating. Default to a logger that doesn't log anything.
:param bool verbose: whether to print the information. Default to True.
:param bool test_in_train: whether to test in the training phase. Default to True.
:return: See :func:`~tianshou.trainer.gather_info`.
"""
env_step, gradient_step = 0, 0
last_rew, last_len = 0.0, 0
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time()
train_collector.reset_stat()
test_collector.reset_stat()
test_in_train = test_in_train and train_collector.policy == policy
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test,
writer, env_step, reward_metric)
logger, env_step, reward_metric)
best_epoch = 0
best_reward = test_result["rews"].mean()
best_reward_std = test_result["rews"].std()
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
for epoch in range(1, 1 + max_epoch):
# train
policy.train()
@ -99,34 +96,32 @@ def offpolicy_trainer(
if train_fn:
train_fn(epoch, env_step)
result = train_collector.collect(n_step=step_per_collect)
if len(result["rews"]) > 0 and reward_metric:
if result["n/ep"] > 0 and reward_metric:
result["rews"] = reward_metric(result["rews"])
env_step += int(result["n/st"])
t.update(result["n/st"])
logger.log_train_data(result, env_step)
last_rew = result['rew'] if 'rew' in result else last_rew
last_len = result['len'] if 'len' in result else last_len
data = {
"env_step": str(env_step),
"rew": f"{result['rews'].mean():.2f}",
"len": str(result["lens"].mean()),
"rew": f"{last_rew:.2f}",
"len": str(last_len),
"n/ep": str(int(result["n/ep"])),
"n/st": str(int(result["n/st"])),
}
if result["n/ep"] > 0:
if writer and env_step % log_interval == 0:
writer.add_scalar(
"train/rew", result['rews'].mean(), global_step=env_step)
writer.add_scalar(
"train/len", result['lens'].mean(), global_step=env_step)
if test_in_train and stop_fn and stop_fn(result["rews"].mean()):
if test_in_train and stop_fn and stop_fn(result["rew"]):
test_result = test_episode(
policy, test_collector, test_fn,
epoch, episode_per_test, writer, env_step)
if stop_fn(test_result["rews"].mean()):
epoch, episode_per_test, logger, env_step)
if stop_fn(test_result["rew"]):
if save_fn:
save_fn(policy)
t.set_postfix(**data)
return gather_info(
start_time, train_collector, test_collector,
test_result["rews"].mean(), test_result["rews"].std())
test_result["rew"], test_result["rew_std"])
else:
policy.train()
for i in range(round(update_per_step * result["n/st"])):
@ -134,26 +129,24 @@ def offpolicy_trainer(
losses = policy.update(batch_size, train_collector.buffer)
for k in losses.keys():
stat[k].add(losses[k])
data[k] = f"{stat[k].get():.6f}"
if writer and gradient_step % log_interval == 0:
writer.add_scalar(
k, stat[k].get(), global_step=gradient_step)
losses[k] = stat[k].get()
data[k] = f"{losses[k]:.6f}"
logger.log_update_data(losses, gradient_step)
t.set_postfix(**data)
if t.n <= t.total:
t.update()
# test
test_result = test_episode(policy, test_collector, test_fn, epoch,
episode_per_test, writer, env_step, reward_metric)
if best_epoch == -1 or best_reward < test_result["rews"].mean():
best_reward = test_result["rews"].mean()
best_reward_std = test_result['rews'].std()
episode_per_test, logger, env_step, reward_metric)
rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch == -1 or best_reward < rew:
best_reward, best_reward_std = rew, rew_std
best_epoch = epoch
if save_fn:
save_fn(policy)
if verbose:
print(f"Epoch #{epoch}: test_reward: {test_result['rews'].mean():.6f} ± "
f"{test_result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± "
f"{best_reward_std:.6f} in #{best_epoch}")
print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward):
break
return gather_info(start_time, train_collector, test_collector,

View File

@ -2,12 +2,11 @@ import time
import tqdm
import numpy as np
from collections import defaultdict
from torch.utils.tensorboard import SummaryWriter
from typing import Dict, Union, Callable, Optional
from tianshou.data import Collector
from tianshou.policy import BasePolicy
from tianshou.utils import tqdm_config, MovAvg
from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
from tianshou.trainer import test_episode, gather_info
@ -27,8 +26,7 @@ def onpolicy_trainer(
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
writer: Optional[SummaryWriter] = None,
log_interval: int = 1,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
test_in_train: bool = True,
) -> Dict[str, Union[float, str]]:
@ -72,9 +70,8 @@ def onpolicy_trainer(
result to monitor training in the multi-agent RL setting. This function
specifies what is the desired metric, e.g., the reward of agent 1 or the
average reward over all agents.
:param torch.utils.tensorboard.SummaryWriter writer: a TensorBoard SummaryWriter;
if None is given, it will not write logs to TensorBoard. Default to None.
:param int log_interval: the log interval of the writer. Default to 1.
:param BaseLogger logger: A logger that logs statistics during
training/testing/updating. Default to a logger that doesn't log anything.
:param bool verbose: whether to print the information. Default to True.
:param bool test_in_train: whether to test in the training phase. Default to True.
@ -85,16 +82,16 @@ def onpolicy_trainer(
Only either one of step_per_collect and episode_per_collect can be specified.
"""
env_step, gradient_step = 0, 0
last_rew, last_len = 0.0, 0
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time()
train_collector.reset_stat()
test_collector.reset_stat()
test_in_train = test_in_train and train_collector.policy == policy
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test,
writer, env_step, reward_metric)
logger, env_step, reward_metric)
best_epoch = 0
best_reward = test_result["rews"].mean()
best_reward_std = test_result["rews"].std()
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
for epoch in range(1, 1 + max_epoch):
# train
policy.train()
@ -110,29 +107,27 @@ def onpolicy_trainer(
result["rews"] = reward_metric(result["rews"])
env_step += int(result["n/st"])
t.update(result["n/st"])
logger.log_train_data(result, env_step)
last_rew = result['rew'] if 'rew' in result else last_rew
last_len = result['len'] if 'len' in result else last_len
data = {
"env_step": str(env_step),
"rew": f"{result['rews'].mean():.2f}",
"len": str(int(result["lens"].mean())),
"rew": f"{last_rew:.2f}",
"len": str(last_len),
"n/ep": str(int(result["n/ep"])),
"n/st": str(int(result["n/st"])),
}
if writer and env_step % log_interval == 0:
writer.add_scalar(
"train/rew", result['rews'].mean(), global_step=env_step)
writer.add_scalar(
"train/len", result['lens'].mean(), global_step=env_step)
if test_in_train and stop_fn and stop_fn(result["rews"].mean()):
if test_in_train and stop_fn and stop_fn(result["rew"]):
test_result = test_episode(
policy, test_collector, test_fn,
epoch, episode_per_test, writer, env_step)
if stop_fn(test_result["rews"].mean()):
epoch, episode_per_test, logger, env_step)
if stop_fn(test_result["rew"]):
if save_fn:
save_fn(policy)
t.set_postfix(**data)
return gather_info(
start_time, train_collector, test_collector,
test_result["rews"].mean(), test_result["rews"].std())
test_result["rew"], test_result["rew_std"])
else:
policy.train()
losses = policy.update(
@ -144,26 +139,24 @@ def onpolicy_trainer(
gradient_step += step
for k in losses.keys():
stat[k].add(losses[k])
data[k] = f"{stat[k].get():.6f}"
if writer and gradient_step % log_interval == 0:
writer.add_scalar(
k, stat[k].get(), global_step=gradient_step)
losses[k] = stat[k].get()
data[k] = f"{losses[k]:.6f}"
logger.log_update_data(losses, gradient_step)
t.set_postfix(**data)
if t.n <= t.total:
t.update()
# test
test_result = test_episode(policy, test_collector, test_fn, epoch,
episode_per_test, writer, env_step)
if best_epoch == -1 or best_reward < test_result["rews"].mean():
best_reward = test_result["rews"].mean()
best_reward_std = test_result['rews'].std()
episode_per_test, logger, env_step)
rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch == -1 or best_reward < rew:
best_reward, best_reward_std = rew, rew_std
best_epoch = epoch
if save_fn:
save_fn(policy)
if verbose:
print(f"Epoch #{epoch}: test_reward: {test_result['rews'].mean():.6f} ± "
f"{test_result['rews'].std():.6f}, best_reward: {best_reward:.6f} ± "
f"{best_reward_std:.6f} in #{best_epoch}")
print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward):
break
return gather_info(start_time, train_collector, test_collector,

View File

@ -1,10 +1,10 @@
import time
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from typing import Any, Dict, Union, Callable, Optional
from tianshou.data import Collector
from tianshou.policy import BasePolicy
from tianshou.utils import BaseLogger
def test_episode(
@ -13,7 +13,7 @@ def test_episode(
test_fn: Optional[Callable[[int, Optional[int]], None]],
epoch: int,
n_episode: int,
writer: Optional[SummaryWriter] = None,
logger: Optional[BaseLogger] = None,
global_step: Optional[int] = None,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
) -> Dict[str, Any]:
@ -26,12 +26,8 @@ def test_episode(
result = collector.collect(n_episode=n_episode)
if reward_metric:
result["rews"] = reward_metric(result["rews"])
if writer is not None and global_step is not None:
rews, lens = result["rews"], result["lens"]
writer.add_scalar("test/rew", rews.mean(), global_step=global_step)
writer.add_scalar("test/rew_std", rews.std(), global_step=global_step)
writer.add_scalar("test/len", lens.mean(), global_step=global_step)
writer.add_scalar("test/len_std", lens.std(), global_step=global_step)
if logger and global_step is not None:
logger.log_test_data(result, global_step)
return result

View File

@ -1,9 +1,11 @@
from tianshou.utils.config import tqdm_config
from tianshou.utils.moving_average import MovAvg
from tianshou.utils.log_tools import SummaryWriter
from tianshou.utils.log_tools import BasicLogger, LazyLogger, BaseLogger
__all__ = [
"MovAvg",
"tqdm_config",
"SummaryWriter",
"BaseLogger",
"BasicLogger",
"LazyLogger",
]

View File

@ -1,47 +1,159 @@
import threading
from torch.utils import tensorboard
from typing import Any, Dict, Optional
import numpy as np
from numbers import Number
from typing import Any, Union
from abc import ABC, abstractmethod
from torch.utils.tensorboard import SummaryWriter
class SummaryWriter(tensorboard.SummaryWriter):
"""A more convenient Summary Writer(`tensorboard.SummaryWriter`).
class BaseLogger(ABC):
"""The base class for any logger which is compatible with trainer."""
You can get the same instance of summary writer everywhere after you
created one.
::
def __init__(self, writer: Any) -> None:
super().__init__()
self.writer = writer
>>> writer1 = SummaryWriter.get_instance(
key="first", log_dir="log/test_sw/first")
>>> writer2 = SummaryWriter.get_instance()
>>> writer1 is writer2
True
>>> writer4 = SummaryWriter.get_instance(
key="second", log_dir="log/test_sw/second")
>>> writer5 = SummaryWriter.get_instance(key="second")
>>> writer1 is not writer4
True
>>> writer4 is writer5
True
@abstractmethod
def write(
self,
key: str,
x: Union[Number, np.number, np.ndarray],
y: Union[Number, np.number, np.ndarray],
**kwargs: Any,
) -> None:
"""Specify how the writer is used to log data.
:param key: namespace which the input data tuple belongs to.
:param x: stands for the ordinate of the input data tuple.
:param y: stands for the abscissa of the input data tuple.
"""
pass
def log_train_data(self, collect_result: dict, step: int) -> None:
"""Use writer to log statistics generated during training.
:param collect_result: a dict containing information of data collected in
training stage, i.e., returns of collector.collect().
:param int step: stands for the timestep the collect_result being logged.
"""
pass
def log_update_data(self, update_result: dict, step: int) -> None:
"""Use writer to log statistics generated during updating.
:param update_result: a dict containing information of data collected in
updating stage, i.e., returns of policy.update().
:param int step: stands for the timestep the collect_result being logged.
"""
pass
def log_test_data(self, collect_result: dict, step: int) -> None:
"""Use writer to log statistics generated during evaluating.
:param collect_result: a dict containing information of data collected in
evaluating stage, i.e., returns of collector.collect().
:param int step: stands for the timestep the collect_result being logged.
"""
pass
class BasicLogger(BaseLogger):
"""A loggger that relies on tensorboard SummaryWriter by default to visualize \
and log statistics.
You can also rewrite write() func to use your own writer.
:param SummaryWriter writer: the writer to log data.
:param int train_interval: the log interval in log_train_data(). Default to 1.
:param int test_interval: the log interval in log_test_data(). Default to 1.
:param int update_interval: the log interval in log_update_data(). Default to 1000.
"""
_mutex_lock = threading.Lock()
_default_key: str
_instance: Optional[Dict[str, "SummaryWriter"]] = None
def __init__(
self,
writer: SummaryWriter,
train_interval: int = 1,
test_interval: int = 1,
update_interval: int = 1000,
) -> None:
super().__init__(writer)
self.train_interval = train_interval
self.test_interval = test_interval
self.update_interval = update_interval
self.last_log_train_step = -1
self.last_log_test_step = -1
self.last_log_update_step = -1
@classmethod
def get_instance(
cls,
key: Optional[str] = None,
*args: Any,
def write(
self,
key: str,
x: Union[Number, np.number, np.ndarray],
y: Union[Number, np.number, np.ndarray],
**kwargs: Any,
) -> "SummaryWriter":
"""Get instance of torch.utils.tensorboard.SummaryWriter by key."""
with SummaryWriter._mutex_lock:
if key is None:
key = SummaryWriter._default_key
if SummaryWriter._instance is None:
SummaryWriter._instance = {}
SummaryWriter._default_key = key
if key not in SummaryWriter._instance.keys():
SummaryWriter._instance[key] = SummaryWriter(*args, **kwargs)
return SummaryWriter._instance[key]
) -> None:
self.writer.add_scalar(key, y, global_step=x)
def log_train_data(self, collect_result: dict, step: int) -> None:
"""Use writer to log statistics generated during training.
:param collect_result: a dict containing information of data collected in
training stage, i.e., returns of collector.collect().
:param int step: stands for the timestep the collect_result being logged.
.. note::
``collect_result`` will be modified in-place with "rew" and "len" keys.
"""
if collect_result["n/ep"] > 0:
collect_result["rew"] = collect_result["rews"].mean()
collect_result["len"] = collect_result["lens"].mean()
if step - self.last_log_train_step >= self.train_interval:
self.write("train/n/ep", step, collect_result["n/ep"])
self.write("train/rew", step, collect_result["rew"])
self.write("train/len", step, collect_result["len"])
self.last_log_train_step = step
def log_test_data(self, collect_result: dict, step: int) -> None:
"""Use writer to log statistics generated during evaluating.
:param collect_result: a dict containing information of data collected in
evaluating stage, i.e., returns of collector.collect().
:param int step: stands for the timestep the collect_result being logged.
.. note::
``collect_result`` will be modified in-place with "rew", "rew_std", "len",
and "len_std" keys.
"""
assert collect_result["n/ep"] > 0
rews, lens = collect_result["rews"], collect_result["lens"]
rew, rew_std, len_, len_std = rews.mean(), rews.std(), lens.mean(), lens.std()
collect_result.update(rew=rew, rew_std=rew_std, len=len_, len_std=len_std)
if step - self.last_log_test_step >= self.test_interval:
self.write("test/rew", step, rew)
self.write("test/len", step, len_)
self.write("test/rew_std", step, rew_std)
self.write("test/len_std", step, len_std)
self.last_log_test_step = step
def log_update_data(self, update_result: dict, step: int) -> None:
if step - self.last_log_update_step >= self.update_interval:
for k, v in update_result.items():
self.write("train/" + k, step, v) # save in train/
self.last_log_update_step = step
class LazyLogger(BasicLogger):
"""A loggger that does nothing. Used as the placeholder in trainer."""
def __init__(self) -> None:
super().__init__(None) # type: ignore
def write(
self,
key: str,
x: Union[Number, np.number, np.ndarray],
y: Union[Number, np.number, np.ndarray],
**kwargs: Any,
) -> None:
"""The LazyLogger writes nothing."""
pass