Implement args/kwargs for init of norm_layers and activation (#788)
As mentioned in #770 , I have fixed the mismatch of args between the Net and MLP. Also, in order to initialize the norm_layers and activations, norm_args and act_args are added to the miniblock and related classes.
This commit is contained in:
parent
1037627a5b
commit
774d3d8e83
@ -18,21 +18,35 @@ from torch import nn
|
|||||||
from tianshou.data.batch import Batch
|
from tianshou.data.batch import Batch
|
||||||
|
|
||||||
ModuleType = Type[nn.Module]
|
ModuleType = Type[nn.Module]
|
||||||
|
ArgsType = Union[Tuple[Any, ...], Dict[Any, Any], Sequence[Tuple[Any, ...]],
|
||||||
|
Sequence[Dict[Any, Any]]]
|
||||||
|
|
||||||
|
|
||||||
def miniblock(
|
def miniblock(
|
||||||
input_size: int,
|
input_size: int,
|
||||||
output_size: int = 0,
|
output_size: int = 0,
|
||||||
norm_layer: Optional[ModuleType] = None,
|
norm_layer: Optional[ModuleType] = None,
|
||||||
|
norm_args: Optional[Union[Tuple[Any, ...], Dict[Any, Any]]] = None,
|
||||||
activation: Optional[ModuleType] = None,
|
activation: Optional[ModuleType] = None,
|
||||||
|
act_args: Optional[Union[Tuple[Any, ...], Dict[Any, Any]]] = None,
|
||||||
linear_layer: Type[nn.Linear] = nn.Linear,
|
linear_layer: Type[nn.Linear] = nn.Linear,
|
||||||
) -> List[nn.Module]:
|
) -> List[nn.Module]:
|
||||||
"""Construct a miniblock with given input/output-size, norm layer and \
|
"""Construct a miniblock with given input/output-size, norm layer and \
|
||||||
activation."""
|
activation."""
|
||||||
layers: List[nn.Module] = [linear_layer(input_size, output_size)]
|
layers: List[nn.Module] = [linear_layer(input_size, output_size)]
|
||||||
if norm_layer is not None:
|
if norm_layer is not None:
|
||||||
|
if isinstance(norm_args, tuple):
|
||||||
|
layers += [norm_layer(output_size, *norm_args)] # type: ignore
|
||||||
|
elif isinstance(norm_args, dict):
|
||||||
|
layers += [norm_layer(output_size, **norm_args)] # type: ignore
|
||||||
|
else:
|
||||||
layers += [norm_layer(output_size)] # type: ignore
|
layers += [norm_layer(output_size)] # type: ignore
|
||||||
if activation is not None:
|
if activation is not None:
|
||||||
|
if isinstance(act_args, tuple):
|
||||||
|
layers += [activation(*act_args)]
|
||||||
|
elif isinstance(act_args, dict):
|
||||||
|
layers += [activation(**act_args)]
|
||||||
|
else:
|
||||||
layers += [activation()]
|
layers += [activation()]
|
||||||
return layers
|
return layers
|
||||||
|
|
||||||
@ -68,7 +82,9 @@ class MLP(nn.Module):
|
|||||||
output_dim: int = 0,
|
output_dim: int = 0,
|
||||||
hidden_sizes: Sequence[int] = (),
|
hidden_sizes: Sequence[int] = (),
|
||||||
norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None,
|
norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None,
|
||||||
|
norm_args: Optional[ArgsType] = None,
|
||||||
activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
|
activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
|
||||||
|
act_args: Optional[ArgsType] = None,
|
||||||
device: Optional[Union[str, int, torch.device]] = None,
|
device: Optional[Union[str, int, torch.device]] = None,
|
||||||
linear_layer: Type[nn.Linear] = nn.Linear,
|
linear_layer: Type[nn.Linear] = nn.Linear,
|
||||||
flatten_input: bool = True,
|
flatten_input: bool = True,
|
||||||
@ -79,24 +95,41 @@ class MLP(nn.Module):
|
|||||||
if isinstance(norm_layer, list):
|
if isinstance(norm_layer, list):
|
||||||
assert len(norm_layer) == len(hidden_sizes)
|
assert len(norm_layer) == len(hidden_sizes)
|
||||||
norm_layer_list = norm_layer
|
norm_layer_list = norm_layer
|
||||||
|
if isinstance(norm_args, list):
|
||||||
|
assert len(norm_args) == len(hidden_sizes)
|
||||||
|
norm_args_list = norm_args
|
||||||
|
else:
|
||||||
|
norm_args_list = [norm_args for _ in range(len(hidden_sizes))]
|
||||||
else:
|
else:
|
||||||
norm_layer_list = [norm_layer for _ in range(len(hidden_sizes))]
|
norm_layer_list = [norm_layer for _ in range(len(hidden_sizes))]
|
||||||
|
norm_args_list = [norm_args for _ in range(len(hidden_sizes))]
|
||||||
else:
|
else:
|
||||||
norm_layer_list = [None] * len(hidden_sizes)
|
norm_layer_list = [None] * len(hidden_sizes)
|
||||||
|
norm_args_list = [None] * len(hidden_sizes)
|
||||||
if activation:
|
if activation:
|
||||||
if isinstance(activation, list):
|
if isinstance(activation, list):
|
||||||
assert len(activation) == len(hidden_sizes)
|
assert len(activation) == len(hidden_sizes)
|
||||||
activation_list = activation
|
activation_list = activation
|
||||||
|
if isinstance(act_args, list):
|
||||||
|
assert len(act_args) == len(hidden_sizes)
|
||||||
|
act_args_list = act_args
|
||||||
|
else:
|
||||||
|
act_args_list = [act_args for _ in range(len(hidden_sizes))]
|
||||||
else:
|
else:
|
||||||
activation_list = [activation for _ in range(len(hidden_sizes))]
|
activation_list = [activation for _ in range(len(hidden_sizes))]
|
||||||
|
act_args_list = [act_args for _ in range(len(hidden_sizes))]
|
||||||
else:
|
else:
|
||||||
activation_list = [None] * len(hidden_sizes)
|
activation_list = [None] * len(hidden_sizes)
|
||||||
|
act_args_list = [None] * len(hidden_sizes)
|
||||||
hidden_sizes = [input_dim] + list(hidden_sizes)
|
hidden_sizes = [input_dim] + list(hidden_sizes)
|
||||||
model = []
|
model = []
|
||||||
for in_dim, out_dim, norm, activ in zip(
|
for in_dim, out_dim, norm, norm_args, activ, act_args in zip(
|
||||||
hidden_sizes[:-1], hidden_sizes[1:], norm_layer_list, activation_list
|
hidden_sizes[:-1], hidden_sizes[1:], norm_layer_list, norm_args_list,
|
||||||
|
activation_list, act_args_list
|
||||||
):
|
):
|
||||||
model += miniblock(in_dim, out_dim, norm, activ, linear_layer)
|
model += miniblock(
|
||||||
|
in_dim, out_dim, norm, norm_args, activ, act_args, linear_layer
|
||||||
|
)
|
||||||
if output_dim > 0:
|
if output_dim > 0:
|
||||||
model += [linear_layer(hidden_sizes[-1], output_dim)]
|
model += [linear_layer(hidden_sizes[-1], output_dim)]
|
||||||
self.output_dim = output_dim or hidden_sizes[-1]
|
self.output_dim = output_dim or hidden_sizes[-1]
|
||||||
@ -161,8 +194,10 @@ class Net(nn.Module):
|
|||||||
state_shape: Union[int, Sequence[int]],
|
state_shape: Union[int, Sequence[int]],
|
||||||
action_shape: Union[int, Sequence[int]] = 0,
|
action_shape: Union[int, Sequence[int]] = 0,
|
||||||
hidden_sizes: Sequence[int] = (),
|
hidden_sizes: Sequence[int] = (),
|
||||||
norm_layer: Optional[ModuleType] = None,
|
norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None,
|
||||||
activation: Optional[ModuleType] = nn.ReLU,
|
norm_args: Optional[ArgsType] = None,
|
||||||
|
activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
|
||||||
|
act_args: Optional[ArgsType] = None,
|
||||||
device: Union[str, int, torch.device] = "cpu",
|
device: Union[str, int, torch.device] = "cpu",
|
||||||
softmax: bool = False,
|
softmax: bool = False,
|
||||||
concat: bool = False,
|
concat: bool = False,
|
||||||
@ -181,8 +216,8 @@ class Net(nn.Module):
|
|||||||
self.use_dueling = dueling_param is not None
|
self.use_dueling = dueling_param is not None
|
||||||
output_dim = action_dim if not self.use_dueling and not concat else 0
|
output_dim = action_dim if not self.use_dueling and not concat else 0
|
||||||
self.model = MLP(
|
self.model = MLP(
|
||||||
input_dim, output_dim, hidden_sizes, norm_layer, activation, device,
|
input_dim, output_dim, hidden_sizes, norm_layer, norm_args, activation,
|
||||||
linear_layer
|
act_args, device, linear_layer
|
||||||
)
|
)
|
||||||
self.output_dim = self.model.output_dim
|
self.output_dim = self.model.output_dim
|
||||||
if self.use_dueling: # dueling DQN
|
if self.use_dueling: # dueling DQN
|
||||||
@ -406,7 +441,9 @@ class BranchingNet(nn.Module):
|
|||||||
value_hidden_sizes: List[int] = [],
|
value_hidden_sizes: List[int] = [],
|
||||||
action_hidden_sizes: List[int] = [],
|
action_hidden_sizes: List[int] = [],
|
||||||
norm_layer: Optional[ModuleType] = None,
|
norm_layer: Optional[ModuleType] = None,
|
||||||
|
norm_args: Optional[ArgsType] = None,
|
||||||
activation: Optional[ModuleType] = nn.ReLU,
|
activation: Optional[ModuleType] = nn.ReLU,
|
||||||
|
act_args: Optional[ArgsType] = None,
|
||||||
device: Union[str, int, torch.device] = "cpu",
|
device: Union[str, int, torch.device] = "cpu",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -418,14 +455,14 @@ class BranchingNet(nn.Module):
|
|||||||
common_output_dim = 0
|
common_output_dim = 0
|
||||||
self.common = MLP(
|
self.common = MLP(
|
||||||
common_input_dim, common_output_dim, common_hidden_sizes, norm_layer,
|
common_input_dim, common_output_dim, common_hidden_sizes, norm_layer,
|
||||||
activation, device
|
norm_args, activation, act_args, device
|
||||||
)
|
)
|
||||||
# value network
|
# value network
|
||||||
value_input_dim = common_hidden_sizes[-1]
|
value_input_dim = common_hidden_sizes[-1]
|
||||||
value_output_dim = 1
|
value_output_dim = 1
|
||||||
self.value = MLP(
|
self.value = MLP(
|
||||||
value_input_dim, value_output_dim, value_hidden_sizes, norm_layer,
|
value_input_dim, value_output_dim, value_hidden_sizes, norm_layer,
|
||||||
activation, device
|
norm_args, activation, act_args, device
|
||||||
)
|
)
|
||||||
# action branching network
|
# action branching network
|
||||||
action_input_dim = common_hidden_sizes[-1]
|
action_input_dim = common_hidden_sizes[-1]
|
||||||
@ -434,7 +471,7 @@ class BranchingNet(nn.Module):
|
|||||||
[
|
[
|
||||||
MLP(
|
MLP(
|
||||||
action_input_dim, action_output_dim, action_hidden_sizes,
|
action_input_dim, action_output_dim, action_hidden_sizes,
|
||||||
norm_layer, activation, device
|
norm_layer, norm_args, activation, act_args, device
|
||||||
) for _ in range(self.num_branches)
|
) for _ in range(self.num_branches)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user