diff --git a/tools.py b/tools.py index f80384e..87f2633 100644 --- a/tools.py +++ b/tools.py @@ -963,10 +963,12 @@ def enable_deterministic_run(): def recursively_collect_optim_state_dict( - obj, path="", optimizers_state_dicts=None, visited=set() + obj, path="", optimizers_state_dicts=None, visited=None ): if optimizers_state_dicts is None: optimizers_state_dicts = {} + if visited is None: + visited = set() # avoid cyclic reference if id(obj) in visited: return optimizers_state_dicts