Remove warnings about the use of save_fn across trainers (#408)

This commit is contained in:
Andriy Drozdyuk 2021-08-03 21:56:00 -04:00 committed by GitHub
parent c19876179a
commit 18d2f25eff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 1 additions and 13 deletions

View File

@ -67,7 +67,7 @@ setup(
"pydocstyle", "pydocstyle",
"doc8", "doc8",
], ],
"atari": ["atari_py", "cv2"], "atari": ["atari_py", "opencv-python"],
"mujoco": ["mujoco_py"], "mujoco": ["mujoco_py"],
"pybullet": ["pybullet"], "pybullet": ["pybullet"],
}, },

View File

@ -1,6 +1,5 @@
import time import time
import tqdm import tqdm
import warnings
import numpy as np import numpy as np
from collections import defaultdict from collections import defaultdict
from typing import Dict, Union, Callable, Optional from typing import Dict, Union, Callable, Optional
@ -68,9 +67,6 @@ def offline_trainer(
:return: See :func:`~tianshou.trainer.gather_info`. :return: See :func:`~tianshou.trainer.gather_info`.
""" """
if save_fn:
warnings.warn("Please consider using save_checkpoint_fn instead of save_fn.")
start_epoch, gradient_step = 0, 0 start_epoch, gradient_step = 0, 0
if resume_from_log: if resume_from_log:
start_epoch, _, gradient_step = logger.restore_data() start_epoch, _, gradient_step = logger.restore_data()

View File

@ -1,6 +1,5 @@
import time import time
import tqdm import tqdm
import warnings
import numpy as np import numpy as np
from collections import defaultdict from collections import defaultdict
from typing import Dict, Union, Callable, Optional from typing import Dict, Union, Callable, Optional
@ -83,9 +82,6 @@ def offpolicy_trainer(
:return: See :func:`~tianshou.trainer.gather_info`. :return: See :func:`~tianshou.trainer.gather_info`.
""" """
if save_fn:
warnings.warn("Please consider using save_checkpoint_fn instead of save_fn.")
start_epoch, env_step, gradient_step = 0, 0, 0 start_epoch, env_step, gradient_step = 0, 0, 0
if resume_from_log: if resume_from_log:
start_epoch, env_step, gradient_step = logger.restore_data() start_epoch, env_step, gradient_step = logger.restore_data()

View File

@ -1,6 +1,5 @@
import time import time
import tqdm import tqdm
import warnings
import numpy as np import numpy as np
from collections import defaultdict from collections import defaultdict
from typing import Dict, Union, Callable, Optional from typing import Dict, Union, Callable, Optional
@ -89,9 +88,6 @@ def onpolicy_trainer(
Only either one of step_per_collect and episode_per_collect can be specified. Only either one of step_per_collect and episode_per_collect can be specified.
""" """
if save_fn:
warnings.warn("Please consider using save_checkpoint_fn instead of save_fn.")
start_epoch, env_step, gradient_step = 0, 0, 0 start_epoch, env_step, gradient_step = 0, 0, 0
if resume_from_log: if resume_from_log:
start_epoch, env_step, gradient_step = logger.restore_data() start_epoch, env_step, gradient_step = logger.restore_data()