miniblock: Fix type annotation of linear_layer

This commit is contained in:
Dominik Jain 2023-10-18 20:57:43 +02:00
parent 9c5ee55644
commit cc6f0162ff

View File

@ -22,7 +22,7 @@ def miniblock(
norm_args: tuple[Any, ...] | dict[Any, Any] | None = None,
activation: ModuleType | None = None,
act_args: tuple[Any, ...] | dict[Any, Any] | None = None,
linear_layer: type[nn.Linear] = nn.Linear,
linear_layer: TLinearLayer = nn.Linear,
) -> list[nn.Module]:
"""Construct a miniblock with given input/output-size, norm layer and activation."""
layers: list[nn.Module] = [linear_layer(input_size, output_size)]