Make trainer resumable (#350)

- specify tensorboard >= 2.5.0
- add `save_checkpoint_fn` and `resume_from_log` in trainer

Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
This commit is contained in:
Ark 2021-05-06 08:53:53 +08:00 committed by GitHub
parent f4e05d585a
commit 84f58636eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 308 additions and 77 deletions

View File

@ -30,6 +30,34 @@ Customize Training Process
See :ref:`customized_trainer`. See :ref:`customized_trainer`.
.. _resume_training:
Resume Training Process
-----------------------
This is related to `Issue 349 <https://github.com/thu-ml/tianshou/issues/349>`_.
To resume training process from an existing checkpoint, you need to do the following things in the training process:
1. Make sure you write ``save_checkpoint_fn`` which saves everything needed in the training process, i.e., policy, optim, buffer; pass it to trainer;
2. Use ``BasicLogger`` which contains a tensorboard;
3. To adjust the save frequency, specify ``save_interval`` when initializing BasicLogger.
And to successfully resume from a checkpoint:
1. Load everything needed in the training process **before trainer initialization**, i.e., policy, optim, buffer;
2. Set ``resume_from_log=True`` with trainer;
We provide an example to show how these steps work: checkout `test_c51.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_c51.py>`_, `test_ppo.py <https://github.com/thu-ml/tianshou/blob/master/test/continuous/test_ppo.py>`_ or `test_il_bcq.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_il_bcq.py>`_ by running
.. code-block:: console
$ python3 test/discrete/test_c51.py # train some epoch
$ python3 test/discrete/test_c51.py --resume # restore from existing log and continuing training
To correctly render the data (including several tfevent files), we highly recommend using ``tensorboard >= 2.5.0`` (see `here <https://github.com/thu-ml/tianshou/pull/350#issuecomment-829123378>`_ for the reason). Otherwise, it may cause overlapping issue that you need to manually handle with.
.. _parallel_sampling: .. _parallel_sampling:
Parallel Sampling Parallel Sampling

View File

@ -85,8 +85,7 @@ def test_discrete_bcq(args=get_args()):
feature_net, args.action_shape, device=args.device, feature_net, args.action_shape, device=args.device,
hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device) hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device)
optim = torch.optim.Adam( optim = torch.optim.Adam(
set(policy_net.parameters()).union(imitation_net.parameters()), list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr)
lr=args.lr)
# define policy # define policy
policy = DiscreteBCQPolicy( policy = DiscreteBCQPolicy(
policy_net, imitation_net, optim, args.gamma, args.n_step, policy_net, imitation_net, optim, args.gamma, args.n_step,

View File

@ -101,7 +101,7 @@ def test_a2c(args=get_args()):
torch.nn.init.zeros_(m.bias) torch.nn.init.zeros_(m.bias)
m.weight.data.copy_(0.01 * m.weight.data) m.weight.data.copy_(0.01 * m.weight.data)
optim = torch.optim.RMSprop(set(actor.parameters()).union(critic.parameters()), optim = torch.optim.RMSprop(list(actor.parameters()) + list(critic.parameters()),
lr=args.lr, eps=1e-5, alpha=0.99) lr=args.lr, eps=1e-5, alpha=0.99)
lr_scheduler = None lr_scheduler = None

View File

@ -106,8 +106,8 @@ def test_ppo(args=get_args()):
torch.nn.init.zeros_(m.bias) torch.nn.init.zeros_(m.bias)
m.weight.data.copy_(0.01 * m.weight.data) m.weight.data.copy_(0.01 * m.weight.data)
optim = torch.optim.Adam(set( optim = torch.optim.Adam(
actor.parameters()).union(critic.parameters()), lr=args.lr) list(actor.parameters()) + list(critic.parameters()), lr=args.lr)
lr_scheduler = None lr_scheduler = None
if args.lr_decay: if args.lr_decay:

View File

@ -48,7 +48,7 @@ setup(
"gym>=0.15.4", "gym>=0.15.4",
"tqdm", "tqdm",
"numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793 "numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793
"tensorboard", "tensorboard>=2.5.0",
"torch>=1.4.0", "torch>=1.4.0",
"numba>=0.51.0", "numba>=0.51.0",
"h5py>=2.10.0", # to match tensorflow's minimal requirements "h5py>=2.10.0", # to match tensorflow's minimal requirements

View File

@ -105,6 +105,7 @@ def test_ddpg(args=get_args()):
update_per_step=args.update_per_step, stop_fn=stop_fn, update_per_step=args.update_per_step, stop_fn=stop_fn,
save_fn=save_fn, logger=logger) save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!

View File

@ -80,8 +80,7 @@ def test_npg(args=get_args()):
if isinstance(m, torch.nn.Linear): if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight) torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias) torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(set( optim = torch.optim.Adam(critic.parameters(), lr=args.lr)
actor.parameters()).union(critic.parameters()), lr=args.lr)
# replace DiagGuassian with Independent(Normal) which is equivalent # replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward # pass *logits to be consistent with policy.forward

