Blog

The Fifth Elephant blog

Vidya Ramakrishnan

Vidya Ramakrishnan

@vidya_ramki

Simrat Hanspal

@simrathanspal

Decoding Llama3: Part 3 - Normalisation

Submitted Jun 18, 2024

We are continuing our series of Decoding Llama3 with Normalisation in this blog.

Llama3 Architecture

Introduction

Let us take a step back and think, why do we need normalization? When the scores have large variations, the model will take longer to converge. Large variation not only affects the training time, but it also affects the performance when input data is out of training data distribution. This basically means that large variations can result in very different outputs because of the variation, which is not good, and that is why normalization is an important part of the architecture.

Below are two plots comparing training time and training steps vs loss with and without normalization (LayerNorm) from RMSNorm paper.

plots

LayerNormalisation is a popular choice of normalization. It centers the values from mean to 0 and scales the values by the variance, resulting in values between 0 to 1. We use epsilon to handle edge cases when variance is zero to avoid division by zero error.
Refer to line number 28 of model.py where eps is defined, norm_eps: float = 1e-5.

LayerNormalisation

RMSNorm

Llama 3 uses RootMeanSquare Normalisation (RMSNorm paper), an optimization over Layer Normalisation. The paper claims that the real impact comes from scaling rather than realignment.

Hence, RMSNorm optimizes to scale the value only and eliminates two statistics computations: mean and variance.

RMSNorm

Code walkthrough

Now, let us go back to the code. As you can see, it is a direct application of the formula.

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight
  1. eps: Very small number 1e-6 used to handle.

  2. gi: Weights used for boosting and suppressing values accordingly. This is called as self.weight in the code. The shape of self.weight is model_dim.

  3. def _norm: Function to normalize value across the row or, in other words, values going to one neuron and hence self.weight is of the size of model_dim.

  4. def forward: Forward pass (execution) method.

Up next

Next, we will cover a very interesting topic of Rotary positional embeddings. We will try to understand why they replaced default embeddings used in Transformer architecture. Head next to Decoding Llama3: Part 4 - Rotary Positional Embeddings

Comments

{{ gettext('Login to leave a comment') }}

{{ gettext('Post a comment…') }}
{{ gettext('New comment') }}
{{ formTitle }}

{{ errorMsg }}

{{ gettext('No comments posted yet') }}

Hosted by

NLP specialist with 14 years of experience building products. Exploring productivity hacks.