From 305b30a6c17d9cef2c069449f808b51f6f72d314 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 10 Oct 2023 16:12:29 +0200 Subject: [PATCH] Simplify parameter transformers by applying ParamTransformerChangeValue --- tianshou/highlevel/params/policy_params.py | 35 ++++++++-------------- 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 7bf5ef4..540eec9 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -189,36 +189,27 @@ class ParamTransformerAutoAlpha(ParamTransformer): kwargs[self.key] = alpha.create_auto_alpha(data.envs, data.optim_factory, data.device) -class ParamTransformerNoiseFactory(ParamTransformer): - def __init__(self, key: str): - self.key = key - - def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: - value = params[self.key] +class ParamTransformerNoiseFactory(ParamTransformerChangeValue): + def change_value(self, value: Any, data: ParamTransformerData) -> Any: if isinstance(value, NoiseFactory): - params[self.key] = value.create_noise(data.envs) + value = value.create_noise(data.envs) + return value -class ParamTransformerFloatEnvParamFactory(ParamTransformer): - def __init__(self, key: str): - self.key = key - - def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: - value = kwargs[self.key] +class ParamTransformerFloatEnvParamFactory(ParamTransformerChangeValue): + def change_value(self, value: Any, data: ParamTransformerData) -> Any: if isinstance(value, EnvValueFactory): - kwargs[self.key] = value.create_value(data.envs) + value = value.create_value(data.envs) + return value -class ParamTransformerDistributionFunction(ParamTransformer): - def __init__(self, key: str): - self.key = key - - def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: - value = kwargs[self.key] +class ParamTransformerDistributionFunction(ParamTransformerChangeValue): + def change_value(self, value: Any, data: ParamTransformerData) -> Any: if value == "default": - kwargs[self.key] = DistributionFunctionFactoryDefault().create_dist_fn(data.envs) + value = DistributionFunctionFactoryDefault().create_dist_fn(data.envs) elif isinstance(value, DistributionFunctionFactory): - kwargs[self.key] = value.create_dist_fn(data.envs) + value = value.create_dist_fn(data.envs) + return value class ParamTransformerActionScaling(ParamTransformerChangeValue):