avoid cyclic reference
This commit is contained in:
parent
f35480f2a6
commit
1002d8b115
11
tools.py
11
tools.py
@ -972,9 +972,16 @@ def enable_deterministic_run():
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
|
||||
def recursively_collect_optim_state_dict(obj, path="", optimizers_state_dicts=None):
|
||||
def recursively_collect_optim_state_dict(
|
||||
obj, path="", optimizers_state_dicts=None, visited=set()
|
||||
):
|
||||
if optimizers_state_dicts is None:
|
||||
optimizers_state_dicts = {}
|
||||
# avoid cyclic reference
|
||||
if id(obj) in visited:
|
||||
return optimizers_state_dicts
|
||||
else:
|
||||
visited.add(id(obj))
|
||||
attrs = obj.__dict__
|
||||
if isinstance(obj, torch.nn.Module):
|
||||
attrs.update(
|
||||
@ -987,7 +994,7 @@ def recursively_collect_optim_state_dict(obj, path="", optimizers_state_dicts=No
|
||||
elif hasattr(attr, "__dict__"):
|
||||
optimizers_state_dicts.update(
|
||||
recursively_collect_optim_state_dict(
|
||||
attr, new_path, optimizers_state_dicts
|
||||
attr, new_path, optimizers_state_dicts, visited
|
||||
)
|
||||
)
|
||||
return optimizers_state_dicts
|
||||
|
Loading…
x
Reference in New Issue
Block a user