added save and load for optimizers

This commit is contained in:
NM512 2023-09-27 09:15:37 +09:00
parent 16635df3e4
commit d3576c5a98
3 changed files with 43 additions and 5 deletions

View File

@ -319,8 +319,10 @@ def main(config):
train_dataset, train_dataset,
).to(config.device) ).to(config.device)
agent.requires_grad_(requires_grad=False) agent.requires_grad_(requires_grad=False)
if (logdir / "latest_model.pt").exists(): if (logdir / "latest.pt").exists():
agent.load_state_dict(torch.load(logdir / "latest_model.pt")) checkpoint = torch.load(logdir / "latest.pt")
agent.load_state_dict(checkpoint["agent_state_dict"])
tools.recursively_load_optim_state_dict(agent, checkpoint["optims_state_dict"])
agent._should_pretrain._once = False agent._should_pretrain._once = False
# make sure eval will be executed once after config.steps # make sure eval will be executed once after config.steps
@ -352,7 +354,11 @@ def main(config):
steps=config.eval_every, steps=config.eval_every,
state=state, state=state,
) )
torch.save(agent.state_dict(), logdir / "latest_model.pt") items_to_save = {
"agent_state_dict": agent.state_dict(),
"optims_state_dict": tools.recursively_collect_optim_state_dict(agent),
}
torch.save(items_to_save, logdir / "latest.pt")
for env in train_envs + eval_envs: for env in train_envs + eval_envs:
try: try:
env.close() env.close()

View File

@ -70,7 +70,7 @@ class Plan2Explore(nn.Module):
[networks.MLP(**kw) for _ in range(config.disag_models)] [networks.MLP(**kw) for _ in range(config.disag_models)]
) )
kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp) kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
self._model_opt = tools.Optimizer( self._expl_opt = tools.Optimizer(
"explorer", "explorer",
self.parameters(), self.parameters(),
config.model_lr, config.model_lr,
@ -129,5 +129,5 @@ class Plan2Explore(nn.Module):
[torch.mean(pred.log_prob(targets))[None] for pred in preds], 0 [torch.mean(pred.log_prob(targets))[None] for pred in preds], 0
) )
loss = -torch.mean(likes) loss = -torch.mean(likes)
metrics = self._model_opt(loss, self.parameters()) metrics = self._expl_opt(loss, self.parameters())
return metrics return metrics

View File

@ -970,3 +970,35 @@ def enable_deterministic_run():
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True) torch.use_deterministic_algorithms(True)
def recursively_collect_optim_state_dict(obj, path="", optimizers_state_dicts=None):
if optimizers_state_dicts is None:
optimizers_state_dicts = {}
attrs = obj.__dict__
if isinstance(obj, torch.nn.Module):
attrs.update(
{k: attr for k, attr in obj.named_modules() if "." not in k and obj != attr}
)
for name, attr in attrs.items():
new_path = path + "." + name if path else name
if isinstance(attr, torch.optim.Optimizer):
optimizers_state_dicts[new_path] = attr.state_dict()
elif hasattr(attr, "__dict__"):
optimizers_state_dicts.update(
recursively_collect_optim_state_dict(
attr, new_path, optimizers_state_dicts
)
)
return optimizers_state_dicts
def recursively_load_optim_state_dict(obj, optimizers_state_dicts):
print(optimizers_state_dicts)
for path, state_dict in optimizers_state_dicts.items():
keys = path.split(".")
obj_now = obj
for key in keys:
obj_now = getattr(obj_now, key)
print(keys)
obj_now.load_state_dict(state_dict)