Photo by Markus Spiske on Unsplash

Paying Attention to Attention: A compilation

A compilation of LinkedIn posts on Transformer Attention Mechanisms

Vijayasri Iyer
7 min readDec 23, 2023

--

LinkedIn is an excellent platform for Educational short-form content and I love to use this platform for sharing AI content, to keep up with the latest news, and to learn from experts. Recently, I did a series of posts on attention mechanisms for transformers which were well received by the community. Since posts get lost on social media over time, I am creating a short compilation in this blog, so that it can serve as a quick overview for anyone looking to learn/brush up their knowledge of this topic. Enjoy!

🛬Self-attention and Cross-attention

What are self-attention and cross-attention mechanisms and when should they be used?

💡 Self-Attention: Each element in the sequence calculates its attention weights with respect to all other elements. This allows the model to weigh the importance of different elements for each position.

When is self-attention used?

1. Tasks involving single sequence data (eg: classification)

2. Self-attention mechanisms are quite computationally expensive with time complexity of O(n² * d) for a sequence of length n and input dimension d.

💡 Cross-Attention: Cross-attention is like self-attention but extended to tasks that involve two different sequences. For eg: cross-attention is used in the decoder to attend to the encoder’s output. When generating the next element in the output sequence, the model considers different parts of the input sequence, giving higher weights to the relevant parts.

When is cross-attention used?

1. Typically used when transferring information from one sequence to another like machine translation, summarization, etc.

2. When capturing complex relationships between elements in different sequences

3. Time complexity of O(m * n * d), where m is the length of the output sequence, n is the length of the input sequence, and d is the input dimension. The time complexity of cross-attention can be slightly better compared to self-attention depending on the length of the input output sequences.

Of course, most models use a combination of self and cross-attention in their architecture. Several derivatives of these attention mechanisms are incredibly popular in both the Computer Vision and NLP communities.

An excellent blog post on the topic by Sebastian Raschka, PhD: https://lnkd.in/g7nb3_Fh

🛬Flash Attention

In the LLM world, we often speak of the context length (basically the length of the input sequence) of a model and how we can extend it. Now, think of self-attention with a quadratic complexity w.r.t to the context length. Hard to scale! Here’s where “Flash Attention” comes in, that is an IO-aware, exact computation of attention but at a memory complexity that is sub-quadratic/linear! But how does it work?

Before that a quick overview of preliminaries, i.e GPU memory hierarchy (from top to bottom):

  • SRAM : Fastest memory (19TB/s with 20MB capacity, where the computations happen)
  • HBM : Slower than SRAM (1.5TB/s with 40GB capacity)
  • DRAM : Main memory, very slow (12.8 GB/s > 1TB)

Simple rule of memory hierarchy, the faster the speed, the lower the capacity.

3 main operations need to be done for attention computation. Q (Query), K (Key), V (Value)

(eq 1) S = QK_Transpose

(eq 2) P = softmax(S)

(eq 3) O = PV

O, being the final output.

Let’s see how Self-Attention (SA) uses this memory hierarchy :

  1. Loading matrices Q,K,V in HBM.
  2. Calculates S from (eq 1) and writes it back to HBM.
  3. Loads S from HBM again computes P (eq 2) and writes P back to HBM.
  4. Loads P and V from HBM, computes O (eq 3) and writes back to HBM.

See this pattern? With SA, there is a significant overhead due to the back-and-forth (excessive reads and writes) between the HBM and SRAM.

Let’s see how Flash Attention (FA) uses two concepts to avoid too many reads and writes to the HBM.

💡Tiling: In FA, Q,K,V are all divided into blocks and perform the computation of (eq 1) & (eq 2) i.e “softmax reduction” incrementally without having access to the full input so that we can use the SRAM to its full capacity.

💡Recomputation: Store an intermediate result (softmax normalization factor) during backward pass computation.

I’ve oversimplified the flow into 3 steps:

  1. Split softmax into numerator and denominator (since they are both distributive, this partial computation becomes possible).
  2. So, we have 3 partial output values stored in HBM. O_i is the partial output, L_i, is the partial numerator of softmax, and M_i is the partial denominator.
  3. In a two-loop structure (outer loop going over blocks of K,V and inner loop going over blocks of Q), use K_i, V_i, and Q_j to update the O_j, L_j, and M_j in the HBM
  4. Recompute O_j, L_j, and M_j in the backward pass.

