Making Pytorch Transformer Twice as Fast on Sequence Generation.
At Scale AI, we use Machine Learning models in a wide range of applications to
empower our data labeling pipeline. We strive for speed and efficiency, and
always try to get the best out of the models. Here, we will discuss some
tricks we discovered that drastically improve over the PyTorch Transformer
implementation in just a few lines of code.
Transformers are Here to Stay
Transformers have become ubiquitous. They were first introduced in Attention
is All You Need
were quickly added to Pytorch. Their popularity increased even more with the
development of HuggingFace, which made large NLP pre-trained models such as
BERT
widely accessible, and created recipes to enable simple fine-tuning on a wide
range of tasks. They’ve been successfully applied to a wide variety of
sequence-to-sequence (Seq2Seq) tasks including machine translation, text
summarization, or even image captioning (an image is just a sequence of
pixels!). This popularity is completely warranted, because Transformers have
some significant upsides:
- The Transformer architecture is non-sequential making it distributable.
- Other traditional methods for sequence modeling such as RNNs are limited to
- processing sequences one token at a time. This sequential nature prevents
- parallelization and makes training slow. Transformers process entire
- sequences at once in a highly parallel fashion. This makes them incredibly
- fast on GPUs and helps handle long-range dependencies elegantly.
- Transformers make few assumptions about the data. Traditionally, ML
- practitioners have tailored their networks to process specific types of
- data. Constraints such as forcing RNNs to process text sequentially from
- left-to-right allowed these networks to perform well even on scarce training
- data, spearheading breakthroughs. However, these constraints introduce
- biases in the model as sequential ordering is rarely the most optimal way to
- understand text. Transformers keep data representations generic, which make
- them capable of learning more subtle interactions between words. Recent
- papers
- (Alexey Dosovitskiy et al., 2020)
- show that the same story could also be true in computer vision, with
- transformers outperforming the long established CNNs once trained on the
- huge datasets that have recently been collected.
- With enough data, Transformers learn more complex and accurate
- representations than the constrained networks which used to be the only
- viable option.
Sequence-to-Sequence with Transformers
But Transformers also have their weaknesses. When generating sequences for
Seq2Seq tasks at inference time, Transformers are constrained because each
item in the output sequence can only be predicted one at a time. This,
combined with the quadratic attention complexity can make them slower than
their counterparts. (For training, this is not an issue thanks to teacher
forcing).
Seq2Seq models typically create an internal high-level representation of the
input sequence and then decode (i.e. generate) the output sentence. Given the
high-level representation of the input sentence and the words that have
already been decoded, Seq2Seq models estimate the most likely words to
complete the sentence. This phenomenon is called auto-regression and
the phase corresponding to generating a new word (or token) is a
timestep.
When a Transformer is used as a Seq2Seq model, the input sequence is fed
through an Encoder, and the output sequence is then generated by a Decoder, as
illustrated in figures 1 and 2.
Decoding Inefficiency of the PyTorch Transformers
The Transformer class in Pytorch is generic which is great because it gives
the ML researchers at Scale AI fine-tuned control but that also means it isn’t
optimized for speed. Let’s take a deeper look.
First, it can be seen in Figure 1 that the encoder output can be computed
separately from the decoder. This means that the encoder outputs can be
computed once and re-used for each timestep thereafter. But Pytorch does NOT
save this for you - and in fact wastes compute for each decoding timestep. To
fix this, the Transformer Encoder and Decoder should always be separated.
# THIS IS THE NAIVE WAY TO USE TRANSFORMERS # INITIALIZATION transformer = nn.Transformer( d_model=hdim, nhead=nhead, num_encoder_layers=num_layers, num_decoder_layers=num_layers, dim_feedforward=dim_feedforward, ).to(device=device) transformer.eval() # INFERENCE LOOP decoded_tokens = first token for i in range(len_output_to_decode) : # generate `len_output_to_decode` tokens mask_dec = generate_square_subsequent_mask( i + 1, device=first_token.device ) # create mask for autoregressive decoding decoded_embeddings = embedding(decoded_tokens) output = transformer(src, decoded_embeddings, tgt_mask=mask_dec) logits = to_vocab(output) # projection to vocab size # keep most likely tokens top_indices = torch.argmax(logits, dim=-1) # we only care about the last token that was decoded top_indices_last_token = top_indices[-1:] # add most likely token to the already decoded tokens decoded_tokens = torch.cat( [decoded_tokens, top_indices_last_token], dim=0 )
The code below is a much more efficient way to get the same results by
decoupling the encoder and the decoder. Note that the code corresponding to
the inference loop barely changes.
# INITIALIZATION encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer( d_model=hdim, nhead=nhead, dim_feedforward=dim_feedforward ), num_layers=num_layers, ).to(device=device) encoder.eval() decoder = nn.TransformerDecoder( nn.TransformerDecoderLayer( d_model=hdim, nhead=nhead, dim_feedforward=dim_feedforward ), num_layers=num_layers, ).to(device=device) decoder.eval() # INFERENCE LOOP decoded_tokens = first_token src_embeddings = encoder(src) for i in range(lenoutput_to_decode): mask_dec = generate_square_subsequent_mask( i + 1, device=first_token.device ) # create mask for autoregressive decoding decoded_embeddings = embedding(decoded_tokens) # the decoder uses the encoder output `src_embeddings` output = decoder(decoded_embeddings, src_embeddings, tgt_mask=mask_dec) logits = to_vocab(output) # projection to vocab size # keep most likely tokens top_indices = torch.argmax(logits, dim=-1) # we only care about the last token that was decoded top_indices_last_token = top_indices[-1:] # add most likely token to the already decoded tokens decoded_tokens = torch.cat( [decoded_tokens, top_indices_last_token], dim=0 )
The main inefficiency extends on the previous point. It can be seen in Figure
2 that the embedding of a decoded token only depends on the tokens that were
decoded before it. This is a direct benefit of the Transformer model being
autoregressive. Thus, it is unnecessary to recompute the embeddings of the
already decoded tokens repeatedly and instead, we can again cache them. Each
timestep then only consists of computing the attention for the newest token’s
embedding.
Figure 3: Decoder self-attention links when decoding tokens.
The boxes at the bottom represent the embeddings of the output tokens
before self-attention, the top boxes represent the embeddings of the output
tokens after self-attention. Using our trick (right side), most of the
embeddings are not recomputed as they are cached. The number of links to
into account becomes linear instead of quadratic.
From a complexity perspective, generating the n-th output token without our
trick involves computing self-attention over the entire current output
(O(n²)) and computing encoder-decoder attention between the whole
input (of size that we will note M) with the current output (O(Mn)).
Hence, the complexity of each timestep is O(Mn + n²). Given that we
want to decode N tokens, N timesteps are needed, and the final complexity is
O(MN² + N³).
Our trick accelerates each timestep. Only the parts of the self-attention and
encoder-decoder attention responsible for updating the last token are
computed. Figure 3 shows how this works for the self-attention. The new
complexity of each timestep is O(M + N), so with N timesteps the
final complexity is sped up to O(MN + N²).
The PyTorch Transformer decoder architecture is not assumed to be
autoregressive. However, by inheriting the TransformerDecoder layer, we
introduce a CausalTransformerDecoder which uses a cache to implement the
improvement above. Our
differs from the Pytorch implementation by a few lines only. Our new decoder
works similarly to the original TransformerDecoder, except that we now have to
take into account the cache:
causal_decoder = CausalTransformerDecoder( CausalTransformerDecoderLayer( d_model=hdim, nhead=nhead, dim_feedforward=dim_feedforward, ), num_layers=6, ).to(device=device) causal_decoder.eval() decoded_tokens = first_token src_embeddings = encoder(src) cache = None for i in range(len_output_to_decode): mask_dec = generate_square_subsequent_mask( i + 1, device=first_token.device ) # create mask for autoregressive decoding decoded_embeddings = embedding(decoded_tokens) # only change here: we add the cache as an extra parameter output, cache = causal_decoder(decoded_embeddings, src_embeddings, cache) logits = to_vocab(output) # projection to vocab size # keep most likely tokens top_indices = torch.argmax(logits, dim=-1) # we only care about the last token that was decoded top_indices_last_token = top_indices[-1:] # add most likely token to the already decoded tokens decoded_tokens = torch.cat( [decoded_tokens, top_indices_last_token], dim=0 )
Experiments
We put our changes to the test to see how much faster we could get. We present
two different scenarios: translation and generation of long texts.
We compare our three different implementations
- The most naive Pytorch implementation (defined in the first piece of code),
- which uses nn.Transformer
- The Pytorch encoder-decoder implementation (second piece of code).
- Our CausalTransformerDecoder (third piece of code).
As a reminder, these are three different implementations of the same model.
When initialized with the same weights, they return the same outputs.
Text Translation
The first setting corresponds to translation. In this setting the input and
output sequences are generally short and of similar lengths.
The non-linear curves show that the attention mechanisms progressively
become the most compute-intensive parts of the model as the number of input
and output tokens increase.
Our causal implementation is up to 40% faster than the Pytorch
Encoder-Decoder implementation, and 150% faster than the Pytorch
nn.Transformer implementation for 500 input/output tokens.
Long Text Generation
We now ask the model to generate long sequences from a fixed size input. Such
a situation might arise when generating a story from an image or from an
initial prompt.
The results below were obtained with a fixed input size of 500 tokens.
Increasing the number of input tokens makes the models slower but doesn’t
change the overall trends observed.
Our causal model is twice as fast as the PyTorch encoder-decoder
implementation when the number of tokens to generate exceeds 1,000.
When decoding more than 500 tokens, the time ratio between the causal model
and the other implementations becomes linear. This confirms the theory
according to which the overall decoding complexity was reduced by a factor of
N.
Finally, our CausalTransformerDecoder can also be used without any input
sentence (i.e. without an encoder), as it is the case in some common
generation settings. The model is typically asked to complete a story or an
article. More information about this type of generation can be found in the
GPT papers
(Alec Radford et al., 2018). The results we find for this case are similar to the ones above.
Digging deeper…
One might notice that caching the output of each layer is sub-optimal. Indeed,
the first stage of the attention layers consists of projecting the embeddings
to the keys, values and queries spaces. In the PyTorch implementation and the
proposed implementation, the same embeddings get projected repeatedly.
Instead, the queries, keys and values could be directly cached. However, this
requires substantially more changes, which could become unstable with new
Pytorch upgrades. Moreover, the estimated gains are minor - less than 5% from
our experiments.
For standard NLP use cases, the HuggingFace repository already embeds these
optimizations. Notably, it caches keys and values. It also comes with
different decoding flavors, such as beam search or nucleus sampling.
Conclusion
The simple tricks proposed take advantage of the fact that the overall Pytorch
implementation of the Transformer is too generic. The changes provide modest
speed improvements when generating a few hundreds of tokens which can become
significant boosts over the original PyTorch implementation when the output
length nears a thousand tokens. These gains are naturally directly
proportional to the number of output tokens to decode. And best of all, they
can be implemented in just a few lines using our
repo.
References:
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan
N Gomez, Łukasz Kaiser, and Illia Polosukhin.
Attention is all you need. In Advances in Neural Information
Processing Systems 30, pp. 5998–6008. Curran Associates, Inc., 2017. URL:
https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf
Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova.
BERT: pre-training of deep bidirectional transformers for language
understanding. CoRR, abs/1810.04805, 2018. URL:
http://arxiv.org/abs/1810.04805.
Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn,
Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani et al.
An Image is Worth 16x16 Words: Transformers for Image Recognition at
Scale.
arXiv preprint 2020 URL:
https://arxiv.org/pdf/2010.11929
Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Clement Delangue,
Anthony Moi, Pierric Cistac et al.
Transformers: State-of-the-art natural language processing arXiv
preprint 2019 URL:
https://arxiv.org/pdf/1910.03771.pdf
Alec Radford, Karthik Narasimhan, Tim Salimans, and Ilya Sutskever.
Improving language understanding by generative pre-training. arXiv
preprint 2018 URL:
https://www.cs.ubc.ca/~amuham01/LING530/papers/radford2018improving.pdf
Footnotes:
1 - The vocabulary size used for this experiment was 30,000. It corresponds to
the vocabulary size chosen in the BERT original paper, which uses BPE
tokenization. We also ran the same experiments with a much smaller vocabulary
(128 tokens), to imitate a character-level setting which did not show
significant benefits.
2 - The hyperparameters used for the Transformer architecture are the ones of
the original paper (6 layers, 8 heads, 512 hidden dimensions, 2048
feed-forward hidden dimensions for both encoder/decoder). The results should
be similar with other configurations, provided that the encoder and decoder
have the same size.
3 - The displayed results correspond to a batch-size of 8 sequences, but we
made sure that a batch-size of 1 gives the same trend. A batch-size of 1 is
usually the most common in inference as requests are sent asynchronously.
However, in our specific setup the model takes several seconds to generate
sentences, so it is more natural to batch requests.
3 - As a safety check, we benchmarked GPT-2 HuggingFace implementation against
our Causal Decoder. To do that, we used the same set of hyperparameters. We
generated up to 1000 tokens with the two models. The speed ratio between these
two models was close to 1, oscillating between 0.85 and 1.10.
4 - All the experiments were run on a V100 GPU.