avoid cyclic reference

This commit is contained in:
NM512 2023-10-02 07:27:26 +09:00
parent f35480f2a6
commit 1002d8b115

View File

@ -972,9 +972,16 @@ def enable_deterministic_run():
torch.use_deterministic_algorithms(True)
def recursively_collect_optim_state_dict(obj, path="", optimizers_state_dicts=None):
def recursively_collect_optim_state_dict(
obj, path="", optimizers_state_dicts=None, visited=set()
):
if optimizers_state_dicts is None:
optimizers_state_dicts = {}
# avoid cyclic reference
if id(obj) in visited:
return optimizers_state_dicts
else:
visited.add(id(obj))
attrs = obj.__dict__
if isinstance(obj, torch.nn.Module):
attrs.update(
@ -987,7 +994,7 @@ def recursively_collect_optim_state_dict(obj, path="", optimizers_state_dicts=No
elif hasattr(attr, "__dict__"):
optimizers_state_dicts.update(
recursively_collect_optim_state_dict(
attr, new_path, optimizers_state_dicts
attr, new_path, optimizers_state_dicts, visited
)
)
return optimizers_state_dicts