Make trainer resumable (#350)

- specify tensorboard >= 2.5.0
- add `save_checkpoint_fn` and `resume_from_log` in trainer

Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
This commit is contained in:
Ark 2021-05-06 08:53:53 +08:00 committed by GitHub
parent f4e05d585a
commit 84f58636eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 308 additions and 77 deletions

View File

@ -30,6 +30,34 @@ Customize Training Process
See :ref:`customized_trainer`.
.. _resume_training:
Resume Training Process
-----------------------
This is related to `Issue 349 <https://github.com/thu-ml/tianshou/issues/349>`_.
To resume training process from an existing checkpoint, you need to do the following things in the training process:
1. Make sure you write ``save_checkpoint_fn`` which saves everything needed in the training process, i.e., policy, optim, buffer; pass it to trainer;
2. Use ``BasicLogger`` which contains a tensorboard;
3. To adjust the save frequency, specify ``save_interval`` when initializing BasicLogger.
And to successfully resume from a checkpoint:
1. Load everything needed in the training process **before trainer initialization**, i.e., policy, optim, buffer;
2. Set ``resume_from_log=True`` with trainer;
We provide an example to show how these steps work: checkout `test_c51.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_c51.py>`_, `test_ppo.py <https://github.com/thu-ml/tianshou/blob/master/test/continuous/test_ppo.py>`_ or `test_il_bcq.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_il_bcq.py>`_ by running
.. code-block:: console
$ python3 test/discrete/test_c51.py # train some epoch
$ python3 test/discrete/test_c51.py --resume # restore from existing log and continuing training
To correctly render the data (including several tfevent files), we highly recommend using ``tensorboard >= 2.5.0`` (see `here <https://github.com/thu-ml/tianshou/pull/350#issuecomment-829123378>`_ for the reason). Otherwise, it may cause overlapping issue that you need to manually handle with.
.. _parallel_sampling:
Parallel Sampling

View File

@ -85,8 +85,7 @@ def test_discrete_bcq(args=get_args()):
feature_net, args.action_shape, device=args.device,
hidden_sizes=args.hidden_sizes, softmax_output=False).to(args.device)
optim = torch.optim.Adam(
set(policy_net.parameters()).union(imitation_net.parameters()),
lr=args.lr)
list(policy_net.parameters()) + list(imitation_net.parameters()), lr=args.lr)
# define policy
policy = DiscreteBCQPolicy(
policy_net, imitation_net, optim, args.gamma, args.n_step,

View File

@ -101,7 +101,7 @@ def test_a2c(args=get_args()):
torch.nn.init.zeros_(m.bias)
m.weight.data.copy_(0.01 * m.weight.data)
optim = torch.optim.RMSprop(set(actor.parameters()).union(critic.parameters()),
optim = torch.optim.RMSprop(list(actor.parameters()) + list(critic.parameters()),
lr=args.lr, eps=1e-5, alpha=0.99)
lr_scheduler = None

View File

@ -106,8 +106,8 @@ def test_ppo(args=get_args()):
torch.nn.init.zeros_(m.bias)
m.weight.data.copy_(0.01 * m.weight.data)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
optim = torch.optim.Adam(
list(actor.parameters()) + list(critic.parameters()), lr=args.lr)
lr_scheduler = None
if args.lr_decay:

View File

@ -48,7 +48,7 @@ setup(
"gym>=0.15.4",
"tqdm",
"numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793
"tensorboard",
"tensorboard>=2.5.0",
"torch>=1.4.0",
"numba>=0.51.0",
"h5py>=2.10.0", # to match tensorflow's minimal requirements

View File

@ -105,6 +105,7 @@ def test_ddpg(args=get_args()):
update_per_step=args.update_per_step, stop_fn=stop_fn,
save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!

View File

@ -80,8 +80,7 @@ def test_npg(args=get_args()):
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
optim = torch.optim.Adam(critic.parameters(), lr=args.lr)
# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward

View File

@ -47,6 +47,8 @@ def get_args():
parser.add_argument('--value-clip', type=int, default=1)
parser.add_argument('--norm-adv', type=int, default=1)
parser.add_argument('--recompute-adv', type=int, default=0)
parser.add_argument('--resume', action="store_true")
parser.add_argument("--save-interval", type=int, default=4)
args = parser.parse_known_args()[0]
return args
@ -83,8 +85,8 @@ def test_ppo(args=get_args()):
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
optim = torch.optim.Adam(
list(actor.parameters()) + list(critic.parameters()), lr=args.lr)
# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
@ -114,7 +116,7 @@ def test_ppo(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'ppo')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = BasicLogger(writer, save_interval=args.save_interval)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -122,13 +124,34 @@ def test_ppo(args=get_args()):
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
torch.save({
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth'))
if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
optim.load_state_dict(checkpoint['optim'])
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")
# trainer
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
args.repeat_per_collect, args.test_num, args.batch_size,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
logger=logger)
logger=logger, resume_from_log=args.resume,
save_checkpoint_fn=save_checkpoint_fn)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
@ -140,5 +163,10 @@ def test_ppo(args=get_args()):
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
def test_ppo_resume(args=get_args()):
args.resume = True
test_ppo(args)
if __name__ == '__main__':
test_ppo()

View File

@ -124,6 +124,7 @@ def test_sac_with_il(args=get_args()):
update_per_step=args.update_per_step, stop_fn=stop_fn,
save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!

View File

@ -119,6 +119,7 @@ def test_td3(args=get_args()):
update_per_step=args.update_per_step, stop_fn=stop_fn,
save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!

View File

@ -27,7 +27,8 @@ def get_args():
parser.add_argument('--epoch', type=int, default=5)
parser.add_argument('--step-per-epoch', type=int, default=50000)
parser.add_argument('--step-per-collect', type=int, default=2048)
parser.add_argument('--repeat-per-collect', type=int, default=1)
parser.add_argument('--repeat-per-collect', type=int,
default=2) # theoretically it should be 1
parser.add_argument('--batch-size', type=int, default=99999)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--training-num', type=int, default=16)
@ -82,8 +83,7 @@ def test_trpo(args=get_args()):
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
optim = torch.optim.Adam(critic.parameters(), lr=args.lr)
# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward

View File

@ -74,8 +74,8 @@ def test_a2c_with_il(args=get_args()):
device=args.device)
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
critic = Critic(net, device=args.device).to(args.device)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
optim = torch.optim.Adam(
list(actor.parameters()) + list(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical
policy = A2CPolicy(
actor, critic, optim, dist,
@ -106,6 +106,7 @@ def test_a2c_with_il(args=get_args()):
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
@ -135,6 +136,7 @@ def test_a2c_with_il(args=get_args()):
args.il_step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!

View File

@ -1,6 +1,7 @@
import os
import gym
import torch
import pickle
import pprint
import argparse
import numpy as np
@ -43,9 +44,11 @@ def get_args():
action="store_true", default=False)
parser.add_argument('--alpha', type=float, default=0.6)
parser.add_argument('--beta', type=float, default=0.4)
parser.add_argument('--resume', action="store_true")
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument("--save-interval", type=int, default=4)
args = parser.parse_known_args()[0]
return args
@ -90,7 +93,7 @@ def test_c51(args=get_args()):
# log
log_path = os.path.join(args.logdir, args.task, 'c51')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = BasicLogger(writer, save_interval=args.save_interval)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -112,14 +115,42 @@ def test_c51(args=get_args()):
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
torch.save({
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth'))
pickle.dump(train_collector.buffer,
open(os.path.join(log_path, 'train_buffer.pkl'), "wb"))
if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
policy.optim.load_state_dict(checkpoint['optim'])
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")
buffer_path = os.path.join(log_path, 'train_buffer.pkl')
if os.path.exists(buffer_path):
train_collector.buffer = pickle.load(open(buffer_path, "rb"))
print("Successfully restore buffer.")
else:
print("Fail to restore buffer.")
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
@ -132,6 +163,11 @@ def test_c51(args=get_args()):
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
def test_c51_resume(args=get_args()):
args.resume = True
test_c51(args)
def test_pc51(args=get_args()):
args.prioritized_replay = True
args.gamma = .95

View File

@ -120,7 +120,6 @@ def test_dqn(args=get_args()):
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, update_per_step=args.update_per_step, train_fn=train_fn,
test_fn=test_fn, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':

View File

@ -99,8 +99,8 @@ def test_drqn(args=get_args()):
args.batch_size, update_per_step=args.update_per_step,
train_fn=train_fn, test_fn=test_fn, stop_fn=stop_fn,
save_fn=save_fn, logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!

View File

@ -42,6 +42,8 @@ def get_args():
"--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
parser.add_argument("--resume", action="store_true")
parser.add_argument("--save-interval", type=int, default=4)
args = parser.parse_known_args()[0]
return args
@ -67,7 +69,7 @@ def test_discrete_bcq(args=get_args()):
args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, device=args.device).to(args.device)
optim = torch.optim.Adam(
set(policy_net.parameters()).union(imitation_net.parameters()),
list(policy_net.parameters()) + list(imitation_net.parameters()),
lr=args.lr)
policy = DiscreteBCQPolicy(
@ -85,7 +87,7 @@ def test_discrete_bcq(args=get_args()):
log_path = os.path.join(args.logdir, args.task, 'discrete_bcq')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
logger = BasicLogger(writer, save_interval=args.save_interval)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
@ -93,11 +95,30 @@ def test_discrete_bcq(args=get_args()):
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
torch.save({
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth'))
if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
optim.load_state_dict(checkpoint['optim'])
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")
result = offline_trainer(
policy, buffer, test_collector,
args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
stop_fn=stop_fn, save_fn=save_fn, logger=logger)
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
resume_from_log=args.resume, save_checkpoint_fn=save_checkpoint_fn)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
@ -112,5 +133,10 @@ def test_discrete_bcq(args=get_args()):
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
def test_discrete_bcq_resume(args=get_args()):
args.resume = True
test_discrete_bcq(args)
if __name__ == "__main__":
test_discrete_bcq(get_args())

View File

@ -93,6 +93,7 @@ def test_pg(args=get_args()):
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!

View File

@ -75,8 +75,8 @@ def test_ppo(args=get_args()):
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
optim = torch.optim.Adam(
list(actor.parameters()) + list(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical
policy = PPOPolicy(
actor, critic, optim, dist,
@ -114,6 +114,7 @@ def test_ppo(args=get_args()):
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
logger=logger)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!

View File

@ -117,8 +117,8 @@ def test_qrdqn(args=get_args()):
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!

View File

@ -112,6 +112,7 @@ def test_discrete_sac(args=get_args()):
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step, test_in_train=False)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!

View File

@ -1,5 +1,6 @@
import time
import tqdm
import warnings
import numpy as np
from collections import defaultdict
from typing import Dict, Union, Callable, Optional
@ -21,6 +22,8 @@ def offline_trainer(
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
resume_from_log: bool = False,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
@ -44,6 +47,12 @@ def offline_trainer(
:param function save_fn: a hook called when the undiscounted average mean reward in
evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
None``.
:param function save_checkpoint_fn: a function to save training process, with the
signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can
save whatever you want. Because offline-RL doesn't have env_step, the env_step
is always 0 here.
:param bool resume_from_log: resume gradient_step and other metadata from existing
tensorboard log. Default to False.
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
bool``, receives the average undiscounted returns of the testing result,
returns a boolean which indicates whether reaching the goal.
@ -59,15 +68,22 @@ def offline_trainer(
:return: See :func:`~tianshou.trainer.gather_info`.
"""
gradient_step = 0
if save_fn:
warnings.warn("Please consider using save_checkpoint_fn instead of save_fn.")
start_epoch, gradient_step = 0, 0
if resume_from_log:
start_epoch, _, gradient_step = logger.restore_data()
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time()
test_collector.reset_stat()
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test,
logger, gradient_step, reward_metric)
best_epoch = 0
test_result = test_episode(policy, test_collector, test_fn, start_epoch,
episode_per_test, logger, gradient_step, reward_metric)
best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
for epoch in range(1, 1 + max_epoch):
for epoch in range(1 + start_epoch, 1 + max_epoch):
policy.train()
with tqdm.trange(
update_per_epoch, desc=f"Epoch #{epoch}", **tqdm_config
@ -87,15 +103,14 @@ def offline_trainer(
policy, test_collector, test_fn, epoch, episode_per_test,
logger, gradient_step, reward_metric)
rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch == -1 or best_reward < rew:
best_reward, best_reward_std = rew, rew_std
best_epoch = epoch
if best_epoch < 0 or best_reward < rew:
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
if save_fn:
save_fn(policy)
logger.save_data(epoch, 0, gradient_step, save_checkpoint_fn)
if verbose:
print(
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_reward:"
f" {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward):
break
return gather_info(start_time, None, test_collector, best_reward, best_reward_std)

View File

@ -1,13 +1,14 @@
import time
import tqdm
import warnings
import numpy as np
from collections import defaultdict
from typing import Dict, Union, Callable, Optional
from tianshou.data import Collector
from tianshou.policy import BasePolicy
from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
from tianshou.trainer import test_episode, gather_info
from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
def offpolicy_trainer(
@ -24,6 +25,8 @@ def offpolicy_trainer(
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
resume_from_log: bool = False,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
@ -57,8 +60,13 @@ def offpolicy_trainer(
It can be used to perform custom additional operations, with the signature ``f(
num_epoch: int, step_idx: int) -> None``.
:param function save_fn: a hook called when the undiscounted average mean reward in
evaluation phase gets better, with the signature ``f(policy:BasePolicy) ->
evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
None``.
:param function save_checkpoint_fn: a function to save training process, with the
signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can
save whatever you want.
:param bool resume_from_log: resume env_step/gradient_step and other metadata from
existing tensorboard log. Default to False.
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
bool``, receives the average undiscounted returns of the testing result,
returns a boolean which indicates whether reaching the goal.
@ -75,18 +83,24 @@ def offpolicy_trainer(
:return: See :func:`~tianshou.trainer.gather_info`.
"""
env_step, gradient_step = 0, 0
if save_fn:
warnings.warn("Please consider using save_checkpoint_fn instead of save_fn.")
start_epoch, env_step, gradient_step = 0, 0, 0
if resume_from_log:
start_epoch, env_step, gradient_step = logger.restore_data()
last_rew, last_len = 0.0, 0
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time()
train_collector.reset_stat()
test_collector.reset_stat()
test_in_train = test_in_train and train_collector.policy == policy
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test,
logger, env_step, reward_metric)
best_epoch = 0
test_result = test_episode(policy, test_collector, test_fn, start_epoch,
episode_per_test, logger, env_step, reward_metric)
best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
for epoch in range(1, 1 + max_epoch):
for epoch in range(1 + start_epoch, 1 + max_epoch):
# train
policy.train()
with tqdm.tqdm(
@ -118,6 +132,8 @@ def offpolicy_trainer(
if stop_fn(test_result["rew"]):
if save_fn:
save_fn(policy)
logger.save_data(
epoch, env_step, gradient_step, save_checkpoint_fn)
t.set_postfix(**data)
return gather_info(
start_time, train_collector, test_collector,
@ -139,15 +155,14 @@ def offpolicy_trainer(
test_result = test_episode(policy, test_collector, test_fn, epoch,
episode_per_test, logger, env_step, reward_metric)
rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch == -1 or best_reward < rew:
best_reward, best_reward_std = rew, rew_std
best_epoch = epoch
if best_epoch < 0 or best_reward < rew:
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
if save_fn:
save_fn(policy)
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
if verbose:
print(
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_reward:"
f" {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward):
break
return gather_info(start_time, train_collector, test_collector,

View File

@ -1,13 +1,14 @@
import time
import tqdm
import warnings
import numpy as np
from collections import defaultdict
from typing import Dict, Union, Callable, Optional
from tianshou.data import Collector
from tianshou.policy import BasePolicy
from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
from tianshou.trainer import test_episode, gather_info
from tianshou.utils import tqdm_config, MovAvg, BaseLogger, LazyLogger
def onpolicy_trainer(
@ -25,6 +26,8 @@ def onpolicy_trainer(
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_fn: Optional[Callable[[BasePolicy], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
resume_from_log: bool = False,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
verbose: bool = True,
@ -61,6 +64,11 @@ def onpolicy_trainer(
:param function save_fn: a hook called when the undiscounted average mean reward in
evaluation phase gets better, with the signature ``f(policy: BasePolicy) ->
None``.
:param function save_checkpoint_fn: a function to save training process, with the
signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``; you can
save whatever you want.
:param bool resume_from_log: resume env_step/gradient_step and other metadata from
existing tensorboard log. Default to False.
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
bool``, receives the average undiscounted returns of the testing result,
returns a boolean which indicates whether reaching the goal.
@ -81,18 +89,24 @@ def onpolicy_trainer(
Only either one of step_per_collect and episode_per_collect can be specified.
"""
env_step, gradient_step = 0, 0
if save_fn:
warnings.warn("Please consider using save_checkpoint_fn instead of save_fn.")
start_epoch, env_step, gradient_step = 0, 0, 0
if resume_from_log:
start_epoch, env_step, gradient_step = logger.restore_data()
last_rew, last_len = 0.0, 0
stat: Dict[str, MovAvg] = defaultdict(MovAvg)
start_time = time.time()
train_collector.reset_stat()
test_collector.reset_stat()
test_in_train = test_in_train and train_collector.policy == policy
test_result = test_episode(policy, test_collector, test_fn, 0, episode_per_test,
logger, env_step, reward_metric)
best_epoch = 0
test_result = test_episode(policy, test_collector, test_fn, start_epoch,
episode_per_test, logger, env_step, reward_metric)
best_epoch = start_epoch
best_reward, best_reward_std = test_result["rew"], test_result["rew_std"]
for epoch in range(1, 1 + max_epoch):
for epoch in range(1 + start_epoch, 1 + max_epoch):
# train
policy.train()
with tqdm.tqdm(
@ -125,6 +139,8 @@ def onpolicy_trainer(
if stop_fn(test_result["rew"]):
if save_fn:
save_fn(policy)
logger.save_data(
epoch, env_step, gradient_step, save_checkpoint_fn)
t.set_postfix(**data)
return gather_info(
start_time, train_collector, test_collector,
@ -150,15 +166,14 @@ def onpolicy_trainer(
test_result = test_episode(policy, test_collector, test_fn, epoch,
episode_per_test, logger, env_step, reward_metric)
rew, rew_std = test_result["rew"], test_result["rew_std"]
if best_epoch == -1 or best_reward < rew:
best_reward, best_reward_std = rew, rew_std
best_epoch = epoch
if best_epoch < 0 or best_reward < rew:
best_epoch, best_reward, best_reward_std = epoch, rew, rew_std
if save_fn:
save_fn(policy)
logger.save_data(epoch, env_step, gradient_step, save_checkpoint_fn)
if verbose:
print(
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_reward:"
f" {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward):
break
return gather_info(start_time, train_collector, test_collector,

View File

@ -1,8 +1,12 @@
import numpy as np
from numbers import Number
from typing import Any, Union
from abc import ABC, abstractmethod
from torch.utils.tensorboard import SummaryWriter
from typing import Any, Tuple, Union, Callable, Optional
from tensorboard.backend.event_processing import event_accumulator
WRITE_TYPE = Union[int, Number, np.number, np.ndarray]
class BaseLogger(ABC):
@ -13,9 +17,7 @@ class BaseLogger(ABC):
self.writer = writer
@abstractmethod
def write(
self, key: str, x: int, y: Union[Number, np.number, np.ndarray], **kwargs: Any
) -> None:
def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None:
"""Specify how the writer is used to log data.
:param str key: namespace which the input data tuple belongs to.
@ -51,6 +53,33 @@ class BaseLogger(ABC):
"""
pass
def save_data(
self,
epoch: int,
env_step: int,
gradient_step: int,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
) -> None:
"""Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.
:param int epoch: the epoch in trainer.
:param int env_step: the env_step in trainer.
:param int gradient_step: the gradient_step in trainer.
:param function save_checkpoint_fn: a hook defined by user, see trainer
documentation for detail.
"""
pass
def restore_data(self) -> Tuple[int, int, int]:
"""Return the metadata from existing log.
If it finds nothing or an error occurs during the recover process, it will
return the default parameters.
:return: epoch, env_step, gradient_step.
"""
pass
class BasicLogger(BaseLogger):
"""A loggger that relies on tensorboard SummaryWriter by default to visualize \
@ -62,6 +91,8 @@ class BasicLogger(BaseLogger):
:param int train_interval: the log interval in log_train_data(). Default to 1.
:param int test_interval: the log interval in log_test_data(). Default to 1.
:param int update_interval: the log interval in log_update_data(). Default to 1000.
:param int save_interval: the save interval in save_data(). Default to 1 (save at
the end of each epoch).
"""
def __init__(
@ -70,18 +101,19 @@ class BasicLogger(BaseLogger):
train_interval: int = 1,
test_interval: int = 1,
update_interval: int = 1000,
save_interval: int = 1,
) -> None:
super().__init__(writer)
self.train_interval = train_interval
self.test_interval = test_interval
self.update_interval = update_interval
self.save_interval = save_interval
self.last_log_train_step = -1
self.last_log_test_step = -1
self.last_log_update_step = -1
self.last_save_step = -1
def write(
self, key: str, x: int, y: Union[Number, np.number, np.ndarray], **kwargs: Any
) -> None:
def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None:
self.writer.add_scalar(key, y, global_step=x)
def log_train_data(self, collect_result: dict, step: int) -> None:
@ -133,6 +165,39 @@ class BasicLogger(BaseLogger):
self.write(k, step, v)
self.last_log_update_step = step
def save_data(
self,
epoch: int,
env_step: int,
gradient_step: int,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
) -> None:
if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval:
self.last_save_step = epoch
save_checkpoint_fn(epoch, env_step, gradient_step)
self.write("save/epoch", epoch, epoch)
self.write("save/env_step", env_step, env_step)
self.write("save/gradient_step", gradient_step, gradient_step)
def restore_data(self) -> Tuple[int, int, int]:
ea = event_accumulator.EventAccumulator(self.writer.log_dir)
ea.Reload()
try: # epoch / gradient_step
epoch = ea.scalars.Items("save/epoch")[-1].step
self.last_save_step = self.last_log_test_step = epoch
gradient_step = ea.scalars.Items("save/gradient_step")[-1].step
self.last_log_update_step = gradient_step
except KeyError:
epoch, gradient_step = 0, 0
try: # offline trainer doesn't have env_step
env_step = ea.scalars.Items("save/env_step")[-1].step
self.last_log_train_step = env_step
except KeyError:
env_step = 0
return epoch, env_step, gradient_step
class LazyLogger(BasicLogger):
"""A loggger that does nothing. Used as the placeholder in trainer."""
@ -140,8 +205,6 @@ class LazyLogger(BasicLogger):
def __init__(self) -> None:
super().__init__(None) # type: ignore
def write(
self, key: str, x: int, y: Union[Number, np.number, np.ndarray], **kwargs: Any
) -> None:
def write(self, key: str, x: int, y: WRITE_TYPE, **kwargs: Any) -> None:
"""The LazyLogger writes nothing."""
pass