Blog

The Fifth Elephant blog

Vidya Ramakrishnan

Vidya Ramakrishnan

@vidya_ramki

Simrat Hanspal

@simrathanspal

Decoding Llama3: Part 2 - Understanding the configuration

Submitted Jun 18, 2024

We are continuing our series of Decoding Llama3 with the overview of model architecture in this blog.

Getting started

model.py has the complete architecture of Llama3. We will understand Llama3 architecture using this code.

You can find model.py for Llama3 at this Link.

Ten thousand foot view

At a high level, the below image will help you map the modules in model.py to the components of the Llama3 architecture (Llama architecture drawn by Umar Jamil).

Llama3 architecture

As you can see, the execution begins with the input which is token embeddings.
Embeddings go into Nx decoder blocks (called TransformerBlock) and finally, the output of the decoder blocks goes to the Linear layer, which converts the model (internal) embeddings to vocabulary size. The softmax results in converting the token scores to probabilities so that we can pick the most probable token.
Don’t worry about the details, we will be covering everything in depth.
Let us begin by understanding the Model configuration or ModelArgs class.

Code walkthrough

@dataclass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    n_kv_heads: Optional[int] = None
    vocab_size: int = -1
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5
    rope_theta: float = 500000

    max_batch_size: int = 32
    max_seq_len: int = 2048

ModelArgs class defines all of the configurable parameters required for the Llama3 model.
We will be passing this class into different blocks. Since this is a static config class, we don’t need an instance to access the values and hence we define it as a data class using decoration @dataclass.

Let us review the configurable parameters (Don’t worry if you don’t understand them now, we will be going through them in detail in their subsequent section. Think of this as a glossary)-

  1. dim: Model dimension, size of the embedding used by the model = 4096

  2. n_layers: Number of layers (Nx) = 32

  3. n_heads: Number of heads for Queries. Unlike the Transformer model, the Llama3 model can use different numbers of heads for KV. This is to optimize on the KV cache memory constraint.

  4. n_kv_heads: Number of KV heads if we want them to be different from Query heads.

  5. vocab_size: The size of the vocabulary is not set in this data class and is required to be set. This is useful when using Llama for different data, like different languages, where your vocabulary can be different or extended.

  6. multiplier_of: This parameter ensures that the hidden layer dimension selected is a multiple of the multiplier_of value to maintain consistency in calculations.

  7. ffn_dim_multipler: This multiplier value scales the model embedding dimension to larger dimension for consolidating all of the learnings in FeedForwardNetwork. Since this is a configurable parameter, we need the enforcing parameter multiplier_of to ensure that the hiddenis multiple of a dimension to a embedding size is a multiple of large power of 2.

  8. norm_eps: Epsilon used in normalization. Useful in case the denominator for cases where variation is zero.

  9. rope_theta: Parameter required for computing rotary embeddings. More on Rotary embedding later in the blog.

  10. max_batch_size: Defines the maximum batch size that can be supported. Loading the whole model into the memory would be very expensive if we were processing only one sequence at a time. We can use batching to parallelise execution, which is b whole model intounded by max_batch_size.

  11. max_seq_len: Context length supported. ModelArgs shows max length supported as 2k whereas Llama3 model card shows that even 8B model has 8K token length supported, this must be possible because of RoPE or Rotary positional embedding.

Up next

Now that we understand the configuration of Llama3 architecture, let us go through each of the modules in detail - Decoding Llama3: Part 3 - Normalisation

Comments

{{ gettext('Login to leave a comment') }}

{{ gettext('Post a comment…') }}
{{ gettext('New comment') }}
{{ formTitle }}

{{ errorMsg }}

{{ gettext('No comments posted yet') }}

Hosted by

NLP specialist with 14 years of experience building products. Exploring productivity hacks.