From c7d0cbb5d3ca4c077afda329f7e196f4804f6509 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 18 Oct 2023 12:26:49 +0200 Subject: [PATCH] Experiment: Fix return type annotation, remove unused type arguments --- tianshou/highlevel/experiment.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 730bc61..29488c0 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -4,7 +4,7 @@ from abc import abstractmethod from collections.abc import Callable, Sequence from dataclasses import dataclass from pprint import pformat -from typing import Any, Generic, Self, TypeVar +from typing import Any, Self import numpy as np import torch @@ -78,14 +78,11 @@ from tianshou.highlevel.trainer import ( ) from tianshou.highlevel.world import World from tianshou.policy import BasePolicy -from tianshou.trainer import BaseTrainer from tianshou.utils import LazyLogger, logging from tianshou.utils.logging import datetime_tag from tianshou.utils.string import ToStringMixin log = logging.getLogger(__name__) -TPolicy = TypeVar("TPolicy", bound=BasePolicy) -TTrainer = TypeVar("TTrainer", bound=BaseTrainer) @dataclass @@ -124,7 +121,7 @@ class ExperimentResult: """dictionary of results as returned by the trained (if any)""" -class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): +class Experiment(ToStringMixin): """Represents a reinforcement learning experiment. An experiment is composed only of configuration and factory objects, which themselves @@ -152,7 +149,7 @@ class Experiment(Generic[TPolicy, TTrainer], ToStringMixin): self.env_config = env_config @classmethod - def from_directory(cls, directory: str, restore_policy: bool = True) -> Self: + def from_directory(cls, directory: str, restore_policy: bool = True) -> "Experiment": """Restores an experiment from a previously stored pickle. :param directory: persistence directory of a previous run, in which a pickled experiment is found