From 1002d8b11567c6e7b5973a460f9f2165945ad426 Mon Sep 17 00:00:00 2001 From: NM512 Date: Mon, 2 Oct 2023 07:27:26 +0900 Subject: [PATCH] avoid cyclic reference --- tools.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tools.py b/tools.py index 9572fe1..8265a57 100644 --- a/tools.py +++ b/tools.py @@ -972,9 +972,16 @@ def enable_deterministic_run(): torch.use_deterministic_algorithms(True) -def recursively_collect_optim_state_dict(obj, path="", optimizers_state_dicts=None): +def recursively_collect_optim_state_dict( + obj, path="", optimizers_state_dicts=None, visited=set() +): if optimizers_state_dicts is None: optimizers_state_dicts = {} + # avoid cyclic reference + if id(obj) in visited: + return optimizers_state_dicts + else: + visited.add(id(obj)) attrs = obj.__dict__ if isinstance(obj, torch.nn.Module): attrs.update( @@ -987,7 +994,7 @@ def recursively_collect_optim_state_dict(obj, path="", optimizers_state_dicts=No elif hasattr(attr, "__dict__"): optimizers_state_dicts.update( recursively_collect_optim_state_dict( - attr, new_path, optimizers_state_dicts + attr, new_path, optimizers_state_dicts, visited ) ) return optimizers_state_dicts