Now Reading
Quick Transformer coaching with lengthy sequences

Quick Transformer coaching with lengthy sequences

2023-10-01 06:23:00

Since FlashAttention was launched 6 months in the past, it has been adopted by many organizations and analysis labs to hurry up their coaching & inference (see this page for a partial record).

For the final 2 months I’ve been collaborating with Adept as a part-time analysis fellow and we’ve been growing some enhancements to FlashAttention to make it even higher! On this publish, we describe one key enchancment that we’re significantly enthusiastic about: making FlashAttention quick for lengthy sequences to allow coaching giant language fashions with longer context.

For example, for sequence size 8k, FlashAttention is now as much as 2.7x sooner than a normal Pytorch implementation, and as much as 2.2x sooner than the optimized implementation from Megatron-LM, even at small batch dimension. As we’ll see, coaching with longer context yields increased high quality fashions. As we’ve mentioned earlier than, we consider that modeling longer sequences may assist us take the following leap in AI, and FlashAttention is one element to scale Transformers to longer context. At Adept, we’ve been coaching giant Transformers (ACT-1) to take actions with the objective of constructing an AI teammate. Understanding webpages, software program instrument interfaces, and multi-turn consumer interactions can require contexts that far exceed the widespread 2k normal.

Motivation: Lengthy sequences

Scaling up the context size of Transformers is a problem, because the multihead consideration layer at their coronary heart has runtime and reminiscence requirement quadratic within the enter sequence size. Ideally, we want to transcend the usual 2k sequence size restrict to coach fashions to know books, excessive decision photos, webpages, multi-turn consumer interactions, and long-form movies.

FlashAttention is an algorithm that reorders the eye computation and leverages classical methods (tiling, recomputation) to considerably velocity it up and cut back reminiscence utilization from quadratic to linear in sequence size. This works nice for many circumstances, nevertheless it was not optimized for the case of tremendous lengthy sequences (the place batch sizes and numbers of heads are small) as a result of inadequate parallelism. If one trains giant Transformers on lengthy sequences with fashionable parallelism methods (information parallel, pipeline parallel, tensor parallel) to separate the information and mannequin amongst many GPUs, the batch dimension can get very small (e.g. batch dimension of 1 with pipeline parallelism, and variety of heads round 8-12 with tensor parallelism). That is the case we want to optimize for.

Consideration parallelism to optimize for lengthy sequences

For every consideration head, to cut back reminiscence reads/writes, FlashAttention makes use of classical tiling methods to load blocks of question, key, and worth from GPU HBM (its primary reminiscence) to SRAM (its quick cache), compute consideration with respect to that block, and write again the output to HBM. This discount in reminiscence reads/writes brings important speedup (2-4x) normally.

The primary model of FlashAttention parallelizes over batch dimension and variety of heads. For these conversant in CUDA programming, we use 1 thread block to course of one consideration head, and there are general batch_size * num_heads threadblocks. Every thread block is scheduled to run on a streaming multiprocessor (SM), and there are 108 of those SMs on an A100 GPU for instance. This scheduling is environment friendly when batch_size * num_heads is giant (say >= 80), since we are able to successfully use virtually the entire compute sources on the GPU.

Within the case of lengthy sequences (which often means small batch sizes or small variety of heads), to make higher use of the multiprocessors on the GPU, we now moreover parallelize over the sequence size dimension. This ends in important speedup for this regime.

Right here is the ahead go computation expressed schematically. Now we have a number of staff (i.e. thread blocks) to course of one consideration head, and every employee takes care of a block of rows of the eye matrix. Because the rows of the eye matrix don’t rely upon one another, we don’t want to speak between the employees.

Within the backward go, we parallelize issues barely in a different way: every employee now takes care of a block of columns of the eye matrix. The employees want to speak to mixture the gradient with respect to the question, which could be accomplished with atomic operations. We discovered that parallelizing by columns right here is quicker than parallelizing by rows because of the decreased communication between the employees (parallelizing by columns requires aggregating the gradient of the question, whereas parallelizing by rows requires aggregating the gradient of the important thing and worth).

Picture caption: Within the ahead go, we parallelize the employees (thread blocks) the place every employee takes care of a block of rows of the eye matrix. Within the backward go, every employee takes care of a block of rows of the eye matrix.

Consideration layer benchmark: We examine right here the time taken by the ahead + backward go, as we improve the sequence size (and reduce the batch dimension to maintain the entire variety of tokens the identical). We maintain the variety of heads at 12 and head dimension at 128. Time is measured on an A100 40GB GPU. In comparison with Pytorch and Megatron-LM consideration implementations, FlashAttention is between 2.2x and a pair of.7x sooner for longer sequences (8k).

Finish-to-end coaching benchmark: once we use FlashAttention to coach Transformers of dimension as much as 2.7B on sequences of size 8k, we obtain a coaching effectivity of as much as 175 TFLOPs/sec per A100 (equal to mannequin FLOPs effectivity of 56%, we don’t have to do any activation checkpointing). That is 2.2 instances sooner than Megatron-LM, as proven within the determine under. Furthermore, coaching with 8k context size with FlashAttention is simply 7% much less {hardware} environment friendly than coaching with 2k context size, as in comparison with Megatron-LM the place rising context size from 2k to 8k drops {hardware} effectivity by 1.9x. FlashAttention has made it a lot simpler to coach on lengthy sequences.

Analysis: higher language fashions with longer sequence lengths

We prepare GPT3 fashions with 1.3B and a pair of.7B parameters for 400B tokens on the Pile, with both 2K or 8K context. On each pretraining metrics (validation perplexity) and downstream analysis (e.g. accuracy on the ChapterBreak problem dataset), fashions with longer context outperforms fashions with shorter context.

FlashAttention table

We consider these fashions on the ChapterBreak dataset (a problem dataset for long-range language fashions the place one is meant to differentiate the best textual content that follows a chapter break). As one will increase the context size, the accuracy of the fashions will increase.

On each metrics, rising the context size past the usual 2K yields constant high quality enchancment.

Trying ahead: extra use circumstances for lengthy sequence lengths

FlashAttention is only a step in direction of equipping fashions with lengthy context, by making it quick to coach fashions on lengthy sequences. ML fashions at the moment are broadly deployed, interacting with billions of customers a day. As these fashions turn out to be extra customized, capturing the historical past of consumer interplay turns into essential. The long run AI brokers ought to have the ability to keep in mind its previous actions and customers’ suggestions. Furthermore, as ML fashions are going to be multi-modal (e.g., textual content, imaginative and prescient, speech, and many others.), lengthy context modeling will play a fair larger position. Lengthy context will permit fashions to know books, excessive decision photos, and movies.

We’re enthusiastic about this imaginative and prescient! You probably have an utility that you simply assume may gain advantage from these concepts, please tell us!

Tri Dao


Source Link

What's Your Reaction?
Excited
0
Happy
0
In Love
0
Not Sure
0
Silly
0
View Comments (0)

Leave a Reply

Your email address will not be published.

2022 Blinking Robots.
WordPress by Doejo

Scroll To Top