Fix save_checkpoint_fn return value (#659)

- Fix save_checkpoint_fn return value to checkpoint_path;
- Fix wrong link in doc;
- Fix an off-by-one bug in trainer iterator.
This commit is contained in:
Jiayi Weng 2022-06-02 12:07:07 -05:00 committed by GitHub
parent 6ad5b520fa
commit 5ecea2402e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 116 additions and 100 deletions

View File

@ -48,7 +48,7 @@ And to successfully resume from a checkpoint:
1. Load everything needed in the training process **before trainer initialization**, i.e., policy, optim, buffer;
2. Set ``resume_from_log=True`` with trainer;
We provide an example to show how these steps work: checkout `test_c51.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_c51.py>`_, `test_ppo.py <https://github.com/thu-ml/tianshou/blob/master/test/continuous/test_ppo.py>`_ or `test_il_bcq.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_il_bcq.py>`_ by running
We provide an example to show how these steps work: checkout `test_c51.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_c51.py>`_, `test_ppo.py <https://github.com/thu-ml/tianshou/blob/master/test/continuous/test_ppo.py>`_ or `test_discrete_bcq.py <https://github.com/thu-ml/tianshou/blob/master/test/offline/test_discrete_bcq.py>`_ by running
.. code-block:: console

View File

@ -192,7 +192,7 @@ def test_dqn(args=get_args()):
def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, "checkpoint.pth")
ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
torch.save({"model": policy.state_dict()}, ckpt_path)
return ckpt_path

View File

@ -222,7 +222,7 @@ def test_ppo(args=get_args()):
def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, "checkpoint.pth")
ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
torch.save({"model": policy.state_dict()}, ckpt_path)
return ckpt_path

View File

@ -117,7 +117,7 @@ def test_ppo(args=get_args()):
dual_clip=args.dual_clip,
value_clip=args.value_clip,
gae_lambda=args.gae_lambda,
action_space=env.action_space
action_space=env.action_space,
)
# collector
train_collector = Collector(
@ -125,33 +125,37 @@ def test_ppo(args=get_args()):
)
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'ppo')
log_path = os.path.join(args.logdir, args.task, "ppo")
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer, save_interval=args.save_interval)
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards):
return mean_rewards >= args.reward_threshold
def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, "checkpoint.pth")
# Example: saving by epoch num
# ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
torch.save(
{
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth')
"model": policy.state_dict(),
"optim": optim.state_dict(),
}, ckpt_path
)
return ckpt_path
if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
ckpt_path = os.path.join(log_path, "checkpoint.pth")
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
optim.load_state_dict(checkpoint['optim'])
policy.load_state_dict(checkpoint["model"])
optim.load_state_dict(checkpoint["optim"])
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")
@ -171,7 +175,7 @@ def test_ppo(args=get_args()):
save_best_fn=save_best_fn,
logger=logger,
resume_from_log=args.resume,
save_checkpoint_fn=save_checkpoint_fn
save_checkpoint_fn=save_checkpoint_fn,
)
for epoch, epoch_stat, info in trainer:
@ -181,7 +185,7 @@ def test_ppo(args=get_args()):
assert stop_fn(info["best_reward"])
if __name__ == '__main__':
if __name__ == "__main__":
pprint.pprint(info)
# Let's watch its performance!
env = gym.make(args.task)
@ -197,5 +201,5 @@ def test_ppo_resume(args=get_args()):
test_ppo(args)
if __name__ == '__main__':
if __name__ == "__main__":
test_ppo()

View File

@ -85,7 +85,7 @@ def test_c51(args=get_args()):
hidden_sizes=args.hidden_sizes,
device=args.device,
softmax=True,
num_atoms=args.num_atoms
num_atoms=args.num_atoms,
)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = C51Policy(
@ -96,7 +96,7 @@ def test_c51(args=get_args()):
args.v_min,
args.v_max,
args.n_step,
target_update_freq=args.target_update_freq
target_update_freq=args.target_update_freq,
).to(args.device)
# buffer
if args.prioritized_replay:
@ -104,7 +104,7 @@ def test_c51(args=get_args()):
args.buffer_size,
buffer_num=len(train_envs),
alpha=args.alpha,
beta=args.beta
beta=args.beta,
)
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
@ -114,12 +114,12 @@ def test_c51(args=get_args()):
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, 'c51')
log_path = os.path.join(args.logdir, args.task, "c51")
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer, save_interval=args.save_interval)
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards):
return mean_rewards >= args.reward_threshold
@ -140,29 +140,31 @@ def test_c51(args=get_args()):
def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, "checkpoint.pth")
# Example: saving by epoch num
# ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
torch.save(
{
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth')
)
pickle.dump(
train_collector.buffer,
open(os.path.join(log_path, 'train_buffer.pkl'), "wb")
"model": policy.state_dict(),
"optim": optim.state_dict(),
}, ckpt_path
)
buffer_path = os.path.join(log_path, "train_buffer.pkl")
pickle.dump(train_collector.buffer, open(buffer_path, "wb"))
return ckpt_path
if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
ckpt_path = os.path.join(log_path, "checkpoint.pth")
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
policy.optim.load_state_dict(checkpoint['optim'])
policy.load_state_dict(checkpoint["model"])
policy.optim.load_state_dict(checkpoint["optim"])
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")
buffer_path = os.path.join(log_path, 'train_buffer.pkl')
buffer_path = os.path.join(log_path, "train_buffer.pkl")
if os.path.exists(buffer_path):
train_collector.buffer = pickle.load(open(buffer_path, "rb"))
print("Successfully restore buffer.")
@ -186,11 +188,11 @@ def test_c51(args=get_args()):
save_best_fn=save_best_fn,
logger=logger,
resume_from_log=args.resume,
save_checkpoint_fn=save_checkpoint_fn
save_checkpoint_fn=save_checkpoint_fn,
)
assert stop_fn(result['best_reward'])
assert stop_fn(result["best_reward"])
if __name__ == '__main__':
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
@ -214,5 +216,5 @@ def test_pc51(args=get_args()):
test_c51(args)
if __name__ == '__main__':
if __name__ == "__main__":
test_c51(get_args())

View File

@ -98,7 +98,7 @@ def test_rainbow(args=get_args()):
"linear_layer": noisy_linear
}, {
"linear_layer": noisy_linear
})
}),
)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = RainbowPolicy(
@ -109,7 +109,7 @@ def test_rainbow(args=get_args()):
args.v_min,
args.v_max,
args.n_step,
target_update_freq=args.target_update_freq
target_update_freq=args.target_update_freq,
).to(args.device)
# buffer
if args.prioritized_replay:
@ -118,7 +118,7 @@ def test_rainbow(args=get_args()):
buffer_num=len(train_envs),
alpha=args.alpha,
beta=args.beta,
weight_norm=True
weight_norm=True,
)
else:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
@ -128,12 +128,12 @@ def test_rainbow(args=get_args()):
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
log_path = os.path.join(args.logdir, args.task, 'rainbow')
log_path = os.path.join(args.logdir, args.task, "rainbow")
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer, save_interval=args.save_interval)
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards):
return mean_rewards >= args.reward_threshold
@ -164,21 +164,23 @@ def test_rainbow(args=get_args()):
def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, "checkpoint.pth")
# Example: saving by epoch num
# ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
torch.save(
{
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth')
)
pickle.dump(
train_collector.buffer,
open(os.path.join(log_path, 'train_buffer.pkl'), "wb")
"model": policy.state_dict(),
"optim": optim.state_dict(),
}, ckpt_path
)
buffer_path = os.path.join(log_path, "train_buffer.pkl")
pickle.dump(train_collector.buffer, open(buffer_path, "wb"))
return ckpt_path
if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
ckpt_path = os.path.join(log_path, "checkpoint.pth")
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
@ -186,7 +188,7 @@ def test_rainbow(args=get_args()):
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")
buffer_path = os.path.join(log_path, 'train_buffer.pkl')
buffer_path = os.path.join(log_path, "train_buffer.pkl")
if os.path.exists(buffer_path):
train_collector.buffer = pickle.load(open(buffer_path, "rb"))
print("Successfully restore buffer.")
@ -210,11 +212,11 @@ def test_rainbow(args=get_args()):
save_best_fn=save_best_fn,
logger=logger,
resume_from_log=args.resume,
save_checkpoint_fn=save_checkpoint_fn
save_checkpoint_fn=save_checkpoint_fn,
)
assert stop_fn(result['best_reward'])
assert stop_fn(result["best_reward"])
if __name__ == '__main__':
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
@ -238,5 +240,5 @@ def test_prainbow(args=get_args()):
test_rainbow(args)
if __name__ == '__main__':
if __name__ == "__main__":
test_rainbow(get_args())

View File

@ -25,7 +25,7 @@ else: # pytest
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument('--reward-threshold', type=float, default=None)
parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.001)
parser.add_argument("--lr", type=float, default=3e-4)
@ -37,7 +37,7 @@ def get_args():
parser.add_argument("--epoch", type=int, default=5)
parser.add_argument("--update-per-epoch", type=int, default=2000)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64])
parser.add_argument("--test-num", type=int, default=100)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
@ -104,33 +104,37 @@ def test_discrete_bcq(args=get_args()):
# collector
test_collector = Collector(policy, test_envs, exploration_noise=True)
log_path = os.path.join(args.logdir, args.task, 'discrete_bcq')
log_path = os.path.join(args.logdir, args.task, "discrete_bcq")
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer, save_interval=args.save_interval)
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards):
return mean_rewards >= args.reward_threshold
def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, "checkpoint.pth")
# Example: saving by epoch num
# ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
torch.save(
{
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth')
"model": policy.state_dict(),
"optim": optim.state_dict(),
}, ckpt_path
)
return ckpt_path
if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
ckpt_path = os.path.join(log_path, "checkpoint.pth")
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
optim.load_state_dict(checkpoint['optim'])
policy.load_state_dict(checkpoint["model"])
optim.load_state_dict(checkpoint["optim"])
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")
@ -147,11 +151,11 @@ def test_discrete_bcq(args=get_args()):
save_best_fn=save_best_fn,
logger=logger,
resume_from_log=args.resume,
save_checkpoint_fn=save_checkpoint_fn
save_checkpoint_fn=save_checkpoint_fn,
)
assert stop_fn(result['best_reward'])
assert stop_fn(result["best_reward"])
if __name__ == '__main__':
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)

View File

@ -163,33 +163,37 @@ def test_gail(args=get_args()):
)
test_collector = Collector(policy, test_envs)
# log
log_path = os.path.join(args.logdir, args.task, 'gail')
log_path = os.path.join(args.logdir, args.task, "gail")
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer, save_interval=args.save_interval)
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards):
return mean_rewards >= args.reward_threshold
def save_checkpoint_fn(epoch, env_step, gradient_step):
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, "checkpoint.pth")
# Example: saving by epoch num
# ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
torch.save(
{
'model': policy.state_dict(),
'optim': optim.state_dict(),
}, os.path.join(log_path, 'checkpoint.pth')
"model": policy.state_dict(),
"optim": optim.state_dict(),
}, ckpt_path
)
return ckpt_path
if args.resume:
# load from existing checkpoint
print(f"Loading agent under {log_path}")
ckpt_path = os.path.join(log_path, 'checkpoint.pth')
ckpt_path = os.path.join(log_path, "checkpoint.pth")
if os.path.exists(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=args.device)
policy.load_state_dict(checkpoint['model'])
optim.load_state_dict(checkpoint['optim'])
policy.load_state_dict(checkpoint["model"])
optim.load_state_dict(checkpoint["optim"])
print("Successfully restore policy and optim.")
else:
print("Fail to restore policy and optim.")
@ -211,9 +215,9 @@ def test_gail(args=get_args()):
resume_from_log=args.resume,
save_checkpoint_fn=save_checkpoint_fn,
)
assert stop_fn(result['best_reward'])
assert stop_fn(result["best_reward"])
if __name__ == '__main__':
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
@ -224,5 +228,5 @@ def test_gail(args=get_args()):
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
if __name__ == '__main__':
if __name__ == "__main__":
test_gail()

View File

@ -58,9 +58,9 @@ class BaseTrainer(ABC):
:param function save_best_fn: a hook called when the undiscounted average mean
reward in evaluation phase gets better, with the signature
``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously.
:param function save_checkpoint_fn: a function to save training process, with
the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``;
you can save whatever you want.
:param function save_checkpoint_fn: a function to save training process and
return the saved checkpoint path, with the signature ``f(epoch: int,
env_step: int, gradient_step: int) -> str``; you can save whatever you want.
:param bool resume_from_log: resume env_step/gradient_step and other metadata
from existing tensorboard log. Default to False.
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
@ -147,7 +147,7 @@ class BaseTrainer(ABC):
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_best_fn: Optional[Callable[[BasePolicy], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
resume_from_log: bool = False,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),
@ -259,7 +259,7 @@ class BaseTrainer(ABC):
if self.iter_num > 1:
# iterator exhaustion check
if self.epoch >= self.max_epoch:
if self.epoch > self.max_epoch:
raise StopIteration
# exit flag 1, when stop_fn succeeds in train_step or test_step

View File

@ -31,10 +31,10 @@ class OfflineTrainer(BaseTrainer):
:param function save_best_fn: a hook called when the undiscounted average mean
reward in evaluation phase gets better, with the signature
``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously.
:param function save_checkpoint_fn: a function to save training process,
with the signature ``f(epoch: int, env_step: int, gradient_step: int) ->
None``; you can save whatever you want. Because offline-RL doesn't have
env_step, the env_step is always 0 here.
:param function save_checkpoint_fn: a function to save training process and
return the saved checkpoint path, with the signature ``f(epoch: int,
env_step: int, gradient_step: int) -> str``; you can save whatever you want.
Because offline-RL doesn't have env_step, the env_step is always 0 here.
:param bool resume_from_log: resume gradient_step and other metadata from
existing tensorboard log. Default to False.
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
@ -67,7 +67,7 @@ class OfflineTrainer(BaseTrainer):
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_best_fn: Optional[Callable[[BasePolicy], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
resume_from_log: bool = False,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),

View File

@ -40,9 +40,9 @@ class OffpolicyTrainer(BaseTrainer):
:param function save_best_fn: a hook called when the undiscounted average mean
reward in evaluation phase gets better, with the signature
``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously.
:param function save_checkpoint_fn: a function to save training process, with
the signature ``f(epoch: int, env_step: int, gradient_step: int) -> None``;
you can save whatever you want.
:param function save_checkpoint_fn: a function to save training process and
return the saved checkpoint path, with the signature ``f(epoch: int,
env_step: int, gradient_step: int) -> str``; you can save whatever you want.
:param bool resume_from_log: resume env_step/gradient_step and other metadata
from existing tensorboard log. Default to False.
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
@ -80,7 +80,7 @@ class OffpolicyTrainer(BaseTrainer):
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_best_fn: Optional[Callable[[BasePolicy], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
resume_from_log: bool = False,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),

View File

@ -42,9 +42,9 @@ class OnpolicyTrainer(BaseTrainer):
:param function save_best_fn: a hook called when the undiscounted average mean
reward in evaluation phase gets better, with the signature
``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously.
:param function save_checkpoint_fn: a function to save training process,
with the signature ``f(epoch: int, env_step: int, gradient_step: int)
-> None``; you can save whatever you want.
:param function save_checkpoint_fn: a function to save training process and
return the saved checkpoint path, with the signature ``f(epoch: int,
env_step: int, gradient_step: int) -> str``; you can save whatever you want.
:param bool resume_from_log: resume env_step/gradient_step and other metadata
from existing tensorboard log. Default to False.
:param function stop_fn: a function with signature ``f(mean_rewards: float) ->
@ -88,7 +88,7 @@ class OnpolicyTrainer(BaseTrainer):
test_fn: Optional[Callable[[int, Optional[int]], None]] = None,
stop_fn: Optional[Callable[[float], bool]] = None,
save_best_fn: Optional[Callable[[BasePolicy], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
resume_from_log: bool = False,
reward_metric: Optional[Callable[[np.ndarray], np.ndarray]] = None,
logger: BaseLogger = LazyLogger(),

View File

@ -94,7 +94,7 @@ class BaseLogger(ABC):
epoch: int,
env_step: int,
gradient_step: int,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
) -> None:
"""Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.

View File

@ -47,7 +47,7 @@ class TensorboardLogger(BaseLogger):
epoch: int,
env_step: int,
gradient_step: int,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
) -> None:
if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval:
self.last_save_step = epoch

View File

@ -97,7 +97,7 @@ class WandbLogger(BaseLogger):
epoch: int,
env_step: int,
gradient_step: int,
save_checkpoint_fn: Optional[Callable[[int, int, int], None]] = None,
save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
) -> None:
"""Use writer to log metadata when calling ``save_checkpoint_fn`` in trainer.
@ -118,7 +118,7 @@ class WandbLogger(BaseLogger):
"save/epoch": epoch,
"save/env_step": env_step,
"save/gradient_step": gradient_step,
"checkpoint_path": str(checkpoint_path)
"checkpoint_path": str(checkpoint_path),
}
)
checkpoint_artifact.add_file(str(checkpoint_path))
@ -126,7 +126,7 @@ class WandbLogger(BaseLogger):
def restore_data(self) -> Tuple[int, int, int]:
checkpoint_artifact = self.wandb_run.use_artifact( # type: ignore
'run_' + self.wandb_run.id + '_checkpoint:latest' # type: ignore
f"run_{self.wandb_run.id}_checkpoint:latest" # type: ignore
)
assert checkpoint_artifact is not None, "W&B dataset artifact doesn't exist"