Use the new BaseCollector interface for annotations

This commit is contained in:
Michael Panchenko 2024-04-26 17:53:27 +02:00
parent 07a97c7d93
commit 2eaf1f37c2
5 changed files with 14 additions and 13 deletions

View File

@ -24,7 +24,7 @@ from tianshou.data.stats import (
SequenceSummaryStats,
TimingStats,
)
from tianshou.data.collector import Collector, AsyncCollector, CollectStats, CollectStatsBase
from tianshou.data.collector import Collector, AsyncCollector, CollectStats, CollectStatsBase, BaseCollector
__all__ = [
"Batch",
@ -50,4 +50,5 @@ __all__ = [
"InfoStats",
"SequenceSummaryStats",
"TimingStats",
"BaseCollector",
]

View File

@ -6,6 +6,7 @@ from typing import Any, Generic, TypeVar, cast
import gymnasium
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.data.collector import BaseCollector
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.actor import (
@ -94,7 +95,7 @@ class AgentFactory(ABC, ToStringMixin):
policy: BasePolicy,
envs: Environments,
reset_collectors: bool = True,
) -> tuple[Collector, Collector]:
) -> tuple[BaseCollector, BaseCollector]:
""":param policy:
:param envs:
:param reset_collectors: Whether to reset the collectors before returning them.

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from tianshou.data import Collector
from tianshou.data import BaseCollector
from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import TLogger
from tianshou.policy import BasePolicy
@ -16,8 +16,8 @@ class World:
envs: "Environments"
policy: "BasePolicy"
train_collector: "Collector"
test_collector: "Collector"
train_collector: "BaseCollector"
test_collector: "BaseCollector"
logger: "TLogger"
persist_directory: str
restore_directory: str | None

View File

@ -10,14 +10,13 @@ import tqdm
from tianshou.data import (
AsyncCollector,
Collector,
CollectStats,
EpochStats,
InfoStats,
ReplayBuffer,
SequenceSummaryStats,
)
from tianshou.data.collector import CollectStatsBase
from tianshou.data.collector import BaseCollector, CollectStatsBase
from tianshou.policy import BasePolicy
from tianshou.policy.base import TrainingStats
from tianshou.trainer.utils import gather_info, test_episode
@ -152,8 +151,8 @@ class BaseTrainer(ABC):
policy: BasePolicy,
max_epoch: int,
batch_size: int | None,
train_collector: Collector | None = None,
test_collector: Collector | None = None,
train_collector: BaseCollector | None = None,
test_collector: BaseCollector | None = None,
buffer: ReplayBuffer | None = None,
step_per_epoch: int | None = None,
repeat_per_collect: int | None = None,

View File

@ -5,17 +5,17 @@ from dataclasses import asdict
import numpy as np
from tianshou.data import (
Collector,
CollectStats,
InfoStats,
SequenceSummaryStats,
TimingStats,
)
from tianshou.data.collector import BaseCollector
from tianshou.utils import BaseLogger
def test_episode(
collector: Collector,
collector: BaseCollector,
test_fn: Callable[[int, int | None], None] | None,
epoch: int,
n_episode: int,
@ -44,8 +44,8 @@ def gather_info(
gradient_step: int,
best_reward: float,
best_reward_std: float,
train_collector: Collector | None = None,
test_collector: Collector | None = None,
train_collector: BaseCollector | None = None,
test_collector: BaseCollector | None = None,
) -> InfoStats:
"""A simple wrapper of gathering information from collectors.