Implement args/kwargs for init of norm_layers and activation (#788)
As mentioned in #770 , I have fixed the mismatch of args between the Net and MLP. Also, in order to initialize the norm_layers and activations, norm_args and act_args are added to the miniblock and related classes.
This commit is contained in:
parent
1037627a5b
commit
774d3d8e83
@ -18,22 +18,36 @@ from torch import nn
|
||||
from tianshou.data.batch import Batch
|
||||
|
||||
ModuleType = Type[nn.Module]
|
||||
ArgsType = Union[Tuple[Any, ...], Dict[Any, Any], Sequence[Tuple[Any, ...]],
|
||||
Sequence[Dict[Any, Any]]]
|
||||
|
||||
|
||||
def miniblock(
|
||||
input_size: int,
|
||||
output_size: int = 0,
|
||||
norm_layer: Optional[ModuleType] = None,
|
||||
norm_args: Optional[Union[Tuple[Any, ...], Dict[Any, Any]]] = None,
|
||||
activation: Optional[ModuleType] = None,
|
||||
act_args: Optional[Union[Tuple[Any, ...], Dict[Any, Any]]] = None,
|
||||
linear_layer: Type[nn.Linear] = nn.Linear,
|
||||
) -> List[nn.Module]:
|
||||
"""Construct a miniblock with given input/output-size, norm layer and \
|
||||
activation."""
|
||||
layers: List[nn.Module] = [linear_layer(input_size, output_size)]
|
||||
if norm_layer is not None:
|
||||
layers += [norm_layer(output_size)] # type: ignore
|
||||
if isinstance(norm_args, tuple):
|
||||
layers += [norm_layer(output_size, *norm_args)] # type: ignore
|
||||
elif isinstance(norm_args, dict):
|
||||
layers += [norm_layer(output_size, **norm_args)] # type: ignore
|
||||
else:
|
||||
layers += [norm_layer(output_size)] # type: ignore
|
||||
if activation is not None:
|
||||
layers += [activation()]
|
||||
if isinstance(act_args, tuple):
|
||||
layers += [activation(*act_args)]
|
||||
elif isinstance(act_args, dict):
|
||||
layers += [activation(**act_args)]
|
||||
else:
|
||||
layers += [activation()]
|
||||
return layers
|
||||
|
||||
|
||||
@ -68,7 +82,9 @@ class MLP(nn.Module):
|
||||
output_dim: int = 0,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None,
|
||||
norm_args: Optional[ArgsType] = None,
|
||||
activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
|
||||
act_args: Optional[ArgsType] = None,
|
||||
device: Optional[Union[str, int, torch.device]] = None,
|
||||
linear_layer: Type[nn.Linear] = nn.Linear,
|
||||
flatten_input: bool = True,
|
||||
@ -79,24 +95,41 @@ class MLP(nn.Module):
|
||||
if isinstance(norm_layer, list):
|
||||
assert len(norm_layer) == len(hidden_sizes)
|
||||
norm_layer_list = norm_layer
|
||||
if isinstance(norm_args, list):
|
||||
assert len(norm_args) == len(hidden_sizes)
|
||||
norm_args_list = norm_args
|
||||
else:
|
||||
norm_args_list = [norm_args for _ in range(len(hidden_sizes))]
|
||||
else:
|
||||
norm_layer_list = [norm_layer for _ in range(len(hidden_sizes))]
|
||||
norm_args_list = [norm_args for _ in range(len(hidden_sizes))]
|
||||
else:
|
||||
norm_layer_list = [None] * len(hidden_sizes)
|
||||
norm_args_list = [None] * len(hidden_sizes)
|
||||
if activation:
|
||||
if isinstance(activation, list):
|
||||
assert len(activation) == len(hidden_sizes)
|
||||
activation_list = activation
|
||||
if isinstance(act_args, list):
|
||||
assert len(act_args) == len(hidden_sizes)
|
||||
act_args_list = act_args
|
||||
else:
|
||||
act_args_list = [act_args for _ in range(len(hidden_sizes))]
|
||||
else:
|
||||
activation_list = [activation for _ in range(len(hidden_sizes))]
|
||||
act_args_list = [act_args for _ in range(len(hidden_sizes))]
|
||||
else:
|
||||
activation_list = [None] * len(hidden_sizes)
|
||||
act_args_list = [None] * len(hidden_sizes)
|
||||
hidden_sizes = [input_dim] + list(hidden_sizes)
|
||||
model = []
|
||||
for in_dim, out_dim, norm, activ in zip(
|
||||
hidden_sizes[:-1], hidden_sizes[1:], norm_layer_list, activation_list
|
||||
for in_dim, out_dim, norm, norm_args, activ, act_args in zip(
|
||||
hidden_sizes[:-1], hidden_sizes[1:], norm_layer_list, norm_args_list,
|
||||
activation_list, act_args_list
|
||||
):
|
||||
model += miniblock(in_dim, out_dim, norm, activ, linear_layer)
|
||||
model += miniblock(
|
||||
in_dim, out_dim, norm, norm_args, activ, act_args, linear_layer
|
||||
)
|
||||
if output_dim > 0:
|
||||
model += [linear_layer(hidden_sizes[-1], output_dim)]
|
||||
self.output_dim = output_dim or hidden_sizes[-1]
|
||||
@ -161,8 +194,10 @@ class Net(nn.Module):
|
||||
state_shape: Union[int, Sequence[int]],
|
||||
action_shape: Union[int, Sequence[int]] = 0,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
norm_layer: Optional[ModuleType] = None,
|
||||
activation: Optional[ModuleType] = nn.ReLU,
|
||||
norm_layer: Optional[Union[ModuleType, Sequence[ModuleType]]] = None,
|
||||
norm_args: Optional[ArgsType] = None,
|
||||
activation: Optional[Union[ModuleType, Sequence[ModuleType]]] = nn.ReLU,
|
||||
act_args: Optional[ArgsType] = None,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
softmax: bool = False,
|
||||
concat: bool = False,
|
||||
@ -181,8 +216,8 @@ class Net(nn.Module):
|
||||
self.use_dueling = dueling_param is not None
|
||||
output_dim = action_dim if not self.use_dueling and not concat else 0
|
||||
self.model = MLP(
|
||||
input_dim, output_dim, hidden_sizes, norm_layer, activation, device,
|
||||
linear_layer
|
||||
input_dim, output_dim, hidden_sizes, norm_layer, norm_args, activation,
|
||||
act_args, device, linear_layer
|
||||
)
|
||||
self.output_dim = self.model.output_dim
|
||||
if self.use_dueling: # dueling DQN
|
||||
@ -406,7 +441,9 @@ class BranchingNet(nn.Module):
|
||||
value_hidden_sizes: List[int] = [],
|
||||
action_hidden_sizes: List[int] = [],
|
||||
norm_layer: Optional[ModuleType] = None,
|
||||
norm_args: Optional[ArgsType] = None,
|
||||
activation: Optional[ModuleType] = nn.ReLU,
|
||||
act_args: Optional[ArgsType] = None,
|
||||
device: Union[str, int, torch.device] = "cpu",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -418,14 +455,14 @@ class BranchingNet(nn.Module):
|
||||
common_output_dim = 0
|
||||
self.common = MLP(
|
||||
common_input_dim, common_output_dim, common_hidden_sizes, norm_layer,
|
||||
activation, device
|
||||
norm_args, activation, act_args, device
|
||||
)
|
||||
# value network
|
||||
value_input_dim = common_hidden_sizes[-1]
|
||||
value_output_dim = 1
|
||||
self.value = MLP(
|
||||
value_input_dim, value_output_dim, value_hidden_sizes, norm_layer,
|
||||
activation, device
|
||||
norm_args, activation, act_args, device
|
||||
)
|
||||
# action branching network
|
||||
action_input_dim = common_hidden_sizes[-1]
|
||||
@ -434,7 +471,7 @@ class BranchingNet(nn.Module):
|
||||
[
|
||||
MLP(
|
||||
action_input_dim, action_output_dim, action_hidden_sizes,
|
||||
norm_layer, activation, device
|
||||
norm_layer, norm_args, activation, act_args, device
|
||||
) for _ in range(self.num_branches)
|
||||
]
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user