View File

@ -47,6 +47,8 @@ def get_args():
parser.add_argument('--value-clip', type=int, default=1) parser.add_argument('--value-clip', type=int, default=1)
parser.add_argument('--norm-adv', type=int, default=1) parser.add_argument('--norm-adv', type=int, default=1)
parser.add_argument('--recompute-adv', type=int, default=0) parser.add_argument('--recompute-adv', type=int, default=0)
parser.add_argument('--resume', action="store_true")
parser.add_argument("--save-interval", type=int, default=4)
args = parser.parse_known_args()[0] args = parser.parse_known_args()[0]
return args return args
@ -83,8 +85,8 @@ def test_ppo(args=get_args()):
if isinstance(m, torch.nn.Linear): if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight) torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias) torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(set( optim = torch.optim.Adam(
actor.parameters()).union(critic.parameters()), lr=args.lr) list(actor.parameters()) + list(critic.parameters()), lr=args.lr)
# replace DiagGuassian with Independent(Normal) which is equivalent # replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward # pass *logits to be consistent with policy.forward
@ -114,7 +116,7 @@ def test_ppo(args=get_args()):
# log # log
log_path = os.path.join(args.logdir, args.task, 'ppo') log_path = os.path.join(args.logdir, args.task, 'ppo')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
logger = BasicLogger(writer) logger = BasicLogger(writer, save_interval=args.save_interval)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -122,13 +124,34 @@ def test_ppo(args=get_args()):
def stop_fn(mean_rewards): def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold return mean_rewards >= env.spec.reward_threshold
def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
torch.save({
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth'))
if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
optim.load_state_dict(checkpoint['optim'])
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")
# trainer # trainer
result = onpolicy_trainer( result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, args.repeat_per_collect, args.test_num, args.batch_size,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
logger=logger) logger=logger, resume_from_log=args.resume,
save_checkpoint_fn=save_checkpoint_fn)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
@ -140,5 +163,10 @@ def test_ppo(args=get_args()):
print(f"Final reward: {rews.mean()}, length: {lens.mean()}") print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
def test_ppo_resume(args=get_args()):
args.resume = True
test_ppo(args)
if __name__ == '__main__': if __name__ == '__main__':
test_ppo() test_ppo()

View File

@ -124,6 +124,7 @@ def test_sac_with_il(args=get_args()):
update_per_step=args.update_per_step, stop_fn=stop_fn, update_per_step=args.update_per_step, stop_fn=stop_fn,
save_fn=save_fn, logger=logger) save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!

View File

@ -119,6 +119,7 @@ def test_td3(args=get_args()):
update_per_step=args.update_per_step, stop_fn=stop_fn, update_per_step=args.update_per_step, stop_fn=stop_fn,
save_fn=save_fn, logger=logger) save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!

View File

