Decoding Llama3: An explainer for tinkerers
A not-so-quick 7-part guide to using the Llama3 open source AI model
We are continuing our series of Decoding Llama3 with Feed Forward Network and SiLU activation function.
embeddings into internal embedding which generalizes well.
Let us recall the FFN architecture used in Transformer
FFN(x) = ReLU(xW1 +b).W2+b
ReLU, allows only values greater than 0.
Picture from PyTorch.
Hence we can rewrite our FFN layers as
FFN(x) = max(0, xW1 +b).W2 + b
Llama3 uses SwiGLU.
FFN(x) = (Swish(x.W1) * x.W3).W2
Swish is the same as SiLU (Sigmoid Linear Unit) when beta =1
Picture from Umar Jamil’s lecture on Llama.
Picture from PyTorch.
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
)
self.w2 = RowParallelLinear(
hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
)
self.w3 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
model_dim
and then converges back to model_dim
to consolidate the learning.dim
: model embedding dimensionhidden_dim
: hidden layer dimensionmultiple_of
: explained belowffn_dim_multiplier
: multiplier explained belowffn_dim_multiplier
is the multiplier value.hidden_dim
to the expected size, like multiple of large power of 2 like 256, multiple_of was introduced. The idea here is that it will get you to the closest multiple of multiple_of.multiplier_of
is 64hidden_dim
is 10064 * (( 100 + 64-1)%64) = 64* (163 %64) = 64*2 = 128
self.w1
, self.w2
and self.w3
.We have covered all the core components of the Transformer Block, next, we put them together as a machinery called Llama 3 😀
Head over to Decoding Llama3: Part 7 - TranformerBlock
Hosted by
Supported by
{{ gettext('Login to leave a comment') }}
{{ gettext('Post a comment…') }}{{ errorMsg }}
{{ gettext('No comments posted yet') }}