Simplify parameter transformers by applying ParamTransformerChangeValue

This commit is contained in:
Dominik Jain 2023-10-10 16:12:29 +02:00
parent 17ef4dd5eb
commit 305b30a6c1

View File

@ -189,36 +189,27 @@ class ParamTransformerAutoAlpha(ParamTransformer):
kwargs[self.key] = alpha.create_auto_alpha(data.envs, data.optim_factory, data.device)
class ParamTransformerNoiseFactory(ParamTransformer):
def __init__(self, key: str):
self.key = key
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
value = params[self.key]
class ParamTransformerNoiseFactory(ParamTransformerChangeValue):
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
if isinstance(value, NoiseFactory):
params[self.key] = value.create_noise(data.envs)
value = value.create_noise(data.envs)
return value
class ParamTransformerFloatEnvParamFactory(ParamTransformer):
def __init__(self, key: str):
self.key = key
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
value = kwargs[self.key]
class ParamTransformerFloatEnvParamFactory(ParamTransformerChangeValue):
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
if isinstance(value, EnvValueFactory):
kwargs[self.key] = value.create_value(data.envs)
value = value.create_value(data.envs)
return value
class ParamTransformerDistributionFunction(ParamTransformer):
def __init__(self, key: str):
self.key = key
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
value = kwargs[self.key]
class ParamTransformerDistributionFunction(ParamTransformerChangeValue):
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
if value == "default":
kwargs[self.key] = DistributionFunctionFactoryDefault().create_dist_fn(data.envs)
value = DistributionFunctionFactoryDefault().create_dist_fn(data.envs)
elif isinstance(value, DistributionFunctionFactory):
kwargs[self.key] = value.create_dist_fn(data.envs)
value = value.create_dist_fn(data.envs)
return value
class ParamTransformerActionScaling(ParamTransformerChangeValue):