Tianshou/test/base/test_stats.py
Daniel Plop 8a0629ded6
Fix mypy issues in tests and examples (#1077)
Closes #952 

- `SamplingConfig` supports `batch_size=None`. #1077
- tests and examples are covered by `mypy`. #1077
- `NetBase` is more used, stricter typing by making it generic. #1077
- `utils.net.common.Recurrent` now receives and returns a
`RecurrentStateBatch` instead of a dict. #1077

---------

Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2024-04-03 18:07:51 +02:00

50 lines
1.8 KiB
Python

import pytest
from tianshou.policy.base import TrainingStats, TrainingStatsWrapper
class DummyTrainingStatsWrapper(TrainingStatsWrapper):
def __init__(self, wrapped_stats: TrainingStats, *, dummy_field: int) -> None:
self.dummy_field = dummy_field
super().__init__(wrapped_stats)
class TestStats:
@staticmethod
def test_training_stats_wrapper() -> None:
train_stats = TrainingStats(train_time=1.0)
setattr(train_stats, "loss_field", 12) # noqa: B010
wrapped_train_stats = DummyTrainingStatsWrapper(train_stats, dummy_field=42)
# basic readout
assert wrapped_train_stats.train_time == 1.0
assert wrapped_train_stats.loss_field == 12
# mutation of TrainingStats fields
wrapped_train_stats.train_time = 2.0
wrapped_train_stats.smoothed_loss["foo"] = 50
assert wrapped_train_stats.train_time == 2.0
assert wrapped_train_stats.smoothed_loss["foo"] == 50
# loss stats dict
assert wrapped_train_stats.get_loss_stats_dict() == {"loss_field": 12, "dummy_field": 42}
# new fields can't be added
with pytest.raises(AttributeError):
wrapped_train_stats.new_loss_field = 90
# existing fields, wrapped and not-wrapped, can be mutated
wrapped_train_stats.loss_field = 13
wrapped_train_stats.dummy_field = 43
assert hasattr(
wrapped_train_stats.wrapped_stats,
"loss_field",
), "Attribute `loss_field` not found in `wrapped_train_stats.wrapped_stats`."
assert hasattr(
wrapped_train_stats,
"loss_field",
), "Attribute `loss_field` not found in `wrapped_train_stats`."
assert wrapped_train_stats.wrapped_stats.loss_field == wrapped_train_stats.loss_field == 13