Use the new BaseCollector interface for annotations
This commit is contained in:
parent
07a97c7d93
commit
2eaf1f37c2
@ -24,7 +24,7 @@ from tianshou.data.stats import (
|
|||||||
SequenceSummaryStats,
|
SequenceSummaryStats,
|
||||||
TimingStats,
|
TimingStats,
|
||||||
)
|
)
|
||||||
from tianshou.data.collector import Collector, AsyncCollector, CollectStats, CollectStatsBase
|
from tianshou.data.collector import Collector, AsyncCollector, CollectStats, CollectStatsBase, BaseCollector
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Batch",
|
"Batch",
|
||||||
@ -50,4 +50,5 @@ __all__ = [
|
|||||||
"InfoStats",
|
"InfoStats",
|
||||||
"SequenceSummaryStats",
|
"SequenceSummaryStats",
|
||||||
"TimingStats",
|
"TimingStats",
|
||||||
|
"BaseCollector",
|
||||||
]
|
]
|
||||||
|
@ -6,6 +6,7 @@ from typing import Any, Generic, TypeVar, cast
|
|||||||
import gymnasium
|
import gymnasium
|
||||||
|
|
||||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||||
|
from tianshou.data.collector import BaseCollector
|
||||||
from tianshou.highlevel.config import SamplingConfig
|
from tianshou.highlevel.config import SamplingConfig
|
||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.highlevel.module.actor import (
|
from tianshou.highlevel.module.actor import (
|
||||||
@ -94,7 +95,7 @@ class AgentFactory(ABC, ToStringMixin):
|
|||||||
policy: BasePolicy,
|
policy: BasePolicy,
|
||||||
envs: Environments,
|
envs: Environments,
|
||||||
reset_collectors: bool = True,
|
reset_collectors: bool = True,
|
||||||
) -> tuple[Collector, Collector]:
|
) -> tuple[BaseCollector, BaseCollector]:
|
||||||
""":param policy:
|
""":param policy:
|
||||||
:param envs:
|
:param envs:
|
||||||
:param reset_collectors: Whether to reset the collectors before returning them.
|
:param reset_collectors: Whether to reset the collectors before returning them.
|
||||||
|
@ -3,7 +3,7 @@ from dataclasses import dataclass
|
|||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tianshou.data import Collector
|
from tianshou.data import BaseCollector
|
||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.highlevel.logger import TLogger
|
from tianshou.highlevel.logger import TLogger
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
@ -16,8 +16,8 @@ class World:
|
|||||||
|
|
||||||
envs: "Environments"
|
envs: "Environments"
|
||||||
policy: "BasePolicy"
|
policy: "BasePolicy"
|
||||||
train_collector: "Collector"
|
train_collector: "BaseCollector"
|
||||||
test_collector: "Collector"
|
test_collector: "BaseCollector"
|
||||||
logger: "TLogger"
|
logger: "TLogger"
|
||||||
persist_directory: str
|
persist_directory: str
|
||||||
restore_directory: str | None
|
restore_directory: str | None
|
||||||
|
@ -10,14 +10,13 @@ import tqdm
|
|||||||
|
|
||||||
from tianshou.data import (
|
from tianshou.data import (
|
||||||
AsyncCollector,
|
AsyncCollector,
|
||||||
Collector,
|
|
||||||
CollectStats,
|
CollectStats,
|
||||||
EpochStats,
|
EpochStats,
|
||||||
InfoStats,
|
InfoStats,
|
||||||
ReplayBuffer,
|
ReplayBuffer,
|
||||||
SequenceSummaryStats,
|
SequenceSummaryStats,
|
||||||
)
|
)
|
||||||
from tianshou.data.collector import CollectStatsBase
|
from tianshou.data.collector import BaseCollector, CollectStatsBase
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.policy.base import TrainingStats
|
from tianshou.policy.base import TrainingStats
|
||||||
from tianshou.trainer.utils import gather_info, test_episode
|
from tianshou.trainer.utils import gather_info, test_episode
|
||||||
@ -152,8 +151,8 @@ class BaseTrainer(ABC):
|
|||||||
policy: BasePolicy,
|
policy: BasePolicy,
|
||||||
max_epoch: int,
|
max_epoch: int,
|
||||||
batch_size: int | None,
|
batch_size: int | None,
|
||||||
train_collector: Collector | None = None,
|
train_collector: BaseCollector | None = None,
|
||||||
test_collector: Collector | None = None,
|
test_collector: BaseCollector | None = None,
|
||||||
buffer: ReplayBuffer | None = None,
|
buffer: ReplayBuffer | None = None,
|
||||||
step_per_epoch: int | None = None,
|
step_per_epoch: int | None = None,
|
||||||
repeat_per_collect: int | None = None,
|
repeat_per_collect: int | None = None,
|
||||||
|
@ -5,17 +5,17 @@ from dataclasses import asdict
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tianshou.data import (
|
from tianshou.data import (
|
||||||
Collector,
|
|
||||||
CollectStats,
|
CollectStats,
|
||||||
InfoStats,
|
InfoStats,
|
||||||
SequenceSummaryStats,
|
SequenceSummaryStats,
|
||||||
TimingStats,
|
TimingStats,
|
||||||
)
|
)
|
||||||
|
from tianshou.data.collector import BaseCollector
|
||||||
from tianshou.utils import BaseLogger
|
from tianshou.utils import BaseLogger
|
||||||
|
|
||||||
|
|
||||||
def test_episode(
|
def test_episode(
|
||||||
collector: Collector,
|
collector: BaseCollector,
|
||||||
test_fn: Callable[[int, int | None], None] | None,
|
test_fn: Callable[[int, int | None], None] | None,
|
||||||
epoch: int,
|
epoch: int,
|
||||||
n_episode: int,
|
n_episode: int,
|
||||||
@ -44,8 +44,8 @@ def gather_info(
|
|||||||
gradient_step: int,
|
gradient_step: int,
|
||||||
best_reward: float,
|
best_reward: float,
|
||||||
best_reward_std: float,
|
best_reward_std: float,
|
||||||
train_collector: Collector | None = None,
|
train_collector: BaseCollector | None = None,
|
||||||
test_collector: Collector | None = None,
|
test_collector: BaseCollector | None = None,
|
||||||
) -> InfoStats:
|
) -> InfoStats:
|
||||||
"""A simple wrapper of gathering information from collectors.
|
"""A simple wrapper of gathering information from collectors.
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user