Simplify parameter transformers by applying ParamTransformerChangeValue
This commit is contained in:
parent
17ef4dd5eb
commit
305b30a6c1
@ -189,36 +189,27 @@ class ParamTransformerAutoAlpha(ParamTransformer):
|
|||||||
kwargs[self.key] = alpha.create_auto_alpha(data.envs, data.optim_factory, data.device)
|
kwargs[self.key] = alpha.create_auto_alpha(data.envs, data.optim_factory, data.device)
|
||||||
|
|
||||||
|
|
||||||
class ParamTransformerNoiseFactory(ParamTransformer):
|
class ParamTransformerNoiseFactory(ParamTransformerChangeValue):
|
||||||
def __init__(self, key: str):
|
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
|
||||||
self.key = key
|
|
||||||
|
|
||||||
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
|
||||||
value = params[self.key]
|
|
||||||
if isinstance(value, NoiseFactory):
|
if isinstance(value, NoiseFactory):
|
||||||
params[self.key] = value.create_noise(data.envs)
|
value = value.create_noise(data.envs)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class ParamTransformerFloatEnvParamFactory(ParamTransformer):
|
class ParamTransformerFloatEnvParamFactory(ParamTransformerChangeValue):
|
||||||
def __init__(self, key: str):
|
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
|
||||||
self.key = key
|
|
||||||
|
|
||||||
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
|
|
||||||
value = kwargs[self.key]
|
|
||||||
if isinstance(value, EnvValueFactory):
|
if isinstance(value, EnvValueFactory):
|
||||||
kwargs[self.key] = value.create_value(data.envs)
|
value = value.create_value(data.envs)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class ParamTransformerDistributionFunction(ParamTransformer):
|
class ParamTransformerDistributionFunction(ParamTransformerChangeValue):
|
||||||
def __init__(self, key: str):
|
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
|
||||||
self.key = key
|
|
||||||
|
|
||||||
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
|
|
||||||
value = kwargs[self.key]
|
|
||||||
if value == "default":
|
if value == "default":
|
||||||
kwargs[self.key] = DistributionFunctionFactoryDefault().create_dist_fn(data.envs)
|
value = DistributionFunctionFactoryDefault().create_dist_fn(data.envs)
|
||||||
elif isinstance(value, DistributionFunctionFactory):
|
elif isinstance(value, DistributionFunctionFactory):
|
||||||
kwargs[self.key] = value.create_dist_fn(data.envs)
|
value = value.create_dist_fn(data.envs)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class ParamTransformerActionScaling(ParamTransformerChangeValue):
|
class ParamTransformerActionScaling(ParamTransformerChangeValue):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user