Simrat Hanspal

## Decoding Llama3: Part 4 - Rotary Positional Embeddings

We are continuing our series of Decoding Llama3 with Rotary Positional Embeddings in this blog.

# Introduction

To parallelize execution over tokens, we need to encode position, the transformer model, into word/token embeddings. In the transformer model, we used sinusoidal waves to derive positional embeddings.

Reference from Link. You can also learn more about them here.

Turns out that sinusoidal based positional embeddings are quite erratic, and the pattern is very hard to capture for the model. Due to this reason, models trained with sinusoidal positional embedding were purely memorising instead of learning position relevant patterns.

Take a look at this image below, you can see different vectors for different position values `m=0`

, `m=1`

... etc. You can’t find any pattern.

For more details, watch the video explanation by DeepLearning Hero.

Plotting perplexity for various input sequence lengths showed that the perplexity exploded immediately after the training sequence length, proving that the model is memorizing and hasn’t learned the pattern in the text to handle beyond the training data sequence length.

# ROPE Embeddings - The solution

ROPE embeddings were introduced in the RoFormer paper. The authors suggested that in order for Q.K(T) to be accurate, we need a function that maps not only similarity in the embeddings but also similarity in position. That is, similar position embeddings should have similar inclination and this inclination should be a predictable pattern.

Take a look at the below image, the position vectors are even spaced and predictable.

Thereby introducing the concept of ROtary Positional Embeddings (ROPE), where the embedding vector is rotated to an angle corresponding to the position.

Below is the rotation matrix that is known to have a predictable pattern.

The challenge with this solution was that it works only for 2D. Hence, the authors came up with a solution that takes token pairs. This is why ROPE embeddings require dimensions of even length.

The paper has a very neat visualization of the process.

With ROPE, the model not only generalizes better but many custom models have extended context length because the model is able to generalize patterns better and perform well even for sequence length beyond what was used for training.

# Code walkthrough

### Compute rotation matrix freqs_cis

```
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
```

- As depicted in the visualization above, we need to take 2 dimensions at a time and convert them into a complex pair (a+ib), which, when raised to e with Euler’s formula, can be written in terms of cosine and sine.
- We rotate every 2 dimensions by a factor of
`m*theta,`

where`m`

is the position of sequence and`theta`

is the angle the token in the seat factors the dimension pair.

- At line number 50, we compute all theta frequencies.
- At line number 51, we compute all positions m.

Next, we execute outer join as we need to get all variations of m with theta, i.e.,`m1*theta1`

,`m1*theta2`

etc.

After generating all freqs, we convert them into polar coordinates, which can be written as cosine and sine components.

### Reshape for broadcast

```
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
```

The above code is for parallelization.

### Apply rotary embedding to Query and Key

```
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
```

- In this function we take Query and Key vectors and rotate them / create rotary positional embedding required for efficient attention computation.
- As inputs, we take
`xq`

—query embedding,`xk`

—key embedding, and`freq_cis`

—rotation matrix defined by position`m`

and angle defined by dimension size`theta`

. As`freq_cis`

is based on max seq length and dimension size, we don’t need to recompute it every time. - At line number 70, we create pairs of dimensions and create a complex number from it for Query embedding
- At line number 71, we create similar pairs and complex numbers out of key embedding
- At line number 73 and 74, we multiply and flatten

# Up next

In this blog, we have covered one of the core components of Llama3 architecture: Rotary Positional Embeddings. In the next blog, we will discuss another big top Grouped Query Attention.

So, head over to Decoding Llama3: Part 5 - Grouped Query Attention

### Comments

Hosted by

{{ gettext('Login to leave a comment') }}

{{ gettext('Post a comment…') }}{{ errorMsg }}

{{ gettext('No comments posted yet') }}