added save and load for optimizers
This commit is contained in:
parent
16635df3e4
commit
d3576c5a98
12
dreamer.py
12
dreamer.py
@ -319,8 +319,10 @@ def main(config):
|
|||||||
train_dataset,
|
train_dataset,
|
||||||
).to(config.device)
|
).to(config.device)
|
||||||
agent.requires_grad_(requires_grad=False)
|
agent.requires_grad_(requires_grad=False)
|
||||||
if (logdir / "latest_model.pt").exists():
|
if (logdir / "latest.pt").exists():
|
||||||
agent.load_state_dict(torch.load(logdir / "latest_model.pt"))
|
checkpoint = torch.load(logdir / "latest.pt")
|
||||||
|
agent.load_state_dict(checkpoint["agent_state_dict"])
|
||||||
|
tools.recursively_load_optim_state_dict(agent, checkpoint["optims_state_dict"])
|
||||||
agent._should_pretrain._once = False
|
agent._should_pretrain._once = False
|
||||||
|
|
||||||
# make sure eval will be executed once after config.steps
|
# make sure eval will be executed once after config.steps
|
||||||
@ -352,7 +354,11 @@ def main(config):
|
|||||||
steps=config.eval_every,
|
steps=config.eval_every,
|
||||||
state=state,
|
state=state,
|
||||||
)
|
)
|
||||||
torch.save(agent.state_dict(), logdir / "latest_model.pt")
|
items_to_save = {
|
||||||
|
"agent_state_dict": agent.state_dict(),
|
||||||
|
"optims_state_dict": tools.recursively_collect_optim_state_dict(agent),
|
||||||
|
}
|
||||||
|
torch.save(items_to_save, logdir / "latest.pt")
|
||||||
for env in train_envs + eval_envs:
|
for env in train_envs + eval_envs:
|
||||||
try:
|
try:
|
||||||
env.close()
|
env.close()
|
||||||
|
@ -70,7 +70,7 @@ class Plan2Explore(nn.Module):
|
|||||||
[networks.MLP(**kw) for _ in range(config.disag_models)]
|
[networks.MLP(**kw) for _ in range(config.disag_models)]
|
||||||
)
|
)
|
||||||
kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
|
kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
|
||||||
self._model_opt = tools.Optimizer(
|
self._expl_opt = tools.Optimizer(
|
||||||
"explorer",
|
"explorer",
|
||||||
self.parameters(),
|
self.parameters(),
|
||||||
config.model_lr,
|
config.model_lr,
|
||||||
@ -129,5 +129,5 @@ class Plan2Explore(nn.Module):
|
|||||||
[torch.mean(pred.log_prob(targets))[None] for pred in preds], 0
|
[torch.mean(pred.log_prob(targets))[None] for pred in preds], 0
|
||||||
)
|
)
|
||||||
loss = -torch.mean(likes)
|
loss = -torch.mean(likes)
|
||||||
metrics = self._model_opt(loss, self.parameters())
|
metrics = self._expl_opt(loss, self.parameters())
|
||||||
return metrics
|
return metrics
|
||||||
|
32
tools.py
32
tools.py
@ -970,3 +970,35 @@ def enable_deterministic_run():
|
|||||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
torch.use_deterministic_algorithms(True)
|
torch.use_deterministic_algorithms(True)
|
||||||
|
|
||||||
|
|
||||||
|
def recursively_collect_optim_state_dict(obj, path="", optimizers_state_dicts=None):
|
||||||
|
if optimizers_state_dicts is None:
|
||||||
|
optimizers_state_dicts = {}
|
||||||
|
attrs = obj.__dict__
|
||||||
|
if isinstance(obj, torch.nn.Module):
|
||||||
|
attrs.update(
|
||||||
|
{k: attr for k, attr in obj.named_modules() if "." not in k and obj != attr}
|
||||||
|
)
|
||||||
|
for name, attr in attrs.items():
|
||||||
|
new_path = path + "." + name if path else name
|
||||||
|
if isinstance(attr, torch.optim.Optimizer):
|
||||||
|
optimizers_state_dicts[new_path] = attr.state_dict()
|
||||||
|
elif hasattr(attr, "__dict__"):
|
||||||
|
optimizers_state_dicts.update(
|
||||||
|
recursively_collect_optim_state_dict(
|
||||||
|
attr, new_path, optimizers_state_dicts
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return optimizers_state_dicts
|
||||||
|
|
||||||
|
|
||||||
|
def recursively_load_optim_state_dict(obj, optimizers_state_dicts):
|
||||||
|
print(optimizers_state_dicts)
|
||||||
|
for path, state_dict in optimizers_state_dicts.items():
|
||||||
|
keys = path.split(".")
|
||||||
|
obj_now = obj
|
||||||
|
for key in keys:
|
||||||
|
obj_now = getattr(obj_now, key)
|
||||||
|
print(keys)
|
||||||
|
obj_now.load_state_dict(state_dict)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user