Fix save_checkpoint_fn return value (#659)
- Fix save_checkpoint_fn return value to checkpoint_path; - Fix wrong link in doc; - Fix an off-by-one bug in trainer iterator.
This commit is contained in:
parent
6ad5b520fa
commit
5ecea2402e
@ -48,7 +48,7 @@ And to successfully resume from a checkpoint:
|
|||||||
1. Load everything needed in the training process **before trainer initialization**, i.e., policy, optim, buffer;
|
1. Load everything needed in the training process **before trainer initialization**, i.e., policy, optim, buffer;
|
||||||
2. Set ``resume_from_log=True`` with trainer;
|
2. Set ``resume_from_log=True`` with trainer;
|
||||||
|
|
||||||
We provide an example to show how these steps work: checkout `test_c51.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_c51.py>`_, `test_ppo.py <https://github.com/thu-ml/tianshou/blob/master/test/continuous/test_ppo.py>`_ or `test_il_bcq.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_il_bcq.py>`_ by running
|
We provide an example to show how these steps work: checkout `test_c51.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_c51.py>`_, `test_ppo.py <https://github.com/thu-ml/tianshou/blob/master/test/continuous/test_ppo.py>`_ or `test_discrete_bcq.py <https://github.com/thu-ml/tianshou/blob/master/test/offline/test_discrete_bcq.py>`_ by running
|
||||||
|
|
||||||
.. code-block:: console
|
.. code-block:: console
|
||||||
|
|
||||||
|
|||||||
@ -192,7 +192,7 @@ def test_dqn(args=get_args()):
|
|||||||
|
|
||||||
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
||||||
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||||
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
|
||||||
torch.save({"model": policy.state_dict()}, ckpt_path)
|
torch.save({"model": policy.state_dict()}, ckpt_path)
|
||||||
return ckpt_path
|
return ckpt_path
|
||||||
|
|
||||||
|
|||||||
@ -222,7 +222,7 @@ def test_ppo(args=get_args()):
|
|||||||
|
|
||||||
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
||||||
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||||
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
|
||||||
torch.save({"model": policy.state_dict()}, ckpt_path)
|
torch.save({"model": policy.state_dict()}, ckpt_path)
|
||||||
return ckpt_path
|
return ckpt_path
|
||||||
|
|
||||||
|
|||||||
@ -117,7 +117,7 @@ def test_ppo(args=get_args()):
|
|||||||
dual_clip=args.dual_clip,
|
dual_clip=args.dual_clip,
|
||||||
value_clip=args.value_clip,
|
value_clip=args.value_clip,
|
||||||
gae_lambda=args.gae_lambda,
|
gae_lambda=args.gae_lambda,
|
||||||
action_space=env.action_space
|
action_space=env.action_space,
|
||||||
)
|
)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
@ -125,33 +125,37 @@ def test_ppo(args=get_args()):
|
|||||||
)
|
)
|
||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'ppo')
|
log_path = os.path.join(args.logdir, args.task, "ppo")
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = TensorboardLogger(writer, save_interval=args.save_interval)
|
logger = TensorboardLogger(writer, save_interval=args.save_interval)
|
||||||
|
|
||||||
def save_best_fn(policy):
|
def save_best_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
|
||||||
|
|
||||||
def stop_fn(mean_rewards):
|
def stop_fn(mean_rewards):
|
||||||
return mean_rewards >= args.reward_threshold
|
return mean_rewards >= args.reward_threshold
|
||||||
|
|
||||||
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
||||||
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||||
|
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||||
|
# Example: saving by epoch num
|
||||||
|
# ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
|
||||||
torch.save(
|
torch.save(
|
||||||
{
|
{
|
||||||
'model': policy.state_dict(),
|
"model": policy.state_dict(),
|
||||||
'optim': optim.state_dict(),
|
"optim": optim.state_dict(),
|
||||||
}, os.path.join(log_path, 'checkpoint.pth')
|
}, ckpt_path
|
||||||
)
|
)
|
||||||
|
return ckpt_path
|
||||||
|
|
||||||
if args.resume:
|
if args.resume:
|
||||||
# load from existing checkpoint
|
# load from existing checkpoint
|
||||||
print(f"Loading agent under {log_path}")
|
print(f"Loading agent under {log_path}")
|
||||||
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
|
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||||
if os.path.exists(ckpt_path):
|
if os.path.exists(ckpt_path):
|
||||||
checkpoint = torch.load(ckpt_path, map_location=args.device)
|
checkpoint = torch.load(ckpt_path, map_location=args.device)
|
||||||
policy.load_state_dict(checkpoint['model'])
|
policy.load_state_dict(checkpoint["model"])
|
||||||
optim.load_state_dict(checkpoint['optim'])
|
optim.load_state_dict(checkpoint["optim"])
|
||||||
print("Successfully restore policy and optim.")
|
print("Successfully restore policy and optim.")
|
||||||
else:
|
else:
|
||||||
print("Fail to restore policy and optim.")
|
print("Fail to restore policy and optim.")
|
||||||
@ -171,7 +175,7 @@ def test_ppo(args=get_args()):
|
|||||||
save_best_fn=save_best_fn,
|
save_best_fn=save_best_fn,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
resume_from_log=args.resume,
|
resume_from_log=args.resume,
|
||||||
save_checkpoint_fn=save_checkpoint_fn
|
save_checkpoint_fn=save_checkpoint_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
for epoch, epoch_stat, info in trainer:
|
for epoch, epoch_stat, info in trainer:
|
||||||
@ -181,7 +185,7 @@ def test_ppo(args=get_args()):
|
|||||||
|
|
||||||
assert stop_fn(info["best_reward"])
|
assert stop_fn(info["best_reward"])
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
pprint.pprint(info)
|
pprint.pprint(info)
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
@ -197,5 +201,5 @@ def test_ppo_resume(args=get_args()):
|
|||||||
test_ppo(args)
|
test_ppo(args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
test_ppo()
|
test_ppo()
|
||||||
|
|||||||
@ -85,7 +85,7 @@ def test_c51(args=get_args()):
|
|||||||
hidden_sizes=args.hidden_sizes,
|
hidden_sizes=args.hidden_sizes,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
softmax=True,
|
softmax=True,
|
||||||
num_atoms=args.num_atoms
|
num_atoms=args.num_atoms,
|
||||||
)
|
)
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
policy = C51Policy(
|
policy = C51Policy(
|
||||||
@ -96,7 +96,7 @@ def test_c51(args=get_args()):
|
|||||||
args.v_min,
|
args.v_min,
|
||||||
args.v_max,
|
args.v_max,
|
||||||
args.n_step,
|
args.n_step,
|
||||||
target_update_freq=args.target_update_freq
|
target_update_freq=args.target_update_freq,
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
# buffer
|
# buffer
|
||||||
if args.prioritized_replay:
|
if args.prioritized_replay:
|
||||||
@ -104,7 +104,7 @@ def test_c51(args=get_args()):
|
|||||||
args.buffer_size,
|
args.buffer_size,
|
||||||
buffer_num=len(train_envs),
|
buffer_num=len(train_envs),
|
||||||
alpha=args.alpha,
|
alpha=args.alpha,
|
||||||
beta=args.beta
|
beta=args.beta,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
|
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
|
||||||
@ -114,12 +114,12 @@ def test_c51(args=get_args()):
|
|||||||
# policy.set_eps(1)
|
# policy.set_eps(1)
|
||||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'c51')
|
log_path = os.path.join(args.logdir, args.task, "c51")
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = TensorboardLogger(writer, save_interval=args.save_interval)
|
logger = TensorboardLogger(writer, save_interval=args.save_interval)
|
||||||
|
|
||||||
def save_best_fn(policy):
|
def save_best_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
|
||||||
|
|
||||||
def stop_fn(mean_rewards):
|
def stop_fn(mean_rewards):
|
||||||
return mean_rewards >= args.reward_threshold
|
return mean_rewards >= args.reward_threshold
|
||||||
@ -140,29 +140,31 @@ def test_c51(args=get_args()):
|
|||||||
|
|
||||||
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
||||||
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||||
|
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||||
|
# Example: saving by epoch num
|
||||||
|
# ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
|
||||||
torch.save(
|
torch.save(
|
||||||
{
|
{
|
||||||
'model': policy.state_dict(),
|
"model": policy.state_dict(),
|
||||||
'optim': optim.state_dict(),
|
"optim": optim.state_dict(),
|
||||||
}, os.path.join(log_path, 'checkpoint.pth')
|
}, ckpt_path
|
||||||
)
|
|
||||||
pickle.dump(
|
|
||||||
train_collector.buffer,
|
|
||||||
open(os.path.join(log_path, 'train_buffer.pkl'), "wb")
|
|
||||||
)
|
)
|
||||||
|
buffer_path = os.path.join(log_path, "train_buffer.pkl")
|
||||||
|
pickle.dump(train_collector.buffer, open(buffer_path, "wb"))
|
||||||
|
return ckpt_path
|
||||||
|
|
||||||
if args.resume:
|
if args.resume:
|
||||||
# load from existing checkpoint
|
# load from existing checkpoint
|
||||||
print(f"Loading agent under {log_path}")
|
print(f"Loading agent under {log_path}")
|
||||||
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
|
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||||
if os.path.exists(ckpt_path):
|
if os.path.exists(ckpt_path):
|
||||||
checkpoint = torch.load(ckpt_path, map_location=args.device)
|
checkpoint = torch.load(ckpt_path, map_location=args.device)
|
||||||
policy.load_state_dict(checkpoint['model'])
|
policy.load_state_dict(checkpoint["model"])
|
||||||
policy.optim.load_state_dict(checkpoint['optim'])
|
policy.optim.load_state_dict(checkpoint["optim"])
|
||||||
print("Successfully restore policy and optim.")
|
print("Successfully restore policy and optim.")
|
||||||
else:
|
else:
|
||||||
print("Fail to restore policy and optim.")
|
print("Fail to restore policy and optim.")
|
||||||
buffer_path = os.path.join(log_path, 'train_buffer.pkl')
|
buffer_path = os.path.join(log_path, "train_buffer.pkl")
|
||||||
if os.path.exists(buffer_path):
|
if os.path.exists(buffer_path):
|
||||||
train_collector.buffer = pickle.load(open(buffer_path, "rb"))
|
train_collector.buffer = pickle.load(open(buffer_path, "rb"))
|
||||||
print("Successfully restore buffer.")
|
print("Successfully restore buffer.")
|
||||||
@ -186,11 +188,11 @@ def test_c51(args=get_args()):
|
|||||||
save_best_fn=save_best_fn,
|
save_best_fn=save_best_fn,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
resume_from_log=args.resume,
|
resume_from_log=args.resume,
|
||||||
save_checkpoint_fn=save_checkpoint_fn
|
save_checkpoint_fn=save_checkpoint_fn,
|
||||||
)
|
)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result["best_reward"])
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
@ -214,5 +216,5 @@ def test_pc51(args=get_args()):
|
|||||||
test_c51(args)
|
test_c51(args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
test_c51(get_args())
|
test_c51(get_args())
|
||||||
|
|||||||
@ -98,7 +98,7 @@ def test_rainbow(args=get_args()):
|
|||||||
"linear_layer": noisy_linear
|
"linear_layer": noisy_linear
|
||||||
}, {
|
}, {
|
||||||
"linear_layer": noisy_linear
|
"linear_layer": noisy_linear
|
||||||
})
|
}),
|
||||||
)
|
)
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
policy = RainbowPolicy(
|
policy = RainbowPolicy(
|
||||||
@ -109,7 +109,7 @@ def test_rainbow(args=get_args()):
|
|||||||
args.v_min,
|
args.v_min,
|
||||||
args.v_max,
|
args.v_max,
|
||||||
args.n_step,
|
args.n_step,
|
||||||
target_update_freq=args.target_update_freq
|
target_update_freq=args.target_update_freq,
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
# buffer
|
# buffer
|
||||||
if args.prioritized_replay:
|
if args.prioritized_replay:
|
||||||
@ -118,7 +118,7 @@ def test_rainbow(args=get_args()):
|
|||||||
buffer_num=len(train_envs),
|
buffer_num=len(train_envs),
|
||||||
alpha=args.alpha,
|
alpha=args.alpha,
|
||||||
beta=args.beta,
|
beta=args.beta,
|
||||||
weight_norm=True
|
weight_norm=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
|
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
|
||||||
@ -128,12 +128,12 @@ def test_rainbow(args=get_args()):
|
|||||||
# policy.set_eps(1)
|
# policy.set_eps(1)
|
||||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'rainbow')
|
log_path = os.path.join(args.logdir, args.task, "rainbow")
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = TensorboardLogger(writer, save_interval=args.save_interval)
|
logger = TensorboardLogger(writer, save_interval=args.save_interval)
|
||||||
|
|
||||||
def save_best_fn(policy):
|
def save_best_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
|
||||||
|
|
||||||
def stop_fn(mean_rewards):
|
def stop_fn(mean_rewards):
|
||||||
return mean_rewards >= args.reward_threshold
|
return mean_rewards >= args.reward_threshold
|
||||||
@ -164,21 +164,23 @@ def test_rainbow(args=get_args()):
|
|||||||
|
|
||||||
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
||||||
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||||
|
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||||
|
# Example: saving by epoch num
|
||||||
|
# ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
|
||||||
torch.save(
|
torch.save(
|
||||||
{
|
{
|
||||||
'model': policy.state_dict(),
|
"model": policy.state_dict(),
|
||||||
'optim': optim.state_dict(),
|
"optim": optim.state_dict(),
|
||||||
}, os.path.join(log_path, 'checkpoint.pth')
|
}, ckpt_path
|
||||||
)
|
|
||||||
pickle.dump(
|
|
||||||
train_collector.buffer,
|
|
||||||
open(os.path.join(log_path, 'train_buffer.pkl'), "wb")
|
|
||||||
)
|
)
|
||||||
|
buffer_path = os.path.join(log_path, "train_buffer.pkl")
|
||||||
|
pickle.dump(train_collector.buffer, open(buffer_path, "wb"))
|
||||||
|
return ckpt_path
|
||||||
|
|
||||||
if args.resume:
|
if args.resume:
|
||||||
# load from existing checkpoint
|
# load from existing checkpoint
|
||||||
print(f"Loading agent under {log_path}")
|
print(f"Loading agent under {log_path}")
|
||||||
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
|
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||||
if os.path.exists(ckpt_path):
|
if os.path.exists(ckpt_path):
|
||||||
checkpoint = torch.load(ckpt_path, map_location=args.device)
|
checkpoint = torch.load(ckpt_path, map_location=args.device)
|
||||||
policy.load_state_dict(checkpoint['model'])
|
policy.load_state_dict(checkpoint['model'])
|
||||||
@ -186,7 +188,7 @@ def test_rainbow(args=get_args()):
|
|||||||
print("Successfully restore policy and optim.")
|
print("Successfully restore policy and optim.")
|
||||||
else:
|
else:
|
||||||
print("Fail to restore policy and optim.")
|
print("Fail to restore policy and optim.")
|
||||||
buffer_path = os.path.join(log_path, 'train_buffer.pkl')
|
buffer_path = os.path.join(log_path, "train_buffer.pkl")
|
||||||
if os.path.exists(buffer_path):
|
if os.path.exists(buffer_path):
|
||||||
train_collector.buffer = pickle.load(open(buffer_path, "rb"))
|
train_collector.buffer = pickle.load(open(buffer_path, "rb"))
|
||||||
print("Successfully restore buffer.")
|
print("Successfully restore buffer.")
|
||||||
@ -210,11 +212,11 @@ def test_rainbow(args=get_args()):
|
|||||||
save_best_fn=save_best_fn,
|
save_best_fn=save_best_fn,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
resume_from_log=args.resume,
|
resume_from_log=args.resume,
|
||||||
save_checkpoint_fn=save_checkpoint_fn
|
save_checkpoint_fn=save_checkpoint_fn,
|
||||||
)
|
)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result["best_reward"])
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
@ -238,5 +240,5 @@ def test_prainbow(args=get_args()):
|
|||||||
test_rainbow(args)
|
test_rainbow(args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
test_rainbow(get_args())
|
test_rainbow(get_args())
|
||||||
|
|||||||
@ -25,7 +25,7 @@ else: # pytest
|
|||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--task", type=str, default="CartPole-v0")
|
parser.add_argument("--task", type=str, default="CartPole-v0")
|
||||||
parser.add_argument('--reward-threshold', type=float, default=None)
|
parser.add_argument("--reward-threshold", type=float, default=None)
|
||||||
parser.add_argument("--seed", type=int, default=1626)
|
parser.add_argument("--seed", type=int, default=1626)
|
||||||
parser.add_argument("--eps-test", type=float, default=0.001)
|
parser.add_argument("--eps-test", type=float, default=0.001)
|
||||||
parser.add_argument("--lr", type=float, default=3e-4)
|
parser.add_argument("--lr", type=float, default=3e-4)
|
||||||
@ -37,7 +37,7 @@ def get_args():
|
|||||||
parser.add_argument("--epoch", type=int, default=5)
|
parser.add_argument("--epoch", type=int, default=5)
|
||||||
parser.add_argument("--update-per-epoch", type=int, default=2000)
|
parser.add_argument("--update-per-epoch", type=int, default=2000)
|
||||||
parser.add_argument("--batch-size", type=int, default=64)
|
parser.add_argument("--batch-size", type=int, default=64)
|
||||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
|
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64])
|
||||||
parser.add_argument("--test-num", type=int, default=100)
|
parser.add_argument("--test-num", type=int, default=100)
|
||||||
parser.add_argument("--logdir", type=str, default="log")
|
parser.add_argument("--logdir", type=str, default="log")
|
||||||
parser.add_argument("--render", type=float, default=0.)
|
parser.add_argument("--render", type=float, default=0.)
|
||||||
@ -104,33 +104,37 @@ def test_discrete_bcq(args=get_args()):
|
|||||||
# collector
|
# collector
|
||||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||||
|
|
||||||
log_path = os.path.join(args.logdir, args.task, 'discrete_bcq')
|
log_path = os.path.join(args.logdir, args.task, "discrete_bcq")
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = TensorboardLogger(writer, save_interval=args.save_interval)
|
logger = TensorboardLogger(writer, save_interval=args.save_interval)
|
||||||
|
|
||||||
def save_best_fn(policy):
|
def save_best_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
|
||||||
|
|
||||||
def stop_fn(mean_rewards):
|
def stop_fn(mean_rewards):
|
||||||
return mean_rewards >= args.reward_threshold
|
return mean_rewards >= args.reward_threshold
|
||||||
|
|
||||||
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
||||||
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||||
|
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||||
|
# Example: saving by epoch num
|
||||||
|
# ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
|
||||||
torch.save(
|
torch.save(
|
||||||
{
|
{
|
||||||
'model': policy.state_dict(),
|
"model": policy.state_dict(),
|
||||||
'optim': optim.state_dict(),
|
"optim": optim.state_dict(),
|
||||||
}, os.path.join(log_path, 'checkpoint.pth')
|
}, ckpt_path
|
||||||
)
|
)
|
||||||
|
return ckpt_path
|
||||||
|
|
||||||
if args.resume:
|
if args.resume:
|
||||||
# load from existing checkpoint
|
# load from existing checkpoint
|
||||||
print(f"Loading agent under {log_path}")
|
print(f"Loading agent under {log_path}")
|
||||||
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
|
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||||
if os.path.exists(ckpt_path):
|
if os.path.exists(ckpt_path):
|
||||||
checkpoint = torch.load(ckpt_path, map_location=args.device)
|
checkpoint = torch.load(ckpt_path, map_location=args.device)
|
||||||
policy.load_state_dict(checkpoint['model'])
|
policy.load_state_dict(checkpoint["model"])
|
||||||
optim.load_state_dict(checkpoint['optim'])
|
optim.load_state_dict(checkpoint["optim"])
|
||||||
print("Successfully restore policy and optim.")
|
print("Successfully restore policy and optim.")
|
||||||
else:
|
else:
|
||||||
print("Fail to restore policy and optim.")
|
print("Fail to restore policy and optim.")
|
||||||
@ -147,11 +151,11 @@ def test_discrete_bcq(args=get_args()):
|
|||||||
save_best_fn=save_best_fn,
|
save_best_fn=save_best_fn,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
resume_from_log=args.resume,
|
resume_from_log=args.resume,
|
||||||
save_checkpoint_fn=save_checkpoint_fn
|
save_checkpoint_fn=save_checkpoint_fn,
|
||||||
)
|
)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result["best_reward"])
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
|
|||||||
@ -163,33 +163,37 @@ def test_gail(args=get_args()):
|
|||||||
)
|
)
|
||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
# log
|
# log
|
||||||
log_path = os.path.join(args.logdir, args.task, 'gail')
|
log_path = os.path.join(args.logdir, args.task, "gail")
|
||||||
writer = SummaryWriter(log_path)
|
writer = SummaryWriter(log_path)
|
||||||
logger = TensorboardLogger(writer, save_interval=args.save_interval)
|
logger = TensorboardLogger(writer, save_interval=args.save_interval)
|
||||||
|
|
||||||
def save_best_fn(policy):
|
def save_best_fn(policy):
|
||||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
|
||||||
|
|
||||||
def stop_fn(mean_rewards):
|
def stop_fn(mean_rewards):
|
||||||
return mean_rewards >= args.reward_threshold
|
return mean_rewards >= args.reward_threshold
|
||||||
|
|
||||||
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
def save_checkpoint_fn(epoch, env_step, gradient_step):
|
||||||
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
|
||||||
|
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||||
|
# Example: saving by epoch num
|
||||||
|
# ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
|
||||||
torch.save(
|
torch.save(
|
||||||
{
|
{
|
||||||
'model': policy.state_dict(),
|
"model": policy.state_dict(),
|
||||||
'optim': optim.state_dict(),
|
"optim": optim.state_dict(),
|
||||||
}, os.path.join(log_path, 'checkpoint.pth')
|
}, ckpt_path
|
||||||
)
|
)
|
||||||
|
return ckpt_path
|
||||||
|
|
||||||
if args.resume:
|
if args.resume:
|
||||||
# load from existing checkpoint
|
# load from existing checkpoint
|
||||||
print(f"Loading agent under {log_path}")
|
print(f"Loading agent under {log_path}")
|
||||||
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
|
ckpt_path = os.path.join(log_path, "checkpoint.pth")
|
||||||
if os.path.exists(ckpt_path):
|
if os.path.exists(ckpt_path):
|
||||||
checkpoint = torch.load(ckpt_path, map_location=args.device)
|
checkpoint = torch.load(ckpt_path, map_location=args.device)
|
||||||
policy.load_state_dict(checkpoint['model'])
|
policy.load_state_dict(checkpoint["model"])
|
||||||
optim.load_state_dict(checkpoint['optim'])
|
optim.load_state_dict(checkpoint["optim"])
|
||||||
print("Successfully restore policy and optim.")
|
print("Successfully restore policy and optim.")
|
||||||
else:
|
else:
|
||||||
print("Fail to restore policy and optim.")
|
print("Fail to restore policy and optim.")
|
||||||
@ -211,9 +215,9 @@ def test_gail(args=get_args()):
|
|||||||
resume_from_log=args.resume,
|
resume_from_log=args.resume,
|
||||||
save_checkpoint_fn=save_checkpoint_fn,
|
save_checkpoint_fn=save_checkpoint_fn,
|
||||||
)
|
)
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result["best_reward"])
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
@ -224,5 +228,5 @@ def test_gail(args=get_args()):
|
|||||||
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
test_gail()
|
test_gail()
|
||||||
|
|||||||
@ -58,9 +58,9 @@ class BaseTrainer(ABC):
|
|||||||
:param function save_best_fn: a hook called when the undiscounted average mean
|
:param function save_best_fn: a hook called when the undiscounted average mean
|
||||||
reward in evaluation phase gets better, with the signature
|
reward in evaluation phase gets better, with the signature
|
||||||
``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously.
|
``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously.
|
||||||
:param function save_checkpoint_fn: a function to save training process, with
|
:param function save_checkpoint_fn: a function to save training process and
|
||||||
the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``;
|
return the saved checkpoint path, with the signature ``f(epoch: int,
|
||||||
you can save whatever you want.
|
env_step: int, gradient_step: int) -> str``; you can save whatever you want.
|
||||||
:param bool resume_from_log: resume env_step/gradient_step and other metadata
|
:param bool resume_from_log: resume env_step/gradient_step and other metadata
|
||||||
from existing tensorboard log. Default to False.
|
from existing tensorboard log. Default to False.
|
||||||
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
||||||
@ -147,7 +147,7 @@ class BaseTrainer(ABC):
|
|||||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||||
save_best_fn: Optional[Callable[[BasePolicy], None]] = None,
|
save_best_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
|
||||||
resume_from_log: bool = False,
|
resume_from_log: bool = False,
|
||||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||||
logger: BaseLogger = LazyLogger(),
|
logger: BaseLogger = LazyLogger(),
|
||||||
@ -259,7 +259,7 @@ class BaseTrainer(ABC):
|
|||||||
if self.iter_num > 1:
|
if self.iter_num > 1:
|
||||||
|
|
||||||
# iterator exhaustion check
|
# iterator exhaustion check
|
||||||
if self.epoch >= self.max_epoch:
|
if self.epoch > self.max_epoch:
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
|
|
||||||
# exit flag 1, when stop_fn succeeds in train_step or test_step
|
# exit flag 1, when stop_fn succeeds in train_step or test_step
|
||||||
|
|||||||
@ -31,10 +31,10 @@ class OfflineTrainer(BaseTrainer):
|
|||||||
:param function save_best_fn: a hook called when the undiscounted average mean
|
:param function save_best_fn: a hook called when the undiscounted average mean
|
||||||
reward in evaluation phase gets better, with the signature
|
reward in evaluation phase gets better, with the signature
|
||||||
``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously.
|
``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously.
|
||||||
:param function save_checkpoint_fn: a function to save training process,
|
:param function save_checkpoint_fn: a function to save training process and
|
||||||
with the signature ``f(epoch: int, env_step: int, gradient_step: int) ->
|
return the saved checkpoint path, with the signature ``f(epoch: int,
|
||||||
None``; you can save whatever you want. Because offline-RL doesn't have
|
env_step: int, gradient_step: int) -> str``; you can save whatever you want.
|
||||||
env_step, the env_step is always 0 here.
|
Because offline-RL doesn't have env_step, the env_step is always 0 here.
|
||||||
:param bool resume_from_log: resume gradient_step and other metadata from
|
:param bool resume_from_log: resume gradient_step and other metadata from
|
||||||
existing tensorboard log. Default to False.
|
existing tensorboard log. Default to False.
|
||||||
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
||||||
@ -67,7 +67,7 @@ class OfflineTrainer(BaseTrainer):
|
|||||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||||
save_best_fn: Optional[Callable[[BasePolicy], None]] = None,
|
save_best_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
|
||||||
resume_from_log: bool = False,
|
resume_from_log: bool = False,
|
||||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||||
logger: BaseLogger = LazyLogger(),
|
logger: BaseLogger = LazyLogger(),
|
||||||
|
|||||||
@ -40,9 +40,9 @@ class OffpolicyTrainer(BaseTrainer):
|
|||||||
:param function save_best_fn: a hook called when the undiscounted average mean
|
:param function save_best_fn: a hook called when the undiscounted average mean
|
||||||
reward in evaluation phase gets better, with the signature
|
reward in evaluation phase gets better, with the signature
|
||||||
``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously.
|
``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously.
|
||||||
:param function save_checkpoint_fn: a function to save training process, with
|
:param function save_checkpoint_fn: a function to save training process and
|
||||||
the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``;
|
return the saved checkpoint path, with the signature ``f(epoch: int,
|
||||||
you can save whatever you want.
|
env_step: int, gradient_step: int) -> str``; you can save whatever you want.
|
||||||
:param bool resume_from_log: resume env_step/gradient_step and other metadata
|
:param bool resume_from_log: resume env_step/gradient_step and other metadata
|
||||||
from existing tensorboard log. Default to False.
|
from existing tensorboard log. Default to False.
|
||||||
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
||||||
@ -80,7 +80,7 @@ class OffpolicyTrainer(BaseTrainer):
|
|||||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||||
save_best_fn: Optional[Callable[[BasePolicy], None]] = None,
|
save_best_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
|
||||||
resume_from_log: bool = False,
|
resume_from_log: bool = False,
|
||||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||||
logger: BaseLogger = LazyLogger(),
|
logger: BaseLogger = LazyLogger(),
|
||||||
|
|||||||
@ -42,9 +42,9 @@ class OnpolicyTrainer(BaseTrainer):
|
|||||||
:param function save_best_fn: a hook called when the undiscounted average mean
|
:param function save_best_fn: a hook called when the undiscounted average mean
|
||||||
reward in evaluation phase gets better, with the signature
|
reward in evaluation phase gets better, with the signature
|
||||||
``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously.
|
``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously.
|
||||||
:param function save_checkpoint_fn: a function to save training process,
|
:param function save_checkpoint_fn: a function to save training process and
|
||||||
with the signature ``f(epoch: int, env_step: int, gradient_step: int)
|
return the saved checkpoint path, with the signature ``f(epoch: int,
|
||||||
-> None``; you can save whatever you want.
|
env_step: int, gradient_step: int) -> str``; you can save whatever you want.
|
||||||
:param bool resume_from_log: resume env_step/gradient_step and other metadata
|
:param bool resume_from_log: resume env_step/gradient_step and other metadata
|
||||||
from existing tensorboard log. Default to False.
|
from existing tensorboard log. Default to False.
|
||||||
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
|
||||||
@ -88,7 +88,7 @@ class OnpolicyTrainer(BaseTrainer):
|
|||||||
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
|
||||||
stop_fn: Optional[Callable[[float], bool]] = None,
|
stop_fn: Optional[Callable[[float], bool]] = None,
|
||||||
save_best_fn: Optional[Callable[[BasePolicy], None]] = None,
|
save_best_fn: Optional[Callable[[BasePolicy], None]] = None,
|
||||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
|
||||||
resume_from_log: bool = False,
|
resume_from_log: bool = False,
|
||||||
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
|
||||||
logger: BaseLogger = LazyLogger(),
|
logger: BaseLogger = LazyLogger(),
|
||||||
|
|||||||
@ -94,7 +94,7 @@ class BaseLogger(ABC):
|
|||||||
epoch: int,
|
epoch: int,
|
||||||
env_step: int,
|
env_step: int,
|
||||||
gradient_step: int,
|
gradient_step: int,
|
||||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.
|
"""Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.
|
||||||
|
|
||||||
|
|||||||
@ -47,7 +47,7 @@ class TensorboardLogger(BaseLogger):
|
|||||||
epoch: int,
|
epoch: int,
|
||||||
env_step: int,
|
env_step: int,
|
||||||
gradient_step: int,
|
gradient_step: int,
|
||||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval:
|
if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval:
|
||||||
self.last_save_step = epoch
|
self.last_save_step = epoch
|
||||||
|
|||||||
@ -97,7 +97,7 @@ class WandbLogger(BaseLogger):
|
|||||||
epoch: int,
|
epoch: int,
|
||||||
env_step: int,
|
env_step: int,
|
||||||
gradient_step: int,
|
gradient_step: int,
|
||||||
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
|
save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.
|
"""Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.
|
||||||
|
|
||||||
@ -118,7 +118,7 @@ class WandbLogger(BaseLogger):
|
|||||||
"save/epoch": epoch,
|
"save/epoch": epoch,
|
||||||
"save/env_step": env_step,
|
"save/env_step": env_step,
|
||||||
"save/gradient_step": gradient_step,
|
"save/gradient_step": gradient_step,
|
||||||
"checkpoint_path": str(checkpoint_path)
|
"checkpoint_path": str(checkpoint_path),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
checkpoint_artifact.add_file(str(checkpoint_path))
|
checkpoint_artifact.add_file(str(checkpoint_path))
|
||||||
@ -126,7 +126,7 @@ class WandbLogger(BaseLogger):
|
|||||||
|
|
||||||
def restore_data(self) -> Tuple[int, int, int]:
|
def restore_data(self) -> Tuple[int, int, int]:
|
||||||
checkpoint_artifact = self.wandb_run.use_artifact( # type: ignore
|
checkpoint_artifact = self.wandb_run.use_artifact( # type: ignore
|
||||||
'run_' + self.wandb_run.id + '_checkpoint:latest' # type: ignore
|
f"run_{self.wandb_run.id}_checkpoint:latest" # type: ignore
|
||||||
)
|
)
|
||||||
assert checkpoint_artifact is not None, "W&B dataset artifact doesn't exist"
|
assert checkpoint_artifact is not None, "W&B dataset artifact doesn't exist"
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user