@ -27,7 +27,8 @@ def get_args():
parser.add_argument('--epoch', type=int, default=5) parser.add_argument('--epoch', type=int, default=5)
parser.add_argument('--step-per-epoch', type=int, default=50000) parser.add_argument('--step-per-epoch', type=int, default=50000)
parser.add_argument('--step-per-collect', type=int, default=2048) parser.add_argument('--step-per-collect', type=int, default=2048)
parser.add_argument('--repeat-per-collect', type=int, default=1) parser.add_argument('--repeat-per-collect', type=int,
default=2) # theoretically it should be 1
parser.add_argument('--batch-size', type=int, default=99999) parser.add_argument('--batch-size', type=int, default=99999)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--training-num', type=int, default=16) parser.add_argument('--training-num', type=int, default=16)
@ -82,8 +83,7 @@ def test_trpo(args=get_args()):
if isinstance(m, torch.nn.Linear): if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight) torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias) torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(set( optim = torch.optim.Adam(critic.parameters(), lr=args.lr)
actor.parameters()).union(critic.parameters()), lr=args.lr)
# replace DiagGuassian with Independent(Normal) which is equivalent # replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward # pass *logits to be consistent with policy.forward

View File

@ -74,8 +74,8 @@ def test_a2c_with_il(args=get_args()):
device=args.device) device=args.device)
actor = Actor(net, args.action_shape, device=args.device).to(args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device)
critic = Critic(net, device=args.device).to(args.device) critic = Critic(net, device=args.device).to(args.device)
optim = torch.optim.Adam(set( optim = torch.optim.Adam(
actor.parameters()).union(critic.parameters()), lr=args.lr) list(actor.parameters()) + list(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical dist = torch.distributions.Categorical
policy = A2CPolicy( policy = A2CPolicy(
actor, critic, optim, dist, actor, critic, optim, dist,
@ -106,6 +106,7 @@ def test_a2c_with_il(args=get_args()):
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
logger=logger) logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
@ -135,6 +136,7 @@ def test_a2c_with_il(args=get_args()):
args.il_step_per_epoch, args.step_per_collect, args.test_num, args.il_step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger) args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!

View File

@ -1,6 +1,7 @@
import os import os
import gym import gym
import torch import torch
import pickle
import pprint import pprint
import argparse import argparse
import numpy as np import numpy as np
@ -43,9 +44,11 @@ def get_args():
action="store_true", default=False) action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6) parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4) parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument('--resume', action="store_true")
parser.add_argument( parser.add_argument(
'--device', type=str, '--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu') default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument("--save-interval", type=int, default=4)
args = parser.parse_known_args()[0] args = parser.parse_known_args()[0]
return args return args
@ -90,7 +93,7 @@ def test_c51(args=get_args()):
# log # log
log_path = os.path.join(args.logdir, args.task, 'c51') log_path = os.path.join(args.logdir, args.task, 'c51')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
logger = BasicLogger(writer) logger = BasicLogger(writer, save_interval=args.save_interval)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -112,14 +115,42 @@ def test_c51(args=get_args()):
def test_fn(epoch, env_step): def test_fn(epoch, env_step):
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)
def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
torch.save({
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth'))
pickle.dump(train_collector.buffer,
open(os.path.join(log_path, 'train_buffer.pkl'), "wb"))
if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
policy.optim.load_state_dict(checkpoint['optim'])
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")
buffer_path = os.path.join(log_path, 'train_buffer.pkl')
if os.path.exists(buffer_path):
train_collector.buffer = pickle.load(open(buffer_path, "rb"))
print("Successfully restore buffer.")
else:
print("Fail to restore buffer.")
# trainer # trainer
result = offpolicy_trainer( result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch, policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num, args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger) test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
@ -132,6 +163,11 @@ def test_c51(args=get_args()):
print(f"Final reward: {rews.mean()}, length: {lens.mean()}") print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
def test_c51_resume(args=get_args()):
args.resume = True
test_c51(args)
def test_pc51(args=get_args()): def test_pc51(args=get_args()):
args.prioritized_replay = True args.prioritized_replay = True
args.gamma = .95 args.gamma = .95

View File

@ -120,7 +120,6 @@ def test_dqn(args=get_args()):
args.step_per_epoch, args.step_per_collect, args.test_num, args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn, args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger) test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -99,8 +99,8 @@ def test_drqn(args=get_args()):
args.batch_size, update_per_step=args.update_per_step, args.batch_size, update_per_step=args.update_per_step,
train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn, train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn,
save_fn=save_fn, logger=logger) save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!

View File