Some excellent resources that I used to learn flash attention:

  1. This excellent YouTube video by Manish Gupta: https://www.youtube.com/watch?v=1RaIS98jj1Q&t=720s
  2. An ELI5-style blogpost by Aleksa Gordic : https://gordicaleksa.medium.com/eli5-flash-attention-5c44017022ad

🛬Soft vs Hard Attention

Any intro to attention content is left incomplete without discussing briefly the difference between soft and hard attention. So, let’s find out the key differences between soft and hard attention and why soft attention, has become by far the popular choice among AI researchers.

Soft attention:

💡 All our current attention mechanisms are soft attention mechanisms (think self, cross, global, local …)

💡Soft attention allows the model to assign weights to all parts of the input, typically using a softmax function. This means the model computes a weighted sum of all input elements, where each weight is the importance of the corresponding input element.

💡Since, soft attention provides a smooth, probabilistic way (hence differentiable) to decide where to focus, enabling the model to consider multiple parts simultaneously and adjust their weights based on the task.

Now, modern neural networks use backpropagation and gradient descent as their main mechanisms for learning. For these mechanisms to work you need to make sure all your components including loss functions, activations, and attention mechanisms are continuous and differentiable. Differentiable means that a function has a defined derivative at every point in its domain.

Hard attention:

💡Hard attention makes a discrete choice at each step, directly selecting a single part of the input to focus on.

💡Unlike soft attention’s continuous distribution of attention weights, hard attention results in a binary or categorical decision.

💡While, this approach can be more computationally efficient since it avoids the need for probabilistic sampling but may suffer from training difficulties due to non-differentiability and the need for reinforcement learning techniques.

Alright, so the choice between hard and soft attention essentially boils down to basic calculus principles of continuity and differentiability. Hopefully, you can now appreciate the meaning of the word smooth in deep-learning papers even more!

Some excellent resources:

  1. Blog by Nikolas Adaloglou: https://lnkd.in/gVkcxvsD
  2. Blog by Jonathan Hui : https://lnkd.in/gEvkrVPV
  3. Continuity and differentiability by Khan Academy: https://lnkd.in/gDg_Xpxh

🛬Deformable Attention

The impact of transformers and attention mechanisms in Computer Vision is not as appreciated even though they are causing just as many ripples in the field, as in NLP. Deformable attention has become a widely used concept in 2D & 3D computer vision and a recurring actor in most top-tier computer vision papers.

How does it work?

Deformable attention (DA) is an adaptation of the Deformable convolution, introduced in 2017 (Dai et al.) as a means to improve the performance of the CNN, by allowing it to learn a dynamic receptive field. Similarly, in the case of a transformer model, deformable attention is an extension of the traditional self-attention mechanism, providing models with the ability to dynamically adjust their focus based on learned deformable offsets. Think of this as shapeshifting your focus based on the size and shape of the objects in an image. In a nutshell, DA adapts its focus (offsets) based on important features in the data (query). It can also be useful for reducing computational load by sampling reference points from the grid (feature map) to determine the offsets.

Finally, we get to the working:

Consider an image/feature map as a grid with dimensions HxWxC (Height, Width, Channels).

Step 1: Obtain values of q,k,v

  1. Downsample the feature map into a uniform grid of dimension Hg x Wg x 2, (Hg = H/r, Wg = W/r, r is pre-determined), sample p reference points from the grid.
  2. Reference points p are linearly projected onto a 2D co-ordinate system normalized between [-1,+1].****
  3. Obtain the query token by a linear projection of the feature map, q = xWq.
  4. Obtain the offset (Δp) by passing the query token (q) to a subnetwork (2 conv layers). The deformed reference points are (p + Δp).
  5. Conduct a bilinear interpolation on these deformed reference points (p + Δp), sampling the feature x~.
  6. Similar to (3), obtain the key token for the feature x~, k~ = x~ 𝑊k and the value token v~ = x~ 𝑊v.

Note: Wg is referring to the width of the downsampled grid, whereas Wq, Wk, Wv are projection matrices. Our outputs from step 1 are q, k~ and v~.

Step 2: Apply self-attention to calculate the final output. The deformable offsets (Δp) are learned through backpropagation.

Some excellent resources to learn about Deformable Attention in detail:
💡 Blog by Joe E. : https://lnkd.in/dMNNQVd5
💡Implementation by 1D-2D-3D Deformable Attention by Phil Wang (lucidrains) : https://lnkd.in/d4qc-zsd

--

--

Vijayasri Iyer

Machine Learning Scientist @ Pi School. MTech in AI. Musician. Yoga Instructor. Learnaholic. I write about anything that makes me curious.