Simrat Hanspal
Decoding Llama3: Part 3  Normalisation
We are continuing our series of Decoding Llama3 with Normalisation in this blog.
Introduction
Let us take a step back and think, why do we need normalization? When the scores have large variations, the model will take longer to converge. Large variation not only affects the training time, but it also affects the performance when input data is out of training data distribution. This basically means that large variations can result in very different outputs because of the variation, which is not good, and that is why normalization is an important part of the architecture.
Below are two plots comparing training time and training steps vs loss with and without normalization (LayerNorm) from RMSNorm paper.
LayerNormalisation is a popular choice of normalization. It centers the values from mean to 0 and scales the values by the variance, resulting in values between 0 to 1. We use epsilon to handle edge cases when variance is zero to avoid division by zero error.
Refer to line number 28 of model.py
where eps is defined, norm_eps: float = 1e5
.
RMSNorm
Llama 3 uses RootMeanSquare Normalisation (RMSNorm paper), an optimization over Layer Normalisation. The paper claims that the real impact comes from scaling rather than realignment.
Hence, RMSNorm optimizes to scale the value only and eliminates two statistics computations: mean and variance.
Code walkthrough
Now, let us go back to the code. As you can see, it is a direct application of the formula.
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight

eps
: Very small number 1e6 used to handle. 
gi
: Weights used for boosting and suppressing values accordingly. This is called asself.weight
in the code. The shape ofself.weight
ismodel_dim
. 
def _norm
: Function to normalize value across the row or, in other words, values going to one neuron and henceself.weight
is of the size ofmodel_dim
. 
def forward
: Forward pass (execution) method.
Up next
Next, we will cover a very interesting topic of Rotary positional embeddings. We will try to understand why they replaced default embeddings used in Transformer architecture. Head next to Decoding Llama3: Part 4  Rotary Positional Embeddings
Comments
Hosted by
{{ gettext('Login to leave a comment') }}
{{ gettext('Post a comment…') }}{{ errorMsg }}
{{ gettext('No comments posted yet') }}