@ -42,6 +42,8 @@ def get_args():
"--device", type=str, "--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu", default="cuda" if torch.cuda.is_available() else "cpu",
) )
parser.add_argument("--resume", action="store_true")
parser.add_argument("--save-interval", type=int, default=4)
args = parser.parse_known_args()[0] args = parser.parse_known_args()[0]
return args return args
@ -67,7 +69,7 @@ def test_discrete_bcq(args=get_args()):
args.state_shape, args.action_shape, args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, device=args.device).to(args.device) hidden_sizes=args.hidden_sizes, device=args.device).to(args.device)
optim = torch.optim.Adam( optim = torch.optim.Adam(
set(policy_net.parameters()).union(imitation_net.parameters()), list(policy_net.parameters()) + list(imitation_net.parameters()),
lr=args.lr) lr=args.lr)
policy = DiscreteBCQPolicy( policy = DiscreteBCQPolicy(
@ -85,7 +87,7 @@ def test_discrete_bcq(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'discrete_bcq') log_path = os.path.join(args.logdir, args.task, 'discrete_bcq')
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
logger = BasicLogger(writer) logger = BasicLogger(writer, save_interval=args.save_interval)
def save_fn(policy): def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth')) torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -93,11 +95,30 @@ def test_discrete_bcq(args=get_args()):
def stop_fn(mean_rewards): def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold return mean_rewards >= env.spec.reward_threshold
def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
torch.save({
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth'))
if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
optim.load_state_dict(checkpoint['optim'])
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")
result = offline_trainer( result = offline_trainer(
policy, buffer, test_collector, policy, buffer, test_collector,
args.epoch, args.update_per_epoch, args.test_num, args.batch_size, args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
stop_fn=stop_fn, save_fn=save_fn, logger=logger) stop_fn=stop_fn, save_fn=save_fn, logger=logger,
resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
@ -112,5 +133,10 @@ def test_discrete_bcq(args=get_args()):
print(f"Final reward: {rews.mean()}, length: {lens.mean()}") print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
def test_discrete_bcq_resume(args=get_args()):
args.resume = True
test_discrete_bcq(args)
if __name__ == "__main__": if __name__ == "__main__":
test_discrete_bcq(get_args()) test_discrete_bcq(get_args())

View File

@ -93,6 +93,7 @@ def test_pg(args=get_args()):
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
logger=logger) logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!

View File

@ -75,8 +75,8 @@ def test_ppo(args=get_args()):
if isinstance(m, torch.nn.Linear): if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight) torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias) torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(set( optim = torch.optim.Adam(
actor.parameters()).union(critic.parameters()), lr=args.lr) list(actor.parameters()) + list(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical dist = torch.distributions.Categorical
policy = PPOPolicy( policy = PPOPolicy(
actor, critic, optim, dist, actor, critic, optim, dist,
@ -114,6 +114,7 @@ def test_ppo(args=get_args()):
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
logger=logger) logger=logger)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!

View File

@ -117,8 +117,8 @@ def test_qrdqn(args=get_args()):
args.batch_size, train_fn=train_fn, test_fn=test_fn, args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, logger=logger, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step) update_per_step=args.update_per_step)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!

View File

@ -112,6 +112,7 @@ def test_discrete_sac(args=get_args()):
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger, args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step, test_in_train=False) update_per_step=args.update_per_step, test_in_train=False)
assert stop_fn(result['best_reward']) assert stop_fn(result['best_reward'])
if __name__ == '__main__': if __name__ == '__main__':
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!

View File

@ -1,5 +1,6 @@
import time import time
import tqdm import tqdm
import warnings
import numpy as np import numpy as np
from collections import defaultdict from collections import defaultdict
from typing import Dict, Union, Callable, Optional from typing import Dict, Union, Callable, Optional
@ -21,6 +22,8 @@ def offline_trainer(
test_fn: Optional[Callable[[int, Optional[int]], None]] = None, test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None, stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
resume_from_log: bool = False,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(), logger: BaseLogger = LazyLogger(),
verbose: bool = True, verbose: bool = True,
@ -44,6 +47,12 @@ def offline_trainer(
:param function save_fn: a hook called when the undiscounted average mean reward in :param function save_fn: a hook called when the undiscounted average mean reward in
evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
None``. None``.
:param function save_checkpoint_fn: a function to save training process, with the
signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can
save whatever you want. Because offline-RL doesn't have env_step, the env_step
is always 0 here.
:param bool resume_from_log: resume gradient_step and other metadata from existing
tensorboard log. Default to False.
:param function stop_fn: a function with signature ``f(mean_rewards: float) -> :param function stop_fn: a function with signature ``f(mean_rewards: float) ->
bool``, receives the average undiscounted returns of the testing result, bool``, receives the average undiscounted returns of the testing result,
returns a boolean which indicates whether reaching the goal. returns a boolean which indicates whether reaching the goal.
@ -59,15 +68,22 @@ def offline_trainer(
:return: See :func:`~tianshou.trainer.gather_info`. :return: See :func:`~tianshou.trainer.gather_info`.
""" """
gradient_step = 0 if save_fn:
warnings.warn("Please consider using save_checkpoint_fn instead of save_fn.")
start_epoch, gradient_step = 0, 0
if resume_from_log:
start_epoch, _, gradient_step = logger.restore_data()
stat: Dict[str, MovAvg] = defaultdict(MovAvg) stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time() start_time = time.time()
test_collector.reset_stat() test_collector.reset_stat()
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test,
logger, gradient_step, reward_metric) test_result = test_episode(policy, test_collector, test_fn, start_epoch,
best_epoch = 0 episode_per_test, logger, gradient_step, reward_metric)
best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
for epoch in range(1, 1 + max_epoch):
for epoch in range(1 + start_epoch, 1 + max_epoch):
policy.train() policy.train()
with tqdm.trange( with tqdm.trange(
update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
@ -87,15 +103,14 @@ def offline_trainer(
policy, test_collector, test_fn, epoch, episode_per_test, policy, test_collector, test_fn, epoch, episode_per_test,
logger, gradient_step, reward_metric) logger, gradient_step, reward_metric)
rew, rew_std = test_result["rew"], test_result["rew_std"] rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch == -1 or best_reward < rew: if best_epoch < 0 or best_reward < rew:
best_reward, best_reward_std = rew, rew_std best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
best_epoch = epoch
if save_fn: if save_fn:
save_fn(policy) save_fn(policy)
logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn)
if verbose: if verbose:
print( print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_reward:" f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
f" {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward): if stop_fn and stop_fn(best_reward):
break break
return gather_info(start_time, None, test_collector, best_reward, best_reward_std) return gather_info(start_time, None, test_collector, best_reward, best_reward_std)

View File

@ -1,13 +1,14 @@
import time import time
import tqdm import tqdm
import warnings
import numpy as np import numpy as np
from collections import defaultdict from collections import defaultdict
from typing import Dict, Union, Callable, Optional from typing import Dict, Union, Callable, Optional
from tianshou.data import Collector from tianshou.data import Collector
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
from tianshou.trainer import test_episode, gather_info from tianshou.trainer import test_episode, gather_info
from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
def offpolicy_trainer( def offpolicy_trainer(
@ -24,6 +25,8 @@ def offpolicy_trainer(
test_fn: Optional[Callable[[int, Optional[int]], None]] = None, test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None, stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
resume_from_log: bool = False,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(), logger: BaseLogger = LazyLogger(),
verbose: bool = True, verbose: bool = True,
@ -57,8 +60,13 @@ def offpolicy_trainer(
It can be used to perform custom additional operations, with the signature ``f( It can be used to perform custom additional operations, with the signature ``f(
num_epoch: int, step_idx: int) -> None``. num_epoch: int, step_idx: int) -> None``.
:param function save_fn: a hook called when the undiscounted average mean reward in :param function save_fn: a hook called when the undiscounted average mean reward in
evaluation phase gets better, with the signature ``f(policy:BasePolicy) -> evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
None``. None``.
:param function save_checkpoint_fn: a function to save training process, with the
signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can
save whatever you want.
:param bool resume_from_log: resume env_step/gradient_step and other metadata from
existing tensorboard log. Default to False.
:param function stop_fn: a function with signature ``f(mean_rewards: float) -> :param function stop_fn: a function with signature ``f(mean_rewards: float) ->
bool``, receives the average undiscounted returns of the testing result, bool``, receives the average undiscounted returns of the testing result,
returns a boolean which indicates whether reaching the goal. returns a boolean which indicates whether reaching the goal.
@ -75,18 +83,24 @@ def offpolicy_trainer(
:return: See :func:`~tianshou.trainer.gather_info`. :return: See :func:`~tianshou.trainer.gather_info`.
""" """
env_step, gradient_step = 0, 0 if save_fn:
warnings.warn("Please consider using save_checkpoint_fn instead of save_fn.")
start_epoch, env_step, gradient_step = 0, 0, 0
if resume_from_log:
start_epoch, env_step, gradient_step = logger.restore_data()
last_rew, last_len = 0.0, 0 last_rew, last_len = 0.0, 0
stat: Dict[str, MovAvg] = defaultdict(MovAvg) stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time() start_time = time.time()
train_collector.reset_stat() train_collector.reset_stat()
test_collector.reset_stat() test_collector.reset_stat()
test_in_train = test_in_train and train_collector.policy == policy test_in_train = test_in_train and train_collector.policy == policy
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test, test_result = test_episode(policy, test_collector, test_fn, start_epoch,
logger, env_step, reward_metric) episode_per_test, logger, env_step, reward_metric)
best_epoch = 0 best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
for epoch in range(1, 1 + max_epoch):
for epoch in range(1 + start_epoch, 1 + max_epoch):
# train # train
policy.train() policy.train()
with tqdm.tqdm( with tqdm.tqdm(
@ -118,6 +132,8 @@ def offpolicy_trainer(
if stop_fn(test_result["rew"]): if stop_fn(test_result["rew"]):
if save_fn: if save_fn:
save_fn(policy) save_fn(policy)
logger.save_data(
epoch, env_step, gradient_step, save_checkpoint_fn)
t.set_postfix(**data) t.set_postfix(**data)
return gather_info( return gather_info(
start_time, train_collector, test_collector, start_time, train_collector, test_collector,
@ -139,15 +155,14 @@ def offpolicy_trainer(
test_result = test_episode(policy, test_collector, test_fn, epoch, test_result = test_episode(policy, test_collector, test_fn, epoch,
episode_per_test, logger, env_step, reward_metric) episode_per_test, logger, env_step, reward_metric)
rew, rew_std = test_result["rew"], test_result["rew_std"] rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch == -1 or best_reward < rew: if best_epoch < 0 or best_reward < rew:
best_reward, best_reward_std = rew, rew_std best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
best_epoch = epoch
if save_fn: if save_fn:
save_fn(policy) save_fn(policy)
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
if verbose: if verbose:
print( print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_reward:" f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
f" {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward): if stop_fn and stop_fn(best_reward):
break break
return gather_info(start_time, train_collector, test_collector, return gather_info(start_time, train_collector, test_collector,

View File

@ -1,13 +1,14 @@
import time import time
import tqdm import tqdm
import warnings
import numpy as np import numpy as np
from collections import defaultdict from collections import defaultdict
from typing import Dict, Union, Callable, Optional from typing import Dict, Union, Callable, Optional
from tianshou.data import Collector from tianshou.data import Collector
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
from tianshou.trainer import test_episode, gather_info from tianshou.trainer import test_episode, gather_info
from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
def onpolicy_trainer( def onpolicy_trainer(
@ -25,6 +26,8 @@ def onpolicy_trainer(
test_fn: Optional[Callable[[int, Optional[int]], None]] = None, test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None, stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None, save_fn: Optional[Callable[[BasePolicy], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
resume_from_log: bool = False,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None, reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(), logger: BaseLogger = LazyLogger(),
verbose: bool = True, verbose: bool = True,
@ -61,6 +64,11 @@ def onpolicy_trainer(
:param function save_fn: a hook called when the undiscounted average mean reward in :param function save_fn: a hook called when the undiscounted average mean reward in
evaluation phase gets better, with the signature ``f(policy: BasePolicy) -> evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
None``. None``.
:param function save_checkpoint_fn: a function to save training process, with the
signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can
save whatever you want.
:param bool resume_from_log: resume env_step/gradient_step and other metadata from
existing tensorboard log. Default to False.
:param function stop_fn: a function with signature ``f(mean_rewards: float) -> :param function stop_fn: a function with signature ``f(mean_rewards: float) ->
bool``, receives the average undiscounted returns of the testing result, bool``, receives the average undiscounted returns of the testing result,
returns a boolean which indicates whether reaching the goal. returns a boolean which indicates whether reaching the goal.
@ -81,18 +89,24 @@ def onpolicy_trainer(
Only either one of step_per_collect and episode_per_collect can be specified. Only either one of step_per_collect and episode_per_collect can be specified.
""" """
env_step, gradient_step = 0, 0 if save_fn:
warnings.warn("Please consider using save_checkpoint_fn instead of save_fn.")
start_epoch, env_step, gradient_step = 0, 0, 0
if resume_from_log:
start_epoch, env_step, gradient_step = logger.restore_data()
last_rew, last_len = 0.0, 0 last_rew, last_len = 0.0, 0
stat: Dict[str, MovAvg] = defaultdict(MovAvg) stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time() start_time = time.time()
train_collector.reset_stat() train_collector.reset_stat()
test_collector.reset_stat() test_collector.reset_stat()
test_in_train = test_in_train and train_collector.policy == policy test_in_train = test_in_train and train_collector.policy == policy
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test, test_result = test_episode(policy, test_collector, test_fn, start_epoch,
logger, env_step, reward_metric) episode_per_test, logger, env_step, reward_metric)
best_epoch = 0 best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"] best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
for epoch in range(1, 1 + max_epoch):
for epoch in range(1 + start_epoch, 1 + max_epoch):
# train # train
policy.train() policy.train()
with tqdm.tqdm( with tqdm.tqdm(
@ -125,6 +139,8 @@ def onpolicy_trainer(
if stop_fn(test_result["rew"]): if stop_fn(test_result["rew"]):
if save_fn: if save_fn:
save_fn(policy) save_fn(policy)
logger.save_data(
epoch, env_step, gradient_step, save_checkpoint_fn)
t.set_postfix(**data) t.set_postfix(**data)
return gather_info( return gather_info(
start_time, train_collector, test_collector, start_time, train_collector, test_collector,
@ -150,15 +166,14 @@ def onpolicy_trainer(
test_result = test_episode(policy, test_collector, test_fn, epoch, test_result = test_episode(policy, test_collector, test_fn, epoch,
episode_per_test, logger, env_step, reward_metric) episode_per_test, logger, env_step, reward_metric)
rew, rew_std = test_result["rew"], test_result["rew_std"] rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch == -1 or best_reward < rew: if best_epoch < 0 or best_reward < rew:
best_reward, best_reward_std = rew, rew_std best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
best_epoch = epoch
if save_fn: if save_fn:
save_fn(policy) save_fn(policy)
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
if verbose: if verbose:
print( print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_reward:" f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
f" {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward): if stop_fn and stop_fn(best_reward):
break break
return gather_info(start_time, train_collector, test_collector, return gather_info(start_time, train_collector, test_collector,

View File

@ -1,8 +1,12 @@
import numpy as np import numpy as np
from numbers import Number from numbers import Number
from typing import Any, Union
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from typing import Any, Tuple, Union, Callable, Optional
from tensorboard.backend.event_processing import event_accumulator
WRITE_TYPE = Union[int, Number, np.number, np.ndarray]
class BaseLogger(ABC): class BaseLogger(ABC):
@ -13,9 +17,7 @@ class BaseLogger(ABC):
self.writer = writer self.writer = writer
@abstractmethod @abstractmethod
def write( def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None:
self, key: str, x: int, y: Union[Number, np.number, np.ndarray], **kwargs: Any
) -> None:
"""Specify how the writer is used to log data. """Specify how the writer is used to log data.
:param str key: namespace which the input data tuple belongs to. :param str key: namespace which the input data tuple belongs to.
@ -51,6 +53,33 @@ class BaseLogger(ABC):
""" """
pass pass
def save_data(
self,
epoch: int,
env_step: int,
gradient_step: int,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
) -> None:
"""Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.
:param int epoch: the epoch in trainer.
:param int env_step: the env_step in trainer.
:param int gradient_step: the gradient_step in trainer.
:param function save_checkpoint_fn: a hook defined by user, see trainer
documentation for detail.
"""
pass
def restore_data(self) -> Tuple[int, int, int]:
"""Return the metadata from existing log.
If it finds nothing or an error occurs during the recover process, it will
return the default parameters.
:return: epoch, env_step, gradient_step.
"""
pass
class BasicLogger(BaseLogger): class BasicLogger(BaseLogger):
"""A loggger that relies on tensorboard SummaryWriter by default to visualize \ """A loggger that relies on tensorboard SummaryWriter by default to visualize \
@ -62,6 +91,8 @@ class BasicLogger(BaseLogger):
:param int train_interval: the log interval in log_train_data(). Default to 1. :param int train_interval: the log interval in log_train_data(). Default to 1.
:param int test_interval: the log interval in log_test_data(). Default to 1. :param int test_interval: the log interval in log_test_data(). Default to 1.
:param int update_interval: the log interval in log_update_data(). Default to 1000. :param int update_interval: the log interval in log_update_data(). Default to 1000.
:param int save_interval: the save interval in save_data(). Default to 1 (save at
the end of each epoch).
""" """
def __init__( def __init__(
@ -70,18 +101,19 @@ class BasicLogger(BaseLogger):
train_interval: int = 1, train_interval: int = 1,
test_interval: int = 1, test_interval: int = 1,
update_interval: int = 1000, update_interval: int = 1000,
save_interval: int = 1,
) -> None: ) -> None:
super().__init__(writer) super().__init__(writer)
self.train_interval = train_interval self.train_interval = train_interval
self.test_interval = test_interval self.test_interval = test_interval
self.update_interval = update_interval self.update_interval = update_interval
self.save_interval = save_interval
self.last_log_train_step = -1 self.last_log_train_step = -1
self.last_log_test_step = -1 self.last_log_test_step = -1
self.last_log_update_step = -1 self.last_log_update_step = -1
self.last_save_step = -1
def write( def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None:
self, key: str, x: int, y: Union[Number, np.number, np.ndarray], **kwargs: Any
) -> None:
self.writer.add_scalar(key, y, global_step=x) self.writer.add_scalar(key, y, global_step=x)
def log_train_data(self, collect_result: dict, step: int) -> None: def log_train_data(self, collect_result: dict, step: int) -> None:
@ -133,6 +165,39 @@ class BasicLogger(BaseLogger):
self.write(k, step, v) self.write(k, step, v)
self.last_log_update_step = step self.last_log_update_step = step
def save_data(
self,
epoch: int,
env_step: int,
gradient_step: int,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
) -> None:
if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval:
self.last_save_step = epoch
save_checkpoint_fn(epoch, env_step, gradient_step)
self.write("save/epoch", epoch, epoch)
self.write("save/env_step", env_step, env_step)
self.write("save/gradient_step", gradient_step, gradient_step)
def restore_data(self) -> Tuple[int, int, int]:
ea = event_accumulator.EventAccumulator(self.writer.log_dir)
ea.Reload()
try: # epoch / gradient_step
epoch = ea.scalars.Items("save/epoch")[-1].step
self.last_save_step = self.last_log_test_step = epoch
gradient_step = ea.scalars.Items("save/gradient_step")[-1].step
self.last_log_update_step = gradient_step
except KeyError:
epoch, gradient_step = 0, 0
try: # offline trainer doesn't have env_step
env_step = ea.scalars.Items("save/env_step")[-1].step
self.last_log_train_step = env_step
except KeyError:
env_step = 0
return epoch, env_step, gradient_step
class LazyLogger(BasicLogger): class LazyLogger(BasicLogger):
"""A loggger that does nothing. Used as the placeholder in trainer.""" """A loggger that does nothing. Used as the placeholder in trainer."""
@ -140,8 +205,6 @@ class LazyLogger(BasicLogger):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(None) # type: ignore super().__init__(None) # type: ignore
def write( def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None:
self, key: str, x: int, y: Union[Number, np.number, np.ndarray], **kwargs: Any
) -> None:
"""The LazyLogger writes nothing.""" """The LazyLogger writes nothing."""
pass pass