Wow, I searched so much for a proper explanation for how a transformer works. Thank you for this simplified, to-the-point explanation. Hard to find such simple and to-the-point explanations these days! Even you mentioned where the issue is faced with fitting the operations on a GPU and how that is related to sequence length, and even key-value caches! I struggled to get this point even watching a proper university course on HW-SW co-design for LLMs. Such model efficiency related points are really valuable! You also kept the evolution of "Attention" very simple, I was expecting that topic to go over my head with information overload before I found your video. I love how knowledge dense this video is without compromizing on simpleness and ease-of-understanding.
Subscribed immediately. You either have a natural knack for pedagogy, or have learned some great teaching & active learning techniques. In either case, much appreciated! This is a great compliment to, and reinforcement of 3B1B's machine learning series.
8:17 - i understand why "queries" and "keys" are needed - they essentially "say" how similar words are to each other and this can be interpreted as "understanding the context" in the attention block but "values" are a bit unclear to me, for me it is something like "normalization" after obtaining attention mask (result of multiplying "queries" and "keys")
The way I like to think about is that: the attention mask just tells us how much each query embedding pays attention to each key embedding. If query and key are [N x d] matrices (where N = number of tokens in sequence, d = size of embedding dimension), the dot product of K and Q will give us the NxN attention mask. From the attention mask, we must now form a new representation of "what the contextual embeddings are for each token". This is where the value embeddings are used. By doing a dot product between the attention-mask and the value embedding matrix, you are combining information from the values to form the contextual embedding for each token. i.e. if value embeddings is of shape Nxd, by doing the dot product of Attention mask with Value, you once again get back the Nxd contextual embedding. So, attention mask can only tell you how much attention one token provides to other - it is just a probability distribution. The multiplication with the value matrix aggregates the value embeddings together to form the new contextual embedding. Step 1 is calculating how much attention to pay, Step 2 is aggregating information according to the calculated attention scores. Hope that makes sense.
@@avb_fj yes, thanks for the explanation, it helped, of course i need to look at a specific example (it seems to me that you will need matrix transposition (yeah, it was later in this video)), but the general sense is clear in essence, first we get a matrix that contains information about the "similarity" of vectors a.k.a. attention map then this matrix (attention map), which has an inconvenient dimension of "n x n", is embedded in value matrix using matrix multiplication and then we have a better way of processing tokens in transformer (we have n x d as input, then we get once again n x d, but this time it has info about relationships between tokens) i will not raise the question of whether it is possible to do without values, even if it is possible, apparently, this is not used because of inconvenience or key-query-value way simply technically simpler and better
@ earlier implementations of attention, before transformers, often used to reuse the query embedding itself as a replacement for values. Look up Bahnadau attention which was used for language translation/ seq2seq tasks. I also have a 3 part series on that, and a coding tutorial (you will find them in the description)
So this is an implementation detail. Let’s say your input word+pos embedding size is 512 and num-attention-heads is 8 (both are hyperparams), you can do one of the two things below: 1. Project to K, Q, V embedding vectors of shape 512 through Wq, Wk, Wv… and then split the vectors up into 8 sections of 64 sized-vectors. Run attention on each and concat them to get 512 dims context vector 2) write it as 8 different sets of Wq, Wk, Wv to convert the 512 dims input to 64 dimensional vectors for each head. Run attention and concat them. In practice, both approaches are same since they generate 8 64-dimensional vectors for your MHA calculation. There is one other approach but this is never taken in practice. You can train 8 different Wq, Wv, Wk that project the 512 input dims to 512 output dims. Run attention on each 512-dimensional embedding, and then ADD them (instead of concat) to get your final context vector. In practice, this will also generate a 512 dimensional context vector with MHA… but this approach is never taken because it makes the number of params in your network scale with the num attention heads. Meaning if you increase num attention heads from 8 to 16, you will now be generating 16 context vectors of size 512. The splitting method avoids this because the total number of FLOPs remains consistent and independent of num attention heads. Hope that makes sense. Feel free to ask if something wasn’t clear. Your question was very practical and deep, so I went for this long answer.
@@avb_fj that I understand, yes. I think I wasn't clear in my question, I was talking about what you said in the video, not the slides, you said "the input embeddings are splitted into multiple heads", which I thought means the word+pos embeddings are splitted. I thought it should be "the key, query, value weight matrices are splitted into heads", no? Btw the extra approach of ADDING instead of concat is very informative, I didn't know that people also did that!
I have a question regarding multi-headed attention. How does the varying causal attention work in the individual heads? Meaning how do you determine one head to focus on semantics and the other to focus on grammar etc.?
By the magic of Deep Learning and optimization, the network tries to learn these behaviours from the data itself. We do not pose any additional training guidance to make the network focus on semantics/grammar etc. We just build the architecture (i.e. MHA) that could support such behaviour, and then throw a LOT of data at the model during training. The network tries to fit the data best for next word prediction, and empirically this usually results in different heads learning different modes of attention. Check out: www.comet.com/site/blog/explainable-ai-for-transformers/ The example in that website talks about a different transformer model called BERT which is not trained on next word prediction, but the ideas remain the same.
@@avb_fj But if I am not mistaking, if all the head have the same initial weight and the gradients are calculated with the same loss function, shouldn't the updates then be the same for all the heads weights? I must miss something here.
@@kr8432 So, your intuition is correct based on your initial assumption. However, "all the heads having the same weights" is not exactly true. Let's say you have 8 heads each of embedding size 64. This means that from your original input embeddings, you generate large K,Q,V embedding vectors of size (8x64=) 512. This embedding vector now gets split up into 8 parts and each of them go through their own attention heads. This means that the computation graph for each of the K,Q,V in each attention head is different, since each of the 512 original values has a unique computation graph. So... weights are not shared. There are also additional things like Layer Normalization and Residual connections that standard implementations of MHA also do which I didn't include in the video.
Correct, a common implementation of multi head attention is to basically slice up the K/Q/V embedding vectors into multiple smaller ones. Each of these “children vectors” go through their attention computation graph to generate context vectors, and finally these are concatenated together to output the final combined context. By seperating out the vectors like that, MHA allows you to learn different attention maps than just one, as explained in the video. These processes are what we call as “heads”.
Yeah… the term word embeddings is less used today since most language models have moved on from pure word based tokenizers and instead use subwords, byte-pair encoders, character level encoders etc. Token embeddings should be the more correct/universal terminology, but I used word embeddings in earlier parts of the video for newer learners to grasp the concept easier.
This is the next video I’m working on. Assuming you are talking about the Byte Latent Transformers, the main innovation there is smarter tokenization strategies. So none of the contents in this video is outdated or anything. Besides nothing is truly “solved” yet in AI. 😊
@ 22:03 - Looks like you meant to have high latency on MHA and low latency on MQA? Great video!
That is correct! Thanks for pointing it out.
Wow, I searched so much for a proper explanation for how a transformer works. Thank you for this simplified, to-the-point explanation. Hard to find such simple and to-the-point explanations these days!
Even you mentioned where the issue is faced with fitting the operations on a GPU and how that is related to sequence length, and even key-value caches! I struggled to get this point even watching a proper university course on HW-SW co-design for LLMs. Such model efficiency related points are really valuable!
You also kept the evolution of "Attention" very simple, I was expecting that topic to go over my head with information overload before I found your video.
I love how knowledge dense this video is without compromizing on simpleness and ease-of-understanding.
Thanks for the feedback. This is exactly the viewer experience I was aiming to go for, and I am so glad you found it helpful. 😊
I'm on-boarding onto a GenAI team at Meta and have been using your videos to learn about LLMs. They are incredibly helpful and well done, thank you!
That is awesome! Good luck at your job! Glad the videos are helping. :)
Subscribed immediately. You either have a natural knack for pedagogy, or have learned some great teaching & active learning techniques. In either case, much appreciated! This is a great compliment to, and reinforcement of 3B1B's machine learning series.
Such an amazing explanation! So simple and intuitive to understand and taking through the changes made this more easier to understand. Thanks a lot!
I’m glad you found it helpful! 😊
Excellent content as always.
8:17 - i understand why "queries" and "keys" are needed - they essentially "say" how similar words are to each other and this can be interpreted as "understanding the context" in the attention block
but "values" are a bit unclear to me, for me it is something like "normalization" after obtaining attention mask (result of multiplying "queries" and "keys")
The way I like to think about is that: the attention mask just tells us how much each query embedding pays attention to each key embedding. If query and key are [N x d] matrices (where N = number of tokens in sequence, d = size of embedding dimension), the dot product of K and Q will give us the NxN attention mask.
From the attention mask, we must now form a new representation of "what the contextual embeddings are for each token". This is where the value embeddings are used. By doing a dot product between the attention-mask and the value embedding matrix, you are combining information from the values to form the contextual embedding for each token. i.e. if value embeddings is of shape Nxd, by doing the dot product of Attention mask with Value, you once again get back the Nxd contextual embedding.
So, attention mask can only tell you how much attention one token provides to other - it is just a probability distribution. The multiplication with the value matrix aggregates the value embeddings together to form the new contextual embedding. Step 1 is calculating how much attention to pay, Step 2 is aggregating information according to the calculated attention scores. Hope that makes sense.
@@avb_fj yes, thanks for the explanation, it helped, of course i need to look at a specific example (it seems to me that you will need matrix transposition (yeah, it was later in this video)), but the general sense is clear
in essence, first we get a matrix that contains information about the "similarity" of vectors a.k.a. attention map
then this matrix (attention map), which has an inconvenient dimension of "n x n", is embedded in value matrix using matrix multiplication and then we have a better way of processing tokens in transformer (we have n x d as input, then we get once again n x d, but this time it has info about relationships between tokens)
i will not raise the question of whether it is possible to do without values, even if it is possible, apparently, this is not used because of inconvenience or key-query-value way simply technically simpler and better
@ earlier implementations of attention, before transformers, often used to reuse the query embedding itself as a replacement for values. Look up Bahnadau attention which was used for language translation/ seq2seq tasks. I also have a 3 part series on that, and a coding tutorial (you will find them in the description)
You are awesome man! Thanks a lot. Really helpful.
Your best video on attention yet!
Excellent explanations, thanks!!
At 14:50 - did you mean splitting the Wq, Wk, Wv matrices into heads instead of splitting the input embeddings? Great video as usual btw!
So this is an implementation detail. Let’s say your input word+pos embedding size is 512 and num-attention-heads is 8 (both are hyperparams), you can do one of the two things below:
1. Project to K, Q, V embedding vectors of shape 512 through Wq, Wk, Wv… and then split the vectors up into 8 sections of 64 sized-vectors. Run attention on each and concat them to get 512 dims context vector
2) write it as 8 different sets of Wq, Wk, Wv to convert the 512 dims input to 64 dimensional vectors for each head. Run attention and concat them.
In practice, both approaches are same since they generate 8 64-dimensional vectors for your MHA calculation.
There is one other approach but this is never taken in practice. You can train 8 different Wq, Wv, Wk that project the 512 input dims to 512 output dims. Run attention on each 512-dimensional embedding, and then ADD them (instead of concat) to get your final context vector. In practice, this will also generate a 512 dimensional context vector with MHA… but this approach is never taken because it makes the number of params in your network scale with the num attention heads. Meaning if you increase num attention heads from 8 to 16, you will now be generating 16 context vectors of size 512. The splitting method avoids this because the total number of FLOPs remains consistent and independent of num attention heads.
Hope that makes sense. Feel free to ask if something wasn’t clear. Your question was very practical and deep, so I went for this long answer.
@@avb_fj that I understand, yes. I think I wasn't clear in my question, I was talking about what you said in the video, not the slides, you said "the input embeddings are splitted into multiple heads", which I thought means the word+pos embeddings are splitted. I thought it should be "the key, query, value weight matrices are splitted into heads", no?
Btw the extra approach of ADDING instead of concat is very informative, I didn't know that people also did that!
Great video, a lot of quality content!
I have a question regarding multi-headed attention. How does the varying causal attention work in the individual heads? Meaning how do you determine one head to focus on semantics and the other to focus on grammar etc.?
By the magic of Deep Learning and optimization, the network tries to learn these behaviours from the data itself. We do not pose any additional training guidance to make the network focus on semantics/grammar etc. We just build the architecture (i.e. MHA) that could support such behaviour, and then throw a LOT of data at the model during training. The network tries to fit the data best for next word prediction, and empirically this usually results in different heads learning different modes of attention.
Check out: www.comet.com/site/blog/explainable-ai-for-transformers/
The example in that website talks about a different transformer model called BERT which is not trained on next word prediction, but the ideas remain the same.
@@avb_fj But if I am not mistaking, if all the head have the same initial weight and the gradients are calculated with the same loss function, shouldn't the updates then be the same for all the heads weights? I must miss something here.
@@kr8432 So, your intuition is correct based on your initial assumption. However, "all the heads having the same weights" is not exactly true.
Let's say you have 8 heads each of embedding size 64. This means that from your original input embeddings, you generate large K,Q,V embedding vectors of size (8x64=) 512. This embedding vector now gets split up into 8 parts and each of them go through their own attention heads.
This means that the computation graph for each of the K,Q,V in each attention head is different, since each of the 512 original values has a unique computation graph. So... weights are not shared.
There are also additional things like Layer Normalization and Residual connections that standard implementations of MHA also do which I didn't include in the video.
i've understood this all..thank you for making this video. keep making this type video bro..🙂
15:40 so what actually is a head? Is it the embedding vector divided evenly into smaller vectors?
Correct, a common implementation of multi head attention is to basically slice up the K/Q/V embedding vectors into multiple smaller ones. Each of these “children vectors” go through their attention computation graph to generate context vectors, and finally these are concatenated together to output the final combined context. By seperating out the vectors like that, MHA allows you to learn different attention maps than just one, as explained in the video. These processes are what we call as “heads”.
Great Video
Awesome video
Great video man, one question tough is the input not token embeddings instead of word embeddings?
later on in the video you do mention token embeddings instead of word embeddings saw it after the comment :)
Yeah… the term word embeddings is less used today since most language models have moved on from pure word based tokenizers and instead use subwords, byte-pair encoders, character level encoders etc. Token embeddings should be the more correct/universal terminology, but I used word embeddings in earlier parts of the video for newer learners to grasp the concept easier.
@avb_fj yes i tought so :)
Meta just solved this with the Byte Transformers lol
This is the next video I’m working on. Assuming you are talking about the Byte Latent Transformers, the main innovation there is smarter tokenization strategies. So none of the contents in this video is outdated or anything. Besides nothing is truly “solved” yet in AI. 😊