'''Define the Layers Derived from - https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/132907dd272e2cc92e3c10e6c4e783a87ff8893d/transformer/Layers.py ''' import torch.nn as nn import torch import torch.utils.checkpoint from SubLayers import MultiHeadAttention, PositionwiseFeedForward class EncoderLayer(nn.Module): ''' Single Encoder layer, that consists of a MHA layers and positiion-wise feedforward layer. ''' def __init__(self, d_model, d_inner, n_head, d_k, d_v): ''' Initialize the module. :param d_model: Dimension of input/output of this layer :param d_inner: Dimension of the hidden layer of hte position-wise feedforward layer :param n_head: Number of self-attention modules :param d_k: Dimension of each Key :param d_v: Dimension of each Value :param dropout: Argument to the dropout layer. ''' super(EncoderLayer, self).__init__() self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v) self.pos_ffn = PositionwiseFeedForward(d_model, d_inner) def forward(self, enc_input, slf_attn_mask=None): ''' The forward module: :param enc_input: The input to the encoder. :param slf_attn_mask: TODO ...... ''' # # Without gradient Checking # enc_output = self.slf_attn( # enc_input, enc_input, enc_input, mask=slf_attn_mask) # With Gradient Checking # enc_output = torch.utils.checkpoint.checkpoint(self.slf_attn, # enc_input, enc_input, enc_input, slf_attn_mask) enc_output = self.slf_attn(enc_input, enc_input, enc_input, slf_attn_mask) # enc_output, enc_slf_attn = self.slf_attn( # enc_input, enc_input, enc_input, mask=slf_attn_mask) enc_output = self.pos_ffn(enc_output) return enc_output class DecoderLayer(nn.Module): ''' Compose with three layers ''' def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): ''' Initialize the Layer :param d_model: Dimension of input/output this layer. :param d_inner: Dimension of hidden layer of the position wise FFN :param n_head: Number of self-attention modules. :param d_k: Dimension of each Key. :param d_v: Dimension of each Value. :param dropout: Argument to the dropout layer. ''' super(DecoderLayer, self).__init__() # self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) def forward(self, dec_input, enc_output, slf_attn_mask=None, dec_enc_attn_mask=None): ''' Callback function :param dec_input: :param enc_output: :param slf_attn_mask: :param dec_enc_attn_mask: ''' # dec_output, dec_slf_attn = self.slf_attn(dec_input, dec_input, dec_input, mask=slf_attn_mask) dec_output, dec_enc_attn = self.enc_attn(dec_input, enc_output, enc_output, mask=dec_enc_attn_mask) dec_output = self.pos_ffn(dec_output) return dec_output, dec_enc_attn