diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 7bf5ef4..540eec9 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -189,36 +189,27 @@ class ParamTransformerAutoAlpha(ParamTransformer): kwargs[self.key] = alpha.create_auto_alpha(data.envs, data.optim_factory, data.device) -class ParamTransformerNoiseFactory(ParamTransformer): - def __init__(self, key: str): - self.key = key - - def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None: - value = params[self.key] +class ParamTransformerNoiseFactory(ParamTransformerChangeValue): + def change_value(self, value: Any, data: ParamTransformerData) -> Any: if isinstance(value, NoiseFactory): - params[self.key] = value.create_noise(data.envs) + value = value.create_noise(data.envs) + return value -class ParamTransformerFloatEnvParamFactory(ParamTransformer): - def __init__(self, key: str): - self.key = key - - def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: - value = kwargs[self.key] +class ParamTransformerFloatEnvParamFactory(ParamTransformerChangeValue): + def change_value(self, value: Any, data: ParamTransformerData) -> Any: if isinstance(value, EnvValueFactory): - kwargs[self.key] = value.create_value(data.envs) + value = value.create_value(data.envs) + return value -class ParamTransformerDistributionFunction(ParamTransformer): - def __init__(self, key: str): - self.key = key - - def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None: - value = kwargs[self.key] +class ParamTransformerDistributionFunction(ParamTransformerChangeValue): + def change_value(self, value: Any, data: ParamTransformerData) -> Any: if value == "default": - kwargs[self.key] = DistributionFunctionFactoryDefault().create_dist_fn(data.envs) + value = DistributionFunctionFactoryDefault().create_dist_fn(data.envs) elif isinstance(value, DistributionFunctionFactory): - kwargs[self.key] = value.create_dist_fn(data.envs) + value = value.create_dist_fn(data.envs) + return value class ParamTransformerActionScaling(ParamTransformerChangeValue):