Flash attention(Fast and Memory-Efficient Exact Attention with IO-Awareness): A deep dive

Flash attention is power optimization transformer attention mechanism that provides 15% efficiency

Photo by sander traa on Unsplash

Flash attention is a power optimization transformer attention mechanism which provides 15% efficiency in terms of wall-clock speed with no approximation.

Context

Given transformer models are slow and memory hungry on long sequences (time and memory complexity is quadratic in nature), flash attention(paper) provides a 15% end-to-end wall-clock speedup on BERT-large, 3x speed on GPT-2.

Considering, enormous amount of energy consumed in training these large models, Flash attention with software and hardware optimization is able to provide 15% efficiency which is a huge win in terms of improvement.

Below, discussion helps to explain some of the basic concepts behind flash attention and how it is implemented.

Basic concepts around compute & memory

Before we dive deeper into compute and memory, let’s revisit them:

What is Compute?

  • Time spent on your GPU computing actual floating point operations (FLOPS)

What is Memory?

  • Time spent transferring tensors within a GPU

Ideally, we want our gCPU to be performing matrix multiplication all the time and not restricted by memory. But in reality, compute have made more progress as compared to memory and we are in a world where gCPU sits idle waiting for data to be loaded. This is usually called memory bound operation. Refer below on illustrative diagram depicting this. Matrix multiplication is considered compute and memory is storing the data (considering it as a warehouse). Compute need data to process and memory bandwidth has to support that operation.

Photo from https://horace.io/brrr_intro.html

What is Memory hierarchy ?

The A100 GPU has 40–80GB of high bandwidth memory with a bandwidth of 1.5–2.0 TB/s and 192KB of on-chip SRAM with each 108 streaming multiprocessors with bandwidth estimated around 19TB/s.

Photo from https://arxiv.org/abs/2205.14135

What is the problem with self attention architecture ?

With the above context in mind, self attention architecture is memory-bound.

Photo by the Author

Looking at attention math, it is a softmax operation which causes the memory-bound.

  • Quantitative evidence: As you can see below, operations like softmax, dropout, masking are taking majority of the time as compared to Matrix multiplication (Matmul)
Photo from https://arxiv.org/abs/2205.14135

Why does softmax become a memory bound operation ?

The scale at which it operates is our biggest bottleneck. In the below diagram

  • N -> number of tokens
  • d -> number of embedding dimensions
  • When Query and Key’ are multiplied, the attention matrix explodes to N * N which takes a lot of memory. For reference (d ~128; N ~128k tokens; google gemini: ~1 million tokens)
Photo from FlashAttention — Tri Dao | Stanford MLSys #67

[Algorithm] How is self attention implemented ?

Below is the algorithm of implementing self attention mechanism

Photo from https://arxiv.org/abs/2205.14135

As noted in the above section, transferring information to HBM (write S to HBM) and then loading back from HBM to gCPU to compute softmax and then writing back to HBM is a lot of information traveling making it memory-bound operation.

[Matrix multiplication] How is self attention implemented ?

Along with the diagram, below steps help explain how self attention is computed through matrix multiplication

Step 1:

  • I have simplified this. In practice, each token is added with positional encoding to generate embeddings to feed into a linear layer to generate <key, query and value>. For illustration I used a dimension of 3 (generally it ranges from 64–128). This is standard transformer architecture input.

Step 2

  • Key -> Key’ (transpose) is computed, and multiplied with Query to give QK’ which is N*N. This contains the attention of each token with the rest of the tokens. Below diagram shows the relationship as well. Since these are tokens and we need to compute the importance of each token with respect to each other, softmax operation is applied row-wise to normalize it from 0 -1.
  • This step requires movement to HBM and is the most expensive operation as we discussed. Entire flash attention paper is how to optimize this process.

