fixed logger test
This commit is contained in:
parent
9055eb5924
commit
ce5fa0dfac
@ -2,11 +2,12 @@ from typing import Literal
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.utils import BaseLogger
|
from tianshou.utils import TensorboardLogger
|
||||||
|
|
||||||
|
|
||||||
class TestBaseLogger:
|
class TestTensorBoardLogger:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"input_dict, expected_output",
|
"input_dict, expected_output",
|
||||||
@ -20,7 +21,8 @@ class TestBaseLogger:
|
|||||||
| dict[str, dict[str, dict[str, int]]],
|
| dict[str, dict[str, dict[str, int]]],
|
||||||
expected_output: dict[str, int],
|
expected_output: dict[str, int],
|
||||||
) -> None:
|
) -> None:
|
||||||
result = BaseLogger.prepare_dict_for_logging(input_dict)
|
logger = TensorboardLogger(SummaryWriter('log/logger'))
|
||||||
|
result = logger.prepare_dict_for_logging(input_dict)
|
||||||
assert result == expected_output
|
assert result == expected_output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -36,7 +38,8 @@ class TestBaseLogger:
|
|||||||
delimiter: Literal["|", "."],
|
delimiter: Literal["|", "."],
|
||||||
expected_output: dict[str, int],
|
expected_output: dict[str, int],
|
||||||
) -> None:
|
) -> None:
|
||||||
result = BaseLogger.prepare_dict_for_logging(input_dict, delimiter=delimiter)
|
logger = TensorboardLogger(SummaryWriter('log/logger'))
|
||||||
|
result = logger.prepare_dict_for_logging(input_dict, delimiter=delimiter)
|
||||||
assert result == expected_output
|
assert result == expected_output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -56,7 +59,8 @@ class TestBaseLogger:
|
|||||||
exclude_arrays: bool,
|
exclude_arrays: bool,
|
||||||
expected_output: dict[str, np.ndarray],
|
expected_output: dict[str, np.ndarray],
|
||||||
) -> None:
|
) -> None:
|
||||||
result = BaseLogger.prepare_dict_for_logging(input_dict, exclude_arrays=exclude_arrays)
|
logger = TensorboardLogger(SummaryWriter('log/logger'))
|
||||||
|
result = logger.prepare_dict_for_logging(input_dict, exclude_arrays=exclude_arrays)
|
||||||
assert result.keys() == expected_output.keys()
|
assert result.keys() == expected_output.keys()
|
||||||
for val1, val2 in zip(result.values(), expected_output.values(), strict=True):
|
for val1, val2 in zip(result.values(), expected_output.values(), strict=True):
|
||||||
assert np.all(val1 == val2)
|
assert np.all(val1 == val2)
|
||||||
@ -72,5 +76,6 @@ class TestBaseLogger:
|
|||||||
input_dict: dict[str, tuple[Literal[1]] | dict[str, str | dict[str, int]]],
|
input_dict: dict[str, tuple[Literal[1]] | dict[str, str | dict[str, int]]],
|
||||||
expected_output: dict[str, int],
|
expected_output: dict[str, int],
|
||||||
) -> None:
|
) -> None:
|
||||||
result = BaseLogger.prepare_dict_for_logging(input_dict)
|
logger = TensorboardLogger(SummaryWriter('log/logger'))
|
||||||
|
result = logger.prepare_dict_for_logging(input_dict)
|
||||||
assert result == expected_output
|
assert result == expected_output
|
||||||
|
Loading…
x
Reference in New Issue
Block a user