avoid cyclic reference
This commit is contained in:
parent
f35480f2a6
commit
1002d8b115
11
tools.py
11
tools.py
@ -972,9 +972,16 @@ def enable_deterministic_run():
|
|||||||
torch.use_deterministic_algorithms(True)
|
torch.use_deterministic_algorithms(True)
|
||||||
|
|
||||||
|
|
||||||
def recursively_collect_optim_state_dict(obj, path="", optimizers_state_dicts=None):
|
def recursively_collect_optim_state_dict(
|
||||||
|
obj, path="", optimizers_state_dicts=None, visited=set()
|
||||||
|
):
|
||||||
if optimizers_state_dicts is None:
|
if optimizers_state_dicts is None:
|
||||||
optimizers_state_dicts = {}
|
optimizers_state_dicts = {}
|
||||||
|
# avoid cyclic reference
|
||||||
|
if id(obj) in visited:
|
||||||
|
return optimizers_state_dicts
|
||||||
|
else:
|
||||||
|
visited.add(id(obj))
|
||||||
attrs = obj.__dict__
|
attrs = obj.__dict__
|
||||||
if isinstance(obj, torch.nn.Module):
|
if isinstance(obj, torch.nn.Module):
|
||||||
attrs.update(
|
attrs.update(
|
||||||
@ -987,7 +994,7 @@ def recursively_collect_optim_state_dict(obj, path="", optimizers_state_dicts=No
|
|||||||
elif hasattr(attr, "__dict__"):
|
elif hasattr(attr, "__dict__"):
|
||||||
optimizers_state_dicts.update(
|
optimizers_state_dicts.update(
|
||||||
recursively_collect_optim_state_dict(
|
recursively_collect_optim_state_dict(
|
||||||
attr, new_path, optimizers_state_dicts
|
attr, new_path, optimizers_state_dicts, visited
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return optimizers_state_dicts
|
return optimizers_state_dicts
|
||||||
|
Loading…
x
Reference in New Issue
Block a user