Implement args/kwargs for init of norm_layers and activation (#788)

As mentioned in #770 , I have fixed the mismatch of args between the Net
and MLP. Also, in order to initialize the norm_layers and activations,
norm_args and act_args are added to the miniblock and related classes.
This commit is contained in:
janofsssun 2022-12-26 19:58:03 -08:00 committed by GitHub
parent 1037627a5b
commit 774d3d8e83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -18,22 +18,36 @@ from torch import nn
from tianshou.data.batch import Batch
ModuleType = Type[nn.Module]
ArgsType = Union[Tuple[Any, ...], Dict[Any, Any], Sequence[Tuple[Any, ...]],
Sequence[Dict[Any, Any]]]
def miniblock(
input_size: int,
output_size: int = 0,
norm_layer: Optional[ModuleType] = None,
norm_args: Optional[Union[Tuple[Any, ...], Dict[Any, Any]]] = None,
activation: Optional[ModuleType] = None,
act_args: Optional[Union[Tuple[Any, ...], Dict[Any, Any]]] = None,
linear_layer: Type[nn.Linear] = nn.Linear,
) -> List[nn.Module]:
"""Construct a miniblock with given input/output-size, norm layer and \
activation."""
layers: List[nn.Module] = [linear_layer(input_size, output_size)]
if norm_layer is not None:
layers += [norm_layer(output_size)] # type: ignore
if isinstance(norm_args, tuple):
layers += [norm_layer(output_size, *norm_args)] # type: ignore
elif isinstance(norm_args, dict):
layers += [norm_layer(output_size, **norm_args)] # type: ignore
else:
layers += [norm_layer(output_size)] # type: ignore
if activation is not None:
layers += [activation()]
if isinstance(act_args, tuple):
layers += [activation(*act_args)]
elif isinstance(act_args, dict):
layers += [activation(**act_args)]
else:
layers += [activation()]
return layers
@ -68,7 +82,9 @@ class MLP(nn.Module):
output_dim: int = 0,
hidden_sizes: Sequence[int] = (),
norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None,
norm_args: Optional[ArgsType] = None,
activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
act_args: Optional[ArgsType] = None,
device: Optional[Union[str, int, torch.device]] = None,
linear_layer: Type[nn.Linear] = nn.Linear,
flatten_input: bool = True,
@ -79,24 +95,41 @@ class MLP(nn.Module):
if isinstance(norm_layer, list):
assert len(norm_layer) == len(hidden_sizes)
norm_layer_list = norm_layer
if isinstance(norm_args, list):
assert len(norm_args) == len(hidden_sizes)
norm_args_list = norm_args
else:
norm_args_list = [norm_args for _ in range(len(hidden_sizes))]
else:
norm_layer_list = [norm_layer for _ in range(len(hidden_sizes))]
norm_args_list = [norm_args for _ in range(len(hidden_sizes))]
else:
norm_layer_list = [None] * len(hidden_sizes)
norm_args_list = [None] * len(hidden_sizes)
if activation:
if isinstance(activation, list):
assert len(activation) == len(hidden_sizes)
activation_list = activation
if isinstance(act_args, list):
assert len(act_args) == len(hidden_sizes)
act_args_list = act_args
else:
act_args_list = [act_args for _ in range(len(hidden_sizes))]
else:
activation_list = [activation for _ in range(len(hidden_sizes))]
act_args_list = [act_args for _ in range(len(hidden_sizes))]
else:
activation_list = [None] * len(hidden_sizes)
act_args_list = [None] * len(hidden_sizes)
hidden_sizes = [input_dim] + list(hidden_sizes)
model = []
for in_dim, out_dim, norm, activ in zip(
hidden_sizes[:-1], hidden_sizes[1:], norm_layer_list, activation_list
for in_dim, out_dim, norm, norm_args, activ, act_args in zip(
hidden_sizes[:-1], hidden_sizes[1:], norm_layer_list, norm_args_list,
activation_list, act_args_list
):
model += miniblock(in_dim, out_dim, norm, activ, linear_layer)
model += miniblock(
in_dim, out_dim, norm, norm_args, activ, act_args, linear_layer
)
if output_dim > 0:
model += [linear_layer(hidden_sizes[-1], output_dim)]
self.output_dim = output_dim or hidden_sizes[-1]
@ -161,8 +194,10 @@ class Net(nn.Module):
state_shape: Union[int, Sequence[int]],
action_shape: Union[int, Sequence[int]] = 0,
hidden_sizes: Sequence[int] = (),
norm_layer: Optional[ModuleType] = None,
activation: Optional[ModuleType] = nn.ReLU,
norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None,
norm_args: Optional[ArgsType] = None,
activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
act_args: Optional[ArgsType] = None,
device: Union[str, int, torch.device] = "cpu",
softmax: bool = False,
concat: bool = False,
@ -181,8 +216,8 @@ class Net(nn.Module):
self.use_dueling = dueling_param is not None
output_dim = action_dim if not self.use_dueling and not concat else 0
self.model = MLP(
input_dim, output_dim, hidden_sizes, norm_layer, activation, device,
linear_layer
input_dim, output_dim, hidden_sizes, norm_layer, norm_args, activation,
act_args, device, linear_layer
)
self.output_dim = self.model.output_dim
if self.use_dueling: # dueling DQN
@ -406,7 +441,9 @@ class BranchingNet(nn.Module):
value_hidden_sizes: List[int] = [],
action_hidden_sizes: List[int] = [],
norm_layer: Optional[ModuleType] = None,
norm_args: Optional[ArgsType] = None,
activation: Optional[ModuleType] = nn.ReLU,
act_args: Optional[ArgsType] = None,
device: Union[str, int, torch.device] = "cpu",
) -> None:
super().__init__()
@ -418,14 +455,14 @@ class BranchingNet(nn.Module):
common_output_dim = 0
self.common = MLP(
common_input_dim, common_output_dim, common_hidden_sizes, norm_layer,
activation, device
norm_args, activation, act_args, device
)
# value network
value_input_dim = common_hidden_sizes[-1]
value_output_dim = 1
self.value = MLP(
value_input_dim, value_output_dim, value_hidden_sizes, norm_layer,
activation, device
norm_args, activation, act_args, device
)
# action branching network
action_input_dim = common_hidden_sizes[-1]
@ -434,7 +471,7 @@ class BranchingNet(nn.Module):
[
MLP(
action_input_dim, action_output_dim, action_hidden_sizes,
norm_layer, activation, device
norm_layer, norm_args, activation, act_args, device
) for _ in range(self.num_branches)
]
)