Blog

The Fifth Elephant blog

Simrat Hanspal

@simrathanspal

Decoding Llama3: Part 4 - Rotary Positional Embeddings

Submitted Jun 18, 2024

We are continuing our series of Decoding Llama3 with Rotary Positional Embeddings in this blog.
RotaryEmbedding

Introduction

To parallelize execution over tokens, we need to encode position, the transformer model, into word/token embeddings. In the transformer model, we used sinusoidal waves to derive positional embeddings.

Sinusoidal positional embeddings
Reference from Link. You can also learn more about them here.

Turns out that sinusoidal based positional embeddings are quite erratic, and the pattern is very hard to capture for the model. Due to this reason, models trained with sinusoidal positional embedding were purely memorising instead of learning position relevant patterns.

Take a look at this image below, you can see different vectors for different position values m=0, m=1 ... etc. You can’t find any pattern.

Erratic Sinusoidal Positional Embedding
For more details, watch the video explanation by DeepLearning Hero.

Plotting perplexity for various input sequence lengths showed that the perplexity exploded immediately after the training sequence length, proving that the model is memorizing and hasn’t learned the pattern in the text to handle beyond the training data sequence length.

Exploding perplexity

ROPE Embeddings - The solution

ROPE embeddings were introduced in the RoFormer paper. The authors suggested that in order for Q.K(T) to be accurate, we need a function that maps not only similarity in the embeddings but also similarity in position. That is, similar position embeddings should have similar inclination and this inclination should be a predictable pattern.

Take a look at the below image, the position vectors are even spaced and predictable.

Rotary Position Embedding

Thereby introducing the concept of ROtary Positional Embeddings (ROPE), where the embedding vector is rotated to an angle corresponding to the position.
Below is the rotation matrix that is known to have a predictable pattern.

Rotation matrix

The challenge with this solution was that it works only for 2D. Hence, the authors came up with a solution that takes token pairs. This is why ROPE embeddings require dimensions of even length.

The paper has a very neat visualization of the process.

ROPE illustration

With ROPE, the model not only generalizes better but many custom models have extended context length because the model is able to generalize patterns better and perform well even for sequence length beyond what was used for training.

Perplexity with ROPE

Code walkthrough

Compute rotation matrix freqs_cis

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis
  1. As depicted in the visualization above, we need to take 2 dimensions at a time and convert them into a complex pair (a+ib), which, when raised to e with Euler’s formula, can be written in terms of cosine and sine.
  2. We rotate every 2 dimensions by a factor of m*theta, where m is the position of sequence and theta is the angle the token in the seat factors the dimension pair.
    Theta
  3. At line number 50, we compute all theta frequencies.
  4. At line number 51, we compute all positions m.
    Next, we execute outer join as we need to get all variations of m with theta, i.e., m1*theta1, m1*theta2 etc.
    After generating all freqs, we convert them into polar coordinates, which can be written as cosine and sine components.

Reshape for broadcast

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

The above code is for parallelization.

Apply rotary embedding to Query and Key


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)
  1. In this function we take Query and Key vectors and rotate them / create rotary positional embedding required for efficient attention computation.
  2. As inputs, we take xq—query embedding, xk—key embedding, and freq_cis—rotation matrix defined by position m and angle defined by dimension size theta. As freq_cis is based on max seq length and dimension size, we don’t need to recompute it every time.
  3. At line number 70, we create pairs of dimensions and create a complex number from it for Query embedding
  4. At line number 71, we create similar pairs and complex numbers out of key embedding
  5. At line number 73 and 74, we multiply and flatten
    Rotating embeddings

Up next

In this blog, we have covered one of the core components of Llama3 architecture: Rotary Positional Embeddings. In the next blog, we will discuss another big top Grouped Query Attention.
So, head over to Decoding Llama3: Part 5 - Grouped Query Attention

Comments

{{ gettext('Login to leave a comment') }}

{{ gettext('Post a comment…') }}
{{ gettext('New comment') }}
{{ formTitle }}

{{ errorMsg }}

{{ gettext('No comments posted yet') }}

Hosted by

NLP specialist with 14 years of experience building products. Exploring productivity hacks.