Decoding Llama3: An explainer for tinkerers
A not-so-quick 7-part guide to using the Llama3 open source AI model
We are continuing our series of Decoding Llama3 with Rotary Positional Embeddings in this blog.
To parallelize execution over tokens, we need to encode position, the transformer model, into word/token embeddings. In the transformer model, we used sinusoidal waves to derive positional embeddings.
Reference from Link. You can also learn more about them here.
Turns out that sinusoidal based positional embeddings are quite erratic, and the pattern is very hard to capture for the model. Due to this reason, models trained with sinusoidal positional embedding were purely memorising instead of learning position relevant patterns.
Take a look at this image below, you can see different vectors for different position values m=0
, m=1
... etc. You can’t find any pattern.
For more details, watch the video explanation by DeepLearning Hero.
Plotting perplexity for various input sequence lengths showed that the perplexity exploded immediately after the training sequence length, proving that the model is memorizing and hasn’t learned the pattern in the text to handle beyond the training data sequence length.
ROPE embeddings were introduced in the RoFormer paper. The authors suggested that in order for Q.K(T) to be accurate, we need a function that maps not only similarity in the embeddings but also similarity in position. That is, similar position embeddings should have similar inclination and this inclination should be a predictable pattern.
Take a look at the below image, the position vectors are even spaced and predictable.
Thereby introducing the concept of ROtary Positional Embeddings (ROPE), where the embedding vector is rotated to an angle corresponding to the position.
Below is the rotation matrix that is known to have a predictable pattern.
The challenge with this solution was that it works only for 2D. Hence, the authors came up with a solution that takes token pairs. This is why ROPE embeddings require dimensions of even length.
The paper has a very neat visualization of the process.
With ROPE, the model not only generalizes better but many custom models have extended context length because the model is able to generalize patterns better and perform well even for sequence length beyond what was used for training.
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
m*theta,
where m
is the position of sequence and theta
is the angle the token in the seat factors the dimension pair.m1*theta1
, m1*theta2
etc.def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
The above code is for parallelization.
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
xq
—query embedding, xk
—key embedding, and freq_cis
—rotation matrix defined by position m
and angle defined by dimension size theta
. As freq_cis
is based on max seq length and dimension size, we don’t need to recompute it every time.In this blog, we have covered one of the core components of Llama3 architecture: Rotary Positional Embeddings. In the next blog, we will discuss another big top Grouped Query Attention.
So, head over to Decoding Llama3: Part 5 - Grouped Query Attention
Hosted by
Supported by
{{ gettext('Login to leave a comment') }}
{{ gettext('Post a comment…') }}{{ errorMsg }}
{{ gettext('No comments posted yet') }}