Simrat Hanspal

## Decoding Llama3: Part 5 - Grouped Query Attention

We are continuing our series of Decoding Llama3 with Grouped Query Attention in this blog.

# Introduction

Transformer introduced self-attention; in multiple heads, we match the Key for incoming Query to understand where the attention should be placed.

With self-attention computation, we want to identify the most related tokens to query. We get this using a dot product between Q and K. Next, we scale it by sqrt(dim) because large networks converge faster on scaled values.

Below is the complete execution flow.

Reference from Jay Alammar blog.

The shape of input is `(batch_size, sequence_length, number_of_heads, head_dimension)`

Below image shows `(sequence_length, number_of_heads, head_dimension)`

.

Batch size is an additional dimension for increasing throughput.

# Is caching all we need?

If we have a `sequence_length`

of 5, we process each token.

In the next iteration, where the `sequence_length`

is 6, we process each of the tokens again. This is computationally expensive, but we don’t need to regenerate KV embeddings for each of those 5 tokens every time. We can cache them - KV cache.

# No, caching is not all we need!

We encountered a different problem here: the KV cache is memory-bound. Reading the cache from VRAM is slow (slower than compute). Refer to this blog, Throughput is all you need.

Multi-Query is an optimization in which we assume a single head for KV. Hence, for all heads (different) queries, we need a single KV cache, which drastically reduces the overhead of loading multiple KV caches into memory. However, this means we would see a degradation in performance.

Another alternative is to use Grouped Query, where every two queries require 1 KV cache. This is a good middle ground and is used by Llama3.

Picture from Umar Jamil’s lecture on Llama

# Code walkthrough

### Attention class

```
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
```

- If
`self.n_kv_heads`

, the number of KV heads is not explicitly defined, then use the total heads. - Line numbers 94-96 are for parallelization.
`self.n_rep`

tells what is the factor for repeating KV heads, this is required if we are Multi-Query / Grouped-Query attention because for every head query, we need KV.`self.head_dim`

is the dimension of heads which is nothing but model dimension length divided by the total number of heads.- At lines 100-127, we define weights
`wq`

,`wk`

,`wv`

, and`wo`

as ParallelLinear. - At lines 129-144, we create K and V cache of dimension
`(batch_size, max_seq_length, n_KV_heads, head_dimension)`

. - We will be updating the value of the indexes as we process tokens.

### Attention computation - forward method

```
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(
keys, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(
values, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(
1, 2
) # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
```

- At line 154, we take a token embedding (or multiple based on seq_length) and multiple with different weights like
`wq`

,`wk`

, and`wv`

to get Query, Key, and Value embeddings. - At line numbers 156-158, we go from dimension
`(batch_size, seq_len, model_dim)`

to`(batch_size, seq_len, n_heads, head_dim)`

. This step splits the`model_dim`

to`n_heads`

with dimension`head_dim`

.

We must note that`model_dim = n_heads * head_dim`

. - We get rotary embedding for Query and Key as we are looking for finding related keys to a Query. These positions encoded are required only for attention and within attention computation only for Query to Key match.
- At lines 165-166, we update the
`self.cache_k`

and`self.cache_v`

. - Next, we fetch all the keys and values. This is going to be the same as
`xk`

and`xv`

if`start_pos`

is 0 and`seqlen`

>1, which means we are processing every token every iteration, and hence`xk`

and`xv`

are computed every time. But, during inference the optimized strategy is to use KV cache (`self.cache_k`

and`self_cache_v`

) and hence our`startpos`

will be the index we are processing and`seqlen`

= 1. - At line numbers 172 - 177, we repeat the KV heads. Recall Attention is computed at every head and concatenated at last. Therefore, every Query head must have some KV head associated with it. In Grouped Query or Multi-Query, the mapping is not 1:1. Multiple Queries have to work with shared KV heads, but this optimization is not present in the code. What we see is an adjustment where we repeat KV heads.

This is still an optimization even though it doesn’t look like it. This is an optimization from a memory bottleneck point of view as we only need to load a few KV weight matrices as compared to before. - At line number 179 to 183, we are transposing the dimensions. Let us understand this

We start with vectors of dimension`(batch_size, seq_len, n_heads, head_dim`

) and we transpose it to`(batch_size, n_heads, seq_len, head_dim)`

.

We do this for matrix multiplication, which requires attention computation, as attention needs to review all the tokens in the sequence.

Dimension of Query is `(seq_len, head_dim)`

Dimension of Key is `(head_dim, seq_len)`

Dimension of the product `(seq_len, seq_len)`

Dimension of Value `(seq_len, head_dim)`

Dimension of the product (attention) `(seq_len, head_dim)`

- At line numbers 184-188, we compute Attention for every head according to the attention function we described above.
- At line 189, we concatenate all the attention from the head

For this, first we transpose`n_heads`

with`seq_len`

Recall the dimension of the result`(batch_size, n_heads, seq_len, head_dim)`

.

We need to concatenate`head_dim`

for`n_heads`

, hence we need to transpose`n_heads`

and`seq_len`

to get back our old dimension`(batch_size, seq_len, n_heads, head_dim)`

Next, we use the torch function`contiguous().view(bsz, seqlen, -1)`

to perform the concatenation. - At line number 190, we multiply
`self.wo`

with the output. This is useful if we want to adjust any weights.

# Up next

Attention is the heart of the Transformers and with this blog we are done with one of the most important topics. Up next we will discuss the Feed Forward Network and the activation function used by Llama3 architecture. So, head over to Decoding Llama3: Part 6 - Feed Forward Network

### Comments

Hosted by

{{ gettext('Login to leave a comment') }}

{{ gettext('Post a comment…') }}{{ errorMsg }}

{{ gettext('No comments posted yet') }}