Decoding Llama3: An explainer for tinkerers
A not-so-quick 7-part guide to using the Llama3 open source AI model
We are continuing our series of Decoding Llama3 with Transformer block and Transformer module.
Transformer constituents of multiple transformer blocks stacked on top of each other. At every layer we perform the same operations but learning is different, helping the model learn from simple to complex relationships.
Each block puts together everything we have covered in earlier blogs as per the diagram (seen at the top). So, let us directly jump into the code.
Transformer constituents of multiple transformer blocks stacked on top of each other. At every layer, we perform the same operations, but learning is different, helping the model learn from simple to complex relationships.
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
Let us review the init code
self.n_heads
: number of heads in every blockself.dim
: model embedding dimensionself.head_dim
: embedding size for attention computation within head.self.attention
: attention blockself.feed_forward
: FFN blockself.layer_id
: layer_idself.attention_norm
: RMSNormself.ffn_norm
: RMSNorm def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward(self.ffn_norm(h))
return out
Let us review the forward code now
x
, the input embeddingstart_pos
, which is mostly 1 for inference when we are using KV cachefreq_cis
for rotary positional embeddingmask
is used to prevent the network from attending to future tokens because autoregressive models will only have access to the previously generated tokens.class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.tok_embeddings = VocabParallelEmbedding(
params.vocab_size, params.dim, init_method=lambda x: x
)
self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = ColumnParallelLinear(
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
)
self.freqs_cis = precompute_freqs_cis(
params.dim // params.n_heads,
params.max_seq_len * 2,
params.rope_theta,
)
vocab_size
, n_layers
.ModuleList
, which is the list to store n_layers
of the transformer (decoder only) architecture.n_layers
timesTransformerBlock
and add to self.layers
.self.norm
for normalizing input and network execution variations.self.output
, Linear layer that maps the final layer’s output to vocabulary.(batch_size, seq_len, model_dim)
to (batch_size, seq_len, vocabulary_size)
.freq_cis
as it is dependent on the model params i.e dimension and max_seq_length
. @torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
mask = None
if seqlen > 1:
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
mask = torch.triu(mask, diagonal=1)
# When performing key-value caching, we compute the attention scores
# only for the new sequence. Thus, the matrix of scores is of size
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
mask = torch.hstack(
[torch.zeros((seqlen, start_pos), device=tokens.device), mask]
).type_as(h)
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)
output = self.output(h).float()
return output
Collect the tokens as input
Convert them to embeddings using self.tok_embeddings
Select freq_cis
based on sequence length, which can be lesser than max_sequence_length
mask
is set to None if seq_len
is 1 else we define the mask
.
mask
is a matrix of dimension (seq_len, seq_len)
and default value of -inf
.
The reason why we are setting the default value as -inf is that when we compute attention, we use the softmax function.
Softmax = e^x/ sum of all values to the power of e
e^-inf = 0
Thereby attention becomes 0, and the network will not attend to the future tokens.
Next, we convert matrix into upper triangular matrix using torch.tiru
[[ -inf -inf -inf]
[ -inf -inf -inf]
[ -inf -inf -inf]] -> gets converted to
[[ 0 -inf -inf]
[ 0 0 -inf]
[ 0 0 0]]
diagonal=1 indicates that the diagonal value is included.
Next, we update this mask based on how many tokens we are processing, if start_pos is 0 then we get the same mask as above, but if start_pos is 1 then we get the mask as below -
[[ 0 0 -inf]
[ 0 0 0]
[ 0 0 0]]
Finally, at line number 298, we build all the layers.
We normalize the results of every layer before the next step.
The last step is to call the output layer of Softmax on the vocabulary.
Llama3 models are big milestones in the Open-Source (or Open-Weights, as some like to call them) space. The models not only catch up with the performance of some of the best private models but go on to perform better in many cases. I am confident we will see many fine-tuned variants of Llama3, and I hope this blog series was helpful in decoding the underlying architecture.
Hosted by
Supported by
{{ gettext('Login to leave a comment') }}
{{ gettext('Post a comment…') }}{{ errorMsg }}
{{ gettext('No comments posted yet') }}