From d3576c5a98793a60447af6963ccf251adc216ad7 Mon Sep 17 00:00:00 2001 From: NM512 Date: Wed, 27 Sep 2023 09:15:37 +0900 Subject: [PATCH] added save and load for optimizers --- dreamer.py | 12 +++++++++--- exploration.py | 4 ++-- tools.py | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 5 deletions(-) diff --git a/dreamer.py b/dreamer.py index 2b61b90..6855c8c 100644 --- a/dreamer.py +++ b/dreamer.py @@ -319,8 +319,10 @@ def main(config): train_dataset, ).to(config.device) agent.requires_grad_(requires_grad=False) - if (logdir / "latest_model.pt").exists(): - agent.load_state_dict(torch.load(logdir / "latest_model.pt")) + if (logdir / "latest.pt").exists(): + checkpoint = torch.load(logdir / "latest.pt") + agent.load_state_dict(checkpoint["agent_state_dict"]) + tools.recursively_load_optim_state_dict(agent, checkpoint["optims_state_dict"]) agent._should_pretrain._once = False # make sure eval will be executed once after config.steps @@ -352,7 +354,11 @@ def main(config): steps=config.eval_every, state=state, ) - torch.save(agent.state_dict(), logdir / "latest_model.pt") + items_to_save = { + "agent_state_dict": agent.state_dict(), + "optims_state_dict": tools.recursively_collect_optim_state_dict(agent), + } + torch.save(items_to_save, logdir / "latest.pt") for env in train_envs + eval_envs: try: env.close() diff --git a/exploration.py b/exploration.py index bb2e60b..135366c 100644 --- a/exploration.py +++ b/exploration.py @@ -70,7 +70,7 @@ class Plan2Explore(nn.Module): [networks.MLP(**kw) for _ in range(config.disag_models)] ) kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp) - self._model_opt = tools.Optimizer( + self._expl_opt = tools.Optimizer( "explorer", self.parameters(), config.model_lr, @@ -129,5 +129,5 @@ class Plan2Explore(nn.Module): [torch.mean(pred.log_prob(targets))[None] for pred in preds], 0 ) loss = -torch.mean(likes) - metrics = self._model_opt(loss, self.parameters()) + metrics = self._expl_opt(loss, self.parameters()) return metrics diff --git a/tools.py b/tools.py index f9edc1d..b36ed40 100644 --- a/tools.py +++ b/tools.py @@ -970,3 +970,35 @@ def enable_deterministic_run(): os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" torch.backends.cudnn.benchmark = False torch.use_deterministic_algorithms(True) + + +def recursively_collect_optim_state_dict(obj, path="", optimizers_state_dicts=None): + if optimizers_state_dicts is None: + optimizers_state_dicts = {} + attrs = obj.__dict__ + if isinstance(obj, torch.nn.Module): + attrs.update( + {k: attr for k, attr in obj.named_modules() if "." not in k and obj != attr} + ) + for name, attr in attrs.items(): + new_path = path + "." + name if path else name + if isinstance(attr, torch.optim.Optimizer): + optimizers_state_dicts[new_path] = attr.state_dict() + elif hasattr(attr, "__dict__"): + optimizers_state_dicts.update( + recursively_collect_optim_state_dict( + attr, new_path, optimizers_state_dicts + ) + ) + return optimizers_state_dicts + + +def recursively_load_optim_state_dict(obj, optimizers_state_dicts): + print(optimizers_state_dicts) + for path, state_dict in optimizers_state_dicts.items(): + keys = path.split(".") + obj_now = obj + for key in keys: + obj_now = getattr(obj_now, key) + print(keys) + obj_now.load_state_dict(state_dict)