Blog

The Fifth Elephant blog

Simrat Hanspal

@simrathanspal

Decoding Llama3: Part 7 - Transformer Block & Module

Submitted Jun 18, 2024

We are continuing our series of Decoding Llama3 with Transformer block and Transformer module.

Transformer intro

Introduction

Transformer constituents of multiple transformer blocks stacked on top of each other. At every layer we perform the same operations but learning is different, helping the model learn from simple to complex relationships.
Decoder stack

Each block puts together everything we have covered in earlier blogs as per the diagram (seen at the top). So, let us directly jump into the code.

Code walkthrough

Transformer block - init

Transformer constituents of multiple transformer blocks stacked on top of each other. At every layer, we perform the same operations, but learning is different, helping the model learn from simple to complex relationships.

    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim,
            hidden_dim=4 * args.dim,
            multiple_of=args.multiple_of,
            ffn_dim_multiplier=args.ffn_dim_multiplier,
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

Let us review the init code

  1. self.n_heads: number of heads in every block
  2. self.dim: model embedding dimension
  3. self.head_dim: embedding size for attention computation within head.
  4. self.attention: attention block
  5. self.feed_forward: FFN block
  6. self.layer_id: layer_id
  7. self.attention_norm: RMSNorm
  8. self.ffn_norm: RMSNorm

Transformer block - forward method

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
        out = h + self.feed_forward(self.ffn_norm(h))
        return out

Let us review the forward code now

  1. Inputs to the transformer blocks are
    a. x, the input embedding
    b. start_pos, which is mostly 1 for inference when we are using KV cache
    c. freq_cis for rotary positional embedding
    d. mask is used to prevent the network from attending to future tokens because autoregressive models will only have access to the previously generated tokens.
  2. At line number 246, we can see the execution flow as depicted in the first Llama architecture diagram
    a. Normalize inputs
    b. Compute attention
    c. Normalize attention
    d. Compute outputs from FFN

Transformer module - init

class Transformer(nn.Module):
    def __init__(self, params: ModelArgs):
        super().__init__()
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers

        self.tok_embeddings = VocabParallelEmbedding(
            params.vocab_size, params.dim, init_method=lambda x: x
        )

        self.layers = torch.nn.ModuleList()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlock(layer_id, params))

        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
        self.output = ColumnParallelLinear(
            params.dim, params.vocab_size, bias=False, init_method=lambda x: x
        )

        self.freqs_cis = precompute_freqs_cis(
            params.dim // params.n_heads,
            params.max_seq_len * 2,
            params.rope_theta,
        )
  1. We set variables like vocab_size, n_layers.
  2. At line number 258, we define the embedding module
  3. At line number 262, we define ModuleList, which is the list to store n_layers of the transformer (decoder only) architecture.
  4. At line number 261, we loop for n_layers times
  5. At line number 262, we create each layer which is a TransformerBlock and add to self.layers.
  6. We define self.norm for normalizing input and network execution variations.
  7. We define self.output , Linear layer that maps the final layer’s output to vocabulary.
    (batch_size, seq_len, model_dim) to (batch_size, seq_len, vocabulary_size).
  8. We compute freq_cis as it is dependent on the model params i.e dimension and max_seq_length.

Transformer module - forward method

    @torch.inference_mode()
    def forward(self, tokens: torch.Tensor, start_pos: int):
        _bsz, seqlen = tokens.shape
        h = self.tok_embeddings(tokens)
        self.freqs_cis = self.freqs_cis.to(h.device)
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

        mask = None
        if seqlen > 1:
            mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)

            mask = torch.triu(mask, diagonal=1)

            # When performing key-value caching, we compute the attention scores
            # only for the new sequence. Thus, the matrix of scores is of size
            # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
            # j > cache_len + i, since row i corresponds to token cache_len + i.
            mask = torch.hstack(
                [torch.zeros((seqlen, start_pos), device=tokens.device), mask]
            ).type_as(h)

        for layer in self.layers:
            h = layer(h, start_pos, freqs_cis, mask)
        h = self.norm(h)
        output = self.output(h).float()
        return output
  1. Collect the tokens as input

  2. Convert them to embeddings using self.tok_embeddings

  3. Select freq_cis based on sequence length, which can be lesser than max_sequence_length

  4. mask is set to None if seq_len is 1 else we define the mask.

  5. mask is a matrix of dimension (seq_len, seq_len) and default value of -inf.
    The reason why we are setting the default value as -inf is that when we compute attention, we use the softmax function.
    Softmax = e^x/ sum of all values to the power of e
    e^-inf = 0
    Thereby attention becomes 0, and the network will not attend to the future tokens.

  6. Next, we convert matrix into upper triangular matrix using torch.tiru
    [[ -inf -inf -inf]
    [ -inf -inf -inf]
    [ -inf -inf -inf]] -> gets converted to

    [[ 0 -inf -inf]
    [ 0 0 -inf]
    [ 0 0 0]]
    diagonal=1 indicates that the diagonal value is included.

  7. Next, we update this mask based on how many tokens we are processing, if start_pos is 0 then we get the same mask as above, but if start_pos is 1 then we get the mask as below -
    [[ 0 0 -inf]
    [ 0 0 0]
    [ 0 0 0]]

  8. Finally, at line number 298, we build all the layers.

  9. We normalize the results of every layer before the next step.

  10. The last step is to call the output layer of Softmax on the vocabulary.

Up next

Llama3 models are big milestones in the Open-Source (or Open-Weights, as some like to call them) space. The models not only catch up with the performance of some of the best private models but go on to perform better in many cases. I am confident we will see many fine-tuned variants of Llama3, and I hope this blog series was helpful in decoding the underlying architecture.

Comments

{{ gettext('Login to leave a comment') }}

{{ gettext('Post a comment…') }}
{{ gettext('New comment') }}
{{ formTitle }}

{{ errorMsg }}

{{ gettext('No comments posted yet') }}

Hosted by

NLP specialist with 14 years of experience building products. Exploring productivity hacks.