fixed logger test

This commit is contained in:
Maximilian Huettenrauch 2024-03-27 13:55:22 +01:00
parent 9055eb5924
commit ce5fa0dfac

View File

@ -2,11 +2,12 @@ from typing import Literal
import numpy as np
import pytest
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BaseLogger
from tianshou.utils import TensorboardLogger
class TestBaseLogger:
class TestTensorBoardLogger:
@staticmethod
@pytest.mark.parametrize(
"input_dict, expected_output",
@ -20,7 +21,8 @@ class TestBaseLogger:
| dict[str, dict[str, dict[str, int]]],
expected_output: dict[str, int],
) -> None:
result = BaseLogger.prepare_dict_for_logging(input_dict)
logger = TensorboardLogger(SummaryWriter('log/logger'))
result = logger.prepare_dict_for_logging(input_dict)
assert result == expected_output
@staticmethod
@ -36,7 +38,8 @@ class TestBaseLogger:
delimiter: Literal["|", "."],
expected_output: dict[str, int],
) -> None:
result = BaseLogger.prepare_dict_for_logging(input_dict, delimiter=delimiter)
logger = TensorboardLogger(SummaryWriter('log/logger'))
result = logger.prepare_dict_for_logging(input_dict, delimiter=delimiter)
assert result == expected_output
@staticmethod
@ -56,7 +59,8 @@ class TestBaseLogger:
exclude_arrays: bool,
expected_output: dict[str, np.ndarray],
) -> None:
result = BaseLogger.prepare_dict_for_logging(input_dict, exclude_arrays=exclude_arrays)
logger = TensorboardLogger(SummaryWriter('log/logger'))
result = logger.prepare_dict_for_logging(input_dict, exclude_arrays=exclude_arrays)
assert result.keys() == expected_output.keys()
for val1, val2 in zip(result.values(), expected_output.values(), strict=True):
assert np.all(val1 == val2)
@ -72,5 +76,6 @@ class TestBaseLogger:
input_dict: dict[str, tuple[Literal[1]] | dict[str, str | dict[str, int]]],
expected_output: dict[str, int],
) -> None:
result = BaseLogger.prepare_dict_for_logging(input_dict)
logger = TensorboardLogger(SummaryWriter('log/logger'))
result = logger.prepare_dict_for_logging(input_dict)
assert result == expected_output