Deleted long deprecated functionality, removed unused warning module
There's better ways to deal with deprecations that we shall use in the future
This commit is contained in:
parent
49c750fb09
commit
829fd9c7a5
20
tianshou/env/worker/base.py
vendored
20
tianshou/env/worker/base.py
vendored
@ -6,7 +6,6 @@ import gymnasium as gym
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tianshou.env.utils import gym_new_venv_step_type
|
from tianshou.env.utils import gym_new_venv_step_type
|
||||||
from tianshou.utils import deprecation
|
|
||||||
|
|
||||||
|
|
||||||
class EnvWorker(ABC):
|
class EnvWorker(ABC):
|
||||||
@ -27,6 +26,7 @@ class EnvWorker(ABC):
|
|||||||
def set_env_attr(self, key: str, value: Any) -> None:
|
def set_env_attr(self, key: str, value: Any) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def send(self, action: np.ndarray | None) -> None:
|
def send(self, action: np.ndarray | None) -> None:
|
||||||
"""Send action signal to low-level worker.
|
"""Send action signal to low-level worker.
|
||||||
|
|
||||||
@ -34,17 +34,6 @@ class EnvWorker(ABC):
|
|||||||
it indicates "step" signal. The paired return value from "recv"
|
it indicates "step" signal. The paired return value from "recv"
|
||||||
function is determined by such kind of different signal.
|
function is determined by such kind of different signal.
|
||||||
"""
|
"""
|
||||||
if hasattr(self, "send_action"):
|
|
||||||
deprecation(
|
|
||||||
"send_action will soon be deprecated. "
|
|
||||||
"Please use send and recv for your own EnvWorker.",
|
|
||||||
)
|
|
||||||
if action is None:
|
|
||||||
self.is_reset = True
|
|
||||||
self.result = self.reset()
|
|
||||||
else:
|
|
||||||
self.is_reset = False
|
|
||||||
self.send_action(action)
|
|
||||||
|
|
||||||
def recv(self) -> gym_new_venv_step_type | tuple[np.ndarray, dict]:
|
def recv(self) -> gym_new_venv_step_type | tuple[np.ndarray, dict]:
|
||||||
"""Receive result from low-level worker.
|
"""Receive result from low-level worker.
|
||||||
@ -54,13 +43,6 @@ class EnvWorker(ABC):
|
|||||||
info) or (obs, rew, terminated, truncated, info), based on whether
|
info) or (obs, rew, terminated, truncated, info), based on whether
|
||||||
the environment is using the old step API or the new one.
|
the environment is using the old step API or the new one.
|
||||||
"""
|
"""
|
||||||
if hasattr(self, "get_result"):
|
|
||||||
deprecation(
|
|
||||||
"get_result will soon be deprecated. "
|
|
||||||
"Please use send and recv for your own EnvWorker.",
|
|
||||||
)
|
|
||||||
if not self.is_reset:
|
|
||||||
self.result = self.get_result()
|
|
||||||
return self.result
|
return self.result
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -26,7 +26,6 @@ from tianshou.utils import (
|
|||||||
DummyTqdm,
|
DummyTqdm,
|
||||||
LazyLogger,
|
LazyLogger,
|
||||||
MovAvg,
|
MovAvg,
|
||||||
deprecation,
|
|
||||||
tqdm_config,
|
tqdm_config,
|
||||||
)
|
)
|
||||||
from tianshou.utils.logging import set_numerical_fields_to_precision
|
from tianshou.utils.logging import set_numerical_fields_to_precision
|
||||||
@ -76,7 +75,7 @@ class BaseTrainer(ABC):
|
|||||||
signature ``f(num_epoch: int, step_idx: int) -> None``.
|
signature ``f(num_epoch: int, step_idx: int) -> None``.
|
||||||
:param save_best_fn: a hook called when the undiscounted average mean
|
:param save_best_fn: a hook called when the undiscounted average mean
|
||||||
reward in evaluation phase gets better, with the signature
|
reward in evaluation phase gets better, with the signature
|
||||||
``f(policy: BasePolicy) -> None``. It was ``save_fn`` previously.
|
``f(policy: BasePolicy) -> None``.
|
||||||
:param save_checkpoint_fn: a function to save training process and
|
:param save_checkpoint_fn: a function to save training process and
|
||||||
return the saved checkpoint path, with the signature ``f(epoch: int,
|
return the saved checkpoint path, with the signature ``f(epoch: int,
|
||||||
env_step: int, gradient_step: int) -> str``; you can save whatever you want.
|
env_step: int, gradient_step: int) -> str``; you can save whatever you want.
|
||||||
@ -173,16 +172,7 @@ class BaseTrainer(ABC):
|
|||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
show_progress: bool = True,
|
show_progress: bool = True,
|
||||||
test_in_train: bool = True,
|
test_in_train: bool = True,
|
||||||
save_fn: Callable[[BasePolicy], None] | None = None,
|
|
||||||
):
|
):
|
||||||
if save_fn:
|
|
||||||
deprecation(
|
|
||||||
"save_fn in trainer is marked as deprecated and will be "
|
|
||||||
"removed in the future. Please use save_best_fn instead.",
|
|
||||||
)
|
|
||||||
assert save_best_fn is None
|
|
||||||
save_best_fn = save_fn
|
|
||||||
|
|
||||||
self.policy = policy
|
self.policy = policy
|
||||||
|
|
||||||
if buffer is not None:
|
if buffer is not None:
|
||||||
|
@ -6,7 +6,6 @@ from tianshou.utils.logger.wandb import WandbLogger
|
|||||||
from tianshou.utils.lr_scheduler import MultipleLRSchedulers
|
from tianshou.utils.lr_scheduler import MultipleLRSchedulers
|
||||||
from tianshou.utils.progress_bar import DummyTqdm, tqdm_config
|
from tianshou.utils.progress_bar import DummyTqdm, tqdm_config
|
||||||
from tianshou.utils.statistics import MovAvg, RunningMeanStd
|
from tianshou.utils.statistics import MovAvg, RunningMeanStd
|
||||||
from tianshou.utils.warning import deprecation
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MovAvg",
|
"MovAvg",
|
||||||
@ -17,6 +16,5 @@ __all__ = [
|
|||||||
"TensorboardLogger",
|
"TensorboardLogger",
|
||||||
"LazyLogger",
|
"LazyLogger",
|
||||||
"WandbLogger",
|
"WandbLogger",
|
||||||
"deprecation",
|
|
||||||
"MultipleLRSchedulers",
|
"MultipleLRSchedulers",
|
||||||
]
|
]
|
||||||
|
@ -1,8 +0,0 @@
|
|||||||
import warnings
|
|
||||||
|
|
||||||
warnings.simplefilter("once", DeprecationWarning)
|
|
||||||
|
|
||||||
|
|
||||||
def deprecation(msg: str) -> None:
|
|
||||||
"""Deprecation warning wrapper."""
|
|
||||||
warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
|
|
Loading…
x
Reference in New Issue
Block a user