From 2eaf1f37c216553b1347240e3a3fe01e1f9d0623 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Fri, 26 Apr 2024 17:53:27 +0200 Subject: [PATCH] Use the new BaseCollector interface for annotations --- tianshou/data/__init__.py | 3 ++- tianshou/highlevel/agent.py | 3 ++- tianshou/highlevel/world.py | 6 +++--- tianshou/trainer/base.py | 7 +++---- tianshou/trainer/utils.py | 8 ++++---- 5 files changed, 14 insertions(+), 13 deletions(-) diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index 6230798..d8c5410 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -24,7 +24,7 @@ from tianshou.data.stats import ( SequenceSummaryStats, TimingStats, ) -from tianshou.data.collector import Collector, AsyncCollector, CollectStats, CollectStatsBase +from tianshou.data.collector import Collector, AsyncCollector, CollectStats, CollectStatsBase, BaseCollector __all__ = [ "Batch", @@ -50,4 +50,5 @@ __all__ = [ "InfoStats", "SequenceSummaryStats", "TimingStats", + "BaseCollector", ] diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index f71a7f9..c131326 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -6,6 +6,7 @@ from typing import Any, Generic, TypeVar, cast import gymnasium from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.data.collector import BaseCollector from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.env import Environments from tianshou.highlevel.module.actor import ( @@ -94,7 +95,7 @@ class AgentFactory(ABC, ToStringMixin): policy: BasePolicy, envs: Environments, reset_collectors: bool = True, - ) -> tuple[Collector, Collector]: + ) -> tuple[BaseCollector, BaseCollector]: """:param policy: :param envs: :param reset_collectors: Whether to reset the collectors before returning them. diff --git a/tianshou/highlevel/world.py b/tianshou/highlevel/world.py index 1a8d648..c32ef9c 100644 --- a/tianshou/highlevel/world.py +++ b/tianshou/highlevel/world.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: - from tianshou.data import Collector + from tianshou.data import BaseCollector from tianshou.highlevel.env import Environments from tianshou.highlevel.logger import TLogger from tianshou.policy import BasePolicy @@ -16,8 +16,8 @@ class World: envs: "Environments" policy: "BasePolicy" - train_collector: "Collector" - test_collector: "Collector" + train_collector: "BaseCollector" + test_collector: "BaseCollector" logger: "TLogger" persist_directory: str restore_directory: str | None diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 3c8a27e..f657f63 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -10,14 +10,13 @@ import tqdm from tianshou.data import ( AsyncCollector, - Collector, CollectStats, EpochStats, InfoStats, ReplayBuffer, SequenceSummaryStats, ) -from tianshou.data.collector import CollectStatsBase +from tianshou.data.collector import BaseCollector, CollectStatsBase from tianshou.policy import BasePolicy from tianshou.policy.base import TrainingStats from tianshou.trainer.utils import gather_info, test_episode @@ -152,8 +151,8 @@ class BaseTrainer(ABC): policy: BasePolicy, max_epoch: int, batch_size: int | None, - train_collector: Collector | None = None, - test_collector: Collector | None = None, + train_collector: BaseCollector | None = None, + test_collector: BaseCollector | None = None, buffer: ReplayBuffer | None = None, step_per_epoch: int | None = None, repeat_per_collect: int | None = None, diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index b790fdd..0c2bf18 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -5,17 +5,17 @@ from dataclasses import asdict import numpy as np from tianshou.data import ( - Collector, CollectStats, InfoStats, SequenceSummaryStats, TimingStats, ) +from tianshou.data.collector import BaseCollector from tianshou.utils import BaseLogger def test_episode( - collector: Collector, + collector: BaseCollector, test_fn: Callable[[int, int | None], None] | None, epoch: int, n_episode: int, @@ -44,8 +44,8 @@ def gather_info( gradient_step: int, best_reward: float, best_reward_std: float, - train_collector: Collector | None = None, - test_collector: Collector | None = None, + train_collector: BaseCollector | None = None, + test_collector: BaseCollector | None = None, ) -> InfoStats: """A simple wrapper of gathering information from collectors.