2023-10-23 15:45:14 +02:00

26 lines
612 B
Python

from torch import nn
from mpd.models.layers.layers import MLP
class MLPModel(nn.Module):
def __init__(self, in_dim=16, out_dim=16, input_field='x',
output_field='y',
pretrained_dir=None,
**kwargs):
super().__init__()
self.input_field = input_field
self.output_field = output_field
self.in_dim = in_dim
self.out_dim = out_dim
self._net = MLP(in_dim, out_dim, **kwargs)
def forward(self, input):
x = input[self.input_field]
out = self._net(x)
return {self.output_field: out}