Step 3

  • Softmax(QK’) * V is computed as the final output matrix. Dimension here is same as input embeddings of Key, query and value.
  • Final row in the output matrix
  • 1*5 means, the embedding of “this” should be changed to incorporate relations with other tokens.
  • 2*5 means, the embedding of “is” should be changed to incorporate relations with other tokens.
  • Same as above for rest of the other rows
Photo by the Author: Illustrative diagram of how self attention mechanism works

Basic idea behind the flash attention paper

Basic idea is explained through the below diagram where blocks of key, query and value are propagated from HBM to SRAM and through some mathematical tricks (explained below), the computation done here is not an approximate but actual correct answer.

With this implementation, paper is able to reduce the wall-speed time by accessing information in blocks without sacrificing correctness.

Photo from https://arxiv.org/abs/2205.14135

Algorithm behind the paper: How is Flash attention implemented ?

This is the most complex part of the paper. Let’s break this problem into sub-aspects and dive deeper.

Below diagram breaks the matrix into blocks and how each block is used to compute partial softmax and then correct softmax.

  • Initial input: Token: This is flash attention paper
  • Key: 4 (tokens) X 3(dimensions), Query: 4 (tokens) X 3(dimensions) and Value: 4 (tokens) X 3(dimensions)
Image modified by author. Original image from https://arxiv.org/abs/2205.14135

Step 0

  • Assume memory is 24 bytes
  • SRAM will be divided into 4 blocks (Query, Key, Value and output matrix)
  • Query, Key, Value, Output will get = 6 bytes each to store their info (12 bytes/4)
  • Each dimension is 3 since each embedding can not be broken, so
  • Query: 6 bytes/ 3 (dimension) = 2. Same for value, key and output
  • Hence, [M/4d] gives the size of each block. In this case, the block size is 2. It means 2 rows can be fetched into SRAM.
  • In general sense, Block size is [M/4d] and # of blocks is [N*4D/M]

Step 1 & 2: Adding a table below which illustrates steps 1 and 2 on how flash attention works and compare memory and computation aspect of it.

Photo by the Author: Step by step break-down of memory & computation usage in Flash attention

Below diagram helps visualize matrix multiplication (block by block) used in flash attention.

Photo by the Author: Illustrative diagram of how flash attention mechanism works

What is the mathematical aspect of softmax ?

One of the most critical aspects of the paper on how breaking down matrices still results in computing softmax accuracy. Leaving the mathematical example below on how to show two different matrices can be clubbed to compute softmax again.

Intuition

  • This is the beautiful property of exponents which is leveraged here.
  • Each softmax is computed individually but along with this maximum value of the row is stored along with the summed exponent value.
  • When merging with another matrix , we need to check how much max differs with the global max of 2 matrices. And because of the exponent, both numerator and denominator are adjusted with e^(current_max — global_max) to incorporate this.

Logic is quite complex and hence leaving an example below to go through. Once familiarized with an example, the above intuition will make a lot of sense.

Photo by the Author: Example to demonstrate how breaking matrix into sub-components and eventually combining them to compute softmax

Complexity analysis

Let’s look at complexity analysis to get a sense of how things changed

Self attention

  • While computing S = QK’ it becomes a N*N matrix which needs to be propagated back to HRAM and then pulled back from HRAM.
  • Hence O(N*N + N*N) = O(N*N) is HBM access

Flash attention

  • Outer loop: Key and Query will be accessed O(Nd) times
  • Inner loop: Only O(Nd/M) will be needed to load from HBM since operating on blocks
  • Overall: O(N*N*d*d/M)
  • Practically, d is much smaller than M. d ranges from (64–128) while M ranges from 100 KB and hence HBM access is optimized

Conclusion

  • We started with the objective of optimizing HBM access and with this complexity analysis, we see the paper has optimized the HBM access by (d*d/M) factor with no approximation.

Such a complex paper with huge improvement in efficiency. I hope the above explanation gives some intuition on how flash attention optimizes and improves the performance. I haven’t covered block sparse flash attention, how does this compare with other optimization techniques, forwards pass optimization etc. Hopefully to cover it in a future post.

References