Grouped Query Attention (GQA) vs. Multi Head Attention (MHA): Optimizing LLM Inference Serving

Blog post thumbnail

Large Language Models (LLMs) are pushing the boundaries of what machines can understand and generate, excelling at tasks from writing coherent articles to generating creative content. However, serving these LLMs can be extremely memory intensive, presenting challenges for scalability and real-world deployment.

One bottleneck in LLM serving is the high memory bandwidth (and consumption) of loading (and storing) the attention keys and values for all the tokens. Grouped-query attention (link to original paper, J. Ainslie, et. al.) (GQA) is one technique that tackles this by grouping similar attention heads to reduce the memory load of the attention mechanism while capturing the complex relationships and patterns in the text. In this article, we will delve into what GQA is and see its effects firsthand with the Friendli Engine.

The Attention Mechanism

The attention mechanism in transformer models allows the decoder to focus on the most relevant parts of the input, improving the model’s understanding of intricate texts for tasks such as summarization, language translation, and beyond. It works like a database query, where a word (the “query”) is queried or compared in relevance to all other words (the “keys”), and the result is a weighted sum of the retrieved “values” with the relevance information baked in. Since each word is being compared against all other words in the sequence, the queries, keys, and values can be thought of as the words themselves–but they are differentiated by learnable weight matrices (Wq, Wk, and Wv) that are trained by the neural network to give better context.

Multi-head attention (MHA) and memory bottlenecks

In a sentence like “I helped my brother move the couch,” there is a relationship between “I” and “my brother”, as well as another important connection between “I” and the action of “[moving] the couch.” To address this, multi-head attention (MHA), used in the models such as Llama 13B and Llama 2 7B, applies the attention mechanism described above multiple times in parallel to capture different types of relationships in the data.

Multi-head attention contains multiple layers, or “heads”, of attention that each holds its weight matrices of queries, keys, and values. While this complexity captures more nuances, the big downside of MHA is its stress on the memory bandwidth during inference. This memory bandwidth overhead can become a severe bottleneck as the all attention keys and values must be loaded at each decoder step.

Multi-query attention (MQA) and the quality trade-off

To solve this memory bandwidth problem, multi-query attention (MQA) was created where only one key-value head exists for multiple query heads. While MQA significantly reduces the memory load and improves inference speed, it comes at the expense of lower quality and training instability.

Grouped-query attention (GQA)

Grouped-query attention (GQA) strikes a good balance between the quality of MHA and the speed of MQA. GQA uses the number of key-value heads as an intermediate value between 1 (MQA) and the number of query heads (MHA). With fewer key-value pairs to load, the memory load and computational complexity are both reduced. Plus, since GQA requires less memory space for storing attention keys and values for the tokens, we can use larger batch sizes to achieve higher throughput.

The benefits of GQA are especially evident in larger models, where memory bandwidth (and consumption) concerns are most prominent and MQA’s reduction to a single key-value head can be too severe. GQA is currently in use in foundation models such as Llama 2 70B and Mistral 7B. In the next section, we’ll observe the quantified benefits of GQA run with our Friendli Engine.

Evaluation

We tested GQA’s impact on inference speed by running Meta’s Llama 2 7B model compared to Mistral’s Mistral 7B model on a single NVIDIA A10G GPU with the Databricks Dolly 15K dataset under varying workloads. Both models have 7B parameters, but Llama 2 7B uses MHA while Mistral uses GQA.

We simulated varying load intensities by increasing the number of requests per second. To simulate real-world scenarios, the request arrival times are randomly sampled from the Poisson distribution, and we varied the parameter from 1N to 4N requests/sec throughout the evaluation. The performance metric is the 90th percentile time per output token (TPOT), where lower is better.

While for a very light workload, the TPOTs are similar, as the workload intensifies, Mistral consistently surpasses Llama 2. In fact, at the most challenging load, Llama 2 fails to respond within the allotted time. The GQA used in Mistral outperforms the MHA in Llama 2, especially in demanding contexts.

Conclusion

GQA is a great technique for efficient inference of LLMs without sacrificing quality by optimizing the transformer attention mechanism. GQA is just one of the many features well supported by Friendli Engine. We invite you to explore the power and efficiency of Friendli Engine, our cutting-edge LLM serving engine, through Friendli Dedicated Endpoints or Friendli Container.



Share

Related Posts

thumbnail
  • February 7, 2024
  • 2 min read

Friendli TCache: Optimizing LLM Serving by Reusing Computations

LLM
Serving
thumbnail
  • January 24, 2024
  • 3 min read

Faster and Cheaper Mixtral 8×7B on Friendli Serverless Endpoints

LLM
Serving
See all from blog