Simrat Hanspal
Decoding Llama3: Part 7  Transformer Block & Module
We are continuing our series of Decoding Llama3 with Transformer block and Transformer module.
Introduction
Transformer constituents of multiple transformer blocks stacked on top of each other. At every layer we perform the same operations but learning is different, helping the model learn from simple to complex relationships.
Each block puts together everything we have covered in earlier blogs as per the diagram (seen at the top). So, let us directly jump into the code.
Code walkthrough
Transformer block  init
Transformer constituents of multiple transformer blocks stacked on top of each other. At every layer, we perform the same operations, but learning is different, helping the model learn from simple to complex relationships.
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
Let us review the init code
self.n_heads
: number of heads in every blockself.dim
: model embedding dimensionself.head_dim
: embedding size for attention computation within head.self.attention
: attention blockself.feed_forward
: FFN blockself.layer_id
: layer_idself.attention_norm
: RMSNormself.ffn_norm
: RMSNorm
Transformer block  forward method
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward(self.ffn_norm(h))
return out
Let us review the forward code now
 Inputs to the transformer blocks are
a.x
, the input embedding
b.start_pos
, which is mostly 1 for inference when we are using KV cache
c.freq_cis
for rotary positional embedding
d.mask
is used to prevent the network from attending to future tokens because autoregressive models will only have access to the previously generated tokens.  At line number 246, we can see the execution flow as depicted in the first Llama architecture diagram
a. Normalize inputs
b. Compute attention
c. Normalize attention
d. Compute outputs from FFN
Transformer module  init
class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.tok_embeddings = VocabParallelEmbedding(
params.vocab_size, params.dim, init_method=lambda x: x
)
self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = ColumnParallelLinear(
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
)
self.freqs_cis = precompute_freqs_cis(
params.dim // params.n_heads,
params.max_seq_len * 2,
params.rope_theta,
)
 We set variables like
vocab_size
,n_layers
.  At line number 258, we define the embedding module
 At line number 262, we define
ModuleList
, which is the list to storen_layers
of the transformer (decoder only) architecture.  At line number 261, we loop for
n_layers
times  At line number 262, we create each layer which is a
TransformerBlock
and add toself.layers
.  We define
self.norm
for normalizing input and network execution variations.  We define
self.output
, Linear layer that maps the final layer’s output to vocabulary.
(batch_size, seq_len, model_dim)
to(batch_size, seq_len, vocabulary_size)
.  We compute
freq_cis
as it is dependent on the model params i.e dimension andmax_seq_length
.
Transformer module  forward method
@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
mask = None
if seqlen > 1:
mask = torch.full((seqlen, seqlen), float("inf"), device=tokens.device)
mask = torch.triu(mask, diagonal=1)
# When performing keyvalue caching, we compute the attention scores
# only for the new sequence. Thus, the matrix of scores is of size
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
mask = torch.hstack(
[torch.zeros((seqlen, start_pos), device=tokens.device), mask]
).type_as(h)
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)
output = self.output(h).float()
return output

Collect the tokens as input

Convert them to embeddings using
self.tok_embeddings

Select
freq_cis
based on sequence length, which can be lesser thanmax_sequence_length

mask
is set to None ifseq_len
is 1 else we define themask
. 
mask
is a matrix of dimension(seq_len, seq_len)
and default value ofinf
.
The reason why we are setting the default value as inf is that when we compute attention, we use the softmax function.
Softmax = e^x/ sum of all values to the power of e
e^inf = 0
Thereby attention becomes 0, and the network will not attend to the future tokens. 
Next, we convert matrix into upper triangular matrix using
torch.tiru
[[ inf inf inf]
[ inf inf inf]
[ inf inf inf]] > gets converted to
[[ 0 inf inf]
[ 0 0 inf]
[ 0 0 0]]
diagonal=1 indicates that the diagonal value is included. 
Next, we update this mask based on how many tokens we are processing, if start_pos is 0 then we get the same mask as above, but if start_pos is 1 then we get the mask as below 
[[ 0 0 inf]
[ 0 0 0]
[ 0 0 0]] 
Finally, at line number 298, we build all the layers.

We normalize the results of every layer before the next step.

The last step is to call the output layer of Softmax on the vocabulary.
Up next
Llama3 models are big milestones in the OpenSource (or OpenWeights, as some like to call them) space. The models not only catch up with the performance of some of the best private models but go on to perform better in many cases. I am confident we will see many finetuned variants of Llama3, and I hope this blog series was helpful in decoding the underlying architecture.
Comments
Hosted by
{{ gettext('Login to leave a comment') }}
{{ gettext('Post a comment…') }}{{ errorMsg }}
{{ gettext('No comments posted yet') }}