Decoding Llama3: An explainer for tinkerers
A not-so-quick 7-part guide to using the Llama3 open source AI model
We are continuing our series of Decoding Llama3 with Grouped Query Attention in this blog.
Transformer introduced self-attention; in multiple heads, we match the Key for incoming Query to understand where the attention should be placed.
With self-attention computation, we want to identify the most related tokens to query. We get this using a dot product between Q and K. Next, we scale it by sqrt(dim) because large networks converge faster on scaled values.
Below is the complete execution flow.
Reference from Jay Alammar blog.
The shape of input is (batch_size, sequence_length, number_of_heads, head_dimension)
Below image shows (sequence_length, number_of_heads, head_dimension)
.
Batch size is an additional dimension for increasing throughput.
If we have a sequence_length
of 5, we process each token.
In the next iteration, where the sequence_length
is 6, we process each of the tokens again. This is computationally expensive, but we don’t need to regenerate KV embeddings for each of those 5 tokens every time. We can cache them - KV cache.
We encountered a different problem here: the KV cache is memory-bound. Reading the cache from VRAM is slow (slower than compute). Refer to this blog, Throughput is all you need.
Multi-Query is an optimization in which we assume a single head for KV. Hence, for all heads (different) queries, we need a single KV cache, which drastically reduces the overhead of loading multiple KV caches into memory. However, this means we would see a degradation in performance.
Another alternative is to use Grouped Query, where every two queries require 1 KV cache. This is a good middle ground and is used by Llama3.
Picture from Umar Jamil’s lecture on Llama
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.n_kv_heads
, the number of KV heads is not explicitly defined, then use the total heads.self.n_rep
tells what is the factor for repeating KV heads, this is required if we are Multi-Query / Grouped-Query attention because for every head query, we need KV.self.head_dim
is the dimension of heads which is nothing but model dimension length divided by the total number of heads.wq
, wk
, wv
, and wo
as ParallelLinear.(batch_size, max_seq_length, n_KV_heads, head_dimension)
.
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(
keys, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(
values, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(
1, 2
) # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
wq
, wk
, and wv
to get Query, Key, and Value embeddings.(batch_size, seq_len, model_dim)
to (batch_size, seq_len, n_heads, head_dim)
. This step splits the model_dim
to n_heads
with dimension head_dim
.model_dim = n_heads * head_dim
.self.cache_k
and self.cache_v
.xk
and xv
if start_pos
is 0 and seqlen
>1, which means we are processing every token every iteration, and hence xk
and xv
are computed every time. But, during inference the optimized strategy is to use KV cache (self.cache_k
and self_cache_v
) and hence our startpos
will be the index we are processing and seqlen
= 1.(batch_size, seq_len, n_heads, head_dim
) and we transpose it to (batch_size, n_heads, seq_len, head_dim)
.We do this for matrix multiplication, which requires attention computation, as attention needs to review all the tokens in the sequence.
Dimension of Query is (seq_len, head_dim)
Dimension of Key is (head_dim, seq_len)
Dimension of the product (seq_len, seq_len)
Dimension of Value (seq_len, head_dim)
Dimension of the product (attention) (seq_len, head_dim)
n_heads
with seq_len
(batch_size, n_heads, seq_len, head_dim)
.head_dim
for n_heads
, hence we need to transpose n_heads
and seq_len
to get back our old dimension (batch_size, seq_len, n_heads, head_dim)
contiguous().view(bsz, seqlen, -1)
to perform the concatenation.self.wo
with the output. This is useful if we want to adjust any weights.Attention is the heart of the Transformers and with this blog we are done with one of the most important topics. Up next we will discuss the Feed Forward Network and the activation function used by Llama3 architecture. So, head over to Decoding Llama3: Part 6 - Feed Forward Network
Hosted by
Supported by
{{ gettext('Login to leave a comment') }}
{{ gettext('Post a comment…') }}{{ errorMsg }}
{{ gettext('No comments posted yet') }}