Fix save_checkpoint_fn return value (#659)
- Fix save_checkpoint_fn return value to checkpoint_path; - Fix wrong link in doc; - Fix an off-by-one bug in trainer iterator.
This commit is contained in:
parent
6ad5b520fa
commit
5ecea2402e
@ -48,7 +48,7 @@ And to successfully resume from a checkpoint:
|
||||
1. Load everything needed in the training process **before trainer initialization**, i.e., policy, optim, buffer;
|
||||
2. Set ``resume_from_log=True`` with trainer;
|
||||
|
||||
We provide an example to show how these steps work: checkout `test_c51.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_c51.py>`_, `test_ppo.py <https://github.com/thu-ml/tianshou/blob/master/test/continuous/test_ppo.py>`_ or `test_il_bcq.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_il_bcq.py>`_ by running
|
||||
We provide an example to show how these steps work: checkout `test_c51.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_c51.py>`_, `test_ppo.py <https://github.com/thu-ml/tianshou/blob/master/test/continuous/test_ppo.py>`_ or `test_discrete_bcq.py <https://github.com/thu-ml/tianshou/blob/master/test/offline/test_discrete_bcq.py>`_ by running
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
|
@ -192,7 +192,7 @@ def test_dqn(args=get_args()):
|
||||
|
||||
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
||||
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||
ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
|
||||
torch.save({"model": policy.state_dict()}, ckpt_path)
|
||||
return ckpt_path
|
||||
|
||||
|
@ -222,7 +222,7 @@ def test_ppo(args=get_args()):
|
||||
|
||||
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
||||
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||
ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
|
||||
torch.save({"model": policy.state_dict()}, ckpt_path)
|
||||
return ckpt_path
|
||||
|
||||
|
@ -117,7 +117,7 @@ def test_ppo(args=get_args()):
|
||||
dual_clip=args.dual_clip,
|
||||
value_clip=args.value_clip,
|
||||
gae_lambda=args.gae_lambda,
|
||||
action_space=env.action_space
|
||||
action_space=env.action_space,
|
||||
)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
@ -125,33 +125,37 @@ def test_ppo(args=get_args()):
|
||||
)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'ppo')
|
||||
log_path = os.path.join(args.logdir, args.task, "ppo")
|
||||
writer = SummaryWriter(log_path)
|
||||
logger = TensorboardLogger(writer, save_interval=args.save_interval)
|
||||
|
||||
def save_best_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= args.reward_threshold
|
||||
|
||||
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
||||
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||
# Example: saving by epoch num
|
||||
# ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
|
||||
torch.save(
|
||||
{
|
||||
'model': policy.state_dict(),
|
||||
'optim': optim.state_dict(),
|
||||
}, os.path.join(log_path, 'checkpoint.pth')
|
||||
"model": policy.state_dict(),
|
||||
"optim": optim.state_dict(),
|
||||
}, ckpt_path
|
||||
)
|
||||
return ckpt_path
|
||||
|
||||
if args.resume:
|
||||
# load from existing checkpoint
|
||||
print(f"Loading agent under {log_path}")
|
||||
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
|
||||
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||
if os.path.exists(ckpt_path):
|
||||
checkpoint = torch.load(ckpt_path, map_location=args.device)
|
||||
policy.load_state_dict(checkpoint['model'])
|
||||
optim.load_state_dict(checkpoint['optim'])
|
||||
policy.load_state_dict(checkpoint["model"])
|
||||
optim.load_state_dict(checkpoint["optim"])
|
||||
print("Successfully restore policy and optim.")
|
||||
else:
|
||||
print("Fail to restore policy and optim.")
|
||||
@ -171,7 +175,7 @@ def test_ppo(args=get_args()):
|
||||
save_best_fn=save_best_fn,
|
||||
logger=logger,
|
||||
resume_from_log=args.resume,
|
||||
save_checkpoint_fn=save_checkpoint_fn
|
||||
save_checkpoint_fn=save_checkpoint_fn,
|
||||
)
|
||||
|
||||
for epoch, epoch_stat, info in trainer:
|
||||
@ -181,7 +185,7 @@ def test_ppo(args=get_args()):
|
||||
|
||||
assert stop_fn(info["best_reward"])
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(info)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
@ -197,5 +201,5 @@ def test_ppo_resume(args=get_args()):
|
||||
test_ppo(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_ppo()
|
||||
|
@ -85,7 +85,7 @@ def test_c51(args=get_args()):
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
device=args.device,
|
||||
softmax=True,
|
||||
num_atoms=args.num_atoms
|
||||
num_atoms=args.num_atoms,
|
||||
)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
policy = C51Policy(
|
||||
@ -96,7 +96,7 @@ def test_c51(args=get_args()):
|
||||
args.v_min,
|
||||
args.v_max,
|
||||
args.n_step,
|
||||
target_update_freq=args.target_update_freq
|
||||
target_update_freq=args.target_update_freq,
|
||||
).to(args.device)
|
||||
# buffer
|
||||
if args.prioritized_replay:
|
||||
@ -104,7 +104,7 @@ def test_c51(args=get_args()):
|
||||
args.buffer_size,
|
||||
buffer_num=len(train_envs),
|
||||
alpha=args.alpha,
|
||||
beta=args.beta
|
||||
beta=args.beta,
|
||||
)
|
||||
else:
|
||||
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
|
||||
@ -114,12 +114,12 @@ def test_c51(args=get_args()):
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'c51')
|
||||
log_path = os.path.join(args.logdir, args.task, "c51")
|
||||
writer = SummaryWriter(log_path)
|
||||
logger = TensorboardLogger(writer, save_interval=args.save_interval)
|
||||
|
||||
def save_best_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= args.reward_threshold
|
||||
@ -140,29 +140,31 @@ def test_c51(args=get_args()):
|
||||
|
||||
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
||||
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||
# Example: saving by epoch num
|
||||
# ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
|
||||
torch.save(
|
||||
{
|
||||
'model': policy.state_dict(),
|
||||
'optim': optim.state_dict(),
|
||||
}, os.path.join(log_path, 'checkpoint.pth')
|
||||
)
|
||||
pickle.dump(
|
||||
train_collector.buffer,
|
||||
open(os.path.join(log_path, 'train_buffer.pkl'), "wb")
|
||||
"model": policy.state_dict(),
|
||||
"optim": optim.state_dict(),
|
||||
}, ckpt_path
|
||||
)
|
||||
buffer_path = os.path.join(log_path, "train_buffer.pkl")
|
||||
pickle.dump(train_collector.buffer, open(buffer_path, "wb"))
|
||||
return ckpt_path
|
||||
|
||||
if args.resume:
|
||||
# load from existing checkpoint
|
||||
print(f"Loading agent under {log_path}")
|
||||
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
|
||||
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||
if os.path.exists(ckpt_path):
|
||||
checkpoint = torch.load(ckpt_path, map_location=args.device)
|
||||
policy.load_state_dict(checkpoint['model'])
|
||||
policy.optim.load_state_dict(checkpoint['optim'])
|
||||
policy.load_state_dict(checkpoint["model"])
|
||||
policy.optim.load_state_dict(checkpoint["optim"])
|
||||
print("Successfully restore policy and optim.")
|
||||
else:
|
||||
print("Fail to restore policy and optim.")
|
||||
buffer_path = os.path.join(log_path, 'train_buffer.pkl')
|
||||
buffer_path = os.path.join(log_path, "train_buffer.pkl")
|
||||
if os.path.exists(buffer_path):
|
||||
train_collector.buffer = pickle.load(open(buffer_path, "rb"))
|
||||
print("Successfully restore buffer.")
|
||||
@ -186,11 +188,11 @@ def test_c51(args=get_args()):
|
||||
save_best_fn=save_best_fn,
|
||||
logger=logger,
|
||||
resume_from_log=args.resume,
|
||||
save_checkpoint_fn=save_checkpoint_fn
|
||||
save_checkpoint_fn=save_checkpoint_fn,
|
||||
)
|
||||
assert stop_fn(result['best_reward'])
|
||||
assert stop_fn(result["best_reward"])
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
@ -214,5 +216,5 @@ def test_pc51(args=get_args()):
|
||||
test_c51(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_c51(get_args())
|
||||
|
@ -98,7 +98,7 @@ def test_rainbow(args=get_args()):
|
||||
"linear_layer": noisy_linear
|
||||
}, {
|
||||
"linear_layer": noisy_linear
|
||||
})
|
||||
}),
|
||||
)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
policy = RainbowPolicy(
|
||||
@ -109,7 +109,7 @@ def test_rainbow(args=get_args()):
|
||||
args.v_min,
|
||||
args.v_max,
|
||||
args.n_step,
|
||||
target_update_freq=args.target_update_freq
|
||||
target_update_freq=args.target_update_freq,
|
||||
).to(args.device)
|
||||
# buffer
|
||||
if args.prioritized_replay:
|
||||
@ -118,7 +118,7 @@ def test_rainbow(args=get_args()):
|
||||
buffer_num=len(train_envs),
|
||||
alpha=args.alpha,
|
||||
beta=args.beta,
|
||||
weight_norm=True
|
||||
weight_norm=True,
|
||||
)
|
||||
else:
|
||||
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
|
||||
@ -128,12 +128,12 @@ def test_rainbow(args=get_args()):
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'rainbow')
|
||||
log_path = os.path.join(args.logdir, args.task, "rainbow")
|
||||
writer = SummaryWriter(log_path)
|
||||
logger = TensorboardLogger(writer, save_interval=args.save_interval)
|
||||
|
||||
def save_best_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= args.reward_threshold
|
||||
@ -164,21 +164,23 @@ def test_rainbow(args=get_args()):
|
||||
|
||||
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
||||
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||
# Example: saving by epoch num
|
||||
# ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
|
||||
torch.save(
|
||||
{
|
||||
'model': policy.state_dict(),
|
||||
'optim': optim.state_dict(),
|
||||
}, os.path.join(log_path, 'checkpoint.pth')
|
||||
)
|
||||
pickle.dump(
|
||||
train_collector.buffer,
|
||||
open(os.path.join(log_path, 'train_buffer.pkl'), "wb")
|
||||
"model": policy.state_dict(),
|
||||
"optim": optim.state_dict(),
|
||||
}, ckpt_path
|
||||
)
|
||||
buffer_path = os.path.join(log_path, "train_buffer.pkl")
|
||||
pickle.dump(train_collector.buffer, open(buffer_path, "wb"))
|
||||
return ckpt_path
|
||||
|
||||
if args.resume:
|
||||
# load from existing checkpoint
|
||||
print(f"Loading agent under {log_path}")
|
||||
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
|
||||
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||
if os.path.exists(ckpt_path):
|
||||
checkpoint = torch.load(ckpt_path, map_location=args.device)
|
||||
policy.load_state_dict(checkpoint['model'])
|
||||
@ -186,7 +188,7 @@ def test_rainbow(args=get_args()):
|
||||
print("Successfully restore policy and optim.")
|
||||
else:
|
||||
print("Fail to restore policy and optim.")
|
||||
buffer_path = os.path.join(log_path, 'train_buffer.pkl')
|
||||
buffer_path = os.path.join(log_path, "train_buffer.pkl")
|
||||
if os.path.exists(buffer_path):
|
||||
train_collector.buffer = pickle.load(open(buffer_path, "rb"))
|
||||
print("Successfully restore buffer.")
|
||||
@ -210,11 +212,11 @@ def test_rainbow(args=get_args()):
|
||||
save_best_fn=save_best_fn,
|
||||
logger=logger,
|
||||
resume_from_log=args.resume,
|
||||
save_checkpoint_fn=save_checkpoint_fn
|
||||
save_checkpoint_fn=save_checkpoint_fn,
|
||||
)
|
||||
assert stop_fn(result['best_reward'])
|
||||
assert stop_fn(result["best_reward"])
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
@ -238,5 +240,5 @@ def test_prainbow(args=get_args()):
|
||||
test_rainbow(args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_rainbow(get_args())
|
||||
|
@ -25,7 +25,7 @@ else: # pytest
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="CartPole-v0")
|
||||
parser.add_argument('--reward-threshold', type=float, default=None)
|
||||
parser.add_argument("--reward-threshold", type=float, default=None)
|
||||
parser.add_argument("--seed", type=int, default=1626)
|
||||
parser.add_argument("--eps-test", type=float, default=0.001)
|
||||
parser.add_argument("--lr", type=float, default=3e-4)
|
||||
@ -37,7 +37,7 @@ def get_args():
|
||||
parser.add_argument("--epoch", type=int, default=5)
|
||||
parser.add_argument("--update-per-epoch", type=int, default=2000)
|
||||
parser.add_argument("--batch-size", type=int, default=64)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
|
||||
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64])
|
||||
parser.add_argument("--test-num", type=int, default=100)
|
||||
parser.add_argument("--logdir", type=str, default="log")
|
||||
parser.add_argument("--render", type=float, default=0.)
|
||||
@ -104,33 +104,37 @@ def test_discrete_bcq(args=get_args()):
|
||||
# collector
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
|
||||
log_path = os.path.join(args.logdir, args.task, 'discrete_bcq')
|
||||
log_path = os.path.join(args.logdir, args.task, "discrete_bcq")
|
||||
writer = SummaryWriter(log_path)
|
||||
logger = TensorboardLogger(writer, save_interval=args.save_interval)
|
||||
|
||||
def save_best_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= args.reward_threshold
|
||||
|
||||
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
||||
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||
# Example: saving by epoch num
|
||||
# ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
|
||||
torch.save(
|
||||
{
|
||||
'model': policy.state_dict(),
|
||||
'optim': optim.state_dict(),
|
||||
}, os.path.join(log_path, 'checkpoint.pth')
|
||||
"model": policy.state_dict(),
|
||||
"optim": optim.state_dict(),
|
||||
}, ckpt_path
|
||||
)
|
||||
return ckpt_path
|
||||
|
||||
if args.resume:
|
||||
# load from existing checkpoint
|
||||
print(f"Loading agent under {log_path}")
|
||||
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
|
||||
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||
if os.path.exists(ckpt_path):
|
||||
checkpoint = torch.load(ckpt_path, map_location=args.device)
|
||||
policy.load_state_dict(checkpoint['model'])
|
||||
optim.load_state_dict(checkpoint['optim'])
|
||||
policy.load_state_dict(checkpoint["model"])
|
||||
optim.load_state_dict(checkpoint["optim"])
|
||||
print("Successfully restore policy and optim.")
|
||||
else:
|
||||
print("Fail to restore policy and optim.")
|
||||
@ -147,11 +151,11 @@ def test_discrete_bcq(args=get_args()):
|
||||
save_best_fn=save_best_fn,
|
||||
logger=logger,
|
||||
resume_from_log=args.resume,
|
||||
save_checkpoint_fn=save_checkpoint_fn
|
||||
save_checkpoint_fn=save_checkpoint_fn,
|
||||
)
|
||||
assert stop_fn(result['best_reward'])
|
||||
assert stop_fn(result["best_reward"])
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
|
@ -163,33 +163,37 @@ def test_gail(args=get_args()):
|
||||
)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'gail')
|
||||
log_path = os.path.join(args.logdir, args.task, "gail")
|
||||
writer = SummaryWriter(log_path)
|
||||
logger = TensorboardLogger(writer, save_interval=args.save_interval)
|
||||
|
||||
def save_best_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= args.reward_threshold
|
||||
|
||||
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
||||
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||
# Example: saving by epoch num
|
||||
# ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
|
||||
torch.save(
|
||||
{
|
||||
'model': policy.state_dict(),
|
||||
'optim': optim.state_dict(),
|
||||
}, os.path.join(log_path, 'checkpoint.pth')
|
||||
"model": policy.state_dict(),
|
||||
"optim": optim.state_dict(),
|
||||
}, ckpt_path
|
||||
)
|
||||
return ckpt_path
|
||||
|
||||
if args.resume:
|
||||
# load from existing checkpoint
|
||||
print(f"Loading agent under {log_path}")
|
||||
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
|
||||
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||
if os.path.exists(ckpt_path):
|
||||
checkpoint = torch.load(ckpt_path, map_location=args.device)
|
||||
policy.load_state_dict(checkpoint['model'])
|
||||
optim.load_state_dict(checkpoint['optim'])
|
||||
policy.load_state_dict(checkpoint["model"])
|
||||
optim.load_state_dict(checkpoint["optim"])
|
||||
print("Successfully restore policy and optim.")
|
||||
else:
|
||||
print("Fail to restore policy and optim.")
|
||||
@ -211,9 +215,9 @@ def test_gail(args=get_args()):
|
||||
resume_from_log=args.resume,
|
||||
save_checkpoint_fn=save_checkpoint_fn,
|
||||
)
|
||||
assert stop_fn(result['best_reward'])
|
||||
assert stop_fn(result["best_reward"])
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
@ -224,5 +228,5 @@ def test_gail(args=get_args()):
|
||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_gail()
|
||||
|
@ -58,9 +58,9 @@ class BaseTrainer(ABC):
|
||||
:param function save_best_fn: a hook called when the undiscounted average mean
|
||||
reward in evaluation phase gets better, with the signature
|
||||
``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously.
|
||||
:param function save_checkpoint_fn: a function to save training process, with
|
||||
the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``;
|
||||
you can save whatever you want.
|
||||
:param function save_checkpoint_fn: a function to save training process and
|
||||
return the saved checkpoint path, with the signature ``f(epoch: int,
|
||||
env_step: int, gradient_step: int) -> str``; you can save whatever you want.
|
||||
:param bool resume_from_log: resume env_step/gradient_step and other metadata
|
||||
from existing tensorboard log. Default to False.
|
||||
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
||||
@ -147,7 +147,7 @@ class BaseTrainer(ABC):
|
||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||
save_best_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
|
||||
resume_from_log: bool = False,
|
||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||
logger: BaseLogger = LazyLogger(),
|
||||
@ -259,7 +259,7 @@ class BaseTrainer(ABC):
|
||||
if self.iter_num > 1:
|
||||
|
||||
# iterator exhaustion check
|
||||
if self.epoch >= self.max_epoch:
|
||||
if self.epoch > self.max_epoch:
|
||||
raise StopIteration
|
||||
|
||||
# exit flag 1, when stop_fn succeeds in train_step or test_step
|
||||
|
@ -31,10 +31,10 @@ class OfflineTrainer(BaseTrainer):
|
||||
:param function save_best_fn: a hook called when the undiscounted average mean
|
||||
reward in evaluation phase gets better, with the signature
|
||||
``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously.
|
||||
:param function save_checkpoint_fn: a function to save training process,
|
||||
with the signature ``f(epoch: int, env_step: int, gradient_step: int) ->
|
||||
None``; you can save whatever you want. Because offline-RL doesn't have
|
||||
env_step, the env_step is always 0 here.
|
||||
:param function save_checkpoint_fn: a function to save training process and
|
||||
return the saved checkpoint path, with the signature ``f(epoch: int,
|
||||
env_step: int, gradient_step: int) -> str``; you can save whatever you want.
|
||||
Because offline-RL doesn't have env_step, the env_step is always 0 here.
|
||||
:param bool resume_from_log: resume gradient_step and other metadata from
|
||||
existing tensorboard log. Default to False.
|
||||
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
||||
@ -67,7 +67,7 @@ class OfflineTrainer(BaseTrainer):
|
||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||
save_best_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
|
||||
resume_from_log: bool = False,
|
||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||
logger: BaseLogger = LazyLogger(),
|
||||
|
@ -40,9 +40,9 @@ class OffpolicyTrainer(BaseTrainer):
|
||||
:param function save_best_fn: a hook called when the undiscounted average mean
|
||||
reward in evaluation phase gets better, with the signature
|
||||
``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously.
|
||||
:param function save_checkpoint_fn: a function to save training process, with
|
||||
the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``;
|
||||
you can save whatever you want.
|
||||
:param function save_checkpoint_fn: a function to save training process and
|
||||
return the saved checkpoint path, with the signature ``f(epoch: int,
|
||||
env_step: int, gradient_step: int) -> str``; you can save whatever you want.
|
||||
:param bool resume_from_log: resume env_step/gradient_step and other metadata
|
||||
from existing tensorboard log. Default to False.
|
||||
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
||||
@ -80,7 +80,7 @@ class OffpolicyTrainer(BaseTrainer):
|
||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||
save_best_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
|
||||
resume_from_log: bool = False,
|
||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||
logger: BaseLogger = LazyLogger(),
|
||||
|
@ -42,9 +42,9 @@ class OnpolicyTrainer(BaseTrainer):
|
||||
:param function save_best_fn: a hook called when the undiscounted average mean
|
||||
reward in evaluation phase gets better, with the signature
|
||||
``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously.
|
||||
:param function save_checkpoint_fn: a function to save training process,
|
||||
with the signature ``f(epoch: int, env_step: int, gradient_step: int)
|
||||
-> None``; you can save whatever you want.
|
||||
:param function save_checkpoint_fn: a function to save training process and
|
||||
return the saved checkpoint path, with the signature ``f(epoch: int,
|
||||
env_step: int, gradient_step: int) -> str``; you can save whatever you want.
|
||||
:param bool resume_from_log: resume env_step/gradient_step and other metadata
|
||||
from existing tensorboard log. Default to False.
|
||||
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
||||
@ -88,7 +88,7 @@ class OnpolicyTrainer(BaseTrainer):
|
||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||
save_best_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
|
||||
resume_from_log: bool = False,
|
||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||
logger: BaseLogger = LazyLogger(),
|
||||
|
@ -94,7 +94,7 @@ class BaseLogger(ABC):
|
||||
epoch: int,
|
||||
env_step: int,
|
||||
gradient_step: int,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
|
||||
) -> None:
|
||||
"""Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.
|
||||
|
||||
|
@ -47,7 +47,7 @@ class TensorboardLogger(BaseLogger):
|
||||
epoch: int,
|
||||
env_step: int,
|
||||
gradient_step: int,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
|
||||
) -> None:
|
||||
if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval:
|
||||
self.last_save_step = epoch
|
||||
|
@ -97,7 +97,7 @@ class WandbLogger(BaseLogger):
|
||||
epoch: int,
|
||||
env_step: int,
|
||||
gradient_step: int,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
||||
save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
|
||||
) -> None:
|
||||
"""Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.
|
||||
|
||||
@ -118,7 +118,7 @@ class WandbLogger(BaseLogger):
|
||||
"save/epoch": epoch,
|
||||
"save/env_step": env_step,
|
||||
"save/gradient_step": gradient_step,
|
||||
"checkpoint_path": str(checkpoint_path)
|
||||
"checkpoint_path": str(checkpoint_path),
|
||||
}
|
||||
)
|
||||
checkpoint_artifact.add_file(str(checkpoint_path))
|
||||
@ -126,7 +126,7 @@ class WandbLogger(BaseLogger):
|
||||
|
||||
def restore_data(self) -> Tuple[int, int, int]:
|
||||
checkpoint_artifact = self.wandb_run.use_artifact( # type: ignore
|
||||
'run_' + self.wandb_run.id + '_checkpoint:latest' # type: ignore
|
||||
f"run_{self.wandb_run.id}_checkpoint:latest" # type: ignore
|
||||
)
|
||||
assert checkpoint_artifact is not None, "W&B dataset artifact doesn't exist"
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user