Decoding Llama3: An explainer for tinkerers
A not-so-quick 7-part guide to using the Llama3 open source AI model
We are continuing our series of Decoding Llama3 with Normalisation in this blog.
Let us take a step back and think, why do we need normalization? When the scores have large variations, the model will take longer to converge. Large variation not only affects the training time, but it also affects the performance when input data is out of training data distribution. This basically means that large variations can result in very different outputs because of the variation, which is not good, and that is why normalization is an important part of the architecture.
Below are two plots comparing training time and training steps vs loss with and without normalization (LayerNorm) from RMSNorm paper.
LayerNormalisation is a popular choice of normalization. It centers the values from mean to 0 and scales the values by the variance, resulting in values between 0 to 1. We use epsilon to handle edge cases when variance is zero to avoid division by zero error.
Refer to line number 28 of model.py
where eps is defined, norm_eps: float = 1e-5
.
Llama 3 uses RootMeanSquare Normalisation (RMSNorm paper), an optimization over Layer Normalisation. The paper claims that the real impact comes from scaling rather than realignment.
Hence, RMSNorm optimizes to scale the value only and eliminates two statistics computations: mean and variance.
Now, let us go back to the code. As you can see, it is a direct application of the formula.
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
eps
: Very small number 1e-6 used to handle.
gi
: Weights used for boosting and suppressing values accordingly. This is called as self.weight
in the code. The shape of self.weight
is model_dim
.
def _norm
: Function to normalize value across the row or, in other words, values going to one neuron and hence self.weight
is of the size of model_dim
.
def forward
: Forward pass (execution) method.
Next, we will cover a very interesting topic of Rotary positional embeddings. We will try to understand why they replaced default embeddings used in Transformer architecture. Head next to Decoding Llama3: Part 4 - Rotary Positional Embeddings
Hosted by
Supported by
{{ gettext('Login to leave a comment') }}
{{ gettext('Post a comment…') }}{{ errorMsg }}
{{ gettext('No comments posted yet') }}