Thanks for making this! The notation is a bit confusion @8:36 if S = Q.K^T, and S={x_1, x_2....x_n}, then x_1, x_2.... should be column vectors, but here: m_0 = -inf ... m_i = max(m_i, x_i) they are handled as values, perhaps there's a missing outer loop? q_j where j goes through 1..N (if square matrices). but then S would be S={x_{1,1}, x_{1,2},...x_{1,N} x_{2,1}, x_{2,2},...x_{2,N} .... x_{N,1}, x_{N,2},...x_{N,N}} essentially I'm confused about whether O_N is a vector or a value. Thanks again for this content, I really enjoyed it!
Thanks for the question. Yes, in general S is a N x N matrix, where N is the number of tokens. When explaining the online softmax, we only look at the attention coming from one query vector and all key vectors. So the S = q * K^\T. Here the query vector is of size 1 x d_k and the key matrix is of the side d_k x N. Therefore, the "matrix" S is just a 1 x N vector. The O_N is a vector of size 1 by d, where d (it's a weighted average of value vectors, where the weights are from the attention). We only need to look at one query vector to understand the key idea of online softmax and FlashAttention. We can process multiple query vectors and key/value vectors at the same time in parallel (depending on the size of the on-chip SRAM).
Yes, good catch! In FlashAttention-1, the KV is in the inner loop and the Q is in the outer loop (to prevent repeatly writing the partial output to the HBM). In FlashAttention-2, they change it back to using KV as the outer loop and Q as the inner loop due to parallelization (as the partial outputs are always in the on-chip SRAM). I intentionally use the loop order from FlashAttention-2 to better illustrate the accumulation of the partial results to obtain the full output. I think it’s easier to understand for the core concept.
Great video as always
Thank you! Appreciate that!
agree
@@siddaarthprasanna Thank you so much!
Thank you for this video! Amazing explanation!
You’re welcome! Happy that you liked it.
Wow! What a beautiful explanation!!!!
Thank you so much!
Thank you for this video! Clearly explained!
Thank you!!
Clearly explained! Highly recommended.
Thank you, Larry!
It's a very good explanation and video in general as well!
Glad it was helpful!
@@jbhuang0604 Love all of the content on this channel! Thank you so much for doing this. I am spreading info about this great channel everywhere :))
Great explanation and animation!
Glad you liked it!
Really amazing video! May I ask what tools you use to create this video?
Thanks! The animation comes from PowerPoints. I edit the video with Adobe premiere pro.
Awesome!
Thanks!
Combining this video with Umar Jamil implementation is useful
That’s COOL!
Thanks for making this! The notation is a bit confusion @8:36 if S = Q.K^T, and S={x_1, x_2....x_n}, then x_1, x_2.... should be column vectors, but here:
m_0 = -inf
...
m_i = max(m_i, x_i)
they are handled as values, perhaps there's a missing outer loop? q_j where j goes through 1..N (if square matrices).
but then S would be S={x_{1,1}, x_{1,2},...x_{1,N}
x_{2,1}, x_{2,2},...x_{2,N}
....
x_{N,1}, x_{N,2},...x_{N,N}}
essentially I'm confused about whether O_N is a vector or a value.
Thanks again for this content, I really enjoyed it!
Thanks for the question. Yes, in general S is a N x N matrix, where N is the number of tokens.
When explaining the online softmax, we only look at the attention coming from one query vector and all key vectors. So the S = q * K^\T. Here the query vector is of size 1 x d_k and the key matrix is of the side d_k x N. Therefore, the "matrix" S is just a 1 x N vector.
The O_N is a vector of size 1 by d, where d (it's a weighted average of value vectors, where the weights are from the attention).
We only need to look at one query vector to understand the key idea of online softmax and FlashAttention. We can process multiple query vectors and key/value vectors at the same time in parallel (depending on the size of the on-chip SRAM).
@@jbhuang0604O_N is of shape 1 by d_v, thank you so much for this answer and making this video! You really made it click!
Thanks a lot!
@11:18 not HMB but HBM
Good catch! Clearly I was just trying to make sure if people are paying attention. :-p
Very very clear explanation! Thank you Professor, learned a lot! ps: May I ask what software you use to make the animation?
Thank you! It’s mostly morph transition from MS PowerPoint.
This video is so cool!!! May I ask how do you make this fantastic slides? Do you use Google docs, Beamer or sth else?
Thanks! I used PowerPoint. The animation comes from the morph transition.
@@jbhuang0604 Thx a lot!
Thank you for your video. I have a simple question.
The paper explained outer and inner loops, so is the order right at around 10:40?
Yes, good catch! In FlashAttention-1, the KV is in the inner loop and the Q is in the outer loop (to prevent repeatly writing the partial output to the HBM). In FlashAttention-2, they change it back to using KV as the outer loop and Q as the inner loop due to parallelization (as the partial outputs are always in the on-chip SRAM).
I intentionally use the loop order from FlashAttention-2 to better illustrate the accumulation of the partial results to obtain the full output. I think it’s easier to understand for the core concept.
Uncle Roger???
haha sorry, good explanation, cheers
Thank you! Cheers!