Accelerating LLaMA-2 Inference with ONNX Runtime

By: Kunal Vaishnavi and Parinita Rahi

14TH NOVEMBER, 2023

Interested in running Llama2 faster? Let us explore how ONNX Runtime can propel your Llama2 variants for faster inference!

You can now experience significant inference gains—up to 4X faster—for the 7B, 13B, and 70B models, thanks to state-of-the-art fusion and kernel optimizations with ONNX Runtime. This blog details performance enhancements, dives into ONNX Runtime fusion optimizations, multi-GPU inferencing support, and guides you on how to leverage the cross-platform prowess of ONNX Runtime for seamless inferencing across platforms. This is the first in a series of upcoming blogs that will cover additional aspects for efficient memory usage with ONNX Runtime quantization updates, and cross-platform usage scenarios.

Background: Llama2 and Microsoft

Llama2 is a state-of-the-art open source LLM from Meta ranging in scale from 7B to 70B parameters (7B, 13B, 70B). Microsoft and Meta announced their AI on Azure and Windows collaboration in July 2023. As part of the announcement, Llama2 was added to the Azure AI model catalog, which serves as a hub of foundation models that empower developers and machine learning (ML) professionals to easily discover, evaluate, customize, and deploy pre-built large AI models at scale.

ONNX Runtime allows users to easily integrate the power of this generative AI model into your apps and services with improved optimizations that yield faster inferencing speeds and lower your costs.

Faster Inferencing with New ONNX Runtime Optimizations

As part of the new 1.16.2 release, ONNX Runtime now has several built-in optimizations for Llama2, including graph fusions and kernel optimizations. The inference speedups, when compared to Hugging Face (HF) variants of Llama2 in PyTorch compile mode for prompt latency of CUDA FP16, are mentioned below. We see ~3X gains in end-to-end throughput comparisons for both 7B and 13B models. The end-to-end throughput or wall-clock throughput shown below is defined as batch size * (prompt length + token generation length) / wall-clock latency where wall-clock latency = the latency from running end-to-end and token generation length = 256 generated tokens. The E2E throughput is up to 4.5X more when compared to PyTorch compile.

E2E Throughput Comparisons - Llama-2-7b
E2E Throughput Comparisons - Llama-2-13b
Figure 1: E2E Throughput Comparisons

Latency and Throughput

The graphs below show latency comparisons between the ONNX Runtime and PyTorch variants of the Llama2 7B model on CUDA FP16. Latency here is defined as the time it takes to complete one pass through the model to produce the logits and synchronize the outputs.

Prompt Latency Comparisons - Llama-2-7b
Prompt Latency Comparisons - Llama-2-13b
Figure 2: Prompt Latency Comparisons

Token generation throughput below is the average throughput of the first 128 tokens generated. We see up to 3.5X gains in token generation throughput when compared to PyTorch eager and compile modes.

Tokens Generated Throughput Comparisons - Llama-2-7b
Tokens Generated Throughput Comparisons - Llama-2-13b
Figure 3: Tokens Generated Throughput Comparisons

More details on these metrics can be found here.

ONNX Runtime with Multi-GPU Inference

ONNX Runtime supports multi-GPU inference to enable serving large models. Even in FP16 precision, the LLaMA-2 70B model requires 140GB. Loading the model requires multiple GPUs for inference, even with a powerful NVIDIA A100 80GB GPU.

ONNX Runtime applied Megatron-LM Tensor Parallelism on the 70B model to split the original model weight onto different GPUs. Megatron sharding on the 70B model shards the PyTorch model with FP16 precision into 4 partitions, converts each partition into ONNX format, and then applies a new ONNX Runtime graph fusion on the converted ONNX model. The 70B model has ~30 tokens per second throughput for token generation at batch size 1, and end-to-end throughput starts at 30 ms for smaller sequence lengths with these optimizations. You can find additional example scripts here.

70B Llama2 Model Throughput
Figure 4: 70B Llama2 Model Throughput

ONNX Runtime Optimizations

LLaMA-2 Optimization Diagram
Figure 5: LLaMA-2 Optimization Diagram

The techniques that ONNX Runtime uses for optimizations, such as graph fusions, are applicable to state-of-the-art models. As these models become more complex, the techniques used to apply the graph fusions are adapted to accommodate the extra complexity. For example, instead of manually matching fusion patterns in the graph, ONNX Runtime now supports automated pattern matching. Rather than detect large subgraphs by hand and match the many paths they form, fusion opportunities can instead be identified by exporting a large module as a function and then pattern matching against a function's spec.

Example of Rotary Embedding Function
Figure 6: Example of Rotary Embedding Function

As a concrete example, Figure 6 is an example of the nodes that comprise rotary embedding computations. Pattern matching against this subgraph is cumbersome because of the number of paths to verify. By exporting this as a function, the parent view of the graph will only show the inputs and outputs and represent all these nodes as a single operator.

Example of Rotary Embedding Function in Parent Graph
Figure 7: Example of Rotary Embedding Function in Parent Graph

This approach makes it easier to maintain and support future versions of the rotary embedding computations because the pattern matching is only dependent on the operator's inputs and outputs instead of its internal semantic representation. It also allows other existing implementations of rotary embeddings in similar models such as GPT-NeoX, Falcon, Mistral, Zephyr, etc. to be pattern matched and fused with minimal or no changes.

ONNX Runtime also adds support for the GroupQueryAttention (GQA) operator, which leverages the new Flash Attention V2 algorithm and its optimized kernels to efficiently compute attention. The GQA operator supports past-present buffer sharing between the past key/value cache (past KV cache) and the present key/value cache (present KV cache). By binding the present KV caches to the past KV caches, there is no need to allocate separate on-device memory for both caches. Instead, the past KV caches can be pre-allocated with enough on-device memory so that no new on-device memory needs to be requested during inference. This reduces memory usage when the KV caches become large during compute-intensive workloads and lowers latency by eliminating on-device memory allocation requests. The past-present buffer sharing can be enabled or disabled without needing to change the ONNX model, allowing greater flexibility for end users to decide which approach is best for them.

In addition to these fusions and kernel optimizations, ONNX Runtime reduces the model’s memory usage. Besides quantization improvements (which will be covered in a future post), ONNX Runtime compresses the size of the cosine and sine caches used in each of the rotary embeddings by 50%. The compute kernels in ONNX Runtime that run the rotary embedding computations can then recognize this format and use their parallelized implementations to calculate the rotary embeddings more efficiently with less memory usage. The rotary embedding compute kernels also support interleaved and non-interleaved formats to support both the Microsoft version of LLaMA-2 and the Hugging Face version of LLaMA-2 respectively while sharing the same calculations.

The optimizations work for the Hugging Face versions (models ending with -hf) and the Microsoft versions. You can download the optimized HF versions from Microsoft's LLaMA-2 ONNX repository. Stay tuned for newer Microsoft versions coming soon!

Optimize your own model using Olive

Olive is a hardware-aware model optimization tool that incorporates advanced techniques such as model compression, optimization, and compilation. We have made ONNX Runtime optimizations available through Olive so you can streamline the entire optimization process for a given hardware with simple experience.

Here is an example of Llama2 optimization with Olive, which harnesses ONNX Runtime optimizations highlighted in this blog. Distinct optimization flows cater to various requirements. For instance, you have the flexibility to choose different data types for quantization in CPU and GPU inference, based on your accuracy tolerance. Additionally, you can fine-tune your own Llama2 model with Olive-QLoRa on client GPUs and perform inference with ONNX Runtime optimizations.

Usage Example

Here is a sample notebook that shows you an end-to-end example of how you can use the above ONNX Runtime optimizations in your application.

Conclusion

The advancements discussed in this blog provide faster Llama2 inferencing with ONNX Runtime, offering exciting possibilities for AI applications and research. With improved performance and efficiency, the horizon is wide open for innovation, and we eagerly await new applications built with Llama2 and ONNX Runtime by its vibrant community of developers. Stay tuned for more updates!