Blog

The Fifth Elephant blog

Simrat Hanspal

@simrathanspal

Decoding Llama3: Part 6 - Feed Forward Network

Submitted Jun 18, 2024

We are continuing our series of Decoding Llama3 with Feed Forward Network and SiLU activation function.

FFN intro

Introduction

embeddings into internal embedding which generalizes well.

Let us recall the FFN architecture used in Transformer

FFN(x) = ReLU(xW1 +b).W2+b

ReLU, allows only values greater than 0.
Relu
Picture from PyTorch.

Hence we can rewrite our FFN layers as

FFN(x) = max(0, xW1 +b).W2 + b

SiLU - activation function

Llama3 uses SwiGLU.

FFN(x) = (Swish(x.W1) * x.W3).W2

Swish is the same as SiLU (Sigmoid Linear Unit) when beta =1

Swish
Picture from Umar Jamil’s lecture on Llama.

SiLU
Picture from PyTorch.

Code walkthrough

class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: Optional[float],
    ):
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = ColumnParallelLinear(
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
        )
        self.w2 = RowParallelLinear(
            hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
        )
        self.w3 = ColumnParallelLinear(
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
        )

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))
  1. The FeedForwardNetwork at the end of the attention block consolidates all of the sentence learnings for the output task of predicting the next token in decoders (all text generation models like Llama).
  2. Transformer and many other models following Transformer have kept a constant FFN architecture with a single hidden layer that grows 4x the size of model_dim and then converges back to model_dim to consolidate the learning.
  3. Because of Grouped-Query attention, we have a reduction in parameters. Llama3 architects used these extra parameters in FFN, keeping the overall parameter size of the network the same and comparable with the older releases.
    Doing so has potentially helped the performance.
  4. At line number 194, we start with taking the inputs -
    dim: model embedding dimension
    hidden_dim: hidden layer dimension
    multiple_of: explained below
    ffn_dim_multiplier: multiplier explained below
  5. The new multiplier of FFN is an interesting configurable setting.
    ffn_dim_multiplier is the multiplier value.
    But in order to keep the scale of the hidden_dim to the expected size, like multiple of large power of 2 like 256, multiple_of was introduced. The idea here is that it will get you to the closest multiple of multiple_of.
  6. At line number 206, let us take an example and understand this
    Let us say
    multiplier_of is 64
    hidden_dim is 100
    64 * (( 100 + 64-1)%64) = 64* (163 %64) = 64*2 = 128
    We got an embedding size of a factor of 64
    Recall that SwiGLU requires 3 weight parameters - self.w1, self.w2 and self.w3.
  7. At line number 218, we put it all togetherTransformer.

Up next

We have covered all the core components of the Transformer Block, next, we put them together as a machinery called Llama 3 😀
Head over to Decoding Llama3: Part 7 - TranformerBlock

Comments

{{ gettext('Login to leave a comment') }}

{{ gettext('Post a comment…') }}
{{ gettext('New comment') }}
{{ formTitle }}

{{ errorMsg }}

{{ gettext('No comments posted yet') }}

Hosted by

NLP specialist with 14 years of experience building products. Exploring productivity hacks.