Closes #952 - `SamplingConfig` supports `batch_size=None`. #1077 - tests and examples are covered by `mypy`. #1077 - `NetBase` is more used, stricter typing by making it generic. #1077 - `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077 --------- Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
50 lines
1.8 KiB
Python
50 lines
1.8 KiB
Python
import pytest
|
|
|
|
from tianshou.policy.base import TrainingStats, TrainingStatsWrapper
|
|
|
|
|
|
class DummyTrainingStatsWrapper(TrainingStatsWrapper):
|
|
def __init__(self, wrapped_stats: TrainingStats, *, dummy_field: int) -> None:
|
|
self.dummy_field = dummy_field
|
|
super().__init__(wrapped_stats)
|
|
|
|
|
|
class TestStats:
|
|
@staticmethod
|
|
def test_training_stats_wrapper() -> None:
|
|
train_stats = TrainingStats(train_time=1.0)
|
|
|
|
setattr(train_stats, "loss_field", 12) # noqa: B010
|
|
|
|
wrapped_train_stats = DummyTrainingStatsWrapper(train_stats, dummy_field=42)
|
|
|
|
# basic readout
|
|
assert wrapped_train_stats.train_time == 1.0
|
|
assert wrapped_train_stats.loss_field == 12
|
|
|
|
# mutation of TrainingStats fields
|
|
wrapped_train_stats.train_time = 2.0
|
|
wrapped_train_stats.smoothed_loss["foo"] = 50
|
|
assert wrapped_train_stats.train_time == 2.0
|
|
assert wrapped_train_stats.smoothed_loss["foo"] == 50
|
|
|
|
# loss stats dict
|
|
assert wrapped_train_stats.get_loss_stats_dict() == {"loss_field": 12, "dummy_field": 42}
|
|
|
|
# new fields can't be added
|
|
with pytest.raises(AttributeError):
|
|
wrapped_train_stats.new_loss_field = 90
|
|
|
|
# existing fields, wrapped and not-wrapped, can be mutated
|
|
wrapped_train_stats.loss_field = 13
|
|
wrapped_train_stats.dummy_field = 43
|
|
assert hasattr(
|
|
wrapped_train_stats.wrapped_stats,
|
|
"loss_field",
|
|
), "Attribute `loss_field` not found in `wrapped_train_stats.wrapped_stats`."
|
|
assert hasattr(
|
|
wrapped_train_stats,
|
|
"loss_field",
|
|
), "Attribute `loss_field` not found in `wrapped_train_stats`."
|
|
assert wrapped_train_stats.wrapped_stats.loss_field == wrapped_train_stats.loss_field == 13
|