Use the new BaseCollector interface for annotations

This commit is contained in:
Michael Panchenko 2024-04-26 17:53:27 +02:00
parent 07a97c7d93
commit 2eaf1f37c2
5 changed files with 14 additions and 13 deletions

View File

@ -24,7 +24,7 @@ from tianshou.data.stats import (
SequenceSummaryStats, SequenceSummaryStats,
TimingStats, TimingStats,
) )
from tianshou.data.collector import Collector, AsyncCollector, CollectStats, CollectStatsBase from tianshou.data.collector import Collector, AsyncCollector, CollectStats, CollectStatsBase, BaseCollector
__all__ = [ __all__ = [
"Batch", "Batch",
@ -50,4 +50,5 @@ __all__ = [
"InfoStats", "InfoStats",
"SequenceSummaryStats", "SequenceSummaryStats",
"TimingStats", "TimingStats",
"BaseCollector",
] ]

View File

@ -6,6 +6,7 @@ from typing import Any, Generic, TypeVar, cast
import gymnasium import gymnasium
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.data.collector import BaseCollector
from tianshou.highlevel.config import SamplingConfig from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.actor import ( from tianshou.highlevel.module.actor import (
@ -94,7 +95,7 @@ class AgentFactory(ABC, ToStringMixin):
policy: BasePolicy, policy: BasePolicy,
envs: Environments, envs: Environments,
reset_collectors: bool = True, reset_collectors: bool = True,
) -> tuple[Collector, Collector]: ) -> tuple[BaseCollector, BaseCollector]:
""":param policy: """:param policy:
:param envs: :param envs:
:param reset_collectors: Whether to reset the collectors before returning them. :param reset_collectors: Whether to reset the collectors before returning them.

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from tianshou.data import Collector from tianshou.data import BaseCollector
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import TLogger from tianshou.highlevel.logger import TLogger
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
@ -16,8 +16,8 @@ class World:
envs: "Environments" envs: "Environments"
policy: "BasePolicy" policy: "BasePolicy"
train_collector: "Collector" train_collector: "BaseCollector"
test_collector: "Collector" test_collector: "BaseCollector"
logger: "TLogger" logger: "TLogger"
persist_directory: str persist_directory: str
restore_directory: str | None restore_directory: str | None

View File

@ -10,14 +10,13 @@ import tqdm
from tianshou.data import ( from tianshou.data import (
AsyncCollector, AsyncCollector,
Collector,
CollectStats, CollectStats,
EpochStats, EpochStats,
InfoStats, InfoStats,
ReplayBuffer, ReplayBuffer,
SequenceSummaryStats, SequenceSummaryStats,
) )
from tianshou.data.collector import CollectStatsBase from tianshou.data.collector import BaseCollector, CollectStatsBase
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.policy.base import TrainingStats from tianshou.policy.base import TrainingStats
from tianshou.trainer.utils import gather_info, test_episode from tianshou.trainer.utils import gather_info, test_episode
@ -152,8 +151,8 @@ class BaseTrainer(ABC):
policy: BasePolicy, policy: BasePolicy,
max_epoch: int, max_epoch: int,
batch_size: int | None, batch_size: int | None,
train_collector: Collector | None = None, train_collector: BaseCollector | None = None,
test_collector: Collector | None = None, test_collector: BaseCollector | None = None,
buffer: ReplayBuffer | None = None, buffer: ReplayBuffer | None = None,
step_per_epoch: int | None = None, step_per_epoch: int | None = None,
repeat_per_collect: int | None = None, repeat_per_collect: int | None = None,

View File

@ -5,17 +5,17 @@ from dataclasses import asdict
import numpy as np import numpy as np
from tianshou.data import ( from tianshou.data import (
Collector,
CollectStats, CollectStats,
InfoStats, InfoStats,
SequenceSummaryStats, SequenceSummaryStats,
TimingStats, TimingStats,
) )
from tianshou.data.collector import BaseCollector
from tianshou.utils import BaseLogger from tianshou.utils import BaseLogger
def test_episode( def test_episode(
collector: Collector, collector: BaseCollector,
test_fn: Callable[[int, int | None], None] | None, test_fn: Callable[[int, int | None], None] | None,
epoch: int, epoch: int,
n_episode: int, n_episode: int,
@ -44,8 +44,8 @@ def gather_info(
gradient_step: int, gradient_step: int,
best_reward: float, best_reward: float,
best_reward_std: float, best_reward_std: float,
train_collector: Collector | None = None, train_collector: BaseCollector | None = None,
test_collector: Collector | None = None, test_collector: BaseCollector | None = None,
) -> InfoStats: ) -> InfoStats:
"""A simple wrapper of gathering information from collectors. """A simple wrapper of gathering information from collectors.