The Fifth Elephant blog

Simrat Hanspal


Decoding Llama3: Part 5 - Grouped Query Attention

Submitted Jun 18, 2024

We are continuing our series of Decoding Llama3 with Grouped Query Attention in this blog.

Grouped Query Attention intro


Transformer introduced self-attention; in multiple heads, we match the Key for incoming Query to understand where the attention should be placed.

Self Attention

With self-attention computation, we want to identify the most related tokens to query. We get this using a dot product between Q and K. Next, we scale it by sqrt(dim) because large networks converge faster on scaled values.

Below is the complete execution flow.

Self attention complete
Reference from Jay Alammar blog.

The shape of input is (batch_size, sequence_length, number_of_heads, head_dimension)
Below image shows (sequence_length, number_of_heads, head_dimension).
Batch size is an additional dimension for increasing throughput.

Self attention multi head

Is caching all we need?

If we have a sequence_length of 5, we process each token.
In the next iteration, where the sequence_length is 6, we process each of the tokens again. This is computationally expensive, but we don’t need to regenerate KV embeddings for each of those 5 tokens every time. We can cache them - KV cache.

No, caching is not all we need!

We encountered a different problem here: the KV cache is memory-bound. Reading the cache from VRAM is slow (slower than compute). Refer to this blog, Throughput is all you need.

Multi-Query is an optimization in which we assume a single head for KV. Hence, for all heads (different) queries, we need a single KV cache, which drastically reduces the overhead of loading multiple KV caches into memory. However, this means we would see a degradation in performance.

Another alternative is to use Grouped Query, where every two queries require 1 KV cache. This is a good middle ground and is used by Llama3.

Picture from Umar Jamil’s lecture on Llama

Code walkthrough

Attention class

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        model_parallel_size = fs_init.get_model_parallel_world_size()
        self.n_local_heads = args.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads

        self.wq = ColumnParallelLinear(
            args.n_heads * self.head_dim,
            init_method=lambda x: x,
        self.wk = ColumnParallelLinear(
            self.n_kv_heads * self.head_dim,
            init_method=lambda x: x,
        self.wv = ColumnParallelLinear(
            self.n_kv_heads * self.head_dim,
            init_method=lambda x: x,
        self.wo = RowParallelLinear(
            args.n_heads * self.head_dim,
            init_method=lambda x: x,

        self.cache_k = torch.zeros(
        self.cache_v = torch.zeros(
  1. If self.n_kv_heads, the number of KV heads is not explicitly defined, then use the total heads.
  2. Line numbers 94-96 are for parallelization.
  3. self.n_rep tells what is the factor for repeating KV heads, this is required if we are Multi-Query / Grouped-Query attention because for every head query, we need KV.
  4. self.head_dim is the dimension of heads which is nothing but model dimension length divided by the total number of heads.
  5. At lines 100-127, we define weights wq, wk, wv, and wo as ParallelLinear.
  6. At lines 129-144, we create K and V cache of dimension (batch_size, max_seq_length, n_KV_heads, head_dimension).
  7. We will be updating the value of the indexes as we process tokens.

Attention computation - forward method

    def forward(
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        self.cache_k =
        self.cache_v =

        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        # repeat k/v heads if n_kv_heads < n_heads
        keys = repeat_kv(
            keys, self.n_rep
        )  # (bs, cache_len + seqlen, n_local_heads, head_dim)
        values = repeat_kv(
            values, self.n_rep
        )  # (bs, cache_len + seqlen, n_local_heads, head_dim)

        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        keys = keys.transpose(1, 2)  # (bs, n_local_heads, cache_len + seqlen, head_dim)
        values = values.transpose(
            1, 2
        )  # (bs, n_local_heads, cache_len + seqlen, head_dim)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)

  1. At line 154, we take a token embedding (or multiple based on seq_length) and multiple with different weights like wq, wk, and wv to get Query, Key, and Value embeddings.
  2. At line numbers 156-158, we go from dimension (batch_size, seq_len, model_dim) to (batch_size, seq_len, n_heads, head_dim). This step splits the model_dim to n_heads with dimension head_dim.
    We must note that model_dim = n_heads * head_dim.
  3. We get rotary embedding for Query and Key as we are looking for finding related keys to a Query. These positions encoded are required only for attention and within attention computation only for Query to Key match.
  4. At lines 165-166, we update the self.cache_k and self.cache_v.
  5. Next, we fetch all the keys and values. This is going to be the same as xk and xv if start_pos is 0 and seqlen >1, which means we are processing every token every iteration, and hence xk and xv are computed every time. But, during inference the optimized strategy is to use KV cache (self.cache_k and self_cache_v) and hence our startpos will be the index we are processing and seqlen = 1.
  6. At line numbers 172 - 177, we repeat the KV heads. Recall Attention is computed at every head and concatenated at last. Therefore, every Query head must have some KV head associated with it. In Grouped Query or Multi-Query, the mapping is not 1:1. Multiple Queries have to work with shared KV heads, but this optimization is not present in the code. What we see is an adjustment where we repeat KV heads.
    This is still an optimization even though it doesn’t look like it. This is an optimization from a memory bottleneck point of view as we only need to load a few KV weight matrices as compared to before.
  7. At line number 179 to 183, we are transposing the dimensions. Let us understand this
    We start with vectors of dimension (batch_size, seq_len, n_heads, head_dim) and we transpose it to (batch_size, n_heads, seq_len, head_dim).

We do this for matrix multiplication, which requires attention computation, as attention needs to review all the tokens in the sequence.

Dimension of Query is (seq_len, head_dim)
Dimension of Key is (head_dim, seq_len)
Dimension of the product (seq_len, seq_len)

Dimension of Value (seq_len, head_dim)
Dimension of the product (attention) (seq_len, head_dim)

  1. At line numbers 184-188, we compute Attention for every head according to the attention function we described above.
  2. At line 189, we concatenate all the attention from the head
    For this, first we transpose n_heads with seq_len
    Recall the dimension of the result (batch_size, n_heads, seq_len, head_dim).
    We need to concatenate head_dim for n_heads, hence we need to transpose n_heads and seq_len to get back our old dimension (batch_size, seq_len, n_heads, head_dim)
    Next, we use the torch function contiguous().view(bsz, seqlen, -1) to perform the concatenation.
  3. At line number 190, we multiply self.wo with the output. This is useful if we want to adjust any weights.

Up next

Attention is the heart of the Transformers and with this blog we are done with one of the most important topics. Up next we will discuss the Feed Forward Network and the activation function used by Llama3 architecture. So, head over to Decoding Llama3: Part 6 - Feed Forward Network


{{ gettext('Login to leave a comment') }}

{{ gettext('Post a comment…') }}
{{ gettext('New comment') }}
{{ formTitle }}

{{ errorMsg }}

{{ gettext('No comments posted yet') }}

Hosted by

NLP specialist with 14 years of experience building products. Exploring productivity hacks.