Make trainer resumable (#350)
- specify tensorboard >= 2.5.0 - add `save_checkpoint_fn` and `resume_from_log` in trainer Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
This commit is contained in:
parent
f4e05d585a
commit
84f58636eb
@ -30,6 +30,34 @@ Customize Training Process
|
||||
See :ref:`customized_trainer`.
|
||||
|
||||
|
||||
.. _resume_training:
|
||||
|
||||
Resume Training Process
|
||||
-----------------------
|
||||
|
||||
This is related to `Issue 349 <https://github.com/thu-ml/tianshou/issues/349>`_.
|
||||
|
||||
To resume training process from an existing checkpoint, you need to do the following things in the training process:
|
||||
|
||||
1. Make sure you write ``save_checkpoint_fn`` which saves everything needed in the training process, i.e., policy, optim, buffer; pass it to trainer;
|
||||
2. Use ``BasicLogger`` which contains a tensorboard;
|
||||
3. To adjust the save frequency, specify ``save_interval`` when initializing BasicLogger.
|
||||
|
||||
And to successfully resume from a checkpoint:
|
||||
|
||||
1. Load everything needed in the training process **before trainer initialization**, i.e., policy, optim, buffer;
|
||||
2. Set ``resume_from_log=True`` with trainer;
|
||||
|
||||
We provide an example to show how these steps work: checkout `test_c51.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_c51.py>`_, `test_ppo.py <https://github.com/thu-ml/tianshou/blob/master/test/continuous/test_ppo.py>`_ or `test_il_bcq.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_il_bcq.py>`_ by running
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python3 test/discrete/test_c51.py # train some epoch
|
||||
$ python3 test/discrete/test_c51.py --resume # restore from existing log and continuing training
|
||||
|
||||
|
||||
To correctly render the data (including several tfevent files), we highly recommend using ``tensorboard >= 2.5.0`` (see `here <https://github.com/thu-ml/tianshou/pull/350#issuecomment-829123378>`_ for the reason). Otherwise, it may cause overlapping issue that you need to manually handle with.
|
||||
|
||||
.. _parallel_sampling:
|
||||
|
||||
Parallel Sampling
|
||||
|
@ -85,8 +85,7 @@ def test_discrete_bcq(args=get_args()):
|
||||
feature_net, args.action_shape, device=args.device,
|
||||
hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device)
|
||||
optim = torch.optim.Adam(
|
||||
set(policy_net.parameters()).union(imitation_net.parameters()),
|
||||
lr=args.lr)
|
||||
list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr)
|
||||
# define policy
|
||||
policy = DiscreteBCQPolicy(
|
||||
policy_net, imitation_net, optim, args.gamma, args.n_step,
|
||||
|
@ -101,7 +101,7 @@ def test_a2c(args=get_args()):
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
m.weight.data.copy_(0.01 * m.weight.data)
|
||||
|
||||
optim = torch.optim.RMSprop(set(actor.parameters()).union(critic.parameters()),
|
||||
optim = torch.optim.RMSprop(list(actor.parameters()) + list(critic.parameters()),
|
||||
lr=args.lr, eps=1e-5, alpha=0.99)
|
||||
|
||||
lr_scheduler = None
|
||||
|
@ -106,8 +106,8 @@ def test_ppo(args=get_args()):
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
m.weight.data.copy_(0.01 * m.weight.data)
|
||||
|
||||
optim = torch.optim.Adam(set(
|
||||
actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||
optim = torch.optim.Adam(
|
||||
list(actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
||||
|
||||
lr_scheduler = None
|
||||
if args.lr_decay:
|
||||
|
2
setup.py
2
setup.py
@ -48,7 +48,7 @@ setup(
|
||||
"gym>=0.15.4",
|
||||
"tqdm",
|
||||
"numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793
|
||||
"tensorboard",
|
||||
"tensorboard>=2.5.0",
|
||||
"torch>=1.4.0",
|
||||
"numba>=0.51.0",
|
||||
"h5py>=2.10.0", # to match tensorflow's minimal requirements
|
||||
|
@ -105,6 +105,7 @@ def test_ddpg(args=get_args()):
|
||||
update_per_step=args.update_per_step, stop_fn=stop_fn,
|
||||
save_fn=save_fn, logger=logger)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
|
@ -80,8 +80,7 @@ def test_npg(args=get_args()):
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.orthogonal_(m.weight)
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
optim = torch.optim.Adam(set(
|
||||
actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||
optim = torch.optim.Adam(critic.parameters(), lr=args.lr)
|
||||
|
||||
# replace DiagGuassian with Independent(Normal) which is equivalent
|
||||
# pass *logits to be consistent with policy.forward
|
||||
|
@ -47,6 +47,8 @@ def get_args():
|
||||
parser.add_argument('--value-clip', type=int, default=1)
|
||||
parser.add_argument('--norm-adv', type=int, default=1)
|
||||
parser.add_argument('--recompute-adv', type=int, default=0)
|
||||
parser.add_argument('--resume', action="store_true")
|
||||
parser.add_argument("--save-interval", type=int, default=4)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
@ -83,8 +85,8 @@ def test_ppo(args=get_args()):
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.orthogonal_(m.weight)
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
optim = torch.optim.Adam(set(
|
||||
actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||
optim = torch.optim.Adam(
|
||||
list(actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
||||
|
||||
# replace DiagGuassian with Independent(Normal) which is equivalent
|
||||
# pass *logits to be consistent with policy.forward
|
||||
@ -114,7 +116,7 @@ def test_ppo(args=get_args()):
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'ppo')
|
||||
writer = SummaryWriter(log_path)
|
||||
logger = BasicLogger(writer)
|
||||
logger = BasicLogger(writer, save_interval=args.save_interval)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
@ -122,13 +124,34 @@ def test_ppo(args=get_args()):
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
||||
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||
torch.save({
|
||||
'model': policy.state_dict(),
|
||||
'optim': optim.state_dict(),
|
||||
}, os.path.join(log_path, 'checkpoint.pth'))
|
||||
|
||||
if args.resume:
|
||||
# load from existing checkpoint
|
||||
print(f"Loading agent under {log_path}")
|
||||
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
|
||||
if os.path.exists(ckpt_path):
|
||||
checkpoint = torch.load(ckpt_path, map_location=args.device)
|
||||
policy.load_state_dict(checkpoint['model'])
|
||||
optim.load_state_dict(checkpoint['optim'])
|
||||
print("Successfully restore policy and optim.")
|
||||
else:
|
||||
print("Fail to restore policy and optim.")
|
||||
|
||||
# trainer
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
|
||||
policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
|
||||
args.repeat_per_collect, args.test_num, args.batch_size,
|
||||
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
|
||||
logger=logger)
|
||||
logger=logger, resume_from_log=args.resume,
|
||||
save_checkpoint_fn=save_checkpoint_fn)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -140,5 +163,10 @@ def test_ppo(args=get_args()):
|
||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||
|
||||
|
||||
def test_ppo_resume(args=get_args()):
|
||||
args.resume = True
|
||||
test_ppo(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_ppo()
|
||||
|
@ -124,6 +124,7 @@ def test_sac_with_il(args=get_args()):
|
||||
update_per_step=args.update_per_step, stop_fn=stop_fn,
|
||||
save_fn=save_fn, logger=logger)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
|
@ -119,6 +119,7 @@ def test_td3(args=get_args()):
|
||||
update_per_step=args.update_per_step, stop_fn=stop_fn,
|
||||
save_fn=save_fn, logger=logger)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
|
@ -27,7 +27,8 @@ def get_args():
|
||||
parser.add_argument('--epoch', type=int, default=5)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=50000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=2048)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=1)
|
||||
parser.add_argument('--repeat-per-collect', type=int,
|
||||
default=2) # theoretically it should be 1
|
||||
parser.add_argument('--batch-size', type=int, default=99999)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
|
||||
parser.add_argument('--training-num', type=int, default=16)
|
||||
@ -82,8 +83,7 @@ def test_trpo(args=get_args()):
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.orthogonal_(m.weight)
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
optim = torch.optim.Adam(set(
|
||||
actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||
optim = torch.optim.Adam(critic.parameters(), lr=args.lr)
|
||||
|
||||
# replace DiagGuassian with Independent(Normal) which is equivalent
|
||||
# pass *logits to be consistent with policy.forward
|
||||
|
@ -74,8 +74,8 @@ def test_a2c_with_il(args=get_args()):
|
||||
device=args.device)
|
||||
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||
critic = Critic(net, device=args.device).to(args.device)
|
||||
optim = torch.optim.Adam(set(
|
||||
actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||
optim = torch.optim.Adam(
|
||||
list(actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
||||
dist = torch.distributions.Categorical
|
||||
policy = A2CPolicy(
|
||||
actor, critic, optim, dist,
|
||||
@ -106,6 +106,7 @@ def test_a2c_with_il(args=get_args()):
|
||||
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
|
||||
logger=logger)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -135,6 +136,7 @@ def test_a2c_with_il(args=get_args()):
|
||||
args.il_step_per_epoch, args.step_per_collect, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
|
@ -1,6 +1,7 @@
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import pickle
|
||||
import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
@ -43,9 +44,11 @@ def get_args():
|
||||
action="store_true", default=False)
|
||||
parser.add_argument('--alpha', type=float, default=0.6)
|
||||
parser.add_argument('--beta', type=float, default=0.4)
|
||||
parser.add_argument('--resume', action="store_true")
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
parser.add_argument("--save-interval", type=int, default=4)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
@ -90,7 +93,7 @@ def test_c51(args=get_args()):
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'c51')
|
||||
writer = SummaryWriter(log_path)
|
||||
logger = BasicLogger(writer)
|
||||
logger = BasicLogger(writer, save_interval=args.save_interval)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
@ -112,14 +115,42 @@ def test_c51(args=get_args()):
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
||||
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
||||
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||
torch.save({
|
||||
'model': policy.state_dict(),
|
||||
'optim': optim.state_dict(),
|
||||
}, os.path.join(log_path, 'checkpoint.pth'))
|
||||
pickle.dump(train_collector.buffer,
|
||||
open(os.path.join(log_path, 'train_buffer.pkl'), "wb"))
|
||||
|
||||
if args.resume:
|
||||
# load from existing checkpoint
|
||||
print(f"Loading agent under {log_path}")
|
||||
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
|
||||
if os.path.exists(ckpt_path):
|
||||
checkpoint = torch.load(ckpt_path, map_location=args.device)
|
||||
policy.load_state_dict(checkpoint['model'])
|
||||
policy.optim.load_state_dict(checkpoint['optim'])
|
||||
print("Successfully restore policy and optim.")
|
||||
else:
|
||||
print("Fail to restore policy and optim.")
|
||||
buffer_path = os.path.join(log_path, 'train_buffer.pkl')
|
||||
if os.path.exists(buffer_path):
|
||||
train_collector.buffer = pickle.load(open(buffer_path, "rb"))
|
||||
print("Successfully restore buffer.")
|
||||
else:
|
||||
print("Fail to restore buffer.")
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||
args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
|
||||
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
|
||||
|
||||
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
|
||||
resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -132,6 +163,11 @@ def test_c51(args=get_args()):
|
||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||
|
||||
|
||||
def test_c51_resume(args=get_args()):
|
||||
args.resume = True
|
||||
test_c51(args)
|
||||
|
||||
|
||||
def test_pc51(args=get_args()):
|
||||
args.prioritized_replay = True
|
||||
args.gamma = .95
|
||||
|
@ -120,7 +120,6 @@ def test_dqn(args=get_args()):
|
||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||
args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
|
||||
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
|
||||
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -99,8 +99,8 @@ def test_drqn(args=get_args()):
|
||||
args.batch_size, update_per_step=args.update_per_step,
|
||||
train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn,
|
||||
save_fn=save_fn, logger=logger)
|
||||
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
|
@ -42,6 +42,8 @@ def get_args():
|
||||
"--device", type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
)
|
||||
parser.add_argument("--resume", action="store_true")
|
||||
parser.add_argument("--save-interval", type=int, default=4)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
@ -67,7 +69,7 @@ def test_discrete_bcq(args=get_args()):
|
||||
args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes, device=args.device).to(args.device)
|
||||
optim = torch.optim.Adam(
|
||||
set(policy_net.parameters()).union(imitation_net.parameters()),
|
||||
list(policy_net.parameters()) + list(imitation_net.parameters()),
|
||||
lr=args.lr)
|
||||
|
||||
policy = DiscreteBCQPolicy(
|
||||
@ -85,7 +87,7 @@ def test_discrete_bcq(args=get_args()):
|
||||
|
||||
log_path = os.path.join(args.logdir, args.task, 'discrete_bcq')
|
||||
writer = SummaryWriter(log_path)
|
||||
logger = BasicLogger(writer)
|
||||
logger = BasicLogger(writer, save_interval=args.save_interval)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
@ -93,11 +95,30 @@ def test_discrete_bcq(args=get_args()):
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
|
||||
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
||||
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||
torch.save({
|
||||
'model': policy.state_dict(),
|
||||
'optim': optim.state_dict(),
|
||||
}, os.path.join(log_path, 'checkpoint.pth'))
|
||||
|
||||
if args.resume:
|
||||
# load from existing checkpoint
|
||||
print(f"Loading agent under {log_path}")
|
||||
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
|
||||
if os.path.exists(ckpt_path):
|
||||
checkpoint = torch.load(ckpt_path, map_location=args.device)
|
||||
policy.load_state_dict(checkpoint['model'])
|
||||
optim.load_state_dict(checkpoint['optim'])
|
||||
print("Successfully restore policy and optim.")
|
||||
else:
|
||||
print("Fail to restore policy and optim.")
|
||||
|
||||
result = offline_trainer(
|
||||
policy, buffer, test_collector,
|
||||
args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
|
||||
stop_fn=stop_fn, save_fn=save_fn, logger=logger)
|
||||
|
||||
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
|
||||
resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -112,5 +133,10 @@ def test_discrete_bcq(args=get_args()):
|
||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||
|
||||
|
||||
def test_discrete_bcq_resume(args=get_args()):
|
||||
args.resume = True
|
||||
test_discrete_bcq(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_discrete_bcq(get_args())
|
||||
|
@ -93,6 +93,7 @@ def test_pg(args=get_args()):
|
||||
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
|
||||
logger=logger)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
|
@ -75,8 +75,8 @@ def test_ppo(args=get_args()):
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.orthogonal_(m.weight)
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
optim = torch.optim.Adam(set(
|
||||
actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||
optim = torch.optim.Adam(
|
||||
list(actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
||||
dist = torch.distributions.Categorical
|
||||
policy = PPOPolicy(
|
||||
actor, critic, optim, dist,
|
||||
@ -114,6 +114,7 @@ def test_ppo(args=get_args()):
|
||||
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
|
||||
logger=logger)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
|
@ -117,8 +117,8 @@ def test_qrdqn(args=get_args()):
|
||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
|
||||
update_per_step=args.update_per_step)
|
||||
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
|
@ -112,6 +112,7 @@ def test_discrete_sac(args=get_args()):
|
||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
|
||||
update_per_step=args.update_per_step, test_in_train=False)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
|
@ -1,5 +1,6 @@
|
||||
import time
|
||||
import tqdm
|
||||
import warnings
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Union, Callable, Optional
|
||||
@ -21,6 +22,8 @@ def offline_trainer(
|
||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||
resume_from_log: bool = False,
|
||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||
logger: BaseLogger = LazyLogger(),
|
||||
verbose: bool = True,
|
||||
@ -44,6 +47,12 @@ def offline_trainer(
|
||||
:param function save_fn: a hook called when the undiscounted average mean reward in
|
||||
evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
|
||||
None``.
|
||||
:param function save_checkpoint_fn: a function to save training process, with the
|
||||
signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can
|
||||
save whatever you want. Because offline-RL doesn't have env_step, the env_step
|
||||
is always 0 here.
|
||||
:param bool resume_from_log: resume gradient_step and other metadata from existing
|
||||
tensorboard log. Default to False.
|
||||
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
||||
bool``, receives the average undiscounted returns of the testing result,
|
||||
returns a boolean which indicates whether reaching the goal.
|
||||
@ -59,15 +68,22 @@ def offline_trainer(
|
||||
|
||||
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||
"""
|
||||
gradient_step = 0
|
||||
if save_fn:
|
||||
warnings.warn("Please consider using save_checkpoint_fn instead of save_fn.")
|
||||
|
||||
start_epoch, gradient_step = 0, 0
|
||||
if resume_from_log:
|
||||
start_epoch, _, gradient_step = logger.restore_data()
|
||||
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
||||
start_time = time.time()
|
||||
test_collector.reset_stat()
|
||||
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test,
|
||||
logger, gradient_step, reward_metric)
|
||||
best_epoch = 0
|
||||
|
||||
test_result = test_episode(policy, test_collector, test_fn, start_epoch,
|
||||
episode_per_test, logger, gradient_step, reward_metric)
|
||||
best_epoch = start_epoch
|
||||
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
||||
for epoch in range(1, 1 + max_epoch):
|
||||
|
||||
for epoch in range(1 + start_epoch, 1 + max_epoch):
|
||||
policy.train()
|
||||
with tqdm.trange(
|
||||
update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
|
||||
@ -87,15 +103,14 @@ def offline_trainer(
|
||||
policy, test_collector, test_fn, epoch, episode_per_test,
|
||||
logger, gradient_step, reward_metric)
|
||||
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
||||
if best_epoch == -1 or best_reward < rew:
|
||||
best_reward, best_reward_std = rew, rew_std
|
||||
best_epoch = epoch
|
||||
if best_epoch < 0 or best_reward < rew:
|
||||
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn)
|
||||
if verbose:
|
||||
print(
|
||||
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_reward:"
|
||||
f" {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
|
||||
print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
||||
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
|
||||
if stop_fn and stop_fn(best_reward):
|
||||
break
|
||||
return gather_info(start_time, None, test_collector, best_reward, best_reward_std)
|
||||
|
@ -1,13 +1,14 @@
|
||||
import time
|
||||
import tqdm
|
||||
import warnings
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Union, Callable, Optional
|
||||
|
||||
from tianshou.data import Collector
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
|
||||
from tianshou.trainer import test_episode, gather_info
|
||||
from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
|
||||
|
||||
|
||||
def offpolicy_trainer(
|
||||
@ -24,6 +25,8 @@ def offpolicy_trainer(
|
||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||
resume_from_log: bool = False,
|
||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||
logger: BaseLogger = LazyLogger(),
|
||||
verbose: bool = True,
|
||||
@ -57,8 +60,13 @@ def offpolicy_trainer(
|
||||
It can be used to perform custom additional operations, with the signature ``f(
|
||||
num_epoch: int, step_idx: int) -> None``.
|
||||
:param function save_fn: a hook called when the undiscounted average mean reward in
|
||||
evaluation phase gets better, with the signature ``f(policy:BasePolicy) ->
|
||||
evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
|
||||
None``.
|
||||
:param function save_checkpoint_fn: a function to save training process, with the
|
||||
signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can
|
||||
save whatever you want.
|
||||
:param bool resume_from_log: resume env_step/gradient_step and other metadata from
|
||||
existing tensorboard log. Default to False.
|
||||
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
||||
bool``, receives the average undiscounted returns of the testing result,
|
||||
returns a boolean which indicates whether reaching the goal.
|
||||
@ -75,18 +83,24 @@ def offpolicy_trainer(
|
||||
|
||||
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||
"""
|
||||
env_step, gradient_step = 0, 0
|
||||
if save_fn:
|
||||
warnings.warn("Please consider using save_checkpoint_fn instead of save_fn.")
|
||||
|
||||
start_epoch, env_step, gradient_step = 0, 0, 0
|
||||
if resume_from_log:
|
||||
start_epoch, env_step, gradient_step = logger.restore_data()
|
||||
last_rew, last_len = 0.0, 0
|
||||
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
||||
start_time = time.time()
|
||||
train_collector.reset_stat()
|
||||
test_collector.reset_stat()
|
||||
test_in_train = test_in_train and train_collector.policy == policy
|
||||
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test,
|
||||
logger, env_step, reward_metric)
|
||||
best_epoch = 0
|
||||
test_result = test_episode(policy, test_collector, test_fn, start_epoch,
|
||||
episode_per_test, logger, env_step, reward_metric)
|
||||
best_epoch = start_epoch
|
||||
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
||||
for epoch in range(1, 1 + max_epoch):
|
||||
|
||||
for epoch in range(1 + start_epoch, 1 + max_epoch):
|
||||
# train
|
||||
policy.train()
|
||||
with tqdm.tqdm(
|
||||
@ -118,6 +132,8 @@ def offpolicy_trainer(
|
||||
if stop_fn(test_result["rew"]):
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
logger.save_data(
|
||||
epoch, env_step, gradient_step, save_checkpoint_fn)
|
||||
t.set_postfix(**data)
|
||||
return gather_info(
|
||||
start_time, train_collector, test_collector,
|
||||
@ -139,15 +155,14 @@ def offpolicy_trainer(
|
||||
test_result = test_episode(policy, test_collector, test_fn, epoch,
|
||||
episode_per_test, logger, env_step, reward_metric)
|
||||
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
||||
if best_epoch == -1 or best_reward < rew:
|
||||
best_reward, best_reward_std = rew, rew_std
|
||||
best_epoch = epoch
|
||||
if best_epoch < 0 or best_reward < rew:
|
||||
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
|
||||
if verbose:
|
||||
print(
|
||||
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_reward:"
|
||||
f" {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
|
||||
print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
||||
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
|
||||
if stop_fn and stop_fn(best_reward):
|
||||
break
|
||||
return gather_info(start_time, train_collector, test_collector,
|
||||
|
@ -1,13 +1,14 @@
|
||||
import time
|
||||
import tqdm
|
||||
import warnings
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Union, Callable, Optional
|
||||
|
||||
from tianshou.data import Collector
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
|
||||
from tianshou.trainer import test_episode, gather_info
|
||||
from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
|
||||
|
||||
|
||||
def onpolicy_trainer(
|
||||
@ -25,6 +26,8 @@ def onpolicy_trainer(
|
||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||
save_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||
resume_from_log: bool = False,
|
||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||
logger: BaseLogger = LazyLogger(),
|
||||
verbose: bool = True,
|
||||
@ -61,6 +64,11 @@ def onpolicy_trainer(
|
||||
:param function save_fn: a hook called when the undiscounted average mean reward in
|
||||
evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
|
||||
None``.
|
||||
:param function save_checkpoint_fn: a function to save training process, with the
|
||||
signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can
|
||||
save whatever you want.
|
||||
:param bool resume_from_log: resume env_step/gradient_step and other metadata from
|
||||
existing tensorboard log. Default to False.
|
||||
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
||||
bool``, receives the average undiscounted returns of the testing result,
|
||||
returns a boolean which indicates whether reaching the goal.
|
||||
@ -81,18 +89,24 @@ def onpolicy_trainer(
|
||||
|
||||
Only either one of step_per_collect and episode_per_collect can be specified.
|
||||
"""
|
||||
env_step, gradient_step = 0, 0
|
||||
if save_fn:
|
||||
warnings.warn("Please consider using save_checkpoint_fn instead of save_fn.")
|
||||
|
||||
start_epoch, env_step, gradient_step = 0, 0, 0
|
||||
if resume_from_log:
|
||||
start_epoch, env_step, gradient_step = logger.restore_data()
|
||||
last_rew, last_len = 0.0, 0
|
||||
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
|
||||
start_time = time.time()
|
||||
train_collector.reset_stat()
|
||||
test_collector.reset_stat()
|
||||
test_in_train = test_in_train and train_collector.policy == policy
|
||||
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test,
|
||||
logger, env_step, reward_metric)
|
||||
best_epoch = 0
|
||||
test_result = test_episode(policy, test_collector, test_fn, start_epoch,
|
||||
episode_per_test, logger, env_step, reward_metric)
|
||||
best_epoch = start_epoch
|
||||
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
|
||||
for epoch in range(1, 1 + max_epoch):
|
||||
|
||||
for epoch in range(1 + start_epoch, 1 + max_epoch):
|
||||
# train
|
||||
policy.train()
|
||||
with tqdm.tqdm(
|
||||
@ -125,6 +139,8 @@ def onpolicy_trainer(
|
||||
if stop_fn(test_result["rew"]):
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
logger.save_data(
|
||||
epoch, env_step, gradient_step, save_checkpoint_fn)
|
||||
t.set_postfix(**data)
|
||||
return gather_info(
|
||||
start_time, train_collector, test_collector,
|
||||
@ -150,15 +166,14 @@ def onpolicy_trainer(
|
||||
test_result = test_episode(policy, test_collector, test_fn, epoch,
|
||||
episode_per_test, logger, env_step, reward_metric)
|
||||
rew, rew_std = test_result["rew"], test_result["rew_std"]
|
||||
if best_epoch == -1 or best_reward < rew:
|
||||
best_reward, best_reward_std = rew, rew_std
|
||||
best_epoch = epoch
|
||||
if best_epoch < 0 or best_reward < rew:
|
||||
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
|
||||
if verbose:
|
||||
print(
|
||||
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_reward:"
|
||||
f" {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
|
||||
print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
||||
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
|
||||
if stop_fn and stop_fn(best_reward):
|
||||
break
|
||||
return gather_info(start_time, train_collector, test_collector,
|
||||
|
@ -1,8 +1,12 @@
|
||||
import numpy as np
|
||||
from numbers import Number
|
||||
from typing import Any, Union
|
||||
from abc import ABC, abstractmethod
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from typing import Any, Tuple, Union, Callable, Optional
|
||||
from tensorboard.backend.event_processing import event_accumulator
|
||||
|
||||
|
||||
WRITE_TYPE = Union[int, Number, np.number, np.ndarray]
|
||||
|
||||
|
||||
class BaseLogger(ABC):
|
||||
@ -13,9 +17,7 @@ class BaseLogger(ABC):
|
||||
self.writer = writer
|
||||
|
||||
@abstractmethod
|
||||
def write(
|
||||
self, key: str, x: int, y: Union[Number, np.number, np.ndarray], **kwargs: Any
|
||||
) -> None:
|
||||
def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None:
|
||||
"""Specify how the writer is used to log data.
|
||||
|
||||
:param str key: namespace which the input data tuple belongs to.
|
||||
@ -51,6 +53,33 @@ class BaseLogger(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def save_data(
|
||||
self,
|
||||
epoch: int,
|
||||
env_step: int,
|
||||
gradient_step: int,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||
) -> None:
|
||||
"""Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.
|
||||
|
||||
:param int epoch: the epoch in trainer.
|
||||
:param int env_step: the env_step in trainer.
|
||||
:param int gradient_step: the gradient_step in trainer.
|
||||
:param function save_checkpoint_fn: a hook defined by user, see trainer
|
||||
documentation for detail.
|
||||
"""
|
||||
pass
|
||||
|
||||
def restore_data(self) -> Tuple[int, int, int]:
|
||||
"""Return the metadata from existing log.
|
||||
|
||||
If it finds nothing or an error occurs during the recover process, it will
|
||||
return the default parameters.
|
||||
|
||||
:return: epoch, env_step, gradient_step.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BasicLogger(BaseLogger):
|
||||
"""A loggger that relies on tensorboard SummaryWriter by default to visualize \
|
||||
@ -62,6 +91,8 @@ class BasicLogger(BaseLogger):
|
||||
:param int train_interval: the log interval in log_train_data(). Default to 1.
|
||||
:param int test_interval: the log interval in log_test_data(). Default to 1.
|
||||
:param int update_interval: the log interval in log_update_data(). Default to 1000.
|
||||
:param int save_interval: the save interval in save_data(). Default to 1 (save at
|
||||
the end of each epoch).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -70,18 +101,19 @@ class BasicLogger(BaseLogger):
|
||||
train_interval: int = 1,
|
||||
test_interval: int = 1,
|
||||
update_interval: int = 1000,
|
||||
save_interval: int = 1,
|
||||
) -> None:
|
||||
super().__init__(writer)
|
||||
self.train_interval = train_interval
|
||||
self.test_interval = test_interval
|
||||
self.update_interval = update_interval
|
||||
self.save_interval = save_interval
|
||||
self.last_log_train_step = -1
|
||||
self.last_log_test_step = -1
|
||||
self.last_log_update_step = -1
|
||||
self.last_save_step = -1
|
||||
|
||||
def write(
|
||||
self, key: str, x: int, y: Union[Number, np.number, np.ndarray], **kwargs: Any
|
||||
) -> None:
|
||||
def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None:
|
||||
self.writer.add_scalar(key, y, global_step=x)
|
||||
|
||||
def log_train_data(self, collect_result: dict, step: int) -> None:
|
||||
@ -133,6 +165,39 @@ class BasicLogger(BaseLogger):
|
||||
self.write(k, step, v)
|
||||
self.last_log_update_step = step
|
||||
|
||||
def save_data(
|
||||
self,
|
||||
epoch: int,
|
||||
env_step: int,
|
||||
gradient_step: int,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||
) -> None:
|
||||
if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval:
|
||||
self.last_save_step = epoch
|
||||
save_checkpoint_fn(epoch, env_step, gradient_step)
|
||||
self.write("save/epoch", epoch, epoch)
|
||||
self.write("save/env_step", env_step, env_step)
|
||||
self.write("save/gradient_step", gradient_step, gradient_step)
|
||||
|
||||
def restore_data(self) -> Tuple[int, int, int]:
|
||||
ea = event_accumulator.EventAccumulator(self.writer.log_dir)
|
||||
ea.Reload()
|
||||
|
||||
try: # epoch / gradient_step
|
||||
epoch = ea.scalars.Items("save/epoch")[-1].step
|
||||
self.last_save_step = self.last_log_test_step = epoch
|
||||
gradient_step = ea.scalars.Items("save/gradient_step")[-1].step
|
||||
self.last_log_update_step = gradient_step
|
||||
except KeyError:
|
||||
epoch, gradient_step = 0, 0
|
||||
try: # offline trainer doesn't have env_step
|
||||
env_step = ea.scalars.Items("save/env_step")[-1].step
|
||||
self.last_log_train_step = env_step
|
||||
except KeyError:
|
||||
env_step = 0
|
||||
|
||||
return epoch, env_step, gradient_step
|
||||
|
||||
|
||||
class LazyLogger(BasicLogger):
|
||||
"""A loggger that does nothing. Used as the placeholder in trainer."""
|
||||
@ -140,8 +205,6 @@ class LazyLogger(BasicLogger):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(None) # type: ignore
|
||||
|
||||
def write(
|
||||
self, key: str, x: int, y: Union[Number, np.number, np.ndarray], **kwargs: Any
|
||||
) -> None:
|
||||
def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None:
|
||||
"""The LazyLogger writes nothing."""
|
||||
pass
|
||||
|
Loading…
x
Reference in New Issue
Block a user