From ce5fa0dfacc6644adc8439b06136e308a25df481 Mon Sep 17 00:00:00 2001 From: Maximilian Huettenrauch Date: Wed, 27 Mar 2024 13:55:22 +0100 Subject: [PATCH] fixed logger test --- test/base/test_logger.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/test/base/test_logger.py b/test/base/test_logger.py index 69c02a2..40dfbbb 100644 --- a/test/base/test_logger.py +++ b/test/base/test_logger.py @@ -2,11 +2,12 @@ from typing import Literal import numpy as np import pytest +from torch.utils.tensorboard import SummaryWriter -from tianshou.utils import BaseLogger +from tianshou.utils import TensorboardLogger -class TestBaseLogger: +class TestTensorBoardLogger: @staticmethod @pytest.mark.parametrize( "input_dict, expected_output", @@ -20,7 +21,8 @@ class TestBaseLogger: | dict[str, dict[str, dict[str, int]]], expected_output: dict[str, int], ) -> None: - result = BaseLogger.prepare_dict_for_logging(input_dict) + logger = TensorboardLogger(SummaryWriter('log/logger')) + result = logger.prepare_dict_for_logging(input_dict) assert result == expected_output @staticmethod @@ -36,7 +38,8 @@ class TestBaseLogger: delimiter: Literal["|", "."], expected_output: dict[str, int], ) -> None: - result = BaseLogger.prepare_dict_for_logging(input_dict, delimiter=delimiter) + logger = TensorboardLogger(SummaryWriter('log/logger')) + result = logger.prepare_dict_for_logging(input_dict, delimiter=delimiter) assert result == expected_output @staticmethod @@ -56,7 +59,8 @@ class TestBaseLogger: exclude_arrays: bool, expected_output: dict[str, np.ndarray], ) -> None: - result = BaseLogger.prepare_dict_for_logging(input_dict, exclude_arrays=exclude_arrays) + logger = TensorboardLogger(SummaryWriter('log/logger')) + result = logger.prepare_dict_for_logging(input_dict, exclude_arrays=exclude_arrays) assert result.keys() == expected_output.keys() for val1, val2 in zip(result.values(), expected_output.values(), strict=True): assert np.all(val1 == val2) @@ -72,5 +76,6 @@ class TestBaseLogger: input_dict: dict[str, tuple[Literal[1]] | dict[str, str | dict[str, int]]], expected_output: dict[str, int], ) -> None: - result = BaseLogger.prepare_dict_for_logging(input_dict) + logger = TensorboardLogger(SummaryWriter('log/logger')) + result = logger.prepare_dict_for_logging(input_dict) assert result == expected_output