How I Finally Understood Self-Attention (With PyTorch)

Поділитися
Вставка
  • Опубліковано 6 січ 2025

КОМЕНТАРІ • 57

  • @BitsOfChris
    @BitsOfChris  17 днів тому

    Thank you all for the feedback on this video!
    I just want to highlight a few things for transparency and completeness.
    I deliberately chose to simplify certain aspects of self-attention in this video to focus on conceptual clarity and make it approachable for beginners.
    For example:
    - I didn’t dive deeply into the query, key, and value matrices
    - I don't discuss causal masking (which ensures that when a model predicts the next word in a sentence, it only looks at the words that came before it and ignores anything that comes after). In fact, I do the opposite of this just to illustrate how context changes meaning.
    Going forward I will be sure to include these disclaimers and instructional shortcuts in the video itself rather than afterwards here as a comment.
    Thank you all for the feedback, it's been incredibly humbling and motivating to see :)

  • @SkegAudio
    @SkegAudio 19 днів тому +2

    it's paramount to have a great understanding of word2vec (aka word embedding vector) but even more important is understanding n-grams to have a grasp as to why word embeddings is such a significant advancement in nlp

    • @BitsOfChris
      @BitsOfChris  18 днів тому

      Great point about the progression from n-grams to word embeddings!
      While my work focuses more on time series data than NLP, these concepts have interesting parallels - like how we handle sequential dependencies in financial or climate data compared to text. Would love to hear your thoughts on how these foundational NLP concepts might apply to non-text sequences?

  • @adeadetayo
    @adeadetayo 23 дні тому +1

    Better explanation than most, keep up the good work .
    There is no requirement for math to be boring or inscrutible.

    • @BitsOfChris
      @BitsOfChris  23 дні тому

      @@adeadetayo really appreciate that, thank you for the compliment!

  • @sapdalf
    @sapdalf 22 дні тому

    I think this is a very good, basic explanation. It's quite illustrative. I believe it's an excellent introduction to understanding how self-attention works in transformers and how it's implemented in large language models.

    • @BitsOfChris
      @BitsOfChris  21 день тому

      Thank you for the kind words :)

  • @coastofkonkan
    @coastofkonkan 6 днів тому

    Great. Please use examples like How to make Thai Curry? ....so audience knows Curry then Thai then where to focus & where to provide attention.

  • @thenextension9160
    @thenextension9160 15 днів тому

    Subbed. Keep it coming.

  • @قيمنقعبود-ب2ل
    @قيمنقعبود-ب2ل 15 днів тому +1

    Excellent

  • @ryparks
    @ryparks 22 дні тому

    These white board style videos are really helpful. Keep it up! You got a subscriber in me and I look forward to seeing the grow!

    • @BitsOfChris
      @BitsOfChris  22 дні тому

      Thank you! I really appreciate hearing that :).
      It's helpful to hear any sort of feedback, please keep it coming!

  • @FinanceLogic
    @FinanceLogic 10 днів тому

    I agree with you that you seem dialed in. Nice video.

  • @mshonle
    @mshonle 22 дні тому

    Great explainer! As for other videos I’d be interested in… what is the deal with positional encoding, specifically the current state of the art; how does a text embedding actually guide the diffusion process in image generation; and, how is there even a gradient that can be useful in training these attention matrices.

    • @BitsOfChris
      @BitsOfChris  21 день тому +1

      Thank you for the ideas! Positional encoding is something I need to go deeper on myself.

  • @coopernik
    @coopernik 22 дні тому

    Subbed. I really like how you presented the topic. The software you use is great for breaking ideas down. I would have loved if you went through the paper at the same time. Trying to break down the complicated equations into what “they’re essentially saying”

    • @BitsOfChris
      @BitsOfChris  21 день тому +1

      Thanks for the feedback :)
      I like that idea of breaking a paper down section by section like that. Thanks for the suggestion - I'll that one.

  • @yassinebahou4088
    @yassinebahou4088 21 день тому

    Very good explanation

  • @AB-wf8ek
    @AB-wf8ek 21 день тому

    Thanks for the explanation. I've been thinking about how to better describe LLMs in general, and one angle I came up with is that it's more like a language calculator.
    Similar to how you wouldn't say a calculator understands arithmetic, it does arithmetic - LLMs don't understand language, they do language.
    I don't know if others would agree with that or not.

    • @BitsOfChris
      @BitsOfChris  21 день тому

      I think that's an excellent analogy. I've heard that used a few times as well. And really - LLMs are just a tool.
      My feeling is most people who hear "AI" are really picturing the Hollywood version of AGI that is there to replace them. This I think turns a lot people off from the whole topic in the first
      place.
      It's this school of thought that complained how ChatGPT couldn't count the number of "r"s in the word strawberry- which IMO completely misses the point.
      LLMs are just tools.
      Tools that are useful when used correctly.

    • @AB-wf8ek
      @AB-wf8ek 21 день тому

      @BitsOfChris Nice, really glad to get your feedback. Thanks again!

  • @michael_gaio
    @michael_gaio 20 днів тому +1

    subbed
    what tool are you using for your presentation ?

    • @BitsOfChris
      @BitsOfChris  19 днів тому +2

      Thanks :)
      I'm using Descript to record the screen and Excalidraw to make the visuals.

  • @timmygilbert4102
    @timmygilbert4102 17 днів тому

    Im still bum out thar people aren't able to draw connection with old parser and chatbot, i strongly feel that knowing how the stanford parser or chatscript works is a great insight about how llm works. Llm would feel a lot less black boxy, because they improve and do not exactly replace.

    • @BitsOfChris
      @BitsOfChris  17 днів тому

      Hey thanks for the suggestion - I've never heard of the Stanford parser. Is this (nlp.stanford.edu/software/lex-parser.shtml) what you refer to?
      I agree - LLMs and neural networks in general are very black boxy. What do you mean by they improve and do not replace?

    • @timmygilbert4102
      @timmygilbert4102 17 днів тому

      @BitsOfChris it's a bit long, so we will go through very simplistic illustration. Neural network are graph, that imply constraints in how they process data. So the question is, given immediate observation of LLM property, such as sequentiality, what shape should have a neural DAG to implement one instance of sequence, purely as a dag. You will find out that neuron can't, because its operations are commutative, order don't matter but it's needed to create sequence dependant operations like in language. By constructing manually a sequence in a DAG that respect ANN architecture, it reveal you need layer to encode basic sequences.
      If we turn ourselves to neuron, instead of considering them as magic statistic pixie faery, we can break it down into atomic operations and reason about them. Neuron do dot product with input vector, but that's not really saying much, let's go more basic, they pick an input and apply mul, that's a mask because it's in the range of 0-1, that first operation filter out data unneeded, then we sum the results, let's call it decibel, then with the bias we threshold through the activation function. Sounds trivial what's the big deal? Well it's about shifting the mental image to make it clearer, masking is equivalent to a bag of word solution aka set detection. Then the sum and activation implement logics on that set, aka a generalization of binary logics to set, with OR being ANY with low bias (any input will trigger) and AND being ALL (all input are needed) and a spectrum between the too being SOME.
      It's a shortcut to say that, but, chat script and the Stanford parser are bag of word solution, as they need a dictionary of words in classes or template. The difference being in how neural network encode atomic lexeme, ie not on a word basis but at token basis. But the fundamental are the same, the token vector encoding is just a black box representation of statistical proximity of words token, which is a type of bag of words, the difference is that bag of words have ad hoc relationship by virtue of being in a bag, statistic solving the creation part of the dictionary.
      If this true, we should be able to back port idea from LLM to regular chatbot by generating bag of word from token (it's loosely byte pair encoding), or using the class of ontology, like WordNet,as embedding and have a less nuanced LLM as proof.
      It also explain why we can quantized LLM down to 3 bit, 1 being set detection, -1 being anti set. Think of it like a naive character detector, pixel in the character need to be turned on for detection, pixel not part of the character turn down, that's two neuron for set and anti set, and a third neuron that takes the set and anti set and declare final solution, the bias of the set and anti set modulating sensibility to uncertainty, as seen in the generalized logics. It's more than just composition, it's circuitry, such as that looking at the semantics of a neuron is misleading, the same way that finding the addition transistor is misleading instead of looking for adder structure.
      That's a very cartoony presentation to fit in a comment, but I'll let you fill the blank for the complete picture. Here is an exercise, create a program that turn embedding into WordNet encoding, Aka human readable ontology, which will pick up word not encoded in that ontology.

  • @ehza
    @ehza 19 днів тому

    Thanks

  • @MudroZvon
    @MudroZvon 17 днів тому

    Why do you do this? Are you working at OpenAI/Anthropic developing LLMs?

    • @BitsOfChris
      @BitsOfChris  17 днів тому +1

      I like making content as a way to learn things deeper - publishing something is a forcing function for me to make sure I understand something well enough to explain it.
      Personally, I find the field fascinating and recently joined an Applied AI Research team working on time series foundation models. My background is data engineering so I'm focused on learning the modeling side now :)

  • @Jeremy-Ai
    @Jeremy-Ai 22 дні тому

    What happens then we choose not to speak no more?
    Who then is paying attention?

    • @BitsOfChris
      @BitsOfChris  21 день тому

      If a tree falls in the woods.. and no one is there to hear it.. does it even make a sound?

  • @ehza
    @ehza 20 днів тому

  • @yvesbernas1772
    @yvesbernas1772 22 дні тому +12

    Just another video that describes but does not explain. Why is being able to describe so often confused with understanding?

    • @BitsOfChris
      @BitsOfChris  22 дні тому +1

      Sorry this video didn’t meet your expectations. What part specifically do you think needs more depth?

    • @aiamfree
      @aiamfree 22 дні тому

      @yvesbernas1772 I got you bro. It's like this:
      Simplified:
      Encoder makes the inputs compatible with the model dimensions, then decoder finds the best response by applying best weights of encoded input using attention.
      Detailed:
      There's Dataloader, Encoder, Decoder, and Attention.
      - Dataloader takes data and creates batch inputs and batch outputs (you specify how many batches you want to produce), by prepping before passing to dataloader through methods like sliding windows for example.
      - Encoder is trained on how to transform (hint hint) these batch inputs to the model dimensionality via embedding sizes, normalization and returning hidden states (relationships between the tokens) and initializes weights and returns encoder outputs and hidden.
      - Decoder takes data from batch outputs from Data Loaders (to be passed as tensors) as well as the encoder outputs and hidden from Encoder, and this is the part where Attention comes in...Based on the decoder inputs (batch outputs from Data Loaders), the Decoder creates embeddings of the batch outputs, then passes the encoder_outputs to the Attention mechanism to create context vectors (context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs) and returns attention weights attn_weights, here's example code:
      encoder_outputs = self.layer_norm(encoder_outputs)
      hidden = self.layer_norm(hidden)
      hidden = hidden.unsqueeze(1).repeat(1, encoder_outputs.size(1), 1)
      combined = torch.cat((hidden, encoder_outputs), dim=2)
      energy = torch.tanh(self.attn(combined))
      energy = torch.matmul(energy, self.v)
      attn_weights = torch.softmax(energy, dim=1)
      attn_weights = torch.clamp(attn_weights, min=1e-9, max=1 - 1e-9)
      Then the Decoder finally gives token output by combining the embedding it created with the Attention context it received, and returns output and hidden state for the next token prediction, like this:
      # Embed the decoder input
      embedded = self.embedding(decoder_input)
      # Compute attention
      context, attn_weights = self.attention(hidden[-1], encoder_outputs)
      # Concatenate embedded input and context
      rnn_input = torch.cat((embedded, context), dim=2)
      # Pass through GRU
      output, hidden = self.rnn(rnn_input, hidden)
      # Compute vocabulary logits, by this step either training or prediction token has been produced
      output = self.fc_out(output.squeeze(1))
      Here's some (heavily truncated) code of the full process:
      #all_pairs is from training data ie: [[x,y]] where x = "The cat is" y = "cat is playing"
      dataloader = DataLoader(all_pairs, batch_size=batch_size,
      shuffle=True, collate_fn=collate_fn_stride, pin_memory=True)
      embedding_size = 768
      hidden_size = 768
      num_layers = 1
      encoder = Encoder(vocab_size, embedding_size, hidden_size, num_layers).to(device)
      decoder = Decoder(vocab_size, embedding_size, hidden_size, num_layers, Attention(hidden_size)).to(device)
      train_model(
      encoder,
      decoder,
      dataloader,
      num_epochs=10,
      learning_rate=0.0001,
      )
      def train_model(encoder, decoder, dataloader, num_epochs, learning_rate):
      #There's other steps in between like optimizers and entropy loss that can be setup before dealing with the encoders and decoders as well
      encoder.train()
      decoder.train()
      encoder_outputs, hidden = encoder(batch_inputs, input_lengths)
      decoder_input = torch.tensor([special_token_ids[SOS_TOKEN]] * batch_outputs.size(0), device=device).unsqueeze(1)
      # at this next step, this batch of inputs and outputs is trained for the epoch
      decoder_output, hidden = decoder(decoder_input, hidden, encoder_outputs, trg_mask)
      How they are used at inference:
      The only difference from the training to actually using the LLM is that instead of batch inputs and outputs from Dataloader (training data) you pass prompt inputs as tensors to the Encoder (and attention better thought of as weights and context is pretrained, so that is not a concern in generation).
      There's also things like source and target masks for improved attention (ie context weighting).

    • @justinpeter5752
      @justinpeter5752 22 дні тому

      How do the attention scores get set for each word and how do they get updated?

    • @aiamfree
      @aiamfree 22 дні тому +4

      @@justinpeter5752 well the attention usually takes an input sequence and output sequence (ie each input offset by a stride or sliding window of let's say 10 words) then transformed to tensors. It doesn't work on word by word basis, but as a weighted score of their relationships of what comes next (hence next token prediction). Any deeper than that then you're beginning to think about how to build transformer tech not what they are doing, which is ok if that's your thing.

    • @tristanreid5770
      @tristanreid5770 21 день тому

      @@justinpeter5752in a neural net anything that is called a parameter is set through a process called back propagation. The attention scores are parameters, and are set in this way. The way that it works is that going backwards through the neural net you send corrections from your training data, and the amount each node’s value changes is determined using calculus based on what its function is, and all the nodes attached to it. The main purpose of doing all this with PyTorch is that it automatically sets up all that calculus, for example that softmax function’s in the video: that function’s derivative gets chained together with all the other functions in the neural net to know how much to alter those weights when it’s exposed to training data

  • @WiseMan_1
    @WiseMan_1 23 дні тому

    First 😉

  • @shivrajnag12
    @shivrajnag12 23 дні тому +1

    Great explanation, I think watching this video ua-cam.com/video/KJtZARuO3JY/v-deo.html after yours can build a great intuition about transformers.

    • @BitsOfChris
      @BitsOfChris  22 дні тому

      Thank you - and thank you for that video. I love 3blue1brown, added this one to the watchlist.

  • @dancar2537
    @dancar2537 14 днів тому

    sorry: i am confident i saw this idea of words related to each other far before the paper you mention. it is not a new idea. what i am not sure is that is what the brain does, meaning it does only that. i really like your explanation it is clear. wquery is the possibilities available for lightly wkeys is the possibilities available for the other words and w value is probably the other meanings possible for the same words in the sentence in other contexts. or something close to that. maybe i am wrong but it does not matter i am close. i don t see why this is necessary probably the brain just skips it. thx. bigger nn is all you need

    • @BitsOfChris
      @BitsOfChris  13 днів тому +1

      Yes it’s definitely not a new idea, never claimed it to be novel and not sure why it matters. :)
      I’m just sharing what I learned about self attention and trying to explain it as I understand it for beginners or folks new to the concept.
      To your point about the brain doing this- yes I think we do this without realizing. Agreed though too, in general it seems more data, more parameters, and/or more compute at inference time matter for performance of models.
      Thanks for watching!

  • @davide0965
    @davide0965 18 днів тому +1

    It repeats the same phrases and words of the paper and other videos. Doesn't explain or add anything

    • @BitsOfChris
      @BitsOfChris  18 днів тому

      Thanks for your feedback! I'm not going to lie, this comment stings a little bit but I don't disagree.
      Sorry the video wasn't what you were looking for.
      I'm a data engineer who recently joined an Applied AI research team building time series foundation models. I use content to help me learn topics deeper.
      Right now, I'm an advanced beginner in these more nuanced AI topics and try to share what helped me as I learn.
      I hope you'll give the channel a second chance in the future. Anyway, have a great day!

  • @gagandwaz6823
    @gagandwaz6823 17 днів тому +1

    Very good explanation