Simplify parameter transformers by applying ParamTransformerChangeValue
This commit is contained in:
parent
17ef4dd5eb
commit
305b30a6c1
@ -189,36 +189,27 @@ class ParamTransformerAutoAlpha(ParamTransformer):
|
||||
kwargs[self.key] = alpha.create_auto_alpha(data.envs, data.optim_factory, data.device)
|
||||
|
||||
|
||||
class ParamTransformerNoiseFactory(ParamTransformer):
|
||||
def __init__(self, key: str):
|
||||
self.key = key
|
||||
|
||||
def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
|
||||
value = params[self.key]
|
||||
class ParamTransformerNoiseFactory(ParamTransformerChangeValue):
|
||||
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
|
||||
if isinstance(value, NoiseFactory):
|
||||
params[self.key] = value.create_noise(data.envs)
|
||||
value = value.create_noise(data.envs)
|
||||
return value
|
||||
|
||||
|
||||
class ParamTransformerFloatEnvParamFactory(ParamTransformer):
|
||||
def __init__(self, key: str):
|
||||
self.key = key
|
||||
|
||||
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
|
||||
value = kwargs[self.key]
|
||||
class ParamTransformerFloatEnvParamFactory(ParamTransformerChangeValue):
|
||||
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
|
||||
if isinstance(value, EnvValueFactory):
|
||||
kwargs[self.key] = value.create_value(data.envs)
|
||||
value = value.create_value(data.envs)
|
||||
return value
|
||||
|
||||
|
||||
class ParamTransformerDistributionFunction(ParamTransformer):
|
||||
def __init__(self, key: str):
|
||||
self.key = key
|
||||
|
||||
def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
|
||||
value = kwargs[self.key]
|
||||
class ParamTransformerDistributionFunction(ParamTransformerChangeValue):
|
||||
def change_value(self, value: Any, data: ParamTransformerData) -> Any:
|
||||
if value == "default":
|
||||
kwargs[self.key] = DistributionFunctionFactoryDefault().create_dist_fn(data.envs)
|
||||
value = DistributionFunctionFactoryDefault().create_dist_fn(data.envs)
|
||||
elif isinstance(value, DistributionFunctionFactory):
|
||||
kwargs[self.key] = value.create_dist_fn(data.envs)
|
||||
value = value.create_dist_fn(data.envs)
|
||||
return value
|
||||
|
||||
|
||||
class ParamTransformerActionScaling(ParamTransformerChangeValue):
|
||||
|
Loading…
x
Reference in New Issue
